aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
ShaferShenoyMRFInference_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
50
51#ifndef DOXYGEN_SHOULD_SKIP_THIS
52# include <algorithm>
53
59
60namespace gum {
61 // default constructor
62 template < typename GUM_SCALAR >
65 bool use_binary_join_tree) :
66 JointTargetedMRFInference< GUM_SCALAR >(MN), EvidenceMRFInference< GUM_SCALAR >(MN),
67 _use_binary_join_tree_(use_binary_join_tree) {
68 // create a default triangulation (the user can change it afterwards)
69 _triangulation_ = new DefaultTriangulation;
70
71 // for each node in the MRF, assign the set of factors that contain it
72 const auto& graph = this->MRF().graph();
73 _node_to_factors_.resize(graph.size());
74 _TensorSet_ empty;
75 for (const auto node: graph)
76 _node_to_factors_.insert(node, empty);
77 for (const auto& factor: this->MRF().factors()) {
78 for (const auto node: factor.first) {
79 _node_to_factors_[node].insert(factor.second);
80 }
81 }
82
83 // for debugging purposes
84 GUM_CONSTRUCTOR(ShaferShenoyMRFInference);
85 }
86
87 // destructor
88 template < typename GUM_SCALAR >
89 INLINE ShaferShenoyMRFInference< GUM_SCALAR >::~ShaferShenoyMRFInference() {
90 // remove all the tensors created during the last message passing
91 for (const auto& pot: _arc_to_created_tensors_)
92 delete pot.second;
93
94 // remove all the tensors in _clique_ss_tensor_ that do not belong
95 // to _clique_tensors_: in this case, those tensors have been
96 // created by combination of the corresponding list of tensors in
97 // _clique_tensors_. In other words, the size of this list is strictly
98 // greater than 1.
99 for (auto pot: _clique_ss_tensor_) {
100 if (_clique_tensors_[pot.first].size() > 1) delete pot.second;
101 }
102
103 for (auto potset: _clique_tensors_) {
104 for (auto pot: potset.second)
105 delete pot;
106 }
107
108 // remove all the posteriors computed
109 for (const auto& pot: _target_posteriors_)
110 delete pot.second;
111 for (const auto& pot: _joint_target_posteriors_)
112 delete pot.second;
113
114 // remove the junction tree and the triangulation algorithm
115 if (_JT_ != nullptr) delete _JT_;
116 if (_junctionTree_ != nullptr) delete _junctionTree_;
117 delete _triangulation_;
118
119 // for debugging purposes
120 GUM_DESTRUCTOR(ShaferShenoyMRFInference);
121 }
122
124 template < typename GUM_SCALAR >
125 void ShaferShenoyMRFInference< GUM_SCALAR >::setTriangulation(
126 const Triangulation& new_triangulation) {
127 delete _triangulation_;
128 _triangulation_ = new_triangulation.newFactory();
129 _is_new_jt_needed_ = true;
130 this->setOutdatedStructureState_();
131 }
132
134 template < typename GUM_SCALAR >
135 INLINE const JoinTree* ShaferShenoyMRFInference< GUM_SCALAR >::joinTree() {
136 if (_is_new_jt_needed_) _createNewJT_();
137
138 return _JT_;
139 }
140
142 template < typename GUM_SCALAR >
143 INLINE const JunctionTree* ShaferShenoyMRFInference< GUM_SCALAR >::junctionTree() {
144 if (_is_new_jt_needed_) _createNewJT_();
145
146 return _junctionTree_;
147 }
148
150 template < typename GUM_SCALAR >
151 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::_setProjectionFunction_(
152 Tensor< GUM_SCALAR > (*proj)(const Tensor< GUM_SCALAR >&, const gum::VariableSet&)) {
153 _projection_op_ = proj;
154
155 // indicate that all messages need be reconstructed to take into account
156 // the change in of the projection operator
157 _invalidateAllMessages_();
158 }
159
161 template < typename GUM_SCALAR >
162 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::_setCombinationFunction_(
163 Tensor< GUM_SCALAR > (*comb)(const Tensor< GUM_SCALAR >&, const Tensor< GUM_SCALAR >&)) {
164 _combination_op_ = comb;
165
166 // indicate that all messages need be reconstructed to take into account
167 // the change of the combination operator
168 _invalidateAllMessages_();
169 }
170
172 template < typename GUM_SCALAR >
173 void ShaferShenoyMRFInference< GUM_SCALAR >::_invalidateAllMessages_() {
174 // remove all the messages computed
175 for (auto& pot: _separator_tensors_)
176 pot.second = nullptr;
177
178 for (auto& mess_computed: _messages_computed_)
179 mess_computed.second = false;
180
181 // remove all the created tensors kept on the arcs
182 for (const auto& pot: _arc_to_created_tensors_)
183 if (pot.second != nullptr) delete pot.second;
184 _arc_to_created_tensors_.clear();
185
186 // remove all the posteriors
187 for (const auto& pot: _target_posteriors_)
188 delete pot.second;
189 _target_posteriors_.clear();
190 for (const auto& pot: _joint_target_posteriors_)
191 delete pot.second;
192 _joint_target_posteriors_.clear();
193
194 // indicate that new messages need be computed
195 if (this->isInferenceReady() || this->isInferenceDone()) this->setOutdatedTensorsState_();
196 }
197
199 template < typename GUM_SCALAR >
200 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onEvidenceAdded_(const NodeId id,
201 bool isHardEvidence) {
202 // if we have a new hard evidence, this modifies the undigraph over which
203 // the join tree is created. This is also the case if id is not a node of
204 // of the undigraph
205 if (isHardEvidence || !_graph_.exists(id)) _is_new_jt_needed_ = true;
206 else {
207 try {
208 _evidence_changes_.insert(id, EvidenceChangeType::EVIDENCE_ADDED);
209 } catch (DuplicateElement const&) {
210 // here, the evidence change already existed. This necessarily means
211 // that the current saved change is an EVIDENCE_ERASED. So if we
212 // erased the evidence and added some again, this corresponds to an
213 // EVIDENCE_MODIFIED
214 _evidence_changes_[id] = EvidenceChangeType::EVIDENCE_MODIFIED;
215 }
216 }
217 }
218
220 template < typename GUM_SCALAR >
221 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onEvidenceErased_(const NodeId id,
222 bool isHardEvidence) {
223 // if we delete a hard evidence, this modifies the undigraph over which
224 // the join tree is created.
225 if (isHardEvidence) _is_new_jt_needed_ = true;
226 else {
227 try {
228 _evidence_changes_.insert(id, EvidenceChangeType::EVIDENCE_ERASED);
229 } catch (DuplicateElement const&) {
230 // here, the evidence change already existed and it is necessarily an
231 // EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
232 // been added and is now erased, this is similar to not having created
233 // it. If the evidence was only modified, it already existed in the
234 // last inference and we should now indicate that it has been removed.
235 if (_evidence_changes_[id] == EvidenceChangeType::EVIDENCE_ADDED)
236 _evidence_changes_.erase(id);
237 else _evidence_changes_[id] = EvidenceChangeType::EVIDENCE_ERASED;
238 }
239 }
240 }
241
243 template < typename GUM_SCALAR >
244 void ShaferShenoyMRFInference< GUM_SCALAR >::onAllEvidenceErased_(bool has_hard_evidence) {
245 if (has_hard_evidence || !this->hardEvidenceNodes().empty()) _is_new_jt_needed_ = true;
246 else {
247 for (const auto node: this->softEvidenceNodes()) {
248 try {
249 _evidence_changes_.insert(node, EvidenceChangeType::EVIDENCE_ERASED);
250 } catch (DuplicateElement const&) {
251 // here, the evidence change already existed and it is necessarily an
252 // EVIDENCE_ADDED or an EVIDENCE_MODIFIED. So, if the evidence has
253 // been added and is now erased, this is similar to not having created
254 // it. If the evidence was only modified, it already existed in the
255 // last inference and we should now indicate that it has been removed.
256 if (_evidence_changes_[node] == EvidenceChangeType::EVIDENCE_ADDED)
257 _evidence_changes_.erase(node);
258 else _evidence_changes_[node] = EvidenceChangeType::EVIDENCE_ERASED;
259 }
260 }
261 }
262 }
263
265 template < typename GUM_SCALAR >
266 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onEvidenceChanged_(const NodeId id,
267 bool hasChangedSoftHard) {
268 if (hasChangedSoftHard) _is_new_jt_needed_ = true;
269 else {
270 try {
271 _evidence_changes_.insert(id, EvidenceChangeType::EVIDENCE_MODIFIED);
272 } catch (DuplicateElement const&) {
273 // here, the evidence change already existed and it is necessarily an
274 // EVIDENCE_ADDED. So we should keep this state to indicate that this
275 // evidence is new w.r.t. the last inference
276 }
277 }
278 }
279
281 template < typename GUM_SCALAR >
282 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onMarginalTargetAdded_(const NodeId id) {}
283
285 template < typename GUM_SCALAR >
286 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onMarginalTargetErased_(const NodeId id) {}
287
289 template < typename GUM_SCALAR >
290 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onJointTargetAdded_(const NodeSet& set) {}
291
293 template < typename GUM_SCALAR >
294 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onJointTargetErased_(const NodeSet& set) {}
295
297 template < typename GUM_SCALAR >
298 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onAllMarginalTargetsAdded_() {}
299
301 template < typename GUM_SCALAR >
302 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onAllMarginalTargetsErased_() {}
303
305 template < typename GUM_SCALAR >
306 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onMRFChanged_(
307 const IMarkovRandomField< GUM_SCALAR >* mn) {}
308
310 template < typename GUM_SCALAR >
311 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onAllJointTargetsErased_() {}
312
314 template < typename GUM_SCALAR >
315 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onAllTargetsErased_() {}
316
317 // check whether a new junction tree is really needed for the next inference
318 template < typename GUM_SCALAR >
319 bool ShaferShenoyMRFInference< GUM_SCALAR >::_isNewJTNeeded_() const {
320 // if we do not have a JT or if _new_jt_needed_ is set to true, then
321 // we know that we need to create a new join tree
322 if ((_JT_ == nullptr) || _is_new_jt_needed_) return true;
323
324 // if some targets do not belong to the join tree and, consequently, to the
325 // undirected graph that was used to construct the join tree, then we need
326 // to create a new JT: this may happen due to the nodes that received hard
327 // evidence (which do not belong to the graph).
328 const auto& hard_ev_nodes = this->hardEvidenceNodes();
329 for (const auto node: this->targets()) {
330 if (!_graph_.exists(node) && !hard_ev_nodes.exists(node)) return true;
331 }
332
333 // now, do the same for the joint targets
334 const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder();
335 NodeProperty< int > elim_order(Size(JT_elim_order.size()));
336 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size; ++i)
337 elim_order.insert(JT_elim_order[i], (int)i);
338 NodeSet unobserved_set;
339
340 for (const auto& joint_target: this->jointTargets()) {
341 // here, we need to check that at least one clique contains all the
342 // nodes of the joint target.
343 NodeId first_eliminated_node = std::numeric_limits< NodeId >::max();
344 int elim_number = std::numeric_limits< int >::max();
345 unobserved_set.clear();
346 for (const auto node: joint_target) {
347 if (!_graph_.exists(node)) {
348 if (!hard_ev_nodes.exists(node)) return true;
349 } else {
350 unobserved_set.insert(node);
351 if (elim_order[node] < elim_number) {
352 elim_number = elim_order[node];
353 first_eliminated_node = node;
354 }
355 }
356 }
357 if (!unobserved_set.empty()) {
358 // here, first_eliminated_node contains the first var (node or one of its
359 // parents) eliminated => the clique created during its elimination
360 // should contain all the nodes in unobserved_set
361 const auto clique_id = _node_to_clique_[first_eliminated_node];
362 const auto& clique = _JT_->clique(clique_id);
363 for (const auto node: unobserved_set) {
364 if (!clique.contains(node)) return true;
365 }
366 }
367 }
368
369 // if some new evidence have been added on nodes that do not belong
370 // to _graph_, then we tensorly have to reconstruct the join tree
371 for (const auto& change: _evidence_changes_) {
372 if ((change.second == EvidenceChangeType::EVIDENCE_ADDED) && !_graph_.exists(change.first))
373 return true;
374 }
375
376 // here, the current JT is exactly what we need for the next inference
377 return false;
378 }
379
381 template < typename GUM_SCALAR >
382 void ShaferShenoyMRFInference< GUM_SCALAR >::_createNewJT_() {
383 // to create the JT, we first create the required subgraph of the MRF in the
384 // following way, in order to take into account the nodes that received
385 // evidence:
386 // 1/ we copy the graph of the MRF
387 // 2/ add edges so that joint targets form a clique of the graph
388 // 3/ remove the nodes that received hard evidence
389 //
390 // At the end of step 3/, we have our required graph and we can triangulate
391 // it to get the new junction tree
392
393 // 1/ copy the undirected graph of the MRF
394 const auto& mn = this->MRF();
395 _graph_ = mn.graph();
396
397 // 2/ if there exist some joint targets, we shall add new edges into the
398 // undirected graph in order to ensure that there exists a clique containing
399 // each joint target
400 for (const auto& nodeset: this->jointTargets()) {
401 for (auto iter1 = nodeset.cbegin(); iter1 != nodeset.cend(); ++iter1) {
402 auto iter2 = iter1;
403 for (++iter2; iter2 != nodeset.cend(); ++iter2) {
404 _graph_.addEdge(*iter1, *iter2);
405 }
406 }
407 }
408
409 // 3/ remove all the nodes that received hard evidence
410 _hard_ev_nodes_ = this->hardEvidenceNodes();
411 for (const auto node: _hard_ev_nodes_) {
412 _graph_.eraseNode(node);
413 }
414
415
416 // now, we can compute the new junction tree. To speed-up computations
417 // (essentially, those of a distribution phase), we construct from this
418 // junction tree a binary join tree
419 if (_JT_ != nullptr) delete _JT_;
420 if (_junctionTree_ != nullptr) delete _junctionTree_;
421
422 const auto& domain_sizes = this->domainSizes();
423 _triangulation_->setGraph(&_graph_, &domain_sizes);
424 const JunctionTree& triang_jt = _triangulation_->junctionTree();
425 if (_use_binary_join_tree_) {
426 BinaryJoinTreeConverterDefault bjt_converter;
427 NodeSet emptyset;
428 _JT_ = new CliqueGraph(bjt_converter.convert(triang_jt, domain_sizes, emptyset));
429 } else {
430 _JT_ = new CliqueGraph(triang_jt);
431 }
432 _junctionTree_ = new CliqueGraph(triang_jt);
433
434
435 // assign to each node the order in which it was eliminated by the triangulation
436 const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder();
437 Size size_elim_order = JT_elim_order.size();
438 NodeProperty< int > elim_order(size_elim_order);
439 for (Idx i = Idx(0); i < size_elim_order; ++i)
440 elim_order.insert(JT_elim_order[i], (int)i);
441
442 // assign to each factor of the Markov random field a clique in _JT_
443 // that can contain its conditional probability table
444 _factor_to_clique_.clear();
445 _factor_to_clique_.resize(mn.factors().size());
446 for (const auto& factor: mn.factors()) {
447 const auto& nodes = factor.first; // factor.second is the Tensor()
448 NodeId first_eliminated_node = std::numeric_limits< NodeId >::max();
449 int elim_number = std::numeric_limits< int >::max();
450 for (const auto node: nodes) {
451 if (_graph_.exists(node) && (elim_order[node] < elim_number)) {
452 elim_number = elim_order[node];
453 first_eliminated_node = node;
454 }
455 }
456
457 if (elim_number != std::numeric_limits< int >::max()) {
458 // first_eliminated_node contains the first var/node eliminated => the
459 // clique created during its elimination must contain node and all of its
460 // neighbors => it necessarily contains all the nodes of factor
461 _factor_to_clique_.insert(
462 factor.second,
463 _triangulation_->createdJunctionTreeClique(first_eliminated_node));
464 }
465 }
466
467 // assign to each node that did not receive some hard evidence the smallest
468 // clique that contains it
469 _node_to_clique_.clear();
470 _node_to_clique_.resize(_graph_.size());
471 NodeProperty< double > node_to_clique_size(_graph_.size());
472 for (const auto node: _graph_) {
473 _node_to_clique_.insert(node, std::numeric_limits< NodeId >::max());
474 node_to_clique_size.insert(node, std::numeric_limits< double >::max());
475 }
476 double overall_size = 0;
477 for (const auto clique_id: *_JT_) {
478 // determine the size of clique_id
479 const auto& clique_nodes = _JT_->clique(clique_id);
480 double clique_size = 1.0;
481 for (const auto node: clique_nodes)
482 clique_size *= double(domain_sizes[node]);
483 overall_size += clique_size;
484
485 // assign the clique to the nodes if its size is smaller than that of the
486 // current cliques assigned to them
487 for (const auto node: clique_nodes) {
488 if (clique_size < node_to_clique_size[node]) {
489 _node_to_clique_[node] = clique_id;
490 node_to_clique_size[node] = clique_size;
491 }
492 }
493 }
494
495 // indicate for each joint_target a clique that contains it
496 _joint_target_to_clique_.clear();
497 for (const auto& set: this->jointTargets()) {
498 NodeId first_eliminated_node = std::numeric_limits< NodeId >::max();
499 int elim_number = std::numeric_limits< int >::max();
500
501 // do not take into account the nodes that received hard evidence
502 // (since they do not belong to the join tree)
503 for (const auto node: set) {
504 if (!_hard_ev_nodes_.contains(node)) {
505 // the clique we are looking for is the one that was created when
506 // the first element of nodeset was eliminated
507 if (elim_order[node] < elim_number) {
508 elim_number = elim_order[node];
509 first_eliminated_node = node;
510 }
511 }
512 }
513
514 if (elim_number != std::numeric_limits< int >::max()) {
515 _joint_target_to_clique_.insert(
516 set,
517 _triangulation_->createdJunctionTreeClique(first_eliminated_node));
518 }
519 }
520
521 // compute the roots of _JT_'s connected components
522 _computeJoinTreeRoots_();
523
524 // remove all the tensors stored into the cliques. Note that these include
525 // the factors resulting from the projections of hard evidence as well as the
526 // CPTs of the soft evidence
527 for (const auto& pot: _clique_ss_tensor_) {
528 if (_clique_tensors_[pot.first].size() > 1) delete pot.second;
529 }
530 _clique_ss_tensor_.clear();
531 for (const auto& potlist: _clique_tensors_)
532 for (const auto pot: potlist.second)
533 delete pot;
534 _clique_tensors_.clear();
535
536 // remove all the tensors created during the last inference
537 for (const auto& pot: _arc_to_created_tensors_)
538 delete pot.second;
539 _arc_to_created_tensors_.clear();
540
541 // remove all the tensors created to take into account hard evidence
542 // during the last inference (they have already been deleted from memory
543 // by the clearing of _clique_tensors_).
544 _hard_ev_projected_factors_.clear();
545
546 // remove all the soft evidence.
547 _node_to_soft_evidence_.clear();
548
549 // create empty tensor lists into the cliques of the joint tree as well
550 // as empty lists of evidence
551 _ScheduleMultiDimSet_ empty_set;
552 for (const auto node: *_JT_) {
553 _clique_tensors_.insert(node, empty_set);
554 _clique_ss_tensor_.insert(node, nullptr);
555 }
556
557 // remove all the constants created due to projections of CPTs that were
558 // defined over only hard evidence nodes
559 _constants_.clear();
560
561 // create empty messages and indicate that no message has been computed yet
562 _separator_tensors_.clear();
563 _messages_computed_.clear();
564 for (const auto& edge: _JT_->edges()) {
565 const Arc arc1(edge.first(), edge.second());
566 _separator_tensors_.insert(arc1, nullptr);
567 _messages_computed_.insert(arc1, false);
568 const Arc arc2(edge.second(), edge.first());
569 _separator_tensors_.insert(arc2, nullptr);
570 _messages_computed_.insert(arc2, false);
571 }
572
573 // remove all the posteriors computed so far
574 for (const auto& pot: _target_posteriors_)
575 delete pot.second;
576 _target_posteriors_.clear();
577 for (const auto& pot: _joint_target_posteriors_)
578 delete pot.second;
579 _joint_target_posteriors_.clear();
580
581 // here, we determine whether we should use schedules during the inference.
582 // the rule is: if the sum of the domain sizes of the cliques is greater
583 // than a threshold, use schedules
584 _use_schedules_ = (overall_size > _schedule_threshold_);
585
586 // we shall now add all the tensors of the soft evidence to the cliques
587 const NodeProperty< const Tensor< GUM_SCALAR >* >& evidence = this->evidence();
588 for (const auto node: this->softEvidenceNodes()) {
589 if (_node_to_clique_.exists(node)) {
590 auto ev_pot = new ScheduleMultiDim< Tensor< GUM_SCALAR > >(*evidence[node], false);
591 _node_to_soft_evidence_.insert(node, ev_pot);
592 _clique_tensors_[_node_to_clique_[node]].insert(ev_pot);
593 }
594 }
595
596 // put all the factors of the MRF into the cliques
597 // here, beware: all the tensors that are defined over some nodes
598 // including hard evidence must be projected so that these nodes are
599 // removed from the tensor
600 if (_use_schedules_) {
601 Schedule schedule;
602 _initializeJTCliques_(schedule);
603 } else {
604 _initializeJTCliques_();
605 }
606
607
608 // indicate that the data structures are up to date.
609 _evidence_changes_.clear();
610 _is_new_jt_needed_ = false;
611 }
612
614 template < typename GUM_SCALAR >
615 void ShaferShenoyMRFInference< GUM_SCALAR >::_initializeJTCliques_() {
616 const auto& mn = this->MRF();
617
618 // put all the factors of the MRF into the cliques
619 // here, beware: all the tensors that are defined over some nodes
620 // including hard evidence must be projected so that these nodes are
621 // removed from the tensor
622 const NodeProperty< const Tensor< GUM_SCALAR >* >& evidence = this->evidence();
623 const NodeProperty< Idx >& hard_evidence = this->hardEvidence();
624
625 for (const auto& factor: mn.factors()) {
626 const auto& factor_nodes = factor.first;
627 const auto& pot = *(factor.second);
628 const auto& variables = pot.variablesSequence();
629
630 // get the list of nodes with hard evidence in the factor
631 NodeSet hard_nodes(factor_nodes.size());
632 bool graph_contains_nodes = false;
633 for (const auto node: factor_nodes) {
634 if (_hard_ev_nodes_.contains(node)) hard_nodes.insert(node);
635 else if (_graph_.exists(node)) graph_contains_nodes = true;
636 }
637
638 // if hard_nodes contains hard evidence nodes, perform a projection
639 // and insert the result into the appropriate clique, else insert
640 // directly pot into the clique
641 if (hard_nodes.empty()) {
642 auto sched_cpt = new ScheduleMultiDim< Tensor< GUM_SCALAR > >(pot, false);
643 _clique_tensors_[_factor_to_clique_[&pot]].insert(sched_cpt);
644 } else {
645 // marginalize out the hard evidence nodes: if factor_nodes is defined
646 // only over nodes that received hard evidence, do not consider it
647 // as a tensor anymore but as a constant
648 // TODO substitute constants by 0-dimensional tensors
649 if (hard_nodes.size() == factor_nodes.size()) {
650 Instantiation inst(pot);
651 for (Size i = 0; i < hard_nodes.size(); ++i) {
652 inst.chgVal(*variables[i], hard_evidence[mn.nodeId(*(variables[i]))]);
653 }
654 _constants_.insert(&pot, pot.get(inst));
655 } else {
656 // here, we have a factor defined over some nodes that received hard
657 // evidence and other nodes that did not receive it. If none of the
658 // latter belong to the graph, then the factor is useless for inference
659 if (!graph_contains_nodes) continue;
660
661 // prepare the projection with a combine and project instance
662 gum::VariableSet hard_variables;
663 _TensorSet_ marg_factor_set(1 + hard_nodes.size());
664 marg_factor_set.insert(&pot);
665 for (const auto node: hard_nodes) {
666 marg_factor_set.insert(evidence[node]);
667 hard_variables.insert(&(mn.variable(node)));
668 }
669
670 // perform the combination of those tensors and their projection
671 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
672 _combination_op_,
673 _projection_op_);
674
675 _TensorSet_ new_factor_list
676 = combine_and_project.execute(marg_factor_set, hard_variables);
677
678 // there should be only one tensor in new_factor_list
679 if (new_factor_list.size() != 1) {
680 for (const auto pot: new_factor_list) {
681 if (!marg_factor_set.contains(pot)) delete pot;
682 }
684 "the projection of a tensor containing " << "hard evidence is empty!");
685 }
686 auto new_factor = const_cast< Tensor< GUM_SCALAR >* >(*(new_factor_list.begin()));
687 auto projected_factor
688 = new ScheduleMultiDim< Tensor< GUM_SCALAR > >(std::move(*new_factor));
689 delete new_factor;
690
691 _clique_tensors_[_factor_to_clique_[&pot]].insert(projected_factor);
692 _hard_ev_projected_factors_.insert(&pot, projected_factor);
693 }
694 }
695 }
696
697 // now, in _clique_tensors_, for each clique, we have the list of
698 // tensors that must be combined in order to produce the Shafer-Shenoy's
699 // tensor stored into the clique. So, perform this combination and
700 // store the result in _clique_ss_tensor_
701 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
702 for (const auto& xpotset: _clique_tensors_) {
703 const auto& potset = xpotset.second;
704 if (potset.size() > 0) {
705 // here, there will be an entry in _clique_ss_tensor_
706 // If there is only one element in potset, this element shall be
707 // stored into _clique_ss_tensor_, else all the elements of potset
708 // shall be combined and their result shall be stored
709 if (potset.size() == 1) {
710 _clique_ss_tensor_[xpotset.first] = *(potset.cbegin());
711 } else {
712 _TensorSet_ p_potset(potset.size());
713 for (const auto pot: potset)
714 p_potset.insert(
715 &(static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(pot)->multiDim()));
716
717 Tensor< GUM_SCALAR >* joint
718 = const_cast< Tensor< GUM_SCALAR >* >(fast_combination.execute(p_potset));
719 _clique_ss_tensor_[xpotset.first]
720 = new ScheduleMultiDim< Tensor< GUM_SCALAR > >(std::move(*joint));
721 delete joint;
722 }
723 }
724 }
725 }
726
728 template < typename GUM_SCALAR >
729 void ShaferShenoyMRFInference< GUM_SCALAR >::_initializeJTCliques_(Schedule& schedule) {
730 const auto& mn = this->MRF();
731
732 // put all the factors of the MRF into the cliques
733 // here, beware: all the tensors that are defined over some nodes
734 // including hard evidence must be projected so that these nodes are
735 // removed from the tensor
736 const NodeProperty< const Tensor< GUM_SCALAR >* >& evidence = this->evidence();
737 const NodeProperty< Idx >& hard_evidence = this->hardEvidence();
738
739 for (const auto& factor: mn.factors()) {
740 const auto& factor_nodes = factor.first;
741 const auto& pot = *(factor.second);
742 const auto& variables = pot.variablesSequence();
743
744 // get the list of nodes with hard evidence in the factor
745 NodeSet hard_nodes;
746 bool graph_contains_nodes = false;
747 for (const auto node: factor_nodes) {
748 if (_hard_ev_nodes_.contains(node)) hard_nodes.insert(node);
749 else if (_graph_.exists(node)) graph_contains_nodes = true;
750 }
751
752 // if hard_nodes contains hard evidence nodes, perform a projection
753 // and insert the result into the appropriate clique, else insert
754 // directly pot into the clique
755 if (hard_nodes.empty()) {
756 auto sched_cpt = new ScheduleMultiDim< Tensor< GUM_SCALAR > >(pot, false);
757 _clique_tensors_[_factor_to_clique_[&pot]].insert(sched_cpt);
758 } else {
759 // marginalize out the hard evidence nodes: if factor_nodes is defined
760 // only over nodes that received hard evidence, do not consider it
761 // as a tensor anymore but as a constant
762 // TODO substitute constants by 0-dimensional tensors
763 if (hard_nodes.size() == factor_nodes.size()) {
764 Instantiation inst(pot);
765 for (Size i = 0; i < hard_nodes.size(); ++i) {
766 inst.chgVal(*variables[i], hard_evidence[mn.nodeId(*(variables[i]))]);
767 }
768 _constants_.insert(&pot, pot.get(inst));
769 } else {
770 // here, we have a factor defined over some nodes that received hard
771 // evidence and other nodes that did not receive it. If none of the
772 // latter belong to the graph, then the factor is useless for inference
773 if (!graph_contains_nodes) continue;
774
775 // prepare the projection with a combine and project instance
776 gum::VariableSet hard_variables;
777 _ScheduleMultiDimSet_ marg_factor_set(1 + hard_nodes.size());
778 const IScheduleMultiDim* sched_pot
779 = schedule.insertTable< Tensor< GUM_SCALAR > >(pot, false);
780 marg_factor_set.insert(sched_pot);
781
782 for (const auto node: hard_nodes) {
783 const IScheduleMultiDim* pot
784 = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[node], false);
785 marg_factor_set.insert(pot);
786 hard_variables.insert(&(mn.variable(node)));
787 }
788
789 // perform the combination of those tensors and their projection
790 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
791 _combination_op_,
792 _projection_op_);
793
794 _ScheduleMultiDimSet_ new_factor_list
795 = combine_and_project.schedule(schedule, marg_factor_set, hard_variables);
796
797 // there should be only one tensor in new_factor_list
798 if (new_factor_list.size() != 1) {
800 "the projection of a tensor containing " << "hard evidence is empty!");
801 }
802 auto projected_factor = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
803 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
804 *new_factor_list.begin()));
805 const_cast< ScheduleOperator* >(schedule.scheduleMultiDimCreator(projected_factor))
806 ->makeResultsPersistent(true);
807
808 _clique_tensors_[_factor_to_clique_[&pot]].insert(projected_factor);
809 _hard_ev_projected_factors_.insert(&pot, projected_factor);
810 }
811 }
812 }
813 this->scheduler().execute(schedule);
814
815 // now, in _clique_tensors_, for each clique, we have the list of
816 // tensors that must be combined in order to produce the Shafer-Shenoy's
817 // tensor stored into the clique. So, perform this combination and
818 // store the result in _clique_ss_tensor_
819 schedule.clear();
820 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
821 for (const auto& xpotset: _clique_tensors_) {
822 const auto& potset = xpotset.second;
823 if (potset.size() > 0) {
824 // here, there will be an entry in _clique_ss_tensor_
825 // If there is only one element in potset, this element shall be
826 // stored into _clique_ss_tensor_, else all the elements of potset
827 // shall be combined and their result shall be stored
828 if (potset.size() == 1) {
829 _clique_ss_tensor_[xpotset.first] = *(potset.cbegin());
830 } else {
831 // add the tables to combine into the schedule
832 for (const auto pot: potset) {
833 schedule.emplaceScheduleMultiDim(*pot);
834 }
835
836 auto joint = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
837 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
838 fast_combination.schedule(schedule, potset)));
839 const_cast< ScheduleOperator* >(schedule.scheduleMultiDimCreator(joint))
840 ->makeResultsPersistent(true);
841 _clique_ss_tensor_[xpotset.first] = joint;
842 }
843 }
844 }
845 this->scheduler().execute(schedule);
846 }
847
849 template < typename GUM_SCALAR >
850 void ShaferShenoyMRFInference< GUM_SCALAR >::updateOutdatedStructure_() {
851 // check if a new JT is really needed. If so, create it
852 if (_isNewJTNeeded_()) {
853 _createNewJT_();
854 } else {
855 // here, we can answer the next queries without reconstructing all the
856 // junction tree. All we need to do is to indicate that we should
857 // update the tensors and messages for these queries
858 updateOutdatedTensors_();
859 }
860 }
861
863 template < typename GUM_SCALAR >
864 void ShaferShenoyMRFInference< GUM_SCALAR >::_diffuseMessageInvalidations_(
865 NodeId from_id,
866 NodeId to_id,
867 NodeSet& invalidated_cliques) {
868 // invalidate the current clique
869 invalidated_cliques.insert(to_id);
870
871 // invalidate the current arc
872 const Arc arc(from_id, to_id);
873 bool& message_computed = _messages_computed_[arc];
874 if (message_computed) {
875 message_computed = false;
876 _separator_tensors_[arc] = nullptr;
877 if (_arc_to_created_tensors_.exists(arc)) {
878 delete _arc_to_created_tensors_[arc];
879 _arc_to_created_tensors_.erase(arc);
880 }
881
882 // go on with the diffusion
883 for (const auto node_id: _JT_->neighbours(to_id)) {
884 if (node_id != from_id) _diffuseMessageInvalidations_(to_id, node_id, invalidated_cliques);
885 }
886 }
887 }
888
891 template < typename GUM_SCALAR >
892 void ShaferShenoyMRFInference< GUM_SCALAR >::updateOutdatedTensors_() {
893 // for each clique, indicate whether the tensor stored into
894 // _clique_ss_tensor_[clique] is the result of a combination. In this
895 // case, it has been allocated by the combination and will need to be
896 // deallocated if its clique has been invalidated
897 NodeProperty< bool > ss_tensor_to_deallocate(_clique_tensors_.size());
898 for (const auto& potset: _clique_tensors_) {
899 ss_tensor_to_deallocate.insert(potset.first, (potset.second.size() > 1));
900 }
901
902 // compute the set of factors that were projected due to hard evidence and
903 // whose hard evidence have changed, so that they need a new projection.
904 // By the way, remove these factors since they are no more needed
905 // Here only the values of the hard evidence can have changed (else a
906 // fully new join tree would have been computed).
907 // Note also that we know that the factors still contain some variable(s) after
908 // the projection (else they should be constants)
909 const auto& mn = this->MRF();
910 NodeSet hard_nodes_changed(_hard_ev_nodes_.size());
911 Set< const Tensor< GUM_SCALAR >* > hard_projected_factors_changed(mn.factors().size());
912 for (const auto node: _hard_ev_nodes_) {
913 if (_evidence_changes_.exists(node)) {
914 hard_nodes_changed.insert(node);
915 for (const auto pot: _node_to_factors_[node]) {
916 if (_hard_ev_projected_factors_.exists(pot)
917 && !hard_projected_factors_changed.exists(pot)) {
918 hard_projected_factors_changed.insert(pot);
919 }
920 }
921 }
922 }
923
924 NodeSet hard_cliques_changed(hard_projected_factors_changed.size());
925 for (const auto pot: hard_projected_factors_changed) {
926 const auto chgPot = _hard_ev_projected_factors_[pot];
927 const NodeId chgClique = _factor_to_clique_[pot];
928 _clique_tensors_[chgClique].erase(chgPot);
929 _hard_ev_projected_factors_.erase(pot);
930 if (!hard_cliques_changed.contains(chgClique)) hard_cliques_changed.insert(chgClique);
931 delete chgPot;
932 }
933
934
935 // invalidate all the messages that are no more correct: start from each of
936 // the nodes whose soft evidence has changed and perform a diffusion from
937 // the clique into which the soft evidence has been entered, indicating that
938 // the messages spreading from this clique are now invalid. At the same time,
939 // if there were tensors created on the arcs over which the messages were
940 // sent, remove them from memory. For all the cliques that received some
941 // projected factors that should now be changed, do the same.
942 NodeSet invalidated_cliques(_JT_->size());
943 for (const auto& pair: _evidence_changes_) {
944 if (_node_to_clique_.exists(pair.first)) {
945 const auto clique = _node_to_clique_[pair.first];
946 invalidated_cliques.insert(clique);
947 for (const auto neighbor: _JT_->neighbours(clique)) {
948 _diffuseMessageInvalidations_(clique, neighbor, invalidated_cliques);
949 }
950 }
951 }
952
953 // now, add to the set of invalidated cliques those that contain projected
954 // factors that were changed.
955 for (const auto clique: hard_cliques_changed) {
956 invalidated_cliques.insert(clique);
957 for (const auto neighbor: _JT_->neighbours(clique)) {
958 _diffuseMessageInvalidations_(clique, neighbor, invalidated_cliques);
959 }
960 }
961
962 // now that we know the cliques whose set of tensors have been changed,
963 // we can discard their corresponding Shafer-Shenoy tensor
964 for (const auto clique: invalidated_cliques) {
965 if (ss_tensor_to_deallocate[clique]) {
966 delete _clique_ss_tensor_[clique];
967 _clique_ss_tensor_[clique] = nullptr;
968 }
969 }
970
971
972 // now we shall remove all the posteriors that belong to the
973 // invalidated cliques.
974 if (!_target_posteriors_.empty()) {
975 for (auto iter = _target_posteriors_.beginSafe(); iter != _target_posteriors_.endSafe();
976 ++iter) {
977 // first, cope only with the nodes that did not receive hard evidence
978 // since the other nodes do not belong to the join tree
979 if (_graph_.exists(iter.key())
980 && (invalidated_cliques.exists(_node_to_clique_[iter.key()]))) {
981 delete iter.val();
982 _target_posteriors_.erase(iter);
983 }
984 // now cope with the nodes that received hard evidence
985 else if (hard_nodes_changed.contains(iter.key())) {
986 delete iter.val();
987 _target_posteriors_.erase(iter);
988 }
989 }
990 }
991
992 // finally, cope with joint targets. Notably, remove the joint posteriors whose
993 // nodes have all received changed evidence
994 for (auto iter = _joint_target_posteriors_.beginSafe();
995 iter != _joint_target_posteriors_.endSafe();
996 ++iter) {
997 if (invalidated_cliques.exists(_joint_target_to_clique_[iter.key()])) {
998 delete iter.val();
999 _joint_target_posteriors_.erase(iter);
1000 } else {
1001 // check for sets in which all nodes have received evidence
1002 bool has_unevidenced_node = false;
1003 for (const auto node: iter.key()) {
1004 if (!hard_nodes_changed.exists(node)) {
1005 has_unevidenced_node = true;
1006 break;
1007 }
1008 }
1009 if (!has_unevidenced_node) {
1010 delete iter.val();
1011 _joint_target_posteriors_.erase(iter);
1012 }
1013 }
1014 }
1015
1016 // remove all the evidence that were entered into _node_to_soft_evidence_
1017 // and _clique_ss_tensor_ and add the new soft ones
1018 for (const auto& pot_pair: _node_to_soft_evidence_) {
1019 delete pot_pair.second;
1020 _clique_tensors_[_node_to_clique_[pot_pair.first]].erase(pot_pair.second);
1021 }
1022 _node_to_soft_evidence_.clear();
1023
1024 const auto& evidence = this->evidence();
1025 for (const auto node: this->softEvidenceNodes()) {
1026 auto ev_pot = new ScheduleMultiDim< Tensor< GUM_SCALAR > >(*evidence[node], false);
1027 _node_to_soft_evidence_.insert(node, ev_pot);
1028 _clique_tensors_[_node_to_clique_[node]].insert(ev_pot);
1029 }
1030
1031
1032 // Now add the projections of the factors due to newly changed hard evidence:
1033 // if we are performing updateOutdatedTensors_, this means that the
1034 // set of nodes that received hard evidence has not changed, only
1035 // their instantiations can have changed. So, if there is an entry
1036 // for node in _constants_, there will still be such an entry after
1037 // performing the new projections. Idem for _hard_ev_projected_factors_
1038 if (_use_schedules_) {
1039 Schedule schedule;
1040 for (const auto pot: hard_projected_factors_changed) {
1041 _ScheduleMultiDimSet_ marg_pot_set;
1042 const auto sched_pot = schedule.insertTable< Tensor< GUM_SCALAR > >(*pot, false);
1043 marg_pot_set.insert(sched_pot);
1044 const auto& variables = pot->variablesSequence();
1045 gum::VariableSet hard_variables(variables.size());
1046 for (const auto var: variables) {
1047 NodeId xnode = mn.nodeId(*var);
1048 if (_hard_ev_nodes_.exists(xnode)) {
1049 const auto ev_pot
1050 = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[xnode], false);
1051 marg_pot_set.insert(ev_pot);
1052 hard_variables.insert(var);
1053 }
1054 }
1055
1056 // perform the combination of those tensors and their projection
1057 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
1058 _combination_op_,
1059 _projection_op_);
1060
1061 _ScheduleMultiDimSet_ new_pot_list
1062 = combine_and_project.schedule(schedule, marg_pot_set, hard_variables);
1063
1064 // there should be only one tensor in new_cpt_list
1065 if (new_pot_list.size() != 1) {
1067 "the projection of a tensor containing " << "hard evidence is empty!");
1068 }
1069 auto projected_pot = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1070 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(*new_pot_list.begin()));
1071 const_cast< ScheduleOperator* >(schedule.scheduleMultiDimCreator(projected_pot))
1072 ->makeResultsPersistent(true);
1073 _clique_tensors_[_factor_to_clique_[pot]].insert(projected_pot);
1074 _hard_ev_projected_factors_.insert(pot, projected_pot);
1075 }
1076
1077 // here, the list of tensors stored in the invalidated cliques have
1078 // been updated. So, now, we can combine them to produce the Shafer-Shenoy
1079 // tensor stored into the clique
1080 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1081 for (const auto clique: invalidated_cliques) {
1082 const auto& potset = _clique_tensors_[clique];
1083
1084 if (potset.size() > 0) {
1085 // here, there will be an entry in _clique_ss_tensor_
1086 // If there is only one element in potset, this element shall be
1087 // stored into _clique_ss_tensor_, else all the elements of potset
1088 // shall be combined and their result shall be stored
1089 if (potset.size() == 1) {
1090 _clique_ss_tensor_[clique] = *(potset.cbegin());
1091 } else {
1092 for (const auto pot: potset)
1093 if (!schedule.existsScheduleMultiDim(pot->id()))
1094 schedule.emplaceScheduleMultiDim(*pot);
1095 auto joint = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1096 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1097 fast_combination.schedule(schedule, potset)));
1098 const_cast< ScheduleOperator* >(schedule.scheduleMultiDimCreator(joint))
1099 ->makeResultsPersistent(true);
1100 _clique_ss_tensor_[clique] = joint;
1101 }
1102 }
1103 }
1104 this->scheduler().execute(schedule);
1105 } else {
1106 for (const auto pot: hard_projected_factors_changed) {
1107 _TensorSet_ marg_pot_set;
1108 marg_pot_set.insert(pot);
1109 const auto& variables = pot->variablesSequence();
1110
1111 gum::VariableSet hard_variables(variables.size());
1112 for (const auto var: variables) {
1113 NodeId xnode = mn.nodeId(*var);
1114 if (_hard_ev_nodes_.exists(xnode)) {
1115 marg_pot_set.insert(evidence[xnode]);
1116 hard_variables.insert(var);
1117 }
1118 }
1119
1120 // perform the combination of those tensors and their projection
1121 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
1122 _combination_op_,
1123 _projection_op_);
1124
1125 _TensorSet_ new_pot_list = combine_and_project.execute(marg_pot_set, hard_variables);
1126
1127 // there should be only one tensor in new_cpt_list
1128 if (new_pot_list.size() != 1) {
1130 "the projection of a tensor containing " << "hard evidence is empty!");
1131 }
1132 Tensor< GUM_SCALAR >* xprojected_pot
1133 = const_cast< Tensor< GUM_SCALAR >* >(*new_pot_list.begin());
1134 auto projected_pot
1135 = new ScheduleMultiDim< Tensor< GUM_SCALAR > >(std::move(*xprojected_pot));
1136 delete xprojected_pot;
1137 _clique_tensors_[_factor_to_clique_[pot]].insert(projected_pot);
1138 _hard_ev_projected_factors_.insert(pot, projected_pot);
1139 }
1140
1141 // here, the list of tensors stored in the invalidated cliques have
1142 // been updated. So, now, we can combine them to produce the Shafer-Shenoy
1143 // tensor stored into the clique
1144 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1145 for (const auto clique: invalidated_cliques) {
1146 const auto& potset = _clique_tensors_[clique];
1147
1148 if (potset.size() > 0) {
1149 // here, there will be an entry in _clique_ss_tensor_
1150 // If there is only one element in potset, this element shall be
1151 // stored into _clique_ss_tensor_, else all the elements of potset
1152 // shall be combined and their result shall be stored
1153 if (potset.size() == 1) {
1154 _clique_ss_tensor_[clique] = *(potset.cbegin());
1155 } else {
1156 _TensorSet_ p_potset(potset.size());
1157 for (const auto pot: potset)
1158 p_potset.insert(&(
1159 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(pot)->multiDim()));
1160
1161 Tensor< GUM_SCALAR >* joint
1162 = const_cast< Tensor< GUM_SCALAR >* >(fast_combination.execute(p_potset));
1163 _clique_ss_tensor_[clique]
1164 = new ScheduleMultiDim< Tensor< GUM_SCALAR > >(std::move(*joint));
1165 delete joint;
1166 }
1167 }
1168 }
1169 }
1170
1171 // update the constants
1172 const auto& hard_evidence = this->hardEvidence();
1173 for (auto& node_cst: _constants_) {
1174 const Tensor< GUM_SCALAR >& pot = *(node_cst.first);
1175 Instantiation inst(pot);
1176 for (const auto var: pot.variablesSequence()) {
1177 inst.chgVal(*var, hard_evidence[mn.nodeId(*var)]);
1178 }
1179 node_cst.second = pot.get(inst);
1180 }
1181
1182 // indicate that all changes have been performed
1183 _evidence_changes_.clear();
1184 }
1185
1187 template < typename GUM_SCALAR >
1188 void ShaferShenoyMRFInference< GUM_SCALAR >::_computeJoinTreeRoots_() {
1189 // get the set of cliques in which we can find the targets and joint_targets.
1190 // Due to hard evidence, the cliques related to a given target node
1191 // might not exist, hence the try..catch.
1192 NodeSet clique_targets;
1193 for (const auto node: this->targets()) {
1194 try {
1195 clique_targets.insert(_node_to_clique_[node]);
1196 } catch (Exception const&) {}
1197 }
1198 for (const auto& set: this->jointTargets()) {
1199 try {
1200 clique_targets.insert(_joint_target_to_clique_[set]);
1201 } catch (Exception const&) {}
1202 }
1203
1204 // put in a vector these cliques and their sizes
1205 std::vector< std::pair< NodeId, Size > > possible_roots(clique_targets.size());
1206 const auto& mn = this->MRF();
1207 std::size_t i = 0;
1208 for (const auto clique_id: clique_targets) {
1209 const auto& clique = _JT_->clique(clique_id);
1210 Size dom_size = 1;
1211 for (const auto node: clique) {
1212 dom_size *= mn.variable(node).domainSize();
1213 }
1214 possible_roots[i] = std::pair< NodeId, Size >(clique_id, dom_size);
1215 ++i;
1216 }
1217
1218 // sort the cliques by increasing domain size
1219 std::sort(possible_roots.begin(),
1220 possible_roots.end(),
1221 [](const std::pair< NodeId, Size >& a, const std::pair< NodeId, Size >& b) -> bool {
1222 return a.second < b.second;
1223 });
1224
1225 // pick up the clique with the smallest size in each connected component
1226 NodeProperty< bool > marked = _JT_->nodesPropertyFromVal(false);
1227 std::function< void(NodeId, NodeId) > diffuse_marks
1228 = [&marked, &diffuse_marks, this](NodeId node, NodeId from) {
1229 if (!marked[node]) {
1230 marked[node] = true;
1231 for (const auto neigh: _JT_->neighbours(node))
1232 if ((neigh != from) && !marked[neigh]) diffuse_marks(neigh, node);
1233 }
1234 };
1235 _roots_.clear();
1236 for (const auto& xclique: possible_roots) {
1237 NodeId clique = xclique.first;
1238 if (!marked[clique]) {
1239 _roots_.insert(clique);
1240 diffuse_marks(clique, clique);
1241 }
1242 }
1243 }
1244
1245 // performs the collect phase of Shafer-Shenoy using schedules
1246 template < typename GUM_SCALAR >
1247 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::_collectMessage_(Schedule& schedule,
1248 NodeId id,
1249 NodeId from) {
1250 for (const auto other: _JT_->neighbours(id)) {
1251 if ((other != from) && !_messages_computed_[Arc(other, id)])
1252 _collectMessage_(schedule, other, id);
1253 }
1254
1255 if ((id != from) && !_messages_computed_[Arc(id, from)]) {
1256 _produceMessage_(schedule, id, from);
1257 }
1258 }
1259
1260 // performs the collect phase of Shafer-Shenoy without schedules
1261 template < typename GUM_SCALAR >
1262 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::_collectMessage_(NodeId id, NodeId from) {
1263 for (const auto other: _JT_->neighbours(id)) {
1264 if ((other != from) && !_messages_computed_[Arc(other, id)]) _collectMessage_(other, id);
1265 }
1266
1267 if ((id != from) && !_messages_computed_[Arc(id, from)]) { _produceMessage_(id, from); }
1268 }
1269
1270 // remove variables del_vars from the list of tensors pot_list
1271 template < typename GUM_SCALAR >
1272 const IScheduleMultiDim* ShaferShenoyMRFInference< GUM_SCALAR >::_marginalizeOut_(
1273 Schedule& schedule,
1274 Set< const IScheduleMultiDim* > pot_list,
1275 gum::VariableSet& del_vars,
1276 gum::VariableSet& kept_vars) {
1277 // let's guarantee that all the tensors to be combined and projected
1278 // belong to the schedule
1279 for (const auto pot: pot_list) {
1280 if (!schedule.existsScheduleMultiDim(pot->id())) schedule.emplaceScheduleMultiDim(*pot);
1281 }
1282
1283 // create a combine and project operator that will perform the
1284 // marginalization
1285 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(_combination_op_,
1286 _projection_op_);
1287 _ScheduleMultiDimSet_ new_pot_list = combine_and_project.schedule(schedule, pot_list, del_vars);
1288
1289 // combine all the remaining tensors in order to create only one resulting tensor
1290 if (new_pot_list.size() == 1) return *(new_pot_list.begin());
1291 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1292 return fast_combination.schedule(schedule, new_pot_list);
1293 }
1294
1295 // remove variables del_vars from the list of tensors pot_list
1296 template < typename GUM_SCALAR >
1297 const IScheduleMultiDim* ShaferShenoyMRFInference< GUM_SCALAR >::_marginalizeOut_(
1298 Set< const IScheduleMultiDim* >& pot_list,
1299 gum::VariableSet& del_vars,
1300 gum::VariableSet& kept_vars) {
1301 _TensorSet_ xpot_list(pot_list.size());
1302 for (auto pot: pot_list)
1303 xpot_list.insert(
1304 &(static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(pot)->multiDim()));
1305
1306 // create a combine and project operator that will perform the
1307 // marginalization
1308 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(_combination_op_,
1309 _projection_op_);
1310 _TensorSet_ xnew_pot_list = combine_and_project.execute(xpot_list, del_vars);
1311
1312 // combine all the remaining tensors in order to create only one resulting tensor
1313 const Tensor< GUM_SCALAR >* xres_pot;
1314 if (xnew_pot_list.size() == 1) {
1315 xres_pot = *(xnew_pot_list.begin());
1316 } else {
1317 // combine all the tensors that resulted from the above combine and
1318 // projet execution
1319 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1320 xres_pot = fast_combination.execute(xnew_pot_list);
1321 for (const auto pot: xnew_pot_list) {
1322 if (!xpot_list.contains(pot) && (pot != xres_pot)) delete pot;
1323 }
1324 }
1325
1326 // transform xres_pot into a ScheduleMultiDim
1327 ScheduleMultiDim< Tensor< GUM_SCALAR > >* res_pot;
1328 if (xpot_list.contains(xres_pot))
1329 res_pot = new ScheduleMultiDim< Tensor< GUM_SCALAR > >(*xres_pot, false);
1330 else {
1331 res_pot = new ScheduleMultiDim< Tensor< GUM_SCALAR > >(
1332 std::move(const_cast< Tensor< GUM_SCALAR >& >(*xres_pot)));
1333 delete xres_pot;
1334 }
1335
1336 return res_pot;
1337 }
1338
1339 // creates the message sent by clique from_id to clique to_id
1340 template < typename GUM_SCALAR >
1341 void ShaferShenoyMRFInference< GUM_SCALAR >::_produceMessage_(Schedule& schedule,
1342 NodeId from_id,
1343 NodeId to_id) {
1344 // get the tensors of the clique.
1345 _ScheduleMultiDimSet_ pot_list;
1346 if (_clique_ss_tensor_[from_id] != nullptr) pot_list.insert(_clique_ss_tensor_[from_id]);
1347
1348 // add the messages sent by adjacent nodes to from_id.
1349 for (const auto other_id: _JT_->neighbours(from_id)) {
1350 if (other_id != to_id) {
1351 const auto separator_pot = _separator_tensors_[Arc(other_id, from_id)];
1352 if (separator_pot != nullptr) pot_list.insert(separator_pot);
1353 }
1354 }
1355
1356 // get the set of variables that need be removed from the tensors
1357 const NodeSet& from_clique = _JT_->clique(from_id);
1358 const NodeSet& separator = _JT_->separator(from_id, to_id);
1359 gum::VariableSet del_vars(from_clique.size());
1360 gum::VariableSet kept_vars(separator.size());
1361 const auto& mn = this->MRF();
1362
1363 for (const auto node: from_clique) {
1364 if (!separator.contains(node)) {
1365 del_vars.insert(&(mn.variable(node)));
1366 } else {
1367 kept_vars.insert(&(mn.variable(node)));
1368 }
1369 }
1370
1371 // pot_list now contains all the tensors to multiply and marginalize
1372 // => combine the messages
1373 const IScheduleMultiDim* new_pot = _marginalizeOut_(schedule, pot_list, del_vars, kept_vars);
1374
1375 // keep track of the newly created tensor
1376 const Arc arc(from_id, to_id);
1377 if (!pot_list.exists(new_pot)) {
1378 if (!_arc_to_created_tensors_.exists(arc)) {
1379 _arc_to_created_tensors_.insert(arc, new_pot);
1380
1381 // do not forget to make the ScheduleMultiDim persistent
1382 auto op = schedule.scheduleMultiDimCreator(new_pot);
1383 if (op != nullptr) const_cast< ScheduleOperator* >(op)->makeResultsPersistent(true);
1384 }
1385 }
1386
1387 _separator_tensors_[arc] = new_pot;
1388 _messages_computed_[arc] = true;
1389 }
1390
1391 // creates the message sent by clique from_id to clique to_id
1392 template < typename GUM_SCALAR >
1393 void ShaferShenoyMRFInference< GUM_SCALAR >::_produceMessage_(NodeId from_id, NodeId to_id) {
1394 // get the tensors of the clique.
1395 _ScheduleMultiDimSet_ pot_list;
1396 if (_clique_ss_tensor_[from_id] != nullptr) pot_list.insert(_clique_ss_tensor_[from_id]);
1397
1398 // add the messages sent by adjacent nodes to from_id.
1399 for (const auto other_id: _JT_->neighbours(from_id)) {
1400 if (other_id != to_id) {
1401 const auto separator_pot = _separator_tensors_[Arc(other_id, from_id)];
1402 if (separator_pot != nullptr) pot_list.insert(separator_pot);
1403 }
1404 }
1405
1406 // get the set of variables that need be removed from the tensors
1407 const NodeSet& from_clique = _JT_->clique(from_id);
1408 const NodeSet& separator = _JT_->separator(from_id, to_id);
1409 gum::VariableSet del_vars(from_clique.size());
1410 gum::VariableSet kept_vars(separator.size());
1411 const auto& mn = this->MRF();
1412
1413 for (const auto node: from_clique) {
1414 if (!separator.contains(node)) {
1415 del_vars.insert(&(mn.variable(node)));
1416 } else {
1417 kept_vars.insert(&(mn.variable(node)));
1418 }
1419 }
1420
1421 // pot_list now contains all the tensors to multiply and marginalize
1422 // => combine the messages
1423 const IScheduleMultiDim* new_pot = _marginalizeOut_(pot_list, del_vars, kept_vars);
1424
1425 // keep track of the newly created tensor
1426 const Arc arc(from_id, to_id);
1427 if (!pot_list.exists(new_pot)) {
1428 if (!_arc_to_created_tensors_.exists(arc)) { _arc_to_created_tensors_.insert(arc, new_pot); }
1429 }
1430
1431 _separator_tensors_[arc] = new_pot;
1432 _messages_computed_[arc] = true;
1433 }
1434
1435 // fired after a new Markov net has been assigned to the inference engine
1436 template < typename GUM_SCALAR >
1437 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::onModelChanged_(const GraphicalModel* mn) {
1438 JointTargetedMRFInference< GUM_SCALAR >::onModelChanged_(mn);
1439 }
1440
1441 // performs a whole inference
1442 template < typename GUM_SCALAR >
1443 INLINE void ShaferShenoyMRFInference< GUM_SCALAR >::makeInference_() {
1444 if (_use_schedules_) {
1445 Schedule schedule;
1446
1447 // collect messages for all single targets
1448 for (const auto node: this->targets()) {
1449 // perform only collects in the join tree for nodes that have
1450 // not received hard evidence (those that received hard evidence were
1451 // not included into the join tree for speed-up reasons)
1452 if (_graph_.exists(node)) {
1453 _collectMessage_(schedule, _node_to_clique_[node], _node_to_clique_[node]);
1454 }
1455 }
1456
1457 // collect messages for all set targets
1458 // by parsing _joint_target_to_clique_, we ensure that the cliques that
1459 // are referenced belong to the join tree (even if some of the nodes in
1460 // their associated joint_target do not belong to _graph_)
1461 for (const auto& set: _joint_target_to_clique_)
1462 _collectMessage_(schedule, set.second, set.second);
1463
1464 // really perform the computations
1465 this->scheduler().execute(schedule);
1466 } else {
1467 // collect messages for all single targets
1468 for (const auto node: this->targets()) {
1469 // perform only collects in the join tree for nodes that have
1470 // not received hard evidence (those that received hard evidence were
1471 // not included into the join tree for speed-up reasons)
1472 if (_graph_.exists(node)) {
1473 _collectMessage_(_node_to_clique_[node], _node_to_clique_[node]);
1474 }
1475 }
1476
1477 // collect messages for all set targets
1478 // by parsing _joint_target_to_clique_, we ensure that the cliques that
1479 // are referenced belong to the join tree (even if some of the nodes in
1480 // their associated joint_target do not belong to _graph_)
1481 for (const auto& set: _joint_target_to_clique_)
1482 _collectMessage_(set.second, set.second);
1483 }
1484 }
1485
1487 template < typename GUM_SCALAR >
1488 Tensor< GUM_SCALAR >*
1489 ShaferShenoyMRFInference< GUM_SCALAR >::unnormalizedJointPosterior_(NodeId id) {
1490 if (_use_schedules_) {
1491 Schedule schedule;
1492 return _unnormalizedJointPosterior_(schedule, id);
1493 } else {
1494 return _unnormalizedJointPosterior_(id);
1495 }
1496 }
1497
1499 template < typename GUM_SCALAR >
1500 Tensor< GUM_SCALAR >*
1501 ShaferShenoyMRFInference< GUM_SCALAR >::_unnormalizedJointPosterior_(Schedule& schedule,
1502 NodeId id) {
1503 const auto& mn = this->MRF();
1504
1505 // hard evidence do not belong to the join tree
1506 // # TODO: check for sets of inconsistent hard evidence
1507 if (this->hardEvidenceNodes().contains(id)) {
1508 return new Tensor< GUM_SCALAR >(*(this->evidence()[id]));
1509 }
1510
1511 auto& scheduler = this->scheduler();
1512
1513 // if we still need to perform some inference task, do it (this should
1514 // already have been done by makeInference_)
1515 const NodeId clique_of_id = _node_to_clique_[id];
1516 _collectMessage_(schedule, clique_of_id, clique_of_id);
1517
1518 // now we just need to create the product of the tensors of the clique
1519 // containing id with the messages received by this clique and
1520 // marginalize out all variables except id
1521 _ScheduleMultiDimSet_ pot_list;
1522 if (_clique_ss_tensor_[clique_of_id] != nullptr)
1523 pot_list.insert(_clique_ss_tensor_[clique_of_id]);
1524
1525 // add the messages sent by adjacent nodes to targetClique
1526 for (const auto other: _JT_->neighbours(clique_of_id))
1527 pot_list.insert(_separator_tensors_[Arc(other, clique_of_id)]);
1528
1529 // get the set of variables that need be removed from the tensors
1530 const NodeSet& nodes = _JT_->clique(clique_of_id);
1531 gum::VariableSet kept_vars{&(mn.variable(id))};
1532 gum::VariableSet del_vars(nodes.size());
1533 for (const auto node: nodes) {
1534 if (node != id) del_vars.insert(&(mn.variable(node)));
1535 }
1536
1537 // pot_list now contains all the tensors to multiply and marginalize
1538 // => combine the messages
1539 auto resulting_pot = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1540 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1541 _marginalizeOut_(schedule, pot_list, del_vars, kept_vars)));
1542 Tensor< GUM_SCALAR >* joint = nullptr;
1543
1544 scheduler.execute(schedule);
1545
1546 // if pot already existed, create a copy, so that we can put it into
1547 // the _target_posteriors_ property
1548 if (pot_list.exists(resulting_pot)) {
1549 joint = new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1550 } else {
1551 joint = resulting_pot->exportMultiDim();
1552 }
1553
1554 // check that the joint posterior is different from a 0 vector: this would
1555 // indicate that some hard evidence are not compatible (their joint
1556 // probability is equal to 0)
1557 bool nonzero_found = false;
1558 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1559 if (joint->get(inst)) {
1560 nonzero_found = true;
1561 break;
1562 }
1563 }
1564 if (!nonzero_found) {
1565 // remove joint from memory to avoid memory leaks
1566 delete joint;
1568 "some evidence entered into the Markov "
1569 "net are incompatible (their joint proba = 0)");
1570 }
1571 return joint;
1572 }
1573
1575 template < typename GUM_SCALAR >
1576 Tensor< GUM_SCALAR >*
1577 ShaferShenoyMRFInference< GUM_SCALAR >::_unnormalizedJointPosterior_(NodeId id) {
1578 const auto& mn = this->MRF();
1579
1580 // hard evidence do not belong to the join tree
1581 // # TODO: check for sets of inconsistent hard evidence
1582 if (this->hardEvidenceNodes().contains(id)) {
1583 return new Tensor< GUM_SCALAR >(*(this->evidence()[id]));
1584 }
1585
1586 // if we still need to perform some inference task, do it (this should
1587 // already have been done by makeInference_)
1588 NodeId clique_of_id = _node_to_clique_[id];
1589 _collectMessage_(clique_of_id, clique_of_id);
1590
1591 // now we just need to create the product of the tensors of the clique
1592 // containing id with the messages received by this clique and
1593 // marginalize out all variables except id
1594 _ScheduleMultiDimSet_ pot_list;
1595 if (_clique_ss_tensor_[clique_of_id] != nullptr)
1596 pot_list.insert(_clique_ss_tensor_[clique_of_id]);
1597
1598 // add the messages sent by adjacent nodes to targetClique
1599 for (const auto other: _JT_->neighbours(clique_of_id))
1600 pot_list.insert(_separator_tensors_[Arc(other, clique_of_id)]);
1601
1602 // get the set of variables that need be removed from the tensors
1603 const NodeSet& nodes = _JT_->clique(clique_of_id);
1604 gum::VariableSet kept_vars{&(mn.variable(id))};
1605 gum::VariableSet del_vars(nodes.size());
1606 for (const auto node: nodes) {
1607 if (node != id) del_vars.insert(&(mn.variable(node)));
1608 }
1609
1610 // pot_list now contains all the tensors to multiply and marginalize
1611 // => combine the messages
1612 auto resulting_pot = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1613 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1614 _marginalizeOut_(pot_list, del_vars, kept_vars)));
1615 Tensor< GUM_SCALAR >* joint = nullptr;
1616
1617 // if pot already existed, create a copy, so that we can put it into
1618 // the _target_posteriors_ property
1619 if (pot_list.exists(resulting_pot)) {
1620 joint = new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1621 } else {
1622 joint = resulting_pot->exportMultiDim();
1623 delete resulting_pot;
1624 }
1625
1626 // check that the joint posterior is different from a 0 vector: this would
1627 // indicate that some hard evidence are not compatible (their joint
1628 // probability is equal to 0)
1629 bool nonzero_found = false;
1630 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1631 if (joint->get(inst)) {
1632 nonzero_found = true;
1633 break;
1634 }
1635 }
1636 if (!nonzero_found) {
1637 // remove joint from memory to avoid memory leaks
1638 delete joint;
1640 "some evidence entered into the Markov "
1641 "net are incompatible (their joint proba = 0)");
1642 }
1643 return joint;
1644 }
1645
1647 template < typename GUM_SCALAR >
1648 const Tensor< GUM_SCALAR >& ShaferShenoyMRFInference< GUM_SCALAR >::posterior_(NodeId id) {
1649 // check if we have already computed the posterior
1650 if (_target_posteriors_.exists(id)) { return *(_target_posteriors_[id]); }
1651
1652 // compute the joint posterior and normalize
1653 auto joint = unnormalizedJointPosterior_(id);
1654 if (joint->sum() != 1) // hard test for ReadOnly CPT (as aggregator)
1655 joint->normalize();
1656 _target_posteriors_.insert(id, joint);
1657
1658 return *joint;
1659 }
1660
1662 template < typename GUM_SCALAR >
1663 Tensor< GUM_SCALAR >*
1664 ShaferShenoyMRFInference< GUM_SCALAR >::unnormalizedJointPosterior_(const NodeSet& set) {
1665 if (_use_schedules_) {
1666 Schedule schedule;
1667 return _unnormalizedJointPosterior_(schedule, set);
1668 } else {
1669 return _unnormalizedJointPosterior_(set);
1670 }
1671 }
1672
1674 template < typename GUM_SCALAR >
1675 Tensor< GUM_SCALAR >*
1676 ShaferShenoyMRFInference< GUM_SCALAR >::_unnormalizedJointPosterior_(Schedule& schedule,
1677 const NodeSet& set) {
1678 // hard evidence do not belong to the join tree, so extract the nodes
1679 // from targets that are not hard evidence
1680 NodeSet targets = set, hard_ev_nodes;
1681 for (const auto node: this->hardEvidenceNodes()) {
1682 if (targets.contains(node)) {
1683 targets.erase(node);
1684 hard_ev_nodes.insert(node);
1685 }
1686 }
1687
1688 auto& scheduler = this->scheduler();
1689
1690 // if all the nodes have received hard evidence, then compute the
1691 // joint posterior directly by multiplying the hard evidence tensors
1692 const auto& evidence = this->evidence();
1693 if (targets.empty()) {
1694 if (set.size() == 1) {
1695 return new Tensor< GUM_SCALAR >(*evidence[*set.begin()]);
1696 } else {
1697 _ScheduleMultiDimSet_ pot_list;
1698 for (const auto node: set) {
1699 auto new_pot_ev = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[node], false);
1700 pot_list.insert(new_pot_ev);
1701 }
1702
1703 // combine all the tensors of the nodes in set
1704 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1705 const IScheduleMultiDim* pot = fast_combination.schedule(schedule, pot_list);
1706 auto schedule_pot = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1707 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(pot));
1708 scheduler.execute(schedule);
1709 auto result = schedule_pot->exportMultiDim();
1710
1711 return result;
1712 }
1713 }
1714
1715
1716 // if we still need to perform some inference task, do it: so, first,
1717 // determine the clique on which we should perform collect to compute
1718 // the unnormalized joint posterior of a set of nodes containing "targets"
1719 NodeId clique_of_set;
1720 try {
1721 clique_of_set = _joint_target_to_clique_[set];
1722 } catch (NotFound const&) {
1723 // here, the precise set of targets does not belong to the set of targets
1724 // defined by the user. So we will try to find a clique in the junction
1725 // tree that contains "targets":
1726
1727 // 1/ we should check that all the nodes belong to the join tree
1728 for (const auto node: targets) {
1729 if (!_graph_.exists(node)) {
1731 "The variable " << this->MRF().variable(node).name() << "(" << node
1732 << ") does not belong to this optimized inference.")
1733 }
1734 }
1735
1736 // 2/ the clique created by the first eliminated node among target is the
1737 // one we are looking for
1738 const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder();
1739 NodeProperty< int > elim_order(Size(JT_elim_order.size()));
1740 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size; ++i)
1741 elim_order.insert(JT_elim_order[i], (int)i);
1742 NodeId first_eliminated_node = *(targets.begin());
1743 int elim_number = elim_order[first_eliminated_node];
1744 for (const auto node: targets) {
1745 if (elim_order[node] < elim_number) {
1746 elim_number = elim_order[node];
1747 first_eliminated_node = node;
1748 }
1749 }
1750
1751 clique_of_set = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
1752
1753
1754 // 3/ check that clique_of_set contains the all the nodes in the target
1755 const NodeSet& clique_nodes = _JT_->clique(clique_of_set);
1756 for (const auto node: targets) {
1757 if (!clique_nodes.contains(node)) {
1759 this->MRF().names(set) << "(" << set << ")"
1760 << " is not addressable in this optimized inference.")
1761 }
1762 }
1763
1764 // add the discovered clique to _joint_target_to_clique_
1765 _joint_target_to_clique_.insert(set, clique_of_set);
1766 }
1767
1768 // now perform a collect on the clique
1769 _collectMessage_(schedule, clique_of_set, clique_of_set);
1770
1771 // now we just need to create the product of the tensors of the clique
1772 // containing set with the messages received by this clique and
1773 // marginalize out all variables except set
1774 _ScheduleMultiDimSet_ pot_list;
1775 if (_clique_ss_tensor_[clique_of_set] != nullptr) {
1776 auto pot = _clique_ss_tensor_[clique_of_set];
1777 if (!schedule.existsScheduleMultiDim(pot->id())) schedule.emplaceScheduleMultiDim(*pot);
1778 pot_list.insert(_clique_ss_tensor_[clique_of_set]);
1779 }
1780
1781 // add the messages sent by adjacent nodes to targetClique
1782 for (const auto other: _JT_->neighbours(clique_of_set)) {
1783 const auto pot = _separator_tensors_[Arc(other, clique_of_set)];
1784 if (pot != nullptr) pot_list.insert(pot);
1785 }
1786
1787
1788 // get the set of variables that need be removed from the tensors
1789 const NodeSet& nodes = _JT_->clique(clique_of_set);
1790 gum::VariableSet del_vars(nodes.size());
1791 gum::VariableSet kept_vars(targets.size());
1792 const auto& mn = this->MRF();
1793 for (const auto node: nodes) {
1794 if (!targets.contains(node)) {
1795 del_vars.insert(&(mn.variable(node)));
1796 } else {
1797 kept_vars.insert(&(mn.variable(node)));
1798 }
1799 }
1800
1801 // pot_list now contains all the tensors to multiply and marginalize
1802 // => combine the messages
1803 const IScheduleMultiDim* new_pot = _marginalizeOut_(schedule, pot_list, del_vars, kept_vars);
1804 scheduler.execute(schedule);
1805 ScheduleMultiDim< Tensor< GUM_SCALAR > >* resulting_pot
1806 = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1807 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(new_pot));
1808
1809 // if pot already existed, create a copy, so that we can put it into
1810 // the _target_posteriors_ property
1811 Tensor< GUM_SCALAR >* joint = nullptr;
1812 if (pot_list.exists(resulting_pot)) {
1813 joint = new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1814 } else {
1815 joint = resulting_pot->exportMultiDim();
1816 }
1817
1818 // check that the joint posterior is different from a 0 vector: this would
1819 // indicate that some hard evidence are not compatible
1820 bool nonzero_found = false;
1821 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1822 if ((*joint)[inst]) {
1823 nonzero_found = true;
1824 break;
1825 }
1826 }
1827 if (!nonzero_found) {
1828 // remove joint from memory to avoid memory leaks
1829 delete joint;
1831 "some evidence entered into the Markov "
1832 "net are incompatible (their joint proba = 0)");
1833 }
1834
1835 return joint;
1836 }
1837
1839 template < typename GUM_SCALAR >
1840 Tensor< GUM_SCALAR >*
1841 ShaferShenoyMRFInference< GUM_SCALAR >::_unnormalizedJointPosterior_(const NodeSet& set) {
1842 // hard evidence do not belong to the join tree, so extract the nodes
1843 // from targets that are not hard evidence
1844 NodeSet targets = set, hard_ev_nodes;
1845 for (const auto node: this->hardEvidenceNodes()) {
1846 if (targets.contains(node)) {
1847 targets.erase(node);
1848 hard_ev_nodes.insert(node);
1849 }
1850 }
1851
1852 // if all the nodes have received hard evidence, then compute the
1853 // joint posterior directly by multiplying the hard evidence tensors
1854 const auto& evidence = this->evidence();
1855 if (targets.empty()) {
1856 if (set.size() == 1) {
1857 return new Tensor< GUM_SCALAR >(*evidence[*set.begin()]);
1858 } else {
1859 _TensorSet_ pot_list;
1860 for (const auto node: set) {
1861 pot_list.insert(evidence[node]);
1862 }
1863
1864 // combine all the tensors of the nodes in set
1865 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1866 const Tensor< GUM_SCALAR >* pot = fast_combination.execute(pot_list);
1867
1868 return const_cast< Tensor< GUM_SCALAR >* >(pot);
1869 }
1870 }
1871
1872
1873 // if we still need to perform some inference task, do it: so, first,
1874 // determine the clique on which we should perform collect to compute
1875 // the unnormalized joint posterior of a set of nodes containing "targets"
1876 NodeId clique_of_set;
1877 try {
1878 clique_of_set = _joint_target_to_clique_[set];
1879 } catch (NotFound const&) {
1880 // here, the precise set of targets does not belong to the set of targets
1881 // defined by the user. So we will try to find a clique in the junction
1882 // tree that contains "targets":
1883
1884 // 1/ we should check that all the nodes belong to the join tree
1885 for (const auto node: targets) {
1886 if (!_graph_.exists(node)) {
1888 "The variable " << this->MRF().variable(node).name() << "(" << node
1889 << ") does not belong to this optimized inference.")
1890 }
1891 }
1892
1893 // 2/ the clique created by the first eliminated node among target is the
1894 // one we are looking for
1895 const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder();
1896 NodeProperty< int > elim_order(Size(JT_elim_order.size()));
1897 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size; ++i)
1898 elim_order.insert(JT_elim_order[i], (int)i);
1899 NodeId first_eliminated_node = *(targets.begin());
1900 int elim_number = elim_order[first_eliminated_node];
1901 for (const auto node: targets) {
1902 if (elim_order[node] < elim_number) {
1903 elim_number = elim_order[node];
1904 first_eliminated_node = node;
1905 }
1906 }
1907
1908 clique_of_set = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
1909
1910
1911 // 3/ check that clique_of_set contains the all the nodes in the target
1912 const NodeSet& clique_nodes = _JT_->clique(clique_of_set);
1913 for (const auto node: targets) {
1914 if (!clique_nodes.contains(node)) {
1915 GUM_ERROR(UndefinedElement, set << " is not a joint target")
1916 }
1917 }
1918
1919 // add the discovered clique to _joint_target_to_clique_
1920 _joint_target_to_clique_.insert(set, clique_of_set);
1921 }
1922
1923 // now perform a collect on the clique
1924 _collectMessage_(clique_of_set, clique_of_set);
1925
1926 // now we just need to create the product of the tensors of the clique
1927 // containing set with the messages received by this clique and
1928 // marginalize out all variables except set
1929 _ScheduleMultiDimSet_ pot_list;
1930 if (_clique_ss_tensor_[clique_of_set] != nullptr) {
1931 auto pot = _clique_ss_tensor_[clique_of_set];
1932 if (pot != nullptr) pot_list.insert(_clique_ss_tensor_[clique_of_set]);
1933 }
1934
1935 // add the messages sent by adjacent nodes to targetClique
1936 for (const auto other: _JT_->neighbours(clique_of_set)) {
1937 const auto pot = _separator_tensors_[Arc(other, clique_of_set)];
1938 if (pot != nullptr) pot_list.insert(pot);
1939 }
1940
1941
1942 // get the set of variables that need be removed from the tensors
1943 const NodeSet& nodes = _JT_->clique(clique_of_set);
1944 gum::VariableSet del_vars(nodes.size());
1945 gum::VariableSet kept_vars(targets.size());
1946 const auto& mn = this->MRF();
1947 for (const auto node: nodes) {
1948 if (!targets.contains(node)) {
1949 del_vars.insert(&(mn.variable(node)));
1950 } else {
1951 kept_vars.insert(&(mn.variable(node)));
1952 }
1953 }
1954
1955 // pot_list now contains all the tensors to multiply and marginalize
1956 // => combine the messages
1957 const IScheduleMultiDim* new_pot = _marginalizeOut_(pot_list, del_vars, kept_vars);
1958 ScheduleMultiDim< Tensor< GUM_SCALAR > >* resulting_pot
1959 = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1960 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(new_pot));
1961
1962 // if pot already existed, create a copy, so that we can put it into
1963 // the _target_posteriors_ property
1964 Tensor< GUM_SCALAR >* joint = nullptr;
1965 if (pot_list.exists(resulting_pot)) {
1966 joint = new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1967 } else {
1968 joint = resulting_pot->exportMultiDim();
1969 delete new_pot;
1970 }
1971
1972 // check that the joint posterior is different from a 0 vector: this would
1973 // indicate that some hard evidence are not compatible
1974 bool nonzero_found = false;
1975 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1976 if ((*joint)[inst]) {
1977 nonzero_found = true;
1978 break;
1979 }
1980 }
1981 if (!nonzero_found) {
1982 // remove joint from memory to avoid memory leaks
1983 delete joint;
1985 "some evidence entered into the Markov "
1986 "net are incompatible (their joint proba = 0)");
1987 }
1988
1989 return joint;
1990 }
1991
1993 template < typename GUM_SCALAR >
1994 const Tensor< GUM_SCALAR >&
1995 ShaferShenoyMRFInference< GUM_SCALAR >::jointPosterior_(const NodeSet& set) {
1996 // check if we have already computed the posterior
1997 if (_joint_target_posteriors_.exists(set)) { return *(_joint_target_posteriors_[set]); }
1998
1999 // compute the joint posterior and normalize
2000 auto joint = unnormalizedJointPosterior_(set);
2001 joint->normalize();
2002 _joint_target_posteriors_.insert(set, joint);
2003
2004 return *joint;
2005 }
2006
2008 template < typename GUM_SCALAR >
2009 const Tensor< GUM_SCALAR >&
2010 ShaferShenoyMRFInference< GUM_SCALAR >::jointPosterior_(const NodeSet& wanted_target,
2011 const NodeSet& declared_target) {
2012 // check if we have already computed the posterior of wanted_target
2013 if (_joint_target_posteriors_.exists(wanted_target))
2014 return *(_joint_target_posteriors_[wanted_target]);
2015
2016 // here, we will have to compute the posterior of declared_target and
2017 // marginalize out all the variables that do not belong to wanted_target
2018
2019 // check if we have already computed the posterior of declared_target
2020 if (!_joint_target_posteriors_.exists(declared_target)) {
2021 return jointPosterior_(declared_target);
2022 }
2023
2024 // marginalize out all the variables that do not belong to wanted_target
2025 const auto& mn = this->MRF();
2026 gum::VariableSet del_vars;
2027 for (const auto node: declared_target)
2028 if (!wanted_target.contains(node)) del_vars.insert(&(mn.variable(node)));
2029 auto pot
2030 = new Tensor< GUM_SCALAR >(_joint_target_posteriors_[declared_target]->sumOut(del_vars));
2031
2032 // save the result into the cache
2033 _joint_target_posteriors_.insert(wanted_target, pot);
2034
2035 return *pot;
2036 }
2037
2038 template < typename GUM_SCALAR >
2039 GUM_SCALAR ShaferShenoyMRFInference< GUM_SCALAR >::evidenceProbability() {
2040 // perform inference in each connected component
2041 this->makeInference();
2042
2043 // for each connected component, select a variable X and compute the
2044 // joint probability of X and evidence e. Then marginalize-out X to get
2045 // p(e) in this connected component. Finally, multiply all the p(e) that
2046 // we got and the elements in _constants_. The result is the probability
2047 // of evidence
2048
2049 GUM_SCALAR prob_ev = 1;
2050 for (const auto root: _roots_) {
2051 // get a node in the clique
2052 NodeId node = *(_JT_->clique(root).begin());
2053 Tensor< GUM_SCALAR >* tmp = unnormalizedJointPosterior_(node);
2054 prob_ev *= tmp->sum();
2055 delete tmp;
2056 }
2057
2058 for (const auto& projected_cpt: _constants_)
2059 prob_ev *= projected_cpt.second;
2060
2061 return prob_ev;
2062 }
2063
2064 template < typename GUM_SCALAR >
2065 bool ShaferShenoyMRFInference< GUM_SCALAR >::isExactJointComputable_(const NodeSet& vars) {
2066 if (JointTargetedMRFInference< GUM_SCALAR >::isExactJointComputable_(vars)) return true;
2067
2068 this->prepareInference();
2069
2070 for (const auto& node: this->_JT_->nodes()) {
2071 const auto clique = _JT_->clique(node);
2072 if (vars == clique) return true;
2073 }
2074 return false;
2075 }
2076
2077 template < typename GUM_SCALAR >
2078 NodeSet ShaferShenoyMRFInference< GUM_SCALAR >::superForJointComputable_(const NodeSet& vars) {
2079 const auto superset = JointTargetedMRFInference< GUM_SCALAR >::superForJointComputable_(vars);
2080 if (!superset.empty()) return superset;
2081
2082 this->prepareInference();
2083
2084 for (const auto& node: _JT_->nodes()) {
2085 const auto clique = _JT_->clique(node);
2086 if (vars.isStrictSubsetOf(clique)) return clique;
2087 }
2088
2089
2090 return NodeSet();
2091 }
2092
2093} /* namespace gum */
2094
2095#endif // DOXYGEN_SHOULD_SKIP_THIS
Implementation of Shafer-Shenoy's algorithm for inference in Markov random fields.
An algorithm for converting a join tree into a binary join tree.
Exception : a similar element already exists.
<agrum/MRF/inference/evidenceMRFInference.h>
Exception : fatal (unknown ?) error.
Class representing the minimal interface for Markov random field.
Exception : several evidence are incompatible together (proba=0).
<agrum/MRF/inference/jointTargetedMRFInference.h>
Exception : the element we looked for cannot be found.
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
ShaferShenoyMRFInference(const IMarkovRandomField< GUM_SCALAR > *MN, bool use_binary_join_tree=true)
default constructor
Exception : a looked-for element could not be found.
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
Header files of gum::Instantiation.
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Set< const DiscreteVariable * > VariableSet
CliqueGraph JoinTree
a join tree is a clique graph satisfying the running intersection property (but some cliques may be i...
CliqueGraph JunctionTree
a junction tree is a clique graph satisfying the running intersection property and such that no cliqu...