aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
gum::BarrenNodesFinder Class Reference

Detect barren nodes for inference in Bayesian networks. More...

#include <barrenNodesFinder.h>

Collaboration diagram for gum::BarrenNodesFinder:

Public Member Functions

Constructors / Destructors
 BarrenNodesFinder (const DAG *dag)
 default constructor
 BarrenNodesFinder (const BarrenNodesFinder &from)
 copy constructor
 BarrenNodesFinder (BarrenNodesFinder &&from) noexcept
 move constructor
 ~BarrenNodesFinder ()
 destructor
Operators
BarrenNodesFinderoperator= (const BarrenNodesFinder &from)
 copy operator
BarrenNodesFinderoperator= (BarrenNodesFinder &&from)
 move operator
Accessors / Modifiers
void setDAG (const DAG *new_dag)
 sets a new DAG
void setEvidence (const NodeSet *observed_nodes)
 sets the observed nodes in the DAG
void setTargets (const NodeSet *target_nodes)
 sets the set of target nodes we are interested in
NodeSet barrenNodes ()
 returns the set of barren nodes
ArcProperty< NodeSetbarrenNodes (const CliqueGraph &junction_tree)
 returns the set of barren nodes in the messages sent in a junction tree
template<typename GUM_SCALAR>
ArcProperty< Set< const Tensor< GUM_SCALAR > * > > barrenTensors (const CliqueGraph &junction_tree, const IBayesNet< GUM_SCALAR > &bn)
 returns the set of barren tensors in messages sent in a junction tree

Private Attributes

const DAG_dag_
 the DAG on which we compute the barren nodes
const NodeSet_observed_nodes_
 the set of observed nodes
const NodeSet_target_nodes_
 the set of targeted nodes

Detailed Description

Detect barren nodes for inference in Bayesian networks.

Definition at line 65 of file barrenNodesFinder.h.

Constructor & Destructor Documentation

◆ BarrenNodesFinder() [1/3]

INLINE gum::BarrenNodesFinder::BarrenNodesFinder ( const DAG * dag)
explicit

default constructor

Definition at line 46 of file barrenNodesFinder_inl.h.

46 :
47 _dag_(dag) { // for debugging purposes
48 GUM_CONSTRUCTOR(BarrenNodesFinder);
49 }
const DAG * _dag_
the DAG on which we compute the barren nodes
BarrenNodesFinder(const DAG *dag)
default constructor

References BarrenNodesFinder(), and _dag_.

Referenced by BarrenNodesFinder(), BarrenNodesFinder(), BarrenNodesFinder(), ~BarrenNodesFinder(), operator=(), and operator=().

Here is the call graph for this function:
Here is the caller graph for this function:

◆ BarrenNodesFinder() [2/3]

INLINE gum::BarrenNodesFinder::BarrenNodesFinder ( const BarrenNodesFinder & from)

copy constructor

Definition at line 52 of file barrenNodesFinder_inl.h.

52 :
53 _dag_(from._dag_), _observed_nodes_(from._observed_nodes_),
54 _target_nodes_(from._target_nodes_) { // for debugging purposes
55 GUM_CONS_CPY(BarrenNodesFinder);
56 }
const NodeSet * _observed_nodes_
the set of observed nodes
const NodeSet * _target_nodes_
the set of targeted nodes

References BarrenNodesFinder(), _dag_, _observed_nodes_, and _target_nodes_.

Here is the call graph for this function:

◆ BarrenNodesFinder() [3/3]

INLINE gum::BarrenNodesFinder::BarrenNodesFinder ( BarrenNodesFinder && from)
noexcept

move constructor

Definition at line 59 of file barrenNodesFinder_inl.h.

59 :
60 _dag_(from._dag_), _observed_nodes_(from._observed_nodes_),
61 _target_nodes_(from._target_nodes_) {
62 // for debugging purposes
63 GUM_CONS_MOV(BarrenNodesFinder);
64 }

References BarrenNodesFinder(), _dag_, _observed_nodes_, and _target_nodes_.

Here is the call graph for this function:

◆ ~BarrenNodesFinder()

INLINE gum::BarrenNodesFinder::~BarrenNodesFinder ( )

destructor

Definition at line 67 of file barrenNodesFinder_inl.h.

67 { // for debugging purposes
68 GUM_DESTRUCTOR(BarrenNodesFinder)
69 }

References BarrenNodesFinder().

Here is the call graph for this function:

Member Function Documentation

◆ barrenNodes() [1/2]

NodeSet gum::BarrenNodesFinder::barrenNodes ( )

returns the set of barren nodes

Definition at line 307 of file barrenNodesFinder.cpp.

307 {
308 // mark all the nodes in the dag as barren (true)
309 NodeProperty< bool > barren_mark = _dag_->nodesPropertyFromVal(true);
310
311 // mark all the ancestors of the evidence and targets as non-barren
312 List< NodeId > nodes_to_examine;
313 int nb_non_barren = 0;
314 for (const auto node: *_observed_nodes_)
315 nodes_to_examine.insert(node);
316 for (const auto node: *_target_nodes_)
317 nodes_to_examine.insert(node);
318
319 while (!nodes_to_examine.empty()) {
320 const NodeId node = nodes_to_examine.front();
321 nodes_to_examine.popFront();
322 if (barren_mark[node]) {
323 barren_mark[node] = false;
324 ++nb_non_barren;
325 for (const auto par: _dag_->parents(node))
326 nodes_to_examine.insert(par);
327 }
328 }
329
330 // here, all the nodes marked true are barren
331 NodeSet barren_nodes(_dag_->sizeNodes() - nb_non_barren);
332 for (const auto& marked_pair: barren_mark)
333 if (marked_pair.second) barren_nodes.insert(marked_pair.first);
334
335 return barren_nodes;
336 }
Size NodeId
Type for node ids.
HashTable< NodeId, VAL > NodeProperty
Property on graph elements.
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...

References _dag_, _observed_nodes_, _target_nodes_, gum::List< Val >::empty(), gum::List< Val >::front(), gum::List< Val >::insert(), gum::Set< Key >::insert(), and gum::List< Val >::popFront().

Referenced by barrenTensors(), and gum::SamplingInference< GUM_SCALAR >::contextualize().

Here is the call graph for this function:
Here is the caller graph for this function:

◆ barrenNodes() [2/2]

ArcProperty< NodeSet > gum::BarrenNodesFinder::barrenNodes ( const CliqueGraph & junction_tree)

returns the set of barren nodes in the messages sent in a junction tree

Definition at line 57 of file barrenNodesFinder.cpp.

57 {
58 // assign a mark to all the nodes
59 // and mark all the observed nodes and their ancestors as non-barren
60 NodeProperty< Size > mark(_dag_->size());
61 {
62 for (const auto node: *_dag_)
63 mark.insert(node, 0); // for the moment, 0 = possibly barren
64
65 // mark all the observed nodes and their ancestors as non barren
66 // std::numeric_limits<unsigned int>::max () will be necessarily non
67 // barren
68 // later on
69 Sequence< NodeId > observed_anc(_dag_->size());
70 const Size non_barren = std::numeric_limits< Size >::max();
71 for (const auto node: *_observed_nodes_)
72 observed_anc.insert(node);
73 for (Idx i = 0; i < observed_anc.size(); ++i) {
74 const NodeId node = observed_anc[i];
75 if (!mark[node]) {
76 mark[node] = non_barren;
77 for (const auto par: _dag_->parents(node)) {
78 if (!mark[par] && !observed_anc.exists(par)) { observed_anc.insert(par); }
79 }
80 }
81 }
82 }
83
84 // create the data structure that will contain the result of the
85 // method. By default, we assume that, for each pair of adjacent cliques,
86 // all
87 // the nodes that do not belong to their separator are possibly barren and,
88 // by sweeping the dag, we will remove the nodes that were determined
89 // above as non-barren. Structure result will assign to each (ordered) pair
90 // of adjacent cliques its set of barren nodes.
92 for (const auto& edge: junction_tree.edges()) {
93 const NodeSet& separator = junction_tree.separator(edge);
94
95 NodeSet non_barren1 = junction_tree.clique(edge.first());
96 for (auto iter = non_barren1.beginSafe(); iter != non_barren1.endSafe(); ++iter) {
97 if (mark[*iter] || separator.exists(*iter)) { non_barren1.erase(iter); }
98 }
99 result.insert(Arc(edge.first(), edge.second()), std::move(non_barren1));
100
101 NodeSet non_barren2 = junction_tree.clique(edge.second());
102 for (auto iter = non_barren2.beginSafe(); iter != non_barren2.endSafe(); ++iter) {
103 if (mark[*iter] || separator.exists(*iter)) { non_barren2.erase(iter); }
104 }
105 result.insert(Arc(edge.second(), edge.first()), std::move(non_barren2));
106 }
107
108 // for each node in the DAG, indicate which are the arcs in the result
109 // structure whose separator contain it: the separators are actually the
110 // targets of the queries.
111 NodeProperty< ArcSet > node2arc;
112 for (const auto node: *_dag_)
113 node2arc.insert(node, ArcSet());
114 for (const auto& elt: result) {
115 const Arc& arc = elt.first;
116 if (!result[arc].empty()) { // no need to further process cliques
117 const NodeSet& separator = // with no barren nodes
118 junction_tree.separator(Edge(arc.tail(), arc.head()));
119
120 for (const auto node: separator) {
121 node2arc[node].insert(arc);
122 }
123 }
124 }
125
126 // To determine the set of non-barren nodes w.r.t. a given single node
127 // query, we rely on the fact that those are precisely all the ancestors of
128 // this single node. To mutualize the computations, we will thus sweep the
129 // DAG from top to bottom and exploit the fact that the set of ancestors of
130 // the child of a given node A contain the ancestors of A. Therefore, we
131 // will
132 // determine sets of paths in the DAG and, for each path, compute the set of
133 // its barren nodes from the source to the destination of the path. The
134 // optimal set of paths, i.e., that which will minimize computations, is
135 // obtained by solving a "minimum path cover in directed acyclic graphs".
136 // But
137 // such an algorithm is too costly for the gain we can get from it, so we
138 // will
139 // rely on a simple heuristics.
140
141 // To compute the heuristics, we proceed as follows:
142 // 1/ we mark to 1 all the nodes that are ancestors of at least one (key)
143 // node
144 // with a non-empty arcset in node2arc and we extract from those the
145 // roots, i.e., those nodes whose set of parents, if any, have all been
146 // identified as non-barren by being marked as
147 // std::numeric_limits<unsigned int>::max (). Such nodes are
148 // thus the top of the graph to sweep.
149 // 2/ create a copy of the subgraph of the DAG w.r.t. the 1-marked nodes
150 // and, for each node, if the node has several parents and children,
151 // keep only one arc from one of the parents to the child with the
152 // smallest
153 // number of parents, and try to create a matching between parents and
154 // children and add one arc for each edge of this matching. This will
155 // allow
156 // us to create distinct paths in the DAG. Whenever a child has no more
157 // parents, it becomes the root of a new path.
158 // 3/ the sweeping will be performed from the roots of all these paths.
159
160 // perform step 1/
161 NodeSet path_roots;
162 {
163 List< NodeId > nodes_to_mark;
164 for (const auto& elt: node2arc) {
165 if (!elt.second.empty()) { // only process nodes with assigned arcs
166 nodes_to_mark.insert(elt.first);
167 }
168 }
169 while (!nodes_to_mark.empty()) {
170 NodeId node = nodes_to_mark.front();
171 nodes_to_mark.popFront();
172
173 if (!mark[node]) { // mark the node and all its ancestors
174 mark[node] = 1;
175 Size nb_par = 0;
176 for (auto par: _dag_->parents(node)) {
177 Size parent_mark = mark[par];
178 if (parent_mark != std::numeric_limits< Size >::max()) {
179 ++nb_par;
180 if (parent_mark == 0) { nodes_to_mark.insert(par); }
181 }
182 }
183
184 if (nb_par == 0) { path_roots.insert(node); }
185 }
186 }
187 }
188
189 // perform step 2/
190 DAG sweep_dag = *_dag_;
191 for (const auto node: *_dag_) { // keep only nodes marked with 1
192 if (mark[node] != 1) { sweep_dag.eraseNode(node); }
193 }
194 for (const auto node: sweep_dag) {
195 const Size nb_parents = sweep_dag.parents(node).size();
196 const Size nb_children = sweep_dag.children(node).size();
197 if ((nb_parents > 1) || (nb_children > 1)) {
198 // perform the matching
199 const auto& parents = sweep_dag.parents(node);
200
201 // if there is no child, remove all the parents except the first one
202 if (nb_children == 0) {
203 auto iter_par = parents.beginSafe();
204 for (++iter_par; iter_par != parents.endSafe(); ++iter_par) {
205 sweep_dag.eraseArc(Arc(*iter_par, node));
206 }
207 } else {
208 // find the child with the smallest number of parents
209 const auto& children = sweep_dag.children(node);
210 NodeId smallest_child = 0;
211 Size smallest_nb_par = std::numeric_limits< Size >::max();
212 for (const auto child: children) {
213 const auto new_nb = sweep_dag.parents(child).size();
214 if (new_nb < smallest_nb_par) {
215 smallest_nb_par = new_nb;
216 smallest_child = child;
217 }
218 }
219
220 // if there is no parent, just keep the link with the smallest child
221 // and remove all the other arcs
222 if (nb_parents == 0) {
223 for (auto iter = children.beginSafe(); iter != children.endSafe(); ++iter) {
224 if (*iter != smallest_child) {
225 if (sweep_dag.parents(*iter).size() == 1) { path_roots.insert(*iter); }
226 sweep_dag.eraseArc(Arc(node, *iter));
227 }
228 }
229 } else {
230 auto nb_match = Size(std::min(nb_parents, nb_children) - 1);
231 auto iter_par = parents.beginSafe();
232 ++iter_par; // skip the first parent, whose arc with node will
233 // remain
234 auto iter_child = children.beginSafe();
235 for (Idx i = 0; i < nb_match; ++i, ++iter_par, ++iter_child) {
236 if (*iter_child == smallest_child) { ++iter_child; }
237 sweep_dag.addArc(*iter_par, *iter_child);
238 sweep_dag.eraseArc(Arc(*iter_par, node));
239 sweep_dag.eraseArc(Arc(node, *iter_child));
240 }
241 for (; iter_par != parents.endSafe(); ++iter_par) {
242 sweep_dag.eraseArc(Arc(*iter_par, node));
243 }
244 for (; iter_child != children.endSafe(); ++iter_child) {
245 if (*iter_child != smallest_child) {
246 if (sweep_dag.parents(*iter_child).size() == 1) { path_roots.insert(*iter_child); }
247 sweep_dag.eraseArc(Arc(node, *iter_child));
248 }
249 }
250 }
251 }
252 }
253 }
254
255 // step 3: sweep the paths from the roots of sweep_dag
256 // here, the idea is that, for each path of sweep_dag, the mark we put
257 // to the ancestors is a given number, say N, that increases from path
258 // to path. Hence, for a given path, all the nodes that are marked with a
259 // number at least as high as N are non-barren, the others being barren.
260 Idx mark_id = 2;
261 for (NodeId path: path_roots) {
262 // perform the sweeping from the path
263 while (true) {
264 // mark all the ancestors of the node
265 List< NodeId > to_mark{path};
266 while (!to_mark.empty()) {
267 NodeId node = to_mark.front();
268 to_mark.popFront();
269 if (mark[node] < mark_id) {
270 mark[node] = mark_id;
271 for (const auto par: _dag_->parents(node)) {
272 if (mark[par] < mark_id) { to_mark.insert(par); }
273 }
274 }
275 }
276
277 // now, get all the arcs that contained node "path" in their separator.
278 // this node acts as a query target and, therefore, its ancestors
279 // shall be non-barren.
280 const ArcSet& arcs = node2arc[path];
281 for (const auto& arc: arcs) {
282 NodeSet& barren = result[arc];
283 for (auto iter = barren.beginSafe(); iter != barren.endSafe(); ++iter) {
284 if (mark[*iter] >= mark_id) {
285 // this indicates a non-barren node
286 barren.erase(iter);
287 }
288 }
289 }
290
291 // go to the next sweeping node
292 const NodeSet& sweep_children = sweep_dag.children(path);
293 if (sweep_children.size()) {
294 path = *(sweep_children.begin());
295 } else {
296 // here, the path has ended, so we shall go to the next path
297 ++mark_id;
298 break;
299 }
300 }
301 }
302
303 return result;
304 }
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size Idx
Type for indexes.
Definition types.h:79
HashTable< Arc, VAL > ArcProperty
Property on graph elements.
Set< Arc > ArcSet
Some typdefs and define for shortcuts ...

References _dag_, _observed_nodes_, gum::DAG::addArc(), gum::Set< Key >::begin(), gum::Set< Key >::beginSafe(), gum::ArcGraphPart::children(), gum::CliqueGraph::clique(), gum::EdgeGraphPart::edges(), gum::List< Val >::empty(), gum::Set< Key >::endSafe(), gum::Set< Key >::erase(), gum::ArcGraphPart::eraseArc(), gum::DiGraph::eraseNode(), gum::SequenceImplementation< Key, std::is_scalar< Key >::value >::exists(), gum::Set< Key >::exists(), gum::Arc::first(), gum::List< Val >::front(), gum::Arc::head(), gum::HashTable< Key, Val >::insert(), gum::List< Val >::insert(), gum::SequenceImplementation< Key, std::is_scalar< Key >::value >::insert(), gum::Set< Key >::insert(), gum::ArcGraphPart::parents(), gum::List< Val >::popFront(), gum::CliqueGraph::separator(), gum::SequenceImplementation< Key, std::is_scalar< Key >::value >::size(), gum::Set< Key >::size(), and gum::Arc::tail().

Here is the call graph for this function:

◆ barrenTensors()

template<typename GUM_SCALAR>
ArcProperty< Set< const Tensor< GUM_SCALAR > * > > gum::BarrenNodesFinder::barrenTensors ( const CliqueGraph & junction_tree,
const IBayesNet< GUM_SCALAR > & bn )

returns the set of barren tensors in messages sent in a junction tree

Definition at line 47 of file barrenNodesFinder_tpl.h.

48 {
49 // get the barren nodes
50 ArcProperty< NodeSet > barren_nodes = this->barrenNodes(junction_tree);
51
52 // transform the node sets into sets of tensors
54 for (const auto& barren: barren_nodes) {
55 Set< const Tensor< GUM_SCALAR >* > tensors;
56 for (const auto node: barren.second) {
57 tensors.insert(&(bn.cpt(node)));
58 }
59 result.insert(Arc(barren.first), std::move(tensors));
60 }
61
62 return result;
63 }
NodeSet barrenNodes()
returns the set of barren nodes

References barrenNodes(), gum::IBayesNet< GUM_SCALAR >::cpt(), gum::HashTable< Key, Val >::insert(), and gum::Set< Key >::insert().

Here is the call graph for this function:

◆ operator=() [1/2]

INLINE BarrenNodesFinder & gum::BarrenNodesFinder::operator= ( BarrenNodesFinder && from)

move operator

Definition at line 82 of file barrenNodesFinder_inl.h.

82 {
83 if (this != &from) {
84 _dag_ = from._dag_;
85 _observed_nodes_ = from._observed_nodes_;
86 _target_nodes_ = from._target_nodes_;
87 }
88 return *this;
89 }

References BarrenNodesFinder(), _dag_, _observed_nodes_, and _target_nodes_.

Here is the call graph for this function:

◆ operator=() [2/2]

INLINE BarrenNodesFinder & gum::BarrenNodesFinder::operator= ( const BarrenNodesFinder & from)

copy operator

Definition at line 72 of file barrenNodesFinder_inl.h.

72 {
73 if (this != &from) {
74 _dag_ = from._dag_;
75 _observed_nodes_ = from._observed_nodes_;
76 _target_nodes_ = from._target_nodes_;
77 }
78 return *this;
79 }

References BarrenNodesFinder(), _dag_, _observed_nodes_, and _target_nodes_.

Here is the call graph for this function:

◆ setDAG()

INLINE void gum::BarrenNodesFinder::setDAG ( const DAG * new_dag)

sets a new DAG

Definition at line 92 of file barrenNodesFinder_inl.h.

92{ _dag_ = new_dag; }

References _dag_.

◆ setEvidence()

INLINE void gum::BarrenNodesFinder::setEvidence ( const NodeSet * observed_nodes)

sets the observed nodes in the DAG

Definition at line 95 of file barrenNodesFinder_inl.h.

95 {
96 _observed_nodes_ = observed_nodes;
97 }

References _observed_nodes_.

Referenced by gum::SamplingInference< GUM_SCALAR >::contextualize().

Here is the caller graph for this function:

◆ setTargets()

INLINE void gum::BarrenNodesFinder::setTargets ( const NodeSet * target_nodes)

sets the set of target nodes we are interested in

Definition at line 100 of file barrenNodesFinder_inl.h.

100 {
101 _target_nodes_ = target_nodes;
102 }

References _target_nodes_.

Referenced by gum::SamplingInference< GUM_SCALAR >::contextualize().

Here is the caller graph for this function:

Member Data Documentation

◆ _dag_

const DAG* gum::BarrenNodesFinder::_dag_
private

the DAG on which we compute the barren nodes

Definition at line 130 of file barrenNodesFinder.h.

Referenced by BarrenNodesFinder(), BarrenNodesFinder(), BarrenNodesFinder(), barrenNodes(), barrenNodes(), operator=(), operator=(), and setDAG().

◆ _observed_nodes_

const NodeSet* gum::BarrenNodesFinder::_observed_nodes_
private

the set of observed nodes

Definition at line 133 of file barrenNodesFinder.h.

Referenced by BarrenNodesFinder(), BarrenNodesFinder(), barrenNodes(), barrenNodes(), operator=(), operator=(), and setEvidence().

◆ _target_nodes_

const NodeSet* gum::BarrenNodesFinder::_target_nodes_
private

the set of targeted nodes

Definition at line 136 of file barrenNodesFinder.h.

Referenced by BarrenNodesFinder(), BarrenNodesFinder(), barrenNodes(), operator=(), operator=(), and setTargets().


The documentation for this class was generated from the following files: