aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
variableElimination_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
50
51#ifndef DOXYGEN_SHOULD_SKIP_THIS
52# include <algorithm>
53
62
63namespace gum {
64
65
66 // default constructor
67 template < typename GUM_SCALAR >
70 RelevantTensorsFinderType relevant_type,
71 FindBarrenNodesType barren_type) : JointTargetedInference< GUM_SCALAR >(BN) {
72 // sets the relevant tensor and the barren nodes finding algorithm
73 _findRelevantTensors_
74 = &VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation2_;
75 setRelevantTensorsFinderType(relevant_type);
76 setFindBarrenNodesType(barren_type);
77
78 // create a default triangulation (the user can change it afterwards)
79 _triangulation_ = new DefaultTriangulation;
80
81 // for debugging purposes
82 GUM_CONSTRUCTOR(VariableElimination);
83 }
84
85 // destructor
86 template < typename GUM_SCALAR >
87 INLINE VariableElimination< GUM_SCALAR >::~VariableElimination() {
88 // remove the junction tree and the triangulation algorithm
89 if (_JT_ != nullptr) delete _JT_;
90 delete _triangulation_;
91 if (_target_posterior_ != nullptr) delete _target_posterior_;
92
93 // for debugging purposes
94 GUM_DESTRUCTOR(VariableElimination);
95 }
96
98 template < typename GUM_SCALAR >
99 void VariableElimination< GUM_SCALAR >::setTriangulation(const Triangulation& new_triangulation) {
100 delete _triangulation_;
101 _triangulation_ = new_triangulation.newFactory();
102 }
103
105 template < typename GUM_SCALAR >
106 INLINE const JunctionTree* VariableElimination< GUM_SCALAR >::junctionTree(NodeId id) {
107 _createNewJT_(NodeSet{id});
108
109 return _JT_;
110 }
111
113 template < typename GUM_SCALAR >
114 void VariableElimination< GUM_SCALAR >::setRelevantTensorsFinderType(
115 RelevantTensorsFinderType type) {
116 if (type != _find_relevant_tensor_type_) {
117 switch (type) {
118 case RelevantTensorsFinderType::DSEP_BAYESBALL_TENSORS :
119 _findRelevantTensors_
120 = &VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation2_;
121 break;
122
123 case RelevantTensorsFinderType::DSEP_BAYESBALL_NODES :
124 _findRelevantTensors_
125 = &VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation_;
126 break;
127
128 case RelevantTensorsFinderType::DSEP_KOLLER_FRIEDMAN_2009 :
129 _findRelevantTensors_
130 = &VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation3_;
131 break;
132
133 case RelevantTensorsFinderType::FIND_ALL :
134 _findRelevantTensors_ = &VariableElimination< GUM_SCALAR >::_findRelevantTensorsGetAll_;
135 break;
136
137 default :
139 "setRelevantTensorsFinderType for type " << (unsigned int)type
140 << " is not implemented yet");
141 }
142
143 _find_relevant_tensor_type_ = type;
144 }
145 }
146
148 template < typename GUM_SCALAR >
149 INLINE void VariableElimination< GUM_SCALAR >::_setProjectionFunction_(
150 Tensor< GUM_SCALAR > (*proj)(const Tensor< GUM_SCALAR >&, const gum::VariableSet&)) {
151 _projection_op_ = proj;
152 }
153
155 template < typename GUM_SCALAR >
156 INLINE void VariableElimination< GUM_SCALAR >::_setCombinationFunction_(
157 Tensor< GUM_SCALAR > (*comb)(const Tensor< GUM_SCALAR >&, const Tensor< GUM_SCALAR >&)) {
158 _combination_op_ = comb;
159 }
160
162 template < typename GUM_SCALAR >
163 void VariableElimination< GUM_SCALAR >::setFindBarrenNodesType(FindBarrenNodesType type) {
164 if (type != _barren_nodes_type_) {
165 // WARNING: if a new type is added here, method _createJT_ should certainly
166 // be updated as well, in particular its step 2.
167 switch (type) {
168 case FindBarrenNodesType::FIND_BARREN_NODES :
169 case FindBarrenNodesType::FIND_NO_BARREN_NODES : break;
170
171 default :
173 "setFindBarrenNodesType for type " << (unsigned int)type
174 << " is not implemented yet");
175 }
176
177 _barren_nodes_type_ = type;
178 }
179 }
180
182 template < typename GUM_SCALAR >
183 INLINE void VariableElimination< GUM_SCALAR >::onEvidenceAdded_(const NodeId, bool) {}
184
186 template < typename GUM_SCALAR >
187 INLINE void VariableElimination< GUM_SCALAR >::onEvidenceErased_(const NodeId, bool) {}
188
190 template < typename GUM_SCALAR >
191 void VariableElimination< GUM_SCALAR >::onAllEvidenceErased_(bool) {}
192
194 template < typename GUM_SCALAR >
195 INLINE void VariableElimination< GUM_SCALAR >::onEvidenceChanged_(const NodeId, bool) {}
196
198 template < typename GUM_SCALAR >
199 INLINE void VariableElimination< GUM_SCALAR >::onMarginalTargetAdded_(const NodeId) {}
200
202 template < typename GUM_SCALAR >
203 INLINE void VariableElimination< GUM_SCALAR >::onMarginalTargetErased_(const NodeId) {}
204
206 template < typename GUM_SCALAR >
207 INLINE void VariableElimination< GUM_SCALAR >::onModelChanged_(const GraphicalModel* bn) {}
208
210 template < typename GUM_SCALAR >
211 INLINE void VariableElimination< GUM_SCALAR >::onJointTargetAdded_(const NodeSet&) {}
212
214 template < typename GUM_SCALAR >
215 INLINE void VariableElimination< GUM_SCALAR >::onJointTargetErased_(const NodeSet&) {}
216
218 template < typename GUM_SCALAR >
219 INLINE void VariableElimination< GUM_SCALAR >::onAllMarginalTargetsAdded_() {}
220
222 template < typename GUM_SCALAR >
223 INLINE void VariableElimination< GUM_SCALAR >::onAllMarginalTargetsErased_() {}
224
226 template < typename GUM_SCALAR >
227 INLINE void VariableElimination< GUM_SCALAR >::onAllJointTargetsErased_() {}
228
230 template < typename GUM_SCALAR >
231 INLINE void VariableElimination< GUM_SCALAR >::onAllTargetsErased_() {}
232
234 template < typename GUM_SCALAR >
235 void VariableElimination< GUM_SCALAR >::_createNewJT_(const NodeSet& targets) {
236 // to create the JT, we first create the moral graph of the BN in the
237 // following way in order to take into account the barren nodes and the
238 // nodes that received evidence:
239 // 1/ we create an undirected graph containing only the nodes and no edge
240 // 2/ if we take into account barren nodes, remove them from the graph
241 // 3/ if we take d-separation into account, remove the d-separated nodes
242 // 4/ add edges so that each node and its parents in the BN form a clique
243 // 5/ add edges so that joint targets form a clique of the moral graph
244 // 6/ remove the nodes that received hard evidence (by step 4/, their
245 // parents are linked by edges, which is necessary for inference)
246 //
247 // At the end of step 6/, we have our moral graph and we can triangulate it
248 // to get the new junction tree
249
250 // 1/ create an undirected graph containing only the nodes and no edge
251 const auto& bn = this->BN();
252 _graph_.clear();
253 for (const auto node: bn.dag())
254 _graph_.addNodeWithId(node);
255
256 // 2/ if we wish to exploit barren nodes, we shall remove them from the
257 // BN. To do so: we identify all the nodes that are not targets and have
258 // received no evidence and such that their descendants are neither
259 // targets nor evidence nodes. Such nodes can be safely discarded from
260 // the BN without altering the inference output
261 if (_barren_nodes_type_ == FindBarrenNodesType::FIND_BARREN_NODES) {
262 // check that all the nodes are not targets, otherwise, there is no
263 // barren node
264 if (targets.size() != bn.size()) {
265 BarrenNodesFinder finder(&(bn.dag()));
266 finder.setTargets(&targets);
267
268 NodeSet evidence_nodes(this->evidence().size());
269 for (const auto& pair: this->evidence()) {
270 evidence_nodes.insert(pair.first);
271 }
272 finder.setEvidence(&evidence_nodes);
273
274 NodeSet barren_nodes = finder.barrenNodes();
275
276 // remove the barren nodes from the moral graph
277 for (const auto node: barren_nodes) {
278 _graph_.eraseNode(node);
279 }
280 }
281 }
282
283 // 3/ if we wish to exploit d-separation, remove all the nodes that are
284 // d-separated from our targets
285 {
286 NodeSet requisite_nodes;
287 bool dsep_analysis = false;
288 switch (_find_relevant_tensor_type_) {
289 case RelevantTensorsFinderType::DSEP_BAYESBALL_TENSORS :
290 case RelevantTensorsFinderType::DSEP_BAYESBALL_NODES : {
291 BayesBall::requisiteNodes(bn.dag(),
292 targets,
293 this->hardEvidenceNodes(),
294 this->softEvidenceNodes(),
295 requisite_nodes);
296 dsep_analysis = true;
297 } break;
298
299 case RelevantTensorsFinderType::DSEP_KOLLER_FRIEDMAN_2009 : {
300 dSeparationAlgorithm dsep;
301 dsep.requisiteNodes(bn.dag(),
302 targets,
303 this->hardEvidenceNodes(),
304 this->softEvidenceNodes(),
305 requisite_nodes);
306 dsep_analysis = true;
307 } break;
308
309 case RelevantTensorsFinderType::FIND_ALL : break;
310
311 default : GUM_ERROR(FatalError, "not implemented yet")
312 }
313
314 // remove all the nodes that are not requisite
315 if (dsep_analysis) {
316 for (auto iter = _graph_.beginSafe(); iter != _graph_.endSafe(); ++iter) {
317 if (!requisite_nodes.contains(*iter) && !this->hardEvidenceNodes().contains(*iter)) {
318 _graph_.eraseNode(*iter);
319 }
320 }
321 }
322 }
323
324 // 4/ add edges so that each node and its parents in the BN form a clique
325 for (const auto node: _graph_) {
326 const NodeSet& parents = bn.parents(node);
327 for (auto iter1 = parents.cbegin(); iter1 != parents.cend(); ++iter1) {
328 // before adding an edge between node and its parent, check that the
329 // parent belong to the graph. Actually, when d-separated nodes are
330 // removed, it may be the case that the parents of hard evidence nodes
331 // are removed. But the latter still exist in the graph.
332 if (_graph_.existsNode(*iter1)) {
333 _graph_.addEdge(*iter1, node);
334
335 auto iter2 = iter1;
336 for (++iter2; iter2 != parents.cend(); ++iter2) {
337 // before adding an edge, check that both extremities belong to
338 // the graph. Actually, when d-separated nodes are removed, it may
339 // be the case that the parents of hard evidence nodes are removed.
340 // But the latter still exist in the graph.
341 if (_graph_.existsNode(*iter2)) _graph_.addEdge(*iter1, *iter2);
342 }
343 }
344 }
345 }
346
347 // 5/ if targets contains several nodes, we shall add new edges into the
348 // moral graph in order to ensure that there exists a clique containing
349 // their joint distribution
350 for (auto iter1 = targets.cbegin(); iter1 != targets.cend(); ++iter1) {
351 auto iter2 = iter1;
352 for (++iter2; iter2 != targets.cend(); ++iter2) {
353 _graph_.addEdge(*iter1, *iter2);
354 }
355 }
356
357 // 6/ remove all the nodes that received hard evidence
358 const auto& hard_ev_nodes = this->hardEvidenceNodes();
359 for (const auto node: hard_ev_nodes) {
360 _graph_.eraseNode(node);
361 }
362
363
364 // now, we can compute the new junction tree.
365 if (_JT_ != nullptr) delete _JT_;
366 _triangulation_->setGraph(&_graph_, &(this->domainSizes()));
367 const JunctionTree& triang_jt = _triangulation_->junctionTree();
368 _JT_ = new CliqueGraph(triang_jt);
369
370 // indicate, for each node of the moral graph, a clique in _JT_ that can
371 // contain its conditional probability table
372 _node_to_clique_.clear();
373 _clique_to_nodes_.clear();
374 NodeSet emptyset(_JT_->size());
375 for (auto clique: *_JT_)
376 _clique_to_nodes_.insert(clique, emptyset);
377 const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder();
378 NodeProperty< int > elim_order(Size(JT_elim_order.size()));
379 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size; ++i)
380 elim_order.insert(JT_elim_order[i], (int)i);
381 const DAG& dag = bn.dag();
382
383 for (const auto node: _graph_) {
384 // get the variables in the tensor of node (and its parents)
385 NodeId first_eliminated_node = node;
386 int elim_number = elim_order[first_eliminated_node];
387
388 for (const auto parent: dag.parents(node)) {
389 if (_graph_.existsNode(parent) && (elim_order[parent] < elim_number)) {
390 elim_number = elim_order[parent];
391 first_eliminated_node = parent;
392 }
393 }
394
395 // first_eliminated_node contains the first var (node or one of its
396 // parents) eliminated => the clique created during its elimination
397 // contains node and all of its parents => it can contain the tensor
398 // assigned to the node in the BN
399 NodeId clique = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
400 _node_to_clique_.insert(node, clique);
401 _clique_to_nodes_[clique].insert(node);
402 }
403
404 // do the same for the nodes that received hard evidence. Here, we only store
405 // the nodes for which at least one parent belongs to _graph_ (otherwise
406 // their CPT is just a constant real number).
407 for (const auto node: hard_ev_nodes) {
408 NodeId first_eliminated_node = std::numeric_limits< NodeId >::max();
409 int elim_number = std::numeric_limits< int >::max();
410
411 for (const auto parent: dag.parents(node)) {
412 if (_graph_.exists(parent) && (elim_order[parent] < elim_number)) {
413 elim_number = elim_order[parent];
414 first_eliminated_node = parent;
415 }
416 }
417
418 // first_eliminated_node contains the first var (node or one of its
419 // parents) eliminated => the clique created during its elimination
420 // contains node and all of its parents => it can contain the tensor
421 // assigned to the node in the BN
422 if (elim_number != std::numeric_limits< int >::max()) {
423 NodeId clique = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
424 _node_to_clique_.insert(node, clique);
425 _clique_to_nodes_[clique].insert(node);
426 }
427 }
428
429
430 // indicate a clique that contains all the nodes of targets
431 _targets2clique_ = std::numeric_limits< NodeId >::max();
432 {
433 // note that we remove from set all the nodes that received hard evidence
434 // (since they do not belong to the join tree)
435 NodeId first_eliminated_node = std::numeric_limits< NodeId >::max();
436 int elim_number = std::numeric_limits< int >::max();
437
438 for (const auto node: targets) {
439 if (!hard_ev_nodes.contains(node) && (elim_order[node] < elim_number)) {
440 elim_number = elim_order[node];
441 first_eliminated_node = node;
442 }
443 }
444
445 if (elim_number != std::numeric_limits< int >::max()) {
446 _targets2clique_ = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
447 }
448 }
449 }
450
452 template < typename GUM_SCALAR >
453 void VariableElimination< GUM_SCALAR >::updateOutdatedStructure_() {}
454
457 template < typename GUM_SCALAR >
458 void VariableElimination< GUM_SCALAR >::updateOutdatedTensors_() {}
459
460 // find the tensors d-connected to a set of variables
461 template < typename GUM_SCALAR >
462 void VariableElimination< GUM_SCALAR >::_findRelevantTensorsGetAll_(
463 Set< const IScheduleMultiDim* >& pot_list,
464 gum::VariableSet& kept_vars) {}
465
466 // find the tensors d-connected to a set of variables
467 template < typename GUM_SCALAR >
468 void VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation_(
469 Set< const IScheduleMultiDim* >& pot_list,
470 gum::VariableSet& kept_vars) {
471 // find the node ids of the kept variables
472 NodeSet kept_ids(kept_vars.size());
473 const auto& bn = this->BN();
474 for (const auto var: kept_vars) {
475 kept_ids.insert(bn.nodeId(*var));
476 }
477
478 // determine the set of tensors d-connected with the kept variables
479 NodeSet requisite_nodes;
480 BayesBall::requisiteNodes(bn.dag(),
481 kept_ids,
482 this->hardEvidenceNodes(),
483 this->softEvidenceNodes(),
484 requisite_nodes);
485 for (auto iter = pot_list.beginSafe(); iter != pot_list.endSafe(); ++iter) {
486 const Sequence< const DiscreteVariable* >& vars = (*iter)->variablesSequence();
487 bool found = false;
488 for (const auto var: vars) {
489 if (requisite_nodes.exists(bn.nodeId(*var))) {
490 found = true;
491 break;
492 }
493 }
494
495 if (!found) { pot_list.erase(iter); }
496 }
497 }
498
499 // find the tensors d-connected to a set of variables
500 template < typename GUM_SCALAR >
501 void VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation2_(
502 Set< const IScheduleMultiDim* >& pot_list,
503 gum::VariableSet& kept_vars) {
504 // find the node ids of the kept variables
505 NodeSet kept_ids(kept_vars.size());
506 const auto& bn = this->BN();
507 for (const auto var: kept_vars) {
508 kept_ids.insert(bn.nodeId(*var));
509 }
510
511 // determine the set of tensors d-connected with the kept variables
512 BayesBall::relevantTensors(bn,
513 kept_ids,
514 this->hardEvidenceNodes(),
515 this->softEvidenceNodes(),
516 pot_list);
517 }
518
519 // find the tensors d-connected to a set of variables
520 template < typename GUM_SCALAR >
521 void VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation3_(
522 Set< const IScheduleMultiDim* >& pot_list,
523 gum::VariableSet& kept_vars) {
524 // find the node ids of the kept variables
525 NodeSet kept_ids(kept_vars.size());
526 const auto& bn = this->BN();
527 for (const auto var: kept_vars) {
528 kept_ids.insert(bn.nodeId(*var));
529 }
530
531 // determine the set of tensors d-connected with the kept variables
532 dSeparationAlgorithm dsep;
533 dsep.relevantTensors(bn,
534 kept_ids,
535 this->hardEvidenceNodes(),
536 this->softEvidenceNodes(),
537 pot_list);
538 }
539
540 // find the tensors d-connected to a set of variables
541 template < typename GUM_SCALAR >
542 void VariableElimination< GUM_SCALAR >::_findRelevantTensorsXX_(
543 Set< const IScheduleMultiDim* >& pot_list,
544 gum::VariableSet& kept_vars) {
545 switch (_find_relevant_tensor_type_) {
546 case RelevantTensorsFinderType::DSEP_BAYESBALL_TENSORS :
547 _findRelevantTensorsWithdSeparation2_(pot_list, kept_vars);
548 break;
549
550 case RelevantTensorsFinderType::DSEP_BAYESBALL_NODES :
551 _findRelevantTensorsWithdSeparation_(pot_list, kept_vars);
552 break;
553
554 case RelevantTensorsFinderType::DSEP_KOLLER_FRIEDMAN_2009 :
555 _findRelevantTensorsWithdSeparation3_(pot_list, kept_vars);
556 break;
557
558 case RelevantTensorsFinderType::FIND_ALL :
559 _findRelevantTensorsGetAll_(pot_list, kept_vars);
560 break;
561
562 default : GUM_ERROR(FatalError, "not implemented yet")
563 }
564 }
565
566 // remove barren variables using schedules
567 template < typename GUM_SCALAR >
568 Set< const IScheduleMultiDim* >
569 VariableElimination< GUM_SCALAR >::_removeBarrenVariables_(Schedule& schedule,
570 _ScheduleMultiDimSet_& pot_list,
571 gum::VariableSet& del_vars) {
572 // remove from del_vars the variables that received some evidence:
573 // only those that did not receive evidence can be barren variables
574 gum::VariableSet the_del_vars = del_vars;
575 for (auto iter = the_del_vars.beginSafe(); iter != the_del_vars.endSafe(); ++iter) {
576 NodeId id = this->BN().nodeId(**iter);
577 if (this->hardEvidenceNodes().exists(id) || this->softEvidenceNodes().exists(id)) {
578 the_del_vars.erase(iter);
579 }
580 }
581
582 // assign to each random variable the set of tensors that contain it
583 HashTable< const DiscreteVariable*, _ScheduleMultiDimSet_ > var2pots(the_del_vars.size());
584 _ScheduleMultiDimSet_ empty_pot_set;
585 for (const auto pot: pot_list) {
586 const auto& vars = pot->variablesSequence();
587 for (const auto var: vars) {
588 if (the_del_vars.exists(var)) {
589 if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); }
590 var2pots[var].insert(pot);
591 }
592 }
593 }
594
595 // each variable with only one tensor is necessarily a barren variable
596 // assign to each tensor with barren nodes its set of barren variables
597 HashTable< const IScheduleMultiDim*, gum::VariableSet > pot2barren_var;
598 gum::VariableSet empty_var_set;
599 for (const auto& elt: var2pots) {
600 if (elt.second.size() == 1) { // here we have a barren variable
601 const IScheduleMultiDim* pot = *(elt.second.begin());
602 if (!pot2barren_var.exists(pot)) { pot2barren_var.insert(pot, empty_var_set); }
603 pot2barren_var[pot].insert(elt.first); // insert the barren variable
604 }
605 }
606
607 // for each tensor with barren variables, marginalize them.
608 // if the tensor has only barren variables, simply remove them from the
609 // set of tensors, else just project the tensor
610 MultiDimProjection< Tensor< GUM_SCALAR > > projector(_projection_op_);
611 _ScheduleMultiDimSet_ projected_pots;
612 for (const auto& elt: pot2barren_var) {
613 // remove the current tensor from pot_list as, anyway, we will change it
614 const IScheduleMultiDim* pot = elt.first;
615 pot_list.erase(pot);
616
617 // check whether we need to add a projected new tensor or not (i.e.,
618 // whether there exist non-barren variables or not)
619 if (pot->variablesSequence().size() != elt.second.size()) {
620 const IScheduleMultiDim* new_pot = projector.schedule(schedule, pot, elt.second);
621 // here, there is no need to enforce that new_pot is persistent since,
622 // if this is needed, the function that called _removeBarrenVariables_ will
623 // do it
624 pot_list.insert(new_pot);
625 projected_pots.insert(new_pot);
626 }
627 }
628
629 return projected_pots;
630 }
631
632 // remove barren variables directly without schedules
633 template < typename GUM_SCALAR >
634 Set< const Tensor< GUM_SCALAR >* >
635 VariableElimination< GUM_SCALAR >::_removeBarrenVariables_(_TensorSet_& pot_list,
636 gum::VariableSet& del_vars) {
637 // remove from del_vars the variables that received some evidence:
638 // only those that did not receive evidence can be barren variables
639 gum::VariableSet the_del_vars = del_vars;
640 for (auto iter = the_del_vars.beginSafe(); iter != the_del_vars.endSafe(); ++iter) {
641 NodeId id = this->BN().nodeId(**iter);
642 if (this->hardEvidenceNodes().exists(id) || this->softEvidenceNodes().exists(id)) {
643 the_del_vars.erase(iter);
644 }
645 }
646
647 // assign to each random variable the set of tensors that contain it
648 HashTable< const DiscreteVariable*, _TensorSet_ > var2pots;
649 _TensorSet_ empty_pot_set;
650 for (const auto pot: pot_list) {
651 const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
652 for (const auto var: vars) {
653 if (the_del_vars.exists(var)) {
654 if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); }
655 var2pots[var].insert(pot);
656 }
657 }
658 }
659
660 // each variable with only one tensor is a barren variable
661 // assign to each tensor with barren nodes its set of barren variables
662 HashTable< const Tensor< GUM_SCALAR >*, gum::VariableSet > pot2barren_var;
663 gum::VariableSet empty_var_set;
664 for (const auto& elt: var2pots) {
665 if (elt.second.size() == 1) { // here we have a barren variable
666 const Tensor< GUM_SCALAR >* pot = *(elt.second.begin());
667 if (!pot2barren_var.exists(pot)) { pot2barren_var.insert(pot, empty_var_set); }
668 pot2barren_var[pot].insert(elt.first); // insert the barren variable
669 }
670 }
671
672 // for each tensor with barren variables, marginalize them.
673 // if the tensor has only barren variables, simply remove them from the
674 // set of tensors, else just project the tensor
675 MultiDimProjection< Tensor< GUM_SCALAR > > projector(_projection_op_);
676 _TensorSet_ projected_pots;
677 for (const auto& elt: pot2barren_var) {
678 // remove the current tensor from pot_list as, anyway, we will change it
679 const Tensor< GUM_SCALAR >* pot = elt.first;
680 pot_list.erase(pot);
681
682 // check whether we need to add a projected new tensor or not (i.e.,
683 // whether there exist non-barren variables or not)
684 if (pot->variablesSequence().size() != elt.second.size()) {
685 const Tensor< GUM_SCALAR >* new_pot = projector.execute(*pot, elt.second);
686 pot_list.insert(new_pot);
687 projected_pots.insert(new_pot);
688 }
689 }
690
691 return projected_pots;
692 }
693
694 // performs the collect phase of Variable Elimination
695 template < typename GUM_SCALAR >
696 Set< const IScheduleMultiDim* >
697 VariableElimination< GUM_SCALAR >::_collectMessage_(Schedule& schedule,
698 NodeId id,
699 NodeId from) {
700 // collect messages from all the neighbors
701 _ScheduleMultiDimSet_ collected_messages;
702 for (const auto other: _JT_->neighbours(id)) {
703 if (other != from) {
704 _ScheduleMultiDimSet_ message(_collectMessage_(schedule, other, id));
705 collected_messages += message;
706 }
707 }
708
709 // combine the collect messages with those of id's clique
710 return _produceMessage_(schedule, id, from, std::move(collected_messages));
711 }
712
713 // performs the collect phase of Variable Elimination
714 template < typename GUM_SCALAR >
715 std::pair< Set< const Tensor< GUM_SCALAR >* >, Set< const Tensor< GUM_SCALAR >* > >
716 VariableElimination< GUM_SCALAR >::_collectMessage_(NodeId id, NodeId from) {
717 // collect messages from all the neighbors
718 std::pair< _TensorSet_, _TensorSet_ > collected_messages;
719 for (const auto other: _JT_->neighbours(id)) {
720 if (other != from) {
721 std::pair< _TensorSet_, _TensorSet_ > message(_collectMessage_(other, id));
722 collected_messages.first += message.first;
723 collected_messages.second += message.second;
724 }
725 }
726
727 // combine the collect messages with those of id's clique
728 return _produceMessage_(id, from, std::move(collected_messages));
729 }
730
731 // get the CPT + evidence of a node projected w.r.t. hard evidence
732 template < typename GUM_SCALAR >
733 Set< const IScheduleMultiDim* >
734 VariableElimination< GUM_SCALAR >::_NodeTensors_(Schedule& schedule, NodeId node) {
735 _ScheduleMultiDimSet_ res;
736 const auto& bn = this->BN();
737
738 // get the CPT of the node
739 // Beware: all the tensors that are defined over some nodes that
740 // received hard evidence must be projected so that these nodes are
741 // removed from the tensor.
742 // Also beware that the CPT of a hard evidence node may be defined over
743 // parents that do not belong to _graph_ and that are not hard evidence.
744 // In this case, those parents have been removed by d-separation and it is
745 // easy to show that, in this case, all the parents have been removed, so
746 // that the CPT does not need to be taken into account
747 const auto& evidence = this->evidence();
748 const auto& hard_evidence = this->hardEvidence();
749 const auto& hard_ev_nodes = this->hardEvidenceNodes();
750 if (_graph_.exists(node) || hard_ev_nodes.contains(node)) {
751 const Tensor< GUM_SCALAR >& cpt = bn.cpt(node);
752 const auto& variables = cpt.variablesSequence();
753
754 // check if the parents of a hard evidence node do not belong to _graph_
755 // and are not themselves hard evidence. In this case, discard the CPT as
756 // it is useless for inference (see the above comment)
757 if (hard_ev_nodes.contains(node)) {
758 for (const auto var: variables) {
759 NodeId xnode = bn.nodeId(*var);
760 if (!hard_ev_nodes.contains(xnode) && !_graph_.existsNode(xnode)) return res;
761 }
762 }
763
764 // get the list of nodes with hard evidence in cpt
765 NodeSet hard_nodes(variables.size());
766 for (const auto var: variables) {
767 NodeId xnode = bn.nodeId(*var);
768 if (hard_ev_nodes.contains(xnode)) hard_nodes.insert(xnode);
769 }
770
771 // if hard_nodes contains hard evidence nodes, perform a projection
772 // and insert the result into the appropriate clique, else insert
773 // directly cpt into the clique
774 if (hard_nodes.empty()) {
775 const IScheduleMultiDim* sched_cpt
776 = schedule.insertTable< Tensor< GUM_SCALAR > >(cpt, false);
777 res.insert(sched_cpt);
778 } else {
779 // marginalize out the hard evidence nodes: if the cpt is defined
780 // only over nodes that received hard evidence, do not consider it
781 // as a tensor anymore
782 if (hard_nodes.size() != variables.size()) {
783 // perform the projection with a combine and project instance
784 gum::VariableSet hard_variables;
785 _ScheduleMultiDimSet_ marg_cpt_set(1 + hard_nodes.size());
786 const IScheduleMultiDim* sched_cpt
787 = schedule.insertTable< Tensor< GUM_SCALAR > >(cpt, false);
788 marg_cpt_set.insert(sched_cpt);
789
790 for (const auto xnode: hard_nodes) {
791 const IScheduleMultiDim* pot
792 = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[xnode], false);
793 marg_cpt_set.insert(pot);
794 hard_variables.insert(&(bn.variable(xnode)));
795 }
796
797 // perform the combination of those tensors and their projection
798 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
799 _combination_op_,
800 _projection_op_);
801 _ScheduleMultiDimSet_ new_cpt_list
802 = combine_and_project.schedule(schedule, marg_cpt_set, hard_variables);
803
804 // there should be only one tensor in new_cpt_list
805 if (new_cpt_list.size() != 1) {
807 "the projection of a tensor containing " << "hard evidence is empty!");
808 }
809 auto projected_pot = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
810 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
811 *new_cpt_list.begin()));
812 res.insert(projected_pot);
813 }
814 }
815
816 // if the node received some soft evidence, add it
817 if (evidence.exists(node) && !hard_evidence.exists(node)) {
818 const IScheduleMultiDim* pot
819 = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[node], false);
820 res.insert(pot);
821 }
822 }
823
824 return res;
825 }
826
827 // get the CPT + evidence of a node projected w.r.t. hard evidence
828 template < typename GUM_SCALAR >
829 std::pair< Set< const Tensor< GUM_SCALAR >* >, Set< const Tensor< GUM_SCALAR >* > >
830 VariableElimination< GUM_SCALAR >::_NodeTensors_(NodeId node) {
831 std::pair< _TensorSet_, _TensorSet_ > res;
832 const auto& bn = this->BN();
833
834 // get the CPT's of the node
835 // beware: all the tensors that are defined over some nodes
836 // including hard evidence must be projected so that these nodes are
837 // removed from the tensor
838 // also beware that the CPT of a hard evidence node may be defined over
839 // parents that do not belong to _graph_ and that are not hard evidence.
840 // In this case, those parents have been removed by d-separation and it is
841 // easy to show that, in this case all the parents have been removed, so
842 // that the CPT does not need to be taken into account
843 const auto& evidence = this->evidence();
844 const auto& hard_evidence = this->hardEvidence();
845 const auto& hard_ev_nodes = this->hardEvidenceNodes();
846 if (_graph_.exists(node) || hard_ev_nodes.contains(node)) {
847 const Tensor< GUM_SCALAR >& cpt = bn.cpt(node);
848 const auto& variables = cpt.variablesSequence();
849
850 // check if the parents of a hard evidence node do not belong to _graph_
851 // and are not themselves hard evidence. In this case, discard the CPT as
852 // it is useless for inference (see the above comment)
853 if (hard_ev_nodes.contains(node)) {
854 for (const auto var: variables) {
855 NodeId xnode = bn.nodeId(*var);
856 if (!hard_ev_nodes.contains(xnode) && !_graph_.existsNode(xnode)) return res;
857 }
858 }
859
860 // get the list of nodes with hard evidence in cpt
861 NodeSet hard_nodes(variables.size());
862 for (const auto var: variables) {
863 NodeId xnode = bn.nodeId(*var);
864 if (hard_ev_nodes.contains(xnode)) hard_nodes.insert(xnode);
865 }
866
867 // if hard_nodes contains hard evidence nodes, perform a projection
868 // and insert the result into the appropriate clique, else insert
869 // directly cpt into the clique
870 if (hard_nodes.empty()) {
871 res.first.insert(&cpt);
872 } else {
873 // marginalize out the hard evidence nodes: if the cpt is defined
874 // only over nodes that received hard evidence, do not consider it
875 // as a tensor anymore
876 if (hard_nodes.size() != variables.size()) {
877 // perform the projection with a combine and project instance
878 gum::VariableSet hard_variables;
879 _TensorSet_ marg_cpt_set(1 + hard_nodes.size());
880 marg_cpt_set.insert(&cpt);
881
882 for (const auto xnode: hard_nodes) {
883 marg_cpt_set.insert(evidence[xnode]);
884 hard_variables.insert(&(bn.variable(xnode)));
885 }
886 // perform the combination of those tensors and their projection
887 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
888 _combination_op_,
889 VENewprojTensor);
890 _TensorSet_ new_cpt_list = combine_and_project.execute(marg_cpt_set, hard_variables);
891
892 // there should be only one tensor in new_cpt_list
893 if (new_cpt_list.size() != 1) {
894 // remove the CPT created to avoid memory leaks
895 for (auto pot: new_cpt_list) {
896 if (!marg_cpt_set.contains(pot)) delete pot;
897 }
899 "the projection of a tensor containing " << "hard evidence is empty!");
900 }
901 const Tensor< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
902 res.first.insert(projected_cpt);
903 res.second.insert(projected_cpt);
904 }
905 }
906
907 // if the node received some soft evidence, add it
908 if (evidence.exists(node) && !hard_evidence.exists(node)) {
909 res.first.insert(this->evidence()[node]);
910 }
911 }
912
913 return res;
914 }
915
916 // creates the message sent by clique from_id to clique to_id
917 template < typename GUM_SCALAR >
918 std::pair< Set< const Tensor< GUM_SCALAR >* >, Set< const Tensor< GUM_SCALAR >* > >
919 VariableElimination< GUM_SCALAR >::_produceMessage_(
920 NodeId from_id,
921 NodeId to_id,
922 std::pair< Set< const Tensor< GUM_SCALAR >* >, Set< const Tensor< GUM_SCALAR >* > >&&
923 incoming_messages) {
924 // get the messages sent by adjacent nodes to from_id
925 std::pair< _TensorSet_, _TensorSet_ > pot_list(std::move(incoming_messages));
926
927 // get the tensors of the clique
928 for (const auto node: _clique_to_nodes_[from_id]) {
929 auto new_pots = _NodeTensors_(node);
930 pot_list.first += new_pots.first;
931 pot_list.second += new_pots.second;
932 }
933
934 // if from_id = to_id: this is the endpoint of a collect
935 if (!_JT_->existsEdge(from_id, to_id)) {
936 return pot_list;
937 } else {
938 // get the set of variables that need to be removed from the tensors
939 const NodeSet& from_clique = _JT_->clique(from_id);
940 const NodeSet& separator = _JT_->separator(from_id, to_id);
941 gum::VariableSet del_vars(from_clique.size());
942 gum::VariableSet kept_vars(separator.size());
943 const auto& bn = this->BN();
944
945 for (const auto node: from_clique) {
946 if (!separator.contains(node)) {
947 del_vars.insert(&(bn.variable(node)));
948 } else {
949 kept_vars.insert(&(bn.variable(node)));
950 }
951 }
952
953 // pot_list now contains all the tensors to multiply and marginalize
954 // => combine the messages
955 _TensorSet_ new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars);
956
957 // remove the unnecessary temporary messages
958 for (auto iter = pot_list.second.beginSafe(); iter != pot_list.second.endSafe(); ++iter) {
959 if (!new_pot_list.contains(*iter)) {
960 delete *iter;
961 pot_list.second.erase(iter);
962 }
963 }
964
965 // keep track of all the newly created tensors
966 for (const auto pot: new_pot_list) {
967 if (!pot_list.first.contains(pot)) { pot_list.second.insert(pot); }
968 }
969
970 // return the new set of tensors
971 return std::pair< _TensorSet_, _TensorSet_ >(std::move(new_pot_list),
972 std::move(pot_list.second));
973 }
974 }
975
976 // creates the message sent by clique from_id to clique to_id
977 template < typename GUM_SCALAR >
978 Set< const IScheduleMultiDim* > VariableElimination< GUM_SCALAR >::_produceMessage_(
979 Schedule& schedule,
980 NodeId from_id,
981 NodeId to_id,
982 Set< const IScheduleMultiDim* >&& incoming_messages) {
983 // get the messages sent by adjacent nodes to from_id
984 _ScheduleMultiDimSet_ pot_list(std::move(incoming_messages));
985
986 // get the tensors of the clique
987 for (const auto node: _clique_to_nodes_[from_id]) {
988 pot_list += _NodeTensors_(schedule, node);
989 }
990
991 // if from_id = to_id: this is the endpoint of a collect
992 if (!_JT_->existsEdge(from_id, to_id)) {
993 return pot_list;
994 } else {
995 // get the set of variables that need be removed from the tensors
996 const NodeSet& from_clique = _JT_->clique(from_id);
997 const NodeSet& separator = _JT_->separator(from_id, to_id);
998 gum::VariableSet del_vars(from_clique.size());
999 gum::VariableSet kept_vars(separator.size());
1000 const auto& bn = this->BN();
1001
1002 for (const auto node: from_clique) {
1003 if (!separator.contains(node)) {
1004 del_vars.insert(&(bn.variable(node)));
1005 } else {
1006 kept_vars.insert(&(bn.variable(node)));
1007 }
1008 }
1009
1010 // pot_list now contains all the tensors to multiply and marginalize
1011 // => combine the messages
1012 _ScheduleMultiDimSet_ new_pot_list
1013 = _marginalizeOut_(schedule, pot_list, del_vars, kept_vars);
1014
1015 // remove the unnecessary temporary messages
1016 for (auto pot: pot_list) {
1017 if (!new_pot_list.contains(pot)) {
1018 const auto sched_pot
1019 = static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(pot);
1020 schedule.emplaceDeletion(*sched_pot);
1021 }
1022 }
1023
1024 // return the new set of tensors
1025 return new_pot_list;
1026 }
1027 }
1028
1029 // remove variables del_vars from the list of tensors pot_list
1030 template < typename GUM_SCALAR >
1031 Set< const Tensor< GUM_SCALAR >* > VariableElimination< GUM_SCALAR >::_marginalizeOut_(
1032 Set< const Tensor< GUM_SCALAR >* > pot_list,
1033 gum::VariableSet& del_vars,
1034 gum::VariableSet& kept_vars) {
1035 // if pot list is empty, do nothing. This may happen when there are many barren variables
1036 if (pot_list.empty()) { return _TensorSet_(); }
1037
1038 // use d-separation analysis to check which tensors shall be combined
1039 // _findRelevantTensorsXX_(pot_list, kept_vars);
1040
1041 // remove the tensors corresponding to barren variables if we want
1042 // to exploit barren nodes
1043 _TensorSet_ barren_projected_tensors;
1044 if (_barren_nodes_type_ == FindBarrenNodesType::FIND_BARREN_NODES) {
1045 barren_projected_tensors = _removeBarrenVariables_(pot_list, del_vars);
1046 }
1047
1048 // Combine and project the remaining tensors
1049 _TensorSet_ new_pot_list;
1050 if (pot_list.size() == 1) {
1051 MultiDimProjection< Tensor< GUM_SCALAR > > projector(_projection_op_);
1052 auto pot = projector.execute(**(pot_list.begin()), del_vars);
1053 new_pot_list.insert(pot);
1054 } else if (pot_list.size() > 1) {
1055 // create a combine and project operator that will perform the
1056 // marginalization
1057 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(_combination_op_,
1058 _projection_op_);
1059 new_pot_list = combine_and_project.execute(pot_list, del_vars);
1060 }
1061
1062 // remove all the tensors that were created due to projections of
1063 // barren nodes and that are not part of the new_pot_list: these
1064 // tensors were just temporary tensors
1065 for (auto iter = barren_projected_tensors.beginSafe();
1066 iter != barren_projected_tensors.endSafe();
1067 ++iter) {
1068 if (!new_pot_list.exists(*iter)) delete *iter;
1069 }
1070
1071 return new_pot_list;
1072 }
1073
1074 // remove variables del_vars from the list of tensors pot_list
1075 template < typename GUM_SCALAR >
1076 Set< const IScheduleMultiDim* >
1077 VariableElimination< GUM_SCALAR >::_marginalizeOut_(Schedule& schedule,
1078 Set< const IScheduleMultiDim* > pot_list,
1079 gum::VariableSet& del_vars,
1080 gum::VariableSet& kept_vars) {
1081 // if pot list is empty, do nothing. This may happen when there are only barren variables
1082 if (pot_list.empty()) { return _ScheduleMultiDimSet_(); }
1083
1084 // use d-separation analysis to check which tensors shall be combined
1085 // _findRelevantTensorsXX_(pot_list, kept_vars);
1086
1087 // now, let's guarantee that all the tensors to be combined and projected
1088 // belong to the schedule
1089 for (const auto pot: pot_list) {
1090 if (!schedule.existsScheduleMultiDim(pot->id())) schedule.emplaceScheduleMultiDim(*pot);
1091 }
1092
1093 // remove the tensors corresponding to barren variables if we want
1094 // to exploit barren nodes
1095 _ScheduleMultiDimSet_ barren_projected_tensors;
1096 if (_barren_nodes_type_ == FindBarrenNodesType::FIND_BARREN_NODES) {
1097 barren_projected_tensors = _removeBarrenVariables_(schedule, pot_list, del_vars);
1098 }
1099
1100 // Combine and project the tensors
1101 _ScheduleMultiDimSet_ new_pot_list;
1102 if (pot_list.size() == 1) { // only one tensor, so just project it
1103 MultiDimProjection< Tensor< GUM_SCALAR > > projector(_projection_op_);
1104 auto xpot = projector.schedule(schedule, *(pot_list.begin()), del_vars);
1105 new_pot_list.insert(xpot);
1106 } else if (pot_list.size() > 1) {
1107 // create a combine and project operator that will perform the
1108 // marginalization
1109 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(_combination_op_,
1110 _projection_op_);
1111 new_pot_list = combine_and_project.schedule(schedule, pot_list, del_vars);
1112 }
1113
1114 // remove all the tensors that were created due to projections of
1115 // barren nodes and that are not part of the new_pot_list: these
1116 // tensors were just temporary tensors
1117 for (auto pot: barren_projected_tensors) {
1118 if (!new_pot_list.exists(pot)) {
1119 const auto sched_pot = static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(pot);
1120 schedule.emplaceDeletion(*sched_pot);
1121 }
1122 }
1123
1124 return new_pot_list;
1125 }
1126
1127 // performs a whole inference
1128 template < typename GUM_SCALAR >
1129 INLINE void VariableElimination< GUM_SCALAR >::makeInference_() {}
1130
1132 template < typename GUM_SCALAR >
1133 Tensor< GUM_SCALAR >* VariableElimination< GUM_SCALAR >::unnormalizedJointPosterior_(NodeId id) {
1134 // hard evidence do not belong to the join tree
1135 // # TODO: check for sets of inconsistent hard evidence
1136 if (this->hardEvidenceNodes().contains(id)) {
1137 return new Tensor< GUM_SCALAR >(*(this->evidence()[id]));
1138 }
1139
1140 // if we still need to perform some inference task, do it
1141 _createNewJT_(NodeSet{id});
1142
1143 // here, we determine whether we should use schedules during the inference.
1144 // the rule is: if the sum of the domain sizes of the cliques is greater
1145 // than a threshold, use schedules
1146 double overall_size = 0;
1147 for (const auto clique: *_JT_) {
1148 double clique_size = 1.0;
1149 for (const auto node: _JT_->clique(clique))
1150 clique_size *= this->domainSizes()[node];
1151 overall_size += clique_size;
1152 }
1153 const bool use_schedules = (overall_size > _schedule_threshold_);
1154
1155 if (use_schedules) {
1156 Schedule schedule;
1157 return _unnormalizedJointPosterior_(schedule, id);
1158 } else {
1159 return _unnormalizedJointPosterior_(id);
1160 }
1161 }
1162
1164 template < typename GUM_SCALAR >
1165 Tensor< GUM_SCALAR >* VariableElimination< GUM_SCALAR >::_unnormalizedJointPosterior_(NodeId id) {
1166 const auto& bn = this->BN();
1167
1168 NodeId clique_of_id = _node_to_clique_[id];
1169 std::pair< _TensorSet_, _TensorSet_ > pot_list = _collectMessage_(clique_of_id, clique_of_id);
1170
1171 // get the set of variables that need be removed from the tensors
1172 const NodeSet& nodes = _JT_->clique(clique_of_id);
1173 gum::VariableSet kept_vars{&(bn.variable(id))};
1174 gum::VariableSet del_vars(nodes.size());
1175 for (const auto node: nodes) {
1176 if (node != id) del_vars.insert(&(bn.variable(node)));
1177 }
1178
1179 // pot_list now contains all the tensors to multiply and marginalize
1180 // => combine the messages
1181 _TensorSet_ new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars);
1182 Tensor< GUM_SCALAR >* joint = nullptr;
1183
1184 if (new_pot_list.size() == 0) {
1185 joint = new Tensor< GUM_SCALAR >;
1186 for (const auto var: kept_vars)
1187 *joint << *var;
1188 } else {
1189 if (new_pot_list.size() == 1) {
1190 joint = const_cast< Tensor< GUM_SCALAR >* >(*(new_pot_list.begin()));
1191 // if joint already existed, create a copy, so that we can put it into
1192 // the _target_posterior_ property
1193 if (pot_list.first.exists(joint)) {
1194 joint = new Tensor< GUM_SCALAR >(*joint);
1195 } else {
1196 // remove the joint from new_pot_list so that it will not be
1197 // removed just after the else block
1198 new_pot_list.clear();
1199 }
1200 } else {
1201 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1202 joint = fast_combination.execute(new_pot_list);
1203 }
1204 }
1205
1206 // remove the tensors that were created in new_pot_list
1207 for (auto pot: new_pot_list)
1208 if (!pot_list.first.exists(pot)) delete pot;
1209
1210 // remove all the temporary tensors created in pot_list
1211 for (auto pot: pot_list.second)
1212 delete pot;
1213
1214 // check that the joint posterior is different from a 0 vector: this would
1215 // indicate that some hard evidence are not compatible (their joint
1216 // probability is equal to 0)
1217 bool nonzero_found = false;
1218 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1219 if ((*joint)[inst]) {
1220 nonzero_found = true;
1221 break;
1222 }
1223 }
1224 if (!nonzero_found) {
1225 // remove joint from memory to avoid memory leaks
1226 delete joint;
1228 "some evidence entered into the Bayes "
1229 "net are incompatible (their joint proba = 0)");
1230 }
1231
1232 return joint;
1233 }
1234
1236 template < typename GUM_SCALAR >
1237 Tensor< GUM_SCALAR >*
1238 VariableElimination< GUM_SCALAR >::_unnormalizedJointPosterior_(Schedule& schedule,
1239 NodeId id) {
1240 const auto& bn = this->BN();
1241
1242 NodeId clique_of_id = _node_to_clique_[id];
1243 _ScheduleMultiDimSet_ pot_list = _collectMessage_(schedule, clique_of_id, clique_of_id);
1244
1245 // get the set of variables that need be removed from the tensors
1246 const NodeSet& nodes = _JT_->clique(clique_of_id);
1247 gum::VariableSet kept_vars{&(bn.variable(id))};
1248 gum::VariableSet del_vars(nodes.size());
1249 for (const auto node: nodes) {
1250 if (node != id) del_vars.insert(&(bn.variable(node)));
1251 }
1252
1253 // pot_list now contains all the tensors to multiply and marginalize
1254 // => combine the messages
1255 _ScheduleMultiDimSet_ new_pot_list = _marginalizeOut_(schedule, pot_list, del_vars, kept_vars);
1256 Tensor< GUM_SCALAR >* joint = nullptr;
1257 ScheduleMultiDim< Tensor< GUM_SCALAR > >* resulting_pot = nullptr;
1258
1259 if (new_pot_list.size() == 0) {
1260 joint = new Tensor< GUM_SCALAR >;
1261 for (const auto var: kept_vars)
1262 *joint << *var;
1263 } else {
1264 auto& scheduler = this->scheduler();
1265 if (new_pot_list.size() == 1) {
1266 scheduler.execute(schedule);
1267 resulting_pot = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1268 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(*new_pot_list.begin()));
1269 } else {
1270 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1271 const IScheduleMultiDim* pot = fast_combination.schedule(schedule, new_pot_list);
1272 scheduler.execute(schedule);
1273 resulting_pot = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1274 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(pot));
1275 }
1276
1277 // if resulting_pot already existed, create a copy, so that we can put it into
1278 // the _target_posteriors_ property
1279 if (pot_list.exists(resulting_pot)) {
1280 joint = new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1281 } else {
1282 joint = resulting_pot->exportMultiDim();
1283 }
1284 }
1285
1286 // check that the joint posterior is different from a 0 vector: this would
1287 // indicate that some hard evidence are not compatible (their joint
1288 // probability is equal to 0)
1289 bool nonzero_found = false;
1290 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1291 if ((*joint)[inst]) {
1292 nonzero_found = true;
1293 break;
1294 }
1295 }
1296 if (!nonzero_found) {
1297 // remove joint from memory to avoid memory leaks
1298 delete joint;
1300 "some evidence entered into the Bayes "
1301 "net are incompatible (their joint proba = 0)");
1302 }
1303
1304 return joint;
1305 }
1306
1308 template < typename GUM_SCALAR >
1309 const Tensor< GUM_SCALAR >& VariableElimination< GUM_SCALAR >::posterior_(NodeId id) {
1310 // compute the joint posterior and normalize
1311 auto joint = unnormalizedJointPosterior_(id);
1312 if (joint->sum() != 1) // hard test for ReadOnly CPT (as aggregator)
1313 joint->normalize();
1314
1315 if (_target_posterior_ != nullptr) delete _target_posterior_;
1316 _target_posterior_ = joint;
1317
1318 return *joint;
1319 }
1320
1321 // returns the marginal a posteriori proba of a given node
1322 template < typename GUM_SCALAR >
1323 Tensor< GUM_SCALAR >*
1324 VariableElimination< GUM_SCALAR >::unnormalizedJointPosterior_(const NodeSet& set) {
1325 // hard evidence do not belong to the join tree, so extract the nodes
1326 // from targets that are not hard evidence
1327 NodeSet targets = set, hard_ev_nodes(this->hardEvidenceNodes().size());
1328 for (const auto node: this->hardEvidenceNodes()) {
1329 if (targets.contains(node)) {
1330 targets.erase(node);
1331 hard_ev_nodes.insert(node);
1332 }
1333 }
1334
1335 // if all the nodes have received hard evidence, then compute the
1336 // joint posterior directly by multiplying the hard evidence tensors
1337 const auto& evidence = this->evidence();
1338 if (targets.empty()) {
1339 _TensorSet_ pot_list;
1340 for (const auto node: set) {
1341 pot_list.insert(evidence[node]);
1342 }
1343 if (pot_list.size() == 1) {
1344 return new Tensor< GUM_SCALAR >(**(pot_list.begin()));
1345 } else {
1346 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1347 return fast_combination.execute(pot_list);
1348 }
1349 }
1350
1351 // if we still need to perform some inference task, do it
1352 _createNewJT_(set);
1353
1354 // here, we determine whether we should use schedules during the inference.
1355 // the rule is: if the sum of the domain sizes of the cliques is greater
1356 // than a threshold, use schedules
1357 double overall_size = 0;
1358 for (const auto clique: *_JT_) {
1359 double clique_size = 1.0;
1360 for (const auto node: _JT_->clique(clique))
1361 clique_size *= this->domainSizes()[node];
1362 overall_size += clique_size;
1363 }
1364 const bool use_schedules = (overall_size > _schedule_threshold_);
1365
1366 if (use_schedules) {
1367 Schedule schedule;
1368 return _unnormalizedJointPosterior_(schedule, set, targets, hard_ev_nodes);
1369 } else {
1370 return _unnormalizedJointPosterior_(set, targets, hard_ev_nodes);
1371 }
1372 }
1373
1374 // returns the marginal a posteriori proba of a given node
1375 template < typename GUM_SCALAR >
1376 Tensor< GUM_SCALAR >* VariableElimination< GUM_SCALAR >::_unnormalizedJointPosterior_(
1377 const NodeSet& set,
1378 const NodeSet& targets,
1379 const NodeSet& hard_ev_nodes) {
1380 std::pair< _TensorSet_, _TensorSet_ > pot_list
1381 = _collectMessage_(_targets2clique_, _targets2clique_);
1382
1383 // get the set of variables that need be removed from the tensors
1384 const NodeSet& nodes = _JT_->clique(_targets2clique_);
1385 gum::VariableSet del_vars(nodes.size());
1386 gum::VariableSet kept_vars(targets.size());
1387 const auto& bn = this->BN();
1388 for (const auto node: nodes) {
1389 if (!targets.contains(node)) {
1390 del_vars.insert(&(bn.variable(node)));
1391 } else {
1392 kept_vars.insert(&(bn.variable(node)));
1393 }
1394 }
1395
1396 // pot_list now contains all the tensors to multiply and marginalize
1397 // => combine the messages
1398 _TensorSet_ new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars);
1399 Tensor< GUM_SCALAR >* joint = nullptr;
1400
1401 if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) {
1402 joint = const_cast< Tensor< GUM_SCALAR >* >(*(new_pot_list.begin()));
1403 // if pot already existed, create a copy, so that we can put it into
1404 // the _target_posteriors_ property
1405 if (pot_list.first.exists(joint)) {
1406 joint = new Tensor< GUM_SCALAR >(*joint);
1407 } else {
1408 // remove the joint from new_pot_list so that it will not be
1409 // removed just after the next else block
1410 new_pot_list.clear();
1411 }
1412 } else {
1413 // combine all the tensors in new_pot_list with all the hard evidence
1414 // of the nodes in set
1415 const auto& evidence = this->evidence();
1416 _TensorSet_ new_new_pot_list = new_pot_list;
1417 for (const auto node: hard_ev_nodes) {
1418 new_new_pot_list.insert(evidence[node]);
1419 }
1420 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1421 joint = fast_combination.execute(new_new_pot_list);
1422 }
1423
1424 // remove the tensors that were created in new_pot_list
1425 for (auto pot: new_pot_list)
1426 if (!pot_list.first.exists(pot)) delete pot;
1427
1428 // remove all the temporary tensors created in pot_list
1429 for (auto pot: pot_list.second)
1430 delete pot;
1431
1432 // check that the joint posterior is different from a 0 vector: this would
1433 // indicate that some hard evidence are not compatible
1434 bool nonzero_found = false;
1435 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1436 if ((*joint)[inst]) {
1437 nonzero_found = true;
1438 break;
1439 }
1440 }
1441 if (!nonzero_found) {
1442 // remove joint from memory to avoid memory leaks
1443 delete joint;
1445 "some evidence entered into the Bayes "
1446 "net are incompatible (their joint proba = 0)");
1447 }
1448
1449 return joint;
1450 }
1451
1452 // returns the marginal a posteriori proba of a given node
1453 template < typename GUM_SCALAR >
1454 Tensor< GUM_SCALAR >* VariableElimination< GUM_SCALAR >::_unnormalizedJointPosterior_(
1455 Schedule& schedule,
1456 const NodeSet& set,
1457 const NodeSet& targets,
1458 const NodeSet& hard_ev_nodes) {
1459 _ScheduleMultiDimSet_ pot_list = _collectMessage_(schedule, _targets2clique_, _targets2clique_);
1460
1461 // get the set of variables that need be removed from the tensors
1462 const NodeSet& nodes = _JT_->clique(_targets2clique_);
1463 gum::VariableSet del_vars(nodes.size());
1464 gum::VariableSet kept_vars(targets.size());
1465 const auto& bn = this->BN();
1466 for (const auto node: nodes) {
1467 if (!targets.contains(node)) {
1468 del_vars.insert(&(bn.variable(node)));
1469 } else {
1470 kept_vars.insert(&(bn.variable(node)));
1471 }
1472 }
1473
1474 // pot_list now contains all the tensors to multiply and marginalize
1475 // => combine the messages
1476 _ScheduleMultiDimSet_ new_pot_list = _marginalizeOut_(schedule, pot_list, del_vars, kept_vars);
1477 ScheduleMultiDim< Tensor< GUM_SCALAR > >* resulting_pot = nullptr;
1478 auto& scheduler = this->scheduler();
1479
1480 if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) {
1481 scheduler.execute(schedule);
1482 resulting_pot = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1483 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(*new_pot_list.begin()));
1484 } else {
1485 // combine all the tensors in new_pot_list with all the hard evidence
1486 // of the nodes in set
1487 const auto& evidence = this->evidence();
1488 for (const auto node: hard_ev_nodes) {
1489 auto new_pot_ev = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[node], false);
1490 new_pot_list.insert(new_pot_ev);
1491 }
1492 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1493 const auto pot = fast_combination.schedule(schedule, new_pot_list);
1494 scheduler.execute(schedule);
1495 resulting_pot = const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(
1496 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR > >* >(pot));
1497 }
1498
1499 // if pot already existed, create a copy, so that we can put it into
1500 // the _target_posteriors_ property
1501 Tensor< GUM_SCALAR >* joint = nullptr;
1502 if (pot_list.exists(resulting_pot)) {
1503 joint = new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1504 } else {
1505 joint = resulting_pot->exportMultiDim();
1506 }
1507
1508 // check that the joint posterior is different from a 0 vector: this would
1509 // indicate that some hard evidence are not compatible
1510 bool nonzero_found = false;
1511 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1512 if ((*joint)[inst]) {
1513 nonzero_found = true;
1514 break;
1515 }
1516 }
1517 if (!nonzero_found) {
1518 // remove joint from memory to avoid memory leaks
1519 delete joint;
1521 "some evidence entered into the Bayes "
1522 "net are incompatible (their joint proba = 0)");
1523 }
1524
1525 return joint;
1526 }
1527
1529 template < typename GUM_SCALAR >
1530 const Tensor< GUM_SCALAR >&
1531 VariableElimination< GUM_SCALAR >::jointPosterior_(const NodeSet& set) {
1532 // compute the joint posterior and normalize
1533 auto joint = unnormalizedJointPosterior_(set);
1534 joint->normalize();
1535
1536 if (_target_posterior_ != nullptr) delete _target_posterior_;
1537 _target_posterior_ = joint;
1538
1539 return *joint;
1540 }
1541
1543 template < typename GUM_SCALAR >
1544 const Tensor< GUM_SCALAR >&
1545 VariableElimination< GUM_SCALAR >::jointPosterior_(const NodeSet& wanted_target,
1546 const NodeSet& declared_target) {
1547 return jointPosterior_(wanted_target);
1548 }
1549
1550
1551} /* namespace gum */
1552
1553#endif // DOXYGEN_SHOULD_SKIP_THIS
The BayesBall algorithm (as described by Schachter).
Detect barren nodes for inference in Bayesian networks.
An algorithm for converting a join tree into a binary join tree.
Exception : fatal (unknown ?) error.
Class representing the minimal interface for Bayesian network with no numerical data.
Definition IBayesNet.h:75
Exception : several evidence are incompatible together (proba=0).
Exception: at least one argument passed to a function is not what was expected.
<agrum/BN/inference/jointTargetedInference.h>
Size size() const noexcept
Returns the number of elements in the set.
Definition set_tpl.h:636
iterator_safe beginSafe() const
The usual safe begin iterator to parse the set.
Definition set_tpl.h:414
const iterator_safe & endSafe() const noexcept
The usual safe end iterator to parse the set.
Definition set_tpl.h:426
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
Definition set_tpl.h:533
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
void erase(const Key &k)
Erases an element from the set.
Definition set_tpl.h:582
VariableElimination(const IBayesNet< GUM_SCALAR > *BN, RelevantTensorsFinderType=RelevantTensorsFinderType::DSEP_BAYESBALL_TENSORS, FindBarrenNodesType=FindBarrenNodesType::FIND_BARREN_NODES)
default constructor
d-separation analysis (as described in Koller & Friedman 2009)
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
Header files of gum::Instantiation.
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
FindBarrenNodesType
type of algorithm to determine barren nodes
Set< const DiscreteVariable * > VariableSet
CliqueGraph JunctionTree
a junction tree is a clique graph satisfying the running intersection property and such that no cliqu...
RelevantTensorsFinderType
type of algorithm for determining the relevant tensors for combinations using some d-separation analy...
Implementation of a variable elimination algorithm for inference in Bayesian networks.