50#ifndef DOXYGEN_SHOULD_SKIP_THIS
63 const std::vector< std::pair< std::size_t, std::size_t > >& ranges,
64 const Bijection< NodeId, std::size_t >& nodeId2columns) :
65 _nodeId2columns_(nodeId2columns) {
67 const std::size_t db_nb_cols = parser.database().nbVariables();
68 for (
auto iter = nodeId2columns.cbegin(); iter != nodeId2columns.cend(); ++iter) {
69 if (iter.second() >= db_nb_cols) {
70 GUM_ERROR(OutOfBounds,
71 "the mapping between ids and database columns "
72 <<
"is incorrect because Column " << iter.second()
73 <<
" does not belong to the database.");
78 const auto max_nb_threads = ThreadNumberManager::getNumberOfThreads();
79 _parsers_.reserve(max_nb_threads);
80 for (std::size_t i = std::size_t(0); i < max_nb_threads; ++i)
81 _parsers_.push_back(parser);
85 _checkRanges_(ranges);
86 _ranges_.reserve(ranges.size());
87 for (
const auto& range: ranges)
88 _ranges_.push_back(range);
91 _dispatchRangesToThreads_();
93 GUM_CONSTRUCTOR(RecordCounter);
97 RecordCounter::RecordCounter(
const DBRowGeneratorParser& parser,
98 const Bijection< NodeId, std::size_t >& nodeId2columns) :
104 RecordCounter::RecordCounter(
const RecordCounter& from) :
105 ThreadNumberManager(from), _parsers_(from._parsers_), _ranges_(from._ranges_),
106 _thread_ranges_(from._thread_ranges_), _nodeId2columns_(from._nodeId2columns_),
107 _last_DB_counting_(from._last_DB_counting_), _last_DB_ids_(from._last_DB_ids_),
108 _last_nonDB_counting_(from._last_nonDB_counting_), _last_nonDB_ids_(from._last_nonDB_ids_),
109 _min_nb_rows_per_thread_(from._min_nb_rows_per_thread_) {
110 GUM_CONS_CPY(RecordCounter);
114 RecordCounter::RecordCounter(RecordCounter&& from) :
115 ThreadNumberManager(
std::move(from)), _parsers_(
std::move(from._parsers_)),
116 _ranges_(
std::move(from._ranges_)), _thread_ranges_(
std::move(from._thread_ranges_)),
117 _nodeId2columns_(
std::move(from._nodeId2columns_)),
118 _last_DB_counting_(
std::move(from._last_DB_counting_)),
119 _last_DB_ids_(
std::move(from._last_DB_ids_)),
120 _last_nonDB_counting_(
std::move(from._last_nonDB_counting_)),
121 _last_nonDB_ids_(
std::move(from._last_nonDB_ids_)),
122 _min_nb_rows_per_thread_(from._min_nb_rows_per_thread_) {
123 GUM_CONS_MOV(RecordCounter);
127 RecordCounter* RecordCounter::clone()
const {
return new RecordCounter(*
this); }
130 RecordCounter::~RecordCounter() { GUM_DESTRUCTOR(RecordCounter); }
133 RecordCounter& RecordCounter::operator=(
const RecordCounter& from) {
135 ThreadNumberManager::operator=(from);
136 _parsers_ = from._parsers_;
137 _ranges_ = from._ranges_;
138 _thread_ranges_ = from._thread_ranges_;
139 _nodeId2columns_ = from._nodeId2columns_;
140 _last_DB_counting_ = from._last_DB_counting_;
141 _last_DB_ids_ = from._last_DB_ids_;
142 _last_nonDB_counting_ = from._last_nonDB_counting_;
143 _last_nonDB_ids_ = from._last_nonDB_ids_;
144 _min_nb_rows_per_thread_ = from._min_nb_rows_per_thread_;
150 RecordCounter& RecordCounter::operator=(RecordCounter&& from) {
152 ThreadNumberManager::operator=(std::move(from));
153 _parsers_ = std::move(from._parsers_);
154 _ranges_ = std::move(from._ranges_);
155 _thread_ranges_ = std::move(from._thread_ranges_);
156 _nodeId2columns_ = std::move(from._nodeId2columns_);
157 _last_DB_counting_ = std::move(from._last_DB_counting_);
158 _last_DB_ids_ = std::move(from._last_DB_ids_);
159 _last_nonDB_counting_ = std::move(from._last_nonDB_counting_);
160 _last_nonDB_ids_ = std::move(from._last_nonDB_ids_);
161 _min_nb_rows_per_thread_ = from._min_nb_rows_per_thread_;
167 void RecordCounter::clear() {
168 _last_DB_counting_.clear();
169 _last_DB_ids_.clear();
170 _last_nonDB_counting_.clear();
171 _last_nonDB_ids_.clear();
176 void RecordCounter::setMinNbRowsPerThread(
const std::size_t nb)
const {
177 if (nb == std::size_t(0)) _min_nb_rows_per_thread_ = std::size_t(1);
178 else _min_nb_rows_per_thread_ = nb;
182 void RecordCounter::_raiseCheckException_(
const std::vector< std::string >& bad_vars)
const {
184 std::stringstream msg;
185 msg <<
"Counts cannot be performed on continuous variables. ";
186 msg <<
"Unfortunately the following variable";
187 if (bad_vars.size() == 1) msg <<
" is continuous: " << bad_vars[0];
189 msg <<
"s are continuous: ";
191 for (
const auto& name: bad_vars) {
192 if (deja) msg <<
", ";
201 void RecordCounter::_checkDiscreteVariables_(
const IdCondSet& ids)
const {
202 const std::size_t size = ids.size();
203 const DatabaseTable& database = _parsers_[0].data.database();
205 if (_nodeId2columns_.empty()) {
207 for (std::size_t i = std::size_t(0); i < size; ++i) {
208 if (database.variable(i).varType() == VarType::CONTINUOUS) {
212 std::vector< std::string > bad_vars{database.variable(i).name()};
213 for (++i; i < size; ++i) {
214 if (database.variable(i).varType() == VarType::CONTINUOUS)
215 bad_vars.push_back(database.variable(i).name());
217 _raiseCheckException_(bad_vars);
222 for (std::size_t i = std::size_t(0); i < size; ++i) {
224 std::size_t pos = _nodeId2columns_.second(ids[i]);
226 if (database.variable(pos).varType() == VarType::CONTINUOUS) {
230 std::vector< std::string > bad_vars{database.variable(pos).name()};
231 for (++i; i < size; ++i) {
232 pos = _nodeId2columns_.second(ids[i]);
233 if (database.variable(pos).varType() == VarType::CONTINUOUS)
234 bad_vars.push_back(database.variable(pos).name());
236 _raiseCheckException_(bad_vars);
244 HashTable< NodeId, std::size_t >
245 RecordCounter::_getNodeIds2Columns_(
const IdCondSet& ids)
const {
246 HashTable< NodeId, std::size_t > res(ids.size());
247 if (_nodeId2columns_.empty()) {
248 for (
const auto id: ids) {
249 res.insert(
id, std::size_t(
id));
252 for (
const auto id: ids) {
253 res.insert(
id, _nodeId2columns_.second(
id));
260 std::vector< double >&
261 RecordCounter::_extractFromCountings_(
const IdCondSet& subset_ids,
262 const IdCondSet& superset_ids,
263 const std::vector< double >& superset_vect) {
267 const auto nodeId2columns = _getNodeIds2Columns_(superset_ids);
271 const auto& database = _parsers_[0].data.database();
272 std::size_t result_vect_size = std::size_t(1);
273 for (
const auto id: subset_ids) {
274 result_vect_size *= database.domainSize(nodeId2columns[
id]);
278 std::vector< double > result_vect(result_vect_size, 0.0);
284 bool subset_begin =
true;
285 const std::size_t subset_ids_size = std::size_t(subset_ids.size());
286 for (std::size_t i = 0; i < subset_ids_size; ++i) {
287 if (superset_ids.pos(subset_ids[i]) != i) {
288 subset_begin =
false;
294 const std::size_t superset_vect_size = superset_vect.size();
295 std::size_t i = std::size_t(0);
296 while (i < superset_vect_size) {
297 for (std::size_t j = std::size_t(0); j < result_vect_size; ++j, ++i) {
298 result_vect[j] += superset_vect[i];
304 _last_nonDB_ids_ = subset_ids;
305 _last_nonDB_counting_ = std::move(result_vect);
306 return _last_nonDB_counting_;
308 _last_nonDB_ids_.clear();
309 _last_nonDB_counting_.clear();
318 bool subset_end =
true;
319 const std::size_t superset_ids_size = std::size_t(superset_ids.size());
320 for (std::size_t i = 0; i < subset_ids_size; ++i) {
321 if (superset_ids.pos(subset_ids[i]) != i + superset_ids_size - subset_ids_size) {
330 std::size_t vect_not_subset_size = std::size_t(1);
331 for (std::size_t i = std::size_t(0); i < superset_ids_size - subset_ids_size; ++i)
332 vect_not_subset_size *= database.domainSize(nodeId2columns[superset_ids[i]]);
335 std::size_t i = std::size_t(0);
336 for (std::size_t j = std::size_t(0); j < result_vect_size; ++j) {
337 for (std::size_t k = std::size_t(0); k < vect_not_subset_size; ++k, ++i) {
338 result_vect[j] += superset_vect[i];
344 _last_nonDB_ids_ = subset_ids;
345 _last_nonDB_counting_ = std::move(result_vect);
346 return _last_nonDB_counting_;
348 _last_nonDB_ids_.clear();
349 _last_nonDB_counting_.clear();
380 std::vector< std::size_t > before_incr(subset_ids_size);
381 std::vector< std::size_t > result_domain(subset_ids_size);
382 std::vector< std::size_t > result_offset(subset_ids_size);
384 std::size_t result_domain_size = std::size_t(1);
385 std::size_t tmp_before_incr = std::size_t(1);
386 std::vector< std::size_t > superset_order(subset_ids_size);
388 for (std::size_t h = std::size_t(0), j = std::size_t(0); j < subset_ids_size; ++h) {
389 if (subset_ids.exists(superset_ids[h])) {
390 before_incr[j] = tmp_before_incr - 1;
391 superset_order[subset_ids.pos(superset_ids[h])] = j;
395 tmp_before_incr *= database.domainSize(nodeId2columns[superset_ids[h]]);
400 for (std::size_t i = 0; i < subset_ids.size(); ++i) {
401 const std::size_t domain_size = database.domainSize(nodeId2columns[subset_ids[i]]);
402 const std::size_t j = superset_order[i];
403 result_domain[j] = domain_size;
404 result_offset[j] = result_domain_size;
405 result_domain_size *= domain_size;
409 std::vector< std::size_t > result_value(result_domain);
410 std::vector< std::size_t > current_incr(before_incr);
411 std::vector< std::size_t > result_down(result_offset);
413 for (std::size_t j = std::size_t(0); j < result_down.size(); ++j) {
414 result_down[j] *= (result_domain[j] - 1);
418 const std::size_t superset_vect_size = superset_vect.size();
419 std::size_t the_result_offset = std::size_t(0);
420 for (std::size_t h = std::size_t(0); h < superset_vect_size; ++h) {
421 result_vect[the_result_offset] += superset_vect[h];
424 for (std::size_t k = 0; k < current_incr.size(); ++k) {
426 if (current_incr[k]) {
431 current_incr[k] = before_incr[k];
436 if (result_value[k]) {
437 the_result_offset += result_offset[k];
441 result_value[k] = result_domain[k];
442 the_result_offset -= result_down[k];
448 _last_nonDB_ids_ = subset_ids;
449 _last_nonDB_counting_ = std::move(result_vect);
450 return _last_nonDB_counting_;
452 _last_nonDB_ids_.clear();
453 _last_nonDB_counting_.clear();
459 std::vector< double >& RecordCounter::_countFromDatabase_(
const IdCondSet& ids) {
462 const auto& database = _parsers_[0].data.database();
463 if (ids.empty() || database.empty() || _thread_ranges_.empty()) {
464 _last_nonDB_counting_.clear();
465 _last_nonDB_ids_.clear();
466 return _last_nonDB_counting_;
471 const auto nodeId2columns = _getNodeIds2Columns_(ids);
475 const std::size_t ids_size = ids.size();
476 std::size_t counting_vect_size = std::size_t(1);
477 std::vector< std::size_t > domain_sizes(ids_size);
478 std::vector< std::pair< std::size_t, std::size_t > > cols_offsets(ids_size);
480 std::size_t i = std::size_t(0);
481 for (
const auto id: ids) {
482 const std::size_t domain_size = database.domainSize(nodeId2columns[
id]);
483 domain_sizes[i] = domain_size;
484 cols_offsets[i].first = nodeId2columns[id];
485 cols_offsets[i].second = counting_vect_size;
486 counting_vect_size *= domain_size;
494 cols_offsets.begin(),
496 [](
const std::pair< std::size_t, std::size_t >& a,
497 const std::pair< std::size_t, std::size_t >& b) ->
bool { return a.first < b.first; });
500 const std::size_t nb_ranges = _thread_ranges_.size();
501 const auto max_nb_threads = ThreadNumberManager::getNumberOfThreads();
502 const std::size_t nb_threads = nb_ranges <= max_nb_threads ? nb_ranges : max_nb_threads;
503 while (_parsers_.size() < nb_threads) {
504 ThreadData< DBRowGeneratorParser > new_parser(_parsers_[0]);
505 _parsers_.push_back(std::move(new_parser));
511 std::vector< std::size_t > cols_of_interest(ids_size);
512 for (std::size_t i = std::size_t(0); i < ids_size; ++i) {
513 cols_of_interest[i] = cols_offsets[i].first;
515 for (
auto& parser: _parsers_) {
516 parser.data.setColumnsOfInterest(cols_of_interest);
522 std::vector< double > counting_vect(counting_vect_size, 0.0);
523 std::vector< ThreadData< std::vector< double > > > thread_countings(
525 ThreadData< std::vector< double > >(counting_vect));
529 auto threadedCount = [
this, nb_ranges, ids_size, &thread_countings, cols_offsets](
530 const std::size_t this_thread,
531 const std::size_t nb_threads,
532 const std::size_t nb_loop) ->
void {
533 if (this_thread + nb_loop < nb_ranges) {
535 DBRowGeneratorParser& parser = this->_parsers_[this_thread].data;
536 parser.setRange(this->_thread_ranges_[this_thread + nb_loop].first,
537 this->_thread_ranges_[this_thread + nb_loop].second);
538 std::vector< double >& counts = thread_countings[this_thread].data;
542 while (parser.hasRows()) {
544 const DBRow< DBTranslatedValue >& row = parser.row();
547 std::size_t offset = std::size_t(0);
548 for (std::size_t i = std::size_t(0); i < ids_size; ++i) {
549 offset += row[cols_offsets[i].first].discr_val * cols_offsets[i].second;
552 counts[offset] += row.weight();
562 for (std::size_t i = std::size_t(0); i < nb_ranges; i += nb_threads) {
563 ThreadExecutor::execute(nb_threads, threadedCount, i);
568 for (std::size_t k = std::size_t(0); k < nb_threads; ++k) {
569 const auto& thread_counting = thread_countings[k].data;
570 for (std::size_t r = std::size_t(0); r < counting_vect_size; ++r) {
571 counting_vect[r] += thread_counting[r];
577 _last_DB_counting_ = std::move(counting_vect);
579 return _last_DB_counting_;
583 void RecordCounter::_checkRanges_(
584 const std::vector< std::pair< std::size_t, std::size_t > >& new_ranges)
const {
585 const std::size_t dbsize = _parsers_[0].data.database().nbRows();
586 std::vector< std::pair< std::size_t, std::size_t > > incorrect_ranges;
587 for (
const auto& range: new_ranges) {
588 if ((range.first >= range.second) || (range.second > dbsize)) {
589 incorrect_ranges.push_back(range);
592 if (!incorrect_ranges.empty()) {
593 std::stringstream str;
594 str <<
"It is impossible to set the ranges because the following one";
595 if (incorrect_ranges.size() > 1) str <<
"s are incorrect: ";
596 else str <<
" is incorrect: ";
598 for (
const auto& range: incorrect_ranges) {
599 if (deja) str <<
", ";
601 str <<
'[' << range.first <<
';' << range.second <<
')';
609 void RecordCounter::_dispatchRangesToThreads_() {
610 _thread_ranges_.clear();
613 bool add_range =
false;
614 if (_ranges_.empty()) {
615 const auto& database = _parsers_[0].data.database();
617 std::pair< std::size_t, std::size_t >(std::size_t(0), database.nbRows()));
622 const auto max_nb_threads = ThreadNumberManager::getNumberOfThreads();
623 for (
const auto& range: _ranges_) {
624 if (range.second > range.first) {
625 const std::size_t range_size = range.second - range.first;
626 std::size_t nb_threads = range_size / _min_nb_rows_per_thread_;
627 if (nb_threads < 1) nb_threads = 1;
628 else if (nb_threads > max_nb_threads) nb_threads = max_nb_threads;
629 std::size_t nb_rows_par_thread = range_size / nb_threads;
630 std::size_t rest_rows = range_size - nb_rows_par_thread * nb_threads;
632 std::size_t begin_index = range.first;
633 for (std::size_t i = std::size_t(0); i < nb_threads; ++i) {
634 std::size_t end_index = begin_index + nb_rows_par_thread;
635 if (rest_rows != std::size_t(0)) {
639 _thread_ranges_.push_back(
640 std::pair< std::size_t, std::size_t >(begin_index, end_index));
641 begin_index = end_index;
645 if (add_range) _ranges_.clear();
651 std::sort(_thread_ranges_.begin(),
652 _thread_ranges_.end(),
653 [](
const std::pair< std::size_t, std::size_t >& a,
654 const std::pair< std::size_t, std::size_t >& b) ->
bool {
655 return (a.second - a.first) > (b.second - b.first);
660 void RecordCounter::setRanges(
661 const std::vector< std::pair< std::size_t, std::size_t > >& new_ranges) {
663 _checkRanges_(new_ranges);
666 const std::size_t new_size = new_ranges.size();
667 std::vector< std::pair< std::size_t, std::size_t > > ranges(new_size);
668 for (std::size_t i = std::size_t(0); i < new_size; ++i) {
669 ranges[i].first = new_ranges[i].first;
670 ranges[i].second = new_ranges[i].second;
674 _ranges_ = std::move(ranges);
677 _dispatchRangesToThreads_();
681 void RecordCounter::clearRanges() {
682 if (_ranges_.empty())
return;
685 _dispatchRangesToThreads_();
Exception : the element we looked for cannot be found.
Exception : out of bound.
Exception : wrong type for this operation.
the class used to read a row in the database and to transform it into a set of DBRow instances that c...
RecordCounter(const DBRowGeneratorParser &parser, const std::vector< std::pair< std::size_t, std::size_t > > &ranges, const Bijection< NodeId, std::size_t > &nodeId2columns=Bijection< NodeId, std::size_t >())
default constructor
#define GUM_ERROR(type, msg)
include the inlined functions if necessary
gum is the global namespace for all aGrUM entities
The class that computes counting of observations from the database.