aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
ShaferShenoyMRFInference.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40
41
49#ifndef GUM_SHAFER_SHENOY_MN_INFERENCE_H
50#define GUM_SHAFER_SHENOY_MN_INFERENCE_H
51
52#include <utility>
53
54#include <agrum/agrum.h>
55
60
61namespace gum {
62
63
64 // the function used to combine two tables
65 template < typename GUM_SCALAR >
66 INLINE static Tensor< GUM_SCALAR > SSNewMNmultiTensor(const Tensor< GUM_SCALAR >& t1,
67 const Tensor< GUM_SCALAR >& t2) {
68 return t1 * t2;
69 }
70
71 // the function used to combine two tables
72 template < typename GUM_SCALAR >
73 INLINE static Tensor< GUM_SCALAR > SSNewMNprojTensor(const Tensor< GUM_SCALAR >& t1,
74 const gum::VariableSet& del_vars) {
75 return t1.sumOut(del_vars);
76 }
77
85 template < typename GUM_SCALAR >
87 public JointTargetedMRFInference< GUM_SCALAR >,
88 public EvidenceMRFInference< GUM_SCALAR >,
89 public ScheduledInference {
90 public:
91 // ############################################################################
93 // ############################################################################
95
98 bool use_binary_join_tree = true);
99
102
104
105
106 // ############################################################################
108 // ############################################################################
110
112 void setTriangulation(const Triangulation& new_triangulation);
113
115
119
121
126
128 GUM_SCALAR evidenceProbability() final;
129
131
132
133 protected:
136 virtual bool isExactJointComputable_(const NodeSet& vars) final;
137 virtual NodeSet superForJointComputable_(const NodeSet& vars) final;
138
140 void onEvidenceAdded_(const NodeId id, bool isHardEvidence) final;
141
143 void onEvidenceErased_(const NodeId id, bool isHardEvidence) final;
144
146 void onAllEvidenceErased_(bool has_hard_evidence) final;
147
155 void onEvidenceChanged_(const NodeId id, bool hasChangedSoftHard) final;
156
158
159 void onMarginalTargetAdded_(const NodeId id) final;
160
162
163 void onMarginalTargetErased_(const NodeId id) final;
164
166 void onModelChanged_(const GraphicalModel* mn) final;
167
169 virtual void onMRFChanged_(const IMarkovRandomField< GUM_SCALAR >* mn) final;
170
172
173 void onJointTargetAdded_(const NodeSet& set) final;
174
176
177 void onJointTargetErased_(const NodeSet& set) final;
178
181
184
187
190
192 void onStateChanged_() final {};
193
195
199
201
205
207
208 void makeInference_() final;
209
210
212
213 const Tensor< GUM_SCALAR >& posterior_(NodeId id) final;
214
216
218 const Tensor< GUM_SCALAR >& jointPosterior_(const NodeSet& set) final;
219
227 const Tensor< GUM_SCALAR >& jointPosterior_(const NodeSet& wanted_target,
228 const NodeSet& declared_target) final;
229
231 Tensor< GUM_SCALAR >* unnormalizedJointPosterior_(NodeId id) final;
232
234 Tensor< GUM_SCALAR >* unnormalizedJointPosterior_(const NodeSet& set) final;
235
236
237 private:
238 using _TensorSet_ = Set< const Tensor< GUM_SCALAR >* >;
240 using _TensorSetIterator_ = SetIteratorSafe< const Tensor< GUM_SCALAR >* >;
241
243 Tensor< GUM_SCALAR > (*_projection_op_)(const Tensor< GUM_SCALAR >&,
245
247 Tensor< GUM_SCALAR > (*_combination_op_)(const Tensor< GUM_SCALAR >&,
248 const Tensor< GUM_SCALAR >&){SSNewMNmultiTensor};
249
252
256
258
261
263 JoinTree* _JT_{nullptr};
264
267
269
273
275
282
284
289
292
294
297
300
302
310
312
317
319
323
325
332
334
336
338
340
346
348
352
354
361
363
368
371
374
378
380 bool _use_schedules_{false};
381
383 static constexpr double _schedule_threshold_{1000000.0};
384
386 static constexpr GUM_SCALAR _one_minus_epsilon_{GUM_SCALAR(1.0 - 1e-6)};
387
388
390 bool _isNewJTNeeded_() const;
391
394
397
400
402 void _setProjectionFunction_(Tensor< GUM_SCALAR > (*proj)(const Tensor< GUM_SCALAR >&,
403 const gum::VariableSet&));
404
406 void _setCombinationFunction_(Tensor< GUM_SCALAR > (*comb)(const Tensor< GUM_SCALAR >&,
407 const Tensor< GUM_SCALAR >&));
408
410 void _diffuseMessageInvalidations_(NodeId from_id, NodeId to_id, NodeSet& invalidated_cliques);
411
414
417
421 _ScheduleMultiDimSet_ pot_list,
422 gum::VariableSet& del_vars,
423 gum::VariableSet& kept_vars);
424
428 gum::VariableSet& del_vars,
429 gum::VariableSet& kept_vars);
430
432 void _produceMessage_(Schedule& schedule, NodeId from_id, NodeId to_id);
433
435 void _produceMessage_(NodeId from_id, NodeId to_id);
436
438 void _collectMessage_(Schedule& schedule, NodeId id, NodeId from);
439
442
444 Tensor< GUM_SCALAR >* _unnormalizedJointPosterior_(Schedule& schedule, NodeId id);
445
447 Tensor< GUM_SCALAR >* _unnormalizedJointPosterior_(NodeId id);
448
450 Tensor< GUM_SCALAR >* _unnormalizedJointPosterior_(Schedule& schedule, const NodeSet& set);
451
453 Tensor< GUM_SCALAR >* _unnormalizedJointPosterior_(const NodeSet& set);
454
455
458
461 = delete;
462 };
463
464
465#ifndef GUM_NO_EXTERN_TEMPLATE_CLASS
466 extern template class ShaferShenoyMRFInference< double >;
467#endif
468
469
470} /* namespace gum */
471
473
474#endif /* SHAFER_SHENOY_INFERENCE_H */
Implementation of Shafer-Shenoy's propagation for inference in Markov random fields.
EvidenceMRFInference(const IMarkovRandomField< GUM_SCALAR > *mn)
default constructor
Virtual base class for probabilistic graphical models.
The class for generic Hash Tables.
Definition hashTable.h:637
Class representing the minimal interface for Markov random field.
The Table-agnostic base class of scheduleMultiDim.
JointTargetedMRFInference(const IMarkovRandomField< GUM_SCALAR > *mn)
default constructor
Class containing a schedule of operations to perform on multidims.
Definition schedule.h:80
ScheduledInference(Size max_nb_threads=0, double max_megabyte_memory=0.0)
default constructor
Safe iterators for the Set class.
Definition set.h:601
Representation of a set.
Definition set.h:131
<agrum/MRF/inference/ShaferShenoyMRFInference.h>
void onAllJointTargetsErased_() final
fired before all the joint targets are removed
HashTable< const Tensor< GUM_SCALAR > *, GUM_SCALAR > _constants_
the constants resulting from the projections of CPTs defined over only hard evidence nodes @TODO remo...
void _setProjectionFunction_(Tensor< GUM_SCALAR >(*proj)(const Tensor< GUM_SCALAR > &, const gum::VariableSet &))
sets the operator for performing the projections
UndiGraph _graph_
the undigraph extracted from the MRF and used to construct the join tree
bool _use_binary_join_tree_
indicates whether we should transform junction trees into binary join trees
void onAllEvidenceErased_(bool has_hard_evidence) final
fired before all the evidence are erased
Triangulation * _triangulation_
the triangulation class creating the junction tree used for inference
void updateOutdatedTensors_() final
prepares inference when the latter is in OutdatedTensors state
NodeProperty< EvidenceChangeType > _evidence_changes_
indicates which nodes of the MRF have evidence that changed since the last inference
void setTriangulation(const Triangulation &new_triangulation)
use a new triangulation algorithm
void _setCombinationFunction_(Tensor< GUM_SCALAR >(*comb)(const Tensor< GUM_SCALAR > &, const Tensor< GUM_SCALAR > &))
sets the operator for performing the combinations
ShaferShenoyMRFInference< GUM_SCALAR > & operator=(const ShaferShenoyMRFInference< GUM_SCALAR > &)=delete
avoid copy operators
NodeProperty< const IScheduleMultiDim * > _clique_ss_tensor_
the tensors stored into the cliques by Shafer-Shenoy
NodeProperty< _ScheduleMultiDimSet_ > _clique_tensors_
the list of all tensors stored in the cliques
HashTable< const Tensor< GUM_SCALAR > *, const IScheduleMultiDim * > _hard_ev_projected_factors_
the factors that were projected due to hard evidence nodes
bool _is_new_jt_needed_
indicates whether a new join tree is needed for the next inference
void _initializeJTCliques_(Schedule &schedule)
put all the CPTs into the cliques when creating the JT using a schedule
void makeInference_() final
called when the inference has to be performed effectively
void onStateChanged_() final
fired when the state of the inference engine is changed
JoinTree * _JT_
the join (or junction) tree used to answer the last inference query
const JoinTree * joinTree()
returns the current join tree used
void _initializeJTCliques_()
put all the CPTs into the cliques when creating the JT without using a schedule
void onAllMarginalTargetsAdded_() final
fired after all the nodes of the MRF are added as single targets
const Tensor< GUM_SCALAR > & jointPosterior_(const NodeSet &set) final
returns the posterior of a declared target set
virtual bool isExactJointComputable_(const NodeSet &vars) final
check if the vars form a possible computable joint (can be redefined by subclass)
Tensor< GUM_SCALAR > * _unnormalizedJointPosterior_(Schedule &schedule, NodeId id)
computes the unnormalized posterior of a node using schedules
const IScheduleMultiDim * _marginalizeOut_(_ScheduleMultiDimSet_ &pot_list, gum::VariableSet &del_vars, gum::VariableSet &kept_vars)
removes variables del_vars from a list of tensors and returns the resulting list directly without sch...
void _createNewJT_()
create a new junction tree as well as its related data structures
HashTable< NodeSet, const Tensor< GUM_SCALAR > * > _joint_target_posteriors_
the set of set target posteriors computed during the last inference
void onMarginalTargetAdded_(const NodeId id) final
fired after a new single target is inserted
HashTable< NodeSet, NodeId > _joint_target_to_clique_
for each set target, assign a clique in the JT that contains it
ArcProperty< bool > _messages_computed_
indicates whether a message (from one clique to another) has been computed
NodeProperty< _TensorSet_ > _node_to_factors_
assign to each node the set of factors containing it
Set< const IScheduleMultiDim * > _ScheduleMultiDimSet_
void _invalidateAllMessages_()
invalidate all messages, posteriors and created tensors
void updateOutdatedStructure_() final
prepares inference when the latter is in OutdatedStructure state
void onJointTargetAdded_(const NodeSet &set) final
fired after a new joint target is inserted
void _produceMessage_(NodeId from_id, NodeId to_id)
creates the message sent by clique from_id to clique to_id without schedules
void _collectMessage_(Schedule &schedule, NodeId id, NodeId from)
perform the collect phase using schedules
HashTable< const Tensor< GUM_SCALAR > *, NodeId > _factor_to_clique_
assign to each factor in the MRF the clique that will contain it
Tensor< GUM_SCALAR >(* _projection_op_)(const Tensor< GUM_SCALAR > &, const gum::VariableSet &)
the operator for performing the projections
GUM_SCALAR evidenceProbability() final
returns the probability of evidence
bool _use_schedules_
indicates whether we should use schedules for inference
static constexpr double _schedule_threshold_
minimal number of operations to perform in the JT to use schedules
NodeSet _hard_ev_nodes_
the hard evidence nodes which were projected in factors
Tensor< GUM_SCALAR > * _unnormalizedJointPosterior_(NodeId id)
computes the unnormalized posterior of a node without using schedules
void onModelChanged_(const GraphicalModel *mn) final
fired after a new Markov net has been assigned to the inference engine
NodeProperty< NodeId > _node_to_clique_
for each node of graph (~ in the Markov net), associate an ID in the JT
void onAllMarginalTargetsErased_() final
fired before all the single targets are removed
virtual NodeSet superForJointComputable_(const NodeSet &vars) final
void onEvidenceChanged_(const NodeId id, bool hasChangedSoftHard) final
fired after an evidence is changed, in particular when its status (soft/hard) changes
const Tensor< GUM_SCALAR > & posterior_(NodeId id) final
returns the posterior of a given variable
Tensor< GUM_SCALAR > * unnormalizedJointPosterior_(NodeId id) final
returns a fresh tensor equal to P(argument,evidence)
void onEvidenceErased_(const NodeId id, bool isHardEvidence) final
fired before an evidence is removed
void onMarginalTargetErased_(const NodeId id) final
fired before a single target is removed
EvidenceChangeType
the possible types of evidence changes
void onEvidenceAdded_(const NodeId id, bool isHardEvidence) final
fired after a new evidence is inserted
void _diffuseMessageInvalidations_(NodeId from_id, NodeId to_id, NodeSet &invalidated_cliques)
invalidate all the messages sent from a given clique
void _computeJoinTreeRoots_()
compute a root for each connected component of JT
void onJointTargetErased_(const NodeSet &set) final
fired before a joint target is removed
Set< const Tensor< GUM_SCALAR > * > _TensorSet_
SetIteratorSafe< const Tensor< GUM_SCALAR > * > _TensorSetIterator_
const IScheduleMultiDim * _marginalizeOut_(Schedule &schedule, _ScheduleMultiDimSet_ pot_list, gum::VariableSet &del_vars, gum::VariableSet &kept_vars)
removes variables del_vars from a list of tensors and returns the resulting list using schedules
ArcProperty< const IScheduleMultiDim * > _arc_to_created_tensors_
the set of tensors created for the last inference messages
NodeSet _roots_
a clique node used as a root in each connected component of JT
void onAllTargetsErased_() final
fired before all single and joint targets are removed
Tensor< GUM_SCALAR >(* _combination_op_)(const Tensor< GUM_SCALAR > &, const Tensor< GUM_SCALAR > &)
the operator for performing the combinations
bool _isNewJTNeeded_() const
check whether a new join tree is really needed for the next inference
void _collectMessage_(NodeId id, NodeId from)
actually perform the collect phase directly without schedules
Tensor< GUM_SCALAR > * _unnormalizedJointPosterior_(const NodeSet &set)
returns a fresh tensor equal to P(argument,evidence) without using schedules
ArcProperty< const IScheduleMultiDim * > _separator_tensors_
the list of all tensors stored in the separators after inferences
void _produceMessage_(Schedule &schedule, NodeId from_id, NodeId to_id)
creates the message sent by clique from_id to clique to_id using schedules
NodeProperty< const IScheduleMultiDim * > _node_to_soft_evidence_
the soft evidence stored in the cliques per their assigned node in the MRF
JunctionTree * _junctionTree_
the junction tree to answer the last inference query
Tensor< GUM_SCALAR > * _unnormalizedJointPosterior_(Schedule &schedule, const NodeSet &set)
returns a fresh tensor equal to P(argument,evidence) using schedules
ShaferShenoyMRFInference(const IMarkovRandomField< GUM_SCALAR > *MN, bool use_binary_join_tree=true)
default constructor
const JunctionTree * junctionTree()
returns the current junction tree
NodeProperty< const Tensor< GUM_SCALAR > * > _target_posteriors_
the set of single posteriors computed during the last inference
virtual void onMRFChanged_(const IMarkovRandomField< GUM_SCALAR > *mn) final
fired after a new Markov net has been assigned to the inference engine
static constexpr GUM_SCALAR _one_minus_epsilon_
for comparisons with 1 - epsilon
ShaferShenoyMRFInference(const ShaferShenoyMRFInference< GUM_SCALAR > &)=delete
avoid copy constructors
aGrUM's Tensor is a multi-dimensional array with tensor operators.
Definition tensor.h:85
Interface for all the triangulation methods.
Base class for undirected graphs.
Definition undiGraph.h:128
Class for computing default triangulations of graphs.
This file contains the abstract class definition for computing the probability of evidence entered in...
Size NodeId
Type for node ids.
HashTable< Arc, VAL > ArcProperty
Property on graph elements.
HashTable< NodeId, VAL > NodeProperty
Property on graph elements.
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
This file contains the abstract inference class definition for computing (incrementally) joint poster...
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Set< const DiscreteVariable * > VariableSet
CliqueGraph JoinTree
a join tree is a clique graph satisfying the running intersection property (but some cliques may be i...
static INLINE Tensor< GUM_SCALAR > SSNewMNmultiTensor(const Tensor< GUM_SCALAR > &t1, const Tensor< GUM_SCALAR > &t2)
CliqueGraph JunctionTree
a junction tree is a clique graph satisfying the running intersection property and such that no cliqu...
static INLINE Tensor< GUM_SCALAR > SSNewMNprojTensor(const Tensor< GUM_SCALAR > &t1, const gum::VariableSet &del_vars)
The class enabling flexible inferences using schedules.