51#ifndef DOXYGEN_SHOULD_SKIP_THIS
62 template <
typename GUM_SCALAR >
65 bool use_binary_join_tree) :
67 _use_binary_join_tree_(use_binary_join_tree) {
69 _triangulation_ =
new DefaultTriangulation;
72 const auto& graph = this->MRF().graph();
73 _node_to_factors_.resize(graph.size());
75 for (
const auto node: graph)
76 _node_to_factors_.insert(node, empty);
77 for (
const auto& factor: this->MRF().factors()) {
78 for (
const auto node: factor.first) {
79 _node_to_factors_[node].insert(factor.second);
84 GUM_CONSTRUCTOR(ShaferShenoyMRFInference);
88 template <
typename GUM_SCALAR >
89 INLINE ShaferShenoyMRFInference< GUM_SCALAR >::~ShaferShenoyMRFInference() {
91 for (
const auto& pot: _arc_to_created_tensors_)
99 for (
auto pot: _clique_ss_tensor_) {
100 if (_clique_tensors_[pot.first].size() > 1)
delete pot.second;
103 for (
auto potset: _clique_tensors_) {
104 for (
auto pot: potset.second)
109 for (
const auto& pot: _target_posteriors_)
111 for (
const auto& pot: _joint_target_posteriors_)
115 if (_JT_ !=
nullptr)
delete _JT_;
116 if (_junctionTree_ !=
nullptr)
delete _junctionTree_;
117 delete _triangulation_;
120 GUM_DESTRUCTOR(ShaferShenoyMRFInference);
124 template <
typename GUM_SCALAR >
125 void ShaferShenoyMRFInference< GUM_SCALAR >::setTriangulation(
126 const Triangulation& new_triangulation) {
127 delete _triangulation_;
128 _triangulation_ = new_triangulation.newFactory();
129 _is_new_jt_needed_ =
true;
130 this->setOutdatedStructureState_();
134 template <
typename GUM_SCALAR >
135 INLINE
const JoinTree* ShaferShenoyMRFInference< GUM_SCALAR >::joinTree() {
136 if (_is_new_jt_needed_) _createNewJT_();
142 template <
typename GUM_SCALAR >
143 INLINE
const JunctionTree* ShaferShenoyMRFInference< GUM_SCALAR >::junctionTree() {
144 if (_is_new_jt_needed_) _createNewJT_();
146 return _junctionTree_;
150 template <
typename GUM_SCALAR >
151 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::_setProjectionFunction_(
152 Tensor< GUM_SCALAR > (*proj)(
const Tensor< GUM_SCALAR >&,
const gum::VariableSet&)) {
153 _projection_op_ = proj;
157 _invalidateAllMessages_();
161 template <
typename GUM_SCALAR >
162 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::_setCombinationFunction_(
163 Tensor< GUM_SCALAR > (*comb)(
const Tensor< GUM_SCALAR >&,
const Tensor< GUM_SCALAR >&)) {
164 _combination_op_ = comb;
168 _invalidateAllMessages_();
172 template <
typename GUM_SCALAR >
173 void ShaferShenoyMRFInference< GUM_SCALAR >::_invalidateAllMessages_() {
175 for (
auto& pot: _separator_tensors_)
176 pot.second =
nullptr;
178 for (
auto& mess_computed: _messages_computed_)
179 mess_computed.second =
false;
182 for (
const auto& pot: _arc_to_created_tensors_)
183 if (pot.second !=
nullptr)
delete pot.second;
184 _arc_to_created_tensors_.clear();
187 for (
const auto& pot: _target_posteriors_)
189 _target_posteriors_.clear();
190 for (
const auto& pot: _joint_target_posteriors_)
192 _joint_target_posteriors_.clear();
195 if (this->isInferenceReady() || this->isInferenceDone()) this->setOutdatedTensorsState_();
199 template <
typename GUM_SCALAR >
200 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onEvidenceAdded_(
const NodeId
id,
201 bool isHardEvidence) {
205 if (isHardEvidence || !_graph_.exists(
id)) _is_new_jt_needed_ =
true;
208 _evidence_changes_.insert(
id, EvidenceChangeType::EVIDENCE_ADDED);
214 _evidence_changes_[id] = EvidenceChangeType::EVIDENCE_MODIFIED;
220 template <
typename GUM_SCALAR >
221 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onEvidenceErased_(
const NodeId
id,
222 bool isHardEvidence) {
225 if (isHardEvidence) _is_new_jt_needed_ =
true;
228 _evidence_changes_.insert(
id, EvidenceChangeType::EVIDENCE_ERASED);
235 if (_evidence_changes_[
id] == EvidenceChangeType::EVIDENCE_ADDED)
236 _evidence_changes_.erase(
id);
237 else _evidence_changes_[id] = EvidenceChangeType::EVIDENCE_ERASED;
243 template <
typename GUM_SCALAR >
244 void ShaferShenoyMRFInference< GUM_SCALAR >::onAllEvidenceErased_(
bool has_hard_evidence) {
245 if (has_hard_evidence || !this->hardEvidenceNodes().empty()) _is_new_jt_needed_ =
true;
247 for (
const auto node: this->softEvidenceNodes()) {
249 _evidence_changes_.insert(node, EvidenceChangeType::EVIDENCE_ERASED);
256 if (_evidence_changes_[node] == EvidenceChangeType::EVIDENCE_ADDED)
257 _evidence_changes_.erase(node);
258 else _evidence_changes_[node] = EvidenceChangeType::EVIDENCE_ERASED;
265 template <
typename GUM_SCALAR >
266 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onEvidenceChanged_(
const NodeId
id,
267 bool hasChangedSoftHard) {
268 if (hasChangedSoftHard) _is_new_jt_needed_ =
true;
271 _evidence_changes_.insert(
id, EvidenceChangeType::EVIDENCE_MODIFIED);
281 template <
typename GUM_SCALAR >
282 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onMarginalTargetAdded_(
const NodeId
id) {}
285 template <
typename GUM_SCALAR >
286 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onMarginalTargetErased_(
const NodeId
id) {}
289 template <
typename GUM_SCALAR >
290 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onJointTargetAdded_(
const NodeSet& set) {}
293 template <
typename GUM_SCALAR >
294 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onJointTargetErased_(
const NodeSet& set) {}
297 template <
typename GUM_SCALAR >
298 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onAllMarginalTargetsAdded_() {}
301 template <
typename GUM_SCALAR >
302 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onAllMarginalTargetsErased_() {}
305 template <
typename GUM_SCALAR >
306 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onMRFChanged_(
307 const IMarkovRandomField< GUM_SCALAR >* mn) {}
310 template <
typename GUM_SCALAR >
311 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onAllJointTargetsErased_() {}
314 template <
typename GUM_SCALAR >
315 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onAllTargetsErased_() {}
318 template <
typename GUM_SCALAR >
319 bool ShaferShenoyMRFInference< GUM_SCALAR >::_isNewJTNeeded_()
const {
322 if ((_JT_ ==
nullptr) || _is_new_jt_needed_)
return true;
328 const auto& hard_ev_nodes = this->hardEvidenceNodes();
329 for (
const auto node: this->targets()) {
330 if (!_graph_.exists(node) && !hard_ev_nodes.exists(node))
return true;
334 const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder();
335 NodeProperty< int > elim_order(Size(JT_elim_order.size()));
336 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size; ++i)
337 elim_order.insert(JT_elim_order[i], (
int)i);
340 for (
const auto& joint_target: this->jointTargets()) {
343 NodeId first_eliminated_node = std::numeric_limits< NodeId >::max();
344 int elim_number = std::numeric_limits< int >::max();
345 unobserved_set.clear();
346 for (
const auto node: joint_target) {
347 if (!_graph_.exists(node)) {
348 if (!hard_ev_nodes.exists(node))
return true;
350 unobserved_set.insert(node);
351 if (elim_order[node] < elim_number) {
352 elim_number = elim_order[node];
353 first_eliminated_node = node;
357 if (!unobserved_set.empty()) {
361 const auto clique_id = _node_to_clique_[first_eliminated_node];
362 const auto& clique = _JT_->clique(clique_id);
363 for (
const auto node: unobserved_set) {
364 if (!clique.contains(node))
return true;
371 for (
const auto& change: _evidence_changes_) {
372 if ((change.second == EvidenceChangeType::EVIDENCE_ADDED) && !_graph_.exists(change.first))
381 template <
typename GUM_SCALAR >
382 void ShaferShenoyMRFInference< GUM_SCALAR >::_createNewJT_() {
394 const auto& mn = this->MRF();
395 _graph_ = mn.graph();
400 for (
const auto& nodeset: this->jointTargets()) {
401 for (
auto iter1 = nodeset.cbegin(); iter1 != nodeset.cend(); ++iter1) {
403 for (++iter2; iter2 != nodeset.cend(); ++iter2) {
404 _graph_.addEdge(*iter1, *iter2);
410 _hard_ev_nodes_ = this->hardEvidenceNodes();
411 for (
const auto node: _hard_ev_nodes_) {
412 _graph_.eraseNode(node);
419 if (_JT_ !=
nullptr)
delete _JT_;
420 if (_junctionTree_ !=
nullptr)
delete _junctionTree_;
422 const auto& domain_sizes = this->domainSizes();
423 _triangulation_->setGraph(&_graph_, &domain_sizes);
424 const JunctionTree& triang_jt = _triangulation_->junctionTree();
425 if (_use_binary_join_tree_) {
426 BinaryJoinTreeConverterDefault bjt_converter;
428 _JT_ =
new CliqueGraph(bjt_converter.convert(triang_jt, domain_sizes, emptyset));
430 _JT_ =
new CliqueGraph(triang_jt);
432 _junctionTree_ =
new CliqueGraph(triang_jt);
436 const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder();
437 Size size_elim_order = JT_elim_order.size();
438 NodeProperty< int > elim_order(size_elim_order);
439 for (Idx i = Idx(0); i < size_elim_order; ++i)
440 elim_order.insert(JT_elim_order[i], (
int)i);
444 _factor_to_clique_.clear();
445 _factor_to_clique_.resize(mn.factors().size());
446 for (
const auto& factor: mn.factors()) {
447 const auto& nodes = factor.first;
448 NodeId first_eliminated_node = std::numeric_limits< NodeId >::max();
449 int elim_number = std::numeric_limits< int >::max();
450 for (
const auto node: nodes) {
451 if (_graph_.exists(node) && (elim_order[node] < elim_number)) {
452 elim_number = elim_order[node];
453 first_eliminated_node = node;
457 if (elim_number != std::numeric_limits< int >::max()) {
461 _factor_to_clique_.insert(
463 _triangulation_->createdJunctionTreeClique(first_eliminated_node));
469 _node_to_clique_.clear();
470 _node_to_clique_.resize(_graph_.size());
471 NodeProperty< double > node_to_clique_size(_graph_.size());
472 for (
const auto node: _graph_) {
473 _node_to_clique_.insert(node, std::numeric_limits< NodeId >::max());
474 node_to_clique_size.insert(node, std::numeric_limits< double >::max());
476 double overall_size = 0;
477 for (
const auto clique_id: *_JT_) {
479 const auto& clique_nodes = _JT_->clique(clique_id);
480 double clique_size = 1.0;
481 for (
const auto node: clique_nodes)
482 clique_size *=
double(domain_sizes[node]);
483 overall_size += clique_size;
487 for (
const auto node: clique_nodes) {
488 if (clique_size < node_to_clique_size[node]) {
489 _node_to_clique_[node] = clique_id;
490 node_to_clique_size[node] = clique_size;
496 _joint_target_to_clique_.clear();
497 for (
const auto& set: this->jointTargets()) {
498 NodeId first_eliminated_node = std::numeric_limits< NodeId >::max();
499 int elim_number = std::numeric_limits< int >::max();
503 for (
const auto node: set) {
504 if (!_hard_ev_nodes_.contains(node)) {
507 if (elim_order[node] < elim_number) {
508 elim_number = elim_order[node];
509 first_eliminated_node = node;
514 if (elim_number != std::numeric_limits< int >::max()) {
515 _joint_target_to_clique_.insert(
517 _triangulation_->createdJunctionTreeClique(first_eliminated_node));
522 _computeJoinTreeRoots_();
527 for (
const auto& pot: _clique_ss_tensor_) {
528 if (_clique_tensors_[pot.first].size() > 1)
delete pot.second;
530 _clique_ss_tensor_.clear();
531 for (
const auto& potlist: _clique_tensors_)
532 for (
const auto pot: potlist.second)
534 _clique_tensors_.clear();
537 for (
const auto& pot: _arc_to_created_tensors_)
539 _arc_to_created_tensors_.clear();
544 _hard_ev_projected_factors_.clear();
547 _node_to_soft_evidence_.clear();
551 _ScheduleMultiDimSet_ empty_set;
552 for (
const auto node: *_JT_) {
553 _clique_tensors_.insert(node, empty_set);
554 _clique_ss_tensor_.insert(node,
nullptr);
562 _separator_tensors_.clear();
563 _messages_computed_.clear();
564 for (
const auto& edge: _JT_->edges()) {
565 const Arc arc1(edge.first(), edge.second());
566 _separator_tensors_.insert(arc1,
nullptr);
567 _messages_computed_.insert(arc1,
false);
568 const Arc arc2(edge.second(), edge.first());
569 _separator_tensors_.insert(arc2,
nullptr);
570 _messages_computed_.insert(arc2,
false);
574 for (
const auto& pot: _target_posteriors_)
576 _target_posteriors_.clear();
577 for (
const auto& pot: _joint_target_posteriors_)
579 _joint_target_posteriors_.clear();
584 _use_schedules_ = (overall_size > _schedule_threshold_);
587 const NodeProperty< const Tensor< GUM_SCALAR >* >& evidence = this->evidence();
588 for (
const auto node: this->softEvidenceNodes()) {
589 if (_node_to_clique_.exists(node)) {
590 auto ev_pot =
new ScheduleMultiDim< Tensor< GUM_SCALAR > >(*evidence[node],
false);
591 _node_to_soft_evidence_.insert(node, ev_pot);
592 _clique_tensors_[_node_to_clique_[node]].insert(ev_pot);
600 if (_use_schedules_) {
602 _initializeJTCliques_(schedule);
604 _initializeJTCliques_();
609 _evidence_changes_.clear();
610 _is_new_jt_needed_ =
false;
614 template <
typename GUM_SCALAR >
615 void ShaferShenoyMRFInference< GUM_SCALAR >::_initializeJTCliques_() {
616 const auto& mn = this->MRF();
622 const NodeProperty< const Tensor< GUM_SCALAR >* >& evidence = this->evidence();
623 const NodeProperty< Idx >& hard_evidence = this->hardEvidence();
625 for (
const auto& factor: mn.factors()) {
626 const auto& factor_nodes = factor.first;
627 const auto& pot = *(factor.second);
628 const auto& variables = pot.variablesSequence();
631 NodeSet hard_nodes(factor_nodes.size());
632 bool graph_contains_nodes =
false;
633 for (
const auto node: factor_nodes) {
634 if (_hard_ev_nodes_.contains(node)) hard_nodes.insert(node);
635 else if (_graph_.exists(node)) graph_contains_nodes =
true;
641 if (hard_nodes.empty()) {
642 auto sched_cpt =
new ScheduleMultiDim< Tensor< GUM_SCALAR > >(pot,
false);
643 _clique_tensors_[_factor_to_clique_[&pot]].insert(sched_cpt);
649 if (hard_nodes.size() == factor_nodes.size()) {
650 Instantiation inst(pot);
651 for (Size i = 0; i < hard_nodes.size(); ++i) {
652 inst.chgVal(*variables[i], hard_evidence[mn.nodeId(*(variables[i]))]);
654 _constants_.insert(&pot, pot.get(inst));
659 if (!graph_contains_nodes)
continue;
663 _TensorSet_ marg_factor_set(1 + hard_nodes.size());
664 marg_factor_set.insert(&pot);
665 for (
const auto node: hard_nodes) {
666 marg_factor_set.insert(evidence[node]);
667 hard_variables.
insert(&(mn.variable(node)));
671 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
675 _TensorSet_ new_factor_list
676 = combine_and_project.execute(marg_factor_set, hard_variables);
679 if (new_factor_list.size() != 1) {
680 for (
const auto pot: new_factor_list) {
681 if (!marg_factor_set.contains(pot))
delete pot;
684 "the projection of a tensor containing " <<
"hard evidence is empty!");
686 auto new_factor =
const_cast< Tensor< GUM_SCALAR >*
>(*(new_factor_list.begin()));
687 auto projected_factor
688 =
new ScheduleMultiDim< Tensor< GUM_SCALAR > >(std::move(*new_factor));
691 _clique_tensors_[_factor_to_clique_[&pot]].insert(projected_factor);
692 _hard_ev_projected_factors_.insert(&pot, projected_factor);
701 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
702 for (
const auto& xpotset: _clique_tensors_) {
703 const auto& potset = xpotset.second;
704 if (potset.size() > 0) {
709 if (potset.size() == 1) {
710 _clique_ss_tensor_[xpotset.first] = *(potset.cbegin());
712 _TensorSet_ p_potset(potset.size());
713 for (
const auto pot: potset)
715 &(
static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(pot)->multiDim()));
717 Tensor< GUM_SCALAR >* joint
718 =
const_cast< Tensor< GUM_SCALAR >*
>(fast_combination.execute(p_potset));
719 _clique_ss_tensor_[xpotset.first]
720 =
new ScheduleMultiDim< Tensor< GUM_SCALAR > >(std::move(*joint));
728 template <
typename GUM_SCALAR >
729 void ShaferShenoyMRFInference< GUM_SCALAR >::_initializeJTCliques_(Schedule& schedule) {
730 const auto& mn = this->MRF();
736 const NodeProperty< const Tensor< GUM_SCALAR >* >& evidence = this->evidence();
737 const NodeProperty< Idx >& hard_evidence = this->hardEvidence();
739 for (
const auto& factor: mn.factors()) {
740 const auto& factor_nodes = factor.first;
741 const auto& pot = *(factor.second);
742 const auto& variables = pot.variablesSequence();
746 bool graph_contains_nodes =
false;
747 for (
const auto node: factor_nodes) {
748 if (_hard_ev_nodes_.contains(node)) hard_nodes.insert(node);
749 else if (_graph_.exists(node)) graph_contains_nodes =
true;
755 if (hard_nodes.empty()) {
756 auto sched_cpt =
new ScheduleMultiDim< Tensor< GUM_SCALAR > >(pot,
false);
757 _clique_tensors_[_factor_to_clique_[&pot]].insert(sched_cpt);
763 if (hard_nodes.size() == factor_nodes.size()) {
764 Instantiation inst(pot);
765 for (Size i = 0; i < hard_nodes.size(); ++i) {
766 inst.chgVal(*variables[i], hard_evidence[mn.nodeId(*(variables[i]))]);
768 _constants_.insert(&pot, pot.get(inst));
773 if (!graph_contains_nodes)
continue;
777 _ScheduleMultiDimSet_ marg_factor_set(1 + hard_nodes.size());
778 const IScheduleMultiDim* sched_pot
779 = schedule.insertTable< Tensor< GUM_SCALAR > >(pot,
false);
780 marg_factor_set.insert(sched_pot);
782 for (
const auto node: hard_nodes) {
783 const IScheduleMultiDim* pot
784 = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[node],
false);
785 marg_factor_set.insert(pot);
786 hard_variables.
insert(&(mn.variable(node)));
790 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
794 _ScheduleMultiDimSet_ new_factor_list
795 = combine_and_project.schedule(schedule, marg_factor_set, hard_variables);
798 if (new_factor_list.size() != 1) {
800 "the projection of a tensor containing " <<
"hard evidence is empty!");
802 auto projected_factor =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
803 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
804 *new_factor_list.begin()));
805 const_cast< ScheduleOperator*
>(schedule.scheduleMultiDimCreator(projected_factor))
806 ->makeResultsPersistent(
true);
808 _clique_tensors_[_factor_to_clique_[&pot]].insert(projected_factor);
809 _hard_ev_projected_factors_.insert(&pot, projected_factor);
813 this->scheduler().execute(schedule);
820 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
821 for (
const auto& xpotset: _clique_tensors_) {
822 const auto& potset = xpotset.second;
823 if (potset.size() > 0) {
828 if (potset.size() == 1) {
829 _clique_ss_tensor_[xpotset.first] = *(potset.cbegin());
832 for (
const auto pot: potset) {
833 schedule.emplaceScheduleMultiDim(*pot);
836 auto joint =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
837 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
838 fast_combination.schedule(schedule, potset)));
839 const_cast< ScheduleOperator*
>(schedule.scheduleMultiDimCreator(joint))
840 ->makeResultsPersistent(
true);
841 _clique_ss_tensor_[xpotset.first] = joint;
845 this->scheduler().execute(schedule);
849 template <
typename GUM_SCALAR >
850 void ShaferShenoyMRFInference< GUM_SCALAR >::updateOutdatedStructure_() {
852 if (_isNewJTNeeded_()) {
858 updateOutdatedTensors_();
863 template <
typename GUM_SCALAR >
864 void ShaferShenoyMRFInference< GUM_SCALAR >::_diffuseMessageInvalidations_(
867 NodeSet& invalidated_cliques) {
869 invalidated_cliques.insert(to_id);
872 const Arc arc(from_id, to_id);
873 bool& message_computed = _messages_computed_[arc];
874 if (message_computed) {
875 message_computed =
false;
876 _separator_tensors_[arc] =
nullptr;
877 if (_arc_to_created_tensors_.exists(arc)) {
878 delete _arc_to_created_tensors_[arc];
879 _arc_to_created_tensors_.erase(arc);
883 for (
const auto node_id: _JT_->neighbours(to_id)) {
884 if (node_id != from_id) _diffuseMessageInvalidations_(to_id, node_id, invalidated_cliques);
891 template <
typename GUM_SCALAR >
892 void ShaferShenoyMRFInference< GUM_SCALAR >::updateOutdatedTensors_() {
897 NodeProperty< bool > ss_tensor_to_deallocate(_clique_tensors_.size());
898 for (
const auto& potset: _clique_tensors_) {
899 ss_tensor_to_deallocate.insert(potset.first, (potset.second.size() > 1));
909 const auto& mn = this->MRF();
910 NodeSet hard_nodes_changed(_hard_ev_nodes_.size());
911 Set< const Tensor< GUM_SCALAR >* > hard_projected_factors_changed(mn.factors().size());
912 for (
const auto node: _hard_ev_nodes_) {
913 if (_evidence_changes_.exists(node)) {
914 hard_nodes_changed.insert(node);
915 for (
const auto pot: _node_to_factors_[node]) {
916 if (_hard_ev_projected_factors_.exists(pot)
917 && !hard_projected_factors_changed.exists(pot)) {
918 hard_projected_factors_changed.insert(pot);
924 NodeSet hard_cliques_changed(hard_projected_factors_changed.size());
925 for (
const auto pot: hard_projected_factors_changed) {
926 const auto chgPot = _hard_ev_projected_factors_[pot];
927 const NodeId chgClique = _factor_to_clique_[pot];
928 _clique_tensors_[chgClique].erase(chgPot);
929 _hard_ev_projected_factors_.erase(pot);
930 if (!hard_cliques_changed.contains(chgClique)) hard_cliques_changed.insert(chgClique);
942 NodeSet invalidated_cliques(_JT_->size());
943 for (
const auto& pair: _evidence_changes_) {
944 if (_node_to_clique_.exists(pair.first)) {
945 const auto clique = _node_to_clique_[pair.first];
946 invalidated_cliques.insert(clique);
947 for (
const auto neighbor: _JT_->neighbours(clique)) {
948 _diffuseMessageInvalidations_(clique, neighbor, invalidated_cliques);
955 for (
const auto clique: hard_cliques_changed) {
956 invalidated_cliques.insert(clique);
957 for (
const auto neighbor: _JT_->neighbours(clique)) {
958 _diffuseMessageInvalidations_(clique, neighbor, invalidated_cliques);
964 for (
const auto clique: invalidated_cliques) {
965 if (ss_tensor_to_deallocate[clique]) {
966 delete _clique_ss_tensor_[clique];
967 _clique_ss_tensor_[clique] =
nullptr;
974 if (!_target_posteriors_.empty()) {
975 for (
auto iter = _target_posteriors_.beginSafe(); iter != _target_posteriors_.endSafe();
979 if (_graph_.exists(iter.key())
980 && (invalidated_cliques.exists(_node_to_clique_[iter.key()]))) {
982 _target_posteriors_.erase(iter);
985 else if (hard_nodes_changed.contains(iter.key())) {
987 _target_posteriors_.erase(iter);
994 for (
auto iter = _joint_target_posteriors_.beginSafe();
995 iter != _joint_target_posteriors_.endSafe();
997 if (invalidated_cliques.exists(_joint_target_to_clique_[iter.key()])) {
999 _joint_target_posteriors_.erase(iter);
1002 bool has_unevidenced_node =
false;
1003 for (
const auto node: iter.key()) {
1004 if (!hard_nodes_changed.exists(node)) {
1005 has_unevidenced_node =
true;
1009 if (!has_unevidenced_node) {
1011 _joint_target_posteriors_.erase(iter);
1018 for (
const auto& pot_pair: _node_to_soft_evidence_) {
1019 delete pot_pair.second;
1020 _clique_tensors_[_node_to_clique_[pot_pair.first]].erase(pot_pair.second);
1022 _node_to_soft_evidence_.clear();
1024 const auto& evidence = this->evidence();
1025 for (
const auto node: this->softEvidenceNodes()) {
1026 auto ev_pot =
new ScheduleMultiDim< Tensor< GUM_SCALAR > >(*evidence[node],
false);
1027 _node_to_soft_evidence_.insert(node, ev_pot);
1028 _clique_tensors_[_node_to_clique_[node]].insert(ev_pot);
1038 if (_use_schedules_) {
1040 for (
const auto pot: hard_projected_factors_changed) {
1041 _ScheduleMultiDimSet_ marg_pot_set;
1042 const auto sched_pot = schedule.insertTable< Tensor< GUM_SCALAR > >(*pot,
false);
1043 marg_pot_set.insert(sched_pot);
1044 const auto& variables = pot->variablesSequence();
1046 for (
const auto var: variables) {
1047 NodeId xnode = mn.nodeId(*var);
1048 if (_hard_ev_nodes_.exists(xnode)) {
1050 = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[xnode],
false);
1051 marg_pot_set.insert(ev_pot);
1052 hard_variables.
insert(var);
1057 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
1061 _ScheduleMultiDimSet_ new_pot_list
1062 = combine_and_project.schedule(schedule, marg_pot_set, hard_variables);
1065 if (new_pot_list.size() != 1) {
1067 "the projection of a tensor containing " <<
"hard evidence is empty!");
1069 auto projected_pot =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1070 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(*new_pot_list.begin()));
1071 const_cast< ScheduleOperator*
>(schedule.scheduleMultiDimCreator(projected_pot))
1072 ->makeResultsPersistent(
true);
1073 _clique_tensors_[_factor_to_clique_[pot]].insert(projected_pot);
1074 _hard_ev_projected_factors_.insert(pot, projected_pot);
1080 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1081 for (
const auto clique: invalidated_cliques) {
1082 const auto& potset = _clique_tensors_[clique];
1084 if (potset.size() > 0) {
1089 if (potset.size() == 1) {
1090 _clique_ss_tensor_[clique] = *(potset.cbegin());
1092 for (
const auto pot: potset)
1093 if (!schedule.existsScheduleMultiDim(pot->id()))
1094 schedule.emplaceScheduleMultiDim(*pot);
1095 auto joint =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1096 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1097 fast_combination.schedule(schedule, potset)));
1098 const_cast< ScheduleOperator*
>(schedule.scheduleMultiDimCreator(joint))
1099 ->makeResultsPersistent(
true);
1100 _clique_ss_tensor_[clique] = joint;
1104 this->scheduler().execute(schedule);
1106 for (
const auto pot: hard_projected_factors_changed) {
1107 _TensorSet_ marg_pot_set;
1108 marg_pot_set.insert(pot);
1109 const auto& variables = pot->variablesSequence();
1112 for (
const auto var: variables) {
1113 NodeId xnode = mn.nodeId(*var);
1114 if (_hard_ev_nodes_.exists(xnode)) {
1115 marg_pot_set.insert(evidence[xnode]);
1116 hard_variables.
insert(var);
1121 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(
1125 _TensorSet_ new_pot_list = combine_and_project.execute(marg_pot_set, hard_variables);
1128 if (new_pot_list.size() != 1) {
1130 "the projection of a tensor containing " <<
"hard evidence is empty!");
1132 Tensor< GUM_SCALAR >* xprojected_pot
1133 =
const_cast< Tensor< GUM_SCALAR >*
>(*new_pot_list.begin());
1135 =
new ScheduleMultiDim< Tensor< GUM_SCALAR > >(std::move(*xprojected_pot));
1136 delete xprojected_pot;
1137 _clique_tensors_[_factor_to_clique_[pot]].insert(projected_pot);
1138 _hard_ev_projected_factors_.insert(pot, projected_pot);
1144 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1145 for (
const auto clique: invalidated_cliques) {
1146 const auto& potset = _clique_tensors_[clique];
1148 if (potset.size() > 0) {
1153 if (potset.size() == 1) {
1154 _clique_ss_tensor_[clique] = *(potset.cbegin());
1156 _TensorSet_ p_potset(potset.size());
1157 for (
const auto pot: potset)
1159 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(pot)->multiDim()));
1161 Tensor< GUM_SCALAR >* joint
1162 =
const_cast< Tensor< GUM_SCALAR >*
>(fast_combination.execute(p_potset));
1163 _clique_ss_tensor_[clique]
1164 =
new ScheduleMultiDim< Tensor< GUM_SCALAR > >(std::move(*joint));
1172 const auto& hard_evidence = this->hardEvidence();
1173 for (
auto& node_cst: _constants_) {
1174 const Tensor< GUM_SCALAR >& pot = *(node_cst.first);
1175 Instantiation inst(pot);
1176 for (
const auto var: pot.variablesSequence()) {
1177 inst.chgVal(*var, hard_evidence[mn.nodeId(*var)]);
1179 node_cst.second = pot.get(inst);
1183 _evidence_changes_.clear();
1187 template <
typename GUM_SCALAR >
1188 void ShaferShenoyMRFInference< GUM_SCALAR >::_computeJoinTreeRoots_() {
1193 for (
const auto node: this->targets()) {
1195 clique_targets.insert(_node_to_clique_[node]);
1196 }
catch (Exception
const&) {}
1198 for (
const auto& set: this->jointTargets()) {
1200 clique_targets.insert(_joint_target_to_clique_[set]);
1201 }
catch (Exception
const&) {}
1205 std::vector< std::pair< NodeId, Size > > possible_roots(clique_targets.size());
1206 const auto& mn = this->MRF();
1208 for (
const auto clique_id: clique_targets) {
1209 const auto& clique = _JT_->clique(clique_id);
1211 for (
const auto node: clique) {
1212 dom_size *= mn.variable(node).domainSize();
1214 possible_roots[i] = std::pair< NodeId, Size >(clique_id, dom_size);
1219 std::sort(possible_roots.begin(),
1220 possible_roots.end(),
1221 [](
const std::pair< NodeId, Size >& a,
const std::pair< NodeId, Size >& b) ->
bool {
1222 return a.second < b.second;
1226 NodeProperty< bool > marked = _JT_->nodesPropertyFromVal(
false);
1227 std::function< void(NodeId, NodeId) > diffuse_marks
1228 = [&marked, &diffuse_marks,
this](NodeId node, NodeId from) {
1229 if (!marked[node]) {
1230 marked[node] =
true;
1231 for (
const auto neigh: _JT_->neighbours(node))
1232 if ((neigh != from) && !marked[neigh]) diffuse_marks(neigh, node);
1236 for (
const auto& xclique: possible_roots) {
1237 NodeId clique = xclique.first;
1238 if (!marked[clique]) {
1239 _roots_.insert(clique);
1240 diffuse_marks(clique, clique);
1246 template <
typename GUM_SCALAR >
1247 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::_collectMessage_(Schedule& schedule,
1250 for (
const auto other: _JT_->neighbours(
id)) {
1251 if ((other != from) && !_messages_computed_[Arc(other,
id)])
1252 _collectMessage_(schedule, other,
id);
1255 if ((
id != from) && !_messages_computed_[Arc(
id, from)]) {
1256 _produceMessage_(schedule,
id, from);
1261 template <
typename GUM_SCALAR >
1262 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::_collectMessage_(NodeId
id, NodeId from) {
1263 for (
const auto other: _JT_->neighbours(
id)) {
1264 if ((other != from) && !_messages_computed_[Arc(other,
id)]) _collectMessage_(other,
id);
1267 if ((
id != from) && !_messages_computed_[Arc(
id, from)]) { _produceMessage_(
id, from); }
1271 template <
typename GUM_SCALAR >
1272 const IScheduleMultiDim* ShaferShenoyMRFInference< GUM_SCALAR >::_marginalizeOut_(
1274 Set< const IScheduleMultiDim* > pot_list,
1279 for (
const auto pot: pot_list) {
1280 if (!schedule.existsScheduleMultiDim(pot->id())) schedule.emplaceScheduleMultiDim(*pot);
1285 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(_combination_op_,
1287 _ScheduleMultiDimSet_ new_pot_list = combine_and_project.schedule(schedule, pot_list, del_vars);
1290 if (new_pot_list.size() == 1)
return *(new_pot_list.begin());
1291 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1292 return fast_combination.schedule(schedule, new_pot_list);
1296 template <
typename GUM_SCALAR >
1297 const IScheduleMultiDim* ShaferShenoyMRFInference< GUM_SCALAR >::_marginalizeOut_(
1298 Set< const IScheduleMultiDim* >& pot_list,
1301 _TensorSet_ xpot_list(pot_list.size());
1302 for (
auto pot: pot_list)
1304 &(
static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(pot)->multiDim()));
1308 MultiDimCombineAndProjectDefault< Tensor< GUM_SCALAR > > combine_and_project(_combination_op_,
1310 _TensorSet_ xnew_pot_list = combine_and_project.execute(xpot_list, del_vars);
1313 const Tensor< GUM_SCALAR >* xres_pot;
1314 if (xnew_pot_list.size() == 1) {
1315 xres_pot = *(xnew_pot_list.begin());
1319 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1320 xres_pot = fast_combination.execute(xnew_pot_list);
1321 for (
const auto pot: xnew_pot_list) {
1322 if (!xpot_list.contains(pot) && (pot != xres_pot))
delete pot;
1327 ScheduleMultiDim< Tensor< GUM_SCALAR > >* res_pot;
1328 if (xpot_list.contains(xres_pot))
1329 res_pot =
new ScheduleMultiDim< Tensor< GUM_SCALAR > >(*xres_pot,
false);
1331 res_pot =
new ScheduleMultiDim< Tensor< GUM_SCALAR > >(
1332 std::move(
const_cast< Tensor< GUM_SCALAR >&
>(*xres_pot)));
1340 template <
typename GUM_SCALAR >
1341 void ShaferShenoyMRFInference< GUM_SCALAR >::_produceMessage_(Schedule& schedule,
1345 _ScheduleMultiDimSet_ pot_list;
1346 if (_clique_ss_tensor_[from_id] !=
nullptr) pot_list.insert(_clique_ss_tensor_[from_id]);
1349 for (
const auto other_id: _JT_->neighbours(from_id)) {
1350 if (other_id != to_id) {
1351 const auto separator_pot = _separator_tensors_[Arc(other_id, from_id)];
1352 if (separator_pot !=
nullptr) pot_list.insert(separator_pot);
1357 const NodeSet& from_clique = _JT_->clique(from_id);
1358 const NodeSet& separator = _JT_->separator(from_id, to_id);
1361 const auto& mn = this->MRF();
1363 for (
const auto node: from_clique) {
1364 if (!separator.contains(node)) {
1365 del_vars.
insert(&(mn.variable(node)));
1367 kept_vars.
insert(&(mn.variable(node)));
1373 const IScheduleMultiDim* new_pot = _marginalizeOut_(schedule, pot_list, del_vars, kept_vars);
1376 const Arc arc(from_id, to_id);
1377 if (!pot_list.exists(new_pot)) {
1378 if (!_arc_to_created_tensors_.exists(arc)) {
1379 _arc_to_created_tensors_.insert(arc, new_pot);
1382 auto op = schedule.scheduleMultiDimCreator(new_pot);
1383 if (op !=
nullptr)
const_cast< ScheduleOperator*
>(op)->makeResultsPersistent(
true);
1387 _separator_tensors_[arc] = new_pot;
1388 _messages_computed_[arc] =
true;
1392 template <
typename GUM_SCALAR >
1393 void ShaferShenoyMRFInference< GUM_SCALAR >::_produceMessage_(NodeId from_id, NodeId to_id) {
1395 _ScheduleMultiDimSet_ pot_list;
1396 if (_clique_ss_tensor_[from_id] !=
nullptr) pot_list.insert(_clique_ss_tensor_[from_id]);
1399 for (
const auto other_id: _JT_->neighbours(from_id)) {
1400 if (other_id != to_id) {
1401 const auto separator_pot = _separator_tensors_[Arc(other_id, from_id)];
1402 if (separator_pot !=
nullptr) pot_list.insert(separator_pot);
1407 const NodeSet& from_clique = _JT_->clique(from_id);
1408 const NodeSet& separator = _JT_->separator(from_id, to_id);
1411 const auto& mn = this->MRF();
1413 for (
const auto node: from_clique) {
1414 if (!separator.contains(node)) {
1415 del_vars.
insert(&(mn.variable(node)));
1417 kept_vars.
insert(&(mn.variable(node)));
1423 const IScheduleMultiDim* new_pot = _marginalizeOut_(pot_list, del_vars, kept_vars);
1426 const Arc arc(from_id, to_id);
1427 if (!pot_list.exists(new_pot)) {
1428 if (!_arc_to_created_tensors_.exists(arc)) { _arc_to_created_tensors_.insert(arc, new_pot); }
1431 _separator_tensors_[arc] = new_pot;
1432 _messages_computed_[arc] =
true;
1436 template <
typename GUM_SCALAR >
1437 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::onModelChanged_(
const GraphicalModel* mn) {
1438 JointTargetedMRFInference< GUM_SCALAR >::onModelChanged_(mn);
1442 template <
typename GUM_SCALAR >
1443 INLINE
void ShaferShenoyMRFInference< GUM_SCALAR >::makeInference_() {
1444 if (_use_schedules_) {
1448 for (
const auto node: this->targets()) {
1452 if (_graph_.exists(node)) {
1453 _collectMessage_(schedule, _node_to_clique_[node], _node_to_clique_[node]);
1461 for (
const auto& set: _joint_target_to_clique_)
1462 _collectMessage_(schedule, set.second, set.second);
1465 this->scheduler().execute(schedule);
1468 for (
const auto node: this->targets()) {
1472 if (_graph_.exists(node)) {
1473 _collectMessage_(_node_to_clique_[node], _node_to_clique_[node]);
1481 for (
const auto& set: _joint_target_to_clique_)
1482 _collectMessage_(set.second, set.second);
1487 template <
typename GUM_SCALAR >
1488 Tensor< GUM_SCALAR >*
1489 ShaferShenoyMRFInference< GUM_SCALAR >::unnormalizedJointPosterior_(NodeId
id) {
1490 if (_use_schedules_) {
1492 return _unnormalizedJointPosterior_(schedule,
id);
1494 return _unnormalizedJointPosterior_(
id);
1499 template <
typename GUM_SCALAR >
1500 Tensor< GUM_SCALAR >*
1501 ShaferShenoyMRFInference< GUM_SCALAR >::_unnormalizedJointPosterior_(Schedule& schedule,
1503 const auto& mn = this->MRF();
1507 if (this->hardEvidenceNodes().contains(
id)) {
1508 return new Tensor< GUM_SCALAR >(*(this->evidence()[
id]));
1511 auto& scheduler = this->scheduler();
1515 const NodeId clique_of_id = _node_to_clique_[id];
1516 _collectMessage_(schedule, clique_of_id, clique_of_id);
1521 _ScheduleMultiDimSet_ pot_list;
1522 if (_clique_ss_tensor_[clique_of_id] !=
nullptr)
1523 pot_list.insert(_clique_ss_tensor_[clique_of_id]);
1526 for (
const auto other: _JT_->neighbours(clique_of_id))
1527 pot_list.insert(_separator_tensors_[Arc(other, clique_of_id)]);
1530 const NodeSet& nodes = _JT_->clique(clique_of_id);
1533 for (
const auto node: nodes) {
1534 if (node !=
id) del_vars.
insert(&(mn.variable(node)));
1539 auto resulting_pot =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1540 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1541 _marginalizeOut_(schedule, pot_list, del_vars, kept_vars)));
1542 Tensor< GUM_SCALAR >* joint =
nullptr;
1544 scheduler.execute(schedule);
1548 if (pot_list.exists(resulting_pot)) {
1549 joint =
new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1551 joint = resulting_pot->exportMultiDim();
1557 bool nonzero_found =
false;
1558 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1559 if (joint->get(inst)) {
1560 nonzero_found =
true;
1564 if (!nonzero_found) {
1568 "some evidence entered into the Markov "
1569 "net are incompatible (their joint proba = 0)");
1575 template <
typename GUM_SCALAR >
1576 Tensor< GUM_SCALAR >*
1577 ShaferShenoyMRFInference< GUM_SCALAR >::_unnormalizedJointPosterior_(NodeId
id) {
1578 const auto& mn = this->MRF();
1582 if (this->hardEvidenceNodes().contains(
id)) {
1583 return new Tensor< GUM_SCALAR >(*(this->evidence()[
id]));
1588 NodeId clique_of_id = _node_to_clique_[id];
1589 _collectMessage_(clique_of_id, clique_of_id);
1594 _ScheduleMultiDimSet_ pot_list;
1595 if (_clique_ss_tensor_[clique_of_id] !=
nullptr)
1596 pot_list.insert(_clique_ss_tensor_[clique_of_id]);
1599 for (
const auto other: _JT_->neighbours(clique_of_id))
1600 pot_list.insert(_separator_tensors_[Arc(other, clique_of_id)]);
1603 const NodeSet& nodes = _JT_->clique(clique_of_id);
1606 for (
const auto node: nodes) {
1607 if (node !=
id) del_vars.
insert(&(mn.variable(node)));
1612 auto resulting_pot =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1613 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1614 _marginalizeOut_(pot_list, del_vars, kept_vars)));
1615 Tensor< GUM_SCALAR >* joint =
nullptr;
1619 if (pot_list.exists(resulting_pot)) {
1620 joint =
new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1622 joint = resulting_pot->exportMultiDim();
1623 delete resulting_pot;
1629 bool nonzero_found =
false;
1630 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1631 if (joint->get(inst)) {
1632 nonzero_found =
true;
1636 if (!nonzero_found) {
1640 "some evidence entered into the Markov "
1641 "net are incompatible (their joint proba = 0)");
1647 template <
typename GUM_SCALAR >
1648 const Tensor< GUM_SCALAR >& ShaferShenoyMRFInference< GUM_SCALAR >::posterior_(NodeId
id) {
1650 if (_target_posteriors_.exists(
id)) {
return *(_target_posteriors_[id]); }
1653 auto joint = unnormalizedJointPosterior_(
id);
1654 if (joint->sum() != 1)
1656 _target_posteriors_.insert(
id, joint);
1662 template <
typename GUM_SCALAR >
1663 Tensor< GUM_SCALAR >*
1664 ShaferShenoyMRFInference< GUM_SCALAR >::unnormalizedJointPosterior_(
const NodeSet& set) {
1665 if (_use_schedules_) {
1667 return _unnormalizedJointPosterior_(schedule, set);
1669 return _unnormalizedJointPosterior_(set);
1674 template <
typename GUM_SCALAR >
1675 Tensor< GUM_SCALAR >*
1676 ShaferShenoyMRFInference< GUM_SCALAR >::_unnormalizedJointPosterior_(Schedule& schedule,
1677 const NodeSet& set) {
1680 NodeSet targets = set, hard_ev_nodes;
1681 for (
const auto node: this->hardEvidenceNodes()) {
1682 if (targets.contains(node)) {
1683 targets.erase(node);
1684 hard_ev_nodes.insert(node);
1688 auto& scheduler = this->scheduler();
1692 const auto& evidence = this->evidence();
1693 if (targets.empty()) {
1694 if (set.size() == 1) {
1695 return new Tensor< GUM_SCALAR >(*evidence[*set.begin()]);
1697 _ScheduleMultiDimSet_ pot_list;
1698 for (
const auto node: set) {
1699 auto new_pot_ev = schedule.insertTable< Tensor< GUM_SCALAR > >(*evidence[node],
false);
1700 pot_list.insert(new_pot_ev);
1704 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1705 const IScheduleMultiDim* pot = fast_combination.schedule(schedule, pot_list);
1706 auto schedule_pot =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1707 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(pot));
1708 scheduler.execute(schedule);
1709 auto result = schedule_pot->exportMultiDim();
1719 NodeId clique_of_set;
1721 clique_of_set = _joint_target_to_clique_[set];
1728 for (
const auto node: targets) {
1729 if (!_graph_.exists(node)) {
1731 "The variable " << this->MRF().variable(node).name() <<
"(" << node
1732 <<
") does not belong to this optimized inference.")
1738 const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder();
1739 NodeProperty< int > elim_order(Size(JT_elim_order.size()));
1740 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size; ++i)
1741 elim_order.insert(JT_elim_order[i], (
int)i);
1742 NodeId first_eliminated_node = *(targets.begin());
1743 int elim_number = elim_order[first_eliminated_node];
1744 for (
const auto node: targets) {
1745 if (elim_order[node] < elim_number) {
1746 elim_number = elim_order[node];
1747 first_eliminated_node = node;
1751 clique_of_set = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
1755 const NodeSet& clique_nodes = _JT_->clique(clique_of_set);
1756 for (
const auto node: targets) {
1757 if (!clique_nodes.contains(node)) {
1759 this->MRF().names(set) <<
"(" << set <<
")"
1760 <<
" is not addressable in this optimized inference.")
1765 _joint_target_to_clique_.
insert(set, clique_of_set);
1769 _collectMessage_(schedule, clique_of_set, clique_of_set);
1774 _ScheduleMultiDimSet_ pot_list;
1775 if (_clique_ss_tensor_[clique_of_set] !=
nullptr) {
1776 auto pot = _clique_ss_tensor_[clique_of_set];
1777 if (!schedule.existsScheduleMultiDim(pot->id())) schedule.emplaceScheduleMultiDim(*pot);
1778 pot_list.insert(_clique_ss_tensor_[clique_of_set]);
1782 for (
const auto other: _JT_->neighbours(clique_of_set)) {
1783 const auto pot = _separator_tensors_[Arc(other, clique_of_set)];
1784 if (pot !=
nullptr) pot_list.insert(pot);
1789 const NodeSet& nodes = _JT_->clique(clique_of_set);
1792 const auto& mn = this->MRF();
1793 for (
const auto node: nodes) {
1794 if (!targets.contains(node)) {
1795 del_vars.
insert(&(mn.variable(node)));
1797 kept_vars.
insert(&(mn.variable(node)));
1803 const IScheduleMultiDim* new_pot = _marginalizeOut_(schedule, pot_list, del_vars, kept_vars);
1804 scheduler.execute(schedule);
1805 ScheduleMultiDim< Tensor< GUM_SCALAR > >* resulting_pot
1806 =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1807 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(new_pot));
1811 Tensor< GUM_SCALAR >* joint =
nullptr;
1812 if (pot_list.exists(resulting_pot)) {
1813 joint =
new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1815 joint = resulting_pot->exportMultiDim();
1820 bool nonzero_found =
false;
1821 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1822 if ((*joint)[inst]) {
1823 nonzero_found =
true;
1827 if (!nonzero_found) {
1831 "some evidence entered into the Markov "
1832 "net are incompatible (their joint proba = 0)");
1839 template <
typename GUM_SCALAR >
1840 Tensor< GUM_SCALAR >*
1841 ShaferShenoyMRFInference< GUM_SCALAR >::_unnormalizedJointPosterior_(
const NodeSet& set) {
1844 NodeSet targets = set, hard_ev_nodes;
1845 for (
const auto node: this->hardEvidenceNodes()) {
1846 if (targets.contains(node)) {
1847 targets.erase(node);
1848 hard_ev_nodes.insert(node);
1854 const auto& evidence = this->evidence();
1855 if (targets.empty()) {
1856 if (set.size() == 1) {
1857 return new Tensor< GUM_SCALAR >(*evidence[*set.begin()]);
1859 _TensorSet_ pot_list;
1860 for (
const auto node: set) {
1861 pot_list.insert(evidence[node]);
1865 MultiDimCombinationDefault< Tensor< GUM_SCALAR > > fast_combination(_combination_op_);
1866 const Tensor< GUM_SCALAR >* pot = fast_combination.execute(pot_list);
1868 return const_cast< Tensor< GUM_SCALAR >*
>(pot);
1876 NodeId clique_of_set;
1878 clique_of_set = _joint_target_to_clique_[set];
1885 for (
const auto node: targets) {
1886 if (!_graph_.exists(node)) {
1888 "The variable " << this->MRF().variable(node).name() <<
"(" << node
1889 <<
") does not belong to this optimized inference.")
1895 const std::vector< NodeId >& JT_elim_order = _triangulation_->eliminationOrder();
1896 NodeProperty< int > elim_order(Size(JT_elim_order.size()));
1897 for (std::size_t i = std::size_t(0), size = JT_elim_order.size(); i < size; ++i)
1898 elim_order.insert(JT_elim_order[i], (
int)i);
1899 NodeId first_eliminated_node = *(targets.begin());
1900 int elim_number = elim_order[first_eliminated_node];
1901 for (
const auto node: targets) {
1902 if (elim_order[node] < elim_number) {
1903 elim_number = elim_order[node];
1904 first_eliminated_node = node;
1908 clique_of_set = _triangulation_->createdJunctionTreeClique(first_eliminated_node);
1912 const NodeSet& clique_nodes = _JT_->clique(clique_of_set);
1913 for (
const auto node: targets) {
1914 if (!clique_nodes.contains(node)) {
1920 _joint_target_to_clique_.
insert(set, clique_of_set);
1924 _collectMessage_(clique_of_set, clique_of_set);
1929 _ScheduleMultiDimSet_ pot_list;
1930 if (_clique_ss_tensor_[clique_of_set] !=
nullptr) {
1931 auto pot = _clique_ss_tensor_[clique_of_set];
1932 if (pot !=
nullptr) pot_list.insert(_clique_ss_tensor_[clique_of_set]);
1936 for (
const auto other: _JT_->neighbours(clique_of_set)) {
1937 const auto pot = _separator_tensors_[Arc(other, clique_of_set)];
1938 if (pot !=
nullptr) pot_list.insert(pot);
1943 const NodeSet& nodes = _JT_->clique(clique_of_set);
1946 const auto& mn = this->MRF();
1947 for (
const auto node: nodes) {
1948 if (!targets.contains(node)) {
1949 del_vars.
insert(&(mn.variable(node)));
1951 kept_vars.
insert(&(mn.variable(node)));
1957 const IScheduleMultiDim* new_pot = _marginalizeOut_(pot_list, del_vars, kept_vars);
1958 ScheduleMultiDim< Tensor< GUM_SCALAR > >* resulting_pot
1959 =
const_cast< ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(
1960 static_cast< const ScheduleMultiDim< Tensor< GUM_SCALAR >
>* >(new_pot));
1964 Tensor< GUM_SCALAR >* joint =
nullptr;
1965 if (pot_list.exists(resulting_pot)) {
1966 joint =
new Tensor< GUM_SCALAR >(resulting_pot->multiDim());
1968 joint = resulting_pot->exportMultiDim();
1974 bool nonzero_found =
false;
1975 for (Instantiation inst(*joint); !inst.end(); ++inst) {
1976 if ((*joint)[inst]) {
1977 nonzero_found =
true;
1981 if (!nonzero_found) {
1985 "some evidence entered into the Markov "
1986 "net are incompatible (their joint proba = 0)");
1993 template <
typename GUM_SCALAR >
1994 const Tensor< GUM_SCALAR >&
1995 ShaferShenoyMRFInference< GUM_SCALAR >::jointPosterior_(
const NodeSet& set) {
1997 if (_joint_target_posteriors_.exists(set)) {
return *(_joint_target_posteriors_[set]); }
2000 auto joint = unnormalizedJointPosterior_(set);
2002 _joint_target_posteriors_.insert(set, joint);
2008 template <
typename GUM_SCALAR >
2009 const Tensor< GUM_SCALAR >&
2010 ShaferShenoyMRFInference< GUM_SCALAR >::jointPosterior_(
const NodeSet& wanted_target,
2011 const NodeSet& declared_target) {
2013 if (_joint_target_posteriors_.exists(wanted_target))
2014 return *(_joint_target_posteriors_[wanted_target]);
2020 if (!_joint_target_posteriors_.exists(declared_target)) {
2021 return jointPosterior_(declared_target);
2025 const auto& mn = this->MRF();
2027 for (
const auto node: declared_target)
2028 if (!wanted_target.contains(node)) del_vars.
insert(&(mn.variable(node)));
2030 =
new Tensor< GUM_SCALAR >(_joint_target_posteriors_[declared_target]->sumOut(del_vars));
2033 _joint_target_posteriors_.insert(wanted_target, pot);
2038 template <
typename GUM_SCALAR >
2039 GUM_SCALAR ShaferShenoyMRFInference< GUM_SCALAR >::evidenceProbability() {
2041 this->makeInference();
2049 GUM_SCALAR prob_ev = 1;
2050 for (
const auto root: _roots_) {
2052 NodeId node = *(_JT_->clique(root).begin());
2053 Tensor< GUM_SCALAR >* tmp = unnormalizedJointPosterior_(node);
2054 prob_ev *= tmp->sum();
2058 for (
const auto& projected_cpt: _constants_)
2059 prob_ev *= projected_cpt.second;
2064 template <
typename GUM_SCALAR >
2065 bool ShaferShenoyMRFInference< GUM_SCALAR >::isExactJointComputable_(
const NodeSet& vars) {
2066 if (JointTargetedMRFInference< GUM_SCALAR >::isExactJointComputable_(vars))
return true;
2068 this->prepareInference();
2070 for (
const auto& node: this->_JT_->nodes()) {
2071 const auto clique = _JT_->clique(node);
2072 if (vars == clique)
return true;
2077 template <
typename GUM_SCALAR >
2078 NodeSet ShaferShenoyMRFInference< GUM_SCALAR >::superForJointComputable_(
const NodeSet& vars) {
2079 const auto superset = JointTargetedMRFInference< GUM_SCALAR >::superForJointComputable_(vars);
2080 if (!superset.empty())
return superset;
2082 this->prepareInference();
2084 for (
const auto& node: _JT_->nodes()) {
2085 const auto clique = _JT_->clique(node);
2086 if (vars.isStrictSubsetOf(clique))
return clique;
Implementation of Shafer-Shenoy's algorithm for inference in Markov random fields.
An algorithm for converting a join tree into a binary join tree.
Exception : a similar element already exists.
<agrum/MRF/inference/evidenceMRFInference.h>
Exception : fatal (unknown ?) error.
Class representing the minimal interface for Markov random field.
Exception : several evidence are incompatible together (proba=0).
<agrum/MRF/inference/jointTargetedMRFInference.h>
Exception : the element we looked for cannot be found.
void insert(const Key &k)
Inserts a new element into the set.
ShaferShenoyMRFInference(const IMarkovRandomField< GUM_SCALAR > *MN, bool use_binary_join_tree=true)
default constructor
Exception : a looked-for element could not be found.
#define GUM_ERROR(type, msg)
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
Header files of gum::Instantiation.
gum is the global namespace for all aGrUM entities
Set< const DiscreteVariable * > VariableSet
CliqueGraph JoinTree
a join tree is a clique graph satisfying the running intersection property (but some cliques may be i...
CliqueGraph JunctionTree
a junction tree is a clique graph satisfying the running intersection property and such that no cliqu...