aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
structuredInference.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40
41
48
49#ifndef GUM_STRUCTURED_INFERENCE_H
50#define GUM_STRUCTURED_INFERENCE_H
51
52#include <string>
53
56
57namespace gum {
58 namespace prm {
59
68 template < typename GUM_SCALAR >
69 class StructuredInference: public PRMInference< GUM_SCALAR > {
70 public:
71 // ========================================================================
73 // ========================================================================
75
78 const PRMSystem< GUM_SCALAR >& system,
80
83
85 virtual ~StructuredInference();
86
89
91 // ========================================================================
93 // ========================================================================
95
97 void setPatternMining(bool b);
98
99 virtual std::string name() const;
100
103
105 const GSpan< GUM_SCALAR >& gspan() const;
106
108 void searchPatterns();
109
111
112 protected:
113 // ========================================================================
115 // ========================================================================
117
119 virtual void evidenceAdded_(const typename PRMInference< GUM_SCALAR >::Chain& chain);
120
122 virtual void evidenceRemoved_(const typename PRMInference< GUM_SCALAR >::Chain& chain);
123
125 virtual void posterior_(const typename PRMInference< GUM_SCALAR >::Chain& chain,
126 Tensor< GUM_SCALAR >& m);
127
129 virtual void joint_(const std::vector< typename PRMInference< GUM_SCALAR >::Chain >& queries,
130 Tensor< GUM_SCALAR >& j);
131
133
134 private:
159
161 struct PData {
189 PData(const PData& source);
191 ~PData();
192
194 inline NodeSet& inners() { return _partial_order_[0]; }
195
198 inline NodeSet& obs() { return _partial_order_[1]; }
199
201 inline NodeSet& outputs() { return _partial_order_[2]; }
202
204 inline NodeSet& queries() { return _partial_order_[3]; }
205
206 // We use the first match for computations
207 // inline const Sequence<PRMInstance<GUM_SCALAR>*>& match() const {
208 // return
209 // **(matches.begin());}
210 // Remove any empty set in partial_order
212
213 private:
221 };
222
261
264
270
276
279
281
284
287
291
294
297 std::pair< Idx, std::string > _query_data_;
298
304 void _buildReduceGraph_(RGData& data);
305
307 // MSVC void _addNodesInReducedGraph_( RGData& data );
308
311
312 void _removeNode_(typename StructuredInference::PData& data,
313 NodeId id,
314 Set< Tensor< GUM_SCALAR >* >& pool);
315
317 void _reduceAloneInstances_(RGData& data);
318
324 void _reducePattern_(const gspan::Pattern* p);
325
331 void _buildPatternGraph_(PData& data,
332 Set< Tensor< GUM_SCALAR >* >& pool,
333 const Sequence< PRMInstance< GUM_SCALAR >* >& match);
334
336 const Sequence< PRMInstance< GUM_SCALAR >* >& match,
339 NodeId id,
340 std::pair< Idx, std::string >& v);
341
343 std::pair< Idx, std::string > attr);
344
346 Set< Tensor< GUM_SCALAR >* >& pool);
347
350 // MVSC void _buildQuerySet_( PData& data );
351
356 const Set< Tensor< GUM_SCALAR >* >& pool,
357 const Sequence< PRMInstance< GUM_SCALAR >* >& match,
358 const std::vector< NodeId >& elim_order);
359
362 const Set< Tensor< GUM_SCALAR >* >& pool,
363 const Sequence< PRMInstance< GUM_SCALAR >* >& match,
364 const std::vector< NodeId >& elim_order);
365
369 const Set< Tensor< GUM_SCALAR >* >& pool,
370 const Sequence< PRMInstance< GUM_SCALAR >* >& match);
371
373 // MVSC void _unreduceMatchWithQuery_();
374
375 // MVSC std::vector<NodeId>* _getClassOutputs_( const
376 // PRMClass<GUM_SCALAR>* c );
378 std::string _dot_;
379 std::string _str_(const PRMInstance< GUM_SCALAR >* i,
380 const PRMAttribute< GUM_SCALAR >* a) const;
381 std::string _str_(const PRMInstance< GUM_SCALAR >* i,
382 const PRMAttribute< GUM_SCALAR >& a) const;
383 std::string _str_(const PRMInstance< GUM_SCALAR >* i,
384 const PRMSlotChain< GUM_SCALAR >& a) const;
385
386 public:
387 // For bench/debug purpose.
394 double obs_time;
395 double full_time;
396 std::string info() const;
397 };
398
399
400#ifndef GUM_NO_EXTERN_TEMPLATE_CLASS
401 extern template class StructuredInference< double >;
402#endif
403
404
405 } /* namespace prm */
406} /* namespace gum */
407
409
410#endif /* GUM_STRUCTURED_INFERENCE_H */
Headers of PRMInference.
Set of pairs of elements with fast search for both elements.
Definition bijection.h:1594
The class for generic Hash Tables.
Definition hashTable.h:637
Generic doubly linked lists.
Definition list.h:379
The generic class for storing (ordered) sequences of objects.
Definition sequence.h:972
Representation of a set.
Definition set.h:131
Class used to compute response times for benchmark purposes.
Definition timer.h:69
Base class for undirected graphs.
Definition undiGraph.h:128
This class discovers pattern in a PRM<GUM_SCALAR>'s PRMSystem<GUM_SCALAR> to speed up structured infe...
Definition gspan.h:86
Set< Sequence< PRMInstance< GUM_SCALAR > * > * > MatchedInstances
Code alias.
Definition gspan.h:185
PRMAttribute is a member of a Class in a PRM.
A PRMClass is an object of a PRM representing a fragment of a Bayesian network which can be instantia...
Definition PRMClass.h:75
std::pair< const PRMInstance< GUM_SCALAR > *, const PRMAttribute< GUM_SCALAR > * > Chain
Code alias.
PRMInference(const PRM< GUM_SCALAR > &prm, const PRMSystem< GUM_SCALAR > &system)
Default constructor.
An PRMInstance is a Bayesian network fragment defined by a Class and used in a PRMSystem.
Definition PRMInstance.h:79
A PRMSlotChain represents a sequence of gum::prm::PRMClassElement<GUM_SCALAR> where the n-1 first gum...
A PRMSystem is a container of PRMInstance and describe a relational skeleton.
Definition PRMSystem.h:70
This class represents a Probabilistic Relational PRMSystem<GUM_SCALAR>.
Definition PRM.h:74
<agrum/PRM/structuredInference.h>
void searchPatterns()
Search for patterns without doing any computations.
Set< Tensor< GUM_SCALAR > * > * _eliminateObservedNodes_(typename StructuredInference::PData &data, const Set< Tensor< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match, const std::vector< NodeId > &elim_order)
Add in data.queries() any queried variable in one of data.pattern matches.
HashTable< const PRMClass< GUM_SCALAR > *, CData * > _cdata_map_
Mapping between a Class<GUM_SCALAR> and data about instances reduced using only Class<GUM_SCALAR> lev...
void _reducePattern_(const gspan::Pattern *p)
Proceed with the elimination of all inner variables (observed or not) of all usable matches of Patter...
std::string _dot_
Unreduce the match containing the query.
void setPatternMining(bool b)
Tells this algorithm to use pattern mining or not.
void _buildReduceGraph_(RGData &data)
This calls reducePattern() over each pattern and then build the reduced graph which is used for infer...
void _removeBarrenNodes_(typename StructuredInference::PData &data, Set< Tensor< GUM_SCALAR > * > &pool)
HashTable< const Sequence< PRMInstance< GUM_SCALAR > * > *, Set< Tensor< GUM_SCALAR > * > * > _elim_map_
Mapping between a Pattern's match and its tensor pool after inner variables were eliminated.
HashTable< const PRMClass< GUM_SCALAR > *, std::vector< NodeId > * > _outputs_
PRMInference< GUM_SCALAR >::Chain _query_
The query.
virtual void posterior_(const typename PRMInference< GUM_SCALAR >::Chain &chain, Tensor< GUM_SCALAR > &m)
See PRMInference::posterior_().
void _removeNode_(typename StructuredInference::PData &data, NodeId id, Set< Tensor< GUM_SCALAR > * > &pool)
void _buildPatternGraph_(PData &data, Set< Tensor< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match)
Build the DAG corresponding to Pattern data.pattern, initialize pool with all the Tensors of all vari...
PData * _pdata_
The pattern data of the pattern which one of its matches contains the query.
Set< Tensor< GUM_SCALAR > * > * _translatePotSet_(typename StructuredInference::PData &data, const Set< Tensor< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match)
Translate a given Tensor Set into one w.r.t. variables in match.
void _insertNodeInElimLists_(typename StructuredInference::PData &data, const Sequence< PRMInstance< GUM_SCALAR > * > &match, PRMInstance< GUM_SCALAR > *inst, PRMAttribute< GUM_SCALAR > *attr, NodeId id, std::pair< Idx, std::string > &v)
std::string _str_(const PRMInstance< GUM_SCALAR > *i, const PRMAttribute< GUM_SCALAR > *a) const
StructuredInference & operator=(const StructuredInference &source)
Copy operator.
void _addEdgesInReducedGraph_(RGData &data)
Add the nodes in the reduced graph.
bool _mining_
Flag which tells to use pattern mining or not.
GSpan< GUM_SCALAR > * _gspan_
Pointer over th GSpan<GUM_SCALAR> instance used by this class.
GSpan< GUM_SCALAR > & gspan()
Returns the instance of gspan used to search patterns.
bool _found_query_
Flag with an explicit name.
virtual void evidenceAdded_(const typename PRMInference< GUM_SCALAR >::Chain &chain)
See PRMInference::evidenceAdded_().
bool _allInstanceNoRefAttr_(typename StructuredInference::PData &data, std::pair< Idx, std::string > attr)
virtual void joint_(const std::vector< typename PRMInference< GUM_SCALAR >::Chain > &queries, Tensor< GUM_SCALAR > &j)
See PRMInference::joint_().
virtual std::string name() const
Tells this algorithm to use pattern mining or not.
virtual void evidenceRemoved_(const typename PRMInference< GUM_SCALAR >::Chain &chain)
See PRMInference::evidenceRemoved_().
Set< const PRMInstance< GUM_SCALAR > * > _reducedInstances_
This keeps track of reduced instances.
StructuredInference(const PRM< GUM_SCALAR > &prm, const PRMSystem< GUM_SCALAR > &system, gspan::SearchStrategy< GUM_SCALAR > *strategy=0)
Default constructor.
Set< Tensor< GUM_SCALAR > * > * _eliminateObservedNodesInSource_(typename StructuredInference::PData &data, const Set< Tensor< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match, const std::vector< NodeId > &elim_order)
Set< Tensor< GUM_SCALAR > * > _trash_
Keeping track of create tensors to delete them after inference.
void _reduceAloneInstances_(RGData &data)
Add the reduced tensors of instances not in any used patterns.
std::pair< Idx, std::string > _query_data_
This contains all the information we want for a node in a DFSTree.
Definition pattern.h:90
This is an abstract class used to tune search strategies in the gspan algorithm.
Size NodeId
Type for node ids.
HashTable< NodeId, VAL > NodeProperty
Property on graph elements.
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
Headers of gspan.
namespace for all probabilistic relational models entities
Definition agrum.h:68
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Private structure to represent data about a Class<GUM_SCALAR>.
NodeSet & aggregators()
Returns the set of aggregators and their parents.
Set< Tensor< GUM_SCALAR > * > _trash_
List< NodeSet > partial_order
The partial order used of variable elimination.
Set< const PRMInstance< GUM_SCALAR > * > instances
The Set of Instances reduces at class level.
NodeSet & outputs()
Returns the set of outputs nodes.
UndiGraph moral_graph
The class moral graph. NodeId matches those in c.
CData(const PRMClass< GUM_SCALAR > &c)
Default constructor.
std::vector< NodeId > & elim_order()
The elimination order for nodes of this class.
Set< Tensor< GUM_SCALAR > * > pool
The tensor pool obtained by C elimination of inner nodes.
NodeSet & inners()
Returns the set of inner nodes.
const PRMClass< GUM_SCALAR > & c
The class about what this data is about.
NodeProperty< Size > mods
The class variables modalities.
Private structure to represent data about a pattern.
List< NodeSet > _partial_order_
We'll use a PartialOrderedTriangulation with three sets: output, nodes and obs with children outside ...
PData(const PData &source)
Copy constructor.
NodeSet & obs()
Returns the set of inner and observed nodes given all the matches of pattern.
NodeSet & queries()
Returns the set of queried nodes given all the matches of pattern.
Set< NodeId > barren
Set of barren nodes.
Bijection< NodeId, std::string > node2attr
A bijection to easily keep track between graph and attributes, its of the form instance_name DOT attr...
NodeProperty< std::pair< Idx, std::string > > map
To ease translating tensors from one match to another.
UndiGraph graph
A yet to be triangulated undigraph.
NodeSet & outputs()
Returns the set of outputs nodes given all the matches of pattern.
const gspan::Pattern & pattern
The pattern for which this represents data about it.
NodeSet & inners()
Returns the set of inner nodes.
NodeProperty< Tensor< GUM_SCALAR > * > pots
To handle barren nodes.
Bijection< NodeId, const DiscreteVariable * > vars
Bijection between graph's nodes and their corresponding DiscreteVariable, for inference purpose.
List< NodeSet > * _real_order_
A copy of partial_order without empty sets.
PData(const gspan::Pattern &p, typename GSpan< GUM_SCALAR >::MatchedInstances &m)
Default constructor.
NodeProperty< Size > mod
The pattern's variables modalities.
GSpan< GUM_SCALAR >::MatchedInstances & matches
A reference over the usable matches of pattern.
Private structure to represent data about a reduced graph.
List< NodeSet > partial_order
Partial order used for triangulation, first is outputs nodes, second query nodes.
Set< Tensor< GUM_SCALAR > * > pool
The pool of tensors matching the reduced graph.
Bijection< const DiscreteVariable *, NodeId > var2node
Mapping between DiscreteVariable and NodeId.
NodeSet & queries()
Returns the set of query nodes (which will not be eliminated).
UndiGraph reducedGraph
The reduced graph.
NodeSet & outputs()
Returns the set of outputs nodes (which will be eliminated).
NodeProperty< Size > mods
Mapping between NodeId and modalities.
Inline implementation of StructuredInference.