53#ifndef DOXYGEN_SHOULD_SKIP_THIS
62 template <
typename GUM_SCALAR >
64 const std::vector< std::string >& missingSymbols,
65 const bool induceTypes) :
66 IBNLearner(filename, missingSymbols, induceTypes) {
67 GUM_CONSTRUCTOR(BNLearner);
70 template <
typename GUM_SCALAR >
71 BNLearner< GUM_SCALAR >::BNLearner(
const DatabaseTable& db) : IBNLearner(db) {
72 GUM_CONSTRUCTOR(BNLearner);
75 template <
typename GUM_SCALAR >
76 BNLearner< GUM_SCALAR >::BNLearner(
const std::string& filename,
78 const std::vector< std::string >& missing_symbols) :
79 IBNLearner(filename, bn, missing_symbols) {
80 GUM_CONSTRUCTOR(BNLearner);
84 template <
typename GUM_SCALAR >
85 BNLearner< GUM_SCALAR >::BNLearner(
const BNLearner< GUM_SCALAR >& src) : IBNLearner(src) {
86 GUM_CONSTRUCTOR(BNLearner);
90 template <
typename GUM_SCALAR >
91 BNLearner< GUM_SCALAR >::BNLearner(BNLearner< GUM_SCALAR >&& src) : IBNLearner(src) {
92 GUM_CONSTRUCTOR(BNLearner);
96 template <
typename GUM_SCALAR >
97 BNLearner< GUM_SCALAR >::~BNLearner() {
98 GUM_DESTRUCTOR(BNLearner);
109 template <
typename GUM_SCALAR >
110 BNLearner< GUM_SCALAR >&
111 BNLearner< GUM_SCALAR >::operator=(
const BNLearner< GUM_SCALAR >& src) {
112 IBNLearner::operator=(src);
117 template <
typename GUM_SCALAR >
118 BNLearner< GUM_SCALAR >&
119 BNLearner< GUM_SCALAR >::operator=(BNLearner< GUM_SCALAR >&& src)
noexcept {
120 IBNLearner::operator=(std::move(src));
125 template <
typename GUM_SCALAR >
126 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnBN() {
128 auto notification = checkScorePriorCompatibility();
129 if (notification !=
"") { std::cout <<
"[aGrUM notification] " << notification << std::endl; }
133 std::unique_ptr< ParamEstimator > param_estimator(
134 createParamEstimator_(scoreDatabase_.parser(),
true));
136 return dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), learnDag_());
140 template <
typename GUM_SCALAR >
141 void BNLearner< GUM_SCALAR >::_checkDAGCompatibility_(
const DAG& dag) {
143 if (dag.size() == 0)
return;
146 std::vector< NodeId > ids;
147 ids.reserve(dag.sizeNodes());
148 for (
const auto node: dag)
150 std::sort(ids.begin(), ids.end());
152 if (ids.back() >= scoreDatabase_.names().size()) {
153 std::stringstream str;
154 str <<
"Learning parameters corresponding to the dag is impossible "
155 <<
"because the database does not contain the following nodeID";
156 std::vector< NodeId > bad_ids;
157 for (
const auto node: ids) {
158 if (node >= scoreDatabase_.names().size()) bad_ids.push_back(node);
160 if (bad_ids.size() > 1) str <<
's';
163 for (
const auto node: bad_ids) {
164 if (deja) str <<
", ";
173 template <
typename GUM_SCALAR >
174 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::_learnParameters_(
const DAG& dag,
175 bool takeIntoAccountScore) {
177 if (dag.size() == 0)
return BayesNet< GUM_SCALAR >();
180 _checkDAGCompatibility_(dag);
186 if (scoreDatabase_.databaseTable().hasMissingValues()
187 || ((priorDatabase_ !=
nullptr)
188 && (priorType_ == BNLearnerPriorType::DIRICHLET_FROM_DATABASE)
189 && priorDatabase_->databaseTable().hasMissingValues())) {
191 "In general, the BNLearner is unable to cope with "
192 <<
"missing values in databases. To learn parameters in "
193 <<
"such situations, you should first use method " <<
"useEM()");
197 DBRowGeneratorParser parser(scoreDatabase_.databaseTable().handler(), DBRowGeneratorSet());
198 std::unique_ptr< ParamEstimator > param_estimator(
199 createParamEstimator_(parser, takeIntoAccountScore));
201 return dag2BN_.createBN< GUM_SCALAR >(*(param_estimator.get()), dag);
205 template <
typename GUM_SCALAR >
206 std::pair< std::shared_ptr< ParamEstimator >, std::shared_ptr< ParamEstimator > >
207 BNLearner< GUM_SCALAR >::_initializeEMParameterLearning_(
const DAG& dag,
208 bool takeIntoAccountScore) {
210 _checkDAGCompatibility_(dag);
220 const auto& database = scoreDatabase_.databaseTable();
221 const std::size_t nb_vars = database.nbVariables();
222 const std::vector< gum::learning::DBTranslatedValueType > col_types(
227 DBRowGenerator4CompleteRows generator_bootstrap(col_types);
228 DBRowGeneratorSet genset_bootstrap;
229 genset_bootstrap.insertGenerator(generator_bootstrap);
230 DBRowGeneratorParser parser_bootstrap(database.handler(), genset_bootstrap);
231 std::shared_ptr< ParamEstimator > param_estimator_bootstrap(
232 createParamEstimator_(parser_bootstrap, takeIntoAccountScore));
235 BayesNet< GUM_SCALAR > dummy_bn;
236 DBRowGeneratorEM< GUM_SCALAR > generator_EM(col_types, dummy_bn);
237 DBRowGenerator& gen_EM = generator_EM;
238 DBRowGeneratorSet genset_EM;
239 genset_EM.insertGenerator(gen_EM);
240 DBRowGeneratorParser parser_EM(database.handler(), genset_EM);
241 std::shared_ptr< ParamEstimator > param_estimator_EM(
242 createParamEstimator_(parser_EM, takeIntoAccountScore));
244 return {param_estimator_bootstrap, param_estimator_EM};
248 template <
typename GUM_SCALAR >
249 BayesNet< GUM_SCALAR >
250 BNLearner< GUM_SCALAR >::_learnParametersWithEM_(
const DAG& dag,
251 bool takeIntoAccountScore) {
253 if (dag.size() == 0)
return BayesNet< GUM_SCALAR >();
256 auto estimators = _initializeEMParameterLearning_(dag, takeIntoAccountScore);
259 return dag2BN_.createBNwithEM< GUM_SCALAR >(*(estimators.first.get()),
260 *(estimators.second.get()),
265 template <
typename GUM_SCALAR >
266 BayesNet< GUM_SCALAR >
267 BNLearner< GUM_SCALAR >::_learnParametersWithEM_(
const BayesNet< GUM_SCALAR >& bn,
268 bool takeIntoAccountScore) {
270 if (bn.dag().size() == 0)
return BayesNet< GUM_SCALAR >();
273 auto estimators = _initializeEMParameterLearning_(bn.dag(), takeIntoAccountScore);
275 return dag2BN_.createBNwithEM< GUM_SCALAR >(*(estimators.first.get()),
276 *(estimators.second.get()),
281 template <
typename GUM_SCALAR >
282 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(
const DAG& dag,
283 bool takeIntoAccountScore) {
284 if (!scoreDatabase_.databaseTable().hasMissingValues() || !useEM_) {
286 return _learnParameters_(dag, takeIntoAccountScore);
289 return _learnParametersWithEM_(dag, takeIntoAccountScore);
294 template <
typename GUM_SCALAR >
295 BayesNet< GUM_SCALAR >
296 BNLearner< GUM_SCALAR >::learnParameters(
const BayesNet< GUM_SCALAR >& bn,
297 bool takeIntoAccountScore) {
298 if (!scoreDatabase_.databaseTable().hasMissingValues() || !useEM_) {
300 const auto& db = scoreDatabase_.databaseTable();
301 for (
const auto n: bn.nodes()) {
302 dag.addNodeWithId(db.columnFromVariableName(bn.variable(n).name()));
304 for (
const auto& arc: bn.arcs()) {
305 dag.addArc(db.columnFromVariableName(bn.variable(arc.tail()).name()),
306 db.columnFromVariableName(bn.variable(arc.head()).name()));
310 return _learnParameters_(dag, takeIntoAccountScore);
312 return _learnParametersWithEM_(bn, takeIntoAccountScore);
317 template <
typename GUM_SCALAR >
318 BayesNet< GUM_SCALAR > BNLearner< GUM_SCALAR >::learnParameters(
bool take_into_account_score) {
319 return learnParameters(initialDag_, take_into_account_score);
322 template <
typename GUM_SCALAR >
323 NodeProperty< Sequence< std::string > >
324 BNLearner< GUM_SCALAR >::_labelsFromBN_(
const std::string& filename,
325 const BayesNet< GUM_SCALAR >& src) {
326 std::ifstream in(filename, std::ifstream::in);
328 if ((in.rdstate() & std::ifstream::failbit) != 0) {
329 GUM_ERROR(gum::IOError,
"File " << filename <<
" not found")
332 CSVParser parser(in, filename);
334 auto names = parser.current();
336 NodeProperty< Sequence< std::string > > modals;
338 for (
gum::Idx col = 0; col < names.size(); col++) {
343 for (
gum::Size i = 0; i < src.variable(graphId).domainSize(); ++i)
344 modals[col].insert(src.variable(graphId).label(i));
345 }
catch (
const gum::NotFound&) {
353 template <
typename GUM_SCALAR >
354 std::string BNLearner< GUM_SCALAR >::toString()
const {
355 const auto st = state();
358 for (
const auto& tuple: st)
359 if (std::get< 0 >(tuple).length() > maxkey) maxkey = std::get< 0 >(tuple).length();
362 for (
const auto& tuple: st) {
363 s << std::setiosflags(std::ios::left) << std::setw(maxkey) << std::get< 0 >(tuple) <<
" : "
364 << std::get< 1 >(tuple);
365 if (std::get< 2 >(tuple) !=
"") s <<
" (" << std::get< 2 >(tuple) <<
")";
371 template <
typename GUM_SCALAR >
372 std::vector< std::tuple< std::string, std::string, std::string > >
373 BNLearner< GUM_SCALAR >::state()
const {
374 std::vector< std::tuple< std::string, std::string, std::string > > vals;
378 const auto& db = database();
380 vals.emplace_back(
"Filename", filename_,
"");
381 vals.emplace_back(
"Size",
382 "(" + std::to_string(nbRows()) +
"," + std::to_string(nbCols()) +
")",
385 std::string vars =
"";
386 for (NodeId i = 0; i < db.nbVariables(); i++) {
387 if (i > 0) vars +=
", ";
388 vars += nameFromId(i) +
"[" + std::to_string(db.domainSize(i)) +
"]";
390 vals.emplace_back(
"Variables", vars,
"");
391 vals.emplace_back(
"Induced types", inducedTypes_ ?
"True" :
"False",
"");
392 vals.emplace_back(
"Missing values", hasMissingValues() ?
"True" :
"False",
"");
395 switch (selectedAlgo_) {
396 case AlgoType::GREEDY_HILL_CLIMBING :
397 vals.emplace_back(key,
"Greedy Hill Climbing",
"");
399 case AlgoType::K2 : {
400 vals.emplace_back(key,
"K2",
"");
401 const auto& k2order = algoK2_.order();
403 for (NodeId i = 0; i < k2order.size(); i++) {
404 if (i > 0) vars +=
", ";
405 vars += nameFromId(k2order.atPos(i));
407 vals.emplace_back(
"K2 order", vars,
"");
409 case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST :
410 vals.emplace_back(key,
"Local Search with Tabu List",
"");
411 vals.emplace_back(
"Tabu list size", std::to_string(nbDecreasingChanges_),
"");
413 case AlgoType::MIIC : vals.emplace_back(key,
"MIIC",
"");
break;
414 default : vals.emplace_back(key,
"(unknown)",
"?");
break;
419 if (isScoreBased()) {
420 switch (scoreType_) {
421 case ScoreType::K2 : vals.emplace_back(key,
"K2",
"");
break;
422 case ScoreType::AIC : vals.emplace_back(key,
"AIC",
"");
break;
423 case ScoreType::BIC : vals.emplace_back(key,
"BIC",
"");
break;
424 case ScoreType::BD : vals.emplace_back(key,
"BD",
"");
break;
425 case ScoreType::BDeu : vals.emplace_back(key,
"BDeu",
"");
break;
426 case ScoreType::LOG2LIKELIHOOD : vals.emplace_back(key,
"Log2Likelihood",
"");
break;
427 default : vals.emplace_back(key,
"(unknown)",
"?");
break;
431 if (isConstraintBased()) {
433 switch (kmodeMiic_) {
434 case CorrectedMutualInformation::KModeTypes::MDL :
435 vals.emplace_back(key,
"MDL",
"");
437 case CorrectedMutualInformation::KModeTypes::NML :
438 vals.emplace_back(key,
"NML",
"");
440 case CorrectedMutualInformation::KModeTypes::NoCorr :
441 vals.emplace_back(key,
"No correction",
"");
443 default : vals.emplace_back(key,
"(unknown)",
"?");
break;
448 comment = checkScorePriorCompatibility();
449 switch (priorType_) {
450 case BNLearnerPriorType::NO_prior : vals.emplace_back(key,
"-", comment);
break;
451 case BNLearnerPriorType::DIRICHLET_FROM_DATABASE :
452 vals.emplace_back(key,
"Dirichlet", comment);
453 vals.emplace_back(
"Dirichlet from database", priorDbname_,
"");
455 case BNLearnerPriorType::DIRICHLET_FROM_BAYESNET :
456 vals.emplace_back(key,
"Dirichlet", comment);
457 vals.emplace_back(
"Dirichlet from Bayesian network : ", _prior_bn_.toString(),
"");
459 case BNLearnerPriorType::BDEU : vals.emplace_back(key,
"BDEU", comment);
break;
460 case BNLearnerPriorType::SMOOTHING : vals.emplace_back(key,
"Smoothing", comment);
break;
461 default : vals.emplace_back(key,
"(unknown)",
"?");
break;
464 if (priorType_ != BNLearnerPriorType::NO_prior)
465 vals.emplace_back(
"Prior weight", std::to_string(priorWeight_),
"");
467 if (databaseWeight() !=
double(nbRows())) {
468 vals.emplace_back(
"Database weight", std::to_string(databaseWeight()),
"");
473 if (!hasMissingValues()) comment =
"But no missing values in this database";
474 vals.emplace_back(
"use EM",
"True",
"");
478 if (dag2BN_.isEnabledMinEpsilonRate()) {
479 s <<
"MinRate: " << dag2BN_.minEpsilonRate();
482 if (dag2BN_.isEnabledEpsilon()) {
483 if (!first) s <<
", ";
485 s <<
"MinDiff: " << dag2BN_.epsilon();
487 if (dag2BN_.isEnabledMaxIter()) {
488 if (!first) s <<
", ";
490 s <<
"MaxIter: " << dag2BN_.maxIter();
492 if (dag2BN_.isEnabledMaxTime()) {
493 if (!first) s <<
", ";
495 s <<
"MaxTime: " << dag2BN_.maxTime();
498 vals.emplace_back(
"EM stopping criteria", s.str(), comment);
503 if (constraintIndegree_.maxIndegree() < std::numeric_limits< Size >::max()) {
504 vals.emplace_back(
"Constraint Max InDegree",
505 std::to_string(constraintIndegree_.maxIndegree()),
508 if (!constraintForbiddenArcs_.arcs().empty()) {
511 for (
const auto& arc: constraintForbiddenArcs_.arcs()) {
512 if (nofirst) res +=
", ";
514 res += nameFromId(arc.tail()) +
"->" + nameFromId(arc.head());
517 vals.emplace_back(
"Constraint Forbidden Arcs", res,
"");
519 if (!constraintMandatoryArcs_.arcs().empty()) {
522 for (
const auto& arc: constraintMandatoryArcs_.arcs()) {
523 if (nofirst) res +=
", ";
525 res += nameFromId(arc.tail()) +
"->" + nameFromId(arc.head());
528 vals.emplace_back(
"Constraint Mandatory Arcs", res,
"");
530 if (!constraintPossibleEdges_.edges().empty()) {
533 for (
const auto& edge: constraintPossibleEdges_.edges()) {
534 if (nofirst) res +=
", ";
536 res += nameFromId(edge.first()) +
"--" + nameFromId(edge.second());
539 vals.emplace_back(
"Constraint Possible Edges", res,
"");
541 if (!constraintSliceOrder_.sliceOrder().empty()) {
544 const auto& order = constraintSliceOrder_.sliceOrder();
545 for (
const auto& p: order) {
546 if (nofirst) res +=
", ";
548 res += nameFromId(p.first) +
":" + std::to_string(p.second);
551 vals.emplace_back(
"Constraint Slice Order", res,
"");
553 if (!constraintNoParentNodes_.nodes().empty()) {
556 for (
const auto& node: constraintNoParentNodes_.nodes()) {
557 if (nofirst) res +=
", ";
559 res += nameFromId(node);
562 vals.emplace_back(
"Constraint No Parent Nodes", res,
"");
564 if (!constraintNoChildrenNodes_.nodes().empty()) {
567 for (
const auto& node: constraintNoChildrenNodes_.nodes()) {
568 if (nofirst) res +=
", ";
570 res += nameFromId(node);
573 vals.emplace_back(
"Constraint No Children Nodes", res,
"");
575 if (initialDag_.size() != 0) {
576 vals.emplace_back(
"Initial DAG",
"True", initialDag_.toDot());
582 template <
typename GUM_SCALAR >
583 void BNLearner< GUM_SCALAR >::copyState(
const BNLearner< GUM_SCALAR >& learner) {
584 switch (learner.selectedAlgo_) {
585 case AlgoType::GREEDY_HILL_CLIMBING : useGreedyHillClimbing();
break;
586 case AlgoType::K2 : useK2(learner.algoK2_.order());
break;
587 case AlgoType::LOCAL_SEARCH_WITH_TABU_LIST :
588 useLocalSearchWithTabuList(learner.nbDecreasingChanges_);
590 case AlgoType::MIIC : useMIIC();
break;
593 switch (learner.scoreType_) {
594 case ScoreType::K2 : useScoreK2();
break;
595 case ScoreType::AIC : useScoreAIC();
break;
596 case ScoreType::BIC : useScoreBIC();
break;
597 case ScoreType::BD : useScoreBD();
break;
598 case ScoreType::BDeu : useScoreBDeu();
break;
599 case ScoreType::LOG2LIKELIHOOD : useScoreLog2Likelihood();
break;
602 switch (learner.kmodeMiic_) {
603 case CorrectedMutualInformation::KModeTypes::MDL : useMDLCorrection();
break;
604 case CorrectedMutualInformation::KModeTypes::NML : useNMLCorrection();
break;
605 case CorrectedMutualInformation::KModeTypes::NoCorr : useNoCorrection();
break;
608 switch (learner.priorType_) {
609 case BNLearnerPriorType::NO_prior : useNoPrior();
break;
610 case BNLearnerPriorType::DIRICHLET_FROM_DATABASE :
611 useDirichletPrior(learner.priorDbname_, learner.priorWeight_);
613 case BNLearnerPriorType::DIRICHLET_FROM_BAYESNET :
614 useDirichletPrior(learner._prior_bn_);
616 case BNLearnerPriorType::BDEU : useBDeuPrior(learner.priorWeight_);
break;
617 case BNLearnerPriorType::SMOOTHING : useSmoothingPrior(learner.priorWeight_);
break;
620 useEM_ = learner.useEM_;
621 noiseEM_ = learner.noiseEM_;
622 dag2BN_ = learner.dag2BN_;
624 setMaxIndegree(learner.constraintIndegree_.maxIndegree());
625 for (
const auto src: learner.constraintNoParentNodes_.nodes()) {
627 const auto dst = idFromName(learner.nameFromId(src));
628 addNoParentNode(dst);
633 for (
const auto src: learner.constraintNoChildrenNodes_.nodes()) {
635 const auto dst = idFromName(learner.nameFromId(src));
636 addNoChildrenNode(dst);
641 for (
const auto& arc: learner.constraintForbiddenArcs_.arcs()) {
643 const auto src = idFromName(learner.nameFromId(arc.tail()));
644 const auto dst = idFromName(learner.nameFromId(arc.head()));
645 addForbiddenArc(src, dst);
650 for (
const auto& arc: learner.constraintMandatoryArcs_.arcs()) {
652 const auto src = idFromName(learner.nameFromId(arc.tail()));
653 const auto dst = idFromName(learner.nameFromId(arc.head()));
654 addMandatoryArc(src, dst);
659 for (
const auto& edge: learner.constraintPossibleEdges_.edges()) {
661 const auto src = idFromName(learner.nameFromId(edge.first()));
662 const auto dst = idFromName(learner.nameFromId(edge.second()));
663 addPossibleEdge(src, dst);
668 if (!learner.constraintSliceOrder_.sliceOrder().empty()) {
669 NodeProperty< NodeId > slice_order;
670 for (
const auto& p: learner.constraintSliceOrder_.sliceOrder()) {
672 slice_order.insert(idFromName(learner.nameFromId(p.first)), p.second);
677 setSliceOrder(slice_order);
681 template <
typename GUM_SCALAR >
682 void BNLearner< GUM_SCALAR >::createPrior_() {
684 Prior* old_prior = prior_;
687 switch (priorType_) {
688 case BNLearnerPriorType::NO_prior :
689 prior_ =
new NoPrior(scoreDatabase_.databaseTable(), scoreDatabase_.nodeId2Columns());
692 case BNLearnerPriorType::SMOOTHING :
694 =
new SmoothingPrior(scoreDatabase_.databaseTable(), scoreDatabase_.nodeId2Columns());
697 case BNLearnerPriorType::DIRICHLET_FROM_DATABASE :
698 if (priorDatabase_ !=
nullptr) {
699 delete priorDatabase_;
700 priorDatabase_ =
nullptr;
704 =
new Database(priorDbname_, scoreDatabase_, scoreDatabase_.missingSymbols());
706 prior_ =
new DirichletPriorFromDatabase(scoreDatabase_.databaseTable(),
707 priorDatabase_->parser(),
708 priorDatabase_->nodeId2Columns());
711 case BNLearnerPriorType::DIRICHLET_FROM_BAYESNET :
713 =
new DirichletPriorFromBN< GUM_SCALAR >(scoreDatabase_.databaseTable(), &_prior_bn_);
716 case BNLearnerPriorType::BDEU :
717 prior_ =
new BDeuPrior(scoreDatabase_.databaseTable(), scoreDatabase_.nodeId2Columns());
724 prior_->setWeight(priorWeight_);
727 if (old_prior !=
nullptr)
delete old_prior;
730 template <
typename GUM_SCALAR >
731 INLINE std::ostream&
operator<<(std::ostream& output,
const BNLearner< GUM_SCALAR >& learner) {
732 output << learner.toString();
A listener that allows BNLearner to be used as a proxy for its inner algorithms.
A basic pack of learning algorithms that can easily be used.
Class representing a Bayesian network.
Error: The database contains some missing values.
Error: A name of variable is not found in the database.
Exception : operation not allowed.
BNLearner(const std::string &filename, const std::vector< std::string > &missingSymbols={"?"}, const bool induceTypes=true)
default constructor
A pack of learning algorithms that can easily be used.
#define GUM_ERROR(type, msg)
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Size Idx
Type for indexes.
Size NodeId
Type for node ids.
include the inlined functions if necessary
gum is the global namespace for all aGrUM entities
std::ostream & operator<<(std::ostream &out, const TiXmlNode &base)