aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
gum::dSeparationAlgorithm Class Reference

the d-separation algorithm as described in Koller & Friedman (2009) More...

#include <dSeparationAlgorithm.h>

Public Member Functions

Constructors / Destructors
 dSeparationAlgorithm ()
 default constructor
 dSeparationAlgorithm (const dSeparationAlgorithm &from)
 copy constructor
 dSeparationAlgorithm (dSeparationAlgorithm &&from)
 move constructor
 ~dSeparationAlgorithm ()
 destructor
Operators
dSeparationAlgorithmoperator= (const dSeparationAlgorithm &from)
 copy operator
dSeparationAlgorithmoperator= (dSeparationAlgorithm &&from)
 move operator
Accessors / Modifiers
void requisiteNodes (const DAG &dag, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, NodeSet &requisite) const
 Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence.
template<typename GUM_SCALAR, class TABLE>
void relevantTensors (const IBayesNet< GUM_SCALAR > &bn, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, Set< const TABLE * > &tensors)
 update a set of tensors, keeping only those d-connected with query variables given evidence

Detailed Description

the d-separation algorithm as described in Koller & Friedman (2009)

Definition at line 63 of file dSeparationAlgorithm.h.

Constructor & Destructor Documentation

◆ dSeparationAlgorithm() [1/3]

INLINE gum::dSeparationAlgorithm::dSeparationAlgorithm ( )

default constructor

Definition at line 53 of file dSeparationAlgorithm_inl.h.

53 {
54 GUM_CONSTRUCTOR(dSeparationAlgorithm);
55 ;
56 }
dSeparationAlgorithm()
default constructor

References dSeparationAlgorithm().

Referenced by dSeparationAlgorithm(), dSeparationAlgorithm(), dSeparationAlgorithm(), ~dSeparationAlgorithm(), operator=(), and operator=().

Here is the call graph for this function:
Here is the caller graph for this function:

◆ dSeparationAlgorithm() [2/3]

INLINE gum::dSeparationAlgorithm::dSeparationAlgorithm ( const dSeparationAlgorithm & from)

copy constructor

Definition at line 59 of file dSeparationAlgorithm_inl.h.

59 {
60 GUM_CONS_CPY(dSeparationAlgorithm);
61 }

References dSeparationAlgorithm().

Here is the call graph for this function:

◆ dSeparationAlgorithm() [3/3]

INLINE gum::dSeparationAlgorithm::dSeparationAlgorithm ( dSeparationAlgorithm && from)

move constructor

Definition at line 64 of file dSeparationAlgorithm_inl.h.

64 {
65 GUM_CONS_MOV(dSeparationAlgorithm);
66 }

References dSeparationAlgorithm().

Here is the call graph for this function:

◆ ~dSeparationAlgorithm()

INLINE gum::dSeparationAlgorithm::~dSeparationAlgorithm ( )

destructor

Definition at line 69 of file dSeparationAlgorithm_inl.h.

69 {
70 GUM_DESTRUCTOR(dSeparationAlgorithm);
71 ;
72 }

References dSeparationAlgorithm().

Here is the call graph for this function:

Member Function Documentation

◆ operator=() [1/2]

INLINE dSeparationAlgorithm & gum::dSeparationAlgorithm::operator= ( const dSeparationAlgorithm & from)

copy operator

Definition at line 75 of file dSeparationAlgorithm_inl.h.

75 {
76 return *this;
77 }

References dSeparationAlgorithm().

Here is the call graph for this function:

◆ operator=() [2/2]

INLINE dSeparationAlgorithm & gum::dSeparationAlgorithm::operator= ( dSeparationAlgorithm && from)

move operator

Definition at line 80 of file dSeparationAlgorithm_inl.h.

80 {
81 return *this;
82 }

References dSeparationAlgorithm().

Here is the call graph for this function:

◆ relevantTensors()

template<typename GUM_SCALAR, class TABLE>
void gum::dSeparationAlgorithm::relevantTensors ( const IBayesNet< GUM_SCALAR > & bn,
const NodeSet & query,
const NodeSet & hardEvidence,
const NodeSet & softEvidence,
Set< const TABLE * > & tensors )

update a set of tensors, keeping only those d-connected with query variables given evidence

Definition at line 56 of file dSeparationAlgorithm_tpl.h.

60 {
61 const DAG& dag = bn.dag();
62
63 // mark the set of ancestors of the evidence
64 NodeSet ev_ancestors(dag.size());
65 {
66 List< NodeId > anc_to_visit;
67 for (const auto node: hardEvidence)
68 anc_to_visit.insert(node);
69 for (const auto node: softEvidence)
70 anc_to_visit.insert(node);
71 while (!anc_to_visit.empty()) {
72 const NodeId node = anc_to_visit.front();
73 anc_to_visit.popFront();
74
75 if (!ev_ancestors.exists(node)) {
76 ev_ancestors.insert(node);
77 for (const auto par: dag.parents(node)) {
78 anc_to_visit.insert(par);
79 }
80 }
81 }
82 }
83
84 // create the marks indicating that we have visited a node
85 NodeSet visited_from_child(dag.size());
86 NodeSet visited_from_parent(dag.size());
87
90 HashTable< NodeId, Set< const TABLE* > > node2tensors;
91 for (const auto pot: tensors) {
92 const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
93 for (const auto var: vars) {
94 const NodeId id = bn.nodeId(*var);
95 if (!node2tensors.exists(id)) { node2tensors.insert(id, Set< const TABLE* >()); }
96 node2tensors[id].insert(pot);
97 }
98 }
99
100 // indicate that we will send the ball to all the query nodes (as children):
101 // in list nodes_to_visit, the first element is the next node to send the
102 // ball to and the Boolean indicates whether we shall reach it from one of
103 // its children (true) or from one parent (false)
104 List< std::pair< NodeId, bool > > nodes_to_visit;
105 for (const auto node: query) {
106 nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
107 }
108
109 // perform the bouncing ball until there is no node in the graph to send
110 // the ball to
111 while (!nodes_to_visit.empty() && !node2tensors.empty()) {
112 // get the next node to visit
113 const NodeId node = nodes_to_visit.front().first;
114 const bool direction = nodes_to_visit.front().second;
115 nodes_to_visit.popFront();
116
117 // check if the node has not already been visited in the same direction
118 bool already_visited;
119 if (direction) {
120 already_visited = visited_from_child.exists(node);
121 if (!already_visited) { visited_from_child.insert(node); }
122 } else {
123 already_visited = visited_from_parent.exists(node);
124 if (!already_visited) { visited_from_parent.insert(node); }
125 }
126
127 // if the node belongs to the query, update _node2tensors_: remove all
128 // the tensors containing the node
129 if (node2tensors.exists(node)) {
130 auto& pot_set = node2tensors[node];
131 for (const auto pot: pot_set) {
132 const auto& vars = pot->variablesSequence();
133 for (const auto var: vars) {
134 const NodeId id = bn.nodeId(*var);
135 if (id != node) {
136 node2tensors[id].erase(pot);
137 if (node2tensors[id].empty()) { node2tensors.erase(id); }
138 }
139 }
140 }
141 node2tensors.erase(node);
142
143 // if _node2tensors_ is empty, no need to go on: all the tensors
144 // are d-connected to the query
145 if (node2tensors.empty()) return;
146 }
147
148 // if this is the first time we meet the node, then visit it
149 if (!already_visited) {
150 // mark the node as reachable if this is not a hard evidence
151 const bool is_hard_evidence = hardEvidence.exists(node);
152
153 // bounce the ball toward the neighbors
154 if (direction && !is_hard_evidence) { // visit from a child
155 // visit the parents
156 for (const auto par: dag.parents(node)) {
157 nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
158 }
159
160 // visit the children
161 for (const auto chi: dag.children(node)) {
162 nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
163 }
164 } else { // visit from a parent
165 if (!hardEvidence.exists(node)) {
166 // visit the children
167 for (const auto chi: dag.children(node)) {
168 nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
169 }
170 }
171 if (ev_ancestors.exists(node)) {
172 // visit the parents
173 for (const auto par: dag.parents(node)) {
174 nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
175 }
176 }
177 }
178 }
179 }
180
181 // here, all the tensors that belong to _node2tensors_ are d-separated
182 // from the query
183 for (const auto& elt: node2tensors) {
184 for (const auto pot: elt.second) {
185 tensors.erase(pot);
186 }
187 }
188 }
Size NodeId
Type for node ids.
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...

References gum::ArcGraphPart::children(), gum::DAGmodel::dag(), gum::HashTable< Key, Val >::empty(), gum::List< Val >::empty(), gum::HashTable< Key, Val >::erase(), gum::Set< Key >::erase(), gum::HashTable< Key, Val >::exists(), gum::Set< Key >::exists(), gum::List< Val >::front(), gum::HashTable< Key, Val >::insert(), gum::List< Val >::insert(), gum::Set< Key >::insert(), gum::IBayesNet< GUM_SCALAR >::nodeId(), gum::ArcGraphPart::parents(), gum::List< Val >::popFront(), and gum::NodeGraphPart::size().

Here is the call graph for this function:

◆ requisiteNodes()

void gum::dSeparationAlgorithm::requisiteNodes ( const DAG & dag,
const NodeSet & query,
const NodeSet & hardEvidence,
const NodeSet & softEvidence,
NodeSet & requisite ) const

Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence.

Requisite nodes are those that are d-connected to at least one of the query nodes given a set of hard and soft evidence

Definition at line 60 of file dSeparationAlgorithm.cpp.

64 {
65 // for the moment, no node is requisite
66 requisite.clear();
67
68 // mark the set of ancestors of the evidence
69 NodeSet ev_ancestors(dag.size());
70 {
71 List< NodeId > anc_to_visit;
72 for (const auto node: hardEvidence)
73 anc_to_visit.insert(node);
74 for (const auto node: softEvidence)
75 anc_to_visit.insert(node);
76 while (!anc_to_visit.empty()) {
77 const NodeId node = anc_to_visit.front();
78 anc_to_visit.popFront();
79
80 if (!ev_ancestors.exists(node)) {
81 ev_ancestors.insert(node);
82 for (const auto par: dag.parents(node)) {
83 anc_to_visit.insert(par);
84 }
85 }
86 }
87 }
88
89 // create the marks indicating that we have visited a node
90 NodeSet visited_from_child(dag.size());
91 NodeSet visited_from_parent(dag.size());
92
93 // indicate that we will send the ball to all the query nodes (as children):
94 // in list nodes_to_visit, the first element is the next node to send the
95 // ball to and the Boolean indicates whether we shall reach it from one of
96 // its children (true) or from one parent (false)
97 List< std::pair< NodeId, bool > > nodes_to_visit;
98 for (const auto node: query) {
99 nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
100 }
101
102 // perform the bouncing ball until there is no node in the graph to send
103 // the ball to
104 while (!nodes_to_visit.empty()) {
105 // get the next node to visit
106 const NodeId node = nodes_to_visit.front().first;
107 const bool direction = nodes_to_visit.front().second;
108 nodes_to_visit.popFront();
109
110 // check if the node has not already been visited in the same direction
111 bool already_visited;
112 if (direction) {
113 already_visited = visited_from_child.exists(node);
114 if (!already_visited) { visited_from_child.insert(node); }
115 } else {
116 already_visited = visited_from_parent.exists(node);
117 if (!already_visited) { visited_from_parent.insert(node); }
118 }
119
120 // if this is the first time we meet the node, then visit it
121 if (!already_visited) {
122 // mark the node as reachable if this is not a hard evidence
123 const bool is_hard_evidence = hardEvidence.exists(node);
124 if (!is_hard_evidence) { requisite.insert(node); }
125
126 // bounce the ball toward the neighbors
127 if (direction && !is_hard_evidence) { // visit from a child
128 // visit the parents
129 for (const auto par: dag.parents(node)) {
130 nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
131 }
132
133 // visit the children
134 for (const auto chi: dag.children(node)) {
135 nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
136 }
137 } else { // visit from a parent
138 if (!hardEvidence.exists(node)) {
139 // visit the children
140 for (const auto chi: dag.children(node)) {
141 nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
142 }
143 }
144 if (ev_ancestors.exists(node)) {
145 // visit the parents
146 for (const auto par: dag.parents(node)) {
147 nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
148 }
149 }
150 }
151 }
152 }
153 }

References gum::ArcGraphPart::children(), gum::Set< Key >::clear(), gum::List< Val >::empty(), gum::Set< Key >::exists(), gum::List< Val >::front(), gum::List< Val >::insert(), gum::Set< Key >::insert(), gum::ArcGraphPart::parents(), gum::List< Val >::popFront(), and gum::NodeGraphPart::size().

Referenced by gum::SamplingInference< GUM_SCALAR >::contextualize().

Here is the call graph for this function:
Here is the caller graph for this function:

The documentation for this class was generated from the following files: