aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
BayesNetFragment_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
50#include <agrum/BN/BayesNet.h>
52
53namespace gum {
54 template < typename GUM_SCALAR >
59
60 template < typename GUM_SCALAR >
62 GUM_DESTRUCTOR(BayesNetFragment)
63
64 for (auto node: nodes())
65 if (_localCPTs_.exists(node)) uninstallCPT_(node);
66 }
67
68 //============================================================
69 // signals to keep consistency with the referred BayesNet
70 template < typename GUM_SCALAR >
71 INLINE void BayesNetFragment< GUM_SCALAR >::whenNodeAdded(const void* src, NodeId id) {
72 // nothing to do
73 }
74
75 template < typename GUM_SCALAR >
77 uninstallNode(id);
78 }
79
80 template < typename GUM_SCALAR >
81 INLINE void
83 // nothing to do
84 }
85
86 template < typename GUM_SCALAR >
87 INLINE void
89 if (dag().existsArc(from, to)) uninstallArc_(from, to);
90 }
91
92 //============================================================
93 // IBayesNet interface : BayesNetFragment here is a decorator for the bn
94
95 template < typename GUM_SCALAR >
96 INLINE const Tensor< GUM_SCALAR >& BayesNetFragment< GUM_SCALAR >::cpt(NodeId id) const {
97 if (!isInstalledNode(id)) GUM_ERROR(NotFound, "NodeId " << id << " is not installed")
98
99 if (_localCPTs_.exists(id)) return *_localCPTs_[id];
100 else return _bn_.cpt(id);
101 }
102
103 template < typename GUM_SCALAR >
105 return this->_bn_.variableNodeMap();
106 }
107
108 template < typename GUM_SCALAR >
110 if (!isInstalledNode(id)) GUM_ERROR(NotFound, "NodeId " << id << " is not installed")
111
112 return _bn_.variable(id);
113 }
114
115 template < typename GUM_SCALAR >
117 NodeId id = _bn_.nodeId(var);
118
119 if (!isInstalledNode(id)) GUM_ERROR(NotFound, "variable " << var.name() << " is not installed")
120
121 return id;
122 }
123
124 template < typename GUM_SCALAR >
125 INLINE NodeId BayesNetFragment< GUM_SCALAR >::idFromName(const std::string& name) const {
126 NodeId id = _bn_.idFromName(name);
127
128 if (!isInstalledNode(id)) GUM_ERROR(NotFound, "variable " << name << " is not installed")
129
130 return id;
131 }
132
133 template < typename GUM_SCALAR >
134 INLINE const DiscreteVariable&
136 NodeId id = idFromName(name);
137
138 if (!isInstalledNode(id)) GUM_ERROR(NotFound, "variable " << name << " is not installed")
139
140 return _bn_.variable(id);
141 }
142
143 //============================================================
144 // specific API for BayesNetFragment
145 template < typename GUM_SCALAR >
147 return dag().existsNode(id);
148 }
149
150 template < typename GUM_SCALAR >
152 if (!_bn_.dag().existsNode(id))
153 GUM_ERROR(NotFound, "Node " << id << " does not exist in referred BayesNet")
154
155 if (!isInstalledNode(id)) {
156 this->dag_.addNodeWithId(id);
157
158 // adding arcs with id as a tail
159 for (auto pa: this->_bn_.parents(id)) {
160 if (isInstalledNode(pa)) this->dag_.addArc(pa, id);
161 }
162
163 // adding arcs with id as a head
164 for (auto son: this->_bn_.children(id))
165 if (isInstalledNode(son)) this->dag_.addArc(id, son);
166 }
167 }
168
169 template < typename GUM_SCALAR >
171 installNode(id);
172
173 // bn is a dag => this will have an end ...
174 for (auto pa: this->_bn_.parents(id))
176 }
177
178 template < typename GUM_SCALAR >
180 if (isInstalledNode(id)) {
181 uninstallCPT(id);
182 this->dag_.eraseNode(id);
183 }
184 }
185
186 template < typename GUM_SCALAR >
188 this->dag_.eraseArc(Arc(from, to));
189 }
190
191 template < typename GUM_SCALAR >
193 this->dag_.addArc(from, to);
194 }
195
196 template < typename GUM_SCALAR >
197 void BayesNetFragment< GUM_SCALAR >::installCPT_(NodeId id, const Tensor< GUM_SCALAR >& pot) {
198 // topology
199 const auto& parents = this->parents(id);
200 for (auto node_it = parents.beginSafe(); node_it != parents.endSafe();
201 ++node_it) // safe iterator needed here
202 uninstallArc_(*node_it, id);
203
204 for (Idx i = 1; i < pot.nbrDim(); i++) {
205 NodeId parent = _bn_.idFromName(pot.variable(i).name());
206
207 if (isInstalledNode(parent)) installArc_(parent, id);
208 }
209
210 // local cpt
211 if (_localCPTs_.exists(id)) uninstallCPT_(id);
212
213 _localCPTs_.insert(id, new gum::Tensor< GUM_SCALAR >(pot));
214 }
215
216 template < typename GUM_SCALAR >
217 void BayesNetFragment< GUM_SCALAR >::installCPT(NodeId id, const Tensor< GUM_SCALAR >& pot) {
218 if (!dag().existsNode(id))
219 GUM_ERROR(NotFound, "Node " << id << " is not installed in the fragment")
220
221 if (&(pot.variable(0)) != &(variable(id))) {
223 "The tensor is not a marginal for _bn_.variable <" << variable(id).name() << ">")
224 }
225
226 const NodeSet& parents = _bn_.parents(id);
227
228 for (Idx i = 1; i < pot.nbrDim(); i++) {
229 if (!parents.contains(_bn_.idFromName(pot.variable(i).name())))
231 "Variable <" << pot.variable(i).name() << "> is not in the parents of node "
232 << id)
233 }
234
235 installCPT_(id, pot);
236 }
237
238 template < typename GUM_SCALAR >
240 delete _localCPTs_[id];
241 _localCPTs_.erase(id);
242 }
243
244 template < typename GUM_SCALAR >
246 if (_localCPTs_.exists(id)) {
247 uninstallCPT_(id);
248
249 // re-create arcs from referred tensor
250 const Tensor< GUM_SCALAR >& pot = cpt(id);
251
252 for (Idx i = 1; i < pot.nbrDim(); i++) {
253 NodeId parent = _bn_.idFromName(pot.variable(i).name());
254
255 if (isInstalledNode(parent)) installArc_(parent, id);
256 }
257 }
258 }
259
260 template < typename GUM_SCALAR >
261 void BayesNetFragment< GUM_SCALAR >::installMarginal(NodeId id, const Tensor< GUM_SCALAR >& pot) {
262 if (!isInstalledNode(id)) {
263 GUM_ERROR(NotFound, "The node " << id << " is not part of this fragment")
264 }
265
266 if (pot.nbrDim() > 1) {
267 GUM_ERROR(OperationNotAllowed, "The tensor is not a marginal :" << pot)
268 }
269
270 if (&(pot.variable(0)) != &(_bn_.variable(id))) {
272 "The tensor is not a marginal for _bn_.variable <" << _bn_.variable(id).name()
273 << ">")
274 }
275
276 installCPT_(id, pot);
277 }
278
279 template < typename GUM_SCALAR >
281 if (!isInstalledNode(id))
282 GUM_ERROR(NotFound, "The node " << id << " is not part of this fragment")
283
284 const auto& cpt = this->cpt(id);
285 NodeSet cpt_parents;
286
287 for (Idx i = 1; i < cpt.nbrDim(); i++) {
288 cpt_parents.insert(_bn_.idFromName(cpt.variable(i).name()));
289 }
290
291 return (this->parents(id) == cpt_parents);
292 }
293
294 template < typename GUM_SCALAR >
296 for (auto node: nodes())
297 if (!checkConsistency(node)) return false;
298
299 return true;
300 }
301
302 template < typename GUM_SCALAR >
304 std::stringstream output;
305 output << "digraph \"";
306
307 std::string bn_name;
308
309 static std::string inFragmentStyle = "fillcolor=\"#ffffaa\","
310 "color=\"#000000\","
311 "fontcolor=\"#000000\"";
312 static std::string styleWithLocalCPT = "fillcolor=\"#ffddaa\","
313 "color=\"#000000\","
314 "fontcolor=\"#000000\"";
315 static std::string notConsistantStyle = "fillcolor=\"#ff0000\","
316 "color=\"#000000\","
317 "fontcolor=\"#ffff00\"";
318 static std::string outFragmentStyle = "fillcolor=\"#f0f0f0\","
319 "color=\"#f0f0f0\","
320 "fontcolor=\"#000000\"";
321
322 try {
323 bn_name = _bn_.property("name");
324 } catch (NotFound const&) { bn_name = "no_name"; }
325
326 bn_name = "Fragment of " + bn_name;
327
328 output << bn_name << "\" {" << std::endl;
329 output << " graph [bgcolor=transparent,label=\"" << bn_name << "\"];" << std::endl;
330 output << " node [style=filled];" << std::endl << std::endl;
331
332 for (auto node: _bn_.nodes()) {
333 output << "\"" << _bn_.variable(node).name() << "\" [comment=\"" << node << ":"
334 << _bn_.variable(node) << ", \"";
335
336 if (isInstalledNode(node)) {
337 if (!checkConsistency(node)) {
338 output << notConsistantStyle;
339 } else if (_localCPTs_.exists(node)) output << styleWithLocalCPT;
340 else output << inFragmentStyle;
341 } else output << outFragmentStyle;
342
343 output << "];" << std::endl;
344 }
345
346 output << std::endl;
347
348 std::string tab = " ";
349
350 for (auto node: _bn_.nodes()) {
351 if (_bn_.children(node).size() > 0) {
352 for (auto child: _bn_.children(node)) {
353 output << tab << "\"" << _bn_.variable(node).name() << "\" -> "
354 << "\"" << _bn_.variable(child).name() << "\" [";
355
356 if (dag().existsArc(Arc(node, child))) output << inFragmentStyle;
357 else output << outFragmentStyle;
358
359 output << "];" << std::endl;
360 }
361 }
362 }
363
364 output << "}" << std::endl;
365
366 return output.str();
367 }
368
369 template < typename GUM_SCALAR >
371 if (!checkConsistency()) {
372 GUM_ERROR(OperationNotAllowed, "The fragment contains un-consistent node(s)")
373 }
375 for (const auto nod: nodes()) {
376 res.add(variable(nod), nod);
377 }
378 for (const auto& arc: dag().arcs()) {
379 res.addArc(arc.tail(), arc.head());
380 }
381 for (const auto nod: nodes()) {
382 res.cpt(nod).fillWith(cpt(nod));
383 }
384
385 return res;
386 }
387} // namespace gum
Class representing Fragment of Bayesian networks.
Class representing Bayesian networks.
The base class for all directed edges.
void uninstallNode(NodeId id)
uninstall a node referenced by its nodeId
virtual void whenNodeDeleted(const void *src, NodeId id) final
the action to take when a node has just been removed from the graph
void installNode(NodeId id)
install a node referenced by its nodeId
void installMarginal(NodeId id, const Tensor< GUM_SCALAR > &pot)
install a local marginal BY COPY for a node into the fragment.
virtual void whenNodeAdded(const void *src, NodeId id) final
the action to take when a new node is inserted into the graph
void installCPT_(NodeId id, const Tensor< GUM_SCALAR > &pot)
const IBayesNet< GUM_SCALAR > & _bn_
The referred BayesNet.
gum::BayesNet< GUM_SCALAR > toBN() const
create a brand new BayesNet from a fragment.
virtual void whenArcDeleted(const void *src, NodeId from, NodeId to) final
the action to take when an arc has just been removed from the graph
virtual void whenArcAdded(const void *src, NodeId from, NodeId to) final
the action to take when a new arc is inserted into the graph
virtual const DiscreteVariable & variable(NodeId id) const final
Returns a constant reference over a variabe given it's node id.
bool checkConsistency() const
returns true if all nodes in the fragment are consistent
NodeProperty< const Tensor< GUM_SCALAR > * > _localCPTs_
Mapping between the variable's id and their CPT specific to this Fragment.
virtual NodeId nodeId(const DiscreteVariable &var) const final
Return id node from discrete var pointer.
bool checkConsistency(NodeId id) const
returns true if the nodeId's (local or not) cpt is consistent with its parents in the fragment
const VariableNodeMap & variableNodeMap() const final
Returns a constant reference to the VariableNodeMap of this BN.
void uninstallCPT_(NodeId id)
uninstall a local CPT.
virtual NodeId idFromName(const std::string &name) const final
Getter by name.
virtual std::string toDot() const final
creates a dot representing the whole referred BN hilighting the fragment.
virtual const DiscreteVariable & variableFromName(const std::string &name) const final
Getter by name.
bool isInstalledNode(NodeId id) const
check if a certain NodeId exists in the fragment
void installArc_(NodeId from, NodeId to)
void installAscendants(NodeId id)
install a node and all its ascendants
void uninstallArc_(NodeId from, NodeId to)
const Tensor< GUM_SCALAR > & cpt(NodeId varId) const final
Returns the CPT of a variable.
void uninstallCPT(NodeId id)
uninstall a local CPT.
void installCPT(NodeId id, const Tensor< GUM_SCALAR > &pot)
install a local cpt BY COPYfor a node into the fragment.
Class representing a Bayesian network.
Definition BayesNet.h:93
const Tensor< GUM_SCALAR > & cpt(NodeId varId) const final
Returns the CPT of a variable.
NodeId add(const DiscreteVariable &var)
Add a variable to the gum::BayesNet.
void addArc(NodeId tail, NodeId head)
Add an arc in the BN, and update arc.head's CPT.
const DAG & dag() const
Returns a constant reference to the dag of this Bayes Net.
DAG dag_
The DAG of this Directed Graphical Model.
Definition DAGmodel.h:272
const ArcSet & arcs() const
return true if the arc tail->head exists in the DAGmodel
bool existsArc(const NodeId tail, const NodeId head) const
return true if the arc tail->head exists in the DAGmodel
const NodeSet & parents(const NodeId id) const
returns the set of nodes with arc ingoing to a given node
const NodeGraphPart & nodes() const final
Returns a constant reference to the dag of this Bayes Net.
DiGraphListener(const DiGraph *g)
default constructor
Base class for discrete random variable.
IBayesNet()
Default constructor.
Exception : the element we looked for cannot be found.
Exception : operation not allowed.
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
aGrUM's Tensor is a multi-dimensional array with tensor operators.
Definition tensor.h:85
Container used to map discrete variables with nodes.
const std::string & name() const
returns the name of the variable
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
Size Idx
Type for indexes.
Definition types.h:79
Size NodeId
Type for node ids.
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Header of the Tensor class.