aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
structuredInference_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
49
51
52namespace gum {
53 namespace prm {
54
55 template < typename GUM_SCALAR >
58 const PRMSystem< GUM_SCALAR >& system,
60 PRMInference< GUM_SCALAR >(prm, system), _gspan_(0), _pdata_(0), _mining_(false),
61 _dot_(".") {
62 GUM_CONSTRUCTOR(StructuredInference);
63 _gspan_ = new GSpan< GUM_SCALAR >(prm, system, strategy);
64 triang_time = 0.0;
65 mining_time = 0.0;
66 pattern_time = 0.0;
67 inner_time = 0.0;
68 obs_time = 0.0;
69 full_time = 0.0;
70 }
71
72 template < typename GUM_SCALAR >
75 PRMInference< GUM_SCALAR >(source), _gspan_(0), _pdata_(0), _mining_(source._mining_),
76 _found_query_(false), _dot_(".") {
77 GUM_CONS_CPY(StructuredInference);
78 _gspan_ = new GSpan< GUM_SCALAR >(*(this->prm_), *(this->sys_));
79 }
80
81 template < typename GUM_SCALAR >
83 GUM_DESTRUCTOR(StructuredInference);
84 delete this->_gspan_;
85
86 for (const auto& elt: _elim_map_)
87 delete elt.second;
88
89 for (const auto& elt: _cdata_map_)
90 delete elt.second;
91
92 for (const auto elt: _trash_)
93 delete (elt);
94
95 for (const auto& elt: _outputs_)
96 delete elt.second;
97
98 if (_pdata_) delete _pdata_;
99 }
100
101 template < typename GUM_SCALAR >
103 const StructuredInference< GUM_SCALAR >& source) {
104 this->prm_ = source.prm_;
105 this->sys_ = source.sys_;
106
107 if (this->_gspan_) delete this->_gspan_;
108
109 this->_gspan_ = new GSpan< GUM_SCALAR >(*(this->prm_), *(this->sys_));
110 return *this;
111 }
112
113 template < typename GUM_SCALAR >
116
117 template < typename GUM_SCALAR >
120
121 template < typename GUM_SCALAR >
123 const typename PRMInference< GUM_SCALAR >::Chain& chain,
124 Tensor< GUM_SCALAR >& m) {
125 timer.reset();
126 _found_query_ = false;
127 _query_ = chain;
129
130 if (!this->hasEvidence() && (chain.second->cpf().nbrDim() == 1)) {
131 Instantiation i(m);
132
133 for (i.setFirst(); !i.end(); i.inc())
134 m.set(i, chain.second->cpf().get(i));
135
136 return;
137 } else if (this->hasEvidence(chain)) {
138 Instantiation i(m);
139 const Tensor< GUM_SCALAR >* e = this->evidence(_query_.first)[_query_.second->id()];
140
141 for (i.setFirst(); !i.end(); i.inc())
142 m.set(i, e->get(i));
143
144 return;
145 }
146
147 _buildReduceGraph_(data);
149
150 if (data.pool.size() > 1) {
151 for (const auto pot: data.pool)
152 if (pot->contains(_query_.second->type().variable())) pots.insert(pot);
153
154 if (pots.size() == 1) {
155 Tensor< GUM_SCALAR >* pot = const_cast< Tensor< GUM_SCALAR >* >(*(pots.begin()));
156 GUM_ASSERT(pot->contains(_query_.second->type().variable()));
157 GUM_ASSERT(pot->variablesSequence().size() == 1);
158 Instantiation i(*pot), j(m);
159
160 for (i.setFirst(), j.setFirst(); !i.end(); i.inc(), j.inc())
161 m.set(j, pot->get(i));
162 } else {
164 Tensor< GUM_SCALAR >* tmp = Comb.execute(pots);
165 Instantiation i(m), j(*tmp);
166
167 for (i.setFirst(), j.setFirst(); !i.end(); i.inc(), j.inc())
168 m.set(i, tmp->get(j));
169
170 delete tmp;
171 }
172 } else {
173 Tensor< GUM_SCALAR >* pot = *(data.pool.begin());
174 GUM_ASSERT(pot->contains(_query_.second->type().variable()));
175 GUM_ASSERT(pot->variablesSequence().size() == 1);
176 Instantiation i(*pot), j(m);
177
178 for (i.setFirst(), j.setFirst(); !i.end(); i.inc(), j.inc())
179 m.set(j, pot->get(i));
180 }
181
182 m.normalize();
183
184 if (_pdata_) {
185 delete _pdata_;
186 _pdata_ = 0;
187 }
188
189 full_time = timer.step();
190 }
191
192 template < typename GUM_SCALAR >
194 const std::vector< typename PRMInference< GUM_SCALAR >::Chain >& queries,
195 Tensor< GUM_SCALAR >& j) {
196 GUM_ERROR(FatalError, "not implemented")
197 }
198
199 template < typename GUM_SCALAR >
201 std::stringstream s;
202 s << "Triangulation time: " << triang_time << std::endl;
203 s << "Pattern mining time: " << mining_time << std::endl;
204 s << "Pattern elimination time: " << pattern_time << std::endl;
205 s << "Inner node elimination time: " << inner_time << std::endl;
206 s << "Observed node elimination time: " << obs_time << std::endl;
207 s << "Full inference time: " << full_time << std::endl;
208 s << "#patterns: " << _gspan_->patterns().size() << std::endl;
209 Size count = 0;
210 using Iter = std::vector< gspan::Pattern* >::const_iterator;
211
212 for (Iter p = _gspan_->patterns().begin(); p != _gspan_->patterns().end(); ++p) {
213 if (_gspan_->matches(**p).size()) {
214 s << "Pattern n°" << count++ << " match count: " << _gspan_->matches(**p).size()
215 << std::endl;
216 s << "Pattern n°" << count++ << " instance count: " << (**p).size() << std::endl;
217 }
218 }
219
220 return s.str();
221 }
222
223 template < typename GUM_SCALAR >
226 // Launch the pattern mining
227 plopTimer.reset();
228
229 if (_mining_) _gspan_->discoverPatterns();
230
231 mining_time = plopTimer.step();
232 // Reducing each used pattern
233 plopTimer.reset();
234 using Iter = std::vector< gspan::Pattern* >::const_iterator;
235
236 for (Iter p = _gspan_->patterns().begin(); p != _gspan_->patterns().end(); ++p)
237 if (_gspan_->matches(**p).size()) _reducePattern_(*p);
238
239 pattern_time = plopTimer.step();
240 // reducing instance not already reduced in a pattern
242 // Adding edges using the pools
244 // Placing the query where it belongs
245 NodeId id = data.var2node.second(&(_query_.second->type().variable()));
246 data.outputs().erase(id);
247 data.queries().insert(id);
248 // Triangulating, then eliminating
249 PartialOrderedTriangulation t(&(data.reducedGraph), &(data.mods), &(data.partial_order));
250 const std::vector< NodeId >& elim_order = t.eliminationOrder();
251
252 for (size_t i = 0; i < data.outputs().size(); ++i)
253 eliminateNode(data.var2node.first(elim_order[i]), data.pool, _trash_);
254 }
255
256 template < typename GUM_SCALAR >
259 typename StructuredInference< GUM_SCALAR >::PData data(*p, _gspan_->matches(*p));
260 _buildPatternGraph_(data, pool, **(data.matches.begin()));
261 _removeBarrenNodes_(data, pool);
262 PartialOrderedTriangulation t(&(data.graph), &(data.mod), data.partial_order());
263 const std::vector< NodeId >& elim_order = t.eliminationOrder();
264
265 for (size_t i = 0; i < data.inners().size(); ++i)
266 if (!data.barren.exists(elim_order[i]))
267 eliminateNode(data.vars.second(elim_order[i]), pool, _trash_);
268
269 typename GSpan< GUM_SCALAR >::MatchedInstances fake_patterns;
271
272 for (const auto elt: **iter)
273 _reducedInstances_.insert(elt);
274
275 if (data.obs().size())
276 _elim_map_.insert(*iter, _eliminateObservedNodesInSource_(data, pool, **iter, elim_order));
277 else _elim_map_.insert(*iter, new Set< Tensor< GUM_SCALAR >* >(pool));
278
279 ++iter;
280
281 if (data.obs().size()) {
282 for (; iter != data.matches.end(); ++iter) {
283 try {
284 _elim_map_.insert(*iter, _eliminateObservedNodes_(data, pool, **iter, elim_order));
285 } catch (OperationNotAllowed const&) { fake_patterns.insert(*iter); }
286 }
287 } else {
288 for (; iter != data.matches.end(); ++iter) {
289 try {
290 _elim_map_.insert(*iter, _translatePotSet_(data, pool, **iter));
291 } catch (OperationNotAllowed const&) { fake_patterns.insert(*iter); }
292 }
293 }
294
295 for (const auto pat: fake_patterns) {
296 for (const auto elt: *pat)
297 _reducedInstances_.erase(elt);
298
299 data.matches.erase(pat);
300 }
301
302 obs_time += plopTimer.step();
303
304 if (data.queries().size())
305 for (const auto m: data.matches)
306 if (!(m->exists(const_cast< PRMInstance< GUM_SCALAR >* >(_query_.first))))
308 &(m->atPos(_query_data_.first)->get(_query_data_.second).type().variable()),
309 *(_elim_map_[m]),
310 _trash_);
311 }
312
313 template < typename GUM_SCALAR >
316 const Sequence< PRMInstance< GUM_SCALAR >* >& match,
319 NodeId id,
320 std::pair< Idx, std::string >& v) {
321 if ((*inst).hasRefAttr((*inst).get(v.second).id())) {
322 std::vector< std::pair< PRMInstance< GUM_SCALAR >*, std::string > >& refs
323 = inst->getRefAttr(inst->get(v.second).id());
324
325 for (auto r = refs.begin(); r != refs.end(); ++r) {
326 if (!match.exists(r->first)) {
327 data.outputs().insert(id);
328 break;
329 }
330 }
331 }
332
333 if (!(data.outputs().size() && (data.outputs().exists(id)))) {
334 for (const auto m: data.matches) {
335 if (this->hasEvidence(std::make_pair((*m)[v.first], &((*m)[v.first]->get(v.second))))) {
336 GUM_ASSERT(inst->type().name() == (*m)[v.first]->type().name());
337 GUM_ASSERT(inst->get(v.second).safeName() == (*m)[v.first]->get(v.second).safeName());
338 data.obs().insert(id);
339 break;
340 }
341 }
342
343 if (!(data.obs().size() && (data.obs().exists(id)))) data.inners().insert(id);
344 }
345 }
346
347 template < typename GUM_SCALAR >
350 Set< Tensor< GUM_SCALAR >* >& pool,
351 const Sequence< PRMInstance< GUM_SCALAR >* >& match) {
352 std::pair< Idx, std::string > v;
353 Tensor< GUM_SCALAR >* pot = 0;
354
355 for (const auto inst: match) {
356 for (const auto& elt: *inst) {
357 NodeId id = data.graph.addNode();
358 v = std::make_pair(match.pos(inst), elt.second->safeName());
359 data.map.insert(id, v);
360 data.node2attr.insert(id, _str_(inst, elt.second));
361 data.mod.insert(id, elt.second->type()->domainSize());
362 data.vars.insert(id, &(elt.second->type().variable()));
363 pool.insert(const_cast< Tensor< GUM_SCALAR >* >(&(elt.second->cpf())));
364 pot = &(const_cast< Tensor< GUM_SCALAR >& >(inst->get(v.second).cpf()));
365
366 for (const auto var: pot->variablesSequence()) {
367 try {
368 if (id != data.vars.first(var)) data.graph.addEdge(id, data.vars.first(var));
369 } catch (DuplicateElement const&) {
370 } catch (NotFound const&) {}
371 }
372
373 _insertNodeInElimLists_(data, match, inst, elt.second, id, v);
374
375 if (data.inners().exists(id)
376 && (inst->type().containerDag().children(elt.second->id()).size() == 0)
377 && _allInstanceNoRefAttr_(data, v))
378 data.barren.insert(id);
379 }
380 }
381
382 if (!_found_query_) {
383 for (const auto mat: data.matches) {
384 if (mat->exists(const_cast< PRMInstance< GUM_SCALAR >* >(_query_.first))) {
385 Idx pos = mat->pos(const_cast< PRMInstance< GUM_SCALAR >* >(_query_.first));
387 = match.atPos(pos)->get(_query_.second->safeName()).type().variable();
388 NodeId id = data.vars.first(&var);
389 data.barren.erase(id);
390 data.inners().erase(id);
391 data.obs().erase(id);
392 data.outputs().erase(id);
393 data.queries().insert(id);
394 _found_query_ = true;
395 _query_data_ = std::make_pair(pos, _query_.second->safeName());
396 break;
397 }
398 }
399 }
400 }
401
402 template < typename GUM_SCALAR >
405 std::pair< Idx, std::string > attr) {
406 for (const auto mat: data.matches)
407 if (mat->atPos(attr.first)->hasRefAttr(mat->atPos(attr.first)->get(attr.second).id()))
408 return false;
409
410 return true;
411 }
412
413 template < typename GUM_SCALAR >
416 Set< Tensor< GUM_SCALAR >* >& pool) {
417 Sequence< NodeId > candidates;
418
419 for (const auto node: data.barren) {
420 for (const auto pot: pool)
421 if (pot->contains(*data.vars.second(node))) {
422 pool.erase(pot);
423 break;
424 }
425
426 for (const auto nei: data.graph.neighbours(node))
427 if (data.inners().exists(nei)) {
428 try {
429 candidates.insert(nei);
430 } catch (DuplicateElement const&) {}
431 }
432 }
433
434 NodeId node;
435 Tensor< GUM_SCALAR >* my_pot = nullptr;
436 short count = 0;
437
438 while (candidates.size()) {
439 node = candidates.back();
440 candidates.erase(node);
441 count = 0;
442
443 for (const auto pot: pool) {
444 if (pot->contains(*data.vars.second(node))) {
445 ++count;
446 my_pot = pot;
447 }
448 }
449
450 if (count == 1) {
451 pool.erase(my_pot);
452 data.barren.insert(node);
453
454 for (const auto nei: data.graph.neighbours(node)) {
455 if (data.inners().exists(nei)) {
456 try {
457 candidates.insert(nei);
458 } catch (DuplicateElement const&) {}
459 }
460 }
461 }
462 }
463 }
464
465 template < typename GUM_SCALAR >
469 const Set< Tensor< GUM_SCALAR >* >& pool,
470 const Sequence< PRMInstance< GUM_SCALAR >* >& match,
471 const std::vector< NodeId >& elim_order) {
472 Set< Tensor< GUM_SCALAR >* >* my_pool = new Set< Tensor< GUM_SCALAR >* >(pool);
473 std::pair< Idx, std::string > target;
474 size_t end = data.inners().size() + data.obs().size();
475
476 for (size_t idx = data.inners().size(); idx < end; ++idx) {
477 target = data.map[data.vars.first(data.vars.second(elim_order[idx]))];
478 eliminateNode(&(match[target.first]->get(target.second).type().variable()),
479 *my_pool,
480 _trash_);
481 }
482
483 return my_pool;
484 }
485
486 template < typename GUM_SCALAR >
489 const Set< Tensor< GUM_SCALAR >* >& pool,
490 const Sequence< PRMInstance< GUM_SCALAR >* >& match,
491 const std::vector< NodeId >& elim_order) {
492 Set< Tensor< GUM_SCALAR >* >* my_pool = _translatePotSet_(data, pool, match);
493 std::pair< Idx, std::string > target;
494 size_t end = data.inners().size() + data.obs().size();
495
496 for (size_t idx = data.inners().size(); idx < end; ++idx) {
497 target = data.map[data.vars.first(data.vars.second(elim_order[idx]))];
498 eliminateNode(&(match[target.first]->get(target.second).type().variable()),
499 *my_pool,
500 _trash_);
501 }
502
503 return my_pool;
504 }
505
506 template < typename GUM_SCALAR >
509 const Set< Tensor< GUM_SCALAR >* >& pool,
510 const Sequence< PRMInstance< GUM_SCALAR >* >& match) {
511#ifdef DEBUG
512
513 for (const auto iter = data.matches.begin(); iter != data.matches.end(); ++iter) {
514 GUM_ASSERT((**iter).size() == match.size());
515
516 for (Size idx = 0; idx < match.size(); ++idx) {
517 GUM_ASSERT((**iter).atPos(idx)->type() == match.atPos(idx)->type());
518 }
519 }
520
521#endif
523 std::pair< Idx, std::string > target;
525 const Sequence< PRMInstance< GUM_SCALAR >* >& source = **(data.matches.begin());
526
527 for (Size idx = 0; idx < match.size(); ++idx) {
528 _reducedInstances_.insert(match[idx]);
529 const auto& chains = source[idx]->type().slotChains();
530
531 for (const auto sc: chains) {
532#ifdef DEBUG
533 GUM_ASSERT(!(sc->isMultiple()));
534#endif
535
536 try {
537 bij.insert(&(source[idx]
538 ->getInstance(sc->id())
539 .get(sc->lastElt().safeName())
540 .type()
541 .variable()),
542 &(match[idx]
543 ->getInstance(sc->id())
544 .get(sc->lastElt().safeName())
545 .type()
546 .variable()));
547 } catch (DuplicateElement const&) {
548 try {
549 if (bij.first(&(match[idx]
550 ->getInstance(sc->id())
551 .get(sc->lastElt().safeName())
552 .type()
553 .variable()))
554 != &(source[idx]
555 ->getInstance(sc->id())
556 .get(sc->lastElt().safeName())
557 .type()
558 .variable())) {
559 delete my_pool;
560 GUM_ERROR(OperationNotAllowed, "fake pattern")
561 }
562 } catch (NotFound const&) {
563 delete my_pool;
564 GUM_ERROR(OperationNotAllowed, "fake pattern")
565 }
566 }
567 }
568 }
569
570 for (const auto p: pool) {
571 for (const auto v: p->variablesSequence()) {
572 try {
573 target = data.map[data.vars.first(v)];
574 bij.insert(v, &(match[target.first]->get(target.second).type().variable()));
575 } catch (NotFound const&) {
576 GUM_ASSERT(bij.existsFirst(v));
577 } catch (DuplicateElement const&) {}
578 }
579
580 try {
581 my_pool->insert(copyTensor(bij, *p));
582 } catch (Exception const&) {
583 for (const auto pot: *my_pool)
584 delete pot;
585
586 delete my_pool;
587 GUM_ERROR(OperationNotAllowed, "fake pattern")
588 }
589 }
590
591 return my_pool;
592 }
593
594 template < typename GUM_SCALAR >
598 Tensor< GUM_SCALAR >* pot = nullptr;
599 PRMInstance< GUM_SCALAR >* inst = nullptr;
600
601 for (const auto& elt: *this->sys_) {
602 inst = elt.second;
603
604 if (!_reducedInstances_.exists(inst)) {
605 // Checking if its not an empty class
606 if (inst->size()) {
608
609 try {
610 data = _cdata_map_[&(inst->type())];
611 } catch (NotFound const&) {
613 _cdata_map_.insert(&(inst->type()), data);
614 }
615
616 data->instances.insert(inst);
617 // Filling up the partial ordering
618 List< NodeSet > partial_order;
619
620 if (data->inners().size()) partial_order.push_back(data->inners());
621
622 if (data->aggregators().size())
623 for (const auto agg: data->aggregators())
624 partial_order[0].insert(agg);
625
626 if (data->outputs().size()) partial_order.push_back(data->outputs());
627
628 if (_query_.first == inst) {
629 // First case, the instance contains the query
630 partial_order[0].erase(_query_.second->id());
631
632 if (partial_order[0].empty()) partial_order.erase(0);
633
634 if (partial_order.size() > 1) {
635 partial_order[1].erase(_query_.second->id());
636
637 if (partial_order[1].empty()) partial_order.erase(1);
638 }
639
640 NodeSet query_set;
641 query_set.insert(_query_.second->id());
642 partial_order.insert(query_set);
643
644 // Adding the tensors
645 for (auto attr = inst->begin(); attr != inst->end(); ++attr)
646 pool.insert(&(const_cast< Tensor< GUM_SCALAR >& >((*(attr.val())).cpf())));
647
648 // Adding evidences if any
649 if (this->hasEvidence(inst))
650 for (const auto& elt: this->evidence(inst))
651 pool.insert(const_cast< Tensor< GUM_SCALAR >* >(elt.second));
652
653 PartialOrderedTriangulation t(&(data->moral_graph), &(data->mods), &(partial_order));
654 const std::vector< NodeId >& v = t.eliminationOrder();
655
656 if (partial_order.size() > 1)
657 for (size_t idx = 0; idx < partial_order[0].size(); ++idx)
658 eliminateNode(&(inst->get(v[idx]).type().variable()), pool, _trash_);
659 } else if (this->hasEvidence(inst)) {
660 // Second case, the instance has evidences
661 // Adding the tensors
662 for (const auto& elt: *inst)
663 pool.insert(&const_cast< Tensor< GUM_SCALAR >& >(elt.second->cpf()));
664
665 // Adding evidences
666 for (const auto& elt: this->evidence(inst))
667 pool.insert(const_cast< Tensor< GUM_SCALAR >* >(elt.second));
668
669 PartialOrderedTriangulation t(&(data->moral_graph), &(data->mods), &(partial_order));
670
671 for (size_t idx = 0; idx < partial_order[0].size(); ++idx)
672 eliminateNode(&(inst->get(t.eliminationOrder()[idx]).type().variable()),
673 pool,
674 _trash_);
675 } else {
676 // Last cast, the instance neither contains evidences nor
677 // instances
678 // We translate the class level tensors into the instance ones
679 // and
680 // proceed with elimination
681 for (const auto srcPot: data->pool) {
682 pot = copyTensor(inst->bijection(), *srcPot);
683 pool.insert(pot);
684 _trash_.insert(pot);
685 }
686
687 for (const auto agg: data->c.aggregates())
688 pool.insert(&(const_cast< Tensor< GUM_SCALAR >& >(inst->get(agg->id()).cpf())));
689
690 // We eliminate inner aggregators with their parents if necessary
691 // (see
692 // CData constructor)
693 Size size = data->inners().size() + data->aggregators().size();
694
695 for (size_t idx = data->inners().size(); idx < size; ++idx)
696 eliminateNode(&(inst->get(data->elim_order()[idx]).type().variable()),
697 pool,
698 _trash_);
699 }
700
701 for (const auto pot: pool)
702 rg_data.pool.insert(pot);
703 }
704 }
705 }
706 }
707
708 template < typename GUM_SCALAR >
711 // We first add edges between variables already in pool (i.e. those of the
712 // reduced instances)
713 NodeId id_1, id_2;
714
715 for (const auto pot: data.pool) {
716 const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
717
718 for (Size var_1 = 0; var_1 < vars.size(); ++var_1) {
719 if (data.var2node.existsFirst(vars.atPos(var_1))) {
720 id_1 = data.var2node.second(vars.atPos(var_1));
721 } else {
722 id_1 = data.reducedGraph.addNode();
723 data.var2node.insert(vars.atPos(var_1), id_1);
724 data.mods.insert(id_1, vars.atPos(var_1)->domainSize());
725 data.outputs().insert(id_1);
726 }
727
728 for (Size var_2 = var_1 + 1; var_2 < vars.size(); ++var_2) {
729 if (data.var2node.existsFirst(vars.atPos(var_2))) {
730 id_2 = data.var2node.second(vars.atPos(var_2));
731 } else {
732 id_2 = data.reducedGraph.addNode();
733 data.var2node.insert(vars.atPos(var_2), id_2);
734 data.mods.insert(id_2, vars.atPos(var_2)->domainSize());
735 data.outputs().insert(id_2);
736 }
737
738 try {
739 data.reducedGraph.addEdge(id_1, id_2);
740 } catch (DuplicateElement const&) {}
741 }
742 }
743 }
744
745 // Adding tensors obtained from reduced patterns
746 for (const auto& elt: _elim_map_) {
747 // We add edges between variables in the same reduced patterns
748 for (const auto pot: *elt.second) {
749 data.pool.insert(pot);
750 const Sequence< const DiscreteVariable* >& vars = pot->variablesSequence();
751
752 for (Size var_1 = 0; var_1 < vars.size(); ++var_1) {
753 if (data.var2node.existsFirst(vars.atPos(var_1))) {
754 id_1 = data.var2node.second(vars.atPos(var_1));
755 } else {
756 id_1 = data.reducedGraph.addNode();
757 data.var2node.insert(vars.atPos(var_1), id_1);
758 data.mods.insert(id_1, vars.atPos(var_1)->domainSize());
759 data.outputs().insert(id_1);
760 }
761
762 for (Size var_2 = var_1 + 1; var_2 < vars.size(); ++var_2) {
763 if (data.var2node.existsFirst(vars.atPos(var_2))) {
764 id_2 = data.var2node.second(vars.atPos(var_2));
765 } else {
766 id_2 = data.reducedGraph.addNode();
767 data.var2node.insert(vars.atPos(var_2), id_2);
768 data.mods.insert(id_2, vars.atPos(var_2)->domainSize());
769 data.outputs().insert(id_2);
770 }
771
772 try {
773 data.reducedGraph.addEdge(id_1, id_2);
774 } catch (DuplicateElement const&) {}
775 }
776 }
777 }
778 }
779 }
780
781 template < typename GUM_SCALAR >
787
788 template < typename GUM_SCALAR >
790 const gspan::Pattern& p,
792 pattern(p), matches(m), _real_order_(0) {
794
795 for (int i = 0; i < 4; ++i)
796 _partial_order_.push_front(NodeSet());
797 }
798
799 template < typename GUM_SCALAR >
801 const typename StructuredInference< GUM_SCALAR >::PData& source) :
802 pattern(source.pattern), matches(source.matches), graph(source.graph), mod(source.mod),
803 node2attr(source.node2attr), vars(source.vars), _partial_order_(source._partial_order_),
804 _real_order_(0) {
806 }
807
808 template < typename GUM_SCALAR >
810 if (!_real_order_) {
812
813 for (const auto& set: _partial_order_)
814 if (set.size() > 0) _real_order_->insert(set);
815 }
816
817 return _real_order_;
818 }
819
820 template < typename GUM_SCALAR >
822 c(a_class), _elim_order_(0) {
824
825 // First step we add Attributes and Aggregators
826 for (const auto node: c.containerDag().nodes()) {
827 switch (c.get(node).elt_type()) {
829 pool.insert(&(const_cast< Tensor< GUM_SCALAR >& >(c.get(node).cpf())));
830 // break omited : We want to execute the next block
831 // for attributes
832 }
833
835 moral_graph.addNodeWithId(node);
836 mods.insert(node, c.get(node).type()->domainSize());
837 break;
838 }
839
840 default : { /* do nothing */
841 }
842 }
843 }
844
845 // Second, we add edges, moralise the graph and build the partial ordering
846 for (const auto node: moral_graph.nodes()) {
847 const auto& parents = c.containerDag().parents(node);
848
849 // Adding edges and marrying parents
850 for (auto tail = parents.begin(); tail != parents.end(); ++tail) {
853 moral_graph.addEdge(*tail, node);
854 NodeSet::const_iterator marry = tail;
855 ++marry;
856
857 while (marry != parents.end()) {
860 moral_graph.addEdge(*tail, *marry);
861
862 ++marry;
863 }
864 }
865 }
866
867 // Adding nodes to the partial ordering
868 switch (c.get(node).elt_type()) {
870 if (c.isOutputNode(c.get(node))) outputs().insert(node);
871 else aggregators().insert(node);
872
873 // If the aggregators is not an output and have parents which are
874 // not outputs, we must eliminate the parents after adding the
875 // aggregator's CPT
876 for (const auto par: c.containerDag().parents(node)) {
877 const auto& prnt = c.get(par);
878
879 if ((!c.isOutputNode(prnt))
882 inners().erase(prnt.id());
883 aggregators().insert(prnt.id());
884 }
885 }
886
887 break;
888 }
889
891 pool.insert(const_cast< Tensor< GUM_SCALAR >* >(&(c.get(node).cpf())));
892
893 if (c.isOutputNode(c.get(node))) outputs().insert(node);
894 else if (!aggregators().exists(node)) inners().insert(node);
895
896 break;
897 }
898
899 default : { /* Do nothing */
900 }
901 }
902 }
903
904 if (inners().size()) partial_order.insert(inners());
905
906 if (aggregators().size()) partial_order.insert(aggregators());
907
908 if (outputs().size()) partial_order.insert(outputs());
909
910 GUM_ASSERT(partial_order.size());
913
914 for (size_t i = 0; i < inners().size(); ++i)
915 eliminateNode(&(c.get(_elim_order_[i]).type().variable()), pool, _trash_);
916 }
917
918 template < typename GUM_SCALAR >
921
922 for (const auto pot: _trash_)
923 delete pot;
924 }
925
926 template < typename GUM_SCALAR >
928 const PRMInstance< GUM_SCALAR >* i = (this->sys_->begin()).val();
929 _query_ = std::make_pair(i, i->begin().val());
930 _found_query_ = false;
932 _buildReduceGraph_(data);
933 }
934
935 template < typename GUM_SCALAR >
939
940 template < typename GUM_SCALAR >
941 INLINE std::string
946
947 template < typename GUM_SCALAR >
948 INLINE std::string
953
954 template < typename GUM_SCALAR >
955 INLINE std::string
957 const PRMSlotChain< GUM_SCALAR >& a) const {
958 return i->name() + _dot_ + a.lastElt().safeName();
959 }
960
961 template < typename GUM_SCALAR >
965
966 template < typename GUM_SCALAR >
970
971 template < typename GUM_SCALAR >
972 INLINE std::string StructuredInference< GUM_SCALAR >::name() const {
973 return "StructuredInference";
974 }
975
976 template < typename GUM_SCALAR >
980
981 template < typename GUM_SCALAR >
983 return *_gspan_;
984 }
985
986 template < typename GUM_SCALAR >
989 NodeId id,
990 Set< Tensor< GUM_SCALAR >* >& pool) {
991 data.graph.eraseNode(id);
992 GUM_ASSERT(!data.graph.exists(id));
993 data.mod.erase(id);
994 GUM_ASSERT(!data.mod.exists(id));
995 data.node2attr.eraseFirst(id);
996 GUM_ASSERT(!data.node2attr.existsFirst(id));
997 data.map.erase(id);
998 GUM_ASSERT(!data.map.exists(id));
999 data.vars.eraseFirst(id);
1000 GUM_ASSERT(!data.vars.existsFirst(id));
1001 data.inners().erase(id);
1002 GUM_ASSERT(!data.inners().exists(id));
1003 pool.erase(data.pots[id]);
1004 GUM_ASSERT(!pool.exists(data.pots[id]));
1005 data.pots.erase(id);
1006 GUM_ASSERT(!data.pots.exists(id));
1007 }
1008
1009 } /* namespace prm */
1010} /* namespace gum */
bool existsFirst(const T1 &first) const
Returns true if first is the first element in a pair in the gum::Bijection.
const T1 & first(const T2 &second) const
Returns the first value of a pair given its second value.
Set of pairs of elements with fast search for both elements.
Definition bijection.h:1594
Base class for discrete random variable.
Exception : a similar element already exists.
const NodeSet & neighbours(NodeId id) const
returns the set of node neighbours to a given node
Base class for all aGrUM's exceptions.
Definition exceptions.h:118
Exception : fatal (unknown ?) error.
Class for assigning/browsing values to tuples of discrete variables.
bool end() const
Returns true if the Instantiation reached the end.
void inc()
Operator increment.
void setFirst()
Assign the first values to the tuple of the Instantiation.
Generic doubly linked lists.
Definition list.h:379
Size size() const noexcept
Returns the number of elements in the list.
Definition list_tpl.h:1719
Val & push_back(Args &&... args)
An alias for pushBack used for STL compliance.
Val & insert(const Val &val)
Inserts a new element at the end of the chained list (alias of pushBack).
Definition list_tpl.h:1515
void erase(Size i)
Erases the ith element of the List (the first one is in position 0).
Definition list_tpl.h:1772
A class to combine efficiently several MultiDim tables.
TABLE * execute(const Set< const TABLE * > &set) const final
Creates and returns the result of the combination of the tables within set.
bool exists(const NodeId id) const
alias for existsNode
virtual NodeId addNode()
insert a new node and return its id
Exception : the element we looked for cannot be found.
Exception : operation not allowed.
class for graph triangulations for which we enforce a given partial ordering on the nodes elimination...
const Key & atPos(Idx i) const
Returns the object at the pos i.
The generic class for storing (ordered) sequences of objects.
Definition sequence.h:972
Representation of a set.
Definition set.h:131
iterator begin() const
The usual unsafe begin iterator to parse the set.
Definition set_tpl.h:438
SetIterator< Sequence< PRMInstance< GUM_SCALAR > * > * > const_iterator
Definition set.h:143
Size size() const noexcept
Returns the number of elements in the set.
Definition set_tpl.h:636
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
Definition set_tpl.h:533
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
void erase(const Key &k)
Erases an element from the set.
Definition set_tpl.h:582
const std::vector< NodeId > & eliminationOrder()
returns an elimination ordering compatible with the triangulated graph
void addEdge(NodeId first, NodeId second) override
insert a new edge into the undirected graph
void eraseNode(NodeId id) override
remove a node and its adjacent edges from the graph
This class discovers pattern in a PRM<GUM_SCALAR>'s PRMSystem<GUM_SCALAR> to speed up structured infe...
Definition gspan.h:86
Set< Sequence< PRMInstance< GUM_SCALAR > * > * > MatchedInstances
Code alias.
Definition gspan.h:185
PRMAttribute is a member of a Class in a PRM.
static INLINE bool isAggregate(const PRMClassElement< GUM_SCALAR > &elt)
Return true if obj is of type PRMAggregate.
const std::string & safeName() const
Returns the safe name of this PRMClassElement, if any.
static INLINE bool isAttribute(const PRMClassElement< GUM_SCALAR > &elt)
Returns true if obj_ptr is of type PRMAttribute.
A PRMClass is an object of a PRM representing a fragment of a Bayesian network which can be instantia...
Definition PRMClass.h:75
std::pair< const PRMInstance< GUM_SCALAR > *, const PRMAttribute< GUM_SCALAR > * > Chain
Code alias.
EMap & evidence(const PRMInstance< GUM_SCALAR > &i)
Returns EMap of evidences over i.
PRMInference(const PRM< GUM_SCALAR > &prm, const PRMSystem< GUM_SCALAR > &system)
Default constructor.
PRMSystem< GUM_SCALAR > const * sys_
The Model on which inference is done.
PRM< GUM_SCALAR > const * prm_
The PRM<GUM_SCALAR> on which inference is done.
bool hasEvidence(const PRMInstance< GUM_SCALAR > &i) const
Returns true if i has evidence.
An PRMInstance is a Bayesian network fragment defined by a Class and used in a PRMSystem.
Definition PRMInstance.h:79
const iterator & end()
Returns a reference over the iterator at the end of the list of gum::prm::PRMAttribute<GUM_SCALAR> in...
const Bijection< const DiscreteVariable *, const DiscreteVariable * > & bijection() const
Returns a mapping between DiscreteVariable used in this and the ones used in this PRMInstance<GUM_SCA...
std::vector< std::pair< PRMInstance< GUM_SCALAR > *, std::string > > & getRefAttr(NodeId id)
Returns a vector of pairs of refering attributes of id.
PRMClass< GUM_SCALAR > & type()
Returns the type of this instance.
PRMAttribute< GUM_SCALAR > & get(NodeId id)
Getter on an PRMAttribute<GUM_SCALAR> of this PRMInstance<GUM_SCALAR>.
iterator begin()
Returns an iterator at the begining of the list of gum::prm::PRMAttribute<GUM_SCALAR> in this PRMInst...
Size size() const
Returns the number of attributes in this PRMInstance<GUM_SCALAR>.
const std::string & name() const
Returns the name of this object.
A PRMSlotChain represents a sequence of gum::prm::PRMClassElement<GUM_SCALAR> where the n-1 first gum...
PRMClassElement< GUM_SCALAR > & lastElt()
Returns the last element of the slot chain, typically this is an gum::PRMAttribute or a gum::PRMAggre...
A PRMSystem is a container of PRMInstance and describe a relational skeleton.
Definition PRMSystem.h:70
This class represents a Probabilistic Relational PRMSystem<GUM_SCALAR>.
Definition PRM.h:74
void searchPatterns()
Search for patterns without doing any computations.
Set< Tensor< GUM_SCALAR > * > * _eliminateObservedNodes_(typename StructuredInference::PData &data, const Set< Tensor< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match, const std::vector< NodeId > &elim_order)
Add in data.queries() any queried variable in one of data.pattern matches.
HashTable< const PRMClass< GUM_SCALAR > *, CData * > _cdata_map_
Mapping between a Class<GUM_SCALAR> and data about instances reduced using only Class<GUM_SCALAR> lev...
void _reducePattern_(const gspan::Pattern *p)
Proceed with the elimination of all inner variables (observed or not) of all usable matches of Patter...
std::string _dot_
Unreduce the match containing the query.
void setPatternMining(bool b)
Tells this algorithm to use pattern mining or not.
void _buildReduceGraph_(RGData &data)
This calls reducePattern() over each pattern and then build the reduced graph which is used for infer...
void _removeBarrenNodes_(typename StructuredInference::PData &data, Set< Tensor< GUM_SCALAR > * > &pool)
HashTable< const Sequence< PRMInstance< GUM_SCALAR > * > *, Set< Tensor< GUM_SCALAR > * > * > _elim_map_
Mapping between a Pattern's match and its tensor pool after inner variables were eliminated.
HashTable< const PRMClass< GUM_SCALAR > *, std::vector< NodeId > * > _outputs_
PRMInference< GUM_SCALAR >::Chain _query_
The query.
virtual void posterior_(const typename PRMInference< GUM_SCALAR >::Chain &chain, Tensor< GUM_SCALAR > &m)
See PRMInference::posterior_().
void _removeNode_(typename StructuredInference::PData &data, NodeId id, Set< Tensor< GUM_SCALAR > * > &pool)
void _buildPatternGraph_(PData &data, Set< Tensor< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match)
Build the DAG corresponding to Pattern data.pattern, initialize pool with all the Tensors of all vari...
PData * _pdata_
The pattern data of the pattern which one of its matches contains the query.
Set< Tensor< GUM_SCALAR > * > * _translatePotSet_(typename StructuredInference::PData &data, const Set< Tensor< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match)
Translate a given Tensor Set into one w.r.t. variables in match.
void _insertNodeInElimLists_(typename StructuredInference::PData &data, const Sequence< PRMInstance< GUM_SCALAR > * > &match, PRMInstance< GUM_SCALAR > *inst, PRMAttribute< GUM_SCALAR > *attr, NodeId id, std::pair< Idx, std::string > &v)
std::string _str_(const PRMInstance< GUM_SCALAR > *i, const PRMAttribute< GUM_SCALAR > *a) const
StructuredInference & operator=(const StructuredInference &source)
Copy operator.
void _addEdgesInReducedGraph_(RGData &data)
Add the nodes in the reduced graph.
bool _mining_
Flag which tells to use pattern mining or not.
GSpan< GUM_SCALAR > * _gspan_
Pointer over th GSpan<GUM_SCALAR> instance used by this class.
GSpan< GUM_SCALAR > & gspan()
Returns the instance of gspan used to search patterns.
bool _found_query_
Flag with an explicit name.
virtual void evidenceAdded_(const typename PRMInference< GUM_SCALAR >::Chain &chain)
See PRMInference::evidenceAdded_().
bool _allInstanceNoRefAttr_(typename StructuredInference::PData &data, std::pair< Idx, std::string > attr)
virtual void joint_(const std::vector< typename PRMInference< GUM_SCALAR >::Chain > &queries, Tensor< GUM_SCALAR > &j)
See PRMInference::joint_().
virtual std::string name() const
Tells this algorithm to use pattern mining or not.
virtual void evidenceRemoved_(const typename PRMInference< GUM_SCALAR >::Chain &chain)
See PRMInference::evidenceRemoved_().
Set< const PRMInstance< GUM_SCALAR > * > _reducedInstances_
This keeps track of reduced instances.
StructuredInference(const PRM< GUM_SCALAR > &prm, const PRMSystem< GUM_SCALAR > &system, gspan::SearchStrategy< GUM_SCALAR > *strategy=0)
Default constructor.
Set< Tensor< GUM_SCALAR > * > * _eliminateObservedNodesInSource_(typename StructuredInference::PData &data, const Set< Tensor< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match, const std::vector< NodeId > &elim_order)
Set< Tensor< GUM_SCALAR > * > _trash_
Keeping track of create tensors to delete them after inference.
void _reduceAloneInstances_(RGData &data)
Add the reduced tensors of instances not in any used patterns.
std::pair< Idx, std::string > _query_data_
This contains all the information we want for a node in a DFSTree.
Definition pattern.h:90
This is an abstract class used to tune search strategies in the gspan algorithm.
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size Idx
Type for indexes.
Definition types.h:79
Size NodeId
Type for node ids.
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
namespace for all probabilistic relational models entities
Definition agrum.h:68
Tensor< GUM_SCALAR > multTensor(const Tensor< GUM_SCALAR > &t1, const Tensor< GUM_SCALAR > &t2)
void eliminateNode(const DiscreteVariable *var, Set< Tensor< GUM_SCALAR > * > &pool, Set< Tensor< GUM_SCALAR > * > &trash)
Proceeds with the elimination of var in pool.
Tensor< GUM_SCALAR > * copyTensor(const Bijection< const DiscreteVariable *, const DiscreteVariable * > &bij, const Tensor< GUM_SCALAR > &source)
Returns a copy of a Tensor after applying a bijection over the variables in source.
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Private structure to represent data about a Class<GUM_SCALAR>.
NodeSet & aggregators()
Returns the set of aggregators and their parents.
Set< Tensor< GUM_SCALAR > * > _trash_
List< NodeSet > partial_order
The partial order used of variable elimination.
Set< const PRMInstance< GUM_SCALAR > * > instances
The Set of Instances reduces at class level.
NodeSet & outputs()
Returns the set of outputs nodes.
UndiGraph moral_graph
The class moral graph. NodeId matches those in c.
CData(const PRMClass< GUM_SCALAR > &c)
Default constructor.
std::vector< NodeId > & elim_order()
The elimination order for nodes of this class.
Set< Tensor< GUM_SCALAR > * > pool
The tensor pool obtained by C elimination of inner nodes.
NodeSet & inners()
Returns the set of inner nodes.
const PRMClass< GUM_SCALAR > & c
The class about what this data is about.
NodeProperty< Size > mods
The class variables modalities.
Private structure to represent data about a pattern.
List< NodeSet > _partial_order_
We'll use a PartialOrderedTriangulation with three sets: output, nodes and obs with children outside ...
NodeSet & obs()
Returns the set of inner and observed nodes given all the matches of pattern.
NodeSet & queries()
Returns the set of queried nodes given all the matches of pattern.
Set< NodeId > barren
Set of barren nodes.
Bijection< NodeId, std::string > node2attr
A bijection to easily keep track between graph and attributes, its of the form instance_name DOT attr...
NodeProperty< std::pair< Idx, std::string > > map
To ease translating tensors from one match to another.
UndiGraph graph
A yet to be triangulated undigraph.
NodeSet & outputs()
Returns the set of outputs nodes given all the matches of pattern.
const gspan::Pattern & pattern
The pattern for which this represents data about it.
NodeSet & inners()
Returns the set of inner nodes.
NodeProperty< Tensor< GUM_SCALAR > * > pots
To handle barren nodes.
Bijection< NodeId, const DiscreteVariable * > vars
Bijection between graph's nodes and their corresponding DiscreteVariable, for inference purpose.
List< NodeSet > * _real_order_
A copy of partial_order without empty sets.
PData(const gspan::Pattern &p, typename GSpan< GUM_SCALAR >::MatchedInstances &m)
Default constructor.
NodeProperty< Size > mod
The pattern's variables modalities.
GSpan< GUM_SCALAR >::MatchedInstances & matches
A reference over the usable matches of pattern.
Private structure to represent data about a reduced graph.
List< NodeSet > partial_order
Partial order used for triangulation, first is outputs nodes, second query nodes.
Set< Tensor< GUM_SCALAR > * > pool
The pool of tensors matching the reduced graph.
Bijection< const DiscreteVariable *, NodeId > var2node
Mapping between DiscreteVariable and NodeId.
NodeSet & queries()
Returns the set of query nodes (which will not be eliminated).
UndiGraph reducedGraph
The reduced graph.
NodeSet & outputs()
Returns the set of outputs nodes (which will be eliminated).
NodeProperty< Size > mods
Mapping between NodeId and modalities.
Headers of StructuredInference.