aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
correctedMutualInformation.cpp
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40
41
48
50
51#ifndef DOXYGEN_SHOULD_SKIP_THIS
52
54# ifdef GUM_NO_INLINE
56# endif /* GUM_NO_INLINE */
57
58namespace gum {
59
60 namespace learning {
61
64 const DBRowGeneratorParser& parser,
65 const Prior& prior,
66 const std::vector< std::pair< std::size_t, std::size_t > >& ranges,
67 const Bijection< NodeId, std::size_t >& nodeId2columns) :
68 _NH_(parser, prior, ranges, nodeId2columns), _k_NML_(parser, prior, ranges, nodeId2columns),
69 _score_MDL_(parser, prior, ranges, nodeId2columns) {
70 GUM_CONSTRUCTOR(CorrectedMutualInformation);
71 }
72
74 CorrectedMutualInformation::CorrectedMutualInformation(
75 const DBRowGeneratorParser& parser,
76 const Prior& prior,
77 const Bijection< NodeId, std::size_t >& nodeId2columns) :
78 _NH_(parser, prior, nodeId2columns), _k_NML_(parser, prior, nodeId2columns),
79 _score_MDL_(parser, prior, nodeId2columns) {
80 GUM_CONSTRUCTOR(CorrectedMutualInformation);
81 }
82
84 CorrectedMutualInformation::CorrectedMutualInformation(const CorrectedMutualInformation& from) :
85 _NH_(from._NH_), _k_NML_(from._k_NML_), _score_MDL_(from._score_MDL_),
86 _kmode_(from._kmode_), _use_ICache_(from._use_ICache_), _use_HCache_(from._use_HCache_),
87 _use_KCache_(from._use_KCache_), _use_CnrCache_(from._use_CnrCache_),
88 _ICache_(from._ICache_), _KCache_(from._KCache_) {
89 GUM_CONS_CPY(CorrectedMutualInformation);
90 }
91
93 CorrectedMutualInformation::CorrectedMutualInformation(CorrectedMutualInformation&& from) :
94 _NH_(std::move(from._NH_)), _k_NML_(std::move(from._k_NML_)),
95 _score_MDL_(std::move(from._score_MDL_)), _kmode_(from._kmode_),
96 _use_ICache_(from._use_ICache_), _use_HCache_(from._use_HCache_),
97 _use_KCache_(from._use_KCache_), _use_CnrCache_(from._use_CnrCache_),
98 _ICache_(std::move(from._ICache_)), _KCache_(std::move(from._KCache_)) {
99 GUM_CONS_MOV(CorrectedMutualInformation);
100 }
101
103 CorrectedMutualInformation* CorrectedMutualInformation::clone() const {
104 return new CorrectedMutualInformation(*this);
105 }
106
108 CorrectedMutualInformation::~CorrectedMutualInformation() {
109 // for debugging purposes
110 GUM_DESTRUCTOR(CorrectedMutualInformation);
111 }
112
114 CorrectedMutualInformation&
115 CorrectedMutualInformation::operator=(const CorrectedMutualInformation& from) {
116 if (this != &from) {
117 _NH_ = from._NH_;
118 _k_NML_ = from._k_NML_;
119 _score_MDL_ = from._score_MDL_;
120 _kmode_ = from._kmode_;
121 _use_ICache_ = from._use_ICache_;
122 _use_HCache_ = from._use_HCache_;
123 _use_KCache_ = from._use_KCache_;
124 _use_CnrCache_ = from._use_CnrCache_;
125 _ICache_ = from._ICache_;
126 _KCache_ = from._KCache_;
127 }
128 return *this;
129 }
130
132 CorrectedMutualInformation&
133 CorrectedMutualInformation::operator=(CorrectedMutualInformation&& from) {
134 if (this != &from) {
135 _NH_ = std::move(from._NH_);
136 _k_NML_ = std::move(from._k_NML_);
137 _score_MDL_ = std::move(from._score_MDL_);
138 _kmode_ = from._kmode_;
139 _use_ICache_ = from._use_ICache_;
140 _use_HCache_ = from._use_HCache_;
141 _use_KCache_ = from._use_KCache_;
142 _use_CnrCache_ = from._use_CnrCache_;
143 _ICache_ = std::move(from._ICache_);
144 _KCache_ = std::move(from._KCache_);
145 }
146 return *this;
147 }
148
150
156 void CorrectedMutualInformation::setRanges(
157 const std::vector< std::pair< std::size_t, std::size_t > >& new_ranges) {
158 std::vector< std::pair< std::size_t, std::size_t > > old_ranges = ranges();
159
160 _NH_.setRanges(new_ranges);
161 _k_NML_.setRanges(new_ranges);
162 _score_MDL_.setRanges(new_ranges);
163
164 if (old_ranges != ranges()) clear();
165 }
166
168 void CorrectedMutualInformation::clearRanges() {
169 std::vector< std::pair< std::size_t, std::size_t > > old_ranges = ranges();
170 _NH_.clearRanges();
171 _k_NML_.clearRanges();
172 _score_MDL_.clearRanges();
173 if (old_ranges != ranges()) clear();
174 }
175
177 double CorrectedMutualInformation::_NI_score_(NodeId var_x,
178 NodeId var_y,
179 const std::vector< NodeId >& vars_z) {
180 /*
181 * We have a few partial entropies to compute in order to have the
182 * 2-point mutual information:
183 * I(x;y) = H(x) + H(y) - H(x,y)
184 * correspondingly
185 * I(x;y) = Hx + Hy - Hxy
186 * or
187 * I(x;y|z) = H(x,z) + H(y,z) - H(z) - H(x,y,z)
188 * correspondingly
189 * I(x;y|z) = Hxz + Hyz - Hz - Hxyz
190 * Note that Entropy H is equal to 1/N times the log2Likelihood,
191 * where N is the size of the database.
192 * Remember that we return N times I(x;y|z)
193 */
194
195 // if the score has already been computed, get its value
196 const IdCondSet idset_xyz(var_x, var_y, vars_z, false, false);
197 if (_use_ICache_)
198 if (_ICache_.exists(idset_xyz)) return _ICache_.score(idset_xyz);
199
200 // compute the score
201
202 // here, we distinguish nodesets with conditioning nodes from those
203 // without conditioning nodes
204 double score;
205 if (!vars_z.empty()) {
206 std::vector< NodeId > vars(vars_z);
207 vars.push_back(var_x);
208 vars.push_back(var_y);
209 const double NHxyz = -_NH_.score(IdCondSet(vars, false, true));
210
211 vars.pop_back();
212 const double NHxz = -_NH_.score(IdCondSet(vars, false, true));
213
214 vars.pop_back();
215 vars.push_back(var_y);
216 const double NHyz = -_NH_.score(IdCondSet(vars, false, true));
217
218 vars.pop_back();
219 const double NHz = -_NH_.score(IdCondSet(vars, false, true));
220
221 const double NHxz_NHyz = NHxz + NHyz;
222 double NHz_NHxyz = NHz + NHxyz;
223
224 // avoid numeric instability due to rounding errors
225 double ratio = 1;
226 if (NHxz_NHyz > 0) {
227 ratio = (NHxz_NHyz - NHz_NHxyz) / NHxz_NHyz;
228 } else if (NHz_NHxyz > 0) {
229 ratio = (NHxz_NHyz - NHz_NHxyz) / NHz_NHxyz;
230 }
231 if (ratio < 0) ratio = -ratio;
232 if (ratio < _threshold_) {
233 NHz_NHxyz = NHxz_NHyz; // ensure that the score is equal to 0
234 }
235
236 score = NHxz_NHyz - NHz_NHxyz;
237 } else {
238 const double NHxy
239 = -_NH_.score(IdCondSet(var_x, var_y, _empty_conditioning_set_, true, false));
240 const double NHx = -_NH_.score(var_x);
241 const double NHy = -_NH_.score(var_y);
242
243 double NHx_NHy = NHx + NHy;
244
245 // avoid numeric instability due to rounding errors
246 double ratio = 1;
247 if (NHx_NHy > 0) {
248 ratio = (NHx_NHy - NHxy) / NHx_NHy;
249 } else if (NHxy > 0) {
250 ratio = (NHx_NHy - NHxy) / NHxy;
251 }
252 if (ratio < 0) ratio = -ratio;
253 if (ratio < _threshold_) {
254 NHx_NHy = NHxy; // ensure that the score is equal to 0
255 }
256
257 score = NHx_NHy - NHxy;
258 }
259
260
261 // shall we put the score into the cache?
262 if (_use_ICache_) { _ICache_.insert(idset_xyz, score); }
263
264 return score;
265 }
266
268 double CorrectedMutualInformation::_K_score_(NodeId var1,
269 NodeId var2,
270 const std::vector< NodeId >& conditioning_ids) {
271 // if no penalty, return 0
272 if (_kmode_ == KModeTypes::NoCorr) return 0.0;
273
274
275 // If using the K cache, verify whether the set isn't already known
276 IdCondSet idset = IdCondSet(var1, var2, conditioning_ids, false);
277 if (_use_KCache_)
278 if (_KCache_.exists(idset)) return _KCache_.score(idset);
279
280
281 // compute the score
282 double score;
283 size_t rx;
284 size_t ry;
285 size_t rui;
286 switch (_kmode_) {
287 case KModeTypes::MDL : {
288 const auto& database = _NH_.database();
289 const auto& node2cols = _NH_.nodeId2Columns();
290
291 rui = 1;
292 if (!node2cols.empty()) {
293 rx = database.domainSize(node2cols.second(var1));
294 ry = database.domainSize(node2cols.second(var2));
295 for (const NodeId i: conditioning_ids) {
296 rui *= database.domainSize(node2cols.second(i));
297 }
298 } else {
299 rx = database.domainSize(var1);
300 ry = database.domainSize(var2);
301 for (const NodeId i: conditioning_ids) {
302 rui *= database.domainSize(i);
303 }
304 }
305
306 // compute the size of the database, including the a priori
307 const double N = _score_MDL_.N(idset);
308
309 score = 0.5 * (rx - 1) * (ry - 1) * rui * std::log2(N);
310 } break;
311
312 case KModeTypes::NML : score = _k_NML_.score(var1, var2, conditioning_ids); break;
313
314 default :
316 "CorrectedMutualInformation mode does "
317 "not support yet this correction");
318 }
319
320 // shall we put the score into the cache?
321 if (_use_KCache_) { _KCache_.insert(idset, score); }
322 return score;
323 }
324
325 } /* namespace learning */
326
327} /* namespace gum */
328
329#endif /* DOXYGEN_SHOULD_SKIP_THIS */
Exception : there is something wrong with an implementation.
CorrectedMutualInformation(const DBRowGeneratorParser &parser, const Prior &prior, const std::vector< std::pair< std::size_t, std::size_t > > &ranges, const Bijection< NodeId, std::size_t > &nodeId2columns=Bijection< NodeId, std::size_t >())
default constructor
the class used to read a row in the database and to transform it into a set of DBRow instances that c...
the base class for all a priori
Definition prior.h:83
The class computing n times the corrected mutual information (where n is the size (or the weight) of ...
The class computing n times the corrected mutual information, as used in the MIIC algorithm.
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
include the inlined functions if necessary
Definition CSVParser.h:54
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
STL namespace.