87 ApproximationScheme::operator=(from);
93 ApproximationScheme::operator=(std::move(from));
112 initiation_(mutualInformation, graph, sep_set, rank);
114 iteration_(mutualInformation, graph, sep_set, rank);
131 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
137 for (
const Edge& edge: edges) {
140 double Ixy = mutualInformation.
score(x, y);
144 sepSet.insert(std::make_pair(x, y),
_emptySet_);
165 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
174 while (rank.
top().second > 0.5) {
177 const NodeId x = std::get< 0 >(*(best.first));
178 const NodeId y = std::get< 1 >(*(best.first));
179 const NodeId z = std::get< 2 >(*(best.first));
180 std::vector< NodeId > ui = std::move(std::get< 3 >(*(best.first)));
183 const double i_xy_ui = mutualInformation.
score(x, y, ui);
186 sepSet.insert(std::make_pair(x, y), std::move(ui));
217 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
218 std::vector< Ranking > triples =
unshieldedTriples_(graph, mutualInformation, sepSet);
219 Size steps_orient = triples.size();
227 while (i < triples.size()) {
231 x = std::get< 0 >(*triple.first);
232 y = std::get< 1 >(*triple.first);
233 z = std::get< 2 >(*triple.first);
235 std::vector< NodeId > ui;
236 std::pair< NodeId, NodeId > key = {x, y};
237 std::pair< NodeId, NodeId > rev_key = {y, x};
238 if (sepSet.exists(key)) {
240 }
else if (sepSet.exists(rev_key)) {
241 ui = sepSet[rev_key];
243 double Ixyz_ui = triple.second;
247 if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
263 }
catch (
const gum::NotFound&) { graph.
addArc(x, z); }
271 }
catch (
const gum::NotFound&) { graph.
addArc(x, z); }
279 }
catch (
const gum::NotFound&) { graph.
addArc(y, z); }
288 }
catch (
const gum::NotFound&) { graph.
addArc(y, z); }
310 }
catch (
const gum::NotFound&) { graph.
addArc(z, y); }
320 }
catch (
const gum::NotFound&) { graph.
addArc(z, x); }
347 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
355 for (
auto iter = marks.
begin(); iter != marks.
end(); ++iter) {
356 if (graph.
existsEdge(iter.key().first, iter.key().second) && iter.val() ==
'>') {
358 graph.
addArc(iter.key().first, iter.key().second);
362 std::vector< ProbabilisticRanking > proba_triples
365 const Size steps_orient = proba_triples.size();
369 if (steps_orient > 0) { best = proba_triples[0]; }
371 while (!proba_triples.empty() && std::max(std::get< 2 >(best), std::get< 3 >(best)) > 0.5) {
372 const NodeId x = std::get< 0 >(*std::get< 0 >(best));
373 const NodeId y = std::get< 1 >(*std::get< 0 >(best));
374 const NodeId z = std::get< 2 >(*std::get< 0 >(best));
376 const double i3 = std::get< 1 >(best);
378 const double p1 = std::get< 2 >(best);
379 const double p2 = std::get< 3 >(best);
386 delete std::get< 0 >(best);
387 proba_triples.erase(proba_triples.begin());
391 if (!proba_triples.empty()) best = proba_triples[0];
408 graph.
addArc(iter->head(), iter->tail());
410 *iter =
Arc(iter->head(), iter->tail());
420 const std::vector< NodeId >& ui,
429 const double Ixy_ui = mutualInformation.
score(x, y, ui);
431 for (
const NodeId z: graph) {
433 if (z != x && z != y && std::find(ui.begin(), ui.end(), z) == ui.end()) {
438 const double Ixyz_ui = mutualInformation.
score(x, y, z, ui);
439 double calc_expo1 = -Ixyz_ui *
M_LN2;
443 }
else if (calc_expo1 < -
_maxLog_) {
446 Pnv = 1 / (1 + std::exp(calc_expo1));
450 const double Ixz_ui = mutualInformation.
score(x, z, ui);
451 const double Iyz_ui = mutualInformation.
score(y, z, ui);
453 calc_expo1 = -(Ixz_ui - Ixy_ui) *
M_LN2;
454 double calc_expo2 = -(Iyz_ui - Ixy_ui) *
M_LN2;
466 expo1 = std::exp(calc_expo1);
471 expo2 = std::exp(calc_expo2);
473 Pb = 1 / (1 + expo1 + expo2);
477 const double min_pnv_pb = std::min(Pnv, Pb);
478 if (min_pnv_pb > maxP) {
497 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
498 std::vector< Ranking > triples;
503 std::vector< NodeId > ui;
504 std::pair< NodeId, NodeId > key = {x, y};
505 std::pair< NodeId, NodeId > rev_key = {y, x};
506 if (sepSet.exists(key)) {
508 }
else if (sepSet.exists(rev_key)) {
509 ui = sepSet[rev_key];
512 const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
513 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
515 double Ixyz_ui = mutualInformation.
score(x, y, z, ui);
519 triple.second = Ixyz_ui;
520 triples.push_back(triple);
534 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
535 HashTable< std::pair< NodeId, NodeId >,
char >& marks) {
536 std::vector< ProbabilisticRanking > triples;
541 std::vector< NodeId > ui;
542 std::pair< NodeId, NodeId > key = {x, y};
543 std::pair< NodeId, NodeId > rev_key = {y, x};
544 if (sepSet.exists(key)) {
546 }
else if (sepSet.exists(rev_key)) {
547 ui = sepSet[rev_key];
550 const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
551 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
553 const double Ixyz_ui = mutualInformation.
score(x, y, z, ui);
556 triples.push_back(triple);
557 if (!marks.exists({x, z})) { marks.insert({x, z},
'o'); }
558 if (!marks.exists({z, x})) { marks.insert({z, x},
'o'); }
559 if (!marks.exists({y, z})) { marks.insert({y, z},
'o'); }
560 if (!marks.exists({z, y})) { marks.insert({z, y},
'o'); }
571 std::vector< ProbabilisticRanking >
573 std::vector< ProbabilisticRanking > probaTriples) {
574 for (
auto& triple: probaTriples) {
576 x = std::get< 0 >(*std::get< 0 >(triple));
577 y = std::get< 1 >(*std::get< 0 >(triple));
578 z = std::get< 2 >(*std::get< 0 >(triple));
579 const double Ixyz = std::get< 1 >(triple);
580 double Pxz = std::get< 2 >(triple);
581 double Pyz = std::get< 3 >(triple);
584 const double expo = std::exp(Ixyz);
585 const double P0 = (1 + expo) / (1 + 3 * expo);
587 if (Pxz == Pyz && Pyz == 0.5) {
588 std::get< 2 >(triple) = P0;
589 std::get< 3 >(triple) = P0;
591 if (graph.
existsArc(x, z) && Pxz >= P0) {
592 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
593 }
else if (graph.
existsArc(y, z) && Pyz >= P0) {
594 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
598 const double expo = std::exp(-Ixyz);
599 if (graph.
existsArc(x, z) && Pxz >= 0.5) {
600 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
601 }
else if (graph.
existsArc(y, z) && Pyz >= 0.5) {
602 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
622 const auto nei_x = essentialGraph.
neighbours(x);
630 essentialGraph.
addArc(y, x);
635 essentialGraph.
addArc(x, y);
640 bool newOrientation =
true;
641 while (newOrientation) {
642 newOrientation =
false;
649 return essentialGraph;
662 const auto nei_x = essentialGraph.
neighbours(x);
670 essentialGraph.
addArc(y, x);
675 essentialGraph.
addArc(x, y);
681 bool newOrientation =
true;
682 while (newOrientation) {
683 newOrientation =
false;
703 for (
auto node: essentialGraph) {
706 for (
const Arc& arc: essentialGraph.
arcs()) {
707 dag.
addArc(arc.tail(), arc.head());
734 for (
const auto p: graph.
parents(xj)) {
749 const auto& edge = *(essentialGraph.
edges().begin());
750 NodeId root = edge.first();
755 while (!stack.
empty()) {
758 if (visited.
contains(next))
continue;
759 if (essentialGraph.
children(next).
size() > size_children_root) {
760 size_children_root = essentialGraph.
children(next).
size();
763 for (
const auto n: essentialGraph.
neighbours(next))
771 while (!stack.
empty()) {
774 if (visited.
contains(next))
continue;
775 const auto nei = essentialGraph.
neighbours(next);
776 for (
const auto n: nei) {
781 essentialGraph.
addArc(n, next);
792 for (
auto& xi: neighbours) {
811 GUM_TRACE(
" + add arc (" << xi <<
"," << xj <<
")")
826 template <
typename GUM_SCALAR,
typename GRAPH_CHANGES_
SELECTOR,
typename PARAM_ESTIMATOR >
828 PARAM_ESTIMATOR& estimator,
841 for (
const auto parent: graph.
parents(n2)) {
865 while (!nodeFIFO.
empty()) {
866 current = nodeFIFO.
front();
870 for (
const auto new_one: graph.
parents(current)) {
875 if (new_one == n1) {
return true; }
890 HashTable< std::pair< NodeId, NodeId >,
char >& marks,
897 if (marks[{x, z}] ==
'o' && marks[{y, z}] ==
'o') {
905 GUM_TRACE(
"Adding latent couple (" << z <<
"," << x <<
")")
926 GUM_TRACE(
"Adding latent couple (" << z <<
"," << y <<
")")
939 }
else if (marks[{x, z}] ==
'>' && marks[{y, z}] ==
'o') {
947 GUM_TRACE(
"Adding latent couple (" << z <<
"," << y <<
")")
960 }
else if (marks[{y, z}] ==
'>' && marks[{x, z}] ==
'o') {
968 GUM_TRACE(
"Adding latent couple (" << z <<
"," << x <<
")")
986 HashTable< std::pair< NodeId, NodeId >,
char >& marks,
993 if (marks[{x, z}] ==
'>' && marks[{y, z}] ==
'o' && marks[{z, y}] !=
'-') {
1000 marks[{z, y}] =
'>';
1001 marks[{y, z}] =
'-';
1005 GUM_TRACE(
"4.b Adding arc (" << y <<
"," << z <<
")")
1006 marks[{z, y}] =
'-';
1007 marks[{y, z}] =
'>';
1013 marks[{z, y}] =
'>';
1014 marks[{y, z}] =
'-';
1018 GUM_TRACE(
"4.d Adding arc (" << y <<
"," << z <<
")")
1020 marks[{z, y}] =
'-';
1021 marks[{y, z}] =
'>';
1024 }
else if (marks[{y, z}] ==
'>' && marks[{x, z}] ==
'o' && marks[{z, x}] !=
'-') {
1030 marks[{z, x}] =
'>';
1031 marks[{x, z}] =
'-';
1035 GUM_TRACE(
"5.b Adding arc (" << x <<
"," << z <<
")")
1036 marks[{z, x}] =
'-';
1037 marks[{x, z}] =
'>';
1043 marks[{z, x}] =
'>';
1044 marks[{x, z}] =
'-';
1048 GUM_TRACE(
"5.d Adding arc (" << x <<
"," << z <<
")")
1049 marks[{z, x}] =
'-';
1050 marks[{x, z}] =
'>';
1061 return (std::find(lbeg, lend,
Arc(x, y)) == lend)
1062 && (std::find(lbeg, lend,
Arc(y, x)) == lend);
A class that, given a structure and a parameter estimator returns a full Bayes net.
The SimpleMiic algorithm.
Size current_step_
The current step.
ApproximationScheme(bool verbosity=false)
bool existsArc(const Arc &arc) const
indicates whether a given arc exists
const NodeSet & parents(NodeId id) const
returns the set of nodes with arc ingoing to a given node
NodeSet children(const NodeSet &ids) const
returns the set of children of a set of nodes
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
std::vector< NodeId > directedPath(NodeId node1, NodeId node2) const
returns a directed path from node1 to node2 belonging to the set of arcs
const ArcSet & arcs() const
returns the set of arcs stored within the ArcGraphPart
The base class for all directed edges.
void addArc(NodeId tail, NodeId head) final
insert a new arc into the directed graph
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Sequence< NodeId > topologicalOrder() const
Build and return a topological order.
virtual void eraseEdge(const Edge &edge)
removes an edge from the EdgeGraphPart
const EdgeSet & edges() const
returns the set of edges stored within the EdgeGraphPart
bool existsEdge(const Edge &edge) const
indicates whether a given edge exists
const NodeSet & neighbours(NodeId id) const
returns the set of node neighbours to a given node
The base class for all undirected edges.
The class for generic Hash Tables.
iterator begin()
Returns an unsafe iterator pointing to the beginning of the hashtable.
const iterator & end() noexcept
Returns the unsafe iterator pointing to the end of the hashtable.
Size size() const noexcept
Returns the number of elements in the heap.
Val pop()
Removes the top element from the heap and return it.
Size insert(const Val &val)
inserts a new element (actually a copy) in the heap and returns its index
const Val & top() const
Returns the element at the top of the heap.
Signaler3< Size, double, double > onProgress
Progression, error and time.
Generic doubly linked lists.
Val & front() const
Returns a reference to first element of a list, if any.
bool empty() const noexcept
Returns a boolean indicating whether the chained list is empty.
void popFront()
Removes the first element of a List, if any.
Val & pushBack(const Val &val)
Inserts a new element (a copy) at the end of the chained list.
Base class for mixed graphs.
NodeSet boundary(NodeId node) const
returns the set of node adjacent to a given node
std::vector< NodeId > mixedOrientedPath(NodeId node1, NodeId node2) const
returns a mixed edge/directed arc path from node1 to node2 in the arc/edge set
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
iterator begin() const
The usual unsafe begin iterator to parse the set.
Size size() const noexcept
Returns the number of elements in the set.
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
bool contains(const Key &k) const
Indicates whether a given elements belong to the set.
void insert(const Key &k)
Inserts a new element into the set.
bool empty() const noexcept
Indicates whether the set is the empty set.
void erase(const Key &k)
Erases an element from the set.
void clear()
Removes all the elements, if any, from the set.
static BayesNet< GUM_SCALAR > createBN(ParamEstimator &estimator, const DAG &dag)
create a BN from a DAG using a one pass generator (typically ML)
DAG learnStructure(CorrectedMutualInformation &I, MixedGraph graph)
learns the structure of a Bayesian network, i.e. a DAG, by first learning an Essential graph and then...
bool isOrientable_(const MixedGraph &graph, NodeId xi, NodeId xj) const
const std::vector< Arc > latentVariables() const
get the list of arcs hiding latent variables
const std::vector< NodeId > _emptySet_
an empty conditioning set
MixedGraph learnMixedStructure(CorrectedMutualInformation &mutualInformation, MixedGraph graph)
learns the structure of an Essential Graph
SimpleMiic & operator=(const SimpleMiic &from)
copy operator
void orientationMiic_(CorrectedMutualInformation &mutualInformation, MixedGraph &graph, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet)
Orientation phase from the MIIC algorithm, returns a mixed graph that may contain circles.
void _propagatingOrientationMiic_(MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, char > &marks, NodeId x, NodeId y, NodeId z, double p1, double p2)
void iteration_(CorrectedMutualInformation &mutualInformation, MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet, Heap< CondRanking, GreaterPairOn2nd > &rank)
Iteration phase.
bool _isNotLatentCouple_(NodeId x, NodeId y)
int _maxLog_
Fixes the maximum log that we accept in exponential computations.
void _orientingVstructureMiic_(MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, char > &marks, NodeId x, NodeId y, NodeId z, double p1, double p2)
~SimpleMiic() override
destructor
std::vector< ProbabilisticRanking > unshieldedTriplesMiic_(const MixedGraph &graph, CorrectedMutualInformation &mutualInformation, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet, HashTable< std::pair< NodeId, NodeId >, char > &marks)
gets the list of unshielded triples in the graph in decreasing value of |I'(x, y, z|{ui}...
ArcProperty< double > _arcProbas_
Storing the propabilities for each arc set in the graph.
std::vector< Arc > _latentCouples_
an empty vector of arcs
static bool _existsDirectedPath_(const MixedGraph &graph, NodeId n1, NodeId n2)
checks for directed paths in a graph, consider double arcs like edges
HashTable< std::pair< NodeId, NodeId >, char > _initialMarks_
Initial marks for the orientation phase, used to convey constraints.
SimpleMiic()
default constructor
Size _size_
size of the database
void propagatesOrientationInChainOfRemainingEdges_(MixedGraph &graph)
heuristic for remaining edges when everything else has been tried
void orientationLatents_(CorrectedMutualInformation &mutualInformation, MixedGraph &graph, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet)
variant trying to propagate both orientations in a bidirected arc
MixedGraph learnPDAG(CorrectedMutualInformation &mutualInformation, MixedGraph graph)
learns the structure of an Essential Graph
bool propagatesRemainingOrientableEdges_(MixedGraph &graph, NodeId xj)
Propagates the orientation from a node to its neighbours.
void findBestContributor_(NodeId x, NodeId y, const std::vector< NodeId > &ui, const MixedGraph &graph, CorrectedMutualInformation &mutualInformation, Heap< CondRanking, GreaterPairOn2nd > &rank)
finds the best contributor node for a pair given a conditioning set
bool isForbidenArc_(NodeId x, NodeId y) const
void addConstraints(HashTable< std::pair< NodeId, NodeId >, char > constraints)
Set a ensemble of constraints for the orientation phase.
BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR &selector, PARAM_ESTIMATOR &estimator, DAG initial_dag=DAG())
learns the structure and the parameters of a BN
static bool _existsNonTrivialDirectedPath_(const MixedGraph &graph, NodeId n1, NodeId n2)
checks for directed paths in a graph, considering double arcs like edges, not considering arc as a di...
std::vector< Ranking > unshieldedTriples_(const MixedGraph &graph, CorrectedMutualInformation &mutualInformation, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet)
gets the list of unshielded triples in the graph in decreasing value of |I'(x, y, z|{ui}...
std::vector< ProbabilisticRanking > updateProbaTriples_(const MixedGraph &graph, std::vector< ProbabilisticRanking > probaTriples)
Gets the orientation probabilities like MIIC for the orientation phase.
void initiation_(CorrectedMutualInformation &mutualInformation, MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet, Heap< CondRanking, GreaterPairOn2nd > &rank)
Initiation phase.
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Set< Edge > EdgeSet
Some typdefs and define for shortcuts ...
Size NodeId
Type for node ids.
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
Class hash tables iterators.
Base classes for mixed directed/undirected graphs.
include the inlined functions if necessary
std::pair< ThreePoints *, double > Ranking
std::pair< CondThreePoints *, double > CondRanking
std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > CondThreePoints
std::tuple< NodeId, NodeId, NodeId > ThreePoints
std::tuple< ThreePoints *, double, double, double > ProbabilisticRanking
gum is the global namespace for all aGrUM entities
#define GUM_EMIT3(signal, arg1, arg2, arg3)
Class used to compute response times for benchmark purposes.