83 ApproximationScheme::operator=(from);
89 ApproximationScheme::operator=(std::move(from));
94 return e1.second > e2.second;
98 return std::abs(e1.second) > std::abs(e2.second);
103 double p1xz = std::get< 2 >(e1);
104 double p1yz = std::get< 3 >(e1);
105 double p2xz = std::get< 2 >(e2);
106 double p2yz = std::get< 3 >(e2);
107 double I1 = std::get< 1 >(e1);
108 double I2 = std::get< 1 >(e2);
112 if ((I1 < 0 && I2 < 0) || (I1 >= 0 && I2 >= 0)) {
113 if (std::max(p1xz, p1yz) == std::max(p2xz, p2yz)) {
114 return std::abs(I1) > std::abs(I2);
116 return std::max(p1xz, p1yz) > std::max(p2xz, p2yz);
140 initiation_(mutualInformation, graph, sep_set, rank);
142 iteration_(mutualInformation, graph, sep_set, rank);
163 initiation_(mutualInformation, graph, sep_set, rank);
165 iteration_(mutualInformation, graph, sep_set, rank);
189 template <
typename GUM_SCALAR,
typename GRAPH_CHANGES_
SELECTOR,
typename PARAM_ESTIMATOR >
191 PARAM_ESTIMATOR& estimator,
206 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
212 for (
const Edge& edge: edges) {
216 GUM_SL_EMIT(x, y,
"Remove " << x <<
" - " << y,
" Constraints : Forbidden edge")
221 double Ixy = mutualInformation.
score(x, y);
227 "Remove " << x <<
" - " << y,
228 "Independent based on Mutual Information :" << Ixy)
230 sepSet.insert(std::make_pair(x, y),
_emptySet_);
235 "Keep " << x <<
" - " << y,
236 "Dependent based on Mutual Information :" << Ixy)
255 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
263 while (rank.
top().second > 0.5) {
266 const NodeId x = std::get< 0 >(*(best.first));
267 const NodeId y = std::get< 1 >(*(best.first));
268 const NodeId z = std::get< 2 >(*(best.first));
269 std::vector< NodeId > ui = std::move(std::get< 3 >(*(best.first)));
272 const double i_xy_ui = mutualInformation.
score(x, y, ui);
277 "Remove " << x <<
" - " << y,
278 "Independent based on MutualInformation knowing Sep "
279 << ui <<
"Mutual information:" << i_xy_ui)
281 sepSet.insert(std::make_pair(x, y), std::move(ui));
312 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
332 if (graph.
existsEdge(arc.head(), arc.tail())) {
336 "Add Arc" << arc.tail() <<
"->" << arc.head(),
338 graph.
addArc(arc.tail(), arc.head());
339 marks.
insert({arc.tail(), arc.head()},
'>');
340 marks.
insert({arc.head(), arc.tail()},
'-');
343 graph.
addArc(arc.tail(), arc.head());
344 marks.
insert({arc.tail(), arc.head()},
'>');
345 marks.
insert({arc.head(), arc.tail()},
'-');
355 graph.
addArc(arc.head(), arc.tail());
358 "Add Arc" << arc.head() <<
"->" << arc.tail(),
359 "Forbidden in the other orientation")
360 marks.
insert({arc.tail(), arc.head()},
'-');
361 marks.
insert({arc.head(), arc.tail()},
'>');
365 std::vector< ProbabilisticRanking > proba_triples
368 const Size steps_orient = proba_triples.size();
372 if (steps_orient > 0) { best = proba_triples[0]; }
374 while (!proba_triples.empty() && std::max(std::get< 2 >(best), std::get< 3 >(best)) >= 0.5) {
375 const NodeId x = std::get< 0 >(*std::get< 0 >(best));
376 const NodeId y = std::get< 1 >(*std::get< 0 >(best));
377 const NodeId z = std::get< 2 >(*std::get< 0 >(best));
379 const double i3 = std::get< 1 >(best);
381 const double p1 = std::get< 2 >(best);
382 const double p2 = std::get< 3 >(best);
390 delete std::get< 0 >(best);
391 proba_triples.erase(proba_triples.begin());
395 if (!proba_triples.empty()) best = proba_triples[0];
411 graph.
addArc(iter->head(), iter->tail());
413 *iter =
Arc(iter->head(), iter->tail());
421 HashTable< std::pair< NodeId, NodeId >,
char >& marks,
428 if (marks[{x, z}] ==
'o' && marks[{y, z}] ==
'o') {
433 GUM_SL_EMIT(x, z,
"Add Arc " << x <<
" -> " << z,
"V-structure Orientation")
451 GUM_SL_EMIT(z, x,
"Add Arc " << z <<
" -> " << x,
"V-structure Orientation")
462 GUM_SL_EMIT(y, z,
"Add Arc " << y <<
" -> " << z,
"V-structure Orientation")
477 GUM_SL_EMIT(z, y,
"Add Arc " << z <<
" -> " << y,
"V-structure Orientation")
483 }
else if (marks[{x, z}] ==
'>' && marks[{y, z}] ==
'o') {
490 "Add Arc " << y <<
" -> " << z,
491 "V-structure Orientation | existing "
492 << x <<
" -> " << z <<
", then orienting " << y <<
" -> " << z)
509 "Add Arc " << z <<
" -> " << y,
510 "V-structure Orientation | existing "
511 << x <<
" -> " << z <<
", then orienting " << z <<
" -> " << y)
518 }
else if (marks[{y, z}] ==
'>' && marks[{x, z}] ==
'o') {
523 GUM_SL_EMIT(x, z,
"Add Arc " << x <<
" -> " << z,
"V-structure Orientation")
539 GUM_SL_EMIT(z, x,
"Add Arc " << z <<
" -> " << x,
"V-structure Orientation")
549 HashTable< std::pair< NodeId, NodeId >,
char >& marks,
556 if (marks[{x, z}] ==
'>' && marks[{y, z}] ==
'o' && marks[{z, y}] !=
'-') {
563 "Add Arc " << z <<
" -> " << y,
564 "Propagation MIIC (919) | existing x -> " << z <<
" and " << z <<
" - "
575 GUM_SL_EMIT(y, z,
"Add Arc " << y <<
" -> " << z,
"Propagation MIIC line 932 ")
587 GUM_SL_EMIT(z, y,
"Add Arc " << z <<
"->" << y,
"Propagation MIIC 947 ")
597 GUM_SL_EMIT(z, y,
"Add Arc " << z <<
"->" << y,
"Propagation MIIC 959")
606 }
else if (marks[{y, z}] ==
'>' && marks[{x, z}] ==
'o' && marks[{z, x}] !=
'-') {
612 GUM_SL_EMIT(z, x,
"Add Arc " << z <<
" -> " << x,
"Propagation MIIC 977")
622 GUM_SL_EMIT(x, z,
"Add Arc " << x <<
"->" << z,
"Propagation MIIC 990")
633 GUM_SL_EMIT(z, x,
"Add Arc " << z <<
" -> " << x,
"Propagation MIIC 1004")
642 GUM_SL_EMIT(x, z,
"Add Arc " << x <<
" -> " << z,
"Propagation MIIC 1016")
669 bool withdrawFlag_L =
false;
670 for (
auto arc:
ArcSet(L)) {
673 bool withdrawFlag_arc =
false;
676 if (tail_head && !head_tail) {
679 withdrawFlag_arc =
true;
682 }
else if (!tail_head && head_tail) {
685 withdrawFlag_arc =
true;
688 }
else if (!tail_head && !head_tail) {
691 withdrawFlag_arc =
true;
695 if (withdrawFlag_arc) {
697 withdrawFlag_L =
true;
701 if (L.
empty()) {
break; }
706 if (!withdrawFlag_L) {
720 const std::vector< NodeId >& ui,
729 const double Ixy_ui = mutualInformation.
score(x, y, ui);
731 for (
const NodeId z: graph) {
733 if (z != x && z != y && std::find(ui.begin(), ui.end(), z) == ui.end()) {
738 const double Ixyz_ui = mutualInformation.
score(x, y, z, ui);
739 double calc_expo1 = -Ixyz_ui *
M_LN2;
743 }
else if (calc_expo1 < -
_maxLog_) {
746 Pnv = 1 / (1 + std::exp(calc_expo1));
750 const double Ixz_ui = mutualInformation.
score(x, z, ui);
751 const double Iyz_ui = mutualInformation.
score(y, z, ui);
753 calc_expo1 = -(Ixz_ui - Ixy_ui) *
M_LN2;
754 double calc_expo2 = -(Iyz_ui - Ixy_ui) *
M_LN2;
766 expo1 = std::exp(calc_expo1);
771 expo2 = std::exp(calc_expo2);
773 Pb = 1 / (1 + expo1 + expo2);
777 const double min_pnv_pb = std::min(Pnv, Pb);
778 if (min_pnv_pb > maxP) {
797 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
798 std::vector< Ranking > triples;
803 std::vector< NodeId > ui;
804 std::pair< NodeId, NodeId > key = {x, y};
805 std::pair< NodeId, NodeId > rev_key = {y, x};
806 if (sepSet.exists(key)) {
808 }
else if (sepSet.exists(rev_key)) {
809 ui = sepSet[rev_key];
812 const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
813 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
815 double Ixyz_ui = mutualInformation.
score(x, y, z, ui);
819 triple.second = Ixyz_ui;
820 triples.push_back(triple);
834 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
835 HashTable< std::pair< NodeId, NodeId >,
char >& marks) {
836 std::vector< ProbabilisticRanking > triples;
841 std::vector< NodeId > ui;
842 std::pair< NodeId, NodeId > key = {x, y};
843 std::pair< NodeId, NodeId > rev_key = {y, x};
844 if (sepSet.exists(key)) {
846 }
else if (sepSet.exists(rev_key)) {
847 ui = sepSet[rev_key];
850 const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
851 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
853 const double Ixyz_ui = mutualInformation.
score(x, y, z, ui);
856 triples.push_back(triple);
857 if (!marks.exists({x, z})) { marks.insert({x, z},
'o'); }
858 if (!marks.exists({z, x})) { marks.insert({z, x},
'o'); }
859 if (!marks.exists({y, z})) { marks.insert({y, z},
'o'); }
860 if (!marks.exists({z, y})) { marks.insert({z, y},
'o'); }
871 std::vector< ProbabilisticRanking >
873 std::vector< ProbabilisticRanking > probaTriples) {
874 for (
auto& triple: probaTriples) {
876 x = std::get< 0 >(*std::get< 0 >(triple));
877 y = std::get< 1 >(*std::get< 0 >(triple));
878 z = std::get< 2 >(*std::get< 0 >(triple));
879 const double Ixyz = std::get< 1 >(triple);
880 double Pxz = std::get< 2 >(triple);
881 double Pyz = std::get< 3 >(triple);
884 const double expo = std::exp(Ixyz);
885 const double P0 = (1 + expo) / (1 + 3 * expo);
887 if (Pxz == Pyz && Pyz == 0.5) {
888 std::get< 2 >(triple) = P0;
889 std::get< 3 >(triple) = P0;
891 if (graph.
existsArc(x, z) && Pxz >= P0) {
892 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
893 }
else if (graph.
existsArc(y, z) && Pyz >= P0) {
894 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
898 const double expo = std::exp(-Ixyz);
899 if (graph.
existsArc(x, z) && Pxz >= 0.5) {
900 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
901 }
else if (graph.
existsArc(y, z) && Pyz >= 0.5) {
902 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
920 for (
const auto parent: graph.
parents(n2)) {
941 while (!nodeFIFO.
empty()) {
942 current = nodeFIFO.
front();
946 for (
const auto new_one: graph.
parents(current)) {
951 if (new_one == n1) {
return true; }
968 return (std::find(lbeg, lend,
Arc(x, y)) == lend)
969 && (std::find(lbeg, lend,
Arc(y, x)) == lend);
A class that, given a structure and a parameter estimator returns a full Bayes net.
#define GUM_SL_EMIT(x, y, action, explain)
Size current_step_
The current step.
ApproximationScheme(bool verbosity=false)
bool existsArc(const Arc &arc) const
indicates whether a given arc exists
const NodeSet & parents(NodeId id) const
returns the set of nodes with arc ingoing to a given node
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
The base class for all directed edges.
Base class for all oriented graphs.
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
virtual void eraseEdge(const Edge &edge)
removes an edge from the EdgeGraphPart
const EdgeSet & edges() const
returns the set of edges stored within the EdgeGraphPart
bool existsEdge(const Edge &edge) const
indicates whether a given edge exists
const NodeSet & neighbours(NodeId id) const
returns the set of node neighbours to a given node
The base class for all undirected edges.
The class for generic Hash Tables.
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
Size size() const noexcept
Returns the number of elements in the heap.
Val pop()
Removes the top element from the heap and return it.
Size insert(const Val &val)
inserts a new element (actually a copy) in the heap and returns its index
const Val & top() const
Returns the element at the top of the heap.
Signaler3< Size, double, double > onProgress
Progression, error and time.
Generic doubly linked lists.
Val & front() const
Returns a reference to first element of a list, if any.
bool empty() const noexcept
Returns a boolean indicating whether the chained list is empty.
void popFront()
Removes the first element of a List, if any.
Val & pushBack(const Val &val)
Inserts a new element (a copy) at the end of the chained list.
Base class for mixed graphs.
const NodeGraphPart & nodes() const
return *this as a NodeGraphPart
Base class for partially directed acyclic graphs.
iterator begin() const
The usual unsafe begin iterator to parse the set.
Size size() const noexcept
Returns the number of elements in the set.
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
bool contains(const Key &k) const
Indicates whether a given elements belong to the set.
void insert(const Key &k)
Inserts a new element into the set.
bool empty() const noexcept
Indicates whether the set is the empty set.
void erase(const Key &k)
Erases an element from the set.
static BayesNet< GUM_SCALAR > createBN(ParamEstimator &estimator, const DAG &dag)
create a BN from a DAG using a one pass generator (typically ML)
bool operator()(const Ranking &e1, const Ranking &e2) const
bool operator()(const CondRanking &e1, const CondRanking &e2) const
bool operator()(const ProbabilisticRanking &e1, const ProbabilisticRanking &e2) const
static bool _existsDirectedPath_(const MixedGraph &graph, NodeId n1, NodeId n2)
checks for directed paths in a graph, consider double arcs like edges
std::vector< ProbabilisticRanking > updateProbaTriples_(const MixedGraph &graph, std::vector< ProbabilisticRanking > probaTriples)
Gets the orientation probabilities like MIIC for the orientation phase.
gum::DAG _mandatoryGraph_
Graph that contains the mandatories arcs.
bool isMaxIndegree_(MixedGraph graph, NodeId x)
void _orientingVstructureMiic_(MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, char > &marks, NodeId x, NodeId y, NodeId z, double p1, double p2)
void orientationMiic_(CorrectedMutualInformation &mutualInformation, MixedGraph &graph, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet)
Orientation phase from the MIIC algorithm, returns a mixed graph that may contain circles.
gum::DiGraph _forbiddenGraph_
Graph that contains the forbidden arcs.
void setMandatoryGraph(gum::DAG mandaGraph)
Miic & operator=(const Miic &from)
copy operator
MixedGraph learnMixedStructure(CorrectedMutualInformation &mutualInformation, MixedGraph graph)
learns the structure of a MixedGraph (Meek rules not used here).
void setMaxIndegree(gum::Size n)
gum::MeekRules meekRules_
Object that can propagates orientations to non-oriented edges.
gum::Size _maxIndegree_
maximum number of parents
const std::vector< NodeId > _emptySet_
an empty conditioning set
PDAG learnPDAG(CorrectedMutualInformation &mutualInformation, MixedGraph graph)
learns the structure of an Essential Graph
~Miic() override
destructor
std::vector< Arc > _latentCouples_
an empty vector of arcs
void addConstraints(HashTable< std::pair< NodeId, NodeId >, char > constraints)
Set a ensemble of constraints for the orientation phase.
void _propagatingOrientationMiic_(MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, char > &marks, NodeId x, NodeId y, NodeId z, double p1, double p2)
void iteration_(CorrectedMutualInformation &mutualInformation, MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet, Heap< CondRanking, GreaterPairOn2nd > &rank)
Iteration phase.
bool isArcValid_(MixedGraph graph, NodeId x, NodeId y)
void findBestContributor_(NodeId x, NodeId y, const std::vector< NodeId > &ui, const MixedGraph &graph, CorrectedMutualInformation &mutualInformation, Heap< CondRanking, GreaterPairOn2nd > &rank)
finds the best contributor node for a pair given a conditioning set
void setForbiddenGraph(gum::DiGraph forbidGraph)
Set ForbiddenGraph (resp. MadatoryGraph) which contains the forbidden (resp. mandatory) arcs.
std::vector< ProbabilisticRanking > unshieldedTriplesMiic_(const MixedGraph &graph, CorrectedMutualInformation &mutualInformation, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet, HashTable< std::pair< NodeId, NodeId >, char > &marks)
gets the list of unshielded triples in the graph in decreasing value of |I'(x, y, z|{ui}...
bool isForbiddenArc_(NodeId x, NodeId y) const
Check constraints.
Miic()
default constructor
BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR &selector, PARAM_ESTIMATOR &estimator, DAG initial_dag=DAG())
learns the structure and the parameters of a BN
std::vector< Ranking > unshieldedTriples_(const MixedGraph &graph, CorrectedMutualInformation &mutualInformation, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet)
gets the list of unshielded triples in the graph in decreasing value of |I'(x, y, z|{ui}...
HashTable< std::pair< NodeId, NodeId >, char > _initialMarks_
Initial marks for the orientation phase, used to convey constraints.
bool isForbiddenEdge_(NodeId x, NodeId y)
ArcProperty< double > _arcProbas_
Storing the propabilities for each arc set in the graph.
void initiation_(CorrectedMutualInformation &mutualInformation, MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet, Heap< CondRanking, GreaterPairOn2nd > &rank)
Initiation phase.
static bool _existsNonTrivialDirectedPath_(const MixedGraph &graph, NodeId n1, NodeId n2)
checks for directed paths in a graph, considering double arcs like edges, not considering arc as a di...
Size _size_
size of the database
bool _isNotLatentCouple_(NodeId x, NodeId y)
void orientDoubleHeadedArcs_(MixedGraph &mg)
Orient double headed arcs to avoid cycles.
const std::vector< Arc > latentVariables() const
get the list of arcs hiding latent variables
int _maxLog_
Fixes the maximum log that we accept in exponential computations.
MixedGraph learnSkeleton(CorrectedMutualInformation &mutualInformation, MixedGraph graph)
learns the skeleton of a MixedGraph (no orientation).
DAG learnStructure(CorrectedMutualInformation &I, MixedGraph graph)
learns the structure of a Bayesian network, i.e. a DAG, by first learning an Essential graph and then...
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Set< Edge > EdgeSet
Some typdefs and define for shortcuts ...
Size NodeId
Type for node ids.
Set< Arc > ArcSet
Some typdefs and define for shortcuts ...
Class hash tables iterators.
Base classes for mixed directed/undirected graphs.
include the inlined functions if necessary
std::pair< ThreePoints *, double > Ranking
std::pair< CondThreePoints *, double > CondRanking
std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > CondThreePoints
std::tuple< NodeId, NodeId, NodeId > ThreePoints
std::tuple< ThreePoints *, double, double, double > ProbabilisticRanking
gum is the global namespace for all aGrUM entities
#define GUM_EMIT3(signal, arg1, arg2, arg3)
Class used to compute response times for benchmark purposes.