Approximate inference in aGrUM (pyAgrum)
There are several approximate inference for BN in aGrUM (pyAgrum). They share the same API than exact inference.
Loopy Belief Propagation : LBP is an approximate inference that uses exact calculous methods (when the BN os a tree) even if the BN is not a tree. LBP is a special case of inference : the algorithm may not converge and even if it converges, it may converge to anything (but the exact posterior). LBP however is fast and usually gives not so bad results.
Sampling inference : Sampling inference use sampling to compute the posterior. The sampling may be (very) slow but those algorithms converge to the exac distribution. aGrUM implements :
Montecarlo sampling,
Weighted sampling,
Importance sampling,
Gibbs sampling.
Finally, aGrUM propose the so-called ‘loopy version’ of the sampling algorithms : the idea is to use LBP as a Dirichlet prior for the sampling algorithm. A loopy version of each sampling algorithm is proposed.
In [1]:
%matplotlib inline
from pylab import *
import matplotlib.pyplot as plt
def unsharpen(bn):
"""
Force the parameters of the BN not to be a bit more far from 0 or 1
"""
for nod in bn.nodes():
bn.cpt(nod).translate(bn.maxParam() / 10).normalizeAsCPT()
def compareInference(ie, ie2, ax=None):
"""
compare 2 inference by plotting all the points from (posterior(ie),posterior(ie2))
"""
exact = []
appro = []
errmax = 0
for node in bn.nodes():
# Tensors as list
exact += ie.posterior(node).tolist()
appro += ie2.posterior(node).tolist()
errmax = max(errmax, (ie.posterior(node) - ie2.posterior(node)).abs().max())
if errmax < 1e-10:
errmax = 0
if ax == None:
fig = plt.Figure(figsize=(4, 4))
ax = plt.gca() # default axis for plt
ax.plot(exact, appro, "ro")
ax.set_title(
"{} vs {}\n {}\nMax error {:2.4} in {:2.4} seconds".format(
str(type(ie)).split(".")[2].split("_")[0][0:-2], # name of first inference
str(type(ie2)).split(".")[2].split("_")[0][0:-2], # name of second inference
ie2.messageApproximationScheme(),
errmax,
ie2.currentTime(),
)
)
In [2]:
import pyagrum as gum
import pyagrum.lib.notebook as gnb
bn = gum.loadBN("res/alarm.dsl")
unsharpen(bn)
ie = gum.LazyPropagation(bn)
ie.makeInference()
In [3]:
gnb.showBN(bn, size="8")
First, an exact inference.
In [4]:
gnb.sideBySide(gnb.getJunctionTreeMap(bn), gnb.getInference(bn, size="8")) # using LazyPropagation by default
print(ie.posterior("KINKEDTUBE"))
KINKEDTUBE |
TRUE |FALSE |
---------|---------|
0.1167 | 0.8833 |
Gibbs Inference
Gibbs inference with default parameters
Gibbs inference iterations can be stopped :
by the value of error (epsilon)
by the rate of change of epsilon (MinEpsilonRate)
by the number of iteration (MaxIteration)
by the duration of the algorithm (MaxTime)
In [5]:
ie2 = gum.GibbsSampling(bn)
ie2.setEpsilon(1e-2)
gnb.showInference(bn, engine=ie2, size="8")
print(ie2.posterior("KINKEDTUBE"))
print(ie2.messageApproximationScheme())
compareInference(ie, ie2)
KINKEDTUBE |
TRUE |FALSE |
---------|---------|
0.0968 | 0.9032 |
stopped with rate=0.00673795
With default parameters, this inference has been stopped by a low value of rate.
Changing parameters
In [6]:
ie2 = gum.GibbsSampling(bn)
ie2.setMaxIter(1000)
ie2.setEpsilon(5e-3)
ie2.makeInference()
print(ie2.posterior(2))
print(ie2.messageApproximationScheme())
INTUBATION |
NORMAL |ESOPHAGEA|ONESIDED |
---------|---------|---------|
0.8736 | 0.0664 | 0.0600 |
stopped with max iteration=1000
In [7]:
compareInference(ie, ie2)
In [8]:
ie2 = gum.GibbsSampling(bn)
ie2.setMaxTime(3)
ie2.makeInference()
print(ie2.posterior(2))
print(ie2.messageApproximationScheme())
compareInference(ie, ie2)
INTUBATION |
NORMAL |ESOPHAGEA|ONESIDED |
---------|---------|---------|
0.6433 | 0.1900 | 0.1667 |
stopped with epsilon=0.201897
Looking at the convergence
In [9]:
ie2 = gum.GibbsSampling(bn)
ie2.setEpsilon(10**-1.8)
ie2.setBurnIn(300)
ie2.setPeriodSize(300)
ie2.setDrawnAtRandom(True)
gnb.animApproximationScheme(ie2)
ie2.makeInference()
In [10]:
compareInference(ie, ie2)
Importance Sampling
In [11]:
ie4 = gum.ImportanceSampling(bn)
ie4.setEpsilon(10**-1.8)
ie4.setMaxTime(10) # 10 seconds for inference
ie4.setPeriodSize(300)
ie4.makeInference()
compareInference(ie, ie4)
Loopy Gibbs Sampling
Every sampling inference has a ‘hybrid’ version which consists in using a first loopy belief inference as a prior for the probability estimations by sampling.
In [12]:
ie3 = gum.LoopyGibbsSampling(bn)
ie3.setEpsilon(10**-1.8)
ie3.setMaxTime(10) # 10 seconds for inference
ie3.setPeriodSize(300)
ie3.makeInference()
compareInference(ie, ie3)
Comparison of approximate inference
These computations may be a bit long
In [13]:
def compareAllInference(bn, evs={}, epsilon=10**-1.6, epsilonRate=1e-8, maxTime=20):
ies = [
gum.LazyPropagation(bn),
gum.LoopyBeliefPropagation(bn),
gum.GibbsSampling(bn),
gum.LoopyGibbsSampling(bn),
gum.WeightedSampling(bn),
gum.LoopyWeightedSampling(bn),
gum.ImportanceSampling(bn),
gum.LoopyImportanceSampling(bn),
]
# burn in for Gibbs samplings
for i in [2, 3]:
ies[i].setBurnIn(300)
ies[i].setDrawnAtRandom(True)
for i in range(2, len(ies)):
ies[i].setEpsilon(epsilon)
ies[i].setMinEpsilonRate(epsilonRate)
ies[i].setPeriodSize(300)
ies[i].setMaxTime(maxTime)
for i in range(len(ies)):
ies[i].setEvidence(evs)
ies[i].makeInference()
fig, axes = plt.subplots(1, len(ies) - 1, figsize=(35, 3), num="gpplot")
for i in range(len(ies) - 1):
compareInference(ies[0], ies[i + 1], axes[i])
Inference stopped by epsilon
In [14]:
compareAllInference(bn, epsilon=1e-1)
In [15]:
compareAllInference(bn, epsilon=1e-2)
inference stopped by time
In [16]:
compareAllInference(bn, maxTime=1, epsilon=1e-8)
In [17]:
compareAllInference(bn, maxTime=2, epsilon=1e-8)
Inference with Evidence (more complex)
In [18]:
funny = {"BP": 1, "PCWP": 2, "EXPCO2": 0, "HISTORY": 0}
compareAllInference(bn, maxTime=1, evs=funny, epsilon=1e-8)
In [19]:
compareAllInference(bn, maxTime=4, evs=funny, epsilon=1e-8)
In [20]:
compareAllInference(bn, maxTime=10, evs=funny, epsilon=1e-8)

