aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
incrementalGraphLearner.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40
41
49
50// =========================================================================
51#ifndef GUM_INCREMENTAL_GRAPH_LEARNER_H
52#define GUM_INCREMENTAL_GRAPH_LEARNER_H
53// =========================================================================
54// =========================================================================
55// =========================================================================
58
59// =========================================================================
60// =========================================================================
61
62namespace gum {
63
82 template < TESTNAME AttributeSelection, bool isScalar = false >
85
86 public:
87 // ###################################################################
89 // ###################################################################
91
92 // ==========================================================================
103 // ==========================================================================
105 gum::VariableSet attributesSet,
106 const DiscreteVariable* learnVariable);
107
108 // ==========================================================================
110 // ==========================================================================
111 virtual ~IncrementalGraphLearner();
112
113 private:
114 // ==========================================================================
116 // ==========================================================================
118
119 // ==========================================================================
122 // ==========================================================================
124
125 // ==========================================================================
128 // ==========================================================================
130
132
133
134 // ###################################################################
136 // ###################################################################
138
139 public:
140 // ==========================================================================
144 // ==========================================================================
145 virtual void addObservation(const Observation* obs);
146
147 private:
148 // ==========================================================================
152 // ==========================================================================
154
156 if (!valueAssumed_.exists(obs->reward())) valueAssumed_ << obs->reward();
157 }
158
160 if (!valueAssumed_.exists(obs->modality(value_))) valueAssumed_ << obs->modality(value_);
161 }
162
163 // ==========================================================================
167 // ==========================================================================
168 Idx _branchObs_(const Observation* obs, const DiscreteVariable* var) {
169 return _branchObs_(obs, var, Int2Type< isScalar >());
170 }
171
173 return obs->rModality(var);
174 }
175
177 return obs->modality(var);
178 }
179
180 protected:
181 // ==========================================================================
188 // ==========================================================================
189 virtual void updateNodeWithObservation_(const Observation* newObs, NodeId currentNodeId) {
190 nodeId2Database_[currentNodeId]->addObservation(newObs);
191 }
192
194
195 // ###################################################################
197 // ###################################################################
199
200 public:
201 // ==========================================================================
206 // ==========================================================================
207 virtual void updateVar(const DiscreteVariable*);
208
209 // ==========================================================================
211 // ==========================================================================
212 virtual void updateGraph() = 0;
213
214 protected:
215 // ==========================================================================
224 // ==========================================================================
225 void updateNode_(NodeId nody, gum::VariableSet& bestVars);
226
227 // ==========================================================================
229 // ==========================================================================
230 virtual void convertNode2Leaf_(NodeId);
231
232 // ==========================================================================
235 // ==========================================================================
236 virtual void transpose_(NodeId, const DiscreteVariable*);
237
238 // ==========================================================================
245 // ==========================================================================
247 const DiscreteVariable* boundVar);
248
249 // ==========================================================================
257 // ==========================================================================
259 const DiscreteVariable* boundVar,
260 NodeId* sonsMap);
261
262 // ==========================================================================
270 // ==========================================================================
272 const DiscreteVariable* boundVar,
274
275 // ==========================================================================
281 // ==========================================================================
282 virtual void chgNodeBoundVar_(NodeId chgedNodeId, const DiscreteVariable* desiredVar);
283
284 // ==========================================================================
289 // ==========================================================================
290 virtual void removeNode_(NodeId removedNodeId);
291
293
294
295 // ###################################################################
297 // ###################################################################
299
300 public:
301 // ==========================================================================
303 // ==========================================================================
304 virtual void updateFunctionGraph() = 0;
305
307
308
309 public:
310 // ==========================================================================
312 // ==========================================================================
313 Size size() { return nodeVarMap_.size(); }
314
315 // ###################################################################
317 // ###################################################################
319
320 public:
321 // ==========================================================================
323 // ==========================================================================
324 NodeId root() const { return this->root_; }
325
326 // ==========================================================================
328 // ==========================================================================
329 bool isTerminal(NodeId ni) const { return !this->nodeSonsMap_.exists(ni); }
330
331 // ==========================================================================
333 // ==========================================================================
334 const DiscreteVariable* nodeVar(NodeId ni) const { return this->nodeVarMap_[ni]; }
335
336 // ==========================================================================
338 // ==========================================================================
339 NodeId nodeSon(NodeId ni, Idx modality) const { return this->nodeSonsMap_[ni][modality]; }
340
341 // ==========================================================================
343 // ==========================================================================
344 Idx nodeNbObservation(NodeId ni) const { return this->nodeId2Database_[ni]->nbObservation(); }
345
346 // ==========================================================================
348 // ==========================================================================
351 varIter != setOfVars_.endSafe();
352 ++varIter)
353 ret->add(**varIter);
354 }
355
357
358 protected:
360
361 // ###################################################################
363 // ###################################################################
365
366 // ==========================================================================
368 // ==========================================================================
370
371 // ==========================================================================
373 // ==========================================================================
375
376 // ==========================================================================
378 // ==========================================================================
380
381 // ==========================================================================
384 // ==========================================================================
386
387 // ==========================================================================
390 // ==========================================================================
392
393 // ==========================================================================
396 // ==========================================================================
398
399 // ==========================================================================
402 // ==========================================================================
404
406
407
410
412
415
417 };
418
419
420} /* namespace gum */
421
423
424#endif // GUM_INCREMENTAL_GRAPH_LEARNER_H
Headers of the Learning Strategy interface.
Base class for discrete random variable.
The class for generic Hash Tables.
Definition hashTable.h:637
<agrum/FMDP/SDyna/IVisitableGraphLearner.h>
const DiscreteVariable * nodeVar(NodeId ni) const
virtual void transpose_(NodeId, const DiscreteVariable *)
Installs given variable to the given node, ensuring that the variable is not present in its subtree.
virtual void updateVar(const DiscreteVariable *)
If a new modality appears to exists for given variable, call this method to turn every associated nod...
NodeId nodeSon(NodeId ni, Idx modality) const
virtual ~IncrementalGraphLearner()
Default destructor.
virtual NodeId insertLeafNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar, Set< const Observation * > *obsSet)
inserts a new leaf node in internal graohs
HashTable< const DiscreteVariable *, LinkedList< NodeId > * > var2Node_
Associates to any variable the list of all nodes associated to this variable.
NodeGraphPart model_
The source of nodeId.
Idx _branchObs_(const Observation *obs, const DiscreteVariable *var)
Seek modality assumed in obs for given var.
NodeId root_
The root of the ordered tree.
HashTable< NodeId, NodeId * > nodeSonsMap_
A table giving for any node a table mapping to its son idx is the modality of associated variable.
void _clearValue_(Int2Type< true >)
In the case where we're learning a function of real values this has to be wiped out upon destruction ...
HashTable< NodeId, NodeDatabase< AttributeSelection, isScalar > * > nodeId2Database_
This hashtable binds every node to an associated NodeDatabase which handles every observation that co...
Idx _branchObs_(const Observation *obs, const DiscreteVariable *var, Int2Type< true >)
Inserts a new observation.
virtual void insertSetOfVars(MultiDimFunctionGraph< double > *ret) const
void _assumeValue_(const Observation *obs, Int2Type< true >)
Inserts a new observation.
virtual void addObservation(const Observation *obs)
Inserts a new observation.
typename ValueSelect< isScalar, double, Idx >::type ValueType
void updateNode_(NodeId nody, gum::VariableSet &bestVars)
From the given sets of node, selects randomly one and installs it on given node.
virtual void convertNode2Leaf_(NodeId)
Turns the given node into a leaf if not already so.
virtual NodeId insertNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar)
inserts a new node in internal graph
void _assumeValue_(const Observation *obs)
Get value assumed by studied variable for current observation.
virtual void updateNodeWithObservation_(const Observation *newObs, NodeId currentNodeId)
Will update internal graph's NodeDatabase of given node with the new observation.
MultiDimFunctionGraph< double > * target_
The final diagram we're building.
void _assumeValue_(const Observation *obs, Int2Type< false >)
Inserts a new observation.
HashTable< NodeId, Set< const Observation * > * > leafDatabase_
This hashtable binds to every leaf an associated set of all hte observations compatible with it.
void _clearValue_()
Template function dispatcher.
virtual void updateFunctionGraph()=0
Updates target to currently learned graph structure.
HashTable< NodeId, const DiscreteVariable * > nodeVarMap_
Gives for any node its associated variable.
Idx _branchObs_(const Observation *obs, const DiscreteVariable *var, Int2Type< false >)
Inserts a new observation.
IncrementalGraphLearner(MultiDimFunctionGraph< double > *target, gum::VariableSet attributesSet, const DiscreteVariable *learnVariable)
Default constructor.
virtual void removeNode_(NodeId removedNodeId)
Removes a node from the internal graph.
virtual void chgNodeBoundVar_(NodeId chgedNodeId, const DiscreteVariable *desiredVar)
Changes the associated variable of a node.
virtual void updateGraph()=0
Updates the tree after a new observation has been added.
void _clearValue_(Int2Type< false >)
In case where we're learning function of variable behaviour, this should do nothing.
virtual NodeId insertInternalNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar, NodeId *sonsMap)
inserts a new internal node in internal graph
virtual void add(const DiscreteVariable &v)
Adds a new var to the variables of the multidimensional matrix.
<agrum/FMDP/learning/datastructure/nodeDatabase.h>
Class for node sets in graph.
INLINE Idx modality(const DiscreteVariable *var) const
Returns the modality assumed by the given variable in this observation.
INLINE Idx rModality(const DiscreteVariable *var) const
Returns the modality assumed by the given variable in this observation.
double reward() const
Returns the modality assumed by the given variable in this observation.
Safe iterators for the Set class.
Definition set.h:601
Representation of a set.
Definition set.h:131
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size Idx
Type for indexes.
Definition types.h:79
Size NodeId
Type for node ids.
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Set< const DiscreteVariable * > VariableSet
Headers of the NodeDatabase class.