Approximate inference in aGrUM (pyAgrum)

Creative Commons License

aGrUM

interactive online version

There are several approximate inference for BN in aGrUM (pyAgrum). They share the same API than exact inference.

  • Loopy Belief Propagation : LBP is an approximate inference that uses exact calculous methods (when the BN os a tree) even if the BN is not a tree. LBP is a special case of inference : the algorithm may not converge and even if it converges, it may converge to anything (but the exact posterior). LBP however is fast and usually gives not so bad results.

  • Sampling inference : Sampling inference use sampling to compute the posterior. The sampling may be (very) slow but those algorithms converge to the exac distribution. aGrUM implements :

    • Montecarlo sampling,

    • Weighted sampling,

    • Importance sampling,

    • Gibbs sampling.

  • Finally, aGrUM propose the so-called ‘loopy version’ of the sampling algorithms : the idea is to use LBP as a Dirichlet prior for the sampling algorithm. A loopy version of each sampling algorithm is proposed.

In [1]:
%matplotlib inline
from pylab import *
import matplotlib.pyplot as plt


def unsharpen(bn):
  """
  Force the parameters of the BN not to be a bit more far from 0 or 1
  """
  for nod in bn.nodes():
    bn.cpt(nod).translate(bn.maxParam() / 10).normalizeAsCPT()


def compareInference(ie, ie2, ax=None):
  """
  compare 2 inference by plotting all the points from (posterior(ie),posterior(ie2))
  """
  exact = []
  appro = []
  errmax = 0
  for node in bn.nodes():
    # Tensors as list
    exact += ie.posterior(node).tolist()
    appro += ie2.posterior(node).tolist()
    errmax = max(errmax, (ie.posterior(node) - ie2.posterior(node)).abs().max())

  if errmax < 1e-10:
    errmax = 0
  if ax == None:
    fig = plt.Figure(figsize=(4, 4))
    ax = plt.gca()  # default axis for plt

  ax.plot(exact, appro, "ro")
  ax.set_title(
    "{} vs {}\n {}\nMax error {:2.4} in {:2.4} seconds".format(
      str(type(ie)).split(".")[2].split("_")[0][0:-2],  # name of first inference
      str(type(ie2)).split(".")[2].split("_")[0][0:-2],  # name of second inference
      ie2.messageApproximationScheme(),
      errmax,
      ie2.currentTime(),
    )
  )
In [2]:
import pyagrum as gum
import pyagrum.lib.notebook as gnb

bn = gum.loadBN("res/alarm.dsl")
unsharpen(bn)

ie = gum.LazyPropagation(bn)
ie.makeInference()
In [3]:
gnb.showBN(bn, size="8")
../_images/notebooks_44-Inference_ApproximateInference_5_0.svg

First, an exact inference.

In [4]:
gnb.sideBySide(gnb.getJunctionTreeMap(bn), gnb.getInference(bn, size="8"))  # using LazyPropagation by default
print(ie.posterior("KINKEDTUBE"))
0 0~16 0--0~16 1 1~32 1--1~32 2 2~33 2--2~33 3 3~4 3--3~4 4 4~22 4--4~22 5 5~22 5--5~22 6 6~23 6--6~23 7 7~26 7--7~26 8 8~17 8--8~17 10 10~14 10--10~14 11 11~16 11--11~16 12 12~13 12--12~13 13 13~30 13--13~30 14 14~26 14--14~26 16 16~17 16--16~17 17 17~24 17--17~24 19 19~27 19--19~27 20 20~33 20--20~33 22 22~33 22--22~33 23 23~27 23--23~27 23~31 23--23~31 24 24~26 24--24~26 26 26~27 26--26~27 27 30 30~31 30--30~31 31 31~32 31--31~32 32 32~33 32--32~33 33 19~27--27 12~13--13 2~33--33 23~27--27 22~33--33 11~16--16 24~26--26 31~32--32 10~14--14 26~27--27 13~30--30 5~22--22 7~26--26 20~33--33 16~17--17 32~33--33 23~31--31 8~17--17 1~32--32 3~4--4 4~22--22 17~24--24 14~26--26 30~31--31 6~23--23 0~16--16
structs Inference in  16.44ms KINKEDTUBE 2025-10-29T14:06:14.778300 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ VENTLUNG 2025-10-29T14:06:15.193439 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ KINKEDTUBE->VENTLUNG PRESS 2025-10-29T14:06:15.242099 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ KINKEDTUBE->PRESS HYPOVOLEMIA 2025-10-29T14:06:14.795719 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ STROKEVOLUME 2025-10-29T14:06:15.011039 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ HYPOVOLEMIA->STROKEVOLUME LVEDVOLUME 2025-10-29T14:06:15.048698 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ HYPOVOLEMIA->LVEDVOLUME INTUBATION 2025-10-29T14:06:14.816481 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ SHUNT 2025-10-29T14:06:15.109500 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ INTUBATION->SHUNT INTUBATION->VENTLUNG MINVOL 2025-10-29T14:06:15.216670 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ INTUBATION->MINVOL INTUBATION->PRESS VENTALV 2025-10-29T14:06:15.265928 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ INTUBATION->VENTALV MINVOLSET 2025-10-29T14:06:14.837139 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ VENTMACH 2025-10-29T14:06:15.070290 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ MINVOLSET->VENTMACH PULMEMBOLUS 2025-10-29T14:06:14.855641 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ PAP 2025-10-29T14:06:14.991892 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ PULMEMBOLUS->PAP PULMEMBOLUS->SHUNT INSUFFANESTH 2025-10-29T14:06:14.874147 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ CATECHOL 2025-10-29T14:06:15.392311 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ INSUFFANESTH->CATECHOL ERRLOWOUTPUT 2025-10-29T14:06:14.891242 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ HRBP 2025-10-29T14:06:15.429657 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ ERRLOWOUTPUT->HRBP ERRCAUTER 2025-10-29T14:06:14.907481 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ HRSAT 2025-10-29T14:06:15.449429 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ ERRCAUTER->HRSAT HREKG 2025-10-29T14:06:15.487626 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ ERRCAUTER->HREKG FIO2 2025-10-29T14:06:14.923978 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ PVSAT 2025-10-29T14:06:15.329528 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ FIO2->PVSAT LVFAILURE 2025-10-29T14:06:14.941007 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ LVFAILURE->STROKEVOLUME LVFAILURE->LVEDVOLUME HISTORY 2025-10-29T14:06:15.128055 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ LVFAILURE->HISTORY DISCONNECT 2025-10-29T14:06:14.957782 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ VENTTUBE 2025-10-29T14:06:15.147866 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ DISCONNECT->VENTTUBE ANAPHYLAXIS 2025-10-29T14:06:14.973922 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ TPR 2025-10-29T14:06:15.030357 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ ANAPHYLAXIS->TPR CO 2025-10-29T14:06:15.468172 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ STROKEVOLUME->CO TPR->CATECHOL BP 2025-10-29T14:06:15.506853 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ TPR->BP PCWP 2025-10-29T14:06:15.091541 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ LVEDVOLUME->PCWP CVP 2025-10-29T14:06:15.170478 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ LVEDVOLUME->CVP VENTMACH->VENTTUBE SAO2 2025-10-29T14:06:15.349879 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ SHUNT->SAO2 VENTTUBE->VENTLUNG VENTTUBE->PRESS VENTLUNG->MINVOL VENTLUNG->VENTALV EXPCO2 2025-10-29T14:06:15.371306 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ VENTLUNG->EXPCO2 ARTCO2 2025-10-29T14:06:15.289224 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ VENTALV->ARTCO2 VENTALV->PVSAT ARTCO2->EXPCO2 ARTCO2->CATECHOL PVSAT->SAO2 SAO2->CATECHOL HR 2025-10-29T14:06:15.410572 image/svg+xml Matplotlib v3.10.7, https://matplotlib.org/ CATECHOL->HR HR->HRBP HR->HRSAT HR->CO HR->HREKG CO->BP

  KINKEDTUBE       |
TRUE     |FALSE    |
---------|---------|
 0.1167  | 0.8833  |

Gibbs Inference

Gibbs inference with default parameters

Gibbs inference iterations can be stopped :

  • by the value of error (epsilon)

  • by the rate of change of epsilon (MinEpsilonRate)

  • by the number of iteration (MaxIteration)

  • by the duration of the algorithm (MaxTime)

In [5]:
ie2 = gum.GibbsSampling(bn)
ie2.setEpsilon(1e-2)
gnb.showInference(bn, engine=ie2, size="8")
print(ie2.posterior("KINKEDTUBE"))
print(ie2.messageApproximationScheme())
compareInference(ie, ie2)
../_images/notebooks_44-Inference_ApproximateInference_9_0.svg

  KINKEDTUBE       |
TRUE     |FALSE    |
---------|---------|
 0.0968  | 0.9032  |

stopped with rate=0.00673795
../_images/notebooks_44-Inference_ApproximateInference_9_2.svg

With default parameters, this inference has been stopped by a low value of rate.

Changing parameters

In [6]:
ie2 = gum.GibbsSampling(bn)
ie2.setMaxIter(1000)
ie2.setEpsilon(5e-3)
ie2.makeInference()

print(ie2.posterior(2))
print(ie2.messageApproximationScheme())

  INTUBATION                 |
NORMAL   |ESOPHAGEA|ONESIDED |
---------|---------|---------|
 0.8736  | 0.0664  | 0.0600  |

stopped with max iteration=1000
In [7]:
compareInference(ie, ie2)
../_images/notebooks_44-Inference_ApproximateInference_13_0.svg
In [8]:
ie2 = gum.GibbsSampling(bn)
ie2.setMaxTime(3)
ie2.makeInference()

print(ie2.posterior(2))
print(ie2.messageApproximationScheme())
compareInference(ie, ie2)

  INTUBATION                 |
NORMAL   |ESOPHAGEA|ONESIDED |
---------|---------|---------|
 0.6433  | 0.1900  | 0.1667  |

stopped with epsilon=0.201897
../_images/notebooks_44-Inference_ApproximateInference_14_1.svg

Looking at the convergence

In [9]:
ie2 = gum.GibbsSampling(bn)
ie2.setEpsilon(10**-1.8)
ie2.setBurnIn(300)
ie2.setPeriodSize(300)
ie2.setDrawnAtRandom(True)
gnb.animApproximationScheme(ie2)
ie2.makeInference()
../_images/notebooks_44-Inference_ApproximateInference_16_0.svg
In [10]:
compareInference(ie, ie2)
../_images/notebooks_44-Inference_ApproximateInference_17_0.svg

Importance Sampling

In [11]:
ie4 = gum.ImportanceSampling(bn)
ie4.setEpsilon(10**-1.8)
ie4.setMaxTime(10)  # 10 seconds for inference
ie4.setPeriodSize(300)
ie4.makeInference()
compareInference(ie, ie4)
../_images/notebooks_44-Inference_ApproximateInference_19_0.svg

Loopy Gibbs Sampling

Every sampling inference has a ‘hybrid’ version which consists in using a first loopy belief inference as a prior for the probability estimations by sampling.

In [12]:
ie3 = gum.LoopyGibbsSampling(bn)

ie3.setEpsilon(10**-1.8)
ie3.setMaxTime(10)  # 10 seconds for inference
ie3.setPeriodSize(300)
ie3.makeInference()
compareInference(ie, ie3)
../_images/notebooks_44-Inference_ApproximateInference_21_0.svg

Comparison of approximate inference

These computations may be a bit long

In [13]:
def compareAllInference(bn, evs={}, epsilon=10**-1.6, epsilonRate=1e-8, maxTime=20):
  ies = [
    gum.LazyPropagation(bn),
    gum.LoopyBeliefPropagation(bn),
    gum.GibbsSampling(bn),
    gum.LoopyGibbsSampling(bn),
    gum.WeightedSampling(bn),
    gum.LoopyWeightedSampling(bn),
    gum.ImportanceSampling(bn),
    gum.LoopyImportanceSampling(bn),
  ]

  # burn in for Gibbs samplings
  for i in [2, 3]:
    ies[i].setBurnIn(300)
    ies[i].setDrawnAtRandom(True)

  for i in range(2, len(ies)):
    ies[i].setEpsilon(epsilon)
    ies[i].setMinEpsilonRate(epsilonRate)
    ies[i].setPeriodSize(300)
    ies[i].setMaxTime(maxTime)

  for i in range(len(ies)):
    ies[i].setEvidence(evs)
    ies[i].makeInference()

  fig, axes = plt.subplots(1, len(ies) - 1, figsize=(35, 3), num="gpplot")
  for i in range(len(ies) - 1):
    compareInference(ies[0], ies[i + 1], axes[i])

Inference stopped by epsilon

In [14]:
compareAllInference(bn, epsilon=1e-1)
../_images/notebooks_44-Inference_ApproximateInference_25_0.svg
In [15]:
compareAllInference(bn, epsilon=1e-2)
../_images/notebooks_44-Inference_ApproximateInference_26_0.svg

inference stopped by time

In [16]:
compareAllInference(bn, maxTime=1, epsilon=1e-8)
../_images/notebooks_44-Inference_ApproximateInference_28_0.svg
In [17]:
compareAllInference(bn, maxTime=2, epsilon=1e-8)
../_images/notebooks_44-Inference_ApproximateInference_29_0.svg

Inference with Evidence (more complex)

In [18]:
funny = {"BP": 1, "PCWP": 2, "EXPCO2": 0, "HISTORY": 0}
compareAllInference(bn, maxTime=1, evs=funny, epsilon=1e-8)
../_images/notebooks_44-Inference_ApproximateInference_31_0.svg
In [19]:
compareAllInference(bn, maxTime=4, evs=funny, epsilon=1e-8)
../_images/notebooks_44-Inference_ApproximateInference_32_0.svg
In [20]:
compareAllInference(bn, maxTime=10, evs=funny, epsilon=1e-8)
../_images/notebooks_44-Inference_ApproximateInference_33_0.svg