aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
recordCounter.cpp
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40
41
47
49
50#ifndef DOXYGEN_SHOULD_SKIP_THIS
51
53# ifdef GUM_NO_INLINE
55# endif /* GUM_NO_INLINE */
56
57namespace gum {
58
59 namespace learning {
60
63 const std::vector< std::pair< std::size_t, std::size_t > >& ranges,
64 const Bijection< NodeId, std::size_t >& nodeId2columns) :
65 _nodeId2columns_(nodeId2columns) {
66 // check that the columns in nodeId2columns do belong to the database
67 const std::size_t db_nb_cols = parser.database().nbVariables();
68 for (auto iter = nodeId2columns.cbegin(); iter != nodeId2columns.cend(); ++iter) {
69 if (iter.second() >= db_nb_cols) {
70 GUM_ERROR(OutOfBounds,
71 "the mapping between ids and database columns "
72 << "is incorrect because Column " << iter.second()
73 << " does not belong to the database.");
74 }
75 }
76
77 // create the parsers. There should always be at least one parser
78 const auto max_nb_threads = ThreadNumberManager::getNumberOfThreads();
79 _parsers_.reserve(max_nb_threads);
80 for (std::size_t i = std::size_t(0); i < max_nb_threads; ++i)
81 _parsers_.push_back(parser);
82
83 // check that the ranges are within the bounds of the database and
84 // save them
85 _checkRanges_(ranges);
86 _ranges_.reserve(ranges.size());
87 for (const auto& range: ranges)
88 _ranges_.push_back(range);
89
90 // dispatch the ranges for the threads
91 _dispatchRangesToThreads_();
92
93 GUM_CONSTRUCTOR(RecordCounter);
94 }
95
97 RecordCounter::RecordCounter(const DBRowGeneratorParser& parser,
98 const Bijection< NodeId, std::size_t >& nodeId2columns) :
99 RecordCounter(parser,
100 std::vector< std::pair< std::size_t, std::size_t > >(),
101 nodeId2columns) {}
102
104 RecordCounter::RecordCounter(const RecordCounter& from) :
105 ThreadNumberManager(from), _parsers_(from._parsers_), _ranges_(from._ranges_),
106 _thread_ranges_(from._thread_ranges_), _nodeId2columns_(from._nodeId2columns_),
107 _last_DB_counting_(from._last_DB_counting_), _last_DB_ids_(from._last_DB_ids_),
108 _last_nonDB_counting_(from._last_nonDB_counting_), _last_nonDB_ids_(from._last_nonDB_ids_),
109 _min_nb_rows_per_thread_(from._min_nb_rows_per_thread_) {
110 GUM_CONS_CPY(RecordCounter);
111 }
112
114 RecordCounter::RecordCounter(RecordCounter&& from) :
115 ThreadNumberManager(std::move(from)), _parsers_(std::move(from._parsers_)),
116 _ranges_(std::move(from._ranges_)), _thread_ranges_(std::move(from._thread_ranges_)),
117 _nodeId2columns_(std::move(from._nodeId2columns_)),
118 _last_DB_counting_(std::move(from._last_DB_counting_)),
119 _last_DB_ids_(std::move(from._last_DB_ids_)),
120 _last_nonDB_counting_(std::move(from._last_nonDB_counting_)),
121 _last_nonDB_ids_(std::move(from._last_nonDB_ids_)),
122 _min_nb_rows_per_thread_(from._min_nb_rows_per_thread_) {
123 GUM_CONS_MOV(RecordCounter);
124 }
125
127 RecordCounter* RecordCounter::clone() const { return new RecordCounter(*this); }
128
130 RecordCounter::~RecordCounter() { GUM_DESTRUCTOR(RecordCounter); }
131
133 RecordCounter& RecordCounter::operator=(const RecordCounter& from) {
134 if (this != &from) {
135 ThreadNumberManager::operator=(from);
136 _parsers_ = from._parsers_;
137 _ranges_ = from._ranges_;
138 _thread_ranges_ = from._thread_ranges_;
139 _nodeId2columns_ = from._nodeId2columns_;
140 _last_DB_counting_ = from._last_DB_counting_;
141 _last_DB_ids_ = from._last_DB_ids_;
142 _last_nonDB_counting_ = from._last_nonDB_counting_;
143 _last_nonDB_ids_ = from._last_nonDB_ids_;
144 _min_nb_rows_per_thread_ = from._min_nb_rows_per_thread_;
145 }
146 return *this;
147 }
148
150 RecordCounter& RecordCounter::operator=(RecordCounter&& from) {
151 if (this != &from) {
152 ThreadNumberManager::operator=(std::move(from));
153 _parsers_ = std::move(from._parsers_);
154 _ranges_ = std::move(from._ranges_);
155 _thread_ranges_ = std::move(from._thread_ranges_);
156 _nodeId2columns_ = std::move(from._nodeId2columns_);
157 _last_DB_counting_ = std::move(from._last_DB_counting_);
158 _last_DB_ids_ = std::move(from._last_DB_ids_);
159 _last_nonDB_counting_ = std::move(from._last_nonDB_counting_);
160 _last_nonDB_ids_ = std::move(from._last_nonDB_ids_);
161 _min_nb_rows_per_thread_ = from._min_nb_rows_per_thread_;
162 }
163 return *this;
164 }
165
167 void RecordCounter::clear() {
168 _last_DB_counting_.clear();
169 _last_DB_ids_.clear();
170 _last_nonDB_counting_.clear();
171 _last_nonDB_ids_.clear();
172 }
173
174 // changes the number min of rows a thread should process in a
175 // multithreading context
176 void RecordCounter::setMinNbRowsPerThread(const std::size_t nb) const {
177 if (nb == std::size_t(0)) _min_nb_rows_per_thread_ = std::size_t(1);
178 else _min_nb_rows_per_thread_ = nb;
179 }
180
182 void RecordCounter::_raiseCheckException_(const std::vector< std::string >& bad_vars) const {
183 // generate the exception
184 std::stringstream msg;
185 msg << "Counts cannot be performed on continuous variables. ";
186 msg << "Unfortunately the following variable";
187 if (bad_vars.size() == 1) msg << " is continuous: " << bad_vars[0];
188 else {
189 msg << "s are continuous: ";
190 bool deja = false;
191 for (const auto& name: bad_vars) {
192 if (deja) msg << ", ";
193 else deja = true;
194 msg << name;
195 }
196 }
197 GUM_ERROR(TypeError, msg.str())
198 }
199
201 void RecordCounter::_checkDiscreteVariables_(const IdCondSet& ids) const {
202 const std::size_t size = ids.size();
203 const DatabaseTable& database = _parsers_[0].data.database();
204
205 if (_nodeId2columns_.empty()) {
206 // check all the ids
207 for (std::size_t i = std::size_t(0); i < size; ++i) {
208 if (database.variable(i).varType() == VarType::CONTINUOUS) {
209 // here, var i does not correspond to a discrete variable.
210 // we check whether there are other non discrete variables, so that
211 // we can generate an exception mentioning all these variables
212 std::vector< std::string > bad_vars{database.variable(i).name()};
213 for (++i; i < size; ++i) {
214 if (database.variable(i).varType() == VarType::CONTINUOUS)
215 bad_vars.push_back(database.variable(i).name());
216 }
217 _raiseCheckException_(bad_vars);
218 }
219 }
220 } else {
221 // check all the ids
222 for (std::size_t i = std::size_t(0); i < size; ++i) {
223 // get the position of the variable in the database
224 std::size_t pos = _nodeId2columns_.second(ids[i]);
225
226 if (database.variable(pos).varType() == VarType::CONTINUOUS) {
227 // here, id does not correspond to a discrete variable.
228 // we check whether there are other non discrete variables, so that
229 // we can generate an exception mentioning all these variables
230 std::vector< std::string > bad_vars{database.variable(pos).name()};
231 for (++i; i < size; ++i) {
232 pos = _nodeId2columns_.second(ids[i]);
233 if (database.variable(pos).varType() == VarType::CONTINUOUS)
234 bad_vars.push_back(database.variable(pos).name());
235 }
236 _raiseCheckException_(bad_vars);
237 }
238 }
239 }
240 }
241
242 // returns a mapping from the nodes ids to the columns of the database
243 // for a given sequence of ids
244 HashTable< NodeId, std::size_t >
245 RecordCounter::_getNodeIds2Columns_(const IdCondSet& ids) const {
246 HashTable< NodeId, std::size_t > res(ids.size());
247 if (_nodeId2columns_.empty()) {
248 for (const auto id: ids) {
249 res.insert(id, std::size_t(id));
250 }
251 } else {
252 for (const auto id: ids) {
253 res.insert(id, _nodeId2columns_.second(id));
254 }
255 }
256 return res;
257 }
258
260 std::vector< double >&
261 RecordCounter::_extractFromCountings_(const IdCondSet& subset_ids,
262 const IdCondSet& superset_ids,
263 const std::vector< double >& superset_vect) {
264 // get a mapping between the node Ids and their columns in the database.
265 // This should be stored into _nodeId2columns_, except if the latter is
266 // empty, in which case there is an identity mapping
267 const auto nodeId2columns = _getNodeIds2Columns_(superset_ids);
268
269 // we first determine the size of the output vector, the domain of
270 // each of its variables and their offsets in the output vector
271 const auto& database = _parsers_[0].data.database();
272 std::size_t result_vect_size = std::size_t(1);
273 for (const auto id: subset_ids) {
274 result_vect_size *= database.domainSize(nodeId2columns[id]);
275 }
276
277 // we create the output vector
278 std::vector< double > result_vect(result_vect_size, 0.0);
279
280 // check if the subset_ids is the beginning of the sequence of superset_ids
281 // if this is the case, then we can outer loop over the variables not in
282 // subset_ids and, for each iteration of this loop add a vector of size
283 // result_size to result_vect
284 bool subset_begin = true;
285 const std::size_t subset_ids_size = std::size_t(subset_ids.size());
286 for (std::size_t i = 0; i < subset_ids_size; ++i) {
287 if (superset_ids.pos(subset_ids[i]) != i) {
288 subset_begin = false;
289 break;
290 }
291 }
292
293 if (subset_begin) {
294 const std::size_t superset_vect_size = superset_vect.size();
295 std::size_t i = std::size_t(0);
296 while (i < superset_vect_size) {
297 for (std::size_t j = std::size_t(0); j < result_vect_size; ++j, ++i) {
298 result_vect[j] += superset_vect[i];
299 }
300 }
301
302 // save the subset_ids and the result vector
303 try {
304 _last_nonDB_ids_ = subset_ids;
305 _last_nonDB_counting_ = std::move(result_vect);
306 return _last_nonDB_counting_;
307 } catch (...) {
308 _last_nonDB_ids_.clear();
309 _last_nonDB_counting_.clear();
310 throw;
311 }
312 }
313
314
315 // check if subset_ids is the end of the sequence of superset_ids.
316 // In this case, as above, there are two simple loops to perform the
317 // counts
318 bool subset_end = true;
319 const std::size_t superset_ids_size = std::size_t(superset_ids.size());
320 for (std::size_t i = 0; i < subset_ids_size; ++i) {
321 if (superset_ids.pos(subset_ids[i]) != i + superset_ids_size - subset_ids_size) {
322 subset_end = false;
323 break;
324 }
325 }
326
327 if (subset_end) {
328 // determine the size of the vector corresponding to the variables
329 // not belonging to subset_ids
330 std::size_t vect_not_subset_size = std::size_t(1);
331 for (std::size_t i = std::size_t(0); i < superset_ids_size - subset_ids_size; ++i)
332 vect_not_subset_size *= database.domainSize(nodeId2columns[superset_ids[i]]);
333
334 // perform the two loops
335 std::size_t i = std::size_t(0);
336 for (std::size_t j = std::size_t(0); j < result_vect_size; ++j) {
337 for (std::size_t k = std::size_t(0); k < vect_not_subset_size; ++k, ++i) {
338 result_vect[j] += superset_vect[i];
339 }
340 }
341
342 // save the subset_ids and the result vector
343 try {
344 _last_nonDB_ids_ = subset_ids;
345 _last_nonDB_counting_ = std::move(result_vect);
346 return _last_nonDB_counting_;
347 } catch (...) {
348 _last_nonDB_ids_.clear();
349 _last_nonDB_counting_.clear();
350 throw;
351 }
352 }
353
354 // here subset_ids is a subset of superset_ids neither prefixing nor
355 // postfixing it. So the computation is somewhat more complicated.
356
357 // We will parse the superset_vect sequentially (using ++ operator).
358 // Sometimes, we will need to change the offset of the cell of result_vect
359 // that will be affected, sometimes not. Vector before_incr will indicate
360 // whether we need to change the offset (value = 0) or not (value different
361 // from 0). Vectors result_domain will indicate how this offset should be
362 // computed. Here is an example of the values of these vectors. Assume that
363 // superset_ids = <A,B,C,D,E> and subset_ids = <A,D,C>. Then, the three
364 // vectors before_incr, result_domain and result_offset are indexed w.r.t.
365 // A,C,D, i.e., w.r.t. to the variables in subset_ids but order w.r.t.
366 // superset_ids (this is convenient as we will parse superset_vect
367 // sequentially. For a variable or a set of variables X, let M_X denote the
368 // domain size of X. Then the contents of the three vectors are as follows:
369 // before_incr = {0, M_B, 0} (this means that whenever we iterate over B's
370 // values, the offset in result_vect does not change)
371 // result_domain = { M_A, M_C, M_D } (i.e., the domain sizes of the variables
372 // in subset_ids, order w.r.t. superset_ids)
373 // result_offset = { 1, M_A*M_D, M_A } (this corresponds to the offsets
374 // in result_vect of variables A, C and D)
375 // Vector superset_order = { 0, 2, 1} : this is a map from the indices of
376 // the variables in subset_ids to the indices of these variables in the
377 // three vectors described above. For instance, the "2" means that variable
378 // D (which is at index 1 in subset_ids) is located at index 2 in vector
379 // before_incr
380 std::vector< std::size_t > before_incr(subset_ids_size);
381 std::vector< std::size_t > result_domain(subset_ids_size);
382 std::vector< std::size_t > result_offset(subset_ids_size);
383 {
384 std::size_t result_domain_size = std::size_t(1);
385 std::size_t tmp_before_incr = std::size_t(1);
386 std::vector< std::size_t > superset_order(subset_ids_size);
387
388 for (std::size_t h = std::size_t(0), j = std::size_t(0); j < subset_ids_size; ++h) {
389 if (subset_ids.exists(superset_ids[h])) {
390 before_incr[j] = tmp_before_incr - 1;
391 superset_order[subset_ids.pos(superset_ids[h])] = j;
392 tmp_before_incr = 1;
393 ++j;
394 } else {
395 tmp_before_incr *= database.domainSize(nodeId2columns[superset_ids[h]]);
396 }
397 }
398
399 // compute the offsets in the order of the superset_ids
400 for (std::size_t i = 0; i < subset_ids.size(); ++i) {
401 const std::size_t domain_size = database.domainSize(nodeId2columns[subset_ids[i]]);
402 const std::size_t j = superset_order[i];
403 result_domain[j] = domain_size;
404 result_offset[j] = result_domain_size;
405 result_domain_size *= domain_size;
406 }
407 }
408
409 std::vector< std::size_t > result_value(result_domain);
410 std::vector< std::size_t > current_incr(before_incr);
411 std::vector< std::size_t > result_down(result_offset);
412
413 for (std::size_t j = std::size_t(0); j < result_down.size(); ++j) {
414 result_down[j] *= (result_domain[j] - 1);
415 }
416
417 // now we can loop over the superset_vect to fill result_vect
418 const std::size_t superset_vect_size = superset_vect.size();
419 std::size_t the_result_offset = std::size_t(0);
420 for (std::size_t h = std::size_t(0); h < superset_vect_size; ++h) {
421 result_vect[the_result_offset] += superset_vect[h];
422
423 // update the offset of result_vect
424 for (std::size_t k = 0; k < current_incr.size(); ++k) {
425 // check if we need modify result_offset
426 if (current_incr[k]) {
427 --current_incr[k];
428 break;
429 }
430
431 current_incr[k] = before_incr[k];
432
433 // here we shall modify result_offset
434 --result_value[k];
435
436 if (result_value[k]) {
437 the_result_offset += result_offset[k];
438 break;
439 }
440
441 result_value[k] = result_domain[k];
442 the_result_offset -= result_down[k];
443 }
444 }
445
446 // save the subset_ids and the result vector
447 try {
448 _last_nonDB_ids_ = subset_ids;
449 _last_nonDB_counting_ = std::move(result_vect);
450 return _last_nonDB_counting_;
451 } catch (...) {
452 _last_nonDB_ids_.clear();
453 _last_nonDB_counting_.clear();
454 throw;
455 }
456 }
457
459 std::vector< double >& RecordCounter::_countFromDatabase_(const IdCondSet& ids) {
460 // if the ids vector is empty or the database is empty, return an
461 // empty vector
462 const auto& database = _parsers_[0].data.database();
463 if (ids.empty() || database.empty() || _thread_ranges_.empty()) {
464 _last_nonDB_counting_.clear();
465 _last_nonDB_ids_.clear();
466 return _last_nonDB_counting_;
467 }
468
469 // we translate the ids into their corresponding columns in the
470 // DatabaseTable
471 const auto nodeId2columns = _getNodeIds2Columns_(ids);
472
473 // we first determine the size of the counting vector, the domain of
474 // each of its variables and their offsets in the output vector
475 const std::size_t ids_size = ids.size();
476 std::size_t counting_vect_size = std::size_t(1);
477 std::vector< std::size_t > domain_sizes(ids_size);
478 std::vector< std::pair< std::size_t, std::size_t > > cols_offsets(ids_size);
479 {
480 std::size_t i = std::size_t(0);
481 for (const auto id: ids) {
482 const std::size_t domain_size = database.domainSize(nodeId2columns[id]);
483 domain_sizes[i] = domain_size;
484 cols_offsets[i].first = nodeId2columns[id];
485 cols_offsets[i].second = counting_vect_size;
486 counting_vect_size *= domain_size;
487 ++i;
488 }
489 }
490
491 // we sort the columns and offsets by increasing column index. This
492 // may speed up threaded counts by improving the cacheline hits
493 std::sort(
494 cols_offsets.begin(),
495 cols_offsets.end(),
496 [](const std::pair< std::size_t, std::size_t >& a,
497 const std::pair< std::size_t, std::size_t >& b) -> bool { return a.first < b.first; });
498
499 // create parsers if needed
500 const std::size_t nb_ranges = _thread_ranges_.size();
501 const auto max_nb_threads = ThreadNumberManager::getNumberOfThreads();
502 const std::size_t nb_threads = nb_ranges <= max_nb_threads ? nb_ranges : max_nb_threads;
503 while (_parsers_.size() < nb_threads) {
504 ThreadData< DBRowGeneratorParser > new_parser(_parsers_[0]);
505 _parsers_.push_back(std::move(new_parser));
506 }
507
508 // set the columns of interest for each parser. This specifies to the
509 // parser which columns are used for the counts. This is important
510 // for parsers like the EM parser that complete unobserved variables.
511 std::vector< std::size_t > cols_of_interest(ids_size);
512 for (std::size_t i = std::size_t(0); i < ids_size; ++i) {
513 cols_of_interest[i] = cols_offsets[i].first;
514 }
515 for (auto& parser: _parsers_) {
516 parser.data.setColumnsOfInterest(cols_of_interest);
517 }
518
519 // allocate all the counting vectors, including that which will add
520 // all the results provided by the threads. We initialize once and
521 // for all these vectors with zeroes
522 std::vector< double > counting_vect(counting_vect_size, 0.0);
523 std::vector< ThreadData< std::vector< double > > > thread_countings(
524 nb_threads,
525 ThreadData< std::vector< double > >(counting_vect));
526
527 // here, we create a lambda that will be executed by all the threads
528 // to perform the counts in a parallel manner
529 auto threadedCount = [this, nb_ranges, ids_size, &thread_countings, cols_offsets](
530 const std::size_t this_thread,
531 const std::size_t nb_threads,
532 const std::size_t nb_loop) -> void {
533 if (this_thread + nb_loop < nb_ranges) {
534 // get the database parser and the contingency table to fill
535 DBRowGeneratorParser& parser = this->_parsers_[this_thread].data;
536 parser.setRange(this->_thread_ranges_[this_thread + nb_loop].first,
537 this->_thread_ranges_[this_thread + nb_loop].second);
538 std::vector< double >& counts = thread_countings[this_thread].data;
539
540 // parse the database
541 try {
542 while (parser.hasRows()) {
543 // get the observed rows
544 const DBRow< DBTranslatedValue >& row = parser.row();
545
546 // fill the counts for the current row
547 std::size_t offset = std::size_t(0);
548 for (std::size_t i = std::size_t(0); i < ids_size; ++i) {
549 offset += row[cols_offsets[i].first].discr_val * cols_offsets[i].second;
550 }
551
552 counts[offset] += row.weight();
553 }
554 } catch (NotFound const&) {} // this exception is raised by the row filter
555 // if the row generators create no output row
556 // from the last rows of the database
557 }
558 };
559
560
561 // launch the threads
562 for (std::size_t i = std::size_t(0); i < nb_ranges; i += nb_threads) {
563 ThreadExecutor::execute(nb_threads, threadedCount, i);
564 }
565
566
567 // add the counts to counting_vect
568 for (std::size_t k = std::size_t(0); k < nb_threads; ++k) {
569 const auto& thread_counting = thread_countings[k].data;
570 for (std::size_t r = std::size_t(0); r < counting_vect_size; ++r) {
571 counting_vect[r] += thread_counting[r];
572 }
573 }
574
575 // save the final results
576 _last_DB_ids_ = ids;
577 _last_DB_counting_ = std::move(counting_vect);
578
579 return _last_DB_counting_;
580 }
581
583 void RecordCounter::_checkRanges_(
584 const std::vector< std::pair< std::size_t, std::size_t > >& new_ranges) const {
585 const std::size_t dbsize = _parsers_[0].data.database().nbRows();
586 std::vector< std::pair< std::size_t, std::size_t > > incorrect_ranges;
587 for (const auto& range: new_ranges) {
588 if ((range.first >= range.second) || (range.second > dbsize)) {
589 incorrect_ranges.push_back(range);
590 }
591 }
592 if (!incorrect_ranges.empty()) {
593 std::stringstream str;
594 str << "It is impossible to set the ranges because the following one";
595 if (incorrect_ranges.size() > 1) str << "s are incorrect: ";
596 else str << " is incorrect: ";
597 bool deja = false;
598 for (const auto& range: incorrect_ranges) {
599 if (deja) str << ", ";
600 else deja = true;
601 str << '[' << range.first << ';' << range.second << ')';
602 }
603
604 GUM_ERROR(OutOfBounds, str.str())
605 }
606 }
607
609 void RecordCounter::_dispatchRangesToThreads_() {
610 _thread_ranges_.clear();
611
612 // ensure that _ranges_ contains the ranges asked by the user
613 bool add_range = false;
614 if (_ranges_.empty()) {
615 const auto& database = _parsers_[0].data.database();
616 _ranges_.push_back(
617 std::pair< std::size_t, std::size_t >(std::size_t(0), database.nbRows()));
618 add_range = true;
619 }
620
621 // dispatch the ranges
622 const auto max_nb_threads = ThreadNumberManager::getNumberOfThreads();
623 for (const auto& range: _ranges_) {
624 if (range.second > range.first) {
625 const std::size_t range_size = range.second - range.first;
626 std::size_t nb_threads = range_size / _min_nb_rows_per_thread_;
627 if (nb_threads < 1) nb_threads = 1;
628 else if (nb_threads > max_nb_threads) nb_threads = max_nb_threads;
629 std::size_t nb_rows_par_thread = range_size / nb_threads;
630 std::size_t rest_rows = range_size - nb_rows_par_thread * nb_threads;
631
632 std::size_t begin_index = range.first;
633 for (std::size_t i = std::size_t(0); i < nb_threads; ++i) {
634 std::size_t end_index = begin_index + nb_rows_par_thread;
635 if (rest_rows != std::size_t(0)) {
636 ++end_index;
637 --rest_rows;
638 }
639 _thread_ranges_.push_back(
640 std::pair< std::size_t, std::size_t >(begin_index, end_index));
641 begin_index = end_index;
642 }
643 }
644 }
645 if (add_range) _ranges_.clear();
646
647 // sort ranges by decreasing range size, so that if the number of
648 // ranges exceeds the number of threads allowed, we start a first round of
649 // threads with the highest range, then another round with lower ranges,
650 // and so on until all the ranges have been processed
651 std::sort(_thread_ranges_.begin(),
652 _thread_ranges_.end(),
653 [](const std::pair< std::size_t, std::size_t >& a,
654 const std::pair< std::size_t, std::size_t >& b) -> bool {
655 return (a.second - a.first) > (b.second - b.first);
656 });
657 }
658
660 void RecordCounter::setRanges(
661 const std::vector< std::pair< std::size_t, std::size_t > >& new_ranges) {
662 // first, we check that all ranges are within the database's bounds
663 _checkRanges_(new_ranges);
664
665 // since the ranges are OK, save them and clear the counting caches
666 const std::size_t new_size = new_ranges.size();
667 std::vector< std::pair< std::size_t, std::size_t > > ranges(new_size);
668 for (std::size_t i = std::size_t(0); i < new_size; ++i) {
669 ranges[i].first = new_ranges[i].first;
670 ranges[i].second = new_ranges[i].second;
671 }
672
673 clear();
674 _ranges_ = std::move(ranges);
675
676 // dispatch the ranges to the threads
677 _dispatchRangesToThreads_();
678 }
679
681 void RecordCounter::clearRanges() {
682 if (_ranges_.empty()) return;
683 clear();
684 _ranges_.clear();
685 _dispatchRangesToThreads_();
686 }
687
688
689 } /* namespace learning */
690
691} /* namespace gum */
692
693#endif /* DOXYGEN_SHOULD_SKIP_THIS */
Exception : the element we looked for cannot be found.
Exception : out of bound.
Exception : wrong type for this operation.
the class used to read a row in the database and to transform it into a set of DBRow instances that c...
RecordCounter(const DBRowGeneratorParser &parser, const std::vector< std::pair< std::size_t, std::size_t > > &ranges, const Bijection< NodeId, std::size_t > &nodeId2columns=Bijection< NodeId, std::size_t >())
default constructor
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
include the inlined functions if necessary
Definition CSVParser.h:54
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
STL namespace.
The class that computes counting of observations from the database.