50#ifndef DOXYGEN_SHOULD_SKIP_THIS
57 template <
typename GUM_SCALAR >
59 const std::vector< DBTranslatedValueType >& column_types,
60 const BayesNet< GUM_SCALAR >& bn,
61 const Bijection< NodeId, std::size_t >& nodeId2columns) :
66 _filled_row1_(bn.size(), 1.0), _filled_row2_(bn.size(), 1.0) {
69 GUM_CONSTRUCTOR(DBRowGeneratorEM);
73 template <
typename GUM_SCALAR >
74 DBRowGeneratorEM< GUM_SCALAR >::DBRowGeneratorEM(
const DBRowGeneratorEM< GUM_SCALAR >& from) :
75 DBRowGeneratorWithBN< GUM_SCALAR >(from), _input_row_(from._input_row_),
76 _missing_cols_(from._missing_cols_), _nb_miss_(from._nb_miss_),
77 _joint_proba_(from._joint_proba_), _filled_row1_(from._filled_row1_),
78 _filled_row2_(from._filled_row2_), _use_filled_row1_(from._use_filled_row1_),
79 _original_weight_(from._original_weight_) {
80 if (from._joint_inst_ !=
nullptr) {
81 _joint_inst_ = new Instantiation(_joint_proba_);
82 const auto& var_seq = _joint_inst_->variablesSequence();
83 const std::size_t size = var_seq.size();
84 for (std::size_t i = std::size_t(0); i < size; ++i) {
85 _joint_inst_->chgVal(Idx(i), from._joint_inst_->val(i));
89 GUM_CONS_CPY(DBRowGeneratorEM);
93 template <
typename GUM_SCALAR >
94 DBRowGeneratorEM< GUM_SCALAR >::DBRowGeneratorEM(DBRowGeneratorEM< GUM_SCALAR >&& from) :
95 DBRowGeneratorWithBN< GUM_SCALAR >(
std::move(from)), _input_row_(from._input_row_),
96 _missing_cols_(
std::move(from._missing_cols_)), _nb_miss_(from._nb_miss_),
97 _joint_proba_(
std::move(from._joint_proba_)), _filled_row1_(
std::move(from._filled_row1_)),
98 _filled_row2_(
std::move(from._filled_row2_)), _use_filled_row1_(from._use_filled_row1_),
99 _original_weight_(from._original_weight_) {
100 if (from._joint_inst_ !=
nullptr) {
101 _joint_inst_ = new Instantiation(_joint_proba_);
102 const auto& var_seq = _joint_inst_->variablesSequence();
103 const std::size_t size = var_seq.size();
104 for (std::size_t i = std::size_t(0); i < size; ++i) {
105 _joint_inst_->chgVal(Idx(i), from._joint_inst_->val(i));
109 GUM_CONS_MOV(DBRowGeneratorEM);
113 template <
typename GUM_SCALAR >
114 DBRowGeneratorEM< GUM_SCALAR >* DBRowGeneratorEM< GUM_SCALAR >::clone()
const {
115 return new DBRowGeneratorEM< GUM_SCALAR >(*
this);
119 template <
typename GUM_SCALAR >
120 DBRowGeneratorEM< GUM_SCALAR >::~DBRowGeneratorEM() {
121 if (_joint_inst_ !=
nullptr)
delete _joint_inst_;
122 GUM_DESTRUCTOR(DBRowGeneratorEM);
126 template <
typename GUM_SCALAR >
127 DBRowGeneratorEM< GUM_SCALAR >&
128 DBRowGeneratorEM< GUM_SCALAR >::operator=(
const DBRowGeneratorEM< GUM_SCALAR >& from) {
130 DBRowGeneratorWithBN< GUM_SCALAR >::operator=(from);
131 _input_row_ = from._input_row_;
132 _missing_cols_ = from._missing_cols_;
133 _nb_miss_ = from._nb_miss_;
134 _joint_proba_ = from._joint_proba_;
135 _filled_row1_ = from._filled_row1_;
136 _filled_row2_ = from._filled_row2_;
137 _use_filled_row1_ = from._use_filled_row1_;
138 _original_weight_ = from._original_weight_;
140 if (_joint_inst_ !=
nullptr) {
142 _joint_inst_ =
nullptr;
145 if (from._joint_inst_ !=
nullptr) {
146 _joint_inst_ =
new Instantiation(_joint_proba_);
147 const auto& var_seq = _joint_inst_->variablesSequence();
148 const std::size_t size = var_seq.size();
149 for (std::size_t i = std::size_t(0); i < size; ++i) {
150 _joint_inst_->chgVal(Idx(i), from._joint_inst_->val(i));
159 template <
typename GUM_SCALAR >
160 DBRowGeneratorEM< GUM_SCALAR >&
161 DBRowGeneratorEM< GUM_SCALAR >::operator=(DBRowGeneratorEM< GUM_SCALAR >&& from) {
163 DBRowGeneratorWithBN< GUM_SCALAR >::operator=(std::move(from));
164 _input_row_ = from._input_row_;
165 _missing_cols_ = std::move(from._missing_cols_);
166 _nb_miss_ = from._nb_miss_;
167 _joint_proba_ = std::move(from._joint_proba_);
168 _filled_row1_ = std::move(from._filled_row1_);
169 _filled_row2_ = std::move(from._filled_row2_);
170 _use_filled_row1_ = from._use_filled_row1_;
171 _original_weight_ = from._original_weight_;
173 if (_joint_inst_ !=
nullptr) {
175 _joint_inst_ =
nullptr;
178 if (from._joint_inst_ !=
nullptr) {
179 _joint_inst_ =
new Instantiation(_joint_proba_);
180 const auto& var_seq = _joint_inst_->variablesSequence();
181 const std::size_t size = var_seq.size();
182 for (std::size_t i = std::size_t(0); i < size; ++i) {
183 _joint_inst_->chgVal(Idx(i), from._joint_inst_->val(i));
192 template <
typename GUM_SCALAR >
193 INLINE
const DBRow< DBTranslatedValue >& DBRowGeneratorEM< GUM_SCALAR >::generate() {
194 this->decreaseRemainingRows();
197 if (_input_row_ !=
nullptr)
return *_input_row_;
199 if (_use_filled_row1_) {
201 _filled_row1_.setWeight(_joint_proba_.get(*_joint_inst_) * _original_weight_);
204 for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i)
205 _filled_row1_[_missing_cols_[i]].discr_val = _joint_inst_->val(i);
208 _use_filled_row1_ =
false;
210 return _filled_row1_;
213 _filled_row2_.setWeight(_joint_proba_.get(*_joint_inst_) * _original_weight_);
216 for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i)
217 _filled_row2_[_missing_cols_[i]].discr_val = _joint_inst_->val(i);
220 _use_filled_row1_ =
true;
222 return _filled_row2_;
227 template <
typename GUM_SCALAR >
229 DBRowGeneratorEM< GUM_SCALAR >::computeRows_(
const DBRow< DBTranslatedValue >& row) {
232 bool found_unobserved =
false;
233 const auto& xrow = row.row();
234 for (
const auto col: this->columns_of_interest_) {
235 switch (this->column_types_[col]) {
236 case DBTranslatedValueType::DISCRETE :
237 if (xrow[col].discr_val == std::numeric_limits< std::size_t >::max()) {
238 if (!found_unobserved) {
239 _missing_cols_.clear();
240 found_unobserved =
true;
242 _missing_cols_.push_back(col);
246 case DBTranslatedValueType::CONTINUOUS :
248 "The BDRowGeneratorEM does not handle yet continuous "
249 <<
"variables. But the variable in column" << col <<
" is continuous.");
254 "DBTranslatedValueType " <<
int(this->column_types_[col])
255 <<
" is not supported yet");
260 if (!found_unobserved) {
262 return std::size_t(1);
265 _input_row_ =
nullptr;
266 _nb_miss_ = _missing_cols_.size();
267 _original_weight_ = row.weight();
271 VariableElimination< GUM_SCALAR > ve(this->bn_);
275 if (this->nodeId2columns_.empty()) {
276 std::size_t i = std::size_t(0);
277 bool end_miss =
false;
278 for (
const auto col: this->columns_of_interest_) {
279 if (!end_miss && (col == _missing_cols_[i])) {
280 target_set.insert(NodeId(col));
282 if (i == _nb_miss_) end_miss =
true;
284 _filled_row1_[col].discr_val = xrow[col].discr_val;
285 _filled_row2_[col].discr_val = xrow[col].discr_val;
289 std::size_t i = std::size_t(0);
290 bool end_miss =
false;
291 for (
const auto col: this->columns_of_interest_) {
292 if (!end_miss && (col == _missing_cols_[i])) {
293 target_set.insert(this->nodeId2columns_.first(col));
295 if (i == _nb_miss_) end_miss =
true;
297 _filled_row1_[col].discr_val = xrow[col].discr_val;
298 _filled_row2_[col].discr_val = xrow[col].discr_val;
303 ve.addJointTarget(target_set);
306 const std::size_t row_size = xrow.size();
307 if (this->nodeId2columns_.empty()) {
308 for (std::size_t col = std::size_t(0); col < row_size; ++col) {
309 switch (this->column_types_[col]) {
310 case DBTranslatedValueType::DISCRETE :
312 if (xrow[col].discr_val != std::numeric_limits< std::size_t >::max()) {
313 ve.addEvidence(NodeId(col), xrow[col].discr_val);
317 case DBTranslatedValueType::CONTINUOUS :
319 "The BDRowGeneratorEM does not handle yet continuous "
320 <<
"variables. But the variable in column" << col <<
" is continuous.");
325 "DBTranslatedValueType " <<
int(this->column_types_[col])
326 <<
" is not supported yet");
330 for (std::size_t col = std::size_t(0); col < row_size; ++col) {
331 switch (this->column_types_[col]) {
332 case DBTranslatedValueType::DISCRETE :
334 if (xrow[col].discr_val != std::numeric_limits< std::size_t >::max()) {
335 ve.addEvidence(this->nodeId2columns_.first(col), xrow[col].discr_val);
339 case DBTranslatedValueType::CONTINUOUS :
341 "The BDRowGeneratorEM does not handle yet continuous "
342 <<
"variables. But the variable in column" << col <<
" is continuous.");
347 "DBTranslatedValueType " <<
int(this->column_types_[col])
348 <<
" is not supported yet");
354 Tensor< GUM_SCALAR >& pot
355 =
const_cast< Tensor< GUM_SCALAR >&
>(ve.jointPosterior(target_set));
356 _joint_proba_ = std::move(pot);
357 if (_joint_inst_ !=
nullptr)
delete _joint_inst_;
358 _joint_inst_ =
new Instantiation(_joint_proba_);
362 const auto& var_sequence = _joint_proba_.variablesSequence();
363 if (this->nodeId2columns_.empty()) {
364 for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i) {
365 _missing_cols_[i] = std::size_t(this->bn_->nodeId(*(var_sequence[i])));
368 for (std::size_t i = std::size_t(0); i < _nb_miss_; ++i) {
369 _missing_cols_[i] = this->nodeId2columns_.second(this->bn_->nodeId(*(var_sequence[i])));
373 return std::size_t(_joint_proba_.domainSize());
377 template <
typename GUM_SCALAR >
378 void DBRowGeneratorEM< GUM_SCALAR >::setBayesNet(
const BayesNet< GUM_SCALAR >& new_bn) {
381 if (!this->nodeId2columns_.empty()) {
382 const DAG& dag = new_bn.dag();
383 for (
auto iter = this->nodeId2columns_.begin(); iter != this->nodeId2columns_.end();
385 if (!dag.existsNode(iter.first())) {
387 "Column " << iter.second() <<
" of the database is associated to Node ID "
389 <<
", which does not belong to the Bayesian network");
394 DBRowGeneratorWithBN< GUM_SCALAR >::setBayesNet(new_bn);
397 std::size_t size = std::size_t(0);
398 if (this->nodeId2columns_.empty()) {
399 for (
auto node: new_bn.dag())
400 if (std::size_t(node) > size) size = std::size_t(node);
402 for (
auto iter = this->nodeId2columns_.begin(); iter != this->nodeId2columns_.end();
404 if (iter.second() > size) size = iter.second();
407 _filled_row1_.resize(size + 1);
408 _filled_row2_.resize(size + 1);
A DBRowGenerator class that returns exactly the rows it gets in input.
Exception : node does not exist.
Exception : there is something wrong with an implementation.
DBRowGeneratorEM(const std::vector< DBTranslatedValueType > &column_types, const BayesNet< GUM_SCALAR > &bn, const Bijection< NodeId, std::size_t > &nodeId2columns=Bijection< NodeId, std::size_t >())
default constructor
Base class for DBRowGenerator classes that use a BN for computing their outputs.
#define GUM_ERROR(type, msg)
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
DBRowGeneratorGoal
the type of things that a DBRowGenerator is designed for
@ ONLY_REMOVE_MISSING_VALUES
include the inlined functions if necessary
gum is the global namespace for all aGrUM entities