aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
BNDatabaseGenerator_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
48
51
52namespace gum::learning {
53
54
56 template < typename GUM_SCALAR >
58 _bn_(bn) {
59 GUM_CONSTRUCTOR(BNDatabaseGenerator)
60
61 // get the node names => they will serve as ids
62 NodeId id = 0;
63 for (const auto& var: _bn_.dag()) {
64 auto name = _bn_.variable(var).name();
65 _names2ids_.insert(name, var);
66 ++id;
67 }
68 _nbVars_ = id;
69 _varOrder_.resize(_nbVars_);
71 std::iota(_varOrder_.begin(), _varOrder_.end(), (Idx)0);
72 }
73
75 template < typename GUM_SCALAR >
79
81 template < typename GUM_SCALAR >
83 const Instantiation inst;
84 return drawSamples(nbSamples, inst);
85 }
86
88 template < typename GUM_SCALAR >
90 const Instantiation& evs,
91 int timeout) {
92 int progress = 0;
93
94 if (onProgress.hasListener()) { GUM_EMIT2(onProgress, progress, 0.0); }
95
96 _database_.clear();
97 _database_.resize(nbSamples);
98 for (auto& row: _database_) {
99 row.resize(_nbVars_);
100 }
101 // get the order in which the nodes will be sampled
102 const auto topOrder = _bn_.topologicalOrder();
103 gum::Instantiation particule;
104
105 // create instantiations in advance
106 for (NodeId node = 0; node < _nbVars_; ++node)
107 particule.add(_bn_.variable(node));
108
109 gum::Timer timer;
110 timer.reset();
111
112 // perform the sampling
114 Idx idSample = 0;
115 while (idSample < nbSamples) {
116 if (onProgress.hasListener()) {
117 auto p = int((idSample * 100) / nbSamples);
118 if (p != progress) {
119 progress = p;
120 GUM_EMIT2(onProgress, progress, timer.step());
121 }
122 }
123 std::vector< Idx >& sample = _database_.at(idSample);
124 bool reject = false;
125 for (Idx rank = 0; rank < _nbVars_; ++rank) {
126 const NodeId node = topOrder[rank];
127 const auto& var = _bn_.variable(node);
128 const auto& cpt = _bn_.cpt(node);
129
130 const double nb = gum::randomProba();
131 double cumul = 0.0;
132 for (particule.setFirstVar(var); !particule.end(); particule.incVar(var)) {
133 cumul += cpt[particule];
134 if (cumul >= nb) break;
135 }
136 if (particule.end()) particule.setLastVar(var);
137
138 if ((!evs.empty()) && evs.contains(var) && (evs.val(var) != particule.val(var))) {
139 reject = true;
140 break;
141 }
142
143 sample.at(node) = particule.val(var);
144 _log2likelihood_ += std::log2(_bn_.cpt(node)[particule]);
145 }
146 if (timeout > 0 && timer.step() > timeout) { break; }
147 if (reject) { continue; }
148 idSample++;
149 }
150
151 if (idSample > 0) {
152 if (idSample < nbSamples) _database_.resize(idSample);
153 } else {
154 _database_.clear();
155 }
156 _drawnSamples_ = true;
157
158 if (onProgress.hasListener()) {
159 std::stringstream ss;
160 ss << "Database of size " << idSample << "(" << nbSamples << ") generated in " << timer.step()
161 << " seconds. Log2likelihood : " << _log2likelihood_;
162 GUM_EMIT1(onStop, ss.str());
163 }
164
165 return _log2likelihood_;
166 }
167
168 template < typename GUM_SCALAR >
170 if (!_drawnSamples_) { GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.") }
171
172 return _database_.size();
173 }
174
175 template < typename GUM_SCALAR >
177 if (!_drawnSamples_) { GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.") }
178
179 return _nbVars_;
180 }
181
182 template < typename GUM_SCALAR >
184 if (!_drawnSamples_) { GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.") }
185 return _database_.at(row).at(_varOrder_.at(col));
186 }
187
188 template < typename GUM_SCALAR >
189 INLINE std::string BNDatabaseGenerator< GUM_SCALAR >::samplesLabelAt(Idx row, Idx col) const {
190 if (!_drawnSamples_) { GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.") }
191 const auto j = _varOrder_.at(col);
192 return _label_(_database_.at(row), _bn_.variable(j), j);
193 }
194
195 template < typename GUM_SCALAR >
199
200 template < typename GUM_SCALAR >
204
205 template < typename GUM_SCALAR >
209
211 template < typename GUM_SCALAR >
212 void BNDatabaseGenerator< GUM_SCALAR >::toCSV(const std::string& csvFileURL,
213 bool useLabels,
214 bool append,
215 std::string csvSeparator,
216 bool checkOnAppend) const {
217 if (!_drawnSamples_) { GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.") }
218
219 if (csvSeparator.find('\n') != std::string::npos) {
220 GUM_ERROR(InvalidArgument, "csvSeparator must not contain end-line characters")
221 }
222
223 bool includeHeader = true;
224 if (append) {
225 std::ifstream csvFile(csvFileURL);
226 if (csvFile) {
227 if (auto varOrder = _varOrderFromCSV_(csvFile, csvSeparator);
228 checkOnAppend && varOrder != _varOrder_)
230 "Inconsistent variable order in csvFile when appending. You "
231 "can use setVarOrderFromCSV(url) function to get the right "
232 "order. You could also set parameter checkOnAppend=false if you "
233 "know what you are doing.")
234 includeHeader = false;
235 }
236 csvFile.close();
237 }
238
239
240 auto ofstreamFlag = append ? std::ofstream::app : std::ofstream::out;
241
242 std::ofstream os(csvFileURL, ofstreamFlag);
243 bool firstCol = true;
244 if (includeHeader) {
245 for (const auto& i: _varOrder_) {
246 if (firstCol) {
247 firstCol = false;
248 } else {
249 os << csvSeparator;
250 }
251 os << _bn_.variable(i).name();
252 }
253 }
254 os << std::endl;
255
256 bool firstRow = true;
257 for (const auto& row: _database_) {
258 if (firstRow) {
259 firstRow = false;
260 } else {
261 os << std::endl;
262 }
263 firstCol = true;
264 for (const auto& i: _varOrder_) {
265 if (firstCol) {
266 firstCol = false;
267 } else {
268 os << csvSeparator;
269 }
270 if (useLabels) {
271 const auto& v = _bn_.variable(i);
272 if (v.varType() == VarType::DISCRETIZED) {
273 switch (_discretizedLabelMode_) {
274 case DiscretizedLabelMode::MEDIAN : os << v.numerical(row.at(i)); break;
276 os << static_cast< const IDiscretizedVariable& >(v).draw(row.at(i));
277 break;
278 case DiscretizedLabelMode::INTERVAL : os << v.label(row.at(i)); break;
279 }
280 } else {
281 os << v.label(row.at(i));
282 }
283 } else {
284 os << row[i];
285 }
286 }
287 }
288
289 os.close();
290 }
291
292 template < typename GUM_SCALAR >
293 std::string BNDatabaseGenerator< GUM_SCALAR >::_label_(const std::vector< Idx >& row,
294 const DiscreteVariable& v,
295 Idx i) const {
296 if (v.varType() == VarType::DISCRETIZED) {
297 switch (_discretizedLabelMode_) {
298 case DiscretizedLabelMode::MEDIAN : return std::to_string(v.numerical(row.at(i)));
300 return std::to_string(static_cast< const IDiscretizedVariable& >(v).draw(row.at(i)));
301 case DiscretizedLabelMode::INTERVAL : return v.label(row.at(i));
302 }
303 }
304
305 return v.label(row.at(i));
306 }
307
309 template < typename GUM_SCALAR >
311 if (!_drawnSamples_) GUM_ERROR(OperationNotAllowed, "proceed() must be called first.")
312
313 DatabaseTable db;
314 std::vector< std::string > varNames;
315 varNames.reserve(_nbVars_);
316 for (const auto& i: _varOrder_) {
317 varNames.push_back(_names2ids_.first(i));
318 }
319
320 // create the translators
321 for (std::size_t i = 0; i < _nbVars_; ++i) {
322 const Variable& var = _bn_.variable(_varOrder_[i]);
323 db.insertTranslator(var, i);
324 }
325
326 if (useLabels) {
327 std::vector< std::string > xrow(_nbVars_);
328 for (const auto& row: _database_) {
329 for (Idx i = 0; i < _nbVars_; ++i) {
330 const Idx j = _varOrder_.at(i);
331 xrow[i] = _label_(row, _bn_.variable(j), j);
332 }
333 db.insertRow(xrow);
334 }
335 } else {
336 std::vector< DBTranslatedValueType > translatorType(_nbVars_);
337 for (std::size_t i = 0; i < _nbVars_; ++i) {
338 translatorType[i] = db.translator(i).getValType();
339 }
341 const auto xmiss = gum::learning::DatabaseTable::IsMissing::False;
342 for (const auto& row: _database_) {
343 for (Idx i = 0; i < _nbVars_; ++i) {
344 const Idx j = _varOrder_.at(i);
345 if (translatorType[i] == DBTranslatedValueType::DISCRETE)
346 xrow[i].discr_val = std::size_t(row.at(j));
347 else xrow[i].cont_val = float(row.at(j));
348 }
349 }
350 db.insertRow(xrow, xmiss);
351 }
352
353 return db;
354 }
355
357 template < typename GUM_SCALAR >
358 std::vector< std::vector< Idx > > BNDatabaseGenerator< GUM_SCALAR >::database() const {
359 if (!_drawnSamples_) GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.")
360
361 auto db(_database_);
362 for (Idx i = 0; i < _database_.size(); ++i) {
363 for (Idx j = 0; j < _nbVars_; ++j) {
364 db.at(i).at(j) = (Idx)_database_.at(i).at(_varOrder_.at(j));
365 }
366 }
367 return db;
368 }
369
371 template < typename GUM_SCALAR >
373 if (varOrder.size() != _nbVars_)
374 GUM_ERROR(FatalError, "varOrder's size must be equal to the number of variables")
375
376 std::vector< bool > usedVars(_nbVars_, false);
377 for (const auto& i: varOrder) {
378 if (i >= _nbVars_) GUM_ERROR(FatalError, "varOrder contains invalid variables")
379 if (usedVars.at(i)) GUM_ERROR(FatalError, "varOrder must not have repeated variables")
380 usedVars.at(i) = true;
381 }
382
383 if (std::find(usedVars.begin(), usedVars.end(), false) != usedVars.end()) {
384 GUM_ERROR(FatalError, "varOrder must contain all variables")
385 }
386
388 }
389
391 template < typename GUM_SCALAR >
392 void BNDatabaseGenerator< GUM_SCALAR >::setVarOrder(const std::vector< std::string >& varOrder) {
393 std::vector< Idx > varOrderIdx;
394 varOrderIdx.reserve(varOrder.size());
395 for (const auto& vname: varOrder) {
396 varOrderIdx.push_back(_names2ids_.second(vname));
397 }
398 setVarOrder(varOrderIdx);
399 }
400
402 template < typename GUM_SCALAR >
404 const std::string& csvSeparator) {
405 setVarOrder(_varOrderFromCSV_(csvFileURL, csvSeparator));
406 }
407
409 template < typename GUM_SCALAR >
411 std::vector< Idx > varOrder;
412 varOrder.reserve(_nbVars_);
413 for (const auto& v: _bn_.topologicalOrder()) {
414 varOrder.push_back(v);
415 }
417 }
418
420 template < typename GUM_SCALAR >
422 std::vector< Idx > varOrder;
423 varOrder.reserve(_nbVars_);
424 for (const auto& v: _bn_.topologicalOrder()) {
425 varOrder.push_back(v);
426 }
427 std::reverse(varOrder.begin(), varOrder.end());
429 }
430
432 template < typename GUM_SCALAR >
434 std::vector< std::string > varOrder;
435 varOrder.reserve(_bn_.size());
436 for (const auto& var: _bn_.dag()) {
437 varOrder.push_back(_bn_.variable(var).name());
438 }
439 std::shuffle(varOrder.begin(), varOrder.end(), gum::randomGenerator());
441 }
442
444 template < typename GUM_SCALAR >
446 return _varOrder_;
447 }
448
450 template < typename GUM_SCALAR >
451 std::vector< std::string > BNDatabaseGenerator< GUM_SCALAR >::varOrderNames() const {
452 std::vector< std::string > varNames;
453 varNames.reserve(_nbVars_);
454 for (const auto& i: _varOrder_) {
455 varNames.push_back(_names2ids_.first(i));
456 }
457
458 return varNames;
459 }
460
462 template < typename GUM_SCALAR >
464 if (!_drawnSamples_) { GUM_ERROR(OperationNotAllowed, "drawSamples() must be called first.") }
465 return _log2likelihood_;
466 }
467
469 template < typename GUM_SCALAR >
470 std::vector< Idx >
472 const std::string& csvSeparator) const {
473 std::ifstream csvFile(csvFileURL);
474 std::vector< Idx > varOrder;
475 if (csvFile) {
476 varOrder = _varOrderFromCSV_(csvFile, csvSeparator);
477 csvFile.close();
478 } else {
479 GUM_ERROR(NotFound, "csvFileURL does not exist")
480 }
481
482 return varOrder;
483 }
484
486 template < typename GUM_SCALAR >
487 std::vector< Idx >
489 const std::string& csvSeparator) const {
490 std::string line;
491 std::vector< std::string > header_found;
492 header_found.reserve(_nbVars_);
493 while (std::getline(csvFile, line)) {
494 std::size_t i = 0;
495 auto pos = line.find(csvSeparator);
496 while (pos != std::string::npos) {
497 header_found.push_back(line.substr(i, pos - i));
498 pos += csvSeparator.length();
499 i = pos;
500 pos = line.find(csvSeparator, pos);
501
502 if (pos == std::string::npos) header_found.push_back(line.substr(i, line.length()));
503 }
504 break;
505 }
506
507 std::vector< Size > varOrder;
508 varOrder.reserve(_nbVars_);
509
510 for (const auto& hf: header_found) {
511 varOrder.push_back(_names2ids_.second(hf));
512 }
513
514 return varOrder;
515 }
516} // namespace gum::learning
Base class for discrete random variable.
virtual double numerical(Idx indice) const =0
get a numerical representation of the indice-th value.
VarType varType() const override=0
returns the varType of variable
virtual std::string label(Idx i) const =0
get the indice-th label. This method is pure virtual.
Exception : fatal (unknown ?) error.
A base class for discretized variables, independent of the ticks type.
Class for assigning/browsing values to tuples of discrete variables.
bool end() const
Returns true if the Instantiation reached the end.
void incVar(const DiscreteVariable &v)
Operator increment for variable v only.
void add(const DiscreteVariable &v) final
Adds a new variable in the Instantiation.
virtual bool empty() const final
Returns true if the instantiation is empty.
bool contains(const DiscreteVariable &v) const final
Indicates whether a given variable belongs to the Instantiation.
void setFirstVar(const DiscreteVariable &v)
Assign the first value in the Instantiation for var v.
Idx val(Idx i) const
Returns the current value of the variable at position i.
void setLastVar(const DiscreteVariable &v)
Assign the last value in the Instantiation for var v.
Exception: at least one argument passed to a function is not what was expected.
Exception : the element we looked for cannot be found.
Exception : operation not allowed.
Signaler2< Size, double > onProgress
Progression (percent) and time.
Signaler1< const std::string & > onStop
with a possible explanation for stopping
Class used to compute response times for benchmark purposes.
Definition timer.h:69
void reset()
Reset the timer.
Definition timer_inl.h:52
double step() const
Returns the delta time between now and the last reset() call (or the constructor).
Definition timer_inl.h:71
Base class for every random variable.
Definition variable.h:79
bool _drawnSamples_
whether drawSamples has been already called.
std::vector< Idx > varOrder() const
returns variable order indexes
std::string samplesLabelAt(Idx row, Idx col) const
generate and stock database, returns log2likelihood using ProgressNotifier as notification
DatabaseTable toDatabaseTable(bool useLabels=true) const
generates a DatabaseVectInRAM
void setVarOrderFromCSV(const std::string &csvFileURL, const std::string &csvSeparator=",")
change columns order according to a csv file
Size samplesNbRows() const
generate and stock database, returns log2likelihood using ProgressNotifier as notification
std::string _label_(const std::vector< Idx > &row, const DiscreteVariable &v, Idx i) const
return the final string for a label (taking into account the behavior for DiscretizedVariable) from a...
std::vector< std::vector< Idx > > database() const
generates database according to bn into a std::vector
void setDiscretizedLabelModeRandom()
set the behaviour of sampling for discretized variable to uniformly draw double value
Idx samplesAt(Idx row, Idx col) const
generate and stock database, returns log2likelihood using ProgressNotifier as notification
double _log2likelihood_
log2Likelihood of generated samples
const BayesNet< GUM_SCALAR > & bn(void)
return const ref to the Bayes Net
BNDatabaseGenerator(const BayesNet< GUM_SCALAR > &bn)
default constructor
void setDiscretizedLabelModeInterval()
set the behaviour of sampling for discretized variable to select the label : "[min,...
Bijection< std::string, NodeId > _names2ids_
bijection nodes names
std::vector< std::vector< Idx > > _database_
generated database
Size samplesNbCols() const
generate and stock database, returns log2likelihood using ProgressNotifier as notification
const BayesNet< GUM_SCALAR > & _bn_
Bayesian network.
void setAntiTopologicalVarOrder()
set columns in antiTopoligical order
double log2likelihood() const
returns log2Likelihood of generated samples
void setTopologicalVarOrder()
set columns in topoligical order
void setDiscretizedLabelModeMedian()
set the behaviour of sampling for discretized variable to deterministic select double median of inter...
double drawSamples(Size nbSamples)
generate and stock database, returns log2likelihood using ProgressNotifier as notification
std::vector< Idx > _varOrder_
variable order in generated database
std::vector< std::string > varOrderNames() const
returns variable order.
std::vector< Idx > _varOrderFromCSV_(const std::string &csvFileURL, const std::string &csvSeparator=",") const
returns varOrder from a csv file
void setVarOrder(const std::vector< Idx > &varOrder)
change columns order
void toCSV(const std::string &csvFileURL, bool useLabels=true, bool append=false, std::string csvSeparator=",", bool checkOnAppend=false) const
generates csv representing the generated database
void setRandomVarOrder()
set columns in random order
The class for storing a record in a database.
Definition DBRow.h:75
DBTranslatedValueType getValType() const
returns the type of values handled by the translator
The class representing a tabular database as used by learning tasks.
std::size_t insertTranslator(const DBTranslator &translator, const std::size_t input_column, const bool unique_column=true)
insert a new translator into the database table
const DBTranslator & translator(const std::size_t k, const bool k_is_input_col=false) const
returns either the kth translator of the database table or the first one reading the kth column of th...
void insertRow(const std::vector< std::string > &new_row) override
insert a new row at the end of the database
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size Idx
Type for indexes.
Definition types.h:79
Size NodeId
Type for node ids.
std::mt19937 & randomGenerator()
define a random_engine with correct seed
double randomProba()
Returns a random double between 0 and 1 included (i.e.
include the inlined functions if necessary
Definition CSVParser.h:54
#define GUM_EMIT1(signal, arg1)
Definition signaler1.h:61
#define GUM_EMIT2(signal, arg1, arg2)
Definition signaler2.h:61
Class used to compute response times for benchmark purposes.