aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
CNLoopyPropagation_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
44
45namespace gum::credal {
46
47 template < typename GUM_SCALAR >
49 std::string path_name = path.substr(0, path.size() - 4);
50 path_name = path_name + ".res";
51
52 std::ofstream res(path_name.c_str(), std::ios::out | std::ios::trunc);
53
54 if (!res.good()) {
56 "CNLoopyPropagation<GUM_SCALAR>::saveInference(std::"
57 "string & path) : could not open file : "
58 + path_name)
59 }
60
61
62 if (std::string ext = path.substr(path.size() - 3, path.size());
63 std::strcmp(ext.c_str(), "evi") == 0) {
64 std::ifstream evi(path.c_str(), std::ios::in);
65 std::string ligne;
66
67 if (!evi.good()) {
69 "CNLoopyPropagation<GUM_SCALAR>::saveInference(std::"
70 "string & path) : could not open file : "
71 + ext)
72 }
73
74 while (evi.good()) {
75 getline(evi, ligne);
76 res << ligne << "\n";
77 }
78
79 evi.close();
80 }
81
82 res << "[RESULTATS]"
83 << "\n";
84
85 for (auto node: _bnet_->nodes()) {
86 // calcul distri posteriori
87 GUM_SCALAR msg_p_min = 1.0;
88 GUM_SCALAR msg_p_max = 0.0;
89
90 // cas evidence, calcul immediat
91 if (_infE_::evidence_.exists(node)) {
92 if (_infE_::evidence_[node][1] == 0.) {
93 msg_p_min = 0.;
94 } else if (_infE_::evidence_[node][1] == 1.) {
95 msg_p_min = 1.;
96 }
97
98 msg_p_max = msg_p_min;
99 }
100 // sinon depuis node P et node L
101 else {
102 GUM_SCALAR min = NodesP_min_[node];
103 GUM_SCALAR max;
104
105 if (NodesP_max_.exists(node)) {
106 max = NodesP_max_[node];
107 } else {
108 max = min;
109 }
110
111 GUM_SCALAR lmin = NodesL_min_[node];
112 GUM_SCALAR lmax;
113
114 if (NodesL_max_.exists(node)) {
115 lmax = NodesL_max_[node];
116 } else {
117 lmax = lmin;
118 }
119
120 // cas limites sur min
121 if (min == INF_ && lmin == 0.) {
122 std::cout << "proba ERR (negatif) : pi = inf, l = 0" << std::endl;
123 }
124
125 if (lmin == INF_) { // cas infini
126 msg_p_min = GUM_SCALAR(1.);
127 } else if (min == 0. || lmin == 0.) {
128 msg_p_min = GUM_SCALAR(0.);
129 } else {
130 msg_p_min = GUM_SCALAR(1. / (1. + ((1. / min - 1.) * 1. / lmin)));
131 }
132
133 // cas limites sur max
134 if (max == INF_ && lmax == 0.) {
135 std::cout << "proba ERR (negatif) : pi = inf, l = 0" << std::endl;
136 }
137
138 if (lmax == INF_) { // cas infini
139 msg_p_max = GUM_SCALAR(1.);
140 } else if (max == 0. || lmax == 0.) {
141 msg_p_max = GUM_SCALAR(0.);
142 } else {
143 msg_p_max = GUM_SCALAR(1. / (1. + ((1. / max - 1.) * 1. / lmax)));
144 }
145 }
146
147 if (msg_p_min != msg_p_min && msg_p_max == msg_p_max) { msg_p_min = msg_p_max; }
148
149 if (msg_p_max != msg_p_max && msg_p_min == msg_p_min) { msg_p_max = msg_p_min; }
150
151 if (msg_p_max != msg_p_max && msg_p_min != msg_p_min) {
152 std::cout << std::endl;
153 std::cout << "pas de proba calculable (verifier observations)" << std::endl;
154 }
155
156 res << "P(" << _bnet_->variable(node).name() << " | e) = ";
157
158 if (_infE_::evidence_.exists(node)) {
159 res << "(observe)" << std::endl;
160 } else {
161 res << std::endl;
162 }
163
164 res << "\t\t" << _bnet_->variable(node).label(0) << " [ " << (GUM_SCALAR)1. - msg_p_max;
165
166 if (msg_p_min != msg_p_max) {
167 res << ", " << (GUM_SCALAR)1. - msg_p_min << " ] | ";
168 } else {
169 res << " ] | ";
170 }
171
172 res << _bnet_->variable(node).label(1) << " [ " << msg_p_min;
173
174 if (msg_p_min != msg_p_max) {
175 res << ", " << msg_p_max << " ]" << std::endl;
176 } else {
177 res << " ]" << std::endl;
178 }
179 } // end of : for each node
180
181 res.close();
182 }
183
189
193 template < typename GUM_SCALAR >
195 GUM_SCALAR& msg_l_max,
196 std::vector< GUM_SCALAR >& lx,
197 GUM_SCALAR& num_min,
198 GUM_SCALAR& num_max,
199 GUM_SCALAR& den_min,
200 GUM_SCALAR& den_max) {
201 GUM_SCALAR num_min_tmp = 1.;
202 GUM_SCALAR den_min_tmp = 1.;
203 GUM_SCALAR num_max_tmp = 1.;
204 GUM_SCALAR den_max_tmp = 1.;
205
206 GUM_SCALAR res_min = 1.0;
207 GUM_SCALAR res_max = 0.0;
208
209 auto lsize = lx.size();
210
211 for (decltype(lsize) i = 0; i < lsize; i++) {
212 bool non_defini_min = false;
213 bool non_defini_max = false;
214
215 if (lx[i] == INF_) {
216 num_min_tmp = num_min;
217 den_min_tmp = den_max;
218 num_max_tmp = num_max;
219 den_max_tmp = den_min;
220 } else if (lx[i] == (GUM_SCALAR)1.) {
221 num_min_tmp = GUM_SCALAR(1.);
222 den_min_tmp = GUM_SCALAR(1.);
223 num_max_tmp = GUM_SCALAR(1.);
224 den_max_tmp = GUM_SCALAR(1.);
225 } else if (lx[i] > (GUM_SCALAR)1.) {
226 GUM_SCALAR li = GUM_SCALAR(1.) / (lx[i] - GUM_SCALAR(1.));
227 num_min_tmp = num_min + li;
228 den_min_tmp = den_max + li;
229 num_max_tmp = num_max + li;
230 den_max_tmp = den_min + li;
231 } else if (lx[i] < (GUM_SCALAR)1.) {
232 GUM_SCALAR li = GUM_SCALAR(1.) / (lx[i] - GUM_SCALAR(1.));
233 num_min_tmp = num_max + li;
234 den_min_tmp = den_min + li;
235 num_max_tmp = num_min + li;
236 den_max_tmp = den_max + li;
237 }
238
239 if (den_min_tmp == 0. && num_min_tmp == 0.) {
240 non_defini_min = true;
241 } else if (den_min_tmp == 0. && num_min_tmp != 0.) {
242 res_min = INF_;
243 } else if (den_min_tmp != INF_ || num_min_tmp != INF_) {
244 res_min = num_min_tmp / den_min_tmp;
245 }
246
247 if (den_max_tmp == 0. && num_max_tmp == 0.) {
248 non_defini_max = true;
249 } else if (den_max_tmp == 0. && num_max_tmp != 0.) {
250 res_max = INF_;
251 } else if (den_max_tmp != INF_ || num_max_tmp != INF_) {
252 res_max = num_max_tmp / den_max_tmp;
253 }
254
255 if (non_defini_max && non_defini_min) {
256 std::cout << "undefined msg" << std::endl;
257 continue;
258 } else if (non_defini_min && !non_defini_max) {
259 res_min = res_max;
260 } else if (non_defini_max && !non_defini_min) {
261 res_max = res_min;
262 }
263
264 if (res_min < 0.) { res_min = 0.; }
265
266 if (res_max < 0.) { res_max = 0.; }
267
268 if (msg_l_min == msg_l_max && msg_l_min == -2.) {
269 msg_l_min = res_min;
270 msg_l_max = res_max;
271 }
272
273 if (res_max > msg_l_max) { msg_l_max = res_max; }
274
275 if (res_min < msg_l_min) { msg_l_min = res_min; }
276
277 } // end of : for each lx
278 }
279
283 template < typename GUM_SCALAR >
285 std::vector< std::vector< GUM_SCALAR > >& combi_msg_p,
286 const NodeId& id,
287 GUM_SCALAR& msg_l_min,
288 GUM_SCALAR& msg_l_max,
289 std::vector< GUM_SCALAR >& lx,
290 const Idx& pos) {
291 GUM_SCALAR num_min = 0.;
292 GUM_SCALAR num_max = 0.;
293 GUM_SCALAR den_min = 0.;
294 GUM_SCALAR den_max = 0.;
295
296 auto taille = combi_msg_p.size();
297
298 std::vector< typename std::vector< GUM_SCALAR >::iterator > it(taille);
299
300 for (decltype(taille) i = 0; i < taille; i++) {
301 it[i] = combi_msg_p[i].begin();
302 }
303
304 Size pp = pos;
305
306 Size combi_den = 0;
307 Size combi_num = pp;
308
309 // marginalisation
310 while (it[taille - 1] != combi_msg_p[taille - 1].end()) {
311 GUM_SCALAR prod = 1.;
312
313 for (decltype(taille) k = 0; k < taille; k++) {
314 prod *= *it[k];
315 }
316
317 den_min += (_cn_->get_binaryCPT_min()[id][combi_den] * prod);
318 den_max += (_cn_->get_binaryCPT_max()[id][combi_den] * prod);
319
320 num_min += (_cn_->get_binaryCPT_min()[id][combi_num] * prod);
321 num_max += (_cn_->get_binaryCPT_max()[id][combi_num] * prod);
322
323 combi_den++;
324 combi_num++;
325
326 if (pp != 0) {
327 if (combi_den % pp == 0) {
328 combi_den += pp;
329 combi_num += pp;
330 }
331 }
332
333 // incrementation
334 ++it[0];
335
336 for (decltype(taille) i = 0; (i < taille - 1) && (it[i] == combi_msg_p[i].end()); ++i) {
337 it[i] = combi_msg_p[i].begin();
338 ++it[i + 1];
339 }
340 } // end of : marginalisation
341
342 compute_ext_(msg_l_min, msg_l_max, lx, num_min, num_max, den_min, den_max);
343 }
344
349 template < typename GUM_SCALAR >
351 std::vector< std::vector< GUM_SCALAR > >& combi_msg_p,
352 const NodeId& id,
353 GUM_SCALAR& msg_p_min,
354 GUM_SCALAR& msg_p_max) {
355 GUM_SCALAR min = 0.;
356 GUM_SCALAR max = 0.;
357
358 auto taille = combi_msg_p.size();
359
360 std::vector< typename std::vector< GUM_SCALAR >::iterator > it(taille);
361
362 for (decltype(taille) i = 0; i < taille; i++) {
363 it[i] = combi_msg_p[i].begin();
364 }
365
366 int combi = 0;
367 auto theEnd = combi_msg_p[taille - 1].end();
368
369 while (it[taille - 1] != theEnd) {
370 GUM_SCALAR prod = 1.;
371
372 for (decltype(taille) k = 0; k < taille; k++) {
373 prod *= *it[k];
374 }
375
376 min += (_cn_->get_binaryCPT_min()[id][combi] * prod);
377 max += (_cn_->get_binaryCPT_max()[id][combi] * prod);
378
379 combi++;
380
381 // incrementation
382 ++it[0];
383
384 for (decltype(taille) i = 0; (i < taille - 1) && (it[i] == combi_msg_p[i].end()); ++i) {
385 it[i] = combi_msg_p[i].begin();
386 ++it[i + 1];
387 }
388 }
389
390 if (min < msg_p_min) { msg_p_min = min; }
391
392 if (max > msg_p_max) { msg_p_max = max; }
393 }
394
398 template < typename GUM_SCALAR >
400 std::vector< std::vector< std::vector< GUM_SCALAR > > >& msgs_p,
401 const NodeId& id,
402 GUM_SCALAR& msg_p_min,
403 GUM_SCALAR& msg_p_max) {
404 auto taille = msgs_p.size();
405
406 // source node
407 if (taille == 0) {
408 msg_p_min = _cn_->get_binaryCPT_min()[id][0];
409 msg_p_max = _cn_->get_binaryCPT_max()[id][0];
410 return;
411 }
412
413 Size msgPerm = 1;
414 for (Size i = 0; i < taille; i++) {
415 msgPerm *= msgs_p[i].size();
416 }
417
418 // dispatch the messages among the threads and prepare the data
419 // they will process
421 ? this->getNumberOfThreads()
422 : 1; // no nested multithreading
423 nb_threads = std::min(msgPerm * taille / this->threadMinimalNbOps_, nb_threads);
424 if (nb_threads < 1) nb_threads = 1;
425
426 const auto ranges = gum::dispatchRangeToThreads(0, msgPerm, (unsigned int)(nb_threads));
427 const auto real_nb_threads = ranges.size();
428 std::vector< GUM_SCALAR > msg_pmin(real_nb_threads, msg_p_min);
429 std::vector< GUM_SCALAR > msg_pmax(real_nb_threads, msg_p_max);
430
431 // create the function to be executed by the threads
432 auto threadedExec
433 = [this, &msg_pmin, &msg_pmax, msgs_p, taille, ranges, id](const std::size_t this_thread,
434 const std::size_t nb_threads) {
435 std::vector< std::vector< GUM_SCALAR > > combi_msg_p(taille);
436
437 const auto& [first, second] = ranges[this_thread];
438 for (Idx j = first; j < second; ++j) {
439 // get jth msg :
440 auto jvalue = j;
441
442 for (Idx i = 0; i < taille; i++) {
443 if (msgs_p[i].size() == 2) {
444 combi_msg_p[i] = (jvalue & 1) ? msgs_p[i][1] : msgs_p[i][0];
445 jvalue /= 2;
446 } else {
447 combi_msg_p[i] = msgs_p[i][0];
448 }
449 }
450
451 compute_ext_(combi_msg_p, id, msg_pmin[this_thread], msg_pmax[this_thread]);
452 }
453 };
454
455 // launch the threads
456 ThreadExecutor::execute(real_nb_threads, threadedExec);
457
458 for (Idx j = 0; j < real_nb_threads; ++j) {
459 if (msg_p_min > msg_pmin[j]) { msg_p_min = msg_pmin[j]; }
460 if (msg_p_max < msg_pmax[j]) { msg_p_max = msg_pmax[j]; }
461 }
462 }
463
468 template < typename GUM_SCALAR >
470 std::vector< std::vector< std::vector< GUM_SCALAR > > >& msgs_p,
471 const NodeId& id,
472 GUM_SCALAR& real_msg_l_min,
473 GUM_SCALAR& real_msg_l_max,
474 std::vector< GUM_SCALAR >& lx,
475 const Idx& pos) {
476 GUM_SCALAR msg_l_min = real_msg_l_min;
477 GUM_SCALAR msg_l_max = real_msg_l_max;
478
479 auto taille = msgs_p.size();
480
481 // one parent node, the one receiving the message
482 if (taille == 0) {
483 GUM_SCALAR num_min = _cn_->get_binaryCPT_min()[id][1];
484 GUM_SCALAR num_max = _cn_->get_binaryCPT_max()[id][1];
485 GUM_SCALAR den_min = _cn_->get_binaryCPT_min()[id][0];
486 GUM_SCALAR den_max = _cn_->get_binaryCPT_max()[id][0];
487
488 compute_ext_(msg_l_min, msg_l_max, lx, num_min, num_max, den_min, den_max);
489
490 real_msg_l_min = msg_l_min;
491 real_msg_l_max = msg_l_max;
492 return;
493 }
494
495 Size msgPerm = 1;
496 for (Size i = 0; i < taille; i++) {
497 msgPerm *= msgs_p[i].size();
498 }
499
500 // dispatch the messages among the threads and prepare the data
501 // they will process
503 ? this->getNumberOfThreads()
504 : 1; // no nested multithreading
505 nb_threads = std::min(msgPerm * taille / this->threadMinimalNbOps_, nb_threads);
506 if (nb_threads < 1) nb_threads = 1;
507
508 const auto ranges = gum::dispatchRangeToThreads(0, msgPerm, (unsigned int)(nb_threads));
509 const auto real_nb_threads = ranges.size();
510 std::vector< GUM_SCALAR > msg_lmin(real_nb_threads, msg_l_min);
511 std::vector< GUM_SCALAR > msg_lmax(real_nb_threads, msg_l_max);
512
513 // create the function to be executed by the threads
514 auto threadedExec = [this, &msg_lmin, &msg_lmax, msgs_p, taille, ranges, id, &lx, pos](
515 const std::size_t this_thread,
516 const std::size_t nb_threads) {
517 std::vector< std::vector< GUM_SCALAR > > combi_msg_p(taille);
518
519 const auto& [first, second] = ranges[this_thread];
520 for (Idx j = first; j < second; ++j) {
521 // get jth msg :
522 auto jvalue = j;
523
524 for (Idx i = 0; i < taille; i++) {
525 if (msgs_p[i].size() == 2) {
526 combi_msg_p[i] = (jvalue & 1) ? msgs_p[i][1] : msgs_p[i][0];
527 jvalue /= 2;
528 } else {
529 combi_msg_p[i] = msgs_p[i][0];
530 }
531 }
532 compute_ext_(combi_msg_p, id, msg_lmin[this_thread], msg_lmax[this_thread], lx, pos);
533 }
534 };
535
536 // launch the threads
537 ThreadExecutor::execute(real_nb_threads, threadedExec);
538
539 for (Idx j = 0; j < real_nb_threads; ++j) {
540 if ((msg_l_min > msg_lmin[j] || msg_l_min == -2) && msg_lmin[j] > 0) {
541 msg_l_min = msg_lmin[j];
542 }
543 if ((msg_l_max < msg_lmax[j] || msg_l_max == -2) && msg_lmax[j] > 0) {
544 msg_l_max = msg_lmax[j];
545 }
546 }
547
548 real_msg_l_min = msg_l_min;
549 real_msg_l_max = msg_l_max;
550 }
551
552 template < typename GUM_SCALAR >
554 if (InferenceUpToDate_) { return; }
555
556 initialize_();
557
559
560 switch (_inferenceType_) {
562
564
566 }
567
568 //_updateMarginals();
569 updateIndicatrices_(); // will call updateMarginals_()
570
572
573 InferenceUpToDate_ = true;
574 }
575
576 template < typename GUM_SCALAR >
579
580 ArcsL_min_.clear();
581 ArcsL_max_.clear();
582 ArcsP_min_.clear();
583 ArcsP_max_.clear();
584 NodesL_min_.clear();
585 NodesL_max_.clear();
586 NodesP_min_.clear();
587 NodesP_max_.clear();
588
589 InferenceUpToDate_ = false;
590
591 if (!msg_l_sent_.empty()) {
592 for (auto node: _bnet_->nodes()) {
593 delete msg_l_sent_[node];
594 }
595 }
596
597 msg_l_sent_.clear();
598 update_l_.clear();
599 update_p_.clear();
600
601 active_nodes_set.clear();
602 next_active_nodes_set.clear();
603 }
604
605 template < typename GUM_SCALAR >
607 const DAG& graphe = _bnet_->dag();
608
609 // use const iterators with cbegin when available
610 for (auto node: _bnet_->topologicalOrder()) {
611 update_p_.set(node, false);
612 update_l_.set(node, false);
613 auto parents_ = new NodeSet();
614 msg_l_sent_.set(node, parents_);
615
616 // accelerer init pour evidences
617 if (_infE_::evidence_.exists(node)) {
618 if (_infE_::evidence_[node][1] != 0. && _infE_::evidence_[node][1] != 1.) {
619 GUM_ERROR(OperationNotAllowed, "CNLoopyPropagation can only handle HARD evidences")
620 }
621
622 active_nodes_set.insert(node);
623 update_l_.set(node, true);
624 update_p_.set(node, true);
625
626 if (_infE_::evidence_[node][1] == (GUM_SCALAR)1.) {
627 NodesL_min_.set(node, INF_);
628 NodesP_min_.set(node, (GUM_SCALAR)1.);
629 } else if (_infE_::evidence_[node][1] == (GUM_SCALAR)0.) {
630 NodesL_min_.set(node, (GUM_SCALAR)0.);
631 NodesP_min_.set(node, (GUM_SCALAR)0.);
632 }
633
634 std::vector< GUM_SCALAR > marg(2);
635 marg[1] = NodesP_min_[node];
636 marg[0] = 1 - marg[1];
637
638 _infE_::oldMarginalMin_.set(node, marg);
639 _infE_::oldMarginalMax_.set(node, marg);
640
641 continue;
642 }
643
644 NodeSet par_ = graphe.parents(node);
645 NodeSet enf_ = graphe.children(node);
646
647 if (par_.empty()) {
648 active_nodes_set.insert(node);
649 update_p_.set(node, true);
650 update_l_.set(node, true);
651 }
652
653 if (enf_.empty()) {
654 active_nodes_set.insert(node);
655 update_p_.set(node, true);
656 update_l_.set(node, true);
657 }
658
663 const auto parents = &_bnet_->cpt(node).variablesSequence();
664
665 std::vector< std::vector< std::vector< GUM_SCALAR > > > msgs_p;
666 std::vector< std::vector< GUM_SCALAR > > msg_p;
667 std::vector< GUM_SCALAR > distri(2);
668
669 // +1 from start to avoid counting_ itself
670 // use const iterators when available with cbegin
671 for (auto jt = ++parents->begin(), theEnd = parents->end(); jt != theEnd; ++jt) {
672 // compute probability distribution to avoid doing it multiple times
673 // (at
674 // each combination of messages)
675 distri[1] = NodesP_min_[_bnet_->nodeId(**jt)];
676 distri[0] = (GUM_SCALAR)1. - distri[1];
677 msg_p.push_back(distri);
678
679 if (NodesP_max_.exists(_bnet_->nodeId(**jt))) {
680 distri[1] = NodesP_max_[_bnet_->nodeId(**jt)];
681 distri[0] = (GUM_SCALAR)1. - distri[1];
682 msg_p.push_back(distri);
683 }
684
685 msgs_p.push_back(msg_p);
686 msg_p.clear();
687 }
688
689 GUM_SCALAR msg_p_min = 1.;
690 GUM_SCALAR msg_p_max = 0.;
691
692 if (_cn_->currentNodeType(node) != CredalNet< GUM_SCALAR >::NodeType::Indic) {
693 enum_combi_(msgs_p, node, msg_p_min, msg_p_max);
694 }
695
696 if (msg_p_min <= (GUM_SCALAR)0.) { msg_p_min = (GUM_SCALAR)0.; }
697
698 if (msg_p_max <= (GUM_SCALAR)0.) { msg_p_max = (GUM_SCALAR)0.; }
699
700 NodesP_min_.set(node, msg_p_min);
701 std::vector< GUM_SCALAR > marg(2);
702 marg[1] = msg_p_min;
703 marg[0] = 1 - msg_p_min;
704
705 _infE_::oldMarginalMin_.set(node, marg);
706
707 if (msg_p_min != msg_p_max) {
708 marg[1] = msg_p_max;
709 marg[0] = 1 - msg_p_max;
710 NodesP_max_.insert(node, msg_p_max);
711 }
712
713 _infE_::oldMarginalMax_.set(node, marg);
714
715 NodesL_min_.set(node, (GUM_SCALAR)1.);
716 }
717
718 for (auto arc: _bnet_->arcs()) {
719 ArcsP_min_.set(arc, NodesP_min_[arc.tail()]);
720
721 if (NodesP_max_.exists(arc.tail())) { ArcsP_max_.set(arc, NodesP_max_[arc.tail()]); }
722
723 ArcsL_min_.set(arc, NodesL_min_[arc.tail()]);
724 }
725 }
726
727 template < typename GUM_SCALAR >
729 const DAG& graphe = _bnet_->dag();
730
731 GUM_SCALAR eps;
732 // to validate TestSuite
734
735 do {
736 for (auto node: active_nodes_set) {
737 for (auto chil: graphe.children(node)) {
738 if (_cn_->currentNodeType(chil) == CredalNet< GUM_SCALAR >::NodeType::Indic) { continue; }
739
740 msgP_(node, chil);
741 }
742
743 for (auto par: graphe.parents(node)) {
744 if (_cn_->currentNodeType(node) == CredalNet< GUM_SCALAR >::NodeType::Indic) { continue; }
745
746 msgL_(node, par);
747 }
748 }
749
750 eps = calculateEpsilon_();
751
753
754 active_nodes_set.clear();
756 next_active_nodes_set.clear();
757
758 } while (_infE_::continueApproximationScheme(eps) && active_nodes_set.size() > 0);
759
760 _infE_::stopApproximationScheme(); // just to be sure of the
761 // approximationScheme has been notified of
762 // the end of looop
763 }
764
765 template < typename GUM_SCALAR >
767 Size nbrArcs = _bnet_->dag().sizeArcs();
768
769 std::vector< cArcP > seq;
770 seq.reserve(nbrArcs);
771
772 for (const auto& arc: _bnet_->arcs()) {
773 seq.push_back(&arc);
774 }
775
776 GUM_SCALAR eps;
777 // validate TestSuite
779
780 do {
781 for (Size j = 0, theEnd = nbrArcs / 2; j < theEnd; j++) {
782 auto w1 = randomValue(nbrArcs);
783 auto w2 = randomValue(nbrArcs);
784
785 if (w1 == w2) { continue; }
786
787 std::swap(seq[w1], seq[w2]);
788 }
789
790 for (const auto it: seq) {
791 if (_cn_->currentNodeType(it->tail()) == CredalNet< GUM_SCALAR >::NodeType::Indic
792 || _cn_->currentNodeType(it->head()) == CredalNet< GUM_SCALAR >::NodeType::Indic) {
793 continue;
794 }
795
796 msgP_(it->tail(), it->head());
797 msgL_(it->head(), it->tail());
798 }
799
800 eps = calculateEpsilon_();
801
803
805 }
806
807 // gives slightly worse results for some variable/modalities than other
808 // inference
809 // types (node D on 2U network loose 0.03 precision)
810 template < typename GUM_SCALAR >
812 Size nbrArcs = _bnet_->dag().sizeArcs();
813
814 std::vector< cArcP > seq;
815 seq.reserve(nbrArcs);
816
817 for (const auto& arc: _bnet_->arcs()) {
818 seq.push_back(&arc);
819 }
820
821 GUM_SCALAR eps;
822 // validate TestSuite
824
825 do {
826 for (const auto it: seq) {
827 if (_cn_->currentNodeType(it->tail()) == CredalNet< GUM_SCALAR >::NodeType::Indic
828 || _cn_->currentNodeType(it->head()) == CredalNet< GUM_SCALAR >::NodeType::Indic) {
829 continue;
830 }
831
832 msgP_(it->tail(), it->head());
833 msgL_(it->head(), it->tail());
834 }
835
836 eps = calculateEpsilon_();
837
839
841 }
842
843 template < typename GUM_SCALAR >
845 NodeSet const& children = _bnet_->children(Y);
846 NodeSet const& parents_ = _bnet_->parents(Y);
847
848 const auto parents = &_bnet_->cpt(Y).variablesSequence();
849
850 if (((children.size() + parents->size() - 1) == 1) && (!_infE_::evidence_.exists(Y))) {
851 return;
852 }
853
854 bool update_l = update_l_[Y];
855 bool update_p = update_p_[Y];
856
857 if (!update_p && !update_l) { return; }
858
859 msg_l_sent_[Y]->insert(X);
860
861 // for future refresh LM/PI
862 if (msg_l_sent_[Y]->size() == parents_.size()) {
863 msg_l_sent_[Y]->clear();
864 update_l_[Y] = false;
865 }
866
867 // refresh LM_part
868 if (update_l) {
869 if (!children.empty() && !_infE_::evidence_.exists(Y)) {
870 GUM_SCALAR lmin = 1.;
871 GUM_SCALAR lmax = 1.;
872
873 for (auto chil: children) {
874 lmin *= ArcsL_min_[Arc(Y, chil)];
875
876 if (ArcsL_max_.exists(Arc(Y, chil))) {
877 lmax *= ArcsL_max_[Arc(Y, chil)];
878 } else {
879 lmax *= ArcsL_min_[Arc(Y, chil)];
880 }
881 }
882
883 lmin = lmax;
884
885 if (lmax != lmax && lmin == lmin) { lmax = lmin; }
886
887 if (lmax != lmax && lmin != lmin) {
888 std::cout << "no likelihood defined [lmin, lmax] (incompatibles "
889 "evidence ?)"
890 << std::endl;
891 }
892
893 if (lmin < 0.) { lmin = 0.; }
894
895 if (lmax < 0.) { lmax = 0.; }
896
897 // no need to update nodeL if evidence since nodeL will never be used
898
899 NodesL_min_[Y] = lmin;
900
901 if (lmin != lmax) {
902 NodesL_max_.set(Y, lmax);
903 } else if (NodesL_max_.exists(Y)) {
904 NodesL_max_.erase(Y);
905 }
906
907 } // end of : node has children & no evidence
908
909 } // end of : if update_l
910
911 GUM_SCALAR lmin = NodesL_min_[Y];
912 GUM_SCALAR lmax;
913
914 if (NodesL_max_.exists(Y)) {
915 lmax = NodesL_max_[Y];
916 } else {
917 lmax = lmin;
918 }
919
923
924 if (lmin == lmax && lmin == 1.) {
925 ArcsL_min_[Arc(X, Y)] = lmin;
926
927 if (ArcsL_max_.exists(Arc(X, Y))) { ArcsL_max_.erase(Arc(X, Y)); }
928
929 return;
930 }
931
932 // garder pour chaque noeud un table des parents maj, une fois tous maj,
933 // stop
934 // jusque notification msg L ou P
935
936 if (update_p || update_l) {
937 std::vector< std::vector< std::vector< GUM_SCALAR > > > msgs_p;
938 std::vector< std::vector< GUM_SCALAR > > msg_p;
939 std::vector< GUM_SCALAR > distri(2);
940
941 Idx pos;
942
943 // +1 from start to avoid counting_ itself
944 // use const iterators with cbegin when available
945 for (auto jt = ++parents->begin(), theEnd = parents->end(); jt != theEnd; ++jt) {
946 if (_bnet_->nodeId(**jt) == X) {
947 // retirer la variable courante de la taille
948 pos = parents->pos(*jt) - 1;
949 continue;
950 }
951
952 // compute probability distribution to avoid doing it multiple times
953 // (at each combination of messages)
954 distri[1] = ArcsP_min_[Arc(_bnet_->nodeId(**jt), Y)];
955 distri[0] = GUM_SCALAR(1.) - distri[1];
956 msg_p.push_back(distri);
957
958 if (ArcsP_max_.exists(Arc(_bnet_->nodeId(**jt), Y))) {
959 distri[1] = ArcsP_max_[Arc(_bnet_->nodeId(**jt), Y)];
960 distri[0] = GUM_SCALAR(1.) - distri[1];
961 msg_p.push_back(distri);
962 }
963
964 msgs_p.push_back(msg_p);
965 msg_p.clear();
966 }
967
968 GUM_SCALAR min = -2.;
969 GUM_SCALAR max = -2.;
970
971 std::vector< GUM_SCALAR > lx;
972 lx.push_back(lmin);
973
974 if (lmin != lmax) { lx.push_back(lmax); }
975
976 enum_combi_(msgs_p, Y, min, max, lx, pos);
977
978 if (min == -2. || max == -2.) {
979 if (min != -2.) {
980 max = min;
981 } else if (max != -2.) {
982 min = max;
983 } else {
984 std::cout << std::endl;
985 std::cout << "!!!! pas de message L calculable !!!!" << std::endl;
986 return;
987 }
988 }
989
990 if (min < 0.) { min = 0.; }
991
992 if (max < 0.) { max = 0.; }
993
994 bool update = false;
995
996 if (min != ArcsL_min_[Arc(X, Y)]) {
997 ArcsL_min_[Arc(X, Y)] = min;
998 update = true;
999 }
1000
1001 if (ArcsL_max_.exists(Arc(X, Y))) {
1002 if (max != ArcsL_max_[Arc(X, Y)]) {
1003 if (max != min) {
1004 ArcsL_max_[Arc(X, Y)] = max;
1005 } else { // if ( max == min )
1006 ArcsL_max_.erase(Arc(X, Y));
1007 }
1008
1009 update = true;
1010 }
1011 } else {
1012 if (max != min) {
1013 ArcsL_max_.insert(Arc(X, Y), max);
1014 update = true;
1015 }
1016 }
1017
1018 if (update) {
1019 update_l_.set(X, true);
1020 next_active_nodes_set.insert(X);
1021 }
1022
1023 } // end of update_p || update_l
1024 }
1025
1026 template < typename GUM_SCALAR >
1027 void CNLoopyPropagation< GUM_SCALAR >::msgP_(const NodeId X, const NodeId demanding_child) {
1028 NodeSet const& children = _bnet_->children(X);
1029
1030 const auto parents = &_bnet_->cpt(X).variablesSequence();
1031
1032 if (((children.size() + parents->size() - 1) == 1) && (!_infE_::evidence_.exists(X))) {
1033 return;
1034 }
1035
1036 // LM_part ---- from all children but one --- the lonely one will get the
1037 // message
1038
1039 if (_infE_::evidence_.exists(X)) {
1040 ArcsP_min_[Arc(X, demanding_child)] = _infE_::evidence_[X][1];
1041
1042 if (ArcsP_max_.exists(Arc(X, demanding_child))) { ArcsP_max_.erase(Arc(X, demanding_child)); }
1043
1044 return;
1045 }
1046
1047 bool update_l = update_l_[X];
1048 bool update_p = update_p_[X];
1049
1050 if (!update_p && !update_l) { return; }
1051
1052 GUM_SCALAR lmin = 1.;
1053 GUM_SCALAR lmax = 1.;
1054
1055 // use cbegin if available
1056 for (auto chil: children) {
1057 if (chil == demanding_child) { continue; }
1058
1059 lmin *= ArcsL_min_[Arc(X, chil)];
1060
1061 if (ArcsL_max_.exists(Arc(X, chil))) {
1062 lmax *= ArcsL_max_[Arc(X, chil)];
1063 } else {
1064 lmax *= ArcsL_min_[Arc(X, chil)];
1065 }
1066 }
1067
1068 if (lmin != lmin && lmax == lmax) { lmin = lmax; }
1069
1070 if (lmax != lmax && lmin == lmin) { lmax = lmin; }
1071
1072 if (lmax != lmax && lmin != lmin) {
1073 std::cout << "pas de vraisemblance definie [lmin, lmax] (observations "
1074 "incompatibles ?)"
1075 << std::endl;
1076 return;
1077 }
1078
1079 if (lmin < 0.) { lmin = 0.; }
1080
1081 if (lmax < 0.) { lmax = 0.; }
1082
1083 // refresh PI_part
1084 GUM_SCALAR min = INF_;
1085 GUM_SCALAR max = 0.;
1086
1087 if (update_p) {
1088 std::vector< std::vector< std::vector< GUM_SCALAR > > > msgs_p;
1089 std::vector< std::vector< GUM_SCALAR > > msg_p;
1090 std::vector< GUM_SCALAR > distri(2);
1091
1092 // +1 from start to avoid counting_ itself
1093 // use const_iterators if available
1094 for (auto jt = ++parents->begin(), theEnd = parents->end(); jt != theEnd; ++jt) {
1095 // compute probability distribution to avoid doing it multiple times
1096 // (at
1097 // each combination of messages)
1098 distri[1] = ArcsP_min_[Arc(_bnet_->nodeId(**jt), X)];
1099 distri[0] = GUM_SCALAR(1.) - distri[1];
1100 msg_p.push_back(distri);
1101
1102 if (ArcsP_max_.exists(Arc(_bnet_->nodeId(**jt), X))) {
1103 distri[1] = ArcsP_max_[Arc(_bnet_->nodeId(**jt), X)];
1104 distri[0] = GUM_SCALAR(1.) - distri[1];
1105 msg_p.push_back(distri);
1106 }
1107
1108 msgs_p.push_back(msg_p);
1109 msg_p.clear();
1110 }
1111
1112 enum_combi_(msgs_p, X, min, max);
1113
1114 if (min < 0.) { min = 0.; }
1115
1116 if (max < 0.) { max = 0.; }
1117
1118 if (min == INF_ || max == INF_) {
1119 std::cout << " ERREUR msg P min = max = INF " << std::endl;
1120 std::cout.flush();
1121 return;
1122 }
1123
1124 NodesP_min_[X] = min;
1125
1126 if (min != max) {
1127 NodesP_max_.set(X, max);
1128 } else if (NodesP_max_.exists(X)) {
1129 NodesP_max_.erase(X);
1130 }
1131
1132 update_p_.set(X, false);
1133
1134 } // end of update_p
1135 else {
1136 min = NodesP_min_[X];
1137
1138 if (NodesP_max_.exists(X)) {
1139 max = NodesP_max_[X];
1140 } else {
1141 max = min;
1142 }
1143 }
1144
1145 if (update_p || update_l) {
1146 GUM_SCALAR msg_p_min;
1147 GUM_SCALAR msg_p_max;
1148
1149 // cas limites sur min
1150 if (min == INF_ && lmin == 0.) {
1151 std::cout << "MESSAGE P ERR (negatif) : pi = inf, l = 0" << std::endl;
1152 }
1153
1154 if (lmin == INF_) { // cas infini
1155 msg_p_min = GUM_SCALAR(1.);
1156 } else if (min == 0. || lmin == 0.) {
1157 msg_p_min = 0;
1158 } else {
1159 msg_p_min = GUM_SCALAR(1. / (1. + ((1. / min - 1.) * 1. / lmin)));
1160 }
1161
1162 // cas limites sur max
1163 if (max == INF_ && lmax == 0.) {
1164 std::cout << "MESSAGE P ERR (negatif) : pi = inf, l = 0" << std::endl;
1165 }
1166
1167 if (lmax == INF_) { // cas infini
1168 msg_p_max = GUM_SCALAR(1.);
1169 } else if (max == 0. || lmax == 0.) {
1170 msg_p_max = 0;
1171 } else {
1172 msg_p_max = GUM_SCALAR(1. / (1. + ((1. / max - 1.) * 1. / lmax)));
1173 }
1174
1175 if (msg_p_min != msg_p_min && msg_p_max == msg_p_max) {
1176 msg_p_min = msg_p_max;
1177 std::cout << std::endl;
1178 std::cout << "msg_p_min is NaN" << std::endl;
1179 }
1180
1181 if (msg_p_max != msg_p_max && msg_p_min == msg_p_min) {
1182 msg_p_max = msg_p_min;
1183 std::cout << std::endl;
1184 std::cout << "msg_p_max is NaN" << std::endl;
1185 }
1186
1187 if (msg_p_max != msg_p_max && msg_p_min != msg_p_min) {
1188 std::cout << std::endl;
1189 std::cout << "pas de message P calculable (verifier observations)" << std::endl;
1190 return;
1191 }
1192
1193 if (msg_p_min < 0.) { msg_p_min = 0.; }
1194
1195 if (msg_p_max < 0.) { msg_p_max = 0.; }
1196
1197 bool update = false;
1198
1199 if (msg_p_min != ArcsP_min_[Arc(X, demanding_child)]) {
1200 ArcsP_min_[Arc(X, demanding_child)] = msg_p_min;
1201 update = true;
1202 }
1203
1204 if (ArcsP_max_.exists(Arc(X, demanding_child))) {
1205 if (msg_p_max != ArcsP_max_[Arc(X, demanding_child)]) {
1206 if (msg_p_max != msg_p_min) {
1207 ArcsP_max_[Arc(X, demanding_child)] = msg_p_max;
1208 } else { // if ( msg_p_max == msg_p_min )
1209 ArcsP_max_.erase(Arc(X, demanding_child));
1210 }
1211
1212 update = true;
1213 }
1214 } else {
1215 if (msg_p_max != msg_p_min) {
1216 ArcsP_max_.insert(Arc(X, demanding_child), msg_p_max);
1217 update = true;
1218 }
1219 }
1220
1221 if (update) {
1222 update_p_.set(demanding_child, true);
1223 next_active_nodes_set.insert(demanding_child);
1224 }
1225
1226 } // end of : update_l || update_p
1227 }
1228
1229 template < typename GUM_SCALAR >
1231 for (auto node: _bnet_->nodes()) {
1232 if ((!refreshIndic)
1233 && _cn_->currentNodeType(node) == CredalNet< GUM_SCALAR >::NodeType::Indic) {
1234 continue;
1235 }
1236
1237 NodeSet const& children = _bnet_->children(node);
1238
1239 auto parents = &_bnet_->cpt(node).variablesSequence();
1240
1241 if (update_l_[node]) {
1242 GUM_SCALAR lmin = 1.;
1243 GUM_SCALAR lmax = 1.;
1244
1245 if (!children.empty() && !_infE_::evidence_.exists(node)) {
1246 for (auto chil: children) {
1247 lmin *= ArcsL_min_[Arc(node, chil)];
1248
1249 if (ArcsL_max_.exists(Arc(node, chil))) {
1250 lmax *= ArcsL_max_[Arc(node, chil)];
1251 } else {
1252 lmax *= ArcsL_min_[Arc(node, chil)];
1253 }
1254 }
1255
1256 if (lmin != lmin && lmax == lmax) { lmin = lmax; }
1257
1258 lmax = lmin;
1259
1260 if (lmax != lmax && lmin != lmin) {
1261 std::cout << "pas de vraisemblance definie [lmin, lmax] (observations "
1262 "incompatibles ?)"
1263 << std::endl;
1264 return;
1265 }
1266
1267 if (lmin < 0.) { lmin = 0.; }
1268
1269 if (lmax < 0.) { lmax = 0.; }
1270
1271 NodesL_min_[node] = lmin;
1272
1273 if (lmin != lmax) {
1274 NodesL_max_.set(node, lmax);
1275 } else if (NodesL_max_.exists(node)) {
1276 NodesL_max_.erase(node);
1277 }
1278 }
1279
1280 } // end of : update_l
1281
1282 if (update_p_[node]) {
1283 if ((parents->size() - 1) > 0 && !_infE_::evidence_.exists(node)) {
1284 std::vector< std::vector< std::vector< GUM_SCALAR > > > msgs_p;
1285 std::vector< std::vector< GUM_SCALAR > > msg_p;
1286 std::vector< GUM_SCALAR > distri(2);
1287
1288 // +1 from start to avoid counting_ itself
1289 // cbegin
1290 for (auto jt = ++parents->begin(), theEnd = parents->end(); jt != theEnd; ++jt) {
1291 // compute probability distribution to avoid doing it multiple
1292 // times
1293 // (at each combination of messages)
1294 distri[1] = ArcsP_min_[Arc(_bnet_->nodeId(**jt), node)];
1295 distri[0] = GUM_SCALAR(1.) - distri[1];
1296 msg_p.push_back(distri);
1297
1298 if (ArcsP_max_.exists(Arc(_bnet_->nodeId(**jt), node))) {
1299 distri[1] = ArcsP_max_[Arc(_bnet_->nodeId(**jt), node)];
1300 distri[0] = GUM_SCALAR(1.) - distri[1];
1301 msg_p.push_back(distri);
1302 }
1303
1304 msgs_p.push_back(msg_p);
1305 msg_p.clear();
1306 }
1307
1308 GUM_SCALAR min = INF_;
1309 GUM_SCALAR max = 0.;
1310
1311 enum_combi_(msgs_p, node, min, max);
1312
1313 if (min < 0.) { min = 0.; }
1314
1315 if (max < 0.) { max = 0.; }
1316
1317 NodesP_min_[node] = min;
1318
1319 if (min != max) {
1320 NodesP_max_.set(node, max);
1321 } else if (NodesP_max_.exists(node)) {
1322 NodesP_max_.erase(node);
1323 }
1324
1325 update_p_[node] = false;
1326 }
1327 } // end of update_p
1328
1329 } // end of : for each node
1330 }
1331
1332 template < typename GUM_SCALAR >
1334 for (auto node: _bnet_->nodes()) {
1335 GUM_SCALAR msg_p_min = 1.;
1336 GUM_SCALAR msg_p_max = 0.;
1337
1338 if (_infE_::evidence_.exists(node)) {
1339 if (_infE_::evidence_[node][1] == 0.) {
1340 msg_p_min = (GUM_SCALAR)0.;
1341 } else if (_infE_::evidence_[node][1] == 1.) {
1342 msg_p_min = 1.;
1343 }
1344
1345 msg_p_max = msg_p_min;
1346 } else {
1347 GUM_SCALAR min = NodesP_min_[node];
1348 GUM_SCALAR max;
1349
1350 if (NodesP_max_.exists(node)) {
1351 max = NodesP_max_[node];
1352 } else {
1353 max = min;
1354 }
1355
1356 GUM_SCALAR lmin = NodesL_min_[node];
1357 GUM_SCALAR lmax;
1358 if (NodesL_max_.exists(node)) {
1359 lmax = NodesL_max_[node];
1360 } else {
1361 lmax = lmin;
1362 }
1363
1364 if (min == INF_ || max == INF_) {
1365 std::cout << " min ou max === INF_ !!!!!!!!!!!!!!!!!!!!!!!!!! " << std::endl;
1366 return;
1367 }
1368
1369 if (min == INF_ && lmin == 0.) {
1370 std::cout << "proba ERR (negatif) : pi = inf, l = 0" << std::endl;
1371 return;
1372 }
1373
1374 if (lmin == INF_) {
1375 msg_p_min = GUM_SCALAR(1.);
1376 } else if (min == 0. || lmin == 0.) {
1377 msg_p_min = GUM_SCALAR(0.);
1378 } else {
1379 msg_p_min = GUM_SCALAR(1. / (1. + ((1. / min - 1.) * 1. / lmin)));
1380 }
1381
1382 if (max == INF_ && lmax == 0.) {
1383 std::cout << "proba ERR (negatif) : pi = inf, l = 0" << std::endl;
1384 return;
1385 }
1386
1387 if (lmax == INF_) {
1388 msg_p_max = GUM_SCALAR(1.);
1389 } else if (max == 0. || lmax == 0.) {
1390 msg_p_max = GUM_SCALAR(0.);
1391 } else {
1392 msg_p_max = GUM_SCALAR(1. / (1. + ((1. / max - 1.) * 1. / lmax)));
1393 }
1394 }
1395
1396 if (msg_p_min != msg_p_min && msg_p_max == msg_p_max) {
1397 msg_p_min = msg_p_max;
1398 std::cout << std::endl;
1399 std::cout << "msg_p_min is NaN" << std::endl;
1400 }
1401
1402 if (msg_p_max != msg_p_max && msg_p_min == msg_p_min) {
1403 msg_p_max = msg_p_min;
1404 std::cout << std::endl;
1405 std::cout << "msg_p_max is NaN" << std::endl;
1406 }
1407
1408 if (msg_p_max != msg_p_max && msg_p_min != msg_p_min) {
1409 std::cout << std::endl;
1410 std::cout << "Please check the observations (no proba can be computed)" << std::endl;
1411 return;
1412 }
1413
1414 if (msg_p_min < 0.) { msg_p_min = 0.; }
1415
1416 if (msg_p_max < 0.) { msg_p_max = 0.; }
1417
1418 _infE_::marginalMin_[node][0] = 1 - msg_p_max;
1419 _infE_::marginalMax_[node][0] = 1 - msg_p_min;
1420 _infE_::marginalMin_[node][1] = msg_p_min;
1421 _infE_::marginalMax_[node][1] = msg_p_max;
1422 }
1423 }
1424
1425 template < typename GUM_SCALAR >
1432
1433 template < typename GUM_SCALAR >
1435 for (auto node: _bnet_->nodes()) {
1436 if (_cn_->currentNodeType(node) != CredalNet< GUM_SCALAR >::NodeType::Indic) { continue; }
1437
1438 for (auto pare: _bnet_->parents(node)) {
1439 msgP_(pare, node);
1440 }
1441 }
1442
1443 refreshLMsPIs_(true);
1445 }
1446
1447 template < typename GUM_SCALAR >
1449 if (_infE_::modal_.empty()) { return; }
1450
1451 std::vector< std::vector< GUM_SCALAR > > vertices(2, std::vector< GUM_SCALAR >(2));
1452
1453 for (auto node: _bnet_->nodes()) {
1454 vertices[0][0] = _infE_::marginalMin_[node][0];
1455 vertices[0][1] = _infE_::marginalMax_[node][1];
1456
1457 vertices[1][0] = _infE_::marginalMax_[node][0];
1458 vertices[1][1] = _infE_::marginalMin_[node][1];
1459
1460 for (auto vertex = 0, vend = 2; vertex != vend; vertex++) {
1462 // test credal sets vertices elim
1463 // remove with L2U since variables are binary
1464 // but does the user know that ?
1466 vertices[vertex]); // no redundancy elimination with 2 vertices
1467 }
1468 }
1469 }
1470
1471 template < typename GUM_SCALAR >
1473 InferenceEngine< GUM_SCALAR >::InferenceEngine(credalNet) {
1474 if (!credalNet.isSeparatelySpecified()) {
1475 GUM_ERROR(OperationNotAllowed,
1476 "CNLoopyPropagation is only available "
1477 "with separately specified nets");
1478 }
1479
1480 // test for binary cn
1481 for (auto node: credalNet.current_bn().nodes())
1482 if (credalNet.current_bn().variable(node).domainSize() != 2) {
1483 GUM_ERROR(OperationNotAllowed,
1484 "CNLoopyPropagation is only available "
1485 "with binary credal networks")
1486 }
1487
1488 // test if compute CPTMinMax has been called
1489 if (!credalNet.hasComputedBinaryCPTMinMax()) {
1490 GUM_ERROR(OperationNotAllowed,
1491 "CNLoopyPropagation only works when "
1492 "\"computeBinaryCPTMinMax()\" has been called for "
1493 "this credal net")
1494 }
1495
1496 _cn_ = &credalNet;
1497 _bnet_ = &credalNet.current_bn();
1498
1500 InferenceUpToDate_ = false;
1501
1502 GUM_CONSTRUCTOR(CNLoopyPropagation)
1503 }
1504
1505 template < typename GUM_SCALAR >
1507 InferenceUpToDate_ = false;
1508
1509 if (!msg_l_sent_.empty()) {
1510 for (auto node: _bnet_->nodes()) {
1511 delete msg_l_sent_[node];
1512 }
1513 }
1514
1515 GUM_DESTRUCTOR(CNLoopyPropagation)
1516 }
1517
1518 template < typename GUM_SCALAR >
1522
1523 template < typename GUM_SCALAR >
1528} // namespace gum::credal
Class implementing loopy-propagation with binary networks - L2U algorithm.
#define INF_
void updateApproximationScheme(unsigned int incr=1)
Update the scheme w.r.t the new error and increment steps.
void initApproximationScheme()
Initialise the scheme.
void stopApproximationScheme()
Stop the approximation scheme.
bool continueApproximationScheme(double error)
Update the scheme w.r.t the new error.
const NodeSet & parents(NodeId id) const
returns the set of nodes with arc ingoing to a given node
NodeSet children(const NodeSet &ids) const
returns the set of children of a set of nodes
The base class for all directed edges.
Base class for dag.
Definition DAG.h:121
Exception : the element we looked for cannot be found.
Exception : operation not allowed.
Size size() const noexcept
Returns the number of elements in the set.
Definition set_tpl.h:636
bool empty() const noexcept
Indicates whether the set is the empty set.
Definition set_tpl.h:642
NodeProperty< GUM_SCALAR > NodesL_min_
"Lower" node information obtained by combinaison of children messages.
NodeProperty< GUM_SCALAR > NodesP_min_
"Lower" node information obtained by combinaison of parent's messages.
NodeProperty< GUM_SCALAR > NodesL_max_
"Upper" node information obtained by combinaison of children messages.
void msgL_(const NodeId X, const NodeId demanding_parent)
Sends a message to one's parent, i.e.
NodeProperty< NodeSet * > msg_l_sent_
Used to keep track of one's messages sent to it's parents.
InferenceType _inferenceType_
The choosen inference type.
void compute_ext_(GUM_SCALAR &msg_l_min, GUM_SCALAR &msg_l_max, std::vector< GUM_SCALAR > &lx, GUM_SCALAR &num_min, GUM_SCALAR &num_max, GUM_SCALAR &den_min, GUM_SCALAR &den_max)
Used by msgL_.
NodeProperty< bool > update_p_
Used to keep track of which node needs to update it's information coming from it's parents.
void refreshLMsPIs_(bool refreshIndic=false)
Get the last messages from one's parents and children.
NodeProperty< bool > update_l_
Used to keep track of which node needs to update it's information coming from it's children.
void makeInferenceNodeToNeighbours_()
Starts the inference with this inference type.
void initialize_()
Topological forward propagation to initialize old marginals & messages.
GUM_SCALAR calculateEpsilon_()
Compute epsilon.
void makeInferenceByRandomOrder_()
Starts the inference with this inference type.
const IBayesNet< GUM_SCALAR > * _bnet_
A pointer to it's IBayesNet used as a DAG.
ArcProperty< GUM_SCALAR > ArcsP_min_
"Lower" information coming from one's parent.
InferenceType
Inference type to be used by the algorithm.
@ nodeToNeighbours
Uses a node-set so we don't iterate on nodes that can't send a new message.
@ randomOrder
Chooses a random arc ordering and sends messages accordingly.
@ ordered
Chooses an arc ordering and sends messages accordingly at all steps.
void msgP_(const NodeId X, const NodeId demanding_child)
Sends a message to one's child, i.e.
ArcProperty< GUM_SCALAR > ArcsL_max_
"Upper" information coming from one's children.
bool InferenceUpToDate_
TRUE if inference has already been performed, FALSE otherwise.
void updateMarginals_()
Compute marginals from up-to-date messages.
const CredalNet< GUM_SCALAR > * _cn_
A pointer to the CredalNet to be used.
void computeExpectations_()
Since the network is binary, expectations can be computed from the final marginals which give us the ...
NodeProperty< GUM_SCALAR > NodesP_max_
"Upper" node information obtained by combinaison of parent's messages.
void enum_combi_(std::vector< std::vector< std::vector< GUM_SCALAR > > > &msgs_p, const NodeId &id, GUM_SCALAR &msg_l_min, GUM_SCALAR &msg_l_max, std::vector< GUM_SCALAR > &lx, const Idx &pos)
Used by msgL_.
void makeInference()
Starts the inference.
InferenceType inferenceType()
Get the inference type.
void saveInference(const std::string &path)
void makeInferenceByOrderedArcs_()
Starts the inference with this inference type.
void eraseAllEvidence()
Erase all inference related data to perform another one.
void updateIndicatrices_()
Only update indicatrices variables at the end of computations ( calls msgP_ ).
NodeSet active_nodes_set
The current node-set to iterate through at this current step.
NodeSet next_active_nodes_set
The next node-set, i.e.
CNLoopyPropagation(const CredalNet< GUM_SCALAR > &credalNet)
Constructor.
ArcProperty< GUM_SCALAR > ArcsL_min_
"Lower" information coming from one's children.
ArcProperty< GUM_SCALAR > ArcsP_max_
"Upper" information coming from one's parent.
Class template representing a Credal Network.
Definition credalNet.h:97
void updateExpectations_(const NodeId &id, const std::vector< GUM_SCALAR > &vertex)
Given a node id and one of it's possible vertex obtained during inference, update this node lower and...
margi oldMarginalMax_
Old upper marginals used to compute epsilon.
margi evidence_
Holds observed variables states.
margi marginalMax_
Upper marginals.
void updateCredalSets_(const NodeId &id, const std::vector< GUM_SCALAR > &vertex, const bool &elimRedund=false)
Given a node id and one of it's possible vertex, update it's credal set.
virtual const GUM_SCALAR computeEpsilon_()
Compute approximation scheme epsilon using the old marginals and the new ones.
const std::vector< std::vector< GUM_SCALAR > > & vertices(const NodeId id) const
Get the vertice of a given node id.
InferenceEngine(const CredalNet< GUM_SCALAR > &credalNet)
Construtor.
margi oldMarginalMin_
Old lower marginals used to compute epsilon.
virtual void eraseAllEvidence()
removes all the evidence entered into the network
const CredalNet< GUM_SCALAR > & credalNet() const
Get this creadal network.
margi marginalMin_
Lower marginals.
dynExpe modal_
Variables modalities used to compute expectations.
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size Idx
Type for indexes.
Definition types.h:79
Size NodeId
Type for node ids.
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
Idx randomValue(const Size max=2)
Returns a random Idx between 0 and max-1 included.
namespace for all credal networks entities
Definition agrum.h:61
std::vector< std::pair< Idx, Idx > > dispatchRangeToThreads(Idx beg, Idx end, unsigned int nb_threads)
returns a vector equally splitting elements of a range among threads
Definition threads.cpp:76
unsigned int getNumberOfThreads()
returns the max number of threads used by default when entering the next parallel region
static void execute(std::size_t nb_threads, FUNCTION exec_func, ARGS &&... func_args)
executes a function using several threads
static int nbRunningThreadsExecutors()
indicates how many threadExecutors are currently running