aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
correctedMutualInformation.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40
41
50#ifndef GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H
51#define GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H
52
53#include <vector>
54
55#include <agrum/config.h>
56
58
61
62namespace gum {
63 namespace learning {
80 public:
81 // ##########################################################################
83 // ##########################################################################
85
87
108 const Prior& prior,
109 const std::vector< std::pair< std::size_t, std::size_t > >& ranges,
110 const Bijection< NodeId, std::size_t >& nodeId2columns
112
114
129 const Prior& prior,
130 const Bijection< NodeId, std::size_t >& nodeId2columns
132
135
138
141
144
146
147
148 // ##########################################################################
150 // ##########################################################################
151
153
156
159
161
162
163 // ##########################################################################
165 // ##########################################################################
167
169 virtual void clear();
170
172
180 virtual void clearCache();
181
183
191 virtual void useCache(bool on_off);
192
194 void useICache(bool on_off);
195
198
200 void useHCache(bool on_off);
201
204
206 void useKCache(bool on_off);
207
210
212 void useCnrCache(bool on_off);
213
216
218
219
220 // ##########################################################################
222 // ##########################################################################
224
226 double score(NodeId var1, NodeId var2);
227
229 double score(NodeId var1, NodeId var2, const std::vector< NodeId >& conditioning_ids);
230
232 double score(NodeId var1, NodeId var2, NodeId var3);
233
235 double score(NodeId var1,
236 NodeId var2,
237 NodeId var3,
238 const std::vector< NodeId >& conditioning_ids);
239
241
242
243 // ##########################################################################
245 // ##########################################################################
247
249 void useMDL();
250
252 void useNML();
253
255 void useNoCorr();
256
258 virtual void setNumberOfThreads(Size nb);
259
261 virtual std::size_t getNumberOfThreads() const;
262
264 virtual bool isGumNumberOfThreadsOverriden() const;
265
275 virtual void setMinNbRowsPerThread(const std::size_t nb) const;
276
278 virtual std::size_t minNbRowsPerThread() const;
279
281
287 void setRanges(const std::vector< std::pair< std::size_t, std::size_t > >& new_ranges);
288
291
293 const std::vector< std::pair< std::size_t, std::size_t > >& ranges() const;
294
296
297
299 enum class KModeTypes { MDL, NML, NoCorr };
300
301
302#ifndef DOXYGEN_SHOULD_SKIP_THIS
303
304 private:
306 /* Note that the log2-likelihood is equal to N times the entropy H */
308
310 KNML _k_NML_;
311
314 ScoreMDL _score_MDL_;
315
317 KModeTypes _kmode_{KModeTypes::MDL};
318
319
321
323 bool _use_ICache_{true};
324
326
329 bool _use_HCache_{true};
330
332
335 bool _use_KCache_{true};
336
338
342 bool _use_CnrCache_{true};
343
344
346 ScoringCache _ICache_;
347
349 ScoringCache _KCache_;
350
351
353 const std::vector< NodeId > _empty_conditioning_set_;
354
356 const double _threshold_{1e-10};
357
358
360 double _NI_score_(NodeId var_x, NodeId var_y, const std::vector< NodeId >& vars_z);
361
363 double _NI_score_(NodeId var_x,
364 NodeId var_y,
365 NodeId var_z,
366 const std::vector< NodeId >& vars_ui);
367
369 double _K_score_(NodeId var_x, NodeId var_y, const std::vector< NodeId >& vars_z);
370
372 double
373 _K_score_(NodeId var_x, NodeId var_y, NodeId var_z, const std::vector< NodeId >& vars_ui);
374
375#endif /* DOXYGEN_SHOULD_SKIP_THIS */
376 };
377
378 } /* namespace learning */
379
380} /* namespace gum */
381
382// include the inlined functions if necessary
383#ifndef GUM_NO_INLINE
385#endif /* GUM_NO_INLINE */
386
387#endif /* GUM_LEARNING_CORRECTED_MUTUAL_INFORMATION_H */
double score(NodeId var1, NodeId var2, NodeId var3)
returns the 3-point mutual information corresponding to a given nodeset
virtual CorrectedMutualInformation * clone() const
virtual copy constructor
virtual std::size_t minNbRowsPerThread() const
returns the minimum of rows that each thread should process
virtual ~CorrectedMutualInformation()
destructor
void clearKCache()
clears the KCache (the cache for the penalties)
void useKCache(bool on_off)
turn on/off the use of the KCache (the cache for the penalties)
virtual void setNumberOfThreads(Size nb)
changes the max number of threads used to parse the database
virtual void setMinNbRowsPerThread(const std::size_t nb) const
changes the number min of rows a thread should process in a multithreading context
virtual void clearCache()
clears all the current caches
CorrectedMutualInformation & operator=(const CorrectedMutualInformation &from)
copy operator
KModeTypes
the description type for the complexity correction
virtual void useCache(bool on_off)
turn on/off the use of all the caches
virtual bool isGumNumberOfThreadsOverriden() const
indicates whether the user set herself the number of threads
void clearRanges()
reset the ranges to the one range corresponding to the whole database
CorrectedMutualInformation & operator=(CorrectedMutualInformation &&from)
move operator
CorrectedMutualInformation(const CorrectedMutualInformation &from)
copy constructor
double score(NodeId var1, NodeId var2, NodeId var3, const std::vector< NodeId > &conditioning_ids)
returns the 3-point mutual information corresponding to a given nodeset
const std::vector< std::pair< std::size_t, std::size_t > > & ranges() const
returns the current ranges
double score(NodeId var1, NodeId var2)
returns the 2-point mutual information corresponding to a given nodeset
void useCnrCache(bool on_off)
turn on/off the use of the CnrCache (the cache for the Cnr formula)
void useICache(bool on_off)
turn on/off the use of the ICache (the mutual information cache)
double score(NodeId var1, NodeId var2, const std::vector< NodeId > &conditioning_ids)
returns the 2-point mutual information corresponding to a given nodeset
void useNML()
use the kNML penalty function
void clearICache()
clears the ICache (the mutual information cache)
void clearHCache()
clears the HCache (the cache for the entropies)
CorrectedMutualInformation(const DBRowGeneratorParser &parser, const Prior &prior, const std::vector< std::pair< std::size_t, std::size_t > > &ranges, const Bijection< NodeId, std::size_t > &nodeId2columns=Bijection< NodeId, std::size_t >())
default constructor
CorrectedMutualInformation(const DBRowGeneratorParser &parser, const Prior &prior, const Bijection< NodeId, std::size_t > &nodeId2columns=Bijection< NodeId, std::size_t >())
default constructor
void useHCache(bool on_off)
turn on/off the use of the HCache (the cache for the entropies)
virtual std::size_t getNumberOfThreads() const
returns the number of threads used to parse the database
virtual void clear()
clears all the data structures from memory
void clearCnrCache()
clears the CnrCache (the cache for the Cnr formula)
CorrectedMutualInformation(CorrectedMutualInformation &&from)
move constructor
void setRanges(const std::vector< std::pair< std::size_t, std::size_t > > &new_ranges)
sets new ranges to perform the counts used by the mutual information
void useNoCorr()
use no correction/penalty function
void useMDL()
use the MDL penalty function
the class used to read a row in the database and to transform it into a set of DBRow instances that c...
the class for computing the NML penalty used by MIIC
Definition kNML.h:67
the base class for all a priori
Definition prior.h:83
the class for computing Log2-likelihood scores
The class computing n times the corrected mutual information, as used in the MIIC algorithm.
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
Size NodeId
Type for node ids.
The class for the NML penalty used in MIIC.
include the inlined functions if necessary
Definition CSVParser.h:54
ScoreBIC ScoreMDL
Definition scoreMDL.h:67
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
the class for computing Log2-likelihood scores
the class for computing MDL scores