aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
incrementalGraphLearner_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
50// =======================================================
51#include <queue>
52// =======================================================
55
58// =======================================================
61// =======================================================
63
64// =======================================================
65
66namespace gum {
67
68 // ============================================================================
70 // ============================================================================
71
72 // ############################################################################
83 // ############################################################################
84 template < TESTNAME AttributeSelection, bool isScalar >
101
102 // ############################################################################
104 // ############################################################################
105 template < TESTNAME AttributeSelection, bool isScalar >
107 for (auto nodeIter = nodeId2Database_.beginSafe(); nodeIter != nodeId2Database_.endSafe();
108 ++nodeIter)
109 delete nodeIter.val();
110
111 for (auto nodeIter = nodeSonsMap_.beginSafe(); nodeIter != nodeSonsMap_.endSafe(); ++nodeIter)
112 SOA_DEALLOCATE(nodeIter.val(), sizeof(NodeId) * nodeVarMap_[nodeIter.key()]->domainSize());
113
114 for (auto varIter = var2Node_.beginSafe(); varIter != var2Node_.endSafe(); ++varIter)
115 delete varIter.val();
116
117 for (auto nodeIter = leafDatabase_.beginSafe(); nodeIter != leafDatabase_.endSafe(); ++nodeIter)
118 delete nodeIter.val();
119
120 _clearValue_();
121
122 GUM_DESTRUCTOR(IncrementalGraphLearner);
123 }
124
125 // ============================================================================
127 // ============================================================================
128
129 // ############################################################################
134 // ############################################################################
135 template < TESTNAME AttributeSelection, bool isScalar >
137 const Observation* newObs) {
138 _assumeValue_(newObs);
139
140 // The we go across the tree
141 NodeId currentNodeId = root_;
142
143 while (nodeSonsMap_.exists(currentNodeId)) {
144 // On each encountered node, we update the database
145 updateNodeWithObservation_(newObs, currentNodeId);
146
147 // The we select the next to go throught
148 currentNodeId = nodeSonsMap_[currentNodeId][_branchObs_(newObs, nodeVarMap_[currentNodeId])];
149 }
150
151 // On final insertion into the leave we reach
152 updateNodeWithObservation_(newObs, currentNodeId);
153 leafDatabase_[currentNodeId]->insert(newObs);
154 }
155
156 // ============================================================================
158 // ============================================================================
159
160 // ############################################################################
164 // ############################################################################
165 template < TESTNAME AttributeSelection, bool isScalar >
167 const DiscreteVariable* var) {
168 Link< NodeId >* nodIter = var2Node_[var]->list();
169 Link< NodeId >* nni = nullptr;
170 while (nodIter) {
171 nni = nodIter->nextLink();
172 convertNode2Leaf_(nodIter->element());
173 nodIter = nni;
174 }
175 }
176
177 // ############################################################################
185 // ############################################################################
186 template < TESTNAME AttributeSelection, bool isScalar >
188 NodeId updatedNode,
189 gum::VariableSet& varsOfInterest) {
190 // If this node has no interesting variable, we turn it into a leaf
191 if (varsOfInterest.empty()) {
192 convertNode2Leaf_(updatedNode);
193 return;
194 }
195
196 // If this node has already one of the best variable intalled as test, we
197 // move on
198 if (nodeVarMap_.exists(updatedNode) && varsOfInterest.exists(nodeVarMap_[updatedNode])) {
199 return;
200 }
201
202 // In any other case we have to install variable as best test
203 Idx randy = randomValue(varsOfInterest.size()), basc = 0;
204 SetConstIteratorSafe< const DiscreteVariable* > varIter;
205 for (varIter = varsOfInterest.cbeginSafe(), basc = 0;
206 varIter != varsOfInterest.cendSafe() && basc < randy;
207 ++varIter, basc++)
208 ;
209
210 transpose_(updatedNode, *varIter);
211 }
212
213 // ############################################################################
215 // ############################################################################
216 template < TESTNAME AttributeSelection, bool isScalar >
218 NodeId currentNodeId) {
219 if (nodeVarMap_[currentNodeId] != value_) {
220 leafDatabase_.insert(currentNodeId, new Set< const Observation* >());
221
222 // Resolving tensor sons issue
223 for (Idx modality = 0; modality < nodeVarMap_[currentNodeId]->domainSize(); ++modality) {
224 NodeId sonId = nodeSonsMap_[currentNodeId][modality];
226 (*leafDatabase_[currentNodeId]) = (*leafDatabase_[currentNodeId]) + *(leafDatabase_[sonId]);
227 removeNode_(sonId);
228 }
229
231 sizeof(NodeId) * nodeVarMap_[currentNodeId]->domainSize());
232 nodeSonsMap_.erase(currentNodeId);
233
234 chgNodeBoundVar_(currentNodeId, value_);
235 }
237
238 // ############################################################################
241 // ############################################################################
242 template < TESTNAME AttributeSelection, bool isScalar >
244 NodeId currentNodeId,
245 const DiscreteVariable* desiredVar) {
246 // **************************************************************************************
247 // Si le noeud courant contient déjà la variable qu'on souhaite lui amener
248 // Il n'y a rien à faire
249 if (nodeVarMap_[currentNodeId] == desiredVar) { return; }
250
251 // **************************************************************************************
252 // Si le noeud courant est terminal,
253 // Il faut artificiellement insérer un noeud liant à la variable
254 if (nodeVarMap_[currentNodeId] == value_) {
255 // We turned this leaf into an internal node.
256 // This mean that we'll need to install children leaves for each value of
257 // desiredVar
259 // First We must prepare these new leaves NodeDatabases and Sets<const
260 // Observation*>
264 Set< const Observation* >** obsetMap = static_cast< Set< const Observation* >** >(
265 SOA_ALLOCATE(sizeof(Set< const Observation* >*) * desiredVar->domainSize()));
266 for (Idx modality = 0; modality < desiredVar->domainSize(); ++modality) {
268 obsetMap[modality] = new Set< const Observation* >();
269 }
271 = leafDatabase_[currentNodeId]->beginSafe();
272 leafDatabase_[currentNodeId]->endSafe() != obsIter;
273 ++obsIter) {
274 dbMap[_branchObs_(*obsIter, desiredVar)]->addObservation(*obsIter);
275 obsetMap[_branchObs_(*obsIter, desiredVar)]->insert(*obsIter);
276 }
277
278 // Then we can install each new leaves (and put in place the sonsMap)
279 NodeId* sonsMap
280 = static_cast< NodeId* >(SOA_ALLOCATE(sizeof(NodeId) * desiredVar->domainSize()));
281 for (Idx modality = 0; modality < desiredVar->domainSize(); ++modality)
282 sonsMap[modality] = insertLeafNode_(dbMap[modality], value_, obsetMap[modality]);
283
284 // Some necessary clean up
285 SOA_DEALLOCATE(dbMap,
287 * desiredVar->domainSize());
288 SOA_DEALLOCATE(obsetMap, sizeof(Set< const Observation* >*) * desiredVar->domainSize());
289
290 // And finally we can turn the node into an internal node associated to
291 // desiredVar
292 chgNodeBoundVar_(currentNodeId, desiredVar);
293 nodeSonsMap_.insert(currentNodeId, sonsMap);
294
295 return;
296 }
297
298 // *************************************************************************************
299 // Remains the general case where currentNodeId is an internal node.
300
301 // First we ensure that children node use desiredVar as variable
302 for (Idx modality = 0; modality < nodeVarMap_[currentNodeId]->domainSize(); ++modality)
303 transpose_(nodeSonsMap_[currentNodeId][modality], desiredVar);
304
305 // Sequence<NodeDatabase<AttributeSelection, isScalar>*>
306 // sonsNodeDatabase =
307 // nodeId2Database_[currentNodeId]->splitOnVar(desiredVar);
308 NodeId* sonsMap
309 = static_cast< NodeId* >(SOA_ALLOCATE(sizeof(NodeId) * desiredVar->domainSize()));
310
311 // Then we create the new mapping
312 for (Idx desiredVarModality = 0; desiredVarModality < desiredVar->domainSize();
313 ++desiredVarModality) {
314 NodeId* grandSonsMap = static_cast< NodeId* >(
315 SOA_ALLOCATE(sizeof(NodeId) * nodeVarMap_[currentNodeId]->domainSize()));
317 = new NodeDatabase< AttributeSelection, isScalar >(&setOfVars_, value_);
318 for (Idx currentVarModality = 0;
319 currentVarModality < nodeVarMap_[currentNodeId]->domainSize();
320 ++currentVarModality) {
321 grandSonsMap[currentVarModality]
322 = nodeSonsMap_[nodeSonsMap_[currentNodeId][currentVarModality]][desiredVarModality];
323 sonDB->operator+=((*nodeId2Database_[grandSonsMap[currentVarModality]]));
324 }
325
326 sonsMap[desiredVarModality]
327 = insertInternalNode_(sonDB, nodeVarMap_[currentNodeId], grandSonsMap);
328 }
329
330 // Finally we clean the old remaining nodes
331 for (Idx currentVarModality = 0; currentVarModality < nodeVarMap_[currentNodeId]->domainSize();
332 ++currentVarModality) {
333 removeNode_(nodeSonsMap_[currentNodeId][currentVarModality]);
334 }
335
336 // We suppress the old sons map and remap to the new one
337 SOA_DEALLOCATE(nodeSonsMap_[currentNodeId],
338 sizeof(NodeId) * nodeVarMap_[currentNodeId]->domainSize());
339 nodeSonsMap_[currentNodeId] = sonsMap;
340
341 chgNodeBoundVar_(currentNodeId, desiredVar);
342 }
343
344 // ############################################################################
351 // ############################################################################
352 template < TESTNAME AttributeSelection, bool isScalar >
355 const DiscreteVariable* boundVar) {
356 NodeId newNodeId = model_.addNode();
357 nodeVarMap_.insert(newNodeId, boundVar);
358 nodeId2Database_.insert(newNodeId, nDB);
359 var2Node_[boundVar]->addLink(newNodeId);
360
361 needUpdate_ = true;
362
363 return newNodeId;
364 }
365
366 // ############################################################################
374 // ############################################################################
375 template < TESTNAME AttributeSelection, bool isScalar >
378 const DiscreteVariable* boundVar,
379 NodeId* sonsMap) {
380 NodeId newNodeId = this->insertNode_(nDB, boundVar);
381 nodeSonsMap_.insert(newNodeId, sonsMap);
382 return newNodeId;
383 }
384
385 // ############################################################################
393 // ############################################################################
394 template < TESTNAME AttributeSelection, bool isScalar >
397 const DiscreteVariable* boundVar,
399 NodeId newNodeId = this->insertNode_(nDB, boundVar);
400 leafDatabase_.insert(newNodeId, obsSet);
401 return newNodeId;
402 }
403
404 // ############################################################################
410 // ############################################################################
411 template < TESTNAME AttributeSelection, bool isScalar >
413 NodeId currentNodeId,
414 const DiscreteVariable* desiredVar) {
415 if (nodeVarMap_[currentNodeId] == desiredVar) return;
416
417 var2Node_[nodeVarMap_[currentNodeId]]->searchAndRemoveLink(currentNodeId);
418 var2Node_[desiredVar]->addLink(currentNodeId);
419 nodeVarMap_[currentNodeId] = desiredVar;
420
421 if (nodeVarMap_[currentNodeId] != value_ && leafDatabase_.exists(currentNodeId)) {
422 delete leafDatabase_[currentNodeId];
423 leafDatabase_.erase(currentNodeId);
424 }
425
426 if (nodeVarMap_[currentNodeId] == value_ && !leafDatabase_.exists(currentNodeId)) {
427 leafDatabase_.insert(currentNodeId, new Set< const Observation* >());
428 }
429
430 needUpdate_ = true;
431 }
432
433 // ############################################################################
438 // ############################################################################
439 template < TESTNAME AttributeSelection, bool isScalar >
441 // Retriat de l'id
442 model_.eraseNode(currentNodeId);
443
444 // Retrait du vecteur fils
445 if (nodeSonsMap_.exists(currentNodeId)) {
446 SOA_DEALLOCATE(nodeSonsMap_[currentNodeId],
447 sizeof(NodeId) * nodeVarMap_[currentNodeId]->domainSize());
448 nodeSonsMap_.erase(currentNodeId);
449 }
450
451 if (leafDatabase_.exists(currentNodeId)) {
452 delete leafDatabase_[currentNodeId];
453 leafDatabase_.erase(currentNodeId);
454 }
455
456 // Retrait de la variable
457 var2Node_[nodeVarMap_[currentNodeId]]->searchAndRemoveLink(currentNodeId);
458 nodeVarMap_.erase(currentNodeId);
459
460 // Retrait du NodeDatabase
461 delete nodeId2Database_[currentNodeId];
462 nodeId2Database_.erase(currentNodeId);
463
464 needUpdate_ = true;
465 }
466} // namespace gum
Headers of the ChiSquare class.
Base class for discrete random variable.
virtual Size domainSize() const =0
virtual void transpose_(NodeId, const DiscreteVariable *)
virtual void updateVar(const DiscreteVariable *)
If a new modality appears to exists for given variable, call this method to turn every associated nod...
virtual ~IncrementalGraphLearner()
Default destructor.
virtual NodeId insertLeafNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar, Set< const Observation * > *obsSet)
inserts a new leaf node in internal graohs
HashTable< const DiscreteVariable *, LinkedList< NodeId > * > var2Node_
Associates to any variable the list of all nodes associated to this variable.
NodeGraphPart model_
The source of nodeId.
Idx _branchObs_(const Observation *obs, const DiscreteVariable *var)
NodeId root_
The root of the ordered tree.
HashTable< NodeId, NodeDatabase< AttributeSelection, isScalar > * > nodeId2Database_
This hashtable binds every node to an associated NodeDatabase which handles every observation that co...
virtual void addObservation(const Observation *obs)
Inserts a new observation.
void updateNode_(NodeId nody, gum::VariableSet &bestVars)
From the given sets of node, selects randomly one and installs it on given node.
virtual void convertNode2Leaf_(NodeId)
Turns the given node into a leaf if not already so.
virtual NodeId insertNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar)
inserts a new node in internal graph
void _assumeValue_(const Observation *obs)
Get value assumed by studied variable for current observation.
virtual void updateNodeWithObservation_(const Observation *newObs, NodeId currentNodeId)
MultiDimFunctionGraph< double > * target_
The final diagram we're building.
HashTable< NodeId, Set< const Observation * > * > leafDatabase_
HashTable< NodeId, const DiscreteVariable * > nodeVarMap_
IncrementalGraphLearner(MultiDimFunctionGraph< double > *target, gum::VariableSet attributesSet, const DiscreteVariable *learnVariable)
Default constructor.
virtual void chgNodeBoundVar_(NodeId chgedNodeId, const DiscreteVariable *desiredVar)
virtual NodeId insertInternalNode_(NodeDatabase< AttributeSelection, isScalar > *nDB, const DiscreteVariable *boundVar, NodeId *sonsMap)
inserts a new internal node in internal graph
<agrum/FMDP/learning/datastructure/nodeDatabase.h>
void addObservation(const Observation *)
Nb observation taken into account by this instance.
Safe iterators for the Set class.
Definition set.h:601
Representation of a set.
Definition set.h:131
Size size() const noexcept
Returns the number of elements in the set.
Definition set_tpl.h:636
const_iterator_safe cbeginSafe() const
The usual safe begin iterator to parse the set.
Definition set_tpl.h:420
bool exists(const Key &k) const
Indicates whether a given elements belong to the set.
Definition set_tpl.h:533
const const_iterator_safe & cendSafe() const noexcept
The usual safe end iterator to parse the set.
Definition set_tpl.h:432
bool empty() const noexcept
Indicates whether the set is the empty set.
Definition set_tpl.h:642
Base class for discrete random variable.
Size Idx
Type for indexes.
Definition types.h:79
Size NodeId
Type for node ids.
Idx randomValue(const Size max=2)
Returns a random Idx between 0 and max-1 included.
Headers of the interface specifying functions to be implemented by any incremental learner.
Useful macros for maths.
Priority queues in which the same element can appear several times.
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Set< const DiscreteVariable * > VariableSet
#define SOA_DEALLOCATE(x, y)
#define SOA_ALLOCATE(x)
Provides basic types used in aGrUM.
Contains useful methods for random stuff.