aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
DBRowGeneratorEM_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
49
50#ifndef DOXYGEN_SHOULD_SKIP_THIS
51
52namespace gum {
53
54 namespace learning {
55
57 template < typename GUM_SCALAR >
59 const std::vector< DBTranslatedValueType >& column_types,
60 const BayesNet< GUM_SCALAR >& bn,
61 const Bijection< NodeId, std::size_t >& nodeId2columns) :
62 DBRowGeneratorWithBN< GUM_SCALAR >(column_types,
63 bn,
65 nodeId2columns),
66 _filled_row1_(bn.size(), 1.0), _filled_row2_(bn.size(), 1.0) {
67 setBayesNet(bn);
68
69 GUM_CONSTRUCTOR(DBRowGeneratorEM);
70 }
71
73 template < typename GUM_SCALAR >
74 DBRowGeneratorEM< GUM_SCALAR >::DBRowGeneratorEM(const DBRowGeneratorEM< GUM_SCALAR >& from) :
75 DBRowGeneratorWithBN< GUM_SCALAR >(from), _input_row_(from._input_row_),
76 _missing_cols_(from._missing_cols_), _nb_miss_(from._nb_miss_),
77 _joint_proba_(from._joint_proba_), _filled_row1_(from._filled_row1_),
78 _filled_row2_(from._filled_row2_), _use_filled_row1_(from._use_filled_row1_),
79 _original_weight_(from._original_weight_) {
80 if (from._joint_inst_ != nullptr) {
81 _joint_inst_ = new Instantiation(_joint_proba_);
82 const auto& var_seq = _joint_inst_->variablesSequence();
83 const std::size_t size = var_seq.size();
84 for (std::size_t i = std::size_t(0); i < size; ++i) {
85 _joint_inst_->chgVal(Idx(i), from._joint_inst_->val(i));
86 }
87 }
88
89 GUM_CONS_CPY(DBRowGeneratorEM);
90 }
91
93 template < typename GUM_SCALAR >
94 DBRowGeneratorEM< GUM_SCALAR >::DBRowGeneratorEM(DBRowGeneratorEM< GUM_SCALAR >&& from) :
95 DBRowGeneratorWithBN< GUM_SCALAR >(std::move(from)), _input_row_(from._input_row_),
96 _missing_cols_(std::move(from._missing_cols_)), _nb_miss_(from._nb_miss_),
97 _joint_proba_(std::move(from._joint_proba_)), _filled_row1_(std::move(from._filled_row1_)),
98 _filled_row2_(std::move(from._filled_row2_)), _use_filled_row1_(from._use_filled_row1_),
99 _original_weight_(from._original_weight_) {
100 if (from._joint_inst_ != nullptr) {
101 _joint_inst_ = new Instantiation(_joint_proba_);
102 const auto& var_seq = _joint_inst_->variablesSequence();
103 const std::size_t size = var_seq.size();
104 for (std::size_t i = std::size_t(0); i < size; ++i) {
105 _joint_inst_->chgVal(Idx(i), from._joint_inst_->val(i));
106 }
107 }
108
109 GUM_CONS_MOV(DBRowGeneratorEM);
110 }
111
113 template < typename GUM_SCALAR >
114 DBRowGeneratorEM< GUM_SCALAR >* DBRowGeneratorEM< GUM_SCALAR >::clone() const {
115 return new DBRowGeneratorEM< GUM_SCALAR >(*this);
116 }
117
119 template < typename GUM_SCALAR >
120 DBRowGeneratorEM< GUM_SCALAR >::~DBRowGeneratorEM() {
121 if (_joint_inst_ != nullptr) delete _joint_inst_;
122 GUM_DESTRUCTOR(DBRowGeneratorEM);
123 }
124
126 template < typename GUM_SCALAR >
127 DBRowGeneratorEM< GUM_SCALAR >&
128 DBRowGeneratorEM< GUM_SCALAR >::operator=(const DBRowGeneratorEM< GUM_SCALAR >& from) {
129 if (this != &from) {
130 DBRowGeneratorWithBN< GUM_SCALAR >::operator=(from);
131 _input_row_ = from._input_row_;
132 _missing_cols_ = from._missing_cols_;
133 _nb_miss_ = from._nb_miss_;
134 _joint_proba_ = from._joint_proba_;
135 _filled_row1_ = from._filled_row1_;
136 _filled_row2_ = from._filled_row2_;
137 _use_filled_row1_ = from._use_filled_row1_;
138 _original_weight_ = from._original_weight_;
139
140 if (_joint_inst_ != nullptr) {
141 delete _joint_inst_;
142 _joint_inst_ = nullptr;
143 }
144
145 if (from._joint_inst_ != nullptr) {
146 _joint_inst_ = new Instantiation(_joint_proba_);
147 const auto& var_seq = _joint_inst_->variablesSequence();
148 const std::size_t size = var_seq.size();
149 for (std::size_t i = std::size_t(0); i < size; ++i) {
150 _joint_inst_->chgVal(Idx(i), from._joint_inst_->val(i));
151 }
152 }
153 }
154
155 return *this;
156 }
157
159 template < typename GUM_SCALAR >
160 DBRowGeneratorEM< GUM_SCALAR >&
161 DBRowGeneratorEM< GUM_SCALAR >::operator=(DBRowGeneratorEM< GUM_SCALAR >&& from) {
162 if (this != &from) {
163 DBRowGeneratorWithBN< GUM_SCALAR >::operator=(std::move(from));
164 _input_row_ = from._input_row_;
165 _missing_cols_ = std::move(from._missing_cols_);
166 _nb_miss_ = from._nb_miss_;
167 _joint_proba_ = std::move(from._joint_proba_);
168 _filled_row1_ = std::move(from._filled_row1_);
169 _filled_row2_ = std::move(from._filled_row2_);
170 _use_filled_row1_ = from._use_filled_row1_;
171 _original_weight_ = from._original_weight_;
172
173 if (_joint_inst_ != nullptr) {
174 delete _joint_inst_;
175 _joint_inst_ = nullptr;
176 }
177
178 if (from._joint_inst_ != nullptr) {
179 _joint_inst_ = new Instantiation(_joint_proba_);
180 const auto& var_seq = _joint_inst_->variablesSequence();
181 const std::size_t size = var_seq.size();
182 for (std::size_t i = std::size_t(0); i < size; ++i) {
183 _joint_inst_->chgVal(Idx(i), from._joint_inst_->val(i));
184 }
185 }
186 }
187
188 return *this;
189 }
190
192 template < typename GUM_SCALAR >
193 INLINE const DBRow< DBTranslatedValue >& DBRowGeneratorEM< GUM_SCALAR >::generate() {
194 this->decreaseRemainingRows();
195
196 // if everything is observed, return the input row
197 if (_input_row_ != nullptr) return *_input_row_;
198
199 if (_use_filled_row1_) {
200 // get the weight of the row from the joint probability
201 _filled_row1_.setWeight(_joint_proba_.get(*_joint_inst_) * _original_weight_);
202
203 // fill the values of the row
204 for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i)
205 _filled_row1_[_missing_cols_[i]].discr_val = _joint_inst_->val(i);
206
207 _joint_inst_->inc();
208 _use_filled_row1_ = false;
209
210 return _filled_row1_;
211 } else {
212 // get the weight of the row from the joint probability
213 _filled_row2_.setWeight(_joint_proba_.get(*_joint_inst_) * _original_weight_);
214
215 // fill the values of the row
216 for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i)
217 _filled_row2_[_missing_cols_[i]].discr_val = _joint_inst_->val(i);
218
219 _joint_inst_->inc();
220 _use_filled_row1_ = true;
221
222 return _filled_row2_;
223 }
224 }
225
227 template < typename GUM_SCALAR >
228 INLINE std::size_t
229 DBRowGeneratorEM< GUM_SCALAR >::computeRows_(const DBRow< DBTranslatedValue >& row) {
230 // check if there are unobserved values among the columns of interest.
231 // If this is the case, save in _missing_cols_ all the columns with unobserved values
232 bool found_unobserved = false;
233 const auto& xrow = row.row();
234 for (const auto col: this->columns_of_interest_) {
235 switch (this->column_types_[col]) {
236 case DBTranslatedValueType::DISCRETE :
237 if (xrow[col].discr_val == std::numeric_limits< std::size_t >::max()) {
238 if (!found_unobserved) {
239 _missing_cols_.clear();
240 found_unobserved = true;
241 }
242 _missing_cols_.push_back(col);
243 }
244 break;
245
246 case DBTranslatedValueType::CONTINUOUS :
248 "The BDRowGeneratorEM does not handle yet continuous "
249 << "variables. But the variable in column" << col << " is continuous.");
250 break;
251
252 default :
254 "DBTranslatedValueType " << int(this->column_types_[col])
255 << " is not supported yet");
256 }
257 }
258
259 // if there is no unobserved value, make the _input_row_ point to the row
260 if (!found_unobserved) {
261 _input_row_ = &row;
262 return std::size_t(1);
263 }
264
265 _input_row_ = nullptr;
266 _nb_miss_ = _missing_cols_.size();
267 _original_weight_ = row.weight();
268
269 // here, there are missing symbols, so we should compute the distribution
270 // of the missing values. For this purpose, we use Variable Elimination
271 VariableElimination< GUM_SCALAR > ve(this->bn_);
272
273 // add the targets and fill the output row with the observed values
274 NodeSet target_set(_nb_miss_);
275 if (this->nodeId2columns_.empty()) {
276 std::size_t i = std::size_t(0);
277 bool end_miss = false;
278 for (const auto col: this->columns_of_interest_) {
279 if (!end_miss && (col == _missing_cols_[i])) {
280 target_set.insert(NodeId(col));
281 ++i;
282 if (i == _nb_miss_) end_miss = true;
283 } else {
284 _filled_row1_[col].discr_val = xrow[col].discr_val;
285 _filled_row2_[col].discr_val = xrow[col].discr_val;
286 }
287 }
288 } else {
289 std::size_t i = std::size_t(0);
290 bool end_miss = false;
291 for (const auto col: this->columns_of_interest_) {
292 if (!end_miss && (col == _missing_cols_[i])) {
293 target_set.insert(this->nodeId2columns_.first(col));
294 ++i;
295 if (i == _nb_miss_) end_miss = true;
296 } else {
297 _filled_row1_[col].discr_val = xrow[col].discr_val;
298 _filled_row2_[col].discr_val = xrow[col].discr_val;
299 }
300 }
301 }
302
303 ve.addJointTarget(target_set);
304
305 // add the evidence and the target into variable elimination
306 const std::size_t row_size = xrow.size();
307 if (this->nodeId2columns_.empty()) {
308 for (std::size_t col = std::size_t(0); col < row_size; ++col) {
309 switch (this->column_types_[col]) {
310 case DBTranslatedValueType::DISCRETE :
311 // only observed values are evidence
312 if (xrow[col].discr_val != std::numeric_limits< std::size_t >::max()) {
313 ve.addEvidence(NodeId(col), xrow[col].discr_val);
314 }
315 break;
316
317 case DBTranslatedValueType::CONTINUOUS :
319 "The BDRowGeneratorEM does not handle yet continuous "
320 << "variables. But the variable in column" << col << " is continuous.");
321 break;
322
323 default :
325 "DBTranslatedValueType " << int(this->column_types_[col])
326 << " is not supported yet");
327 }
328 }
329 } else {
330 for (std::size_t col = std::size_t(0); col < row_size; ++col) {
331 switch (this->column_types_[col]) {
332 case DBTranslatedValueType::DISCRETE :
333 // only observed values are evidence
334 if (xrow[col].discr_val != std::numeric_limits< std::size_t >::max()) {
335 ve.addEvidence(this->nodeId2columns_.first(col), xrow[col].discr_val);
336 }
337 break;
338
339 case DBTranslatedValueType::CONTINUOUS :
341 "The BDRowGeneratorEM does not handle yet continuous "
342 << "variables. But the variable in column" << col << " is continuous.");
343 break;
344
345 default :
347 "DBTranslatedValueType " << int(this->column_types_[col])
348 << " is not supported yet");
349 }
350 }
351 }
352
353 // get the tensor of the target set
354 Tensor< GUM_SCALAR >& pot
355 = const_cast< Tensor< GUM_SCALAR >& >(ve.jointPosterior(target_set));
356 _joint_proba_ = std::move(pot);
357 if (_joint_inst_ != nullptr) delete _joint_inst_;
358 _joint_inst_ = new Instantiation(_joint_proba_);
359
360 // get the mapping between variables of the joint proba and the
361 // columns in the database
362 const auto& var_sequence = _joint_proba_.variablesSequence();
363 if (this->nodeId2columns_.empty()) {
364 for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i) {
365 _missing_cols_[i] = std::size_t(this->bn_->nodeId(*(var_sequence[i])));
366 }
367 } else {
368 for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i) {
369 _missing_cols_[i] = this->nodeId2columns_.second(this->bn_->nodeId(*(var_sequence[i])));
370 }
371 }
372
373 return std::size_t(_joint_proba_.domainSize());
374 }
375
377 template < typename GUM_SCALAR >
378 void DBRowGeneratorEM< GUM_SCALAR >::setBayesNet(const BayesNet< GUM_SCALAR >& new_bn) {
379 // check that if nodeId2columns is not empty, then all the columns
380 // correspond to nodes of the BN
381 if (!this->nodeId2columns_.empty()) {
382 const DAG& dag = new_bn.dag();
383 for (auto iter = this->nodeId2columns_.begin(); iter != this->nodeId2columns_.end();
384 ++iter) {
385 if (!dag.existsNode(iter.first())) {
387 "Column " << iter.second() << " of the database is associated to Node ID "
388 << iter.first()
389 << ", which does not belong to the Bayesian network");
390 }
391 }
392 }
393
394 DBRowGeneratorWithBN< GUM_SCALAR >::setBayesNet(new_bn);
395
396 // we determine the size of the filled rows
397 std::size_t size = std::size_t(0);
398 if (this->nodeId2columns_.empty()) {
399 for (auto node: new_bn.dag())
400 if (std::size_t(node) > size) size = std::size_t(node);
401 } else {
402 for (auto iter = this->nodeId2columns_.begin(); iter != this->nodeId2columns_.end();
403 ++iter) {
404 if (iter.second() > size) size = iter.second();
405 }
406 }
407 _filled_row1_.resize(size + 1);
408 _filled_row2_.resize(size + 1);
409 }
410
411 } /* namespace learning */
412
413} /* namespace gum */
414
415#endif /* DOXYGEN_SHOULD_SKIP_THIS */
A DBRowGenerator class that returns exactly the rows it gets in input.
Exception : node does not exist.
Exception : there is something wrong with an implementation.
DBRowGeneratorEM(const std::vector< DBTranslatedValueType > &column_types, const BayesNet< GUM_SCALAR > &bn, const Bijection< NodeId, std::size_t > &nodeId2columns=Bijection< NodeId, std::size_t >())
default constructor
Base class for DBRowGenerator classes that use a BN for computing their outputs.
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
DBRowGeneratorGoal
the type of things that a DBRowGenerator is designed for
include the inlined functions if necessary
Definition CSVParser.h:54
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
STL namespace.