51#ifndef DOXYGEN_SHOULD_SKIP_THIS
67 template <
typename GUM_SCALAR >
74 = &VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation2_;
75 setRelevantTensorsFinderType(relevant_type);
76 setFindBarrenNodesType(barren_type);
79 _triangulation_ =
new DefaultTriangulation;
82 GUM_CONSTRUCTOR(VariableElimination);
86 template <
typename GUM_SCALAR >
87 INLINE VariableElimination< GUM_SCALAR >::~VariableElimination() {
89 if (_JT_ !=
nullptr)
delete _JT_;
90 delete _triangulation_;
91 if (_target_posterior_ !=
nullptr)
delete _target_posterior_;
94 GUM_DESTRUCTOR(VariableElimination);
98 template <
typename GUM_SCALAR >
99 void VariableElimination< GUM_SCALAR >::setTriangulation(
const Triangulation& new_triangulation) {
100 delete _triangulation_;
101 _triangulation_ = new_triangulation.newFactory();
105 template <
typename GUM_SCALAR >
106 INLINE
const JunctionTree* VariableElimination< GUM_SCALAR >::junctionTree(NodeId
id) {
107 _createNewJT_(NodeSet{
id});
113 template <
typename GUM_SCALAR >
114 void VariableElimination< GUM_SCALAR >::setRelevantTensorsFinderType(
115 RelevantTensorsFinderType type) {
116 if (type != _find_relevant_tensor_type_) {
118 case RelevantTensorsFinderType::DSEP_BAYESBALL_TENSORS :
119 _findRelevantTensors_
120 = &VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation2_;
123 case RelevantTensorsFinderType::DSEP_BAYESBALL_NODES :
124 _findRelevantTensors_
125 = &VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation_;
128 case RelevantTensorsFinderType::DSEP_KOLLER_FRIEDMAN_2009 :
129 _findRelevantTensors_
130 = &VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation3_;
133 case RelevantTensorsFinderType::FIND_ALL :
134 _findRelevantTensors_ = &VariableElimination< GUM_SCALAR >::_findRelevantTensorsGetAll_;
139 "setRelevantTensorsFinderType for type " << (
unsigned int)type
140 <<
" is not implemented yet");
143 _find_relevant_tensor_type_ = type;
148 template <
typename GUM_SCALAR >
149 INLINE
void VariableElimination< GUM_SCALAR >::_setProjectionFunction_(
150 Tensor< GUM_SCALAR > (*proj)(
const Tensor< GUM_SCALAR >&,
const gum::VariableSet&)) {
151 _projection_op_ = proj;
155 template <
typename GUM_SCALAR >
156 INLINE
void VariableElimination< GUM_SCALAR >::_setCombinationFunction_(
157 Tensor< GUM_SCALAR > (*comb)(
const Tensor< GUM_SCALAR >&,
const Tensor< GUM_SCALAR >&)) {
158 _combination_op_ = comb;
162 template <
typename GUM_SCALAR >
163 void VariableElimination< GUM_SCALAR >::setFindBarrenNodesType(FindBarrenNodesType type) {
164 if (type != _barren_nodes_type_) {
168 case FindBarrenNodesType::FIND_BARREN_NODES :
169 case FindBarrenNodesType::FIND_NO_BARREN_NODES :
break;
173 "setFindBarrenNodesType for type " << (
unsigned int)type
174 <<
" is not implemented yet");
177 _barren_nodes_type_ = type;
182 template <
typename GUM_SCALAR >
183 INLINE
void VariableElimination< GUM_SCALAR >::onEvidenceAdded_(
const NodeId,
bool) {}
186 template <
typename GUM_SCALAR >
187 INLINE
void VariableElimination< GUM_SCALAR >::onEvidenceErased_(
const NodeId,
bool) {}
190 template <
typename GUM_SCALAR >
191 void VariableElimination< GUM_SCALAR >::onAllEvidenceErased_(
bool) {}
194 template <
typename GUM_SCALAR >
195 INLINE
void VariableElimination< GUM_SCALAR >::onEvidenceChanged_(
const NodeId,
bool) {}
198 template <
typename GUM_SCALAR >
199 INLINE
void VariableElimination< GUM_SCALAR >::onMarginalTargetAdded_(
const NodeId) {}
202 template <
typename GUM_SCALAR >
203 INLINE
void VariableElimination< GUM_SCALAR >::onMarginalTargetErased_(
const NodeId) {}
206 template <
typename GUM_SCALAR >
207 INLINE
void VariableElimination< GUM_SCALAR >::onModelChanged_(
const GraphicalModel* bn) {}
210 template <
typename GUM_SCALAR >
211 INLINE
void VariableElimination< GUM_SCALAR >::onJointTargetAdded_(
const NodeSet&) {}
214 template <
typename GUM_SCALAR >
215 INLINE
void VariableElimination< GUM_SCALAR >::onJointTargetErased_(
const NodeSet&) {}
218 template <
typename GUM_SCALAR >
219 INLINE
void VariableElimination< GUM_SCALAR >::onAllMarginalTargetsAdded_() {}
222 template <
typename GUM_SCALAR >
223 INLINE
void VariableElimination< GUM_SCALAR >::onAllMarginalTargetsErased_() {}
226 template <
typename GUM_SCALAR >
227 INLINE
void VariableElimination< GUM_SCALAR >::onAllJointTargetsErased_() {}
230 template <
typename GUM_SCALAR >
231 INLINE
void VariableElimination< GUM_SCALAR >::onAllTargetsErased_() {}
234 template <
typename GUM_SCALAR >
235 void VariableElimination< GUM_SCALAR >::_createNewJT_(
const NodeSet& targets) {
251 const auto& bn = this->BN();
253 for (
const auto node: bn.dag())
254 _graph_.addNodeWithId(node);
261 if (_barren_nodes_type_ == FindBarrenNodesType::FIND_BARREN_NODES) {
264 if (targets.size() != bn.size()) {
265 BarrenNodesFinder finder(&(bn.dag()));
266 finder.setTargets(&targets);
268 NodeSet evidence_nodes(this->evidence().size());
269 for (
const auto& pair: this->evidence()) {
270 evidence_nodes.insert(pair.first);
272 finder.setEvidence(&evidence_nodes);
274 NodeSet barren_nodes = finder.barrenNodes();
277 for (
const auto node: barren_nodes) {
278 _graph_.eraseNode(node);
287 bool dsep_analysis =
false;
288 switch (_find_relevant_tensor_type_) {
289 case RelevantTensorsFinderType::DSEP_BAYESBALL_TENSORS :
290 case RelevantTensorsFinderType::DSEP_BAYESBALL_NODES : {
291 BayesBall::requisiteNodes(bn.dag(),
293 this->hardEvidenceNodes(),
294 this->softEvidenceNodes(),
296 dsep_analysis =
true;
299 case RelevantTensorsFinderType::DSEP_KOLLER_FRIEDMAN_2009 : {
300 dSeparationAlgorithm dsep;
301 dsep.requisiteNodes(bn.dag(),
303 this->hardEvidenceNodes(),
304 this->softEvidenceNodes(),
306 dsep_analysis =
true;
309 case RelevantTensorsFinderType::FIND_ALL :
break;
316 for (
auto iter = _graph_.beginSafe(); iter != _graph_.endSafe(); ++iter) {
317 if (!requisite_nodes.contains(*iter) && !this->hardEvidenceNodes().contains(*iter)) {
318 _graph_.eraseNode(*iter);
325 for (
const auto node: _graph_) {
326 const NodeSet& parents = bn.parents(node);
327 for (
auto iter1 = parents.cbegin(); iter1 != parents.cend(); ++iter1) {
332 if (_graph_.existsNode(*iter1)) {
333 _graph_.addEdge(*iter1, node);
336 for (++iter2; iter2 != parents.cend(); ++iter2) {
341 if (_graph_.existsNode(*iter2)) _graph_.addEdge(*iter1, *iter2);
350 for (
auto iter1 = targets.cbegin(); iter1 != targets.cend(); ++iter1) {
352 for (++iter2; iter2 != targets.cend(); ++iter2) {
353 _graph_.addEdge(*iter1, *iter2);
358 const auto& hard_ev_nodes = this->hardEvidenceNodes();
359 for (
const auto node: hard_ev_nodes) {
360 _graph_.eraseNode(node);
365 if (_JT_ !=
nullptr)
delete _JT_;
366 _triangulation_->setGraph(&_graph_, &(this->domainSizes()));
367 const JunctionTree& triang_jt = _triangulation_->junctionTree();
368 _JT_ =
new CliqueGraph(triang_jt);
372 _node_to_clique_.clear();
373 _clique_to_nodes_.clear();
374 NodeSet emptyset(_JT_->size());
375 for (
auto clique: *_JT_)
376 _clique_to_nodes_.insert(clique, emptyset);
377 const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder();
378 NodeProperty< int > elim_order(Size(JT_elim_order.size()));
379 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size; ++i)
380 elim_order.insert(JT_elim_order[i], (
int)i);
381 const DAG& dag = bn.dag();
383 for (
const auto node: _graph_) {
385 NodeId first_eliminated_node = node;
386 int elim_number = elim_order[first_eliminated_node];
388 for (
const auto parent: dag.parents(node)) {
389 if (_graph_.existsNode(parent) && (elim_order[parent] < elim_number)) {
390 elim_number = elim_order[parent];
391 first_eliminated_node = parent;
399 NodeId clique = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
400 _node_to_clique_.insert(node, clique);
401 _clique_to_nodes_[clique].insert(node);
407 for (
const auto node: hard_ev_nodes) {
408 NodeId first_eliminated_node = std::numeric_limits< NodeId >::max();
409 int elim_number = std::numeric_limits< int >::max();
411 for (
const auto parent: dag.parents(node)) {
412 if (_graph_.exists(parent) && (elim_order[parent] < elim_number)) {
413 elim_number = elim_order[parent];
414 first_eliminated_node = parent;
422 if (elim_number != std::numeric_limits< int >::max()) {
423 NodeId clique = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
424 _node_to_clique_.insert(node, clique);
425 _clique_to_nodes_[clique].insert(node);
431 _targets2clique_ = std::numeric_limits< NodeId >::max();
435 NodeId first_eliminated_node = std::numeric_limits< NodeId >::max();
436 int elim_number = std::numeric_limits< int >::max();
438 for (
const auto node: targets) {
439 if (!hard_ev_nodes.contains(node) && (elim_order[node] < elim_number)) {
440 elim_number = elim_order[node];
441 first_eliminated_node = node;
445 if (elim_number != std::numeric_limits< int >::max()) {
446 _targets2clique_ = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
452 template <
typename GUM_SCALAR >
453 void VariableElimination< GUM_SCALAR >::updateOutdatedStructure_() {}
457 template <
typename GUM_SCALAR >
458 void VariableElimination< GUM_SCALAR >::updateOutdatedTensors_() {}
461 template <
typename GUM_SCALAR >
462 void VariableElimination< GUM_SCALAR >::_findRelevantTensorsGetAll_(
463 Set< const IScheduleMultiDim* >& pot_list,
467 template <
typename GUM_SCALAR >
468 void VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation_(
469 Set< const IScheduleMultiDim* >& pot_list,
473 const auto& bn = this->BN();
474 for (
const auto var: kept_vars) {
475 kept_ids.insert(bn.nodeId(*var));
480 BayesBall::requisiteNodes(bn.dag(),
482 this->hardEvidenceNodes(),
483 this->softEvidenceNodes(),
485 for (
auto iter = pot_list.beginSafe(); iter != pot_list.endSafe(); ++iter) {
486 const Sequence< const DiscreteVariable* >& vars = (*iter)->variablesSequence();
488 for (
const auto var: vars) {
489 if (requisite_nodes.exists(bn.nodeId(*var))) {
495 if (!found) { pot_list.erase(iter); }
500 template <
typename GUM_SCALAR >
501 void VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation2_(
502 Set< const IScheduleMultiDim* >& pot_list,
506 const auto& bn = this->BN();
507 for (
const auto var: kept_vars) {
508 kept_ids.insert(bn.nodeId(*var));
512 BayesBall::relevantTensors(bn,
514 this->hardEvidenceNodes(),
515 this->softEvidenceNodes(),
520 template <
typename GUM_SCALAR >
521 void VariableElimination< GUM_SCALAR >::_findRelevantTensorsWithdSeparation3_(
522 Set< const IScheduleMultiDim* >& pot_list,
526 const auto& bn = this->BN();
527 for (
const auto var: kept_vars) {
528 kept_ids.insert(bn.nodeId(*var));
532 dSeparationAlgorithm dsep;
533 dsep.relevantTensors(bn,
535 this->hardEvidenceNodes(),
536 this->softEvidenceNodes(),
541 template <
typename GUM_SCALAR >
542 void VariableElimination< GUM_SCALAR >::_findRelevantTensorsXX_(
543 Set< const IScheduleMultiDim* >& pot_list,
545 switch (_find_relevant_tensor_type_) {
546 case RelevantTensorsFinderType::DSEP_BAYESBALL_TENSORS :
547 _findRelevantTensorsWithdSeparation2_(pot_list, kept_vars);
550 case RelevantTensorsFinderType::DSEP_BAYESBALL_NODES :
551 _findRelevantTensorsWithdSeparation_(pot_list, kept_vars);
554 case RelevantTensorsFinderType::DSEP_KOLLER_FRIEDMAN_2009 :
555 _findRelevantTensorsWithdSeparation3_(pot_list, kept_vars);
558 case RelevantTensorsFinderType::FIND_ALL :
559 _findRelevantTensorsGetAll_(pot_list, kept_vars);
567 template <
typename GUM_SCALAR >
568 Set< const IScheduleMultiDim* >
569 VariableElimination< GUM_SCALAR >::_removeBarrenVariables_(Schedule& schedule,
570 _ScheduleMultiDimSet_& pot_list,
575 for (
auto iter = the_del_vars.
beginSafe(); iter != the_del_vars.
endSafe(); ++iter) {
576 NodeId
id = this->BN().nodeId(**iter);
577 if (this->hardEvidenceNodes().exists(
id) || this->softEvidenceNodes().exists(
id)) {
578 the_del_vars.
erase(iter);
583 HashTable< const DiscreteVariable*, _ScheduleMultiDimSet_ > var2pots(the_del_vars.
size());
584 _ScheduleMultiDimSet_ empty_pot_set;
585 for (
const auto pot: pot_list) {
586 const auto& vars = pot->variablesSequence();
587 for (
const auto var: vars) {
588 if (the_del_vars.
exists(var)) {
589 if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); }
590 var2pots[var].insert(pot);
597 HashTable< const IScheduleMultiDim*, gum::VariableSet > pot2barren_var;
599 for (
const auto& elt: var2pots) {
600 if (elt.second.size() == 1) {
601 const IScheduleMultiDim* pot = *(elt.second.begin());
602 if (!pot2barren_var.exists(pot)) { pot2barren_var.insert(pot, empty_var_set); }
603 pot2barren_var[pot].insert(elt.first);
610 MultiDimProjection< Tensor< GUM_SCALAR > > projector(_projection_op_);
611 _ScheduleMultiDimSet_ projected_pots;
612 for (
const auto& elt: pot2barren_var) {
614 const IScheduleMultiDim* pot = elt.first;
619 if (pot->variablesSequence().size() != elt.second.size()) {
620 const IScheduleMultiDim* new_pot = projector.schedule(schedule, pot, elt.second);
624 pot_list.insert(new_pot);
625 projected_pots.insert(new_pot);
629 return projected_pots;
633 template <
typename GUM_SCALAR >
634 Set< const Tensor< GUM_SCALAR >* >
635 VariableElimination< GUM_SCALAR >::_removeBarrenVariables_(_TensorSet_& pot_list,
640 for (
auto iter = the_del_vars.
beginSafe(); iter != the_del_vars.
endSafe(); ++iter) {
641 NodeId
id = this->BN().nodeId(**iter);
642 if (this->hardEvidenceNodes().exists(
id) || this->softEvidenceNodes().exists(
id)) {
643 the_del_vars.
erase(iter);
648 HashTable< const DiscreteVariable*, _TensorSet_ > var2pots;
649 _TensorSet_ empty_pot_set;
650 for (
const auto pot: pot_list) {
651 const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
652 for (
const auto var: vars) {
653 if (the_del_vars.
exists(var)) {
654 if (!var2pots.exists(var)) { var2pots.insert(var, empty_pot_set); }
655 var2pots[var].insert(pot);
664 for (
const auto& elt: var2pots) {
665 if (elt.second.size() == 1) {
666 const Tensor< GUM_SCALAR >* pot = *(elt.second.begin());
667 if (!pot2barren_var.exists(pot)) { pot2barren_var.insert(pot, empty_var_set); }
668 pot2barren_var[pot].insert(elt.first);
675 MultiDimProjection< Tensor< GUM_SCALAR > > projector(_projection_op_);
676 _TensorSet_ projected_pots;
677 for (
const auto& elt: pot2barren_var) {
679 const Tensor< GUM_SCALAR >* pot = elt.first;
684 if (pot->variablesSequence().size() != elt.second.size()) {
685 const Tensor< GUM_SCALAR >* new_pot = projector.execute(*pot, elt.second);
686 pot_list.insert(new_pot);
687 projected_pots.insert(new_pot);
691 return projected_pots;
695 template <
typename GUM_SCALAR >
696 Set< const IScheduleMultiDim* >
697 VariableElimination< GUM_SCALAR >::_collectMessage_(Schedule& schedule,
701 _ScheduleMultiDimSet_ collected_messages;
702 for (
const auto other: _JT_->neighbours(
id)) {
704 _ScheduleMultiDimSet_ message(_collectMessage_(schedule, other,
id));
705 collected_messages += message;
710 return _produceMessage_(schedule,
id, from, std::move(collected_messages));
714 template <
typename GUM_SCALAR >
715 std::pair< Set< const Tensor< GUM_SCALAR >* >, Set< const Tensor< GUM_SCALAR >* > >
716 VariableElimination< GUM_SCALAR >::_collectMessage_(NodeId
id, NodeId from) {
718 std::pair< _TensorSet_, _TensorSet_ > collected_messages;
719 for (
const auto other: _JT_->neighbours(
id)) {
721 std::pair< _TensorSet_, _TensorSet_ > message(_collectMessage_(other,
id));
722 collected_messages.first += message.first;
723 collected_messages.second += message.second;
728 return _produceMessage_(
id, from, std::move(collected_messages));
732 template <
typename GUM_SCALAR >
733 Set< const IScheduleMultiDim* >
734 VariableElimination< GUM_SCALAR >::_NodeTensors_(Schedule& schedule, NodeId node) {
735 _ScheduleMultiDimSet_ res;
736 const auto& bn = this->BN();
747 const auto& evidence = this->evidence();
748 const auto& hard_evidence = this->hardEvidence();
749 const auto& hard_ev_nodes = this->hardEvidenceNodes();
750 if (_graph_.exists(node) || hard_ev_nodes.contains(node)) {
751 const Tensor< GUM_SCALAR >& cpt = bn.cpt(node);
752 const auto& variables = cpt.variablesSequence();
757 if (hard_ev_nodes.contains(node)) {
758 for (
const auto var: variables) {
759 NodeId xnode = bn.nodeId(*var);
760 if (!hard_ev_nodes.contains(xnode) && !_graph_.existsNode(xnode))
return res;
765 NodeSet hard_nodes(variables.size());
766 for (
const auto var: variables) {
767 NodeId xnode = bn.nodeId(*var);
768 if (hard_ev_nodes.contains(xnode)) hard_nodes.insert(xnode);
774 if (hard_nodes.empty()) {
775 const IScheduleMultiDim* sched_cpt
776 = schedule.insertTable< Tensor< GUM_SCALAR > >(cpt,
false);
777 res.insert(sched_cpt);
782 if (hard_nodes.size() != variables.size()) {
785 _ScheduleMultiDimSet_ marg_cpt_set(1 + hard_nodes.size());
786 const IScheduleMultiDim* sched_cpt
787 = schedule.insertTable< Tensor< GUM_SCALAR > >(cpt,
false);
788 marg_cpt_set.insert(sched_cpt);
790 for (
const auto xnode: hard_nodes) {
791 const IScheduleMultiDim* pot
792 = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[xnode],
false);
793 marg_cpt_set.insert(pot);
794 hard_variables.
insert(&(bn.variable(xnode)));
798 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
801 _ScheduleMultiDimSet_ new_cpt_list
802 = combine_and_project.schedule(schedule, marg_cpt_set, hard_variables);
805 if (new_cpt_list.size() != 1) {
807 "the projection of a tensor containing " <<
"hard evidence is empty!");
809 auto projected_pot =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
810 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
811 *new_cpt_list.begin()));
812 res.insert(projected_pot);
817 if (evidence.exists(node) && !hard_evidence.exists(node)) {
818 const IScheduleMultiDim* pot
819 = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[node],
false);
828 template <
typename GUM_SCALAR >
829 std::pair< Set< const Tensor< GUM_SCALAR >* >, Set< const Tensor< GUM_SCALAR >* > >
830 VariableElimination< GUM_SCALAR >::_NodeTensors_(NodeId node) {
831 std::pair< _TensorSet_, _TensorSet_ > res;
832 const auto& bn = this->BN();
843 const auto& evidence = this->evidence();
844 const auto& hard_evidence = this->hardEvidence();
845 const auto& hard_ev_nodes = this->hardEvidenceNodes();
846 if (_graph_.exists(node) || hard_ev_nodes.contains(node)) {
847 const Tensor< GUM_SCALAR >& cpt = bn.cpt(node);
848 const auto& variables = cpt.variablesSequence();
853 if (hard_ev_nodes.contains(node)) {
854 for (
const auto var: variables) {
855 NodeId xnode = bn.nodeId(*var);
856 if (!hard_ev_nodes.contains(xnode) && !_graph_.existsNode(xnode))
return res;
861 NodeSet hard_nodes(variables.size());
862 for (
const auto var: variables) {
863 NodeId xnode = bn.nodeId(*var);
864 if (hard_ev_nodes.contains(xnode)) hard_nodes.insert(xnode);
870 if (hard_nodes.empty()) {
871 res.first.insert(&cpt);
876 if (hard_nodes.size() != variables.size()) {
879 _TensorSet_ marg_cpt_set(1 + hard_nodes.size());
880 marg_cpt_set.insert(&cpt);
882 for (
const auto xnode: hard_nodes) {
883 marg_cpt_set.insert(evidence[xnode]);
884 hard_variables.
insert(&(bn.variable(xnode)));
887 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
890 _TensorSet_ new_cpt_list = combine_and_project.execute(marg_cpt_set, hard_variables);
893 if (new_cpt_list.size() != 1) {
895 for (
auto pot: new_cpt_list) {
896 if (!marg_cpt_set.contains(pot))
delete pot;
899 "the projection of a tensor containing " <<
"hard evidence is empty!");
901 const Tensor< GUM_SCALAR >* projected_cpt = *(new_cpt_list.begin());
902 res.first.insert(projected_cpt);
903 res.second.insert(projected_cpt);
908 if (evidence.exists(node) && !hard_evidence.exists(node)) {
909 res.first.insert(this->evidence()[node]);
917 template <
typename GUM_SCALAR >
918 std::pair< Set< const Tensor< GUM_SCALAR >* >, Set< const Tensor< GUM_SCALAR >* > >
919 VariableElimination< GUM_SCALAR >::_produceMessage_(
922 std::pair< Set<
const Tensor< GUM_SCALAR >* >, Set<
const Tensor< GUM_SCALAR >* > >&&
925 std::pair< _TensorSet_, _TensorSet_ > pot_list(std::move(incoming_messages));
928 for (
const auto node: _clique_to_nodes_[from_id]) {
929 auto new_pots = _NodeTensors_(node);
930 pot_list.first += new_pots.first;
931 pot_list.second += new_pots.second;
935 if (!_JT_->existsEdge(from_id, to_id)) {
939 const NodeSet& from_clique = _JT_->clique(from_id);
940 const NodeSet& separator = _JT_->separator(from_id, to_id);
943 const auto& bn = this->BN();
945 for (
const auto node: from_clique) {
946 if (!separator.contains(node)) {
947 del_vars.
insert(&(bn.variable(node)));
949 kept_vars.
insert(&(bn.variable(node)));
955 _TensorSet_ new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars);
958 for (
auto iter = pot_list.second.beginSafe(); iter != pot_list.second.endSafe(); ++iter) {
959 if (!new_pot_list.contains(*iter)) {
961 pot_list.second.erase(iter);
966 for (
const auto pot: new_pot_list) {
967 if (!pot_list.first.contains(pot)) { pot_list.second.insert(pot); }
971 return std::pair< _TensorSet_, _TensorSet_ >(std::move(new_pot_list),
972 std::move(pot_list.second));
977 template <
typename GUM_SCALAR >
978 Set< const IScheduleMultiDim* > VariableElimination< GUM_SCALAR >::_produceMessage_(
982 Set< const IScheduleMultiDim* >&& incoming_messages) {
984 _ScheduleMultiDimSet_ pot_list(std::move(incoming_messages));
987 for (
const auto node: _clique_to_nodes_[from_id]) {
988 pot_list += _NodeTensors_(schedule, node);
992 if (!_JT_->existsEdge(from_id, to_id)) {
996 const NodeSet& from_clique = _JT_->clique(from_id);
997 const NodeSet& separator = _JT_->separator(from_id, to_id);
1000 const auto& bn = this->BN();
1002 for (
const auto node: from_clique) {
1003 if (!separator.contains(node)) {
1004 del_vars.
insert(&(bn.variable(node)));
1006 kept_vars.
insert(&(bn.variable(node)));
1012 _ScheduleMultiDimSet_ new_pot_list
1013 = _marginalizeOut_(schedule, pot_list, del_vars, kept_vars);
1016 for (
auto pot: pot_list) {
1017 if (!new_pot_list.contains(pot)) {
1018 const auto sched_pot
1019 =
static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(pot);
1020 schedule.emplaceDeletion(*sched_pot);
1025 return new_pot_list;
1030 template <
typename GUM_SCALAR >
1031 Set< const Tensor< GUM_SCALAR >* > VariableElimination< GUM_SCALAR >::_marginalizeOut_(
1032 Set<
const Tensor< GUM_SCALAR >* > pot_list,
1036 if (pot_list.empty()) {
return _TensorSet_(); }
1043 _TensorSet_ barren_projected_tensors;
1044 if (_barren_nodes_type_ == FindBarrenNodesType::FIND_BARREN_NODES) {
1045 barren_projected_tensors = _removeBarrenVariables_(pot_list, del_vars);
1049 _TensorSet_ new_pot_list;
1050 if (pot_list.size() == 1) {
1051 MultiDimProjection< Tensor< GUM_SCALAR > > projector(_projection_op_);
1052 auto pot = projector.execute(**(pot_list.begin()), del_vars);
1053 new_pot_list.insert(pot);
1054 }
else if (pot_list.size() > 1) {
1057 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(_combination_op_,
1059 new_pot_list = combine_and_project.execute(pot_list, del_vars);
1065 for (
auto iter = barren_projected_tensors.beginSafe();
1066 iter != barren_projected_tensors.endSafe();
1068 if (!new_pot_list.exists(*iter))
delete *iter;
1071 return new_pot_list;
1075 template <
typename GUM_SCALAR >
1076 Set< const IScheduleMultiDim* >
1077 VariableElimination< GUM_SCALAR >::_marginalizeOut_(Schedule& schedule,
1078 Set< const IScheduleMultiDim* > pot_list,
1082 if (pot_list.empty()) {
return _ScheduleMultiDimSet_(); }
1089 for (
const auto pot: pot_list) {
1090 if (!schedule.existsScheduleMultiDim(pot->id())) schedule.emplaceScheduleMultiDim(*pot);
1095 _ScheduleMultiDimSet_ barren_projected_tensors;
1096 if (_barren_nodes_type_ == FindBarrenNodesType::FIND_BARREN_NODES) {
1097 barren_projected_tensors = _removeBarrenVariables_(schedule, pot_list, del_vars);
1101 _ScheduleMultiDimSet_ new_pot_list;
1102 if (pot_list.size() == 1) {
1103 MultiDimProjection< Tensor< GUM_SCALAR > > projector(_projection_op_);
1104 auto xpot = projector.schedule(schedule, *(pot_list.begin()), del_vars);
1105 new_pot_list.insert(xpot);
1106 }
else if (pot_list.size() > 1) {
1109 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(_combination_op_,
1111 new_pot_list = combine_and_project.schedule(schedule, pot_list, del_vars);
1117 for (
auto pot: barren_projected_tensors) {
1118 if (!new_pot_list.exists(pot)) {
1119 const auto sched_pot =
static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(pot);
1120 schedule.emplaceDeletion(*sched_pot);
1124 return new_pot_list;
1128 template <
typename GUM_SCALAR >
1129 INLINE
void VariableElimination< GUM_SCALAR >::makeInference_() {}
1132 template <
typename GUM_SCALAR >
1133 Tensor< GUM_SCALAR >* VariableElimination< GUM_SCALAR >::unnormalizedJointPosterior_(NodeId
id) {
1136 if (this->hardEvidenceNodes().contains(
id)) {
1137 return new Tensor< GUM_SCALAR >(*(this->evidence()[
id]));
1141 _createNewJT_(NodeSet{
id});
1146 double overall_size = 0;
1147 for (
const auto clique: *_JT_) {
1148 double clique_size = 1.0;
1149 for (
const auto node: _JT_->clique(clique))
1150 clique_size *= this->domainSizes()[node];
1151 overall_size += clique_size;
1153 const bool use_schedules = (overall_size > _schedule_threshold_);
1155 if (use_schedules) {
1157 return _unnormalizedJointPosterior_(schedule,
id);
1159 return _unnormalizedJointPosterior_(
id);
1164 template <
typename GUM_SCALAR >
1165 Tensor< GUM_SCALAR >* VariableElimination< GUM_SCALAR >::_unnormalizedJointPosterior_(NodeId
id) {
1166 const auto& bn = this->BN();
1168 NodeId clique_of_id = _node_to_clique_[id];
1169 std::pair< _TensorSet_, _TensorSet_ > pot_list = _collectMessage_(clique_of_id, clique_of_id);
1172 const NodeSet& nodes = _JT_->clique(clique_of_id);
1175 for (
const auto node: nodes) {
1176 if (node !=
id) del_vars.
insert(&(bn.variable(node)));
1181 _TensorSet_ new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars);
1182 Tensor< GUM_SCALAR >* joint =
nullptr;
1184 if (new_pot_list.size() == 0) {
1185 joint =
new Tensor< GUM_SCALAR >;
1186 for (
const auto var: kept_vars)
1189 if (new_pot_list.size() == 1) {
1190 joint =
const_cast< Tensor< GUM_SCALAR >*
>(*(new_pot_list.begin()));
1193 if (pot_list.first.exists(joint)) {
1194 joint =
new Tensor< GUM_SCALAR >(*joint);
1198 new_pot_list.clear();
1201 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1202 joint = fast_combination.execute(new_pot_list);
1207 for (
auto pot: new_pot_list)
1208 if (!pot_list.first.exists(pot))
delete pot;
1211 for (
auto pot: pot_list.second)
1217 bool nonzero_found =
false;
1218 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1219 if ((*joint)[inst]) {
1220 nonzero_found =
true;
1224 if (!nonzero_found) {
1228 "some evidence entered into the Bayes "
1229 "net are incompatible (their joint proba = 0)");
1236 template <
typename GUM_SCALAR >
1237 Tensor< GUM_SCALAR >*
1238 VariableElimination< GUM_SCALAR >::_unnormalizedJointPosterior_(Schedule& schedule,
1240 const auto& bn = this->BN();
1242 NodeId clique_of_id = _node_to_clique_[id];
1243 _ScheduleMultiDimSet_ pot_list = _collectMessage_(schedule, clique_of_id, clique_of_id);
1246 const NodeSet& nodes = _JT_->clique(clique_of_id);
1249 for (
const auto node: nodes) {
1250 if (node !=
id) del_vars.
insert(&(bn.variable(node)));
1255 _ScheduleMultiDimSet_ new_pot_list = _marginalizeOut_(schedule, pot_list, del_vars, kept_vars);
1256 Tensor< GUM_SCALAR >* joint =
nullptr;
1257 ScheduleMultiDim< Tensor< GUM_SCALAR > >* resulting_pot =
nullptr;
1259 if (new_pot_list.size() == 0) {
1260 joint =
new Tensor< GUM_SCALAR >;
1261 for (
const auto var: kept_vars)
1264 auto& scheduler = this->scheduler();
1265 if (new_pot_list.size() == 1) {
1266 scheduler.execute(schedule);
1267 resulting_pot =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1268 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(*new_pot_list.begin()));
1270 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1271 const IScheduleMultiDim* pot = fast_combination.schedule(schedule, new_pot_list);
1272 scheduler.execute(schedule);
1273 resulting_pot =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1274 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(pot));
1279 if (pot_list.exists(resulting_pot)) {
1280 joint =
new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1282 joint = resulting_pot->exportMultiDim();
1289 bool nonzero_found =
false;
1290 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1291 if ((*joint)[inst]) {
1292 nonzero_found =
true;
1296 if (!nonzero_found) {
1300 "some evidence entered into the Bayes "
1301 "net are incompatible (their joint proba = 0)");
1308 template <
typename GUM_SCALAR >
1309 const Tensor< GUM_SCALAR >& VariableElimination< GUM_SCALAR >::posterior_(NodeId
id) {
1311 auto joint = unnormalizedJointPosterior_(
id);
1312 if (joint->sum() != 1)
1315 if (_target_posterior_ !=
nullptr)
delete _target_posterior_;
1316 _target_posterior_ = joint;
1322 template <
typename GUM_SCALAR >
1323 Tensor< GUM_SCALAR >*
1324 VariableElimination< GUM_SCALAR >::unnormalizedJointPosterior_(
const NodeSet& set) {
1327 NodeSet targets = set, hard_ev_nodes(this->hardEvidenceNodes().size());
1328 for (
const auto node: this->hardEvidenceNodes()) {
1329 if (targets.contains(node)) {
1330 targets.erase(node);
1331 hard_ev_nodes.insert(node);
1337 const auto& evidence = this->evidence();
1338 if (targets.empty()) {
1339 _TensorSet_ pot_list;
1340 for (
const auto node: set) {
1341 pot_list.insert(evidence[node]);
1343 if (pot_list.size() == 1) {
1344 return new Tensor< GUM_SCALAR >(**(pot_list.begin()));
1346 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1347 return fast_combination.execute(pot_list);
1357 double overall_size = 0;
1358 for (
const auto clique: *_JT_) {
1359 double clique_size = 1.0;
1360 for (
const auto node: _JT_->clique(clique))
1361 clique_size *= this->domainSizes()[node];
1362 overall_size += clique_size;
1364 const bool use_schedules = (overall_size > _schedule_threshold_);
1366 if (use_schedules) {
1368 return _unnormalizedJointPosterior_(schedule, set, targets, hard_ev_nodes);
1370 return _unnormalizedJointPosterior_(set, targets, hard_ev_nodes);
1375 template <
typename GUM_SCALAR >
1376 Tensor< GUM_SCALAR >* VariableElimination< GUM_SCALAR >::_unnormalizedJointPosterior_(
1378 const NodeSet& targets,
1379 const NodeSet& hard_ev_nodes) {
1380 std::pair< _TensorSet_, _TensorSet_ > pot_list
1381 = _collectMessage_(_targets2clique_, _targets2clique_);
1384 const NodeSet& nodes = _JT_->clique(_targets2clique_);
1387 const auto& bn = this->BN();
1388 for (
const auto node: nodes) {
1389 if (!targets.contains(node)) {
1390 del_vars.
insert(&(bn.variable(node)));
1392 kept_vars.
insert(&(bn.variable(node)));
1398 _TensorSet_ new_pot_list = _marginalizeOut_(pot_list.first, del_vars, kept_vars);
1399 Tensor< GUM_SCALAR >* joint =
nullptr;
1401 if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) {
1402 joint =
const_cast< Tensor< GUM_SCALAR >*
>(*(new_pot_list.begin()));
1405 if (pot_list.first.exists(joint)) {
1406 joint =
new Tensor< GUM_SCALAR >(*joint);
1410 new_pot_list.clear();
1415 const auto& evidence = this->evidence();
1416 _TensorSet_ new_new_pot_list = new_pot_list;
1417 for (
const auto node: hard_ev_nodes) {
1418 new_new_pot_list.insert(evidence[node]);
1420 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1421 joint = fast_combination.execute(new_new_pot_list);
1425 for (
auto pot: new_pot_list)
1426 if (!pot_list.first.exists(pot))
delete pot;
1429 for (
auto pot: pot_list.second)
1434 bool nonzero_found =
false;
1435 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1436 if ((*joint)[inst]) {
1437 nonzero_found =
true;
1441 if (!nonzero_found) {
1445 "some evidence entered into the Bayes "
1446 "net are incompatible (their joint proba = 0)");
1453 template <
typename GUM_SCALAR >
1454 Tensor< GUM_SCALAR >* VariableElimination< GUM_SCALAR >::_unnormalizedJointPosterior_(
1457 const NodeSet& targets,
1458 const NodeSet& hard_ev_nodes) {
1459 _ScheduleMultiDimSet_ pot_list = _collectMessage_(schedule, _targets2clique_, _targets2clique_);
1462 const NodeSet& nodes = _JT_->clique(_targets2clique_);
1465 const auto& bn = this->BN();
1466 for (
const auto node: nodes) {
1467 if (!targets.contains(node)) {
1468 del_vars.
insert(&(bn.variable(node)));
1470 kept_vars.
insert(&(bn.variable(node)));
1476 _ScheduleMultiDimSet_ new_pot_list = _marginalizeOut_(schedule, pot_list, del_vars, kept_vars);
1477 ScheduleMultiDim< Tensor< GUM_SCALAR > >* resulting_pot =
nullptr;
1478 auto& scheduler = this->scheduler();
1480 if ((new_pot_list.size() == 1) && hard_ev_nodes.empty()) {
1481 scheduler.execute(schedule);
1482 resulting_pot =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1483 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(*new_pot_list.begin()));
1487 const auto& evidence = this->evidence();
1488 for (
const auto node: hard_ev_nodes) {
1489 auto new_pot_ev = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[node],
false);
1490 new_pot_list.insert(new_pot_ev);
1492 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1493 const auto pot = fast_combination.schedule(schedule, new_pot_list);
1494 scheduler.execute(schedule);
1495 resulting_pot =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1496 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(pot));
1501 Tensor< GUM_SCALAR >* joint =
nullptr;
1502 if (pot_list.exists(resulting_pot)) {
1503 joint =
new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1505 joint = resulting_pot->exportMultiDim();
1510 bool nonzero_found =
false;
1511 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1512 if ((*joint)[inst]) {
1513 nonzero_found =
true;
1517 if (!nonzero_found) {
1521 "some evidence entered into the Bayes "
1522 "net are incompatible (their joint proba = 0)");
1529 template <
typename GUM_SCALAR >
1530 const Tensor< GUM_SCALAR >&
1531 VariableElimination< GUM_SCALAR >::jointPosterior_(
const NodeSet& set) {
1533 auto joint = unnormalizedJointPosterior_(set);
1536 if (_target_posterior_ !=
nullptr)
delete _target_posterior_;
1537 _target_posterior_ = joint;
1543 template <
typename GUM_SCALAR >
1544 const Tensor< GUM_SCALAR >&
1545 VariableElimination< GUM_SCALAR >::jointPosterior_(
const NodeSet& wanted_target,
1546 const NodeSet& declared_target) {
1547 return jointPosterior_(wanted_target);
The BayesBall algorithm (as described by Schachter).
Detect barren nodes for inference in Bayesian networks.
An algorithm for converting a join tree into a binary join tree.
Exception : fatal (unknown ?) error.
Class representing the minimal interface for Bayesian network with no numerical data.
Exception : several evidence are incompatible together (proba=0).
Exception: at least one argument passed to a function is not what was expected.
<agrum/BN/inference/jointTargetedInference.h>
Size size() const noexcept
Returns the number of elements in the set.
iterator_safe beginSafe() const
The usual safe begin iterator to parse the set.
const iterator_safe & endSafe() const noexcept
The usual safe end iterator to parse the set.
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
void insert(const Key &k)
Inserts a new element into the set.
void erase(const Key &k)
Erases an element from the set.
VariableElimination(const IBayesNet< GUM_SCALAR > *BN, RelevantTensorsFinderType=RelevantTensorsFinderType::DSEP_BAYESBALL_TENSORS, FindBarrenNodesType=FindBarrenNodesType::FIND_BARREN_NODES)
default constructor
d-separation analysis (as described in Koller & Friedman 2009)
#define GUM_ERROR(type, msg)
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
Header files of gum::Instantiation.
gum is the global namespace for all aGrUM entities
FindBarrenNodesType
type of algorithm to determine barren nodes
Set< const DiscreteVariable * > VariableSet
CliqueGraph JunctionTree
a junction tree is a clique graph satisfying the running intersection property and such that no cliqu...
RelevantTensorsFinderType
type of algorithm for determining the relevant tensors for combinations using some d-separation analy...
Implementation of a variable elimination algorithm for inference in Bayesian networks.