aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
BNLearner_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
51#include <fstream>
52
53#ifndef DOXYGEN_SHOULD_SKIP_THIS
54
55// to help IDE parser
58
59namespace gum {
60
61 namespace learning {
62 template < typename GUM_SCALAR >
63 BNLearner< GUM_SCALAR >::BNLearner(const std::string& filename,
64 const std::vector< std::string >& missingSymbols,
65 const bool induceTypes) :
66 IBNLearner(filename, missingSymbols, induceTypes) {
67 GUM_CONSTRUCTOR(BNLearner);
68 }
69
70 template < typename GUM_SCALAR >
71 BNLearner< GUM_SCALAR >::BNLearner(const DatabaseTable& db) : IBNLearner(db) {
72 GUM_CONSTRUCTOR(BNLearner);
73 }
74
75 template < typename GUM_SCALAR >
76 BNLearner< GUM_SCALAR >::BNLearner(const std::string& filename,
78 const std::vector< std::string >& missing_symbols) :
79 IBNLearner(filename, bn, missing_symbols) {
80 GUM_CONSTRUCTOR(BNLearner);
81 }
82
84 template < typename GUM_SCALAR >
85 BNLearner< GUM_SCALAR >::BNLearner(const BNLearner< GUM_SCALAR >& src) : IBNLearner(src) {
86 GUM_CONSTRUCTOR(BNLearner);
87 }
88
90 template < typename GUM_SCALAR >
91 BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) : IBNLearner(src) {
92 GUM_CONSTRUCTOR(BNLearner);
93 }
94
96 template < typename GUM_SCALAR >
97 BNLearner< GUM_SCALAR >::~BNLearner() {
98 GUM_DESTRUCTOR(BNLearner);
99 }
100
102
103 // ##########################################################################
105 // ##########################################################################
107
109 template < typename GUM_SCALAR >
110 BNLearner< GUM_SCALAR >&
111 BNLearner< GUM_SCALAR >::operator=(const BNLearner< GUM_SCALAR >& src) {
112 IBNLearner::operator=(src);
113 return *this;
114 }
115
117 template < typename GUM_SCALAR >
118 BNLearner< GUM_SCALAR >&
119 BNLearner< GUM_SCALAR >::operator=(BNLearner< GUM_SCALAR >&& src) noexcept {
120 IBNLearner::operator=(std::move(src));
121 return *this;
122 }
123
125 template < typename GUM_SCALAR >
126 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() {
127 // create the score, the prior and the estimator
128 auto notification = checkScorePriorCompatibility();
129 if (notification != "") { std::cout << "[aGrUM notification] " << notification << std::endl; }
130 createPrior_();
131 createScore_();
132
133 std::unique_ptr< ParamEstimator > param_estimator(
134 createParamEstimator_(scoreDatabase_.parser(), true));
135
136 return dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), learnDag_());
137 }
138
139 // check that the database contains the nodes of the dag, else raise an exception
140 template < typename GUM_SCALAR >
141 void BNLearner< GUM_SCALAR >::_checkDAGCompatibility_(const DAG& dag) {
142 // if the dag contains no node, this is compatible with the database
143 if (dag.size() == 0) return;
144
145 // check that the dag corresponds to the database
146 std::vector< NodeId > ids;
147 ids.reserve(dag.sizeNodes());
148 for (const auto node: dag)
149 ids.push_back(node);
150 std::sort(ids.begin(), ids.end());
151
152 if (ids.back() >= scoreDatabase_.names().size()) {
153 std::stringstream str;
154 str << "Learning parameters corresponding to the dag is impossible "
155 << "because the database does not contain the following nodeID";
156 std::vector< NodeId > bad_ids;
157 for (const auto node: ids) {
158 if (node >= scoreDatabase_.names().size()) bad_ids.push_back(node);
159 }
160 if (bad_ids.size() > 1) str << 's';
161 str << ": ";
162 bool deja = false;
163 for (const auto node: bad_ids) {
164 if (deja) str << ", ";
165 else deja = true;
166 str << node;
167 }
169 }
170 }
171
172 // learns a BN (its parameters) using a basic learning when its structure is known
173 template < typename GUM_SCALAR >
174 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::_learnParameters_(const DAG& dag,
175 bool takeIntoAccountScore) {
176 // if the dag contains no node, return an empty BN
177 if (dag.size() == 0) return BayesNet< GUM_SCALAR >();
178
179 // be sure that the database contains dag's node ids
180 _checkDAGCompatibility_(dag);
181
182 // create the prior
183 createPrior_();
184
185 // check that the database does not contain any missing value
186 if (scoreDatabase_.databaseTable().hasMissingValues()
187 || ((priorDatabase_ != nullptr)
188 && (priorType_ == BNLearnerPriorType::DIRICHLET_FROM_DATABASE)
189 && priorDatabase_->databaseTable().hasMissingValues())) {
191 "In general, the BNLearner is unable to cope with "
192 << "missing values in databases. To learn parameters in "
193 << "such situations, you should first use method " << "useEM()");
194 }
195
196 // create the usual estimator
197 DBRowGeneratorParser parser(scoreDatabase_.databaseTable().handler(), DBRowGeneratorSet());
198 std::unique_ptr< ParamEstimator > param_estimator(
199 createParamEstimator_(parser, takeIntoAccountScore));
200
201 return dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), dag);
202 }
203
204 // initialize the parameter learning by EM
205 template < typename GUM_SCALAR >
206 std::pair< std::shared_ptr< ParamEstimator >, std::shared_ptr< ParamEstimator > >
207 BNLearner< GUM_SCALAR >::_initializeEMParameterLearning_(const DAG& dag,
208 bool takeIntoAccountScore) {
209 // be sure that the database contains dag's node ids
210 _checkDAGCompatibility_(dag);
211
212 // create the prior
213 createPrior_();
214
215 // propagate the messages of dag2BN_ to the BNLearner so that the objects that listen
216 // to the BNLearner can be informed of the progress of the EM's execution by dag2BN_
217 // BNLearnerListener listener(this, dag2BN_);
218
219 // get the column types
220 const auto& database = scoreDatabase_.databaseTable();
221 const std::size_t nb_vars = database.nbVariables();
222 const std::vector< gum::learning::DBTranslatedValueType > col_types(
223 nb_vars,
225
226 // create the bootstrap estimator
227 DBRowGenerator4CompleteRows generator_bootstrap(col_types);
228 DBRowGeneratorSet genset_bootstrap;
229 genset_bootstrap.insertGenerator(generator_bootstrap);
230 DBRowGeneratorParser parser_bootstrap(database.handler(), genset_bootstrap);
231 std::shared_ptr< ParamEstimator > param_estimator_bootstrap(
232 createParamEstimator_(parser_bootstrap, takeIntoAccountScore));
233
234 // create the EM estimator
235 BayesNet< GUM_SCALAR > dummy_bn;
236 DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
237 DBRowGenerator& gen_EM = generator_EM; // fix for g++-4.8
238 DBRowGeneratorSet genset_EM;
239 genset_EM.insertGenerator(gen_EM);
240 DBRowGeneratorParser parser_EM(database.handler(), genset_EM);
241 std::shared_ptr< ParamEstimator > param_estimator_EM(
242 createParamEstimator_(parser_EM, takeIntoAccountScore));
243
244 return {param_estimator_bootstrap, param_estimator_EM};
245 }
246
247 // learns a BN (its parameters) with EM when its structure is known
248 template < typename GUM_SCALAR >
249 BayesNet< GUM_SCALAR >
250 BNLearner< GUM_SCALAR >::_learnParametersWithEM_(const DAG& dag,
251 bool takeIntoAccountScore) {
252 // if the dag contains no node, return an empty BN
253 if (dag.size() == 0) return BayesNet< GUM_SCALAR >();
254
255 // get a pair containing the bootstrap and the EM estimators
256 auto estimators = _initializeEMParameterLearning_(dag, takeIntoAccountScore);
257
258 // perform the EM algorithm
259 return dag2BN_.createBNwithEM< GUM_SCALAR >(*(estimators.first.get()),
260 *(estimators.second.get()),
261 dag);
262 }
263
265 template < typename GUM_SCALAR >
266 BayesNet< GUM_SCALAR >
267 BNLearner< GUM_SCALAR >::_learnParametersWithEM_(const BayesNet< GUM_SCALAR >& bn,
268 bool takeIntoAccountScore) {
269 // if the dag contains no node, return an empty BN
270 if (bn.dag().size() == 0) return BayesNet< GUM_SCALAR >();
271
272 // get a pair containing the bootstrap and the EM estimators
273 auto estimators = _initializeEMParameterLearning_(bn.dag(), takeIntoAccountScore);
274
275 return dag2BN_.createBNwithEM< GUM_SCALAR >(*(estimators.first.get()),
276 *(estimators.second.get()),
277 bn);
278 }
279
281 template < typename GUM_SCALAR >
282 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(const DAG& dag,
283 bool takeIntoAccountScore) {
284 if (!scoreDatabase_.databaseTable().hasMissingValues() || !useEM_) {
285 // here, we learn without EM
286 return _learnParameters_(dag, takeIntoAccountScore);
287 } else {
288 // here we learn with EM
289 return _learnParametersWithEM_(dag, takeIntoAccountScore);
290 }
291 }
292
294 template < typename GUM_SCALAR >
295 BayesNet< GUM_SCALAR >
296 BNLearner< GUM_SCALAR >::learnParameters(const BayesNet< GUM_SCALAR >& bn,
297 bool takeIntoAccountScore) {
298 if (!scoreDatabase_.databaseTable().hasMissingValues() || !useEM_) {
299 DAG dag;
300 const auto& db = scoreDatabase_.databaseTable();
301 for (const auto n: bn.nodes()) {
302 dag.addNodeWithId(db.columnFromVariableName(bn.variable(n).name()));
303 }
304 for (const auto& arc: bn.arcs()) {
305 dag.addArc(db.columnFromVariableName(bn.variable(arc.tail()).name()),
306 db.columnFromVariableName(bn.variable(arc.head()).name()));
307 }
308
309 // create le DAG en fonction des
310 return _learnParameters_(dag, takeIntoAccountScore);
311 } else {
312 return _learnParametersWithEM_(bn, takeIntoAccountScore);
313 }
314 }
315
317 template < typename GUM_SCALAR >
318 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(bool take_into_account_score) {
319 return learnParameters(initialDag_, take_into_account_score);
320 }
321
322 template < typename GUM_SCALAR >
323 NodeProperty< Sequence< std::string > >
324 BNLearner< GUM_SCALAR >::_labelsFromBN_(const std::string& filename,
325 const BayesNet< GUM_SCALAR >& src) {
326 std::ifstream in(filename, std::ifstream::in);
327
328 if ((in.rdstate() & std::ifstream::failbit) != 0) {
329 GUM_ERROR(gum::IOError, "File " << filename << " not found")
330 }
331
332 CSVParser parser(in, filename);
333 parser.next();
334 auto names = parser.current();
335
336 NodeProperty< Sequence< std::string > > modals;
337
338 for (gum::Idx col = 0; col < names.size(); col++) {
339 try {
340 gum::NodeId graphId = src.idFromName(names[col]);
341 modals.insert(col, gum::Sequence< std::string >());
342
343 for (gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
344 modals[col].insert(src.variable(graphId).label(i));
345 } catch (const gum::NotFound&) {
346 // no problem : a column not in the BN...
347 }
348 }
349
350 return modals;
351 }
352
353 template < typename GUM_SCALAR >
354 std::string BNLearner< GUM_SCALAR >::toString() const {
355 const auto st = state();
356
357 Size maxkey = 0;
358 for (const auto& tuple: st)
359 if (std::get< 0 >(tuple).length() > maxkey) maxkey = std::get< 0 >(tuple).length();
360
361 std::stringstream s;
362 for (const auto& tuple: st) {
363 s << std::setiosflags(std::ios::left) << std::setw(maxkey) << std::get< 0 >(tuple) << " : "
364 << std::get< 1 >(tuple);
365 if (std::get< 2 >(tuple) != "") s << " (" << std::get< 2 >(tuple) << ")";
366 s << std::endl;
367 }
368 return s.str();
369 }
370
371 template < typename GUM_SCALAR >
372 std::vector< std::tuple< std::string, std::string, std::string > >
373 BNLearner< GUM_SCALAR >::state() const {
374 std::vector< std::tuple< std::string, std::string, std::string > > vals;
375
376 std::string key;
377 std::string comment;
378 const auto& db = database();
379
380 vals.emplace_back("Filename", filename_, "");
381 vals.emplace_back("Size",
382 "(" + std::to_string(nbRows()) + "," + std::to_string(nbCols()) + ")",
383 "");
384
385 std::string vars = "";
386 for (NodeId i = 0; i < db.nbVariables(); i++) {
387 if (i > 0) vars += ", ";
388 vars += nameFromId(i) + "[" + std::to_string(db.domainSize(i)) + "]";
389 }
390 vals.emplace_back("Variables", vars, "");
391 vals.emplace_back("Induced types", inducedTypes_ ? "True" : "False", "");
392 vals.emplace_back("Missing values", hasMissingValues() ? "True" : "False", "");
393
394 key = "Algorithm";
395 switch (selectedAlgo_) {
396 case AlgoType::GREEDY_HILL_CLIMBING :
397 vals.emplace_back(key, "Greedy Hill Climbing", "");
398 break;
399 case AlgoType::K2 : {
400 vals.emplace_back(key, "K2", "");
401 const auto& k2order = algoK2_.order();
402 vars = "";
403 for (NodeId i = 0; i < k2order.size(); i++) {
404 if (i > 0) vars += ", ";
405 vars += nameFromId(k2order.atPos(i));
406 }
407 vals.emplace_back("K2 order", vars, "");
408 } break;
409 case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST :
410 vals.emplace_back(key, "Local Search with Tabu List", "");
411 vals.emplace_back("Tabu list size", std::to_string(nbDecreasingChanges_), "");
412 break;
413 case AlgoType::MIIC : vals.emplace_back(key, "MIIC", ""); break;
414 default : vals.emplace_back(key, "(unknown)", "?"); break;
415 }
416
417 key = "Score";
418
419 if (isScoreBased()) {
420 switch (scoreType_) {
421 case ScoreType::K2 : vals.emplace_back(key, "K2", ""); break;
422 case ScoreType::AIC : vals.emplace_back(key, "AIC", ""); break;
423 case ScoreType::BIC : vals.emplace_back(key, "BIC", ""); break;
424 case ScoreType::BD : vals.emplace_back(key, "BD", ""); break;
425 case ScoreType::BDeu : vals.emplace_back(key, "BDeu", ""); break;
426 case ScoreType::LOG2LIKELIHOOD : vals.emplace_back(key, "Log2Likelihood", ""); break;
427 default : vals.emplace_back(key, "(unknown)", "?"); break;
428 }
429 }
430
431 if (isConstraintBased()) {
432 key = "Correction";
433 switch (kmodeMiic_) {
434 case CorrectedMutualInformation::KModeTypes::MDL :
435 vals.emplace_back(key, "MDL", "");
436 break;
437 case CorrectedMutualInformation::KModeTypes::NML :
438 vals.emplace_back(key, "NML", "");
439 break;
440 case CorrectedMutualInformation::KModeTypes::NoCorr :
441 vals.emplace_back(key, "No correction", "");
442 break;
443 default : vals.emplace_back(key, "(unknown)", "?"); break;
444 }
445 }
446
447 key = "Prior";
448 comment = checkScorePriorCompatibility();
449 switch (priorType_) {
450 case BNLearnerPriorType::NO_prior : vals.emplace_back(key, "-", comment); break;
451 case BNLearnerPriorType::DIRICHLET_FROM_DATABASE :
452 vals.emplace_back(key, "Dirichlet", comment);
453 vals.emplace_back("Dirichlet from database", priorDbname_, "");
454 break;
455 case BNLearnerPriorType::DIRICHLET_FROM_BAYESNET :
456 vals.emplace_back(key, "Dirichlet", comment);
457 vals.emplace_back("Dirichlet from Bayesian network : ", _prior_bn_.toString(), "");
458 break;
459 case BNLearnerPriorType::BDEU : vals.emplace_back(key, "BDEU", comment); break;
460 case BNLearnerPriorType::SMOOTHING : vals.emplace_back(key, "Smoothing", comment); break;
461 default : vals.emplace_back(key, "(unknown)", "?"); break;
462 }
463
464 if (priorType_ != BNLearnerPriorType::NO_prior)
465 vals.emplace_back("Prior weight", std::to_string(priorWeight_), "");
466
467 if (databaseWeight() != double(nbRows())) {
468 vals.emplace_back("Database weight", std::to_string(databaseWeight()), "");
469 }
470
471 if (useEM_) {
472 comment = "";
473 if (!hasMissingValues()) comment = "But no missing values in this database";
474 vals.emplace_back("use EM", "True", "");
475 std::stringstream s;
476 s << "[";
477 bool first = true;
478 if (dag2BN_.isEnabledMinEpsilonRate()) {
479 s << "MinRate: " << dag2BN_.minEpsilonRate();
480 first = false;
481 }
482 if (dag2BN_.isEnabledEpsilon()) {
483 if (!first) s << ", ";
484 first = false;
485 s << "MinDiff: " << dag2BN_.epsilon();
486 }
487 if (dag2BN_.isEnabledMaxIter()) {
488 if (!first) s << ", ";
489 first = false;
490 s << "MaxIter: " << dag2BN_.maxIter();
491 }
492 if (dag2BN_.isEnabledMaxTime()) {
493 if (!first) s << ", ";
494 first = false;
495 s << "MaxTime: " << dag2BN_.maxTime();
496 }
497 s << "]";
498 vals.emplace_back("EM stopping criteria", s.str(), comment);
499 }
500
501 std::string res;
502 bool nofirst;
503 if (constraintIndegree_.maxIndegree() < std::numeric_limits< Size >::max()) {
504 vals.emplace_back("Constraint Max InDegree",
505 std::to_string(constraintIndegree_.maxIndegree()),
506 "");
507 }
508 if (!constraintForbiddenArcs_.arcs().empty()) {
509 res = "{";
510 nofirst = false;
511 for (const auto& arc: constraintForbiddenArcs_.arcs()) {
512 if (nofirst) res += ", ";
513 else nofirst = true;
514 res += nameFromId(arc.tail()) + "->" + nameFromId(arc.head());
515 }
516 res += "}";
517 vals.emplace_back("Constraint Forbidden Arcs", res, "");
518 }
519 if (!constraintMandatoryArcs_.arcs().empty()) {
520 res = "{";
521 nofirst = false;
522 for (const auto& arc: constraintMandatoryArcs_.arcs()) {
523 if (nofirst) res += ", ";
524 else nofirst = true;
525 res += nameFromId(arc.tail()) + "->" + nameFromId(arc.head());
526 }
527 res += "}";
528 vals.emplace_back("Constraint Mandatory Arcs", res, "");
529 }
530 if (!constraintPossibleEdges_.edges().empty()) {
531 res = "{";
532 nofirst = false;
533 for (const auto& edge: constraintPossibleEdges_.edges()) {
534 if (nofirst) res += ", ";
535 else nofirst = true;
536 res += nameFromId(edge.first()) + "--" + nameFromId(edge.second());
537 }
538 res += "}";
539 vals.emplace_back("Constraint Possible Edges", res, "");
540 }
541 if (!constraintSliceOrder_.sliceOrder().empty()) {
542 res = "{";
543 nofirst = false;
544 const auto& order = constraintSliceOrder_.sliceOrder();
545 for (const auto& p: order) {
546 if (nofirst) res += ", ";
547 else nofirst = true;
548 res += nameFromId(p.first) + ":" + std::to_string(p.second);
549 }
550 res += "}";
551 vals.emplace_back("Constraint Slice Order", res, "");
552 }
553 if (!constraintNoParentNodes_.nodes().empty()) {
554 res = "{";
555 nofirst = false;
556 for (const auto& node: constraintNoParentNodes_.nodes()) {
557 if (nofirst) res += ", ";
558 else nofirst = true;
559 res += nameFromId(node);
560 }
561 res += "}";
562 vals.emplace_back("Constraint No Parent Nodes", res, "");
563 }
564 if (!constraintNoChildrenNodes_.nodes().empty()) {
565 res = "{";
566 nofirst = false;
567 for (const auto& node: constraintNoChildrenNodes_.nodes()) {
568 if (nofirst) res += ", ";
569 else nofirst = true;
570 res += nameFromId(node);
571 }
572 res += "}";
573 vals.emplace_back("Constraint No Children Nodes", res, "");
574 }
575 if (initialDag_.size() != 0) {
576 vals.emplace_back("Initial DAG", "True", initialDag_.toDot());
577 }
578
579 return vals;
580 }
581
582 template < typename GUM_SCALAR >
583 void BNLearner< GUM_SCALAR >::copyState(const BNLearner< GUM_SCALAR >& learner) {
584 switch (learner.selectedAlgo_) {
585 case AlgoType::GREEDY_HILL_CLIMBING : useGreedyHillClimbing(); break;
586 case AlgoType::K2 : useK2(learner.algoK2_.order()); break;
587 case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST :
588 useLocalSearchWithTabuList(learner.nbDecreasingChanges_);
589 break;
590 case AlgoType::MIIC : useMIIC(); break;
591 }
592
593 switch (learner.scoreType_) {
594 case ScoreType::K2 : useScoreK2(); break;
595 case ScoreType::AIC : useScoreAIC(); break;
596 case ScoreType::BIC : useScoreBIC(); break;
597 case ScoreType::BD : useScoreBD(); break;
598 case ScoreType::BDeu : useScoreBDeu(); break;
599 case ScoreType::LOG2LIKELIHOOD : useScoreLog2Likelihood(); break;
600 }
601
602 switch (learner.kmodeMiic_) {
603 case CorrectedMutualInformation::KModeTypes::MDL : useMDLCorrection(); break;
604 case CorrectedMutualInformation::KModeTypes::NML : useNMLCorrection(); break;
605 case CorrectedMutualInformation::KModeTypes::NoCorr : useNoCorrection(); break;
606 }
607
608 switch (learner.priorType_) {
609 case BNLearnerPriorType::NO_prior : useNoPrior(); break;
610 case BNLearnerPriorType::DIRICHLET_FROM_DATABASE :
611 useDirichletPrior(learner.priorDbname_, learner.priorWeight_);
612 break;
613 case BNLearnerPriorType::DIRICHLET_FROM_BAYESNET :
614 useDirichletPrior(learner._prior_bn_);
615 break;
616 case BNLearnerPriorType::BDEU : useBDeuPrior(learner.priorWeight_); break;
617 case BNLearnerPriorType::SMOOTHING : useSmoothingPrior(learner.priorWeight_); break;
618 }
619
620 useEM_ = learner.useEM_;
621 noiseEM_ = learner.noiseEM_;
622 dag2BN_ = learner.dag2BN_;
623
624 setMaxIndegree(learner.constraintIndegree_.maxIndegree());
625 for (const auto src: learner.constraintNoParentNodes_.nodes()) {
626 try {
627 const auto dst = idFromName(learner.nameFromId(src));
628 addNoParentNode(dst);
629 } catch (const MissingVariableInDatabase&) {
630 // nothing to do
631 }
632 }
633 for (const auto src: learner.constraintNoChildrenNodes_.nodes()) {
634 try {
635 const auto dst = idFromName(learner.nameFromId(src));
636 addNoChildrenNode(dst);
637 } catch (const MissingVariableInDatabase&) {
638 // nothing to do
639 }
640 }
641 for (const auto& arc: learner.constraintForbiddenArcs_.arcs()) {
642 try {
643 const auto src = idFromName(learner.nameFromId(arc.tail()));
644 const auto dst = idFromName(learner.nameFromId(arc.head()));
645 addForbiddenArc(src, dst);
646 } catch (const MissingVariableInDatabase&) {
647 // nothing to do
648 }
649 }
650 for (const auto& arc: learner.constraintMandatoryArcs_.arcs()) {
651 try {
652 const auto src = idFromName(learner.nameFromId(arc.tail()));
653 const auto dst = idFromName(learner.nameFromId(arc.head()));
654 addMandatoryArc(src, dst);
655 } catch (const MissingVariableInDatabase&) {
656 // nothing to do
657 }
658 }
659 for (const auto& edge: learner.constraintPossibleEdges_.edges()) {
660 try {
661 const auto src = idFromName(learner.nameFromId(edge.first()));
662 const auto dst = idFromName(learner.nameFromId(edge.second()));
663 addPossibleEdge(src, dst);
664 } catch (const MissingVariableInDatabase&) {
665 // nothing to do
666 }
667 }
668 if (!learner.constraintSliceOrder_.sliceOrder().empty()) {
669 NodeProperty< NodeId > slice_order;
670 for (const auto& p: learner.constraintSliceOrder_.sliceOrder()) {
671 try {
672 slice_order.insert(idFromName(learner.nameFromId(p.first)), p.second);
673 } catch (const MissingVariableInDatabase&) {
674 // nothing to do
675 }
676 }
677 setSliceOrder(slice_order);
678 }
679 }
680
681 template < typename GUM_SCALAR >
682 void BNLearner< GUM_SCALAR >::createPrior_() {
683 // first, save the old prior, to be delete if everything is ok
684 Prior* old_prior = prior_;
685
686 // create the new prior
687 switch (priorType_) {
688 case BNLearnerPriorType::NO_prior :
689 prior_ = new NoPrior(scoreDatabase_.databaseTable(), scoreDatabase_.nodeId2Columns());
690 break;
691
692 case BNLearnerPriorType::SMOOTHING :
693 prior_
694 = new SmoothingPrior(scoreDatabase_.databaseTable(), scoreDatabase_.nodeId2Columns());
695 break;
696
697 case BNLearnerPriorType::DIRICHLET_FROM_DATABASE :
698 if (priorDatabase_ != nullptr) {
699 delete priorDatabase_;
700 priorDatabase_ = nullptr;
701 }
702
703 priorDatabase_
704 = new Database(priorDbname_, scoreDatabase_, scoreDatabase_.missingSymbols());
705
706 prior_ = new DirichletPriorFromDatabase(scoreDatabase_.databaseTable(),
707 priorDatabase_->parser(),
708 priorDatabase_->nodeId2Columns());
709 break;
710
711 case BNLearnerPriorType::DIRICHLET_FROM_BAYESNET :
712 prior_
713 = new DirichletPriorFromBN< GUM_SCALAR >(scoreDatabase_.databaseTable(), &_prior_bn_);
714 break;
715
716 case BNLearnerPriorType::BDEU :
717 prior_ = new BDeuPrior(scoreDatabase_.databaseTable(), scoreDatabase_.nodeId2Columns());
718 break;
719
720 default : GUM_ERROR(OperationNotAllowed, "The BNLearner does not support yet this prior")
721 }
722
723 // do not forget to assign a weight to the prior
724 prior_->setWeight(priorWeight_);
725
726 // remove the old prior, if any
727 if (old_prior != nullptr) delete old_prior;
728 }
729
730 template < typename GUM_SCALAR >
731 INLINE std::ostream& operator<<(std::ostream& output, const BNLearner< GUM_SCALAR >& learner) {
732 output << learner.toString();
733 return output;
734 }
735 } /* namespace learning */
736
737} /* namespace gum */
738
739#endif /* DOXYGEN_SHOULD_SKIP_THIS */
A listener that allows BNLearner to be used as a proxy for its inner algorithms.
A basic pack of learning algorithms that can easily be used.
Class representing a Bayesian network.
Definition BayesNet.h:93
Error: The database contains some missing values.
Error: A name of variable is not found in the database.
Exception : operation not allowed.
BNLearner(const std::string &filename, const std::vector< std::string > &missingSymbols={"?"}, const bool induceTypes=true)
default constructor
A pack of learning algorithms that can easily be used.
Definition IBNLearner.h:98
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size Idx
Type for indexes.
Definition types.h:79
Size NodeId
Type for node ids.
include the inlined functions if necessary
Definition CSVParser.h:54
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
std::ostream & operator<<(std::ostream &out, const TiXmlNode &base)
Definition tinyxml.cpp:1516