aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
jointTargetedMRFInference_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
51
52namespace gum {
53
54
55 // Default Constructor
56 template < typename GUM_SCALAR >
59 // assign a MRF if this has not been done before (due to virtual inheritance)
61 GUM_CONSTRUCTOR(JointTargetedMRFInference);
62 }
63
64 // Destructor
65 template < typename GUM_SCALAR >
69
70 // assigns a new MRF to the inference engine
71 template < typename GUM_SCALAR >
77
78 // ##############################################################################
79 // Targets
80 // ##############################################################################
81
82 // return true if target is a nodeset target.
83 template < typename GUM_SCALAR >
85 if (this->hasNoModel_())
87 "No Markov net has been assigned to the "
88 "inference algorithm");
89
90 const auto& gra = this->MRF().graph();
91 for (const auto var: vars) {
92 if (!gra.exists(var)) {
93 GUM_ERROR(UndefinedElement, var << " is not a NodeId in the Markov random field")
94 }
95 }
96
97 return _joint_targets_.contains(vars);
98 }
99
100 // Clear all previously defined single targets
101 template < typename GUM_SCALAR >
105
106 // Clear all previously defined targets (single targets and sets of targets)
107 template < typename GUM_SCALAR >
109 if (_joint_targets_.size() > 0) {
110 // we already are in target mode. So no this->setTargetedMode_(); is needed
112 _joint_targets_.clear();
114 }
115 }
116
117 // Clear all previously defined targets (single and joint targets)
118 template < typename GUM_SCALAR >
123
124 // Add a set of nodes as a new target
125 template < typename GUM_SCALAR >
127 // check if the nodes in the target belong to the Markov random field
128 if (this->hasNoModel_())
130 "No Markov net has been assigned to the "
131 "inference algorithm");
132
133 const auto& dag = this->MRF().graph();
134 for (const auto node: joint_target) {
135 if (!dag.exists(node)) {
137 "at least one one in " << joint_target << " does not belong to the mn");
138 }
139 }
140
141 if (isExactJointComputable_(joint_target)) return;
142 if (!superForJointComputable_(joint_target).empty()) return;
143
144 // check if joint_target is a subset of an already existing target
145 for (const auto& target: _joint_targets_) {
146 if (target.isStrictSupersetOf(joint_target)) return;
147 }
148
149 // check if joint_target is not a superset of an already existing target
150 // in this case, we need to remove old existing target
151 for (auto iter = _joint_targets_.beginSafe(); iter != _joint_targets_.endSafe(); ++iter) {
152 if (iter->isStrictSubsetOf(joint_target)) eraseJointTarget(*iter);
153 }
154
155 this->setTargetedMode_(); // does nothing if already in targeted mode
156 _joint_targets_.insert(joint_target);
157 onJointTargetAdded_(joint_target);
159 }
160
161 // removes an existing set target
162 template < typename GUM_SCALAR >
164 // check if the nodes in the target belong to the Markov random field
165 if (this->hasNoModel_())
167 "No Markov net has been assigned to the "
168 "inference algorithm");
169
170 const auto& dag = this->MRF().graph();
171 for (const auto node: joint_target) {
172 if (!dag.exists(node)) {
174 "at least one one in " << joint_target << " does not belong to the mn");
175 }
176 }
177
178 // check that the joint_target set does not contain the new target
179 if (_joint_targets_.contains(joint_target)) {
180 // note that we have to be in target mode when we are here
181 // so, no this->setTargetedMode_(); is necessary
182 onJointTargetErased_(joint_target);
183 _joint_targets_.erase(joint_target);
185 }
186 }
187
189 template < typename GUM_SCALAR >
190 INLINE const Set< NodeSet >&
194
196 template < typename GUM_SCALAR >
198 return _joint_targets_.size();
199 }
200
201 // ##############################################################################
202 // Inference
203 // ##############################################################################
204
205 // Compute the posterior of a nodeset.
206 template < typename GUM_SCALAR >
207 const Tensor< GUM_SCALAR >&
209 NodeSet real_nodes;
210 for (const auto& node: nodes) {
211 if (!this->hasHardEvidence(node)) { real_nodes.insert(node); }
212 }
213 // try to get the smallest set of targets that contains "nodes"
214 bool found_exact_target = false;
215 NodeSet super_target;
216
217 if (isExactJointComputable_(real_nodes)) {
218 found_exact_target = true;
219 } else {
220 super_target = superForJointComputable_(real_nodes);
221 if (super_target.empty()) {
223 "No joint target containing " << real_nodes << " could be found among "
224 << _joint_targets_);
225 }
226 }
227
228 if (!this->isInferenceDone()) { this->makeInference(); }
229
230 if (found_exact_target) return jointPosterior_(real_nodes);
231 else { return jointPosterior_(real_nodes, super_target); }
232 }
233
234 // Compute the posterior of a node
235 template < typename GUM_SCALAR >
238 else return jointPosterior(NodeSet{node});
239 }
240
241 // Compute the posterior of a node
242 template < typename GUM_SCALAR >
243 const Tensor< GUM_SCALAR >&
245 return posterior(this->MRF().idFromName(nodeName));
246 }
247
248 // ##############################################################################
249 // Entropy
250 // ##############################################################################
251
252 template < typename GUM_SCALAR >
253 Tensor< GUM_SCALAR >
255 const NodeSet& evs) {
256 if (!(evs * targets).empty()) {
258 "Targets (" << targets << ") can not intersect evs (" << evs << ").");
259 }
260 auto condset = this->MRF().minimalCondSet(targets, evs);
261
262 this->eraseAllTargets();
263 this->eraseAllEvidence();
264
265 Instantiation iTarget;
266 Tensor< GUM_SCALAR > res;
267 for (const auto& target: targets) {
268 res.add(this->MRF().variable(target));
269 iTarget.add(this->MRF().variable(target));
270 }
271 this->addJointTarget(targets);
272
273 for (const auto& n: condset) {
274 res.add(this->MRF().variable(n));
275 this->addEvidence(n, 0);
276 }
277
278 Instantiation inst(res);
279 for (inst.setFirstOut(iTarget); !inst.end(); inst.incOut(iTarget)) {
280 // inferring
281 for (const auto& n: condset)
282 this->chgEvidence(n, inst.val(this->MRF().variable(n)));
283 this->makeInference();
284 // populate res
285 for (inst.setFirstIn(iTarget); !inst.end(); inst.incIn(iTarget)) {
286 res.set(inst, this->jointPosterior(targets)[inst]);
287 }
288 inst.setFirstIn(iTarget); // remove inst.end() flag
289 }
290
291 return res;
292 }
293
294 template < typename GUM_SCALAR >
296 const std::vector< std::string >& targets,
297 const std::vector< std::string >& evs) {
298 const auto& mn = this->MRF();
299 return evidenceJointImpact(mn.nodeset(targets), mn.nodeset(evs));
300 }
301
302 template < typename GUM_SCALAR >
303 GUM_SCALAR
305 const auto& mn = this->MRF();
306 const Size siz = targets.size();
307 if (siz <= 1) {
309 "jointMutualInformation needs at least 2 variables (targets=" << targets << ")");
310 }
311
312 this->eraseAllTargets();
313 this->eraseAllEvidence();
314 this->addJointTarget(targets);
315 this->makeInference();
316 const auto po = this->jointPosterior(targets);
317
318 gum::Instantiation caracteristic;
319 gum::Instantiation variables;
320 for (const auto nod: targets) {
321 const auto& var = mn.variable(nod);
322 auto pv = new gum::RangeVariable(var.name(), "", 0, 1);
323 caracteristic.add(*pv);
324 variables.add(var);
325 }
326
328
329 const GUM_SCALAR start = (siz % 2 == 0) ? GUM_SCALAR(-1.0) : GUM_SCALAR(1.0);
330 GUM_SCALAR sign;
331 GUM_SCALAR res = GUM_SCALAR(0.0);
332
333 caracteristic.setFirst();
334 for (caracteristic.inc(); !caracteristic.end(); caracteristic.inc()) {
335 sov.clear();
336 sign = start;
337 for (Idx i = 0; i < caracteristic.nbrDim(); i++) {
338 if (caracteristic.val(i) == 1) {
339 sign = -sign;
340 sov.insert(&variables.variable(i));
341 }
342 }
343 res += sign * po.sumIn(sov).entropy();
344 }
345
346 for (Idx i = 0; i < caracteristic.nbrDim(); i++) {
347 delete &caracteristic.variable(i);
348 }
349
350 return res;
351 }
352
353 template < typename GUM_SCALAR >
355 const std::vector< std::string >& targets) {
356 return jointMutualInformation(this->MRF().nodeset(targets));
357 }
358
359 template < typename GUM_SCALAR >
361 if (_joint_targets_.contains(vars)) return true;
362
363 return false;
364 }
365
366 template < typename GUM_SCALAR >
368 for (const auto& target: _joint_targets_)
369 if (vars.isSubsetOrEqual(target)) return target;
370
371 for (const auto& factor: this->MRF().factors()) {
372 if (vars.isSubsetOrEqual(factor.first)) return factor.first;
373 }
374
375 return NodeSet();
376 }
377
378} /* namespace gum */
virtual bool hasHardEvidence(NodeId id) const final
indicates whether node id has received a hard evidence
virtual void setState_(const StateOfInference state) final
set the state of the inference engine and call the notification onStateChanged_ when necessary (i....
virtual void chgEvidence(NodeId id, const Idx val) final
change the value of an already existing hard evidence
virtual void addEvidence(NodeId id, const Idx val) final
adds a new hard evidence on node id
virtual void eraseAllEvidence() final
removes all the evidence entered into the network
virtual void makeInference() final
perform the heavy computations needed to compute the targets' posteriors
virtual bool isInferenceDone() const noexcept final
returns whether the inference object is in a InferenceDone state
Virtual base class for probabilistic graphical models.
Class representing the minimal interface for Markov random field.
Class for assigning/browsing values to tuples of discrete variables.
void incIn(const Instantiation &i)
Operator increment for the variables in i.
bool end() const
Returns true if the Instantiation reached the end.
void incOut(const Instantiation &i)
Operator increment for the variables not in i.
void setFirstIn(const Instantiation &i)
Assign the first values in the Instantiation for the variables in i.
void add(const DiscreteVariable &v) final
Adds a new variable in the Instantiation.
void inc()
Operator increment.
Idx val(Idx i) const
Returns the current value of the variable at position i.
void setFirst()
Assign the first values to the tuple of the Instantiation.
void setFirstOut(const Instantiation &i)
Assign the first values in the Instantiation for the variables not in i.
const DiscreteVariable & variable(Idx i) const final
Returns the variable at position i in the tuple.
Idx nbrDim() const final
Returns the number of variables in the Instantiation.
Exception: at least one argument passed to a function is not what was expected.
virtual void eraseJointTarget(const NodeSet &joint_target) final
removes an existing joint target
virtual void eraseAllJointTargets() final
Clear all previously defined joint targets.
virtual Size nbrJointTargets() const noexcept final
returns the number of joint targets
GUM_SCALAR jointMutualInformation(const NodeSet &targets)
Mutual information between targets.
virtual void eraseAllMarginalTargets() final
Clear all the previously defined marginal targets.
Set< NodeSet > _joint_targets_
the set of joint targets
virtual bool isJointTarget(const NodeSet &vars) const final
return true if target is a joint target.
virtual const Tensor< GUM_SCALAR > & jointPosterior(const NodeSet &nodes) final
Compute the joint posterior of a set of nodes.
virtual void addJointTarget(const NodeSet &joint_target) final
Add a set of nodes as a new joint target. As a collateral effect, every node is added as a marginal t...
virtual const Set< NodeSet > & jointTargets() const noexcept final
returns the list of joint targets
virtual void onJointTargetErased_(const NodeSet &set)=0
fired before a joint target is removed
virtual const Tensor< GUM_SCALAR > & posterior(NodeId node) final
Computes and returns the posterior of a node.
virtual const Tensor< GUM_SCALAR > & jointPosterior_(const NodeSet &set)=0
asks derived classes for the joint posterior of a declared target set
virtual void onAllJointTargetsErased_()=0
fired before a all the joint targets are removed
virtual void onJointTargetAdded_(const NodeSet &set)=0
fired after a new joint target is inserted
virtual bool isExactJointComputable_(const NodeSet &vars)
check if the vars form a possible computable joint (can be redefined by subclass)
Tensor< GUM_SCALAR > evidenceJointImpact(const NodeSet &targets, const NodeSet &evs)
Create a gum::Tensor for P(joint targets|evs) (for all instanciation of targets and evs).
JointTargetedMRFInference(const IMarkovRandomField< GUM_SCALAR > *mn)
default constructor
virtual NodeSet superForJointComputable_(const NodeSet &vars)
virtual void eraseAllTargets() final
Clear all previously defined targets (marginal and joint targets).
virtual void onModelChanged_(const GraphicalModel *mn) override
fired after a new Markov net has been assigned to the engine
virtual const IMarkovRandomField< GUM_SCALAR > & MRF() const final
Returns a constant reference over the IMarkovRandomField referenced by this class.
void _setMRFDuringConstruction_(const IMarkovRandomField< GUM_SCALAR > *mn)
assigns a MRF during the inference engine construction
virtual bool isTarget(NodeId node) const final
return true if variable is a (marginal) target
virtual const NodeSet & targets() const noexcept final
returns the list of marginal targets
virtual void onModelChanged_(const GraphicalModel *mn)
fired after a new Markov net has been assigned to the engine
MarginalTargetedMRFInference(const IMarkovRandomField< GUM_SCALAR > *mn)
default constructor
virtual const Tensor< GUM_SCALAR > & posterior(NodeId node)
Computes and returns the posterior of a node.
virtual void eraseAllTargets()
Clear all previously defined targets.
Exception : a pointer or a reference on a nullptr (0) object.
Defines a discrete random variable over an integer interval.
Representation of a set.
Definition set.h:131
bool isSubsetOrEqual(const Set< Key > &s) const
Definition set_tpl.h:517
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
bool empty() const noexcept
Indicates whether the set is the empty set.
Definition set_tpl.h:642
void clear()
Removes all the elements, if any, from the set.
Definition set_tpl.h:338
Exception : a looked-for element could not be found.
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size Idx
Type for indexes.
Definition types.h:79
Size NodeId
Type for node ids.
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
Class encapsulating computations of notions from Information Theory.
This file contains the abstract inference class definition for computing (incrementally) joint poster...
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Set< const DiscreteVariable * > VariableSet
Header of gumRangeVariable.