aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
searchStrategy_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
50
51namespace gum {
52 namespace prm {
53 namespace gspan {
54
55 template < typename GUM_SCALAR >
57 double cost = 0;
59 = *(this->tree_->data(p).iso_map.begin().val());
61
62 for (const auto inst: seq) {
63 for (const auto input: inst->type().slotChains())
64 for (const auto inst2: inst->getInstances(input->id()))
65 if ((!seq.exists(inst2))
66 && (!input_set.exists(&(inst2->get(input->lastElt().safeName()))))) {
67 cost += std::log(input->type().variable().domainSize());
68 input_set.insert(&(inst2->get(input->lastElt().safeName())));
69 }
70
71 for (auto vec = inst->beginInvRef(); vec != inst->endInvRef(); ++vec)
72 for (const auto& inverse: *vec.val())
73 if (!seq.exists(inverse.first)) {
74 cost += std::log(inst->get(vec.key()).type().variable().domainSize());
75 break;
76 }
77 }
78
79 return cost;
80 }
81
82 template < typename GUM_SCALAR >
85 Set< Tensor< GUM_SCALAR >* >& pool,
86 const Sequence< PRMInstance< GUM_SCALAR >* >& match) {
87 for (const auto inst: match) {
88 for (const auto& elt: *inst) {
89 // Adding the node
90 NodeId id = data.graph.addNode();
91 data.node2attr.insert(id, _str_(inst, elt.second));
92 data.mod.insert(id, elt.second->type()->domainSize());
93 data.vars.insert(id, &elt.second->type().variable());
94 pool.insert(const_cast< Tensor< GUM_SCALAR >* >(&(elt.second->cpf())));
95 }
96 }
97
98 // Second we add edges and nodes to inners or outputs
99 for (const auto inst: match)
100 for (const auto& elt: *inst) {
101 NodeId node = data.node2attr.first(_str_(inst, elt.second));
102 bool found = false; // If this is set at true, then node is an outer node
103
104 // Children existing in the instance type's DAG
105 for (const auto chld: inst->type().containerDag().children(elt.second->id())) {
106 data.graph.addEdge(node, data.node2attr.first(_str_(inst, inst->get(chld))));
107 }
108
109 // Parents existing in the instance type's DAG
110 for (const auto par: inst->type().containerDag().parents(elt.second->id())) {
111 switch (inst->type().get(par).elt_type()) {
114 data.graph.addEdge(node, data.node2attr.first(_str_(inst, inst->get(par))));
115 break;
116 }
117
119 for (const auto inst2: inst->getInstances(par))
120 if (match.exists(inst2))
121 data.graph.addEdge(node,
122 data.node2attr.first(
123 _str_(inst2,
124 static_cast< const PRMSlotChain< GUM_SCALAR >& >(
125 inst->type().get(par)))));
126
127 break;
128 }
129
130 default : { /* Do nothing */
131 }
132 }
133 }
134
135 // Referring PRMAttribute<GUM_SCALAR>
136 if (inst->hasRefAttr(elt.second->id())) {
137 const std::vector< std::pair< PRMInstance< GUM_SCALAR >*, std::string > >& ref_attr
138 = inst->getRefAttr(elt.second->id());
139
140 for (auto pair = ref_attr.begin(); pair != ref_attr.end(); ++pair) {
141 if (match.exists(pair->first)) {
142 NodeId id = pair->first->type().get(pair->second).id();
143
144 for (const auto child: pair->first->type().containerDag().children(id))
145 data.graph.addEdge(
146 node,
147 data.node2attr.first(_str_(pair->first, pair->first->get(child))));
148 } else {
149 found = true;
150 }
151 }
152 }
153
154 if (found) data.outputs.insert(node);
155 else data.inners.insert(node);
156 }
157 }
158
159 template < typename GUM_SCALAR >
162 Set< Tensor< GUM_SCALAR >* >& pool) {
163 List< NodeSet > partial_order;
164
165 if (data.inners.size()) partial_order.insert(data.inners);
166
167 if (data.outputs.size()) partial_order.insert(data.outputs);
168
169 PartialOrderedTriangulation t(&(data.graph), &(data.mod), &partial_order);
170 const std::vector< NodeId >& elim_order = t.eliminationOrder();
171 Size max(0), max_count(1);
173 Tensor< GUM_SCALAR >* pot = 0;
174
175 for (size_t idx = 0; idx < data.inners.size(); ++idx) {
176 pot = new Tensor< GUM_SCALAR >(new MultiDimSparse< GUM_SCALAR >(0));
177 pot->add(*(data.vars.second(elim_order[idx])));
178 trash.insert(pot);
179 Set< Tensor< GUM_SCALAR >* > toRemove;
180
181 for (const auto p: pool)
182 if (p->contains(*(data.vars.second(elim_order[idx])))) {
183 for (auto var = p->variablesSequence().begin(); var != p->variablesSequence().end();
184 ++var) {
185 try {
186 pot->add(**var);
187 } catch (DuplicateElement const&) {}
188 }
189
190 toRemove.insert(p);
191 }
192
193 if (pot->domainSize() > max) {
194 max = pot->domainSize();
195 max_count = 1;
196 } else if (pot->domainSize() == max) {
197 ++max_count;
198 }
199
200 for (const auto p: toRemove)
201 pool.erase(p);
202
203 pot->erase(*(data.vars.second(elim_order[idx])));
204 }
205
206 for (const auto pot: trash)
207 delete pot;
208
209 return std::make_pair(max, max_count);
210 }
211
212 // The SearchStrategy class
213 template < typename GUM_SCALAR >
215 GUM_CONSTRUCTOR(SearchStrategy);
216 }
217
218 template < typename GUM_SCALAR >
219 INLINE
224
225 template < typename GUM_SCALAR >
229
230 template < typename GUM_SCALAR >
233 this->tree_ = from.tree_;
234 return *this;
235 }
236
237 template < typename GUM_SCALAR >
238 INLINE void SearchStrategy< GUM_SCALAR >::setTree(DFSTree< GUM_SCALAR >* tree) {
239 this->tree_ = tree;
240 }
241
242 // FrequenceSearch
243
244 // The FrequenceSearch class
245 template < typename GUM_SCALAR >
247 SearchStrategy< GUM_SCALAR >(), _freq_(freq) {
248 GUM_CONSTRUCTOR(FrequenceSearch);
249 }
250
251 template < typename GUM_SCALAR >
253 const FrequenceSearch< GUM_SCALAR >& from) :
254 SearchStrategy< GUM_SCALAR >(from), _freq_(from._freq_) {
255 GUM_CONS_CPY(FrequenceSearch);
256 }
257
258 template < typename GUM_SCALAR >
262
263 template < typename GUM_SCALAR >
269
270 template < typename GUM_SCALAR >
272 return this->tree_->frequency(*r) >= _freq_;
273 }
274
275 template < typename GUM_SCALAR >
276 INLINE bool
278 const Pattern* child,
279 const EdgeGrowth< GUM_SCALAR >& growh) {
280 return this->tree_->frequency(*child) >= _freq_;
281 }
282
283 template < typename GUM_SCALAR >
285 // We want a descending order
286 return this->tree_->frequency(*i) > this->tree_->frequency(*j);
287 }
288
289 template < typename GUM_SCALAR >
291 return (this->tree_->graph().size(i) > this->tree_->graph().size(j));
292 }
293
294 // StrictSearch
295
296 // The StrictSearch class
297 template < typename GUM_SCALAR >
299 SearchStrategy< GUM_SCALAR >(), _freq_(freq), _dot_(".") {
300 GUM_CONSTRUCTOR(StrictSearch);
301 }
302
303 template < typename GUM_SCALAR >
305 SearchStrategy< GUM_SCALAR >(from), _freq_(from._freq_) {
306 GUM_CONS_CPY(StrictSearch);
307 }
308
309 template < typename GUM_SCALAR >
311 GUM_DESTRUCTOR(StrictSearch);
312 }
313
314 template < typename GUM_SCALAR >
317 _freq_ = from._freq_;
318 return *this;
319 }
320
321 template < typename GUM_SCALAR >
323 return (this->tree_->frequency(*r) >= _freq_);
324 }
325
326 template < typename GUM_SCALAR >
327 INLINE bool
329 const Pattern* child,
330 const EdgeGrowth< GUM_SCALAR >& growth) {
331 return _inner_cost_(child) + this->tree_->frequency(*child) * _outer_cost_(child)
332 < this->tree_->frequency(*child) * _outer_cost_(parent);
333 }
334
335 template < typename GUM_SCALAR >
337 return _inner_cost_(i) + this->tree_->frequency(*i) * _outer_cost_(i)
338 < _inner_cost_(j) + this->tree_->frequency(*j) * _outer_cost_(j);
339 }
340
341 template < typename GUM_SCALAR >
343 return i->tree_width * this->tree_->graph().size(i)
344 < j->tree_width * this->tree_->graph().size(j);
345 }
346
347 template < typename GUM_SCALAR >
349 try {
350 return _map_[p].first;
351 } catch (NotFound const&) {
353 return _map_[p].first;
354 }
355 }
356
357 template < typename GUM_SCALAR >
359 try {
360 return _map_[p].second;
361 } catch (NotFound const&) {
363 return _map_[p].second;
364 }
365 }
366
367 template < typename GUM_SCALAR >
368 INLINE std::string
370 const PRMAttribute< GUM_SCALAR >* a) const {
371 return i->name() + _dot_ + a->safeName();
372 }
373
374 template < typename GUM_SCALAR >
375 INLINE std::string
377 const PRMAttribute< GUM_SCALAR >& a) const {
378 return i->name() + _dot_ + a.safeName();
379 }
380
381 template < typename GUM_SCALAR >
382 INLINE std::string
384 const PRMSlotChain< GUM_SCALAR >& a) const {
385 return i->name() + _dot_ + a.lastElt().safeName();
386 }
387
388 template < typename GUM_SCALAR >
392 _buildPatternGraph_(data, pool, *(this->tree_->data(*p).iso_map.begin().val()));
393 double inner = std::log(_elimination_cost_(data, pool).first);
394 double outer = this->computeCost_(*p);
395 _map_.insert(p, std::make_pair(inner, outer));
396 }
397
398 // TreeWidthSearch
399
400 template < typename GUM_SCALAR >
402 GUM_CONSTRUCTOR(TreeWidthSearch);
403 }
404
405 template < typename GUM_SCALAR >
407 const TreeWidthSearch< GUM_SCALAR >& from) : SearchStrategy< GUM_SCALAR >(from) {
408 GUM_CONS_CPY(TreeWidthSearch);
409 }
410
411 template < typename GUM_SCALAR >
415
416 template < typename GUM_SCALAR >
421
422 template < typename GUM_SCALAR >
424 try {
425 return _map_[&p];
426 } catch (NotFound const&) {
427 _map_.insert(&p, this->computeCost_(p));
428 return _map_[&p];
429 }
430 }
431
432 template < typename GUM_SCALAR >
434 Size tree_width = 0;
435
436 for (const auto n: r->nodes())
437 tree_width += r->label(n).tree_width;
438
439 return tree_width >= cost(*r);
440 }
441
442 template < typename GUM_SCALAR >
443 INLINE bool
445 const Pattern* child,
446 const EdgeGrowth< GUM_SCALAR >& growth) {
447 return cost(*parent) >= cost(*child);
448 }
449
450 template < typename GUM_SCALAR >
452 return cost(*i) < cost(*j);
453 }
454
455 template < typename GUM_SCALAR >
459
460 } /* namespace gspan */
461 } /* namespace prm */
462} /* namespace gum */
Exception : a similar element already exists.
Generic doubly linked lists.
Definition list.h:379
Val & insert(const Val &val)
Inserts a new element at the end of the chained list (alias of pushBack).
Definition list_tpl.h:1515
Multidimensional matrix stored as a sparse array in memory.
virtual NodeId addNode()
insert a new node and return its id
Exception : the element we looked for cannot be found.
class for graph triangulations for which we enforce a given partial ordering on the nodes elimination...
bool exists(const Key &k) const
Check the existence of k in the sequence.
void insert(const Key &k)
Insert an element at the end of the sequence.
The generic class for storing (ordered) sequences of objects.
Definition sequence.h:972
Representation of a set.
Definition set.h:131
Size size() const noexcept
Returns the number of elements in the set.
Definition set_tpl.h:636
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
void erase(const Key &k)
Erases an element from the set.
Definition set_tpl.h:582
const std::vector< NodeId > & eliminationOrder()
returns an elimination ordering compatible with the triangulated graph
void addEdge(NodeId first, NodeId second) override
insert a new edge into the undirected graph
PRMAttribute is a member of a Class in a PRM.
const std::string & safeName() const
Returns the safe name of this PRMClassElement, if any.
An PRMInstance is a Bayesian network fragment defined by a Class and used in a PRMSystem.
Definition PRMInstance.h:79
const std::string & name() const
Returns the name of this object.
A PRMSlotChain represents a sequence of gum::prm::PRMClassElement<GUM_SCALAR> where the n-1 first gum...
PRMClassElement< GUM_SCALAR > & lastElt()
Returns the last element of the slot chain, typically this is an gum::PRMAttribute or a gum::PRMAggre...
This class is used to define an edge growth of a pattern in this DFSTree.
Definition edgeGrowth.h:73
This is class is an implementation of a simple serach strategy for the gspan algorithm: it accept a g...
FrequenceSearch(Size freq)
Default constructor.
virtual bool operator()(LabelData *i, LabelData *j)
virtual bool accept_root(const Pattern *r)
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
FrequenceSearch & operator=(const FrequenceSearch &from)
Copy operator.
This contains all the information we want for a node in a DFSTree.
Definition pattern.h:90
const NodeGraphPart & nodes() const
LabelData & label(NodeId node)
Returns the LabelData assigned to node.
Definition pattern_inl.h:75
This is an abstract class used to tune search strategies in the gspan algorithm.
double computeCost_(const Pattern &p)
SearchStrategy< GUM_SCALAR > & operator=(const SearchStrategy< GUM_SCALAR > &from)
Copy operator.
DFSTree< GUM_SCALAR > * tree_
void setTree(DFSTree< GUM_SCALAR > *tree)
This is class is an implementation of a strict strategy for the GSpan algorithm.
double _outer_cost_(const Pattern *p)
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
StrictSearch(Size freq=2)
Default constructor.
virtual bool accept_root(const Pattern *r)
StrictSearch & operator=(const StrictSearch &from)
Copy operator.
double _inner_cost_(const Pattern *p)
HashTable< const Pattern *, std::pair< double, double > > _map_
virtual bool operator()(LabelData *i, LabelData *j)
void _compute_costs_(const Pattern *p)
void _buildPatternGraph_(typename StrictSearch< GUM_SCALAR >::PData &data, Set< Tensor< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match)
std::pair< Size, Size > _elimination_cost_(typename StrictSearch< GUM_SCALAR >::PData &data, Set< Tensor< GUM_SCALAR > * > &pool)
std::string _str_(const PRMInstance< GUM_SCALAR > *i, const PRMAttribute< GUM_SCALAR > *a) const
A growth is accepted if and only if the new growth has a tree width less large or equal than its fath...
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
HashTable< const Pattern *, double > _map_
TreeWidthSearch & operator=(const TreeWidthSearch &from)
Copy operator.
virtual bool operator()(LabelData *i, LabelData *j)
virtual bool accept_root(const Pattern *r)
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size NodeId
Type for node ids.
namespace for all probabilistic relational models entities
Definition agrum.h:68
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Headers of the SearchStrategy class and child.
Inner class to handle data about labels in this interface graph.
Size tree_width
The size in terms of tree width of the given label.
Private structure to represent data about a pattern.
Bijection< NodeId, std::string > node2attr
A bijection to easily keep track between graph and attributes, its of the form instance_name DOT attr...
NodeProperty< Size > mod
The pattern's variables modalities.
UndiGraph graph
A yet to be triangulated undigraph.
NodeSet outputs
Returns the set of outputs nodes given all the matches of pattern.
NodeSet inners
Returns the set of inner nodes.
Bijection< NodeId, const DiscreteVariable * > vars
Bijection between graph's nodes and their corresponding DiscreteVariable, for inference purpose.