54 template <
typename GUM_SCALAR >
60 template <
typename GUM_SCALAR >
64 for (
auto node:
nodes())
70 template <
typename GUM_SCALAR >
75 template <
typename GUM_SCALAR >
80 template <
typename GUM_SCALAR >
86 template <
typename GUM_SCALAR >
95 template <
typename GUM_SCALAR >
100 else return _bn_.cpt(
id);
103 template <
typename GUM_SCALAR >
105 return this->
_bn_.variableNodeMap();
108 template <
typename GUM_SCALAR >
112 return _bn_.variable(
id);
115 template <
typename GUM_SCALAR >
124 template <
typename GUM_SCALAR >
133 template <
typename GUM_SCALAR >
140 return _bn_.variable(
id);
145 template <
typename GUM_SCALAR >
147 return dag().existsNode(
id);
150 template <
typename GUM_SCALAR >
152 if (!
_bn_.dag().existsNode(
id))
156 this->
dag_.addNodeWithId(
id);
159 for (
auto pa: this->
_bn_.parents(
id)) {
164 for (
auto son: this->
_bn_.children(
id))
169 template <
typename GUM_SCALAR >
174 for (
auto pa: this->
_bn_.parents(
id))
178 template <
typename GUM_SCALAR >
182 this->
dag_.eraseNode(
id);
186 template <
typename GUM_SCALAR >
188 this->
dag_.eraseArc(
Arc(from, to));
191 template <
typename GUM_SCALAR >
193 this->
dag_.addArc(from, to);
196 template <
typename GUM_SCALAR >
200 for (
auto node_it =
parents.beginSafe(); node_it !=
parents.endSafe();
204 for (
Idx i = 1; i < pot.nbrDim(); i++) {
205 NodeId parent =
_bn_.idFromName(pot.variable(i).name());
216 template <
typename GUM_SCALAR >
218 if (!
dag().existsNode(
id))
221 if (&(pot.variable(0)) != &(
variable(
id))) {
223 "The tensor is not a marginal for _bn_.variable <" <<
variable(
id).name() <<
">")
228 for (
Idx i = 1; i < pot.nbrDim(); i++) {
229 if (!
parents.contains(
_bn_.idFromName(pot.variable(i).name())))
231 "Variable <" << pot.variable(i).name() <<
"> is not in the parents of node "
238 template <
typename GUM_SCALAR >
244 template <
typename GUM_SCALAR >
250 const Tensor< GUM_SCALAR >& pot =
cpt(
id);
252 for (
Idx i = 1; i < pot.nbrDim(); i++) {
253 NodeId parent =
_bn_.idFromName(pot.variable(i).name());
260 template <
typename GUM_SCALAR >
266 if (pot.nbrDim() > 1) {
270 if (&(pot.variable(0)) != &(
_bn_.variable(
id))) {
272 "The tensor is not a marginal for _bn_.variable <" <<
_bn_.variable(
id).name()
279 template <
typename GUM_SCALAR >
284 const auto&
cpt = this->
cpt(
id);
287 for (
Idx i = 1; i <
cpt.nbrDim(); i++) {
291 return (this->
parents(
id) == cpt_parents);
294 template <
typename GUM_SCALAR >
296 for (
auto node:
nodes())
302 template <
typename GUM_SCALAR >
304 std::stringstream output;
305 output <<
"digraph \"";
309 static std::string inFragmentStyle =
"fillcolor=\"#ffffaa\","
311 "fontcolor=\"#000000\"";
312 static std::string styleWithLocalCPT =
"fillcolor=\"#ffddaa\","
314 "fontcolor=\"#000000\"";
315 static std::string notConsistantStyle =
"fillcolor=\"#ff0000\","
317 "fontcolor=\"#ffff00\"";
318 static std::string outFragmentStyle =
"fillcolor=\"#f0f0f0\","
320 "fontcolor=\"#000000\"";
323 bn_name =
_bn_.property(
"name");
324 }
catch (
NotFound const&) { bn_name =
"no_name"; }
326 bn_name =
"Fragment of " + bn_name;
328 output << bn_name <<
"\" {" << std::endl;
329 output <<
" graph [bgcolor=transparent,label=\"" << bn_name <<
"\"];" << std::endl;
330 output <<
" node [style=filled];" << std::endl << std::endl;
332 for (
auto node:
_bn_.nodes()) {
333 output <<
"\"" <<
_bn_.variable(node).name() <<
"\" [comment=\"" << node <<
":"
334 <<
_bn_.variable(node) <<
", \"";
338 output << notConsistantStyle;
339 }
else if (
_localCPTs_.exists(node)) output << styleWithLocalCPT;
340 else output << inFragmentStyle;
341 }
else output << outFragmentStyle;
343 output <<
"];" << std::endl;
348 std::string tab =
" ";
350 for (
auto node:
_bn_.nodes()) {
351 if (
_bn_.children(node).size() > 0) {
352 for (
auto child:
_bn_.children(node)) {
353 output << tab <<
"\"" <<
_bn_.variable(node).name() <<
"\" -> "
354 <<
"\"" <<
_bn_.variable(child).name() <<
"\" [";
357 else output << outFragmentStyle;
359 output <<
"];" << std::endl;
364 output <<
"}" << std::endl;
369 template <
typename GUM_SCALAR >
375 for (
const auto nod:
nodes()) {
378 for (
const auto& arc:
dag().
arcs()) {
379 res.
addArc(arc.tail(), arc.head());
381 for (
const auto nod:
nodes()) {
382 res.
cpt(nod).fillWith(
cpt(nod));
Class representing Fragment of Bayesian networks.
Class representing Bayesian networks.
The base class for all directed edges.
void uninstallNode(NodeId id)
uninstall a node referenced by its nodeId
virtual void whenNodeDeleted(const void *src, NodeId id) final
the action to take when a node has just been removed from the graph
void installNode(NodeId id)
install a node referenced by its nodeId
virtual ~BayesNetFragment()
void installMarginal(NodeId id, const Tensor< GUM_SCALAR > &pot)
install a local marginal BY COPY for a node into the fragment.
virtual void whenNodeAdded(const void *src, NodeId id) final
the action to take when a new node is inserted into the graph
void installCPT_(NodeId id, const Tensor< GUM_SCALAR > &pot)
const IBayesNet< GUM_SCALAR > & _bn_
The referred BayesNet.
gum::BayesNet< GUM_SCALAR > toBN() const
create a brand new BayesNet from a fragment.
virtual void whenArcDeleted(const void *src, NodeId from, NodeId to) final
the action to take when an arc has just been removed from the graph
virtual void whenArcAdded(const void *src, NodeId from, NodeId to) final
the action to take when a new arc is inserted into the graph
virtual const DiscreteVariable & variable(NodeId id) const final
Returns a constant reference over a variabe given it's node id.
bool checkConsistency() const
returns true if all nodes in the fragment are consistent
NodeProperty< const Tensor< GUM_SCALAR > * > _localCPTs_
Mapping between the variable's id and their CPT specific to this Fragment.
virtual NodeId nodeId(const DiscreteVariable &var) const final
Return id node from discrete var pointer.
bool checkConsistency(NodeId id) const
returns true if the nodeId's (local or not) cpt is consistent with its parents in the fragment
const VariableNodeMap & variableNodeMap() const final
Returns a constant reference to the VariableNodeMap of this BN.
void uninstallCPT_(NodeId id)
uninstall a local CPT.
virtual NodeId idFromName(const std::string &name) const final
Getter by name.
virtual std::string toDot() const final
creates a dot representing the whole referred BN hilighting the fragment.
virtual const DiscreteVariable & variableFromName(const std::string &name) const final
Getter by name.
bool isInstalledNode(NodeId id) const
check if a certain NodeId exists in the fragment
void installArc_(NodeId from, NodeId to)
void installAscendants(NodeId id)
install a node and all its ascendants
void uninstallArc_(NodeId from, NodeId to)
BayesNetFragment()=delete
const Tensor< GUM_SCALAR > & cpt(NodeId varId) const final
Returns the CPT of a variable.
void uninstallCPT(NodeId id)
uninstall a local CPT.
void installCPT(NodeId id, const Tensor< GUM_SCALAR > &pot)
install a local cpt BY COPYfor a node into the fragment.
Class representing a Bayesian network.
const Tensor< GUM_SCALAR > & cpt(NodeId varId) const final
Returns the CPT of a variable.
NodeId add(const DiscreteVariable &var)
Add a variable to the gum::BayesNet.
void addArc(NodeId tail, NodeId head)
Add an arc in the BN, and update arc.head's CPT.
const DAG & dag() const
Returns a constant reference to the dag of this Bayes Net.
DAG dag_
The DAG of this Directed Graphical Model.
const ArcSet & arcs() const
return true if the arc tail->head exists in the DAGmodel
bool existsArc(const NodeId tail, const NodeId head) const
return true if the arc tail->head exists in the DAGmodel
const NodeSet & parents(const NodeId id) const
returns the set of nodes with arc ingoing to a given node
const NodeGraphPart & nodes() const final
Returns a constant reference to the dag of this Bayes Net.
DiGraphListener(const DiGraph *g)
default constructor
Base class for discrete random variable.
IBayesNet()
Default constructor.
Exception : the element we looked for cannot be found.
Exception : operation not allowed.
void insert(const Key &k)
Inserts a new element into the set.
aGrUM's Tensor is a multi-dimensional array with tensor operators.
Container used to map discrete variables with nodes.
const std::string & name() const
returns the name of the variable
#define GUM_ERROR(type, msg)
Size Idx
Type for indexes.
Size NodeId
Type for node ids.
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
gum is the global namespace for all aGrUM entities
Header of the Tensor class.