aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
imddi_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
50// =======================================================
53
55// =======================================================
58// =======================================================
60
61// =======================================================
62
63
64namespace gum {
65
66 // ############################################################################
67 // Constructor & destructor.
68 // ############################################################################
69
70 // ============================================================================
71 // Variable Learner constructor
72 // ============================================================================
73 template < TESTNAME AttributeSelection, bool isScalar >
75 double attributeSelectionThreshold,
76 double pairSelectionThreshold,
77 gum::VariableSet attributeListe,
78 const DiscreteVariable* learnedValue) :
79 IncrementalGraphLearner< AttributeSelection, isScalar >(target, attributeListe, learnedValue),
80 _lg_(&(this->model_), pairSelectionThreshold), _nbTotalObservation_(0),
81 _attributeSelectionThreshold_(attributeSelectionThreshold) {
82 GUM_CONSTRUCTOR(IMDDI);
83 _addLeaf_(this->root_);
84 }
85
86 // ============================================================================
87 // Reward Learner constructor
88 // ============================================================================
89 template < TESTNAME AttributeSelection, bool isScalar >
91 double attributeSelectionThreshold,
92 double pairSelectionThreshold,
93 gum::VariableSet attributeListe) :
94 IncrementalGraphLearner< AttributeSelection, isScalar >(
95 target,
96 attributeListe,
97 new LabelizedVariable("Reward", "", 2)),
98 _lg_(&(this->model_), pairSelectionThreshold), _nbTotalObservation_(0),
99 _attributeSelectionThreshold_(attributeSelectionThreshold) {
100 GUM_CONSTRUCTOR(IMDDI);
101 _addLeaf_(this->root_);
102 }
103
104 // ============================================================================
105 // Reward Learner constructor
106 // ============================================================================
107 template < TESTNAME AttributeSelection, bool isScalar >
109 GUM_DESTRUCTOR(IMDDI);
111 leafIter != _leafMap_.endSafe();
112 ++leafIter)
113 delete leafIter.val();
114 }
115
116 // ############################################################################
117 // Incrementals methods
118 // ############################################################################
119
120 template < TESTNAME AttributeSelection, bool isScalar >
125
126 template < TESTNAME AttributeSelection, bool isScalar >
128 NodeId currentNodeId) {
130 newObs,
131 currentNodeId);
132 if (this->nodeVarMap_[currentNodeId] == this->value_) _lg_.updateLeaf(_leafMap_[currentNodeId]);
133 }
134
135 // ============================================================================
136 // Updates the tree after a new observation has been added
137 // ============================================================================
138 template < TESTNAME AttributeSelection, bool isScalar >
140 _varOrder_.clear();
141
142 // First xe initialize the node set which will give us the scores
143 Set< NodeId > currentNodeSet;
144 currentNodeSet.insert(this->root_);
145
146 // Then we initialize the pool of variables to consider
148 for (vs.begin(); vs.hasNext(); vs.next()) {
149 _updateScore_(vs.current(), this->root_, vs);
150 }
151
152 // Then, until there's no node remaining
153 while (!vs.isEmpty()) {
154 // We select the best var
155 const DiscreteVariable* selectedVar = vs.select();
156 _varOrder_.insert(selectedVar);
157
158 // Then we decide if we update each node according to this var
159 _updateNodeSet_(currentNodeSet, selectedVar, vs);
160 }
161
162 // If there are remaining node that are not leaves after we establish the
163 // var order
164 // these nodes are turned into leaf.
165 for (SetIteratorSafe< NodeId > nodeIter = currentNodeSet.beginSafe();
166 nodeIter != currentNodeSet.endSafe();
167 ++nodeIter)
168 this->convertNode2Leaf_(*nodeIter);
169
170
171 if (_lg_.needsUpdate()) _lg_.update();
172 }
173
174 // ############################################################################
175 // Updating methods
176 // ############################################################################
177
178
179 // ###################################################################
180 // Select the most relevant variable
181 //
182 // First parameter is the set of variables among which the most
183 // relevant one is choosed
184 // Second parameter is the set of node the will attribute a score
185 // to each variable so that we choose the best.
186 // ###################################################################
187 template < TESTNAME AttributeSelection, bool isScalar >
189 NodeId nody,
190 VariableSelector& vs) {
191 if (!this->nodeId2Database_[nody]->isTestRelevant(var)) return;
192 double weight = (double)this->nodeId2Database_[nody]->nbObservation()
194 vs.updateScore(var,
195 weight * this->nodeId2Database_[nody]->testValue(var),
196 weight * this->nodeId2Database_[nody]->testOtherCriterion(var));
197 }
198
199 template < TESTNAME AttributeSelection, bool isScalar >
201 NodeId nody,
202 VariableSelector& vs) {
203 if (!this->nodeId2Database_[nody]->isTestRelevant(var)) return;
204 double weight = (double)this->nodeId2Database_[nody]->nbObservation()
206 vs.downdateScore(var,
207 weight * this->nodeId2Database_[nody]->testValue(var),
208 weight * this->nodeId2Database_[nody]->testOtherCriterion(var));
209 }
210
211 // ============================================================================
212 // For each node in the given set, this methods checks whether or not
213 // we should installed the given variable as a test.
214 // If so, the node is updated
215 // ============================================================================
216 template < TESTNAME AttributeSelection, bool isScalar >
218 const DiscreteVariable* selectedVar,
219 VariableSelector& vs) {
220 Set< NodeId > oldNodeSet(nodeSet);
221 nodeSet.clear();
222 for (SetIteratorSafe< NodeId > nodeIter = oldNodeSet.beginSafe();
223 nodeIter != oldNodeSet.endSafe();
224 ++nodeIter) {
225 if (this->nodeId2Database_[*nodeIter]->isTestRelevant(selectedVar)
226 && this->nodeId2Database_[*nodeIter]->testValue(selectedVar)
228 this->transpose_(*nodeIter, selectedVar);
229
230 // Then we subtract the from the score given to each variables the
231 // quantity given by this node
232 for (vs.begin(); vs.hasNext(); vs.next()) {
233 _downdateScore_(vs.current(), *nodeIter, vs);
234 }
235
236 // And finally we add all its child to the new set of nodes
237 // and updates the remaining var's score
238 for (Idx modality = 0; modality < this->nodeVarMap_[*nodeIter]->domainSize(); ++modality) {
239 NodeId sonId = this->nodeSonsMap_[*nodeIter][modality];
240 nodeSet << sonId;
241
242 for (vs.begin(); vs.hasNext(); vs.next()) {
243 _updateScore_(vs.current(), sonId, vs);
244 }
245 }
246 } else {
247 nodeSet << *nodeIter;
248 }
249 }
250 }
251
252 // ============================================================================
253 // Insert a new node with given associated database, var and maybe sons
254 // ============================================================================
255 template < TESTNAME AttributeSelection, bool isScalar >
258 const DiscreteVariable* boundVar,
260 NodeId currentNodeId
262 boundVar,
263 obsSet);
264
265 _addLeaf_(currentNodeId);
266
267 return currentNodeId;
268 }
269
270 // ============================================================================
271 // Changes var associated to a node
272 // ============================================================================
273 template < TESTNAME AttributeSelection, bool isScalar >
275 const DiscreteVariable* desiredVar) {
276 if (this->nodeVarMap_[currentNodeId] == this->value_) _removeLeaf_(currentNodeId);
277
279 desiredVar);
280
281 if (desiredVar == this->value_) _addLeaf_(currentNodeId);
282 }
283
284 // ============================================================================
285 // Remove node from graph
286 // ============================================================================
287 template < TESTNAME AttributeSelection, bool isScalar >
289 if (this->nodeVarMap_[currentNodeId] == this->value_) _removeLeaf_(currentNodeId);
291 }
292
293 // ============================================================================
294 // Add leaf to aggregator
295 // ============================================================================
296 template < TESTNAME AttributeSelection, bool isScalar >
298 _leafMap_.insert(
299 currentNodeId,
301 this->nodeId2Database_[currentNodeId],
302 &(this->valueAssumed_)));
303 _lg_.addLeaf(_leafMap_[currentNodeId]);
304 }
305
306 // ============================================================================
307 // Remove leaf from aggregator
308 // ============================================================================
309 template < TESTNAME AttributeSelection, bool isScalar >
311 _lg_.removeLeaf(_leafMap_[currentNodeId]);
312 delete _leafMap_[currentNodeId];
313 _leafMap_.erase(currentNodeId);
314 }
315
316 // ============================================================================
317 // Computes the Reduced and Ordered Function Graph associated to this ordered
318 // tree
319 // ============================================================================
320 template < TESTNAME AttributeSelection, bool isScalar >
322 // if( _lg_.needsUpdate() || this->needUpdate_ ){
324 this->needUpdate_ = false;
325 // }
326 }
327
328 // ============================================================================
329 // Performs the leaves merging
330 // ============================================================================
331 template < TESTNAME AttributeSelection, bool isScalar >
333 // *******************************************************************************************************
334 // Mise à jour de l'aggregateur de feuille
335 _lg_.update();
336
337 // *******************************************************************************************************
338 // Reinitialisation du Graphe de Décision
339 this->target_->clear();
340 for (auto varIter = _varOrder_.beginSafe(); varIter != _varOrder_.endSafe(); ++varIter)
341 this->target_->add(**varIter);
342 this->target_->add(*this->value_);
343
345
346 // *******************************************************************************************************
347 // Insertion des feuilles
348 HashTable< NodeId, AbstractLeaf* > treeNode2leaf = _lg_.leavesMap();
351 = treeNode2leaf.cbeginSafe();
352 treeNodeIter != treeNode2leaf.cendSafe();
353 ++treeNodeIter) {
354 if (!leaf2DGNode.exists(treeNodeIter.val()))
355 leaf2DGNode.insert(treeNodeIter.val(),
357
358 toTarget.insert(treeNodeIter.key(), leaf2DGNode[treeNodeIter.val()]);
359 }
360
361 // *******************************************************************************************************
362 // Insertion des noeuds internes (avec vérification des possibilités de
363 // fusion)
365 varIter != _varOrder_.rendSafe();
366 --varIter) {
367 for (Link< NodeId >* curNodeIter = this->var2Node_[*varIter]->list(); curNodeIter;
368 curNodeIter = curNodeIter->nextLink()) {
369 NodeId* sonsMap
370 = static_cast< NodeId* >(SOA_ALLOCATE(sizeof(NodeId) * (*varIter)->domainSize()));
371 for (Idx modality = 0; modality < (*varIter)->domainSize(); ++modality)
372 sonsMap[modality] = toTarget[this->nodeSonsMap_[curNodeIter->element()][modality]];
373 toTarget.insert(curNodeIter->element(),
374 this->target_->manager()->addInternalNode(*varIter, sonsMap));
375 }
376 }
377
378 // *******************************************************************************************************
379 // Polish
380 this->target_->manager()->setRootNode(toTarget[this->root_]);
381 this->target_->manager()->clean();
382 }
383
384 // ============================================================================
385 // Performs the leaves merging
386 // ============================================================================
387 template < TESTNAME AttributeSelection, bool isScalar >
390 double value = 0.0;
391 for (Idx moda = 0; moda < leaf->nbModa(); moda++) {
392 value += (double)leaf->effectif(moda) * this->valueAssumed_.atPos(moda);
393 }
394 if (leaf->total()) value /= (double)leaf->total();
395 return this->target_->manager()->addTerminalNode(value);
396 }
397
398 // ============================================================================
399 // Performs the leaves merging
400 // ============================================================================
401 template < TESTNAME AttributeSelection, bool isScalar >
404 NodeId* sonsMap
405 = static_cast< NodeId* >(SOA_ALLOCATE(sizeof(NodeId) * this->value_->domainSize()));
406 for (Idx modality = 0; modality < this->value_->domainSize(); ++modality) {
407 double newVal = 0.0;
408 if (leaf->total()) newVal = (double)leaf->effectif(modality) / (double)leaf->total();
409 sonsMap[modality] = this->target_->manager()->addTerminalNode(newVal);
410 }
411 return this->target_->manager()->addInternalNode(this->value_, sonsMap);
412 }
413} // namespace gum
Headers of the ChiSquare class.
Safe Iterators for hashtables.
<agrum/FMDP/learning/datastructure/leaves/abstractLeaf.h>
virtual Idx nbModa() const =0
virtual double total() const =0
virtual double effectif(Idx) const =0
Gaves the leaf effectif for given modality.
<agrum/FMDP/learning/datastructure/leaves/concreteLeaf.h>
Base class for discrete random variable.
Safe Const Iterators for hashtables.
Definition hashTable.h:1602
const const_iterator_safe & cendSafe() const noexcept
Returns the safe const_iterator pointing to the end of the hashtable.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
const_iterator_safe cbeginSafe() const
Returns the safe const_iterator pointing to the beginning of the hashtable.
void _addLeaf_(NodeId)
Adds a new observation to the structure.
Definition imddi_tpl.h:297
NodeId _insertLeafInFunctionGraph_(AbstractLeaf *, Int2Type< true >)
Computes the score of the given variables for the given node.
Definition imddi_tpl.h:388
void updateNodeWithObservation_(const Observation *newObs, NodeId currentNodeId)
Adds a new observation to the structure.
Definition imddi_tpl.h:127
void updateFunctionGraph()
Computes the score of the given variables for the given node.
Definition imddi_tpl.h:321
void _updateNodeSet_(Set< NodeId > &, const DiscreteVariable *, VariableSelector &)
For each node in the given set, this methods checks whether or not we should installed the given vari...
Definition imddi_tpl.h:217
Idx _nbTotalObservation_
The total number of observation added to this tree.
Definition imddi.h:189
IMDDI(MultiDimFunctionGraph< double > *target, double attributeSelectionThreshold, double pairSelectionThreshold, gum::VariableSet attributeListe, const DiscreteVariable *learnedValue)
Variable Learner constructor.
Definition imddi_tpl.h:74
void _removeLeaf_(NodeId)
Adds a new observation to the structure.
Definition imddi_tpl.h:310
~IMDDI()
Default destructor.
Definition imddi_tpl.h:108
void _updateScore_(const DiscreteVariable *, NodeId, VariableSelector &vs)
Computes the score of the given variables for the given node.
Definition imddi_tpl.h:188
LeafAggregator _lg_
Definition imddi.h:184
void removeNode_(NodeId removedNodeId)
Adds a new observation to the structure.
Definition imddi_tpl.h:288
Sequence< const DiscreteVariable * > _varOrder_
Definition imddi.h:182
void _downdateScore_(const DiscreteVariable *, NodeId, VariableSelector &vs)
Computes the score of the given variables for the given node.
Definition imddi_tpl.h:200
double _attributeSelectionThreshold_
The threshold above which we consider variables to be dependant.
Definition imddi.h:192
void chgNodeBoundVar_(NodeId chgedNodeId, const DiscreteVariable *desiredVar)
Adds a new observation to the structure.
Definition imddi_tpl.h:274
HashTable< NodeId, AbstractLeaf * > _leafMap_
Definition imddi.h:186
void updateGraph()
Updates the tree after a new observation has been added.
Definition imddi_tpl.h:139
void _rebuildFunctionGraph_()
Computes the score of the given variables for the given node.
Definition imddi_tpl.h:332
NodeId insertLeafNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar, Set< const Observation * > *sonsMap)
Adds a new observation to the structure.
Definition imddi_tpl.h:256
void addObservation(const Observation *)
Adds a new observation to the structure.
Definition imddi_tpl.h:121
virtual void transpose_(NodeId, const DiscreteVariable *)
virtual NodeId insertLeafNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar, Set< const Observation * > *obsSet)
inserts a new leaf node in internal graohs
HashTable< const DiscreteVariable *, LinkedList< NodeId > * > var2Node_
HashTable< NodeId, NodeDatabase< AttributeSelection, isScalar > * > nodeId2Database_
virtual void addObservation(const Observation *obs)
Inserts a new observation.
virtual void updateNodeWithObservation_(const Observation *newObs, NodeId currentNodeId)
Will update internal graph's NodeDatabase of given node with the new observation.
HashTable< NodeId, const DiscreteVariable * > nodeVarMap_
IncrementalGraphLearner(MultiDimFunctionGraph< double > *target, gum::VariableSet attributesSet, const DiscreteVariable *learnVariable)
virtual void removeNode_(NodeId removedNodeId)
Removes a node from the internal graph.
virtual void chgNodeBoundVar_(NodeId chgedNodeId, const DiscreteVariable *desiredVar)
Changes the associated variable of a node.
class LabelizedVariable
<agrum/FMDP/learning/datastructure/nodeDatabase.h>
Safe iterators for Sequence.
Definition sequence.h:1134
Safe iterators for the Set class.
Definition set.h:601
iterator_safe beginSafe() const
The usual safe begin iterator to parse the set.
Definition set_tpl.h:414
const iterator_safe & endSafe() const noexcept
The usual safe end iterator to parse the set.
Definition set_tpl.h:426
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
void clear()
Removes all the elements, if any, from the set.
Definition set_tpl.h:338
<agrum/FMDP/planning/FunctionGraph/variableselector.h>
void updateScore(const DiscreteVariable *var, double score, double secondaryscore)
The set of remaining vars to select among.
void next()
The set of remaining vars to select among.
void downdateScore(const DiscreteVariable *var, double score, double secondaryscore)
The set of remaining vars to select among.
bool isEmpty() const
The set of remaining vars to select among.
const DiscreteVariable * select()
Select the most relevant variable.
void begin()
The set of remaining vars to select among.
const DiscreteVariable * current() const
The set of remaining vars to select among.
bool hasNext() const
The set of remaining vars to select among.
Size Idx
Type for indexes.
Definition types.h:79
Size NodeId
Type for node ids.
Headers of the IMDDI class.
Base class for labelized discrete random variables.
Useful macros for maths.
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Set< const DiscreteVariable * > VariableSet
priority queues (in which an element cannot appear more than once)
#define SOA_ALLOCATE(x)
Provides basic types used in aGrUM.