51#ifndef DOXYGEN_SHOULD_SKIP_THIS
66 const std::vector< std::pair< std::size_t, std::size_t > >& ranges,
67 const Bijection< NodeId, std::size_t >& nodeId2columns) :
68 _NH_(parser, prior, ranges, nodeId2columns), _k_NML_(parser, prior, ranges, nodeId2columns),
69 _score_MDL_(parser, prior, ranges, nodeId2columns) {
70 GUM_CONSTRUCTOR(CorrectedMutualInformation);
74 CorrectedMutualInformation::CorrectedMutualInformation(
75 const DBRowGeneratorParser& parser,
77 const Bijection< NodeId, std::size_t >& nodeId2columns) :
78 _NH_(parser, prior, nodeId2columns), _k_NML_(parser, prior, nodeId2columns),
79 _score_MDL_(parser, prior, nodeId2columns) {
80 GUM_CONSTRUCTOR(CorrectedMutualInformation);
84 CorrectedMutualInformation::CorrectedMutualInformation(
const CorrectedMutualInformation& from) :
85 _NH_(from._NH_), _k_NML_(from._k_NML_), _score_MDL_(from._score_MDL_),
86 _kmode_(from._kmode_), _use_ICache_(from._use_ICache_), _use_HCache_(from._use_HCache_),
87 _use_KCache_(from._use_KCache_), _use_CnrCache_(from._use_CnrCache_),
88 _ICache_(from._ICache_), _KCache_(from._KCache_) {
89 GUM_CONS_CPY(CorrectedMutualInformation);
93 CorrectedMutualInformation::CorrectedMutualInformation(CorrectedMutualInformation&& from) :
94 _NH_(
std::move(from._NH_)), _k_NML_(
std::move(from._k_NML_)),
95 _score_MDL_(
std::move(from._score_MDL_)), _kmode_(from._kmode_),
96 _use_ICache_(from._use_ICache_), _use_HCache_(from._use_HCache_),
97 _use_KCache_(from._use_KCache_), _use_CnrCache_(from._use_CnrCache_),
98 _ICache_(
std::move(from._ICache_)), _KCache_(
std::move(from._KCache_)) {
99 GUM_CONS_MOV(CorrectedMutualInformation);
103 CorrectedMutualInformation* CorrectedMutualInformation::clone()
const {
104 return new CorrectedMutualInformation(*
this);
108 CorrectedMutualInformation::~CorrectedMutualInformation() {
110 GUM_DESTRUCTOR(CorrectedMutualInformation);
114 CorrectedMutualInformation&
115 CorrectedMutualInformation::operator=(
const CorrectedMutualInformation& from) {
118 _k_NML_ = from._k_NML_;
119 _score_MDL_ = from._score_MDL_;
120 _kmode_ = from._kmode_;
121 _use_ICache_ = from._use_ICache_;
122 _use_HCache_ = from._use_HCache_;
123 _use_KCache_ = from._use_KCache_;
124 _use_CnrCache_ = from._use_CnrCache_;
125 _ICache_ = from._ICache_;
126 _KCache_ = from._KCache_;
132 CorrectedMutualInformation&
133 CorrectedMutualInformation::operator=(CorrectedMutualInformation&& from) {
135 _NH_ = std::move(from._NH_);
136 _k_NML_ = std::move(from._k_NML_);
137 _score_MDL_ = std::move(from._score_MDL_);
138 _kmode_ = from._kmode_;
139 _use_ICache_ = from._use_ICache_;
140 _use_HCache_ = from._use_HCache_;
141 _use_KCache_ = from._use_KCache_;
142 _use_CnrCache_ = from._use_CnrCache_;
143 _ICache_ = std::move(from._ICache_);
144 _KCache_ = std::move(from._KCache_);
156 void CorrectedMutualInformation::setRanges(
157 const std::vector< std::pair< std::size_t, std::size_t > >& new_ranges) {
158 std::vector< std::pair< std::size_t, std::size_t > > old_ranges = ranges();
160 _NH_.setRanges(new_ranges);
161 _k_NML_.setRanges(new_ranges);
162 _score_MDL_.setRanges(new_ranges);
164 if (old_ranges != ranges()) clear();
168 void CorrectedMutualInformation::clearRanges() {
169 std::vector< std::pair< std::size_t, std::size_t > > old_ranges = ranges();
171 _k_NML_.clearRanges();
172 _score_MDL_.clearRanges();
173 if (old_ranges != ranges()) clear();
177 double CorrectedMutualInformation::_NI_score_(NodeId var_x,
179 const std::vector< NodeId >& vars_z) {
196 const IdCondSet idset_xyz(var_x, var_y, vars_z,
false,
false);
198 if (_ICache_.exists(idset_xyz))
return _ICache_.score(idset_xyz);
205 if (!vars_z.empty()) {
206 std::vector< NodeId > vars(vars_z);
207 vars.push_back(var_x);
208 vars.push_back(var_y);
209 const double NHxyz = -_NH_.score(IdCondSet(vars,
false,
true));
212 const double NHxz = -_NH_.score(IdCondSet(vars,
false,
true));
215 vars.push_back(var_y);
216 const double NHyz = -_NH_.score(IdCondSet(vars,
false,
true));
219 const double NHz = -_NH_.score(IdCondSet(vars,
false,
true));
221 const double NHxz_NHyz = NHxz + NHyz;
222 double NHz_NHxyz = NHz + NHxyz;
227 ratio = (NHxz_NHyz - NHz_NHxyz) / NHxz_NHyz;
228 }
else if (NHz_NHxyz > 0) {
229 ratio = (NHxz_NHyz - NHz_NHxyz) / NHz_NHxyz;
231 if (ratio < 0) ratio = -ratio;
232 if (ratio < _threshold_) {
233 NHz_NHxyz = NHxz_NHyz;
236 score = NHxz_NHyz - NHz_NHxyz;
239 = -_NH_.score(IdCondSet(var_x, var_y, _empty_conditioning_set_,
true,
false));
240 const double NHx = -_NH_.score(var_x);
241 const double NHy = -_NH_.score(var_y);
243 double NHx_NHy = NHx + NHy;
248 ratio = (NHx_NHy - NHxy) / NHx_NHy;
249 }
else if (NHxy > 0) {
250 ratio = (NHx_NHy - NHxy) / NHxy;
252 if (ratio < 0) ratio = -ratio;
253 if (ratio < _threshold_) {
257 score = NHx_NHy - NHxy;
262 if (_use_ICache_) { _ICache_.insert(idset_xyz, score); }
268 double CorrectedMutualInformation::_K_score_(NodeId var1,
270 const std::vector< NodeId >& conditioning_ids) {
272 if (_kmode_ == KModeTypes::NoCorr)
return 0.0;
276 IdCondSet idset = IdCondSet(var1, var2, conditioning_ids,
false);
278 if (_KCache_.exists(idset))
return _KCache_.score(idset);
287 case KModeTypes::MDL : {
288 const auto& database = _NH_.database();
289 const auto& node2cols = _NH_.nodeId2Columns();
292 if (!node2cols.empty()) {
293 rx = database.domainSize(node2cols.second(var1));
294 ry = database.domainSize(node2cols.second(var2));
295 for (
const NodeId i: conditioning_ids) {
296 rui *= database.domainSize(node2cols.second(i));
299 rx = database.domainSize(var1);
300 ry = database.domainSize(var2);
301 for (
const NodeId i: conditioning_ids) {
302 rui *= database.domainSize(i);
307 const double N = _score_MDL_.N(idset);
309 score = 0.5 * (rx - 1) * (ry - 1) * rui * std::log2(N);
312 case KModeTypes::NML : score = _k_NML_.score(var1, var2, conditioning_ids);
break;
316 "CorrectedMutualInformation mode does "
317 "not support yet this correction");
321 if (_use_KCache_) { _KCache_.insert(idset, score); }
Exception : there is something wrong with an implementation.
the class used to read a row in the database and to transform it into a set of DBRow instances that c...
the base class for all a priori
#define GUM_ERROR(type, msg)
include the inlined functions if necessary
gum is the global namespace for all aGrUM entities