aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
multiDimCombineAndProjectDefault_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
49
50#ifndef DOXYGEN_SHOULD_SKIP_THIS
51
52# include <limits>
53
54# include <agrum/agrum.h>
55
57
58namespace gum {
59
60 // default constructor
61 template < class TABLE >
63 TABLE (*combine)(const TABLE&, const TABLE&),
64 TABLE (*project)(const TABLE&, const gum::VariableSet&)) :
66 _combination_(new MultiDimCombinationDefault< TABLE >(combine)),
67 _projection_(new MultiDimProjection< TABLE >(project)) {
68 // for debugging purposes
69 GUM_CONSTRUCTOR(MultiDimCombineAndProjectDefault);
70 }
71
72 // copy constructor
73 template < class TABLE >
74 MultiDimCombineAndProjectDefault< TABLE >::MultiDimCombineAndProjectDefault(
75 const MultiDimCombineAndProjectDefault< TABLE >& from) :
76 MultiDimCombineAndProject< TABLE >(), _combination_(from._combination_->clone()),
77 _projection_(from._projection_->clone()) {
78 // for debugging purposes
79 GUM_CONS_CPY(MultiDimCombineAndProjectDefault);
80 }
81
82 // destructor
83 template < class TABLE >
84 MultiDimCombineAndProjectDefault< TABLE >::~MultiDimCombineAndProjectDefault() {
85 // for debugging purposes
86 GUM_DESTRUCTOR(MultiDimCombineAndProjectDefault);
87 delete _combination_;
88 delete _projection_;
89 }
90
91 // virtual constructor
92 template < class TABLE >
93 MultiDimCombineAndProjectDefault< TABLE >*
94 MultiDimCombineAndProjectDefault< TABLE >::clone() const {
95 return new MultiDimCombineAndProjectDefault< TABLE >(*this);
96 }
97
98 // combine and project
99 template < class TABLE >
100 Set< const TABLE* >
101 MultiDimCombineAndProjectDefault< TABLE >::execute(const Set< const TABLE* >& table_set,
102 const gum::VariableSet& del_vars) {
103 // create a vector with all the tables stored as multidims
104 std::vector< const IScheduleMultiDim* > tables;
105 tables.reserve(table_set.size());
106 for (const auto table: table_set) {
107 tables.push_back(new ScheduleMultiDim< TABLE >(*table, false));
108 }
109
110 // get the set of operations to perform and execute them
111 auto ops_plus_res = operations(tables, del_vars, false);
112 for (auto op: ops_plus_res.first) {
113 op->execute();
114 }
115
116 // get the schedule multidims resulting from the computations and save them
117 Set< const TABLE* > result(ops_plus_res.second.size());
118 for (const auto pot: ops_plus_res.second) {
119 auto& schedule_result = const_cast< ScheduleMultiDim< TABLE >& >(
120 static_cast< const ScheduleMultiDim< TABLE >& >(*pot));
121 auto potres = new TABLE(std::move(schedule_result.multiDim()));
122 result.insert(potres);
123 }
124
125 // delete all the operations created as well as all the schedule tables
126 _freeData_(tables, ops_plus_res.first);
127
128 return result;
129 }
130
131 // changes the function used for combining two TABLES
132 template < class TABLE >
133 INLINE void MultiDimCombineAndProjectDefault< TABLE >::setCombinationFunction(
134 TABLE (*combine)(const TABLE&, const TABLE&)) {
135 _combination_->setCombinationFunction(combine);
136 }
137
138 // returns the current combination function
139 template < class TABLE >
140 INLINE TABLE (*MultiDimCombineAndProjectDefault< TABLE >::combinationFunction())(const TABLE&,
141 const TABLE&) {
142 return _combination_->combinationFunction();
143 }
144
145 // changes the class that performs the combinations
146 template < class TABLE >
147 INLINE void MultiDimCombineAndProjectDefault< TABLE >::setCombinationClass(
148 const MultiDimCombination< TABLE >& comb_class) {
149 delete _combination_;
150 _combination_ = comb_class.clone();
151 }
152
153 // changes the function used for projecting TABLES
154 template < class TABLE >
155 INLINE void MultiDimCombineAndProjectDefault< TABLE >::setProjectionFunction(
156 TABLE (*proj)(const TABLE&, const gum::VariableSet&)) {
157 _projection_->setProjectionFunction(proj);
158 }
159
160 // returns the current projection function
161 template < class TABLE >
162 INLINE TABLE (*MultiDimCombineAndProjectDefault< TABLE >::projectionFunction())(
163 const TABLE&,
164 const gum::VariableSet&) {
165 return _projection_->projectionFunction();
166 }
167
168 // changes the class that performs the projections
169 template < class TABLE >
170 INLINE void MultiDimCombineAndProjectDefault< TABLE >::setProjectionClass(
171 const MultiDimProjection< TABLE >& proj_class) {
172 delete _projection_;
173 _projection_ = proj_class.clone();
174 }
175
178 template < class TABLE >
179 double MultiDimCombineAndProjectDefault< TABLE >::nbOperations(
180 const Set< const Sequence< const DiscreteVariable* >* >& table_set,
181 const gum::VariableSet& del_vars) const {
182 // create a vector with all the tables stored as multidims
183 std::vector< const IScheduleMultiDim* > tables;
184 tables.reserve(table_set.size());
185 for (const auto vars: table_set) {
186 tables.push_back(new ScheduleMultiDim< TABLE >(*vars, false));
187 }
188
189 // get the set of operations to perform and compute their number of operations
190 auto ops_plus_res = operations(tables, del_vars, false);
191 double nb_operations = 0.0;
192 for (auto op: ops_plus_res.first) {
193 nb_operations += op->nbOperations();
194 }
195
196 // delete all the operations created as well as all the schedule tables
197 _freeData_(tables, ops_plus_res.first);
198
199 return nb_operations;
200 }
201
204 template < class TABLE >
205 double MultiDimCombineAndProjectDefault< TABLE >::nbOperations(
206 const Set< const TABLE* >& set,
207 const gum::VariableSet& del_vars) const {
208 // create the set of sets of discrete variables involved in the tables
209 Set< const Sequence< const DiscreteVariable* >* > var_set(set.size());
210
211 for (const auto ptrTab: set) {
212 var_set << &(ptrTab->variablesSequence());
213 }
214
215 return nbOperations(var_set, del_vars);
216 }
217
218 // returns the memory consumption used during the combinations and
219 // projections
220 template < class TABLE >
221 std::pair< double, double > MultiDimCombineAndProjectDefault< TABLE >::memoryUsage(
222 const Set< const Sequence< const DiscreteVariable* >* >& table_set,
223 const gum::VariableSet& del_vars) const {
224 // create a vector with all the tables stored as multidims
225 std::vector< const IScheduleMultiDim* > tables;
226 tables.reserve(table_set.size());
227 for (const auto vars: table_set) {
228 tables.push_back(new ScheduleMultiDim< TABLE >(*vars, false));
229 }
230
231 // get the set of operations to perform and compute their number of operations
232 auto ops_plus_res = operations(tables, del_vars, false);
233
234 // the resulting memory consumtions
235 double max_memory = 0.0;
236 double end_memory = 0.0;
237 for (const auto op: ops_plus_res.first) {
238 const auto usage = op->memoryUsage();
239 if (end_memory + usage.first > max_memory) max_memory = end_memory + usage.first;
240 end_memory += usage.second;
241 }
242
243 // delete all the operations created as well as all the schedule tables
244 _freeData_(tables, ops_plus_res.first);
245
246 return {max_memory, end_memory};
247 }
248
249 // returns the memory consumption used during the combinations and
250 // projections
251 template < class TABLE >
252 std::pair< double, double > MultiDimCombineAndProjectDefault< TABLE >::memoryUsage(
253 const Set< const TABLE* >& set,
254 const gum::VariableSet& del_vars) const {
255 // create the set of sets of discrete variables involved in the tables
256 Set< const Sequence< const DiscreteVariable* >* > var_set(set.size());
257
258 for (const auto ptrTab: set) {
259 var_set << &(ptrTab->variablesSequence());
260 }
261
262 return memoryUsage(var_set, del_vars);
263 }
264
267 template < class TABLE >
268 std::pair< std::vector< ScheduleOperator* >, Set< const IScheduleMultiDim* > >
269 MultiDimCombineAndProjectDefault< TABLE >::operations(
270 const std::vector< const IScheduleMultiDim* >& original_tables,
271 const gum::VariableSet& del_vars,
272 const bool is_result_persistent) const {
273 Set< const IScheduleMultiDim* > tables_set(original_tables.size());
274 for (const auto table: original_tables) {
275 tables_set.insert(table);
276 }
277 return operations(tables_set, del_vars, is_result_persistent);
278 }
279
282 template < class TABLE >
283 std::pair< std::vector< ScheduleOperator* >, Set< const IScheduleMultiDim* > >
284 MultiDimCombineAndProjectDefault< TABLE >::operations(
285 const Set< const IScheduleMultiDim* >& original_tables,
286 const gum::VariableSet& original_del_vars,
287 const bool is_result_persistent) const {
288 // check if we need to combine and/or project something
289 const Size tabsize = original_tables.size();
290 if (tabsize < 2) {
291 if (tabsize == 1) {
292 auto res = _projection_->operations(*original_tables.begin(), original_del_vars);
293 return std::pair< std::vector< ScheduleOperator* >, Set< const IScheduleMultiDim* > >(
294 {res.first},
295 {res.second});
296 } else {
297 std::string names;
298 for (const auto& v: original_del_vars) {
299 names += v->name() + ", ";
300 }
302 "MultiDimCombineAndProject need at least one table to "
303 "have some work to do (original_del_vars ="
304 << names << ").");
305 }
306 }
307
308 // we copy the set of tables to be combined and the set of variables to
309 // delete because we will modify them during the combination/projection process
310 Set< const IScheduleMultiDim* > tables = original_tables;
311 gum::VariableSet del_vars = original_del_vars;
312
313 // when we remove a variable, we need to combine all the tables containing
314 // this variable in order to produce a new unique table containing this
315 // variable. Removing a variable is then performed by marginalizing it out of
316 // the table. In the combineAndProjectDefault algorithm, we wish to remove
317 // first variables that would produce small tables. This should speed up the
318 // whole marginalizing process.
319
320 Size nb_vars;
321 {
322 // determine the set of all the variables involved in the tables.
323 // this should help sizing correctly the hashtables used hereafter
324 gum::VariableSet all_vars;
325
326 for (const auto table: tables) {
327 for (const auto ptrVar: table->variablesSequence()) {
328 all_vars.insert(ptrVar);
329 }
330 }
331
332 nb_vars = all_vars.size();
333 }
334
335 // the tables containing a given variable
336 HashTable< const DiscreteVariable*, Set< const IScheduleMultiDim* > > tables_per_var(nb_vars);
337
338 // for a given variable X to be deleted, the list of all the variables of
339 // the tables containing X (actually, we also count the number of tables
340 // containing the variable. This is more efficient for computing and
341 // updating the product_size priority queue (see below) when some tables
342 // are removed)
343 HashTable< const DiscreteVariable*, HashTable< const DiscreteVariable*, unsigned int > >
344 clique_vars_per_var(nb_vars);
345
346 // initialize clique_vars_per_var and tables_per_var
347 {
348 Set< const IScheduleMultiDim* > empty_set(tables.size());
349 HashTable< const DiscreteVariable*, unsigned int > empty_hash(nb_vars);
350
351 for (const auto ptrVar: del_vars) {
352 tables_per_var.insert(ptrVar, empty_set);
353 clique_vars_per_var.insert(ptrVar, empty_hash);
354 }
355
356 // update properly tables_per_var and clique_vars_per_var
357 for (const auto ptrTab: tables) {
358 const auto& vars = ptrTab->variablesSequence();
359
360 for (const auto ptrVar: vars) {
361 if (del_vars.contains(ptrVar)) {
362 // add the table to the set of tables related to vars[i]
363 tables_per_var[ptrVar].insert(ptrTab);
364
365 // add the variables of the table to clique_vars_per_var[vars[i]]
366 auto& comb_vars = clique_vars_per_var[ptrVar];
367 for (const auto xptrVar: vars) {
368 try {
369 ++comb_vars[xptrVar];
370 } catch (const NotFound&) { comb_vars.insert(xptrVar, 1); }
371 }
372 }
373 }
374 }
375 }
376
377 // create the set of operations to execute to perform the combinations and
378 // projections
379 std::vector< ScheduleOperator* > ops;
380 ops.reserve(2 * tables.size() + del_vars.size());
381
382 // keep track of the operations that created new tables. This is useful
383 // when requiring that results are persistent
384 HashTable< const IScheduleMultiDim*, ScheduleOperator* > multidim2op(tables.size());
385
386 // the sizes of the tables produced when removing a given discrete variable
387 PriorityQueue< const DiscreteVariable*, double > product_size;
388
389 // initialize properly product_size
390 for (const auto& elt: clique_vars_per_var) {
391 double size = 1.0;
392 const auto ptrVar = elt.first;
393 const auto& hashvars = elt.second; // HashTable<DiscreteVariable*, int>
394
395 if (!hashvars.empty()) {
396 for (const auto& xelt: hashvars) {
397 size *= (double)xelt.first->domainSize();
398 }
399
400 product_size.insert(ptrVar, size);
401 }
402 }
403
404 // now, remove all the variables in del_vars, starting from those that
405 // produce the smallest tables
406 while (!product_size.empty()) {
407 // get the best variable to remove
408 const DiscreteVariable* del_var = product_size.pop();
409 del_vars.erase(del_var);
410
411 // get the set of tables to combine
412 auto& tables_to_combine = tables_per_var[del_var];
413
414 // if there is no tables to combine, do nothing
415 if (tables_to_combine.empty()) continue;
416
417 // compute the combination of all the tables: if there is only one table,
418 // there is nothing to do, else we shall use the MultiDimCombination
419 // to perform the combination
420 const IScheduleMultiDim* joint = nullptr;
421 bool joint_to_delete;
422 if (tables_to_combine.size() == 1) {
423 joint = *(tables_to_combine.begin());
424 joint_to_delete = false;
425 } else {
426 // get the operations to perform to make the combination as well as
427 // the result of the combination
428 auto comb_ops = _combination_->operations(tables_to_combine);
429 ops.insert(ops.cend(), comb_ops.first.begin(), comb_ops.first.end());
430 joint = comb_ops.second;
431 joint_to_delete = true;
432 }
433
434 // compute the table resulting from marginalizing out del_var from joint
435 // and add the projection to the set of operations. Here, we know that the
436 // joint contains del_var, hence there is a nonempty projection to perform
437 gum::VariableSet del_one_var;
438 del_one_var << del_var;
439 auto proj_ops = _projection_->operations(joint, del_one_var);
440 ops.push_back(proj_ops.first);
441 const IScheduleMultiDim* marginal = proj_ops.second;
442 if (is_result_persistent) multidim2op.insert(marginal, proj_ops.first);
443
444 // remove the temporary joint if needed
445 if (joint_to_delete) {
446 auto deletion = new ScheduleDeletion< TABLE >(
447 static_cast< const ScheduleMultiDim< TABLE >& >(*joint));
448 ops.push_back(deletion);
449 }
450
451 // update clique_vars_per_var : remove the variables of the tables we
452 // combined from this hashtable
453 // update accordingly tables_per_vars : remove these tables
454 // update accordingly product_size : when a variable is no more used by
455 // any table, divide product_size by its domain size
456 for (const auto ptrTab: tables_to_combine) {
457 const auto& table_vars = ptrTab->variablesSequence();
458 const Size tab_vars_size = table_vars.size();
459
460 for (Size i = 0; i < tab_vars_size; ++i) {
461 if (del_vars.contains(table_vars[i])) {
462 // here we have a variable that needed to be removed => update
463 // product_size, tables_per_var and clique_vars_per_var: here,
464 // the update corresponds to removing table PtrTab
465 auto& table_vars_of_var_i = clique_vars_per_var[table_vars[i]];
466 double div_size = 1.0;
467
468 for (Size j = 0; j < tab_vars_size; ++j) {
469 unsigned int k = --table_vars_of_var_i[table_vars[j]];
470
471 if (k == 0) {
472 div_size *= table_vars[j]->domainSize();
473 table_vars_of_var_i.erase(table_vars[j]);
474 }
475 }
476
477 tables_per_var[table_vars[i]].erase(ptrTab);
478
479 if (div_size != 1.0) {
480 product_size.setPriority(table_vars[i],
481 product_size.priority(table_vars[i]) / div_size);
482 }
483 }
484 }
485
486 // if ptrTab is a table resulting from preceding combinations/projections,
487 // it is temporary and, therefore, it should be deleted
488 if (!original_tables.contains(ptrTab)) {
489 auto deletion = new ScheduleDeletion< TABLE >(
490 static_cast< const ScheduleMultiDim< TABLE >& >(*ptrTab));
491 ops.push_back(deletion);
492 }
493
494 tables.erase(ptrTab);
495 }
496
497 tables_per_var.erase(del_var);
498
499 // add the new projected marginal to the list of tables
500 const auto& marginal_vars = marginal->variablesSequence();
501 for (const auto mvar: marginal_vars) {
502 if (del_vars.contains(mvar)) {
503 // add the new marginal table to the set of tables of mvar
504 tables_per_var[mvar].insert(marginal);
505
506 // add the variables of the table to clique_vars_per_var[mvar]
507 auto& iter_vars = clique_vars_per_var[mvar];
508 double mult_size = 1.0;
509 for (const auto var: marginal_vars) {
510 try {
511 ++iter_vars[var];
512 } catch (const NotFound&) {
513 iter_vars.insert(var, 1);
514 mult_size *= (double)var->domainSize();
515 }
516 }
517
518 if (mult_size != 1.0) {
519 product_size.setPriority(mvar, product_size.priority(mvar) * mult_size);
520 }
521 }
522 }
523
524 tables.insert(marginal);
525 }
526
527 // here, Set "tables" contains the list of the tables resulting from
528 // marginalizing out of del_vars of the combination of the tables
529 // of original_tables. Note in particular that it will contain all the
530 // tensors with no dimension (constants)
531
532 // if we require persistent results, update the operations that produced some
533 // of the tables in Set "tables"
534 if (is_result_persistent) {
535 for (const auto table: tables) {
536 if (multidim2op.exists(table)) multidim2op[table]->makeResultsPersistent(true);
537 }
538 }
539
540 return {ops, tables};
541 }
542
544 template < class TABLE >
545 INLINE void MultiDimCombineAndProjectDefault< TABLE >::_freeData_(
546 std::vector< const IScheduleMultiDim* >& tables,
547 std::vector< ScheduleOperator* >& operations) const {
548 for (auto op: operations)
549 delete op;
550
551 for (auto table: tables)
552 delete table;
553 }
554
555} /* namespace gum */
556
557#endif /* DOXYGEN_SHOULD_SKIP_THIS */
A class to combine efficiently several MultiDim tables.
MultiDimCombineAndProjectDefault(TABLE(*combine)(const TABLE &, const TABLE &), TABLE(*project)(const TABLE &, const gum::VariableSet &))
Default constructor.
A generic interface to combine and project efficiently MultiDim tables.
A generic class to project efficiently a MultiDim table over a subset of its variables.
Exception : the element we looked for cannot be found.
Exception : operation not allowed.
Size size() const noexcept
Returns the number of elements in the set.
Definition set_tpl.h:636
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Set< const DiscreteVariable * > VariableSet