aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
fmdpLearner_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
49
50// =========================================================================
52
53// =========================================================================
54
55namespace gum {
56
57 // ==========================================================================
58 // Constructor & destructor.
59 // ==========================================================================
60
61 // ###################################################################
62 // Default constructor
63 // ###################################################################
64 template < TESTNAME VariableAttributeSelection,
65 TESTNAME RewardAttributeSelection,
66 LEARNERNAME LearnerSelection >
68 FMDPLearner(double lT, bool actionReward, double sT) :
70 GUM_CONSTRUCTOR(FMDPLearner);
71 _rewardLearner_ = nullptr;
72 }
73
74 // ###################################################################
75 // Default destructor
76 // ###################################################################
77 template < TESTNAME VariableAttributeSelection,
78 TESTNAME RewardAttributeSelection,
79 LEARNERNAME LearnerSelection >
82 for (auto actionIter = _actionLearners_.beginSafe(); actionIter != _actionLearners_.endSafe();
83 ++actionIter) {
84 for (auto learnerIter = actionIter.val()->beginSafe();
85 learnerIter != actionIter.val()->endSafe();
86 ++learnerIter)
87 delete learnerIter.val();
88 delete actionIter.val();
89 if (_actionRewardLearners_.exists(actionIter.key()))
90 delete _actionRewardLearners_[actionIter.key()];
91 }
92
94
95 GUM_DESTRUCTOR(FMDPLearner);
96 }
97
98 // ==========================================================================
99 //
100 // ==========================================================================
101
102 // ###################################################################
103 //
104 // ###################################################################
105 template < TESTNAME VariableAttributeSelection,
106 TESTNAME RewardAttributeSelection,
107 LEARNERNAME LearnerSelection >
110 _fmdp_ = fmdp;
111
112 _modaMax_ = 0;
113 _rmax_ = 0.0;
114
115 gum::VariableSet mainVariables;
116 for (auto varIter = _fmdp_->beginVariables(); varIter != _fmdp_->endVariables(); ++varIter) {
117 mainVariables.insert(*varIter);
118 _modaMax_ = _modaMax_ < (*varIter)->domainSize() ? (*varIter)->domainSize() : _modaMax_;
119 }
120
121 for (auto actionIter = _fmdp_->beginActions(); actionIter != _fmdp_->endActions();
122 ++actionIter) {
123 // Adding a Hashtable for the action
124 _actionLearners_.insert(*actionIter, new VarLearnerTable());
125
126 // Adding a learner for each variable
127 for (auto varIter = _fmdp_->beginVariables(); varIter != _fmdp_->endVariables(); ++varIter) {
129 varTrans->setTableName("ACTION : " + _fmdp_->actionName(*actionIter)
130 + " - VARIABLE : " + (*varIter)->name());
131 _fmdp_->addTransitionForAction(*actionIter, *varIter, varTrans);
132 _actionLearners_[*actionIter]->insert(
133 (*varIter),
134 _instantiateVarLearner_(varTrans, mainVariables, _fmdp_->main2prime(*varIter)));
135 }
136
137 if (_actionReward_) {
139 reward->setTableName("REWARD - ACTION : " + _fmdp_->actionName(*actionIter));
140 _fmdp_->addRewardForAction(*actionIter, reward);
141 _actionRewardLearners_.insert(*actionIter,
142 _instantiateRewardLearner_(reward, mainVariables));
143 }
144 }
145
146 if (!_actionReward_) {
148 reward->setTableName("REWARD");
149 _fmdp_->addReward(reward);
150 _rewardLearner_ = _instantiateRewardLearner_(reward, mainVariables);
151 }
152 }
153
154 // ###################################################################
155 //
156 // ###################################################################
157 template < TESTNAME VariableAttributeSelection,
158 TESTNAME RewardAttributeSelection,
159 LEARNERNAME LearnerSelection >
161 addObservation(Idx actionId, const Observation* newObs) {
162 for (SequenceIteratorSafe< const DiscreteVariable* > varIter = _fmdp_->beginVariables();
163 varIter != _fmdp_->endVariables();
164 ++varIter) {
165 _actionLearners_[actionId]->getWithDefault(*varIter, nullptr)->addObservation(newObs);
166 _actionLearners_[actionId]->getWithDefault(*varIter, nullptr)->updateGraph();
167 }
168
169 if (_actionReward_) {
170 _actionRewardLearners_[actionId]->addObservation(newObs);
171 _actionRewardLearners_[actionId]->updateGraph();
172 } else {
173 _rewardLearner_->addObservation(newObs);
174 _rewardLearner_->updateGraph();
175 }
176
177 _rmax_ = _rmax_ < std::abs(newObs->reward()) ? std::abs(newObs->reward()) : _rmax_;
178
179 return false;
180 }
181
182 // ###################################################################
183 //
184 // ###################################################################
185 template < TESTNAME VariableAttributeSelection,
186 TESTNAME RewardAttributeSelection,
187 LEARNERNAME LearnerSelection >
189 size() {
190 Size s = 0;
191 for (SequenceIteratorSafe< Idx > actionIter = _fmdp_->beginActions();
192 actionIter != _fmdp_->endActions();
193 ++actionIter) {
194 for (SequenceIteratorSafe< const DiscreteVariable* > varIter = _fmdp_->beginVariables();
195 varIter != _fmdp_->endVariables();
196 ++varIter)
197 s += _actionLearners_[*actionIter]->getWithDefault(*varIter, nullptr)->size();
198 if (_actionReward_) s += _actionRewardLearners_[*actionIter]->size();
199 }
200
201 if (!_actionReward_) s += _rewardLearner_->size();
202
203 return s;
204 }
205
206 // ###################################################################
207 //
208 // ###################################################################
209 template < TESTNAME VariableAttributeSelection,
210 TESTNAME RewardAttributeSelection,
211 LEARNERNAME LearnerSelection >
213 updateFMDP() {
214 for (SequenceIteratorSafe< Idx > actionIter = _fmdp_->beginActions();
215 actionIter != _fmdp_->endActions();
216 ++actionIter) {
217 for (SequenceIteratorSafe< const DiscreteVariable* > varIter = _fmdp_->beginVariables();
218 varIter != _fmdp_->endVariables();
219 ++varIter)
220 _actionLearners_[*actionIter]->getWithDefault(*varIter, nullptr)->updateFunctionGraph();
221 if (_actionReward_) _actionRewardLearners_[*actionIter]->updateFunctionGraph();
222 }
223
224 if (!_actionReward_) _rewardLearner_->updateFunctionGraph();
225 }
226} // End of namespace gum
HashTable< Idx, RewardLearnerType * > _actionRewardLearners_
const double _similarityThreshold_
RewardLearnerType * _instantiateRewardLearner_(MultiDimFunctionGraph< double > *target, gum::VariableSet &mainVariables)
Initializes the learner.
~FMDPLearner()
Default destructor.
MultiDimFunctionGraph< double > * _instantiateFunctionGraph_()
Initializes the learner.
double _modaMax_
learnerSize
Size size()
learnerSize
void updateFMDP()
Starts an update of datastructure in the associated FMDP.
FMDP< double > * _fmdp_
The FMDP to store the learned model.
double _rmax_
learnerSize
HashTable< const DiscreteVariable *, VariableLearnerType * > VarLearnerTable
Definition fmdpLearner.h:86
const double _learningThreshold_
RewardLearnerType * _rewardLearner_
FMDPLearner(double learningThreshold, bool actionReward, double similarityThreshold=0.05)
Default constructor.
void initialize(FMDP< double > *fmdp)
Initializes the learner.
HashTable< Idx, VarLearnerTable * > _actionLearners_
VariableLearnerType * _instantiateVarLearner_(MultiDimFunctionGraph< double > *target, gum::VariableSet &mainVariables, const DiscreteVariable *learnedVar)
Initializes the learner.
bool addObservation(Idx actionId, const Observation *obs)
Gives to the learner a new transition.
void setTableName(const std::string &name)
Sets the name of the table represented by this structure.
double reward() const
Returns the modality assumed by the given variable in this observation.
Safe iterators for Sequence.
Definition sequence.h:1134
void insert(const Key &k)
Inserts a new element into the set.
Definition set_tpl.h:539
Headers of the FMDPLearner class.
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size Idx
Type for indexes.
Definition types.h:79
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
Set< const DiscreteVariable * > VariableSet