aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
PRMInference_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
51
52namespace gum {
53 namespace prm {
54
55 template < typename GUM_SCALAR >
57 for (const auto& elt: _evidences_) {
58 for (const auto& elt2: *elt.second)
59 delete elt2.second;
60
61 delete elt.second;
62 }
63
64 _evidences_.clear();
65 }
66
67 template < typename GUM_SCALAR >
69 prm_(source.prm_), sys_(source.sys_) {
70 GUM_CONS_CPY(PRMInference);
71
72 for (const auto& elt: source._evidences_) {
73 _evidences_.insert(elt.first, new PRMInference< GUM_SCALAR >::EMap());
74
75 for (const auto& elt2: *elt.second) {
76 Tensor< GUM_SCALAR >* e = new Tensor< GUM_SCALAR >();
77 e->add(*(elt2.second->variablesSequence().front()));
78 Instantiation i(*e);
79
80 for (i.setFirst(); !i.end(); i.inc())
81 e->set(i, elt2.second->get(i));
82
83 _evidences_[elt.first]->insert(elt2.first, e);
84 }
85 }
86 }
87
88 template < typename GUM_SCALAR >
92 prm_ = source.prm_;
93 sys_ = source.sys_;
94
95 for (const auto& elt: source._evidences_) {
96 _evidences_.insert(elt.first, new PRMInference< GUM_SCALAR >::EMap());
97
98 for (const auto& elt2: *elt.second) {
99 Tensor< GUM_SCALAR >* e = new Tensor< GUM_SCALAR >();
100 e->add(*(elt2.second->variablesSequence().front()));
101 Instantiation i(*e);
102
103 for (i.setFirst(); !i.end(); i.inc()) {
104 e->set(i, elt2.second->get(i));
105 }
106
107 _evidences_[elt.first]->insert(elt2.first, e);
108 }
109 }
110
111 return *this;
112 }
113
114 template < typename GUM_SCALAR >
117 if (_evidences_.exists(i)) {
118 return *(_evidences_[i]);
119 } else {
121 return *(_evidences_[i]);
122 }
123 }
124
125 template < typename GUM_SCALAR >
127 const Tensor< GUM_SCALAR >& p) {
128 if (chain.first->exists(chain.second->id())) {
129 if ((p.nbrDim() != 1) || (!p.contains(chain.second->type().variable())))
130 GUM_ERROR(OperationNotAllowed, "illegal evidence for the given PRMAttribute.")
131
132 Tensor< GUM_SCALAR >* e = new Tensor< GUM_SCALAR >();
133 e->add(chain.second->type().variable());
134 Instantiation i(*e);
135
136 for (i.setFirst(); !i.end(); i.inc())
137 e->set(i, p.get(i));
138
139 PRMInference< GUM_SCALAR >::EMap& emap = _EMap_(chain.first);
141 if (emap.exists(chain.second->id())) {
142 delete emap[chain.second->id()];
143 emap[chain.second->id()] = e;
144 } else {
145 emap.insert(chain.second->id(), e);
146 }
147
148 evidenceAdded_(chain);
149 } else {
151 "the given PRMAttribute does not belong to this "
152 "Instance<GUM_SCALAR>.");
153 }
154 }
155
156 template < typename GUM_SCALAR >
158 const PRMSystem< GUM_SCALAR >& system) :
159 prm_(&prm), sys_(&system) {
160 GUM_CONSTRUCTOR(PRMInference);
161 }
162
163 template < typename GUM_SCALAR >
165 GUM_DESTRUCTOR(PRMInference);
167 }
168
169 template < typename GUM_SCALAR >
170 INLINE typename PRMInference< GUM_SCALAR >::EMap&
172 try {
173 return *(_evidences_[&i]);
174 } catch (NotFound const&) { GUM_ERROR(NotFound, "this instance has no evidence.") }
175 }
176
177 template < typename GUM_SCALAR >
178 INLINE const typename PRMInference< GUM_SCALAR >::EMap&
180 try {
181 return *(_evidences_[&i]);
182 } catch (NotFound const&) { GUM_ERROR(NotFound, "this instance has no evidence.") }
183 }
184
185 template < typename GUM_SCALAR >
186 INLINE typename PRMInference< GUM_SCALAR >::EMap&
188 try {
189 return *(_evidences_[i]);
190 } catch (NotFound const&) { GUM_ERROR(NotFound, "this instance has no evidence.") }
191 }
193 template < typename GUM_SCALAR >
194 INLINE const typename PRMInference< GUM_SCALAR >::EMap&
196 try {
197 return *(_evidences_[i]);
198 } catch (NotFound const&) { GUM_ERROR(NotFound, "this instance has no evidence.") }
199 }
200
201 template < typename GUM_SCALAR >
203 return _evidences_.exists(&i);
204 }
205
206 template < typename GUM_SCALAR >
208 return _evidences_.exists(i);
209 }
210
211 template < typename GUM_SCALAR >
212 INLINE bool PRMInference< GUM_SCALAR >::hasEvidence(const Chain& chain) const {
213 return (hasEvidence(chain.first)) ? evidence(chain.first).exists(chain.second->id()) : false;
214 }
215
216 template < typename GUM_SCALAR >
218 return (_evidences_.size() != (Size)0);
219 }
220
221 template < typename GUM_SCALAR >
223 try {
224 if (_EMap_(chain.first).exists(chain.second->id())) {
225 evidenceRemoved_(chain);
226 delete _EMap_(chain.first)[chain.second->id()];
227 _EMap_(chain.first).erase(chain.second->id());
228 }
229 } catch (NotFound const&) {
230 // Ok, we are only removing
231 }
232 }
233
234 template < typename GUM_SCALAR >
236 const typename PRMInference< GUM_SCALAR >::Chain& chain,
237 Tensor< GUM_SCALAR >& m) {
238 if (m.nbrDim() > 0) { GUM_ERROR(OperationNotAllowed, "the given Tensor is not empty.") }
239
240 if (hasEvidence(chain)) {
241 m.add(chain.second->type().variable());
242 const Tensor< GUM_SCALAR >& e = *(evidence(chain.first)[chain.second->id()]);
243 Instantiation i(m), j(e);
244
245 for (i.setFirst(), j.setFirst(); !i.end(); i.inc(), j.inc())
246 m.set(i, e.get(j));
247 } else {
248 if (chain.second != &(chain.first->get(chain.second->safeName()))) {
249 typename PRMInference< GUM_SCALAR >::Chain good_chain
250 = std::make_pair(chain.first, &(chain.first->get(chain.second->safeName())));
251 m.add(good_chain.second->type().variable());
252 posterior_(good_chain, m);
253 } else {
254 m.add(chain.second->type().variable());
255 posterior_(chain, m);
256 }
257 }
258 }
259
260 template < typename GUM_SCALAR >
262 const std::vector< typename PRMInference< GUM_SCALAR >::Chain >& chains,
263 Tensor< GUM_SCALAR >& j) {
264 if (j.nbrDim() > 0) { GUM_ERROR(OperationNotAllowed, "the given Tensor is not empty.") }
265
266 for (auto chain = chains.begin(); chain != chains.end(); ++chain) {
267 j.add(chain->second->type().variable());
268 }
269
270 joint_(chains, j);
271 }
272
273 } /* namespace prm */
274} /* namespace gum */
Headers of PRMInference.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
Class for assigning/browsing values to tuples of discrete variables.
bool end() const
Returns true if the Instantiation reached the end.
void inc()
Operator increment.
void setFirst()
Assign the first values to the tuple of the Instantiation.
Exception : the element we looked for cannot be found.
Exception : operation not allowed.
std::pair< const PRMInstance< GUM_SCALAR > *, const PRMAttribute< GUM_SCALAR > * > Chain
Code alias.
void addEvidence(const Chain &chain, const Tensor< GUM_SCALAR > &p)
Add an evidence to the given instance's elt.
virtual ~PRMInference()
Destructor.
EMap & evidence(const PRMInstance< GUM_SCALAR > &i)
Returns EMap of evidences over i.
virtual void joint_(const std::vector< Chain > &queries, Tensor< GUM_SCALAR > &j)=0
Generic method to compute the posterior of given element.
void posterior(const Chain &chain, Tensor< GUM_SCALAR > &m)
Compute the posterior of the formal attribute pointed by chain and stores it in m.
HashTable< const PRMInstance< GUM_SCALAR > *, EMap * > _evidences_
Mapping of evidence over PRMInstance<GUM_SCALAR>'s nodes.
PRMInference(const PRM< GUM_SCALAR > &prm, const PRMSystem< GUM_SCALAR > &system)
Default constructor.
EMap & _EMap_(const PRMInstance< GUM_SCALAR > *i)
Private getter over evidences, if necessary creates an EMap for i.
void joint(const std::vector< Chain > &chains, Tensor< GUM_SCALAR > &j)
Compute the joint probability of the formals attributes pointed by chains and stores it in m.
void removeEvidence(const Chain &chain)
Remove evidence on the given instance's elt.
PRMSystem< GUM_SCALAR > const * sys_
The Model on which inference is done.
PRMInference & operator=(const PRMInference &source)
Copy operator.
NodeProperty< const Tensor< GUM_SCALAR > * > EMap
Code alias.
PRM< GUM_SCALAR > const * prm_
The PRM<GUM_SCALAR> on which inference is done.
virtual void posterior_(const Chain &chain, Tensor< double > &m)=0
bool hasEvidence() const
Returns true if i has evidence on PRMAttribute<GUM_SCALAR> a.
virtual void evidenceRemoved_(const Chain &chain)=0
This method is called whenever an evidence is removed, but BEFORE any processing made by PRMInference...
bool hasEvidence(const PRMInstance< GUM_SCALAR > &i) const
Returns true if i has evidence.
void clearEvidence()
Remove all evidences.
An PRMInstance is a Bayesian network fragment defined by a Class and used in a PRMSystem.
Definition PRMInstance.h:79
const iterator & end()
Returns a reference over the iterator at the end of the list of gum::prm::PRMAttribute<GUM_SCALAR> in...
A PRMSystem is a container of PRMInstance and describe a relational skeleton.
Definition PRMSystem.h:70
This class represents a Probabilistic Relational PRMSystem<GUM_SCALAR>.
Definition PRM.h:74
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
namespace for all probabilistic relational models entities
Definition agrum.h:68
gum is the global namespace for all aGrUM entities
Definition agrum.h:46