aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
SimpleMiic.cpp
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40
41
47
55
57
58namespace gum {
59
60 namespace learning {
61
63 SimpleMiic::SimpleMiic() : _maxLog_(100), _size_(0) { GUM_CONSTRUCTOR(SimpleMiic); }
64
66 SimpleMiic::SimpleMiic(int maxLog) : _maxLog_(maxLog), _size_(0) {
67 GUM_CONSTRUCTOR(SimpleMiic);
68 }
69
72 ApproximationScheme(from), _size_(from._size_) {
73 GUM_CONS_CPY(SimpleMiic);
74 }
75
78 ApproximationScheme(std::move(from)), _size_(from._size_) {
79 GUM_CONS_MOV(SimpleMiic);
80 }
81
83 SimpleMiic::~SimpleMiic() { GUM_DESTRUCTOR(SimpleMiic); }
84
87 ApproximationScheme::operator=(from);
88 return *this;
89 }
90
93 ApproximationScheme::operator=(std::move(from));
94 return *this;
95 }
96
99 MixedGraph graph) {
100 timer_.reset();
101 current_step_ = 0;
102
103 // clear the vector of latent arcs to be sure
104 _latentCouples_.clear();
105
108
110 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > sep_set;
111
112 initiation_(mutualInformation, graph, sep_set, rank);
113
114 iteration_(mutualInformation, graph, sep_set, rank);
115
116 orientationMiic_(mutualInformation, graph, sep_set);
117
118 return graph;
119 }
120
121 /*
122 * PHASE 1 : INITIATION
123 *
124 * We go over all edges and test if the variables are independent. If they
125 * are,
126 * the edge is deleted. If not, the best contributor is found.
127 */
129 CorrectedMutualInformation& mutualInformation,
130 MixedGraph& graph,
131 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
133 NodeId x, y;
134 EdgeSet edges = graph.edges();
135 Size steps_init = edges.size();
136
137 for (const Edge& edge: edges) {
138 x = edge.first();
139 y = edge.second();
140 double Ixy = mutualInformation.score(x, y);
141
142 if (Ixy <= 0) { //< K
143 graph.eraseEdge(edge);
144 sepSet.insert(std::make_pair(x, y), _emptySet_);
145 } else {
146 findBestContributor_(x, y, _emptySet_, graph, mutualInformation, rank);
147 }
148
150 if (onProgress.hasListener()) {
151 GUM_EMIT3(onProgress, (current_step_ * 33) / steps_init, 0., timer_.step());
152 }
153 }
154 }
155
156 /*
157 * PHASE 2 : ITERATION
158 *
159 * As long as we find important nodes for edges, we go over them to see if
160 * we can assess the independence of the variables.
161 */
163 CorrectedMutualInformation& mutualInformation,
164 MixedGraph& graph,
165 HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
167 // if no triples to further examine pass
168 CondRanking best;
169
170 Size steps_init = current_step_;
171 Size steps_iter = rank.size();
172
173 try {
174 while (rank.top().second > 0.5) {
175 best = rank.pop();
176
177 const NodeId x = std::get< 0 >(*(best.first));
178 const NodeId y = std::get< 1 >(*(best.first));
179 const NodeId z = std::get< 2 >(*(best.first));
180 std::vector< NodeId > ui = std::move(std::get< 3 >(*(best.first)));
181
182 ui.push_back(z);
183 const double i_xy_ui = mutualInformation.score(x, y, ui);
184 if (i_xy_ui < 0) {
185 graph.eraseEdge(Edge(x, y));
186 sepSet.insert(std::make_pair(x, y), std::move(ui));
187 } else {
188 findBestContributor_(x, y, ui, graph, mutualInformation, rank);
189 }
190
191 delete best.first;
192
194 if (onProgress.hasListener()) {
196 (current_step_ * 66) / (steps_init + steps_iter),
197 0.,
198 timer_.step());
199 }
200 }
201 } catch (...) {} // here, rank is empty
202 current_step_ = steps_init + steps_iter;
203 if (onProgress.hasListener()) { GUM_EMIT3(onProgress, 66, 0., timer_.step()); }
204 current_step_ = steps_init + steps_iter;
205 }
206
207 /*
208 * PHASE 3 : ORIENTATION
209 *
210 * Try to assess v-structures and propagate them.
211 */
212
215 CorrectedMutualInformation& mutualInformation,
216 MixedGraph& graph,
217 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
218 std::vector< Ranking > triples = unshieldedTriples_(graph, mutualInformation, sepSet);
219 Size steps_orient = triples.size();
220 Size past_steps = current_step_;
221
222 NodeId i = 0;
223 // list of elements that we shouldnt read again, ie elements that are
224 // eligible to
225 // rule 0 after the first time they are tested, and elements on which rule 1
226 // has been applied
227 while (i < triples.size()) {
228 // if i not in do_not_reread
229 Ranking triple = triples[i];
230 NodeId x, y, z;
231 x = std::get< 0 >(*triple.first);
232 y = std::get< 1 >(*triple.first);
233 z = std::get< 2 >(*triple.first);
234
235 std::vector< NodeId > ui;
236 std::pair< NodeId, NodeId > key = {x, y};
237 std::pair< NodeId, NodeId > rev_key = {y, x};
238 if (sepSet.exists(key)) {
239 ui = sepSet[key];
240 } else if (sepSet.exists(rev_key)) {
241 ui = sepSet[rev_key];
242 }
243 double Ixyz_ui = triple.second;
244 // try Rule 0
245 if (Ixyz_ui < 0) {
246 // if ( z not in Sep[x,y])
247 if (std::find(ui.begin(), ui.end(), z) == ui.end()) {
248 // if what we want to add already exists : pass
249 if ((graph.existsArc(x, z) || graph.existsArc(z, x))
250 && (graph.existsArc(y, z) || graph.existsArc(z, y))) {
251 ++i;
252 } else {
253 i = 0;
254 graph.eraseEdge(Edge(x, z));
255 graph.eraseEdge(Edge(y, z));
256 // checking for cycles
257 if (graph.existsArc(z, x)) {
258 graph.eraseArc(Arc(z, x));
259 try {
260 std::vector< NodeId > path = graph.directedPath(z, x);
261 // if we find a cycle, we force the competing edge
262 _latentCouples_.emplace_back(z, x);
263 } catch (const gum::NotFound&) { graph.addArc(x, z); }
264 graph.addArc(z, x);
265 } else {
266 try {
267 std::vector< NodeId > path = graph.directedPath(z, x);
268 // if we find a cycle, we force the competing edge
269 graph.addArc(z, x);
270 _latentCouples_.emplace_back(z, x);
271 } catch (const gum::NotFound&) { graph.addArc(x, z); }
272 }
273 if (graph.existsArc(z, y)) {
274 graph.eraseArc(Arc(z, y));
275 try {
276 std::vector< NodeId > path = graph.directedPath(z, y);
277 // if we find a cycle, we force the competing edge
278 _latentCouples_.emplace_back(z, y);
279 } catch (const gum::NotFound&) { graph.addArc(y, z); }
280 graph.addArc(z, y);
281 } else {
282 try {
283 std::vector< NodeId > path = graph.directedPath(z, y);
284 // if we find a cycle, we force the competing edge
285 graph.addArc(z, y);
286 _latentCouples_.emplace_back(z, y);
287
288 } catch (const gum::NotFound&) { graph.addArc(y, z); }
289 }
290 if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
291 _latentCouples_.emplace_back(z, x);
292 }
293 if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
294 _latentCouples_.emplace_back(z, y);
295 }
296 }
297 } else {
298 ++i;
299 }
300 } else { // try Rule 1
301 bool reset{false};
302 if (graph.existsArc(x, z) && !graph.existsArc(z, y) && !graph.existsArc(y, z)) {
303 reset = true;
304 graph.eraseEdge(Edge(z, y));
305 try {
306 std::vector< NodeId > path = graph.directedPath(y, z);
307 // if we find a cycle, we force the competing edge
308 graph.addArc(y, z);
309 _latentCouples_.emplace_back(y, z);
310 } catch (const gum::NotFound&) { graph.addArc(z, y); }
311 }
312 if (graph.existsArc(y, z) && !graph.existsArc(z, x) && !graph.existsArc(x, z)) {
313 reset = true;
314 graph.eraseEdge(Edge(z, x));
315 try {
316 std::vector< NodeId > path = graph.directedPath(x, z);
317 // if we find a cycle, we force the competing edge
318 graph.addArc(x, z);
319 _latentCouples_.emplace_back(x, z);
320 } catch (const gum::NotFound&) { graph.addArc(z, x); }
321 }
322
323 if (reset) {
324 i = 0;
325 } else {
326 ++i;
327 }
328 } // if rule 0 or rule 1
329 if (onProgress.hasListener()) {
331 ((current_step_ + i) * 100) / (past_steps + steps_orient),
332 0.,
333 timer_.step());
334 }
335 } // while
336
337 // erasing the the double headed arcs
338 for (const Arc& arc: _latentCouples_) {
339 graph.eraseArc(Arc(arc.head(), arc.tail()));
340 }
341 }
342
345 CorrectedMutualInformation& mutualInformation,
346 MixedGraph& graph,
347 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
348 // structure to store the orientations marks -, o, or >,
349 // Considers the head of the arc/edge first node -* second node
351
352 // marks always correspond to the head of the arc/edge. - is for a forbidden
353 // arc, > for a mandatory arc
354 // we start by adding the mandatory arcs
355 for (auto iter = marks.begin(); iter != marks.end(); ++iter) {
356 if (graph.existsEdge(iter.key().first, iter.key().second) && iter.val() == '>') {
357 graph.eraseEdge(Edge(iter.key().first, iter.key().second));
358 graph.addArc(iter.key().first, iter.key().second);
359 }
360 }
361
362 std::vector< ProbabilisticRanking > proba_triples
363 = unshieldedTriplesMiic_(graph, mutualInformation, sepSet, marks);
364
365 const Size steps_orient = proba_triples.size();
366 Size past_steps = current_step_;
367
369 if (steps_orient > 0) { best = proba_triples[0]; }
370
371 while (!proba_triples.empty() && std::max(std::get< 2 >(best), std::get< 3 >(best)) > 0.5) {
372 const NodeId x = std::get< 0 >(*std::get< 0 >(best));
373 const NodeId y = std::get< 1 >(*std::get< 0 >(best));
374 const NodeId z = std::get< 2 >(*std::get< 0 >(best));
375
376 const double i3 = std::get< 1 >(best);
377
378 const double p1 = std::get< 2 >(best);
379 const double p2 = std::get< 3 >(best);
380 if (i3 <= 0) {
381 _orientingVstructureMiic_(graph, marks, x, y, z, p1, p2);
382 } else {
383 _propagatingOrientationMiic_(graph, marks, x, y, z, p1, p2);
384 }
385
386 delete std::get< 0 >(best);
387 proba_triples.erase(proba_triples.begin());
388 // actualisation of the list of triples
389 proba_triples = updateProbaTriples_(graph, proba_triples);
390
391 if (!proba_triples.empty()) best = proba_triples[0];
392
394 if (onProgress.hasListener()) {
396 (current_step_ * 100) / (steps_orient + past_steps),
397 0.,
398 timer_.step());
399 }
400 } // while
401
402 // erasing the double headed arcs
403 GUM_TRACE(_latentCouples_)
404 for (auto iter = _latentCouples_.rbegin(); iter != _latentCouples_.rend(); ++iter) {
405 graph.eraseArc(Arc(iter->head(), iter->tail()));
406 if (_existsDirectedPath_(graph, iter->head(), iter->tail())) {
407 // if we find a cycle, we force the competing edge
408 graph.addArc(iter->head(), iter->tail());
409 graph.eraseArc(Arc(iter->tail(), iter->head()));
410 *iter = Arc(iter->head(), iter->tail());
411 }
412 }
413
414 if (onProgress.hasListener()) { GUM_EMIT3(onProgress, 100, 0., timer_.step()); }
415 }
416
419 NodeId y,
420 const std::vector< NodeId >& ui,
421 const MixedGraph& graph,
422 CorrectedMutualInformation& mutualInformation,
424 double maxP = -1.0;
425 NodeId maxZ = 0;
426
427 // compute N
428 // __N = I.N();
429 const double Ixy_ui = mutualInformation.score(x, y, ui);
430
431 for (const NodeId z: graph) {
432 // if z!=x and z!=y and z not in ui
433 if (z != x && z != y && std::find(ui.begin(), ui.end(), z) == ui.end()) {
434 double Pnv;
435 double Pb;
436
437 // Computing Pnv
438 const double Ixyz_ui = mutualInformation.score(x, y, z, ui);
439 double calc_expo1 = -Ixyz_ui * M_LN2;
440 // if exponential are too high or to low, crop them at _maxLog_
441 if (calc_expo1 > _maxLog_) {
442 Pnv = 0.0;
443 } else if (calc_expo1 < -_maxLog_) {
444 Pnv = 1.0;
445 } else {
446 Pnv = 1 / (1 + std::exp(calc_expo1));
447 }
448
449 // Computing Pb
450 const double Ixz_ui = mutualInformation.score(x, z, ui);
451 const double Iyz_ui = mutualInformation.score(y, z, ui);
452
453 calc_expo1 = -(Ixz_ui - Ixy_ui) * M_LN2;
454 double calc_expo2 = -(Iyz_ui - Ixy_ui) * M_LN2;
455
456 // if exponential are too high or to low, crop them at _maxLog_
457 if (calc_expo1 > _maxLog_ || calc_expo2 > _maxLog_) {
458 Pb = 0.0;
459 } else if (calc_expo1 < -_maxLog_ && calc_expo2 < -_maxLog_) {
460 Pb = 1.0;
461 } else {
462 double expo1, expo2;
463 if (calc_expo1 < -_maxLog_) {
464 expo1 = 0.0;
465 } else {
466 expo1 = std::exp(calc_expo1);
467 }
468 if (calc_expo2 < -_maxLog_) {
469 expo2 = 0.0;
470 } else {
471 expo2 = std::exp(calc_expo2);
472 }
473 Pb = 1 / (1 + expo1 + expo2);
474 }
475
476 // Getting max(min(Pnv, pb))
477 const double min_pnv_pb = std::min(Pnv, Pb);
478 if (min_pnv_pb > maxP) {
479 maxP = min_pnv_pb;
480 maxZ = z;
481 }
482 } // if z not in (x, y)
483 } // for z in graph.nodes
484 // storing best z in rank_
485 CondRanking final;
486 auto tup = new CondThreePoints{x, y, maxZ, ui};
487 final.first = tup;
488 final.second = maxP;
489 rank.insert(final);
490 }
491
494 std::vector< Ranking > SimpleMiic::unshieldedTriples_(
495 const MixedGraph& graph,
496 CorrectedMutualInformation& mutualInformation,
497 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet) {
498 std::vector< Ranking > triples;
499 for (NodeId z: graph) {
500 for (NodeId x: graph.neighbours(z)) {
501 for (NodeId y: graph.neighbours(z)) {
502 if (y < x && !graph.existsEdge(x, y)) {
503 std::vector< NodeId > ui;
504 std::pair< NodeId, NodeId > key = {x, y};
505 std::pair< NodeId, NodeId > rev_key = {y, x};
506 if (sepSet.exists(key)) {
507 ui = sepSet[key];
508 } else if (sepSet.exists(rev_key)) {
509 ui = sepSet[rev_key];
510 }
511 // remove z from ui if it's present
512 const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
513 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
514
515 double Ixyz_ui = mutualInformation.score(x, y, z, ui);
516 Ranking triple;
517 auto tup = new ThreePoints{x, y, z};
518 triple.first = tup;
519 triple.second = Ixyz_ui;
520 triples.push_back(triple);
521 }
522 }
523 }
524 }
525 std::sort(triples.begin(), triples.end(), GreaterAbsPairOn2nd());
526 return triples;
527 }
528
531 std::vector< ProbabilisticRanking > SimpleMiic::unshieldedTriplesMiic_(
532 const MixedGraph& graph,
533 CorrectedMutualInformation& mutualInformation,
534 const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > >& sepSet,
535 HashTable< std::pair< NodeId, NodeId >, char >& marks) {
536 std::vector< ProbabilisticRanking > triples;
537 for (NodeId z: graph) {
538 for (NodeId x: graph.neighbours(z)) {
539 for (NodeId y: graph.neighbours(z)) {
540 if (y < x && !graph.existsEdge(x, y)) {
541 std::vector< NodeId > ui;
542 std::pair< NodeId, NodeId > key = {x, y};
543 std::pair< NodeId, NodeId > rev_key = {y, x};
544 if (sepSet.exists(key)) {
545 ui = sepSet[key];
546 } else if (sepSet.exists(rev_key)) {
547 ui = sepSet[rev_key];
548 }
549 // remove z from ui if it's present
550 const auto iter_z_place = std::find(ui.begin(), ui.end(), z);
551 if (iter_z_place != ui.end()) { ui.erase(iter_z_place); }
552
553 const double Ixyz_ui = mutualInformation.score(x, y, z, ui);
554 auto tup = new ThreePoints{x, y, z};
555 ProbabilisticRanking triple{tup, Ixyz_ui, 0.5, 0.5};
556 triples.push_back(triple);
557 if (!marks.exists({x, z})) { marks.insert({x, z}, 'o'); }
558 if (!marks.exists({z, x})) { marks.insert({z, x}, 'o'); }
559 if (!marks.exists({y, z})) { marks.insert({y, z}, 'o'); }
560 if (!marks.exists({z, y})) { marks.insert({z, y}, 'o'); }
561 }
562 }
563 }
564 }
565 triples = updateProbaTriples_(graph, triples);
566 std::sort(triples.begin(), triples.end(), GreaterTupleOnLast());
567 return triples;
568 }
569
571 std::vector< ProbabilisticRanking >
573 std::vector< ProbabilisticRanking > probaTriples) {
574 for (auto& triple: probaTriples) {
575 NodeId x, y, z;
576 x = std::get< 0 >(*std::get< 0 >(triple));
577 y = std::get< 1 >(*std::get< 0 >(triple));
578 z = std::get< 2 >(*std::get< 0 >(triple));
579 const double Ixyz = std::get< 1 >(triple);
580 double Pxz = std::get< 2 >(triple);
581 double Pyz = std::get< 3 >(triple);
582
583 if (Ixyz <= 0) {
584 const double expo = std::exp(Ixyz);
585 const double P0 = (1 + expo) / (1 + 3 * expo);
586 // distinguish between the initialization and the update process
587 if (Pxz == Pyz && Pyz == 0.5) {
588 std::get< 2 >(triple) = P0;
589 std::get< 3 >(triple) = P0;
590 } else {
591 if (graph.existsArc(x, z) && Pxz >= P0) {
592 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
593 } else if (graph.existsArc(y, z) && Pyz >= P0) {
594 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
595 }
596 }
597 } else {
598 const double expo = std::exp(-Ixyz);
599 if (graph.existsArc(x, z) && Pxz >= 0.5) {
600 std::get< 3 >(triple) = Pxz * (1 / (1 + expo) - 0.5) + 0.5;
601 } else if (graph.existsArc(y, z) && Pyz >= 0.5) {
602 std::get< 2 >(triple) = Pyz * (1 / (1 + expo) - 0.5) + 0.5;
603 }
604 }
605 }
606 std::sort(probaTriples.begin(), probaTriples.end(), GreaterTupleOnLast());
607 return probaTriples;
608 }
609
613
615 MixedGraph essentialGraph = learnMixedStructure(I, initialGraph);
616
617 // orientate remaining edges
618 const Sequence< NodeId > order = essentialGraph.topologicalOrder();
619
620 // first, forbidden arcs force arc in the other direction
621 for (NodeId x: order) {
622 const auto nei_x = essentialGraph.neighbours(x);
623 for (NodeId y: nei_x)
624 if (isForbidenArc_(x, y)) {
625 essentialGraph.eraseEdge(Edge(x, y));
626 if (isForbidenArc_(y, x)) {
627 // GUM_TRACE("Neither arc allowed for edge (" << x << "," << y << ")")
628 } else {
629 // GUM_TRACE("Forced orientation : " << y << "->" << x)
630 essentialGraph.addArc(y, x);
631 }
632 } else if (isForbidenArc_(y, x)) {
633 essentialGraph.eraseEdge(Edge(x, y));
634 // GUM_TRACE("Forced orientation : " << x << "->" << y)
635 essentialGraph.addArc(x, y);
636 }
637 }
638
639 // then propagates existing orientations thanks to Meek rules
640 bool newOrientation = true;
641 while (newOrientation) {
642 newOrientation = false;
643 for (NodeId x: order) {
644 if (!essentialGraph.parents(x).empty()) {
645 newOrientation |= propagatesRemainingOrientableEdges_(essentialGraph, x);
646 }
647 }
648 }
649 return essentialGraph;
650 }
651
655 MixedGraph essentialGraph = learnMixedStructure(I, initialGraph);
656 // orientate remaining edges
657
658 const Sequence< NodeId > order = essentialGraph.topologicalOrder();
659
660 // first, forbidden arcs force arc in the other direction
661 for (NodeId x: order) {
662 const auto nei_x = essentialGraph.neighbours(x);
663 for (NodeId y: nei_x)
664 if (isForbidenArc_(x, y)) {
665 essentialGraph.eraseEdge(Edge(x, y));
666 if (isForbidenArc_(y, x)) {
667 // GUM_TRACE("Neither arc allowed for edge (" << x << "," << y << ")")
668 } else {
669 // GUM_TRACE("Forced orientation : " << y << "->" << x)
670 essentialGraph.addArc(y, x);
671 }
672 } else if (isForbidenArc_(y, x)) {
673 essentialGraph.eraseEdge(Edge(x, y));
674 // GUM_TRACE("Forced orientation : " << x << "->" << y)
675 essentialGraph.addArc(x, y);
676 }
677 }
678 // GUM_TRACE(essentialGraph.toDot());
679
680 // first, propagate existing orientations
681 bool newOrientation = true;
682 while (newOrientation) {
683 newOrientation = false;
684 for (NodeId x: order) {
685 if (!essentialGraph.parents(x).empty()) {
686 newOrientation |= propagatesRemainingOrientableEdges_(essentialGraph, x);
687 }
688 }
689 }
690 // GUM_TRACE(essentialGraph.toDot());
692 // GUM_TRACE(essentialGraph.toDot());
693
694 // then decide the orientation for double arcs
695 for (NodeId x: order)
696 for (NodeId y: essentialGraph.parents(x))
697 if (essentialGraph.parents(y).contains(x)) {
698 // GUM_TRACE(" + Resolving double arcs (poorly)")
699 essentialGraph.eraseArc(Arc(y, x));
700 }
701
702 DAG dag;
703 for (auto node: essentialGraph) {
704 dag.addNodeWithId(node);
705 }
706 for (const Arc& arc: essentialGraph.arcs()) {
707 dag.addArc(arc.tail(), arc.head());
708 }
709
710 return dag;
711 }
712
713 bool SimpleMiic::isOrientable_(const MixedGraph& graph, NodeId xi, NodeId xj) const {
714 // no cycle
715 if (_existsDirectedPath_(graph, xj, xi)) {
716 // GUM_TRACE("cycle(" << xi << "-" << xj << ")")
717 return false;
718 }
719
720 // R1
721 if (!(graph.parents(xi) - graph.boundary(xj)).empty()) {
722 // GUM_TRACE("R1(" << xi << "-" << xj << ")")
723 return true;
724 }
725
726 // R2
727 if (_existsDirectedPath_(graph, xi, xj)) {
728 // GUM_TRACE("R2(" << xi << "-" << xj << ")")
729 return true;
730 }
731
732 // R3
733 int nbr = 0;
734 for (const auto p: graph.parents(xj)) {
735 if (!graph.mixedOrientedPath(xi, p).empty()) {
736 nbr += 1;
737 if (nbr == 2) {
738 // GUM_TRACE("R3(" << xi << "-" << xj << ")")
739 return true;
740 }
741 }
742 }
743 return false;
744 }
745
747 // then decide the orientation for remaining edges
748 while (!essentialGraph.edges().empty()) {
749 const auto& edge = *(essentialGraph.edges().begin());
750 NodeId root = edge.first();
751 Size size_children_root = essentialGraph.children(root).size();
752 NodeSet visited;
753 NodeSet stack{root};
754 // check the best root for the set of neighbours
755 while (!stack.empty()) {
756 NodeId next = *(stack.begin());
757 stack.erase(next);
758 if (visited.contains(next)) continue;
759 if (essentialGraph.children(next).size() > size_children_root) {
760 size_children_root = essentialGraph.children(next).size();
761 root = next;
762 }
763 for (const auto n: essentialGraph.neighbours(next))
764 if (!stack.contains(n) && !visited.contains(n)) stack.insert(n);
765 visited.insert(next);
766 }
767 // orientation now
768 visited.clear();
769 stack.clear();
770 stack.insert(root);
771 while (!stack.empty()) {
772 NodeId next = *(stack.begin());
773 stack.erase(next);
774 if (visited.contains(next)) continue;
775 const auto nei = essentialGraph.neighbours(next);
776 for (const auto n: nei) {
777 if (!stack.contains(n) && !visited.contains(n)) stack.insert(n);
778 // GUM_TRACE(" + amap reasonably orientation for " << n << "->" << next);
779 if (propagatesRemainingOrientableEdges_(essentialGraph, next)) continue;
780 else essentialGraph.eraseEdge(Edge(n, next));
781 essentialGraph.addArc(n, next);
782 }
783 visited.insert(next);
784 }
785 }
786 }
787
790 bool res = false;
791 const auto neighbours = graph.neighbours(xj);
792 for (auto& xi: neighbours) {
793 bool i_j = isOrientable_(graph, xi, xj);
794 bool j_i = isOrientable_(graph, xj, xi);
795 if (i_j || j_i) {
796 // GUM_TRACE(" + Removing edge (" << xi << "," << xj << ")")
797 graph.eraseEdge(Edge(xi, xj));
798 res = true;
799 }
800 if (i_j) {
801 // GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
802 graph.addArc(xi, xj);
804 }
805 if (j_i) {
806 // GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
807 graph.addArc(xj, xi);
809 }
810 if (i_j && j_i) {
811 GUM_TRACE(" + add arc (" << xi << "," << xj << ")")
812 _latentCouples_.emplace_back(xi, xj);
813 }
814 }
815
816 return res;
817 }
818
820 const std::vector< Arc > SimpleMiic::latentVariables() const {
821 GUM_CHECKPOINT
822 return _latentCouples_;
823 }
824
826 template < typename GUM_SCALAR, typename GRAPH_CHANGES_SELECTOR, typename PARAM_ESTIMATOR >
827 BayesNet< GUM_SCALAR > SimpleMiic::learnBN(GRAPH_CHANGES_SELECTOR& selector,
828 PARAM_ESTIMATOR& estimator,
829 DAG initial_dag) {
831 learnStructure(selector, initial_dag));
832 }
833
834 void SimpleMiic::addConstraints(HashTable< std::pair< NodeId, NodeId >, char > constraints) {
835 this->_initialMarks_ = constraints;
836 }
837
839 const NodeId n1,
840 const NodeId n2) {
841 for (const auto parent: graph.parents(n2)) {
842 if (graph.existsArc(parent,
843 n2)) // if there is a double arc, pass
844 continue;
845 if (parent == n1) // trivial directed path => not recognized
846 continue;
847 if (_existsDirectedPath_(graph, n1, parent)) return true;
848 }
849 return false;
850 }
851
853 const NodeId n1,
854 const NodeId n2) {
855 // not recursive version => use a FIFO for simulating the recursion
856 List< NodeId > nodeFIFO;
857 // mark[node] = successor if visited, else mark[node] does not exist
858 Set< NodeId > mark;
859
860 mark.insert(n2);
861 nodeFIFO.pushBack(n2);
862
863 NodeId current;
864
865 while (!nodeFIFO.empty()) {
866 current = nodeFIFO.front();
867 nodeFIFO.popFront();
868
869 // check the parents
870 for (const auto new_one: graph.parents(current)) {
871 if (graph.existsArc(current,
872 new_one)) // if there is a double arc, pass
873 continue;
874
875 if (new_one == n1) { return true; }
876
877 if (mark.exists(new_one)) // if this node is already marked, do not
878 continue; // check it again
879
880 mark.insert(new_one);
881 nodeFIFO.pushBack(new_one);
882 }
883 }
884
885 return false;
886 }
887
888 void
890 HashTable< std::pair< NodeId, NodeId >, char >& marks,
891 NodeId x,
892 NodeId y,
893 NodeId z,
894 double p1,
895 double p2) {
896 // v-structure discovery
897 if (marks[{x, z}] == 'o' && marks[{y, z}] == 'o') { // If x-z-y
898 if (!_existsNonTrivialDirectedPath_(graph, z, x)) {
899 graph.eraseEdge(Edge(x, z));
900 graph.addArc(x, z);
901 // GUM_TRACE("1.a Removing edge (" << x << "," << z << ")")
902 // GUM_TRACE("1.a Adding arc (" << x << "," << z << ")")
903 marks[{x, z}] = '>';
904 if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
905 GUM_TRACE("Adding latent couple (" << z << "," << x << ")")
906 _latentCouples_.emplace_back(z, x);
907 }
908 if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
909 } else {
910 graph.eraseEdge(Edge(x, z));
911 // GUM_TRACE("1.b Adding arc (" << x << "," << z << ")")
912 if (!_existsNonTrivialDirectedPath_(graph, x, z)) {
913 graph.addArc(z, x);
914 // GUM_TRACE("1.b Removing edge (" << x << "," << z << ")")
915 marks[{z, x}] = '>';
916 }
917 }
918
919 if (!_existsNonTrivialDirectedPath_(graph, z, y)) {
920 graph.eraseEdge(Edge(y, z));
921 graph.addArc(y, z);
922 // GUM_TRACE("1.c Removing edge (" << y << "," << z << ")")
923 // GUM_TRACE("1.c Adding arc (" << y << "," << z << ")")
924 marks[{y, z}] = '>';
925 if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
926 GUM_TRACE("Adding latent couple (" << z << "," << y << ")")
927 _latentCouples_.emplace_back(z, y);
928 }
929 if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
930 } else {
931 graph.eraseEdge(Edge(y, z));
932 // GUM_TRACE("1.d Removing edge (" << y << "," << z << ")")
933 if (!_existsNonTrivialDirectedPath_(graph, y, z)) {
934 graph.addArc(z, y);
935 // GUM_TRACE("1.d Adding arc (" << z << "," << y << ")")
936 marks[{z, y}] = '>';
937 }
938 }
939 } else if (marks[{x, z}] == '>' && marks[{y, z}] == 'o') { // If x->z-y
940 if (!_existsNonTrivialDirectedPath_(graph, z, y)) {
941 graph.eraseEdge(Edge(y, z));
942 graph.addArc(y, z);
943 // GUM_TRACE("2.a Removing edge (" << y << "," << z << ")")
944 // GUM_TRACE("2.a Adding arc (" << y << "," << z << ")")
945 marks[{y, z}] = '>';
946 if (graph.existsArc(z, y) && _isNotLatentCouple_(z, y)) {
947 GUM_TRACE("Adding latent couple (" << z << "," << y << ")")
948 _latentCouples_.emplace_back(z, y);
949 }
950 if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
951 } else {
952 graph.eraseEdge(Edge(y, z));
953 // GUM_TRACE("2.b Removing edge (" << y << "," << z << ")")
954 if (!_existsNonTrivialDirectedPath_(graph, y, z)) {
955 graph.addArc(z, y);
956 // GUM_TRACE("2.b Adding arc (" << y << "," << z << ")")
957 marks[{z, y}] = '>';
958 }
959 }
960 } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o') {
961 if (!_existsNonTrivialDirectedPath_(graph, z, x)) {
962 graph.eraseEdge(Edge(x, z));
963 graph.addArc(x, z);
964 // GUM_TRACE("3.a Removing edge (" << x << "," << z << ")")
965 // GUM_TRACE("3.a Adding arc (" << x << "," << z << ")")
966 marks[{x, z}] = '>';
967 if (graph.existsArc(z, x) && _isNotLatentCouple_(z, x)) {
968 GUM_TRACE("Adding latent couple (" << z << "," << x << ")")
969 _latentCouples_.emplace_back(z, x);
970 }
971 if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
972 } else {
973 graph.eraseEdge(Edge(x, z));
974 // GUM_TRACE("3.b Removing edge (" << x << "," << z << ")")
975 if (!_existsNonTrivialDirectedPath_(graph, x, z)) {
976 graph.addArc(z, x);
977 // GUM_TRACE("3.b Adding arc (" << x << "," << z << ")")
978 marks[{z, x}] = '>';
979 }
980 }
981 }
982 }
983
985 MixedGraph& graph,
986 HashTable< std::pair< NodeId, NodeId >, char >& marks,
987 NodeId x,
988 NodeId y,
989 NodeId z,
990 double p1,
991 double p2) {
992 // orientation propagation
993 if (marks[{x, z}] == '>' && marks[{y, z}] == 'o' && marks[{z, y}] != '-') {
994 graph.eraseEdge(Edge(z, y));
995 // std::cout << "4. Removing edge (" << z << "," << y << ")" <<
996 // std::endl;
997 if (!_existsDirectedPath_(graph, y, z) && graph.parents(y).empty()) {
998 graph.addArc(z, y);
999 // GUM_TRACE("4.a Adding arc (" << z << "," << y << ")")
1000 marks[{z, y}] = '>';
1001 marks[{y, z}] = '-';
1002 if (!_arcProbas_.exists(Arc(z, y))) _arcProbas_.insert(Arc(z, y), p2);
1003 } else if (!_existsDirectedPath_(graph, z, y) && graph.parents(z).empty()) {
1004 graph.addArc(y, z);
1005 GUM_TRACE("4.b Adding arc (" << y << "," << z << ")")
1006 marks[{z, y}] = '-';
1007 marks[{y, z}] = '>';
1008 _latentCouples_.emplace_back(y, z);
1009 if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1010 } else if (!_existsDirectedPath_(graph, y, z)) {
1011 graph.addArc(z, y);
1012 // GUM_TRACE("4.c Adding arc (" << z << "," << y << ")")
1013 marks[{z, y}] = '>';
1014 marks[{y, z}] = '-';
1015 if (!_arcProbas_.exists(Arc(z, y))) _arcProbas_.insert(Arc(z, y), p2);
1016 } else if (!_existsDirectedPath_(graph, z, y)) {
1017 graph.addArc(y, z);
1018 GUM_TRACE("4.d Adding arc (" << y << "," << z << ")")
1019 _latentCouples_.emplace_back(y, z);
1020 marks[{z, y}] = '-';
1021 marks[{y, z}] = '>';
1022 if (!_arcProbas_.exists(Arc(y, z))) _arcProbas_.insert(Arc(y, z), p2);
1023 }
1024 } else if (marks[{y, z}] == '>' && marks[{x, z}] == 'o' && marks[{z, x}] != '-') {
1025 graph.eraseEdge(Edge(z, x));
1026 // GUM_TRACE("5. Removing edge (" << z << "," << x << ")")
1027 if (!_existsDirectedPath_(graph, x, z) && graph.parents(x).empty()) {
1028 graph.addArc(z, x);
1029 // GUM_TRACE("5.a Adding arc (" << z << "," << x << ")")
1030 marks[{z, x}] = '>';
1031 marks[{x, z}] = '-';
1032 if (!_arcProbas_.exists(Arc(z, x))) _arcProbas_.insert(Arc(z, x), p1);
1033 } else if (!_existsDirectedPath_(graph, z, x) && graph.parents(z).empty()) {
1034 graph.addArc(x, z);
1035 GUM_TRACE("5.b Adding arc (" << x << "," << z << ")")
1036 marks[{z, x}] = '-';
1037 marks[{x, z}] = '>';
1038 _latentCouples_.emplace_back(x, z);
1039 if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1040 } else if (!_existsDirectedPath_(graph, x, z)) {
1041 graph.addArc(z, x);
1042 // GUM_TRACE("5.c Adding arc (" << z << "," << x << ")")
1043 marks[{z, x}] = '>';
1044 marks[{x, z}] = '-';
1045 if (!_arcProbas_.exists(Arc(z, x))) _arcProbas_.insert(Arc(z, x), p1);
1046 } else if (!_existsDirectedPath_(graph, z, x)) {
1047 graph.addArc(x, z);
1048 GUM_TRACE("5.d Adding arc (" << x << "," << z << ")")
1049 marks[{z, x}] = '-';
1050 marks[{x, z}] = '>';
1051 _latentCouples_.emplace_back(x, z);
1052 if (!_arcProbas_.exists(Arc(x, z))) _arcProbas_.insert(Arc(x, z), p1);
1053 }
1054 }
1055 }
1056
1058 const auto& lbeg = _latentCouples_.begin();
1059 const auto& lend = _latentCouples_.end();
1060
1061 return (std::find(lbeg, lend, Arc(x, y)) == lend)
1062 && (std::find(lbeg, lend, Arc(y, x)) == lend);
1063 }
1064
1066 return (_initialMarks_.exists({x, y}) && _initialMarks_[{x, y}] == '-');
1067 }
1068 } /* namespace learning */
1069
1070} /* namespace gum */
A class that, given a structure and a parameter estimator returns a full Bayes net.
The SimpleMiic algorithm.
Size current_step_
The current step.
ApproximationScheme(bool verbosity=false)
bool existsArc(const Arc &arc) const
indicates whether a given arc exists
const NodeSet & parents(NodeId id) const
returns the set of nodes with arc ingoing to a given node
NodeSet children(const NodeSet &ids) const
returns the set of children of a set of nodes
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
std::vector< NodeId > directedPath(NodeId node1, NodeId node2) const
returns a directed path from node1 to node2 belonging to the set of arcs
const ArcSet & arcs() const
returns the set of arcs stored within the ArcGraphPart
The base class for all directed edges.
Base class for dag.
Definition DAG.h:121
void addArc(NodeId tail, NodeId head) final
insert a new arc into the directed graph
Definition DAG_inl.h:63
virtual void addArc(const NodeId tail, const NodeId head)
insert a new arc into the directed graph
Definition diGraph_inl.h:55
Sequence< NodeId > topologicalOrder() const
Build and return a topological order.
Definition diGraph.cpp:111
virtual void eraseEdge(const Edge &edge)
removes an edge from the EdgeGraphPart
const EdgeSet & edges() const
returns the set of edges stored within the EdgeGraphPart
bool existsEdge(const Edge &edge) const
indicates whether a given edge exists
const NodeSet & neighbours(NodeId id) const
returns the set of node neighbours to a given node
The base class for all undirected edges.
The class for generic Hash Tables.
Definition hashTable.h:637
iterator begin()
Returns an unsafe iterator pointing to the beginning of the hashtable.
const iterator & end() noexcept
Returns the unsafe iterator pointing to the end of the hashtable.
Heap data structure.
Definition heap.h:141
Size size() const noexcept
Returns the number of elements in the heap.
Definition heap_tpl.h:148
Val pop()
Removes the top element from the heap and return it.
Definition heap_tpl.h:213
Size insert(const Val &val)
inserts a new element (actually a copy) in the heap and returns its index
Definition heap_tpl.h:239
const Val & top() const
Returns the element at the top of the heap.
Definition heap_tpl.h:140
Signaler3< Size, double, double > onProgress
Progression, error and time.
Generic doubly linked lists.
Definition list.h:379
Val & front() const
Returns a reference to first element of a list, if any.
Definition list_tpl.h:1703
bool empty() const noexcept
Returns a boolean indicating whether the chained list is empty.
Definition list_tpl.h:1831
void popFront()
Removes the first element of a List, if any.
Definition list_tpl.h:1825
Val & pushBack(const Val &val)
Inserts a new element (a copy) at the end of the chained list.
Definition list_tpl.h:1488
Base class for mixed graphs.
Definition mixedGraph.h:146
NodeSet boundary(NodeId node) const
returns the set of node adjacent to a given node
std::vector< NodeId > mixedOrientedPath(NodeId node1, NodeId node2) const
returns a mixed edge/directed arc path from node1 to node2 in the arc/edge set
virtual void addNodeWithId(const NodeId id)
try to insert a node with the given id
iterator begin() const
The usual unsafe begin iterator to parse the set.
Definition set_tpl.h:438
Size size() const noexcept
Returns the number of elements in the set.
Definition set_tpl.h:636
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
Definition set_tpl.h:533
bool contains(const Key &k) const
Indicates whether a given elements belong to the set.
Definition set_tpl.h:497
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
bool empty() const noexcept
Indicates whether the set is the empty set.
Definition set_tpl.h:642
void erase(const Key &k)
Erases an element from the set.
Definition set_tpl.h:582
void clear()
Removes all the elements, if any, from the set.
Definition set_tpl.h:338
The class computing n times the corrected mutual information, as used in the MIIC algorithm.
double score(NodeId var1, NodeId var2)
returns the 2-point mutual information corresponding to a given nodeset
static BayesNet< GUM_SCALAR > createBN(ParamEstimator &estimator, const DAG &dag)
create a BN from a DAG using a one pass generator (typically ML)
DAG learnStructure(CorrectedMutualInformation &I, MixedGraph graph)
learns the structure of a Bayesian network, i.e. a DAG, by first learning an Essential graph and then...
bool isOrientable_(const MixedGraph &graph, NodeId xi, NodeId xj) const
const std::vector< Arc > latentVariables() const
get the list of arcs hiding latent variables
const std::vector< NodeId > _emptySet_
an empty conditioning set
Definition SimpleMiic.h:291
MixedGraph learnMixedStructure(CorrectedMutualInformation &mutualInformation, MixedGraph graph)
learns the structure of an Essential Graph
SimpleMiic & operator=(const SimpleMiic &from)
copy operator
void orientationMiic_(CorrectedMutualInformation &mutualInformation, MixedGraph &graph, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet)
Orientation phase from the MIIC algorithm, returns a mixed graph that may contain circles.
void _propagatingOrientationMiic_(MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, char > &marks, NodeId x, NodeId y, NodeId z, double p1, double p2)
void iteration_(CorrectedMutualInformation &mutualInformation, MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet, Heap< CondRanking, GreaterPairOn2nd > &rank)
Iteration phase.
bool _isNotLatentCouple_(NodeId x, NodeId y)
int _maxLog_
Fixes the maximum log that we accept in exponential computations.
Definition SimpleMiic.h:289
void _orientingVstructureMiic_(MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, char > &marks, NodeId x, NodeId y, NodeId z, double p1, double p2)
~SimpleMiic() override
destructor
std::vector< ProbabilisticRanking > unshieldedTriplesMiic_(const MixedGraph &graph, CorrectedMutualInformation &mutualInformation, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet, HashTable< std::pair< NodeId, NodeId >, char > &marks)
gets the list of unshielded triples in the graph in decreasing value of |I'(x, y, z|{ui}...
ArcProperty< double > _arcProbas_
Storing the propabilities for each arc set in the graph.
Definition SimpleMiic.h:299
std::vector< Arc > _latentCouples_
an empty vector of arcs
Definition SimpleMiic.h:293
static bool _existsDirectedPath_(const MixedGraph &graph, NodeId n1, NodeId n2)
checks for directed paths in a graph, consider double arcs like edges
HashTable< std::pair< NodeId, NodeId >, char > _initialMarks_
Initial marks for the orientation phase, used to convey constraints.
Definition SimpleMiic.h:302
SimpleMiic()
default constructor
Size _size_
size of the database
Definition SimpleMiic.h:296
void propagatesOrientationInChainOfRemainingEdges_(MixedGraph &graph)
heuristic for remaining edges when everything else has been tried
void orientationLatents_(CorrectedMutualInformation &mutualInformation, MixedGraph &graph, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet)
variant trying to propagate both orientations in a bidirected arc
MixedGraph learnPDAG(CorrectedMutualInformation &mutualInformation, MixedGraph graph)
learns the structure of an Essential Graph
bool propagatesRemainingOrientableEdges_(MixedGraph &graph, NodeId xj)
Propagates the orientation from a node to its neighbours.
void findBestContributor_(NodeId x, NodeId y, const std::vector< NodeId > &ui, const MixedGraph &graph, CorrectedMutualInformation &mutualInformation, Heap< CondRanking, GreaterPairOn2nd > &rank)
finds the best contributor node for a pair given a conditioning set
bool isForbidenArc_(NodeId x, NodeId y) const
void addConstraints(HashTable< std::pair< NodeId, NodeId >, char > constraints)
Set a ensemble of constraints for the orientation phase.
BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR &selector, PARAM_ESTIMATOR &estimator, DAG initial_dag=DAG())
learns the structure and the parameters of a BN
static bool _existsNonTrivialDirectedPath_(const MixedGraph &graph, NodeId n1, NodeId n2)
checks for directed paths in a graph, considering double arcs like edges, not considering arc as a di...
std::vector< Ranking > unshieldedTriples_(const MixedGraph &graph, CorrectedMutualInformation &mutualInformation, const HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet)
gets the list of unshielded triples in the graph in decreasing value of |I'(x, y, z|{ui}...
std::vector< ProbabilisticRanking > updateProbaTriples_(const MixedGraph &graph, std::vector< ProbabilisticRanking > probaTriples)
Gets the orientation probabilities like MIIC for the orientation phase.
void initiation_(CorrectedMutualInformation &mutualInformation, MixedGraph &graph, HashTable< std::pair< NodeId, NodeId >, std::vector< NodeId > > &sepSet, Heap< CondRanking, GreaterPairOn2nd > &rank)
Initiation phase.
The class computing n times the corrected mutual information (where n is the size (or the weight) of ...
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Set< Edge > EdgeSet
Some typdefs and define for shortcuts ...
Size NodeId
Type for node ids.
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
Class hash tables iterators.
Heaps definition.
Useful macros for maths.
#define M_LN2
Definition math_utils.h:63
Base classes for mixed directed/undirected graphs.
include the inlined functions if necessary
Definition CSVParser.h:54
std::pair< ThreePoints *, double > Ranking
Definition Miic.h:90
std::pair< CondThreePoints *, double > CondRanking
Definition Miic.h:87
std::tuple< NodeId, NodeId, NodeId, std::vector< NodeId > > CondThreePoints
Definition Miic.h:86
std::tuple< NodeId, NodeId, NodeId > ThreePoints
Definition Miic.h:89
std::tuple< ThreePoints *, double, double, double > ProbabilisticRanking
Definition Miic.h:92
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
STL namespace.
#define GUM_EMIT3(signal, arg1, arg2, arg3)
Definition signaler3.h:61
Class used to compute response times for benchmark purposes.