aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
iti_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
50// =======================================================
53
55// =======================================================
58// =======================================================
60
61// =======================================================
62
63
64namespace gum {
65
66 // ==========================================================================
68 // ==========================================================================
69
70 // ###################################################################
82 // ###################################################################
83 template < TESTNAME AttributeSelection, bool isScalar >
85 double attributeSelectionThreshold,
86 gum::VariableSet attributeListe,
87 const DiscreteVariable* learnedValue) :
88 IncrementalGraphLearner< AttributeSelection, isScalar >(target, attributeListe, learnedValue),
89 _nbTotalObservation_(0), _attributeSelectionThreshold_(attributeSelectionThreshold) {
90 GUM_CONSTRUCTOR(ITI);
91 _staleTable_.insert(this->root_, false);
92 }
93
94 // ###################################################################
105 // ###################################################################
106 template < TESTNAME AttributeSelection, bool isScalar >
108 double attributeSelectionThreshold,
109 gum::VariableSet attributeListe) :
110 IncrementalGraphLearner< AttributeSelection, isScalar >(
111 target,
112 attributeListe,
113 new LabelizedVariable("Reward", "", 2)),
114 _nbTotalObservation_(0), _attributeSelectionThreshold_(attributeSelectionThreshold) {
115 GUM_CONSTRUCTOR(ITI);
116 _staleTable_.insert(this->root_, false);
117 }
118
119 // ==========================================================================
121 // ==========================================================================
122
123 // ############################################################################
128 // ############################################################################
129 template < TESTNAME AttributeSelection, bool isScalar >
134
135 // ############################################################################
142 // ############################################################################
143 template < TESTNAME AttributeSelection, bool isScalar >
151
152 // ============================================================================
154 // ============================================================================
155
156 // ############################################################################
158 // ############################################################################
159 template < TESTNAME AttributeSelection, bool isScalar >
161 std::vector< NodeId > filo;
162 filo.push_back(this->root_);
164 tensorVars.insert(this->root_, new gum::VariableSet(this->setOfVars_));
165
166
167 while (!filo.empty()) {
168 NodeId currentNodeId = filo.back();
169 filo.pop_back();
170
171 // First we look for the best var to install on the node
172 double bestValue = _attributeSelectionThreshold_;
173 gum::VariableSet bestVars;
174
175 for (auto varIter = tensorVars[currentNodeId]->cbeginSafe();
176 varIter != tensorVars[currentNodeId]->cendSafe();
177 ++varIter)
178 if (this->nodeId2Database_[currentNodeId]->isTestRelevant(*varIter)) {
179 double varValue = this->nodeId2Database_[currentNodeId]->testValue(*varIter);
180 if (varValue >= bestValue) {
181 if (varValue > bestValue) {
182 bestValue = varValue;
183 bestVars.clear();
184 }
185 bestVars.insert(*varIter);
186 }
187 }
188
189 // Then We installed Variable a test on that node
190 this->updateNode_(currentNodeId, bestVars);
191
192 // The we move on the children if needed
193 if (this->nodeVarMap_[currentNodeId] != this->value_) {
194 for (Idx moda = 0; moda < this->nodeVarMap_[currentNodeId]->domainSize(); moda++) {
195 gum::VariableSet* itsTensorVars = new gum::VariableSet(*tensorVars[currentNodeId]);
196 itsTensorVars->erase(this->nodeVarMap_[currentNodeId]);
197 NodeId sonId = this->nodeSonsMap_[currentNodeId][moda];
198 if (_staleTable_[sonId]) {
199 filo.push_back(sonId);
200 tensorVars.insert(sonId, itsTensorVars);
201 }
202 }
203 }
204 }
205
207 nodeIter != tensorVars.endSafe();
208 ++nodeIter)
209 delete nodeIter.val();
210 }
211
212 // ############################################################################
219 // ############################################################################
220 template < TESTNAME AttributeSelection, bool isScalar >
228
229 // ############################################################################
235 // ############################################################################
236 template < TESTNAME AttributeSelection, bool isScalar >
238 const DiscreteVariable* desiredVar) {
239 if (this->nodeVarMap_[currentNodeId] != desiredVar) {
240 _staleTable_[currentNodeId] = true;
242 desiredVar);
243 }
244 }
245
246 // ############################################################################
251 // ############################################################################
252 template < TESTNAME AttributeSelection, bool isScalar >
257
258 // ============================================================================
260 // ============================================================================
261
262 // ############################################################################
264 // ############################################################################
265 template < TESTNAME AttributeSelection, bool isScalar >
267 this->target_->clear();
268 this->target_->manager()->setRootNode(this->_insertNodeInFunctionGraph_(this->root_));
269 }
270
271 // ############################################################################
277 // ############################################################################
278 template < TESTNAME AttributeSelection, bool isScalar >
280 if (this->nodeVarMap_[currentNodeId] == this->value_) {
281 NodeId nody = _insertTerminalNode_(currentNodeId);
282 return nody;
283 }
284
285 if (!this->target_->variablesSequence().exists(this->nodeVarMap_[currentNodeId])) {
286 this->target_->add(*(this->nodeVarMap_[currentNodeId]));
287 }
288
289 NodeId nody = this->target_->manager()->addInternalNode(this->nodeVarMap_[currentNodeId]);
290 for (Idx moda = 0; moda < this->nodeVarMap_[currentNodeId]->domainSize(); ++moda) {
291 NodeId son = this->_insertNodeInFunctionGraph_(this->nodeSonsMap_[currentNodeId][moda]);
292 this->target_->manager()->setSon(nody, moda, son);
293 }
294
295 return nody;
296 }
297
298 // ############################################################################
306 // ############################################################################
307 template < TESTNAME AttributeSelection, bool isScalar >
310 if (!this->target_->variablesSequence().exists(this->value_))
311 this->target_->add(*(this->value_));
312
313 Size tot = this->nodeId2Database_[currentNodeId]->nbObservation();
314 if (tot == Size(0)) return this->target_->manager()->addTerminalNode(0.0);
315
316 NodeId* sonsMap
317 = static_cast< NodeId* >(SOA_ALLOCATE(sizeof(NodeId) * this->value_->domainSize()));
318 for (Idx modality = 0; modality < this->value_->domainSize(); ++modality) {
319 double newVal = 0.0;
320 newVal = (double)this->nodeId2Database_[currentNodeId]->effectif(modality) / (double)tot;
321 sonsMap[modality] = this->target_->manager()->addTerminalNode(newVal);
322 }
323 NodeId nody = this->target_->manager()->addInternalNode(this->value_, sonsMap);
324 return nody;
325 }
326
327 // ############################################################################
335 // ############################################################################
336 template < TESTNAME AttributeSelection, bool isScalar >
339 double value = 0.0;
340 for (auto valIter = this->nodeId2Database_[currentNodeId]->cbeginValues();
341 valIter != this->nodeId2Database_[currentNodeId]->cendValues();
342 ++valIter) {
343 value += (double)valIter.key() * valIter.val();
344 }
345 if (this->nodeId2Database_[currentNodeId]->nbObservation())
346 value /= (double)this->nodeId2Database_[currentNodeId]->nbObservation();
347 NodeId nody = this->target_->manager()->addTerminalNode(value);
348 return nody;
349 }
350} // namespace gum
Headers of the ChiSquare class.
Safe Iterators for hashtables.
Base class for discrete random variable.
The class for generic Hash Tables.
Definition hashTable.h:637
const iterator_safe & endSafe() noexcept
Returns the safe iterator pointing to the end of the hashtable.
const const_iterator_safe & cendSafe() const noexcept
Returns the safe const_iterator pointing to the end of the hashtable.
value_type & insert(const Key &key, const Val &val)
Adds a new element (actually a copy of this element) into the hash table.
iterator_safe beginSafe()
Returns the safe iterator pointing to the beginning of the hashtable.
void chgNodeBoundVar_(NodeId chgedNodeId, const DiscreteVariable *desiredVar)
Changes the associated variable of a node.
Definition iti_tpl.h:237
NodeId _insertTerminalNode_(NodeId src)
Insert a terminal node in the target.
Definition iti.h:229
double _attributeSelectionThreshold_
The threshold above which we consider variables to be dependant.
Definition iti.h:282
void removeNode_(NodeId removedNodeId)
Removes a node from the internal graph.
Definition iti_tpl.h:253
HashTable< NodeId, bool > _staleTable_
Hashtable indicating if given node has been modified (upon receiving new exemple or through a transpo...
Definition iti.h:276
NodeId insertNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar)
inserts a new node in internal graph
Definition iti_tpl.h:221
NodeId _insertNodeInFunctionGraph_(NodeId src)
Inserts an internal node in the target.
Definition iti_tpl.h:279
void updateGraph()
Updates the internal graph after a new observation has been added.
Definition iti_tpl.h:160
void updateFunctionGraph()
Updates target to currently learned graph structure.
Definition iti_tpl.h:266
Idx _nbTotalObservation_
The total number of observation added to this tree.
Definition iti.h:279
void updateNodeWithObservation_(const Observation *newObs, NodeId currentNodeId)
Will update internal graph's NodeDatabase of given node with the new observation.
Definition iti_tpl.h:144
ITI(MultiDimFunctionGraph< double > *target, double attributeSelectionThreshold, gum::VariableSet attributeListe, const DiscreteVariable *learnedValue)
ITI constructor for functions describing the behaviour of one variable according to a set of other va...
Definition iti_tpl.h:84
void addObservation(const Observation *obs)
Inserts a new observation.
Definition iti_tpl.h:130
HashTable< NodeId, NodeDatabase< AttributeSelection, isScalar > * > nodeId2Database_
virtual void addObservation(const Observation *obs)
Inserts a new observation.
virtual NodeId insertNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar)
inserts a new node in internal graph
virtual void updateNodeWithObservation_(const Observation *newObs, NodeId currentNodeId)
Will update internal graph's NodeDatabase of given node with the new observation.
HashTable< NodeId, const DiscreteVariable * > nodeVarMap_
IncrementalGraphLearner(MultiDimFunctionGraph< double > *target, gum::VariableSet attributesSet, const DiscreteVariable *learnVariable)
virtual void removeNode_(NodeId removedNodeId)
Removes a node from the internal graph.
virtual void chgNodeBoundVar_(NodeId chgedNodeId, const DiscreteVariable *desiredVar)
Changes the associated variable of a node.
class LabelizedVariable
<agrum/FMDP/learning/datastructure/nodeDatabase.h>
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
void erase(const Key &k)
Erases an element from the set.
Definition set_tpl.h:582
void clear()
Removes all the elements, if any, from the set.
Definition set_tpl.h:338
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size Idx
Type for indexes.
Definition types.h:79
Size NodeId
Type for node ids.
Headers of the ITI class.
Base class for labelized discrete random variables.
Useful macros for maths.
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Set< const DiscreteVariable * > VariableSet
priority queues (in which an element cannot appear more than once)
#define SOA_ALLOCATE(x)
Provides basic types used in aGrUM.