aGrUM 2.3.2
a C++ library for (probabilistic) graphical models

Implementation of Shachter's Bayes Balls algorithm. More...

#include <agrum/BN/inference/BayesBall.h>

Static Public Member Functions

Accessors / Modifiers
static void requisiteNodes (const DAG &dag, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, NodeSet &requisite)
 Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence.
template<typename GUM_SCALAR, class TABLE>
static void relevantTensors (const IBayesNet< GUM_SCALAR > &bn, const NodeSet &query, const NodeSet &hardEvidence, const NodeSet &softEvidence, Set< const TABLE * > &tensors)
 update a set of tensors, keeping only those d-connected with query variables given evidence

Private Member Functions

Constructors / Destructors
 BayesBall ()
 Default constructor.
 ~BayesBall ()
 Destructor.

Detailed Description

Implementation of Shachter's Bayes Balls algorithm.

Definition at line 67 of file BayesBall.h.

Constructor & Destructor Documentation

◆ BayesBall()

INLINE gum::BayesBall::BayesBall ( )
private

Default constructor.

Definition at line 54 of file BayesBall_inl.h.

54{ GUM_CONSTRUCTOR(BayesBall) }
BayesBall()
Default constructor.

References BayesBall().

Referenced by BayesBall(), and ~BayesBall().

Here is the call graph for this function:
Here is the caller graph for this function:

◆ ~BayesBall()

INLINE gum::BayesBall::~BayesBall ( )
private

Destructor.

Definition at line 56 of file BayesBall_inl.h.

56{ GUM_DESTRUCTOR(BayesBall) }

References BayesBall().

Here is the call graph for this function:

Member Function Documentation

◆ relevantTensors()

template<typename GUM_SCALAR, class TABLE>
void gum::BayesBall::relevantTensors ( const IBayesNet< GUM_SCALAR > & bn,
const NodeSet & query,
const NodeSet & hardEvidence,
const NodeSet & softEvidence,
Set< const TABLE * > & tensors )
static

update a set of tensors, keeping only those d-connected with query variables given evidence

Definition at line 53 of file BayesBall_tpl.h.

57 {
58 const DAG& dag = bn.dag();
59
60 // create the marks (top = first and bottom = second)
61 NodeProperty< std::pair< bool, bool > > marks(dag.size());
62 const std::pair< bool, bool > empty_mark(false, false);
63
66 HashTable< NodeId, Set< const TABLE* > > node2tensors;
67 for (const auto pot: tensors) {
68 const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
69 for (const auto var: vars) {
70 const NodeId id = bn.nodeId(*var);
71 if (!node2tensors.exists(id)) { node2tensors.insert(id, Set< const TABLE* >()); }
72 node2tensors[id].insert(pot);
73 }
74 }
75
76 // indicate that we will send the ball to all the query nodes (as children):
77 // in list nodes_to_visit, the first element is the next node to send the
78 // ball to and the Boolean indicates whether we shall reach it from one of
79 // its children (true) or from one parent (false)
80 List< std::pair< NodeId, bool > > nodes_to_visit;
81 for (const auto node: query) {
82 nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
83 }
84
85 // perform the bouncing ball until _node2tensors_ becomes empty (which
86 // means that we have reached all the tensors and, therefore, those
87 // are d-connected to query) or until there is no node in the graph to send
88 // the ball to
89 while (!nodes_to_visit.empty() && !node2tensors.empty()) {
90 // get the next node to visit
91 NodeId node = nodes_to_visit.front().first;
92
93 // if the marks of the node do not exist, create them
94 if (!marks.exists(node)) marks.insert(node, empty_mark);
95
96 // if the node belongs to the query, update _node2tensors_: remove all
97 // the tensors containing the node
98 if (node2tensors.exists(node)) {
99 auto& pot_set = node2tensors[node];
100 for (const auto pot: pot_set) {
101 const auto& vars = pot->variablesSequence();
102 for (const auto var: vars) {
103 const NodeId id = bn.nodeId(*var);
104 if (id != node) {
105 node2tensors[id].erase(pot);
106 if (node2tensors[id].empty()) { node2tensors.erase(id); }
107 }
108 }
109 }
110 node2tensors.erase(node);
111
112 // if _node2tensors_ is empty, no need to go on: all the tensors
113 // are d-connected to the query
114 if (node2tensors.empty()) return;
115 }
116
117
118 // bounce the ball toward the neighbors
119 if (nodes_to_visit.front().second) { // visit from a child
120 nodes_to_visit.popFront();
121
122 if (hardEvidence.exists(node)) { continue; }
123
124 if (!marks[node].first) {
125 marks[node].first = true; // top marked
126 for (const auto par: dag.parents(node)) {
127 nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
128 }
129 }
130
131 if (!marks[node].second) {
132 marks[node].second = true; // bottom marked
133 for (const auto chi: dag.children(node)) {
134 nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
135 }
136 }
137 } else { // visit from a parent
138 nodes_to_visit.popFront();
139
140 const bool is_hard_evidence = hardEvidence.exists(node);
141 const bool is_evidence = is_hard_evidence || softEvidence.exists(node);
142
143 if (is_evidence && !marks[node].first) {
144 marks[node].first = true;
145
146 for (const auto par: dag.parents(node)) {
147 nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
148 }
149 }
150
151 if (!is_hard_evidence && !marks[node].second) {
152 marks[node].second = true;
153
154 for (const auto chi: dag.children(node)) {
155 nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
156 }
157 }
158 }
159 }
160
161
162 // here, all the tensors that belong to _node2tensors_ are d-separated
163 // from the query
164 for (const auto& elt: node2tensors) {
165 for (const auto pot: elt.second) {
166 tensors.erase(pot);
167 }
168 }
169 }
Size NodeId
Type for node ids.
HashTable< NodeId, VAL > NodeProperty
Property on graph elements.

References gum::ArcGraphPart::children(), gum::DAGmodel::dag(), gum::HashTable< Key, Val >::empty(), gum::List< Val >::empty(), gum::HashTable< Key, Val >::erase(), gum::Set< Key >::erase(), gum::HashTable< Key, Val >::exists(), gum::Set< Key >::exists(), gum::List< Val >::front(), gum::HashTable< Key, Val >::insert(), gum::List< Val >::insert(), gum::IBayesNet< GUM_SCALAR >::nodeId(), gum::ArcGraphPart::parents(), gum::List< Val >::popFront(), and gum::NodeGraphPart::size().

Here is the call graph for this function:

◆ requisiteNodes()

void gum::BayesBall::requisiteNodes ( const DAG & dag,
const NodeSet & query,
const NodeSet & hardEvidence,
const NodeSet & softEvidence,
NodeSet & requisite )
static

Fill the 'requisite' nodeset with the requisite nodes in dag given a query and evidence.

Requisite nodes are those that are d-connected to at least one of the query nodes given a set of hard and soft evidence

Definition at line 55 of file BayesBall.cpp.

59 {
60 // for the moment, no node is requisite
61 requisite.clear();
62
63 // create the marks (top = first and bottom = second )
64 NodeProperty< std::pair< bool, bool > > marks(dag.size());
65 const std::pair< bool, bool > empty_mark(false, false);
66
67 // indicate that we will send the ball to all the query nodes (as children):
68 // in list nodes_to_visit, the first element is the next node to send the
69 // ball to and the Boolean indicates whether we shall reach it from one of
70 // its children (true) or from one parent (false)
71 List< std::pair< NodeId, bool > > nodes_to_visit;
72 for (const auto node: query) {
73 nodes_to_visit.insert(std::pair< NodeId, bool >(node, true));
74 }
75
76 // perform the bouncing ball until there is no node in the graph to send
77 // the ball to
78 while (!nodes_to_visit.empty()) {
79 // get the next node to visit
80 NodeId node = nodes_to_visit.front().first;
81
82 // if the marks of the node do not exist, create them
83 if (!marks.exists(node)) marks.insert(node, empty_mark);
84
85 // bounce the ball toward the neighbors
86 if (nodes_to_visit.front().second) { // visit from a child
87 nodes_to_visit.popFront();
88 requisite.insert(node);
89
90 if (hardEvidence.exists(node)) { continue; }
91
92 if (!marks[node].first) {
93 marks[node].first = true; // top marked
94 for (const auto par: dag.parents(node)) {
95 nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
96 }
97 }
98
99 if (!marks[node].second) {
100 marks[node].second = true; // bottom marked
101 for (const auto chi: dag.children(node)) {
102 nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
103 }
104 }
105 } else { // visit from a parent
106 nodes_to_visit.popFront();
107
108 const bool is_hard_evidence = hardEvidence.exists(node);
109 const bool is_evidence = is_hard_evidence || softEvidence.exists(node);
110
111 if (is_evidence && !marks[node].first) {
112 marks[node].first = true;
113 requisite.insert(node);
114
115 for (const auto par: dag.parents(node)) {
116 nodes_to_visit.insert(std::pair< NodeId, bool >(par, true));
117 }
118 }
119
120 if (!is_hard_evidence && !marks[node].second) {
121 marks[node].second = true;
122
123 for (const auto chi: dag.children(node)) {
124 nodes_to_visit.insert(std::pair< NodeId, bool >(chi, false));
125 }
126 }
127 }
128 }
129 }

References gum::ArcGraphPart::children(), gum::Set< Key >::clear(), gum::List< Val >::empty(), gum::HashTable< Key, Val >::exists(), gum::Set< Key >::exists(), gum::List< Val >::front(), gum::HashTable< Key, Val >::insert(), gum::List< Val >::insert(), gum::Set< Key >::insert(), gum::ArcGraphPart::parents(), gum::List< Val >::popFront(), and gum::NodeGraphPart::size().

Here is the call graph for this function:

The documentation for this class was generated from the following files: