aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
paramEstimator.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40
41
47#ifndef GUM_LEARNING_PARAM_ESTIMATOR_H
48#define GUM_LEARNING_PARAM_ESTIMATOR_H
49
50#include <agrum/agrum.h>
51
54
55#include <type_traits>
56
57namespace gum {
58
59 namespace learning {
60
61
68 public:
69 // ##########################################################################
71 // ##########################################################################
73
75
97 const Prior& external_prior,
98 const Prior& _score_internal_prior,
99 const std::vector< std::pair< std::size_t, std::size_t > >& ranges,
100 const Bijection< NodeId, std::size_t >& nodeId2columns
102
104
120 const Prior& external_prior,
121 const Prior& _score_internal_prior,
122 const Bijection< NodeId, std::size_t >& nodeId2columns
124
127
130
132 virtual ParamEstimator* clone() const = 0;
133
136
138
139
140 // ##########################################################################
142 // ##########################################################################
144
146 virtual void clear();
147
149
153 virtual void setNumberOfThreads(Size nb);
154
156 virtual Size getNumberOfThreads() const;
157
159 virtual bool isGumNumberOfThreadsOverriden() const;
160
170 virtual void setMinNbRowsPerThread(const std::size_t nb) const;
171
173 virtual std::size_t minNbRowsPerThread() const;
174
176
182 void setRanges(const std::vector< std::pair< std::size_t, std::size_t > >& new_ranges);
183
186
188 const std::vector< std::pair< std::size_t, std::size_t > >& ranges() const;
189
191 std::vector< double > parameters(const NodeId target_node);
192
194 std::pair< std::vector< double >, double >
196
198
203 virtual std::vector< double > parameters(const NodeId target_node,
204 const std::vector< NodeId >& conditioning_nodes)
205 = 0;
206
219 virtual std::pair< std::vector< double >, double >
221 const std::vector< NodeId >& conditioning_nodes)
222 = 0;
223
240 template < typename GUM_SCALAR >
241 double setParameters(const NodeId target_node,
242 const std::vector< NodeId >& conditioning_nodes,
243 Tensor< GUM_SCALAR >& pot,
244 const bool compute_log_likelihood = false);
245
247
251
253 const DatabaseTable& database() const;
254
256
259 template < typename GUM_SCALAR >
260 void setBayesNet(const BayesNet< GUM_SCALAR >& new_bn);
261
263
264 protected:
267
271
274
276 const std::vector< NodeId > empty_nodevect_;
277
278
281
284
285 private:
286#ifndef DOXYGEN_SHOULD_SKIP_THIS
287
290 template < typename GUM_SCALAR >
291 void _checkParameters_(const NodeId target_node,
292 const std::vector< NodeId >& conditioning_nodes,
293 Tensor< GUM_SCALAR >& pot);
294
295 // sets the CPT's parameters corresponding to a given Tensor
296 // when the tensor belongs to a BayesNet<GUM_SCALAR> when
297 // GUM_SCALAR is different from a double
298 template < typename GUM_SCALAR >
299 typename std::enable_if< !std::is_same< GUM_SCALAR, double >::value, double >::type
300 _setParameters_(const NodeId target_node,
301 const std::vector< NodeId >& conditioning_nodes,
302 Tensor< GUM_SCALAR >& pot,
303 const bool compute_log_likelihood);
304
305 // sets the CPT's parameters corresponding to a given Tensor
306 // when the tensor belongs to a BayesNet<GUM_SCALAR> when
307 // GUM_SCALAR is equal to double (the code is optimized for doubles)
308 template < typename GUM_SCALAR >
309 typename std::enable_if< std::is_same< GUM_SCALAR, double >::value, double >::type
310 _setParameters_(const NodeId target_node,
311 const std::vector< NodeId >& conditioning_nodes,
312 Tensor< GUM_SCALAR >& pot,
313 const bool compute_log_likelihood);
314
315 friend class DAG2BNLearner;
316
317#endif /* DOXYGEN_SHOULD_SKIP_THIS */
318 };
319
320 } /* namespace learning */
321
322} /* namespace gum */
323
326
327// include the inlined functions if necessary
328#ifndef GUM_NO_INLINE
330#endif /* GUM_NO_INLINE */
331
332#endif /* GUM_LEARNING_PARAM_ESTIMATOR_H */
A class that, given a structure and a parameter estimator returns a full Bayes net.
the class used to read a row in the database and to transform it into a set of DBRow instances that c...
The class representing a tabular database as used by learning tasks.
RecordCounter counter_
the record counter used to parse the database
ParamEstimator & operator=(const ParamEstimator &from)
copy operator
virtual std::vector< double > parameters(const NodeId target_node, const std::vector< NodeId > &conditioning_nodes)=0
returns the CPT's parameters corresponding to a given nodeset
void clearRanges()
reset the ranges to the one range corresponding to the whole database
virtual void setMinNbRowsPerThread(const std::size_t nb) const
changes the number min of rows a thread should process in a multithreading context
const Bijection< NodeId, std::size_t > & nodeId2Columns() const
returns the mapping from ids to column positions in the database
virtual bool isGumNumberOfThreadsOverriden() const
indicates whether the user set herself the number of threads
ParamEstimator(const DBRowGeneratorParser &parser, const Prior &external_prior, const Prior &_score_internal_prior, const std::vector< std::pair< std::size_t, std::size_t > > &ranges, const Bijection< NodeId, std::size_t > &nodeId2columns=Bijection< NodeId, std::size_t >())
default constructor
virtual ~ParamEstimator()
destructor
double setParameters(const NodeId target_node, const std::vector< NodeId > &conditioning_nodes, Tensor< GUM_SCALAR > &pot, const bool compute_log_likelihood=false)
sets a CPT's parameters and, possibly, return its log-likelihhod
virtual Size getNumberOfThreads() const
returns the current max number of threads of the scheduler
ParamEstimator(ParamEstimator &&from)
move constructor
ParamEstimator(const ParamEstimator &from)
copy constructor
const std::vector< NodeId > empty_nodevect_
an empty vector of nodes, used for empty conditioning
Prior * score_internal_prior_
if a score was used for learning the structure of the PGM, this is the priori internal to the score
ParamEstimator(const DBRowGeneratorParser &parser, const Prior &external_prior, const Prior &_score_internal_prior, const Bijection< NodeId, std::size_t > &nodeId2columns=Bijection< NodeId, std::size_t >())
default constructor
virtual void clear()
clears all the data structures from memory
void setBayesNet(const BayesNet< GUM_SCALAR > &new_bn)
assign a new Bayes net to all the counter's generators depending on a BN
std::pair< std::vector< double >, double > parametersAndLogLikelihood(const NodeId target_node)
returns the parameters of a CPT as well as its log-likelihood
virtual std::pair< std::vector< double >, double > parametersAndLogLikelihood(const NodeId target_node, const std::vector< NodeId > &conditioning_nodes)=0
returns the parameters of a CPT as well as its log-likelihood
void setRanges(const std::vector< std::pair< std::size_t, std::size_t > > &new_ranges)
sets new ranges to perform the counts used by the parameter estimator
virtual ParamEstimator * clone() const =0
virtual copy constructor
const std::vector< std::pair< std::size_t, std::size_t > > & ranges() const
returns the current ranges
ParamEstimator & operator=(ParamEstimator &&from)
move operator
std::vector< double > parameters(const NodeId target_node)
returns the CPT's parameters corresponding to a given target node
virtual void setNumberOfThreads(Size nb)
sets the number max of threads that can be used
const DatabaseTable & database() const
returns the database on which we perform the counts
virtual std::size_t minNbRowsPerThread() const
returns the minimum of rows that each thread should process
Prior * external_prior_
an external a priori
the base class for all a priori
Definition prior.h:83
The class that computes counting of observations from the database.
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size NodeId
Type for node ids.
include the inlined functions if necessary
Definition CSVParser.h:54
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
the base class for estimating parameters of CPTs
the base class for estimating parameters of CPTs
the base class for all a priori
The class that computes counting of observations from the database.