aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
graphChangesSelector4DiGraph_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
49#ifndef DOXYGEN_SHOULD_SKIP_THIS
50
51# include <limits>
52
53namespace gum {
54
55 namespace learning {
56
58 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
61 STRUCTURAL_CONSTRAINT& constraint,
62 GRAPH_CHANGES_GENERATOR& changes_generator) :
63 _score_(score.clone()), _constraint_(&constraint), _changes_generator_(&changes_generator) {
64 _parents_.resize(32);
65 GUM_CONSTRUCTOR(GraphChangesSelector4DiGraph);
66 }
67
69 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
70 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::
71 GraphChangesSelector4DiGraph(
72 const GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >&
73 from) :
74 _score_(from._score_ != nullptr ? from._score_->clone() : nullptr),
75 _constraint_(from._constraint_), _changes_generator_(from._changes_generator_),
76 _changes_(from._changes_), _change_scores_(from._change_scores_),
77 _change_queue_per_node_(from._change_queue_per_node_), _node_queue_(from._node_queue_),
78 _illegal_changes_(from._illegal_changes_),
79 _node_current_scores_(from._node_current_scores_), _parents_(from._parents_),
80 _queues_valid_(from._queues_valid_), _queues_to_update_(from._queues_to_update_) {
81 // for debugging
82 GUM_CONS_CPY(GraphChangesSelector4DiGraph);
83 }
84
86 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
87 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::
88 GraphChangesSelector4DiGraph(
89 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >&& from) :
90 _score_(from._score_), _constraint_(std::move(from._constraint_)),
91 _changes_generator_(std::move(from._changes_generator_)),
92 _changes_(std::move(from._changes_)), _change_scores_(std::move(from._change_scores_)),
93 _change_queue_per_node_(std::move(from._change_queue_per_node_)),
94 _node_queue_(std::move(from._node_queue_)),
95 _illegal_changes_(std::move(from._illegal_changes_)),
96 _node_current_scores_(std::move(from._node_current_scores_)),
97 _parents_(std::move(from._parents_)), _queues_valid_(std::move(from._queues_valid_)),
98 _queues_to_update_(std::move(from._queues_to_update_)) {
99 from._score_ = nullptr;
100 // for debugging
101 GUM_CONS_MOV(GraphChangesSelector4DiGraph);
102 }
103
105 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
106 INLINE
107 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
108 GRAPH_CHANGES_GENERATOR >::~GraphChangesSelector4DiGraph() {
109 if (_score_ != nullptr) delete _score_;
110 GUM_DESTRUCTOR(GraphChangesSelector4DiGraph);
111 }
112
114 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
115 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >&
116 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::operator=(
117 const GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >&
118 from) {
119 if (this != &from) {
120 // remove the old score
121 if (_score_ != nullptr) {
122 delete _score_;
123 _score_ = nullptr;
124 }
125
126 if (from._score_ != nullptr) _score_ = from._score_->clone();
127 _constraint_ = from._constraint_;
128 _changes_generator_ = from._changes_generator_;
129 _changes_ = from._changes_;
130 _change_scores_ = from._change_scores_;
131 _change_queue_per_node_ = from._change_queue_per_node_;
132 _node_queue_ = from._node_queue_;
133 _illegal_changes_ = from._illegal_changes_;
134 _node_current_scores_ = from._node_current_scores_;
135 _parents_ = from._parents_;
136 _queues_valid_ = from._queues_valid_;
137 _queues_to_update_ = from._queues_to_update_;
138 }
139
140 return *this;
141 }
142
144 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
145 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >&
146 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::operator=(
147 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >&& from) {
148 if (this != &from) {
149 _score_ = from._score_;
150 from._score_ = nullptr;
151
152 _constraint_ = std::move(from._constraint_);
153 _changes_generator_ = std::move(from._changes_generator_);
154 _changes_ = std::move(from._changes_);
155 _change_scores_ = std::move(from._change_scores_);
156 _change_queue_per_node_ = std::move(from._change_queue_per_node_);
157 _node_queue_ = std::move(from._node_queue_);
158 _illegal_changes_ = std::move(from._illegal_changes_);
159 _node_current_scores_ = std::move(from._node_current_scores_);
160 _parents_ = std::move(from._parents_);
161 _queues_valid_ = std::move(from._queues_valid_);
162 _queues_to_update_ = std::move(from._queues_to_update_);
163 }
164
165 return *this;
166 }
167
169 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
170 INLINE bool GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::
171 isChangeValid(const GraphChange& change) const {
172 return _constraint_->checkModification(change);
173 }
174
176 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
177 INLINE bool GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::
178 _isChangeValid_(const std::size_t index) const {
179 return isChangeValid(_changes_[index]);
180 }
181
183 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
184 void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::setGraph(
185 DiGraph& graph) {
186 // fill the DAG with all the missing nodes
187 const DatabaseTable& database = _score_->database();
188 const auto& nodeId2Columns = _score_->nodeId2Columns();
189
190 if (nodeId2Columns.empty()) {
191 const NodeId nb_nodes = NodeId(database.nbVariables());
192 for (NodeId i = NodeId(0); i < nb_nodes; ++i) {
193 if (!graph.existsNode(i)) { graph.addNodeWithId(i); }
194 }
195 } else {
196 for (auto iter = nodeId2Columns.cbegin(); iter != nodeId2Columns.cend(); ++iter) {
197 const NodeId id = iter.first();
198 if (!graph.existsNode(id)) { graph.addNodeWithId(id); }
199 }
200 }
201
202
203 // remove the node that do belong neither to the database
204 // nor to nodeId2Columns
205 if (nodeId2Columns.empty()) {
206 const NodeId nb_nodes = NodeId(database.nbVariables());
207 for (auto node: graph) {
208 if (node >= nb_nodes) { graph.eraseNode(node); }
209 }
210 } else {
211 for (auto node: graph) {
212 if (!nodeId2Columns.existsFirst(node)) { graph.eraseNode(node); }
213 }
214 }
215
216
217 // _constraint_ is the constraint used by the selector to restrict the set
218 // of applicable changes. However, the generator may have a different set
219 // of constraints (e.g., a constraintSliceOrder needs be tested only by the
220 // generator because the changes returned by the generator will always
221 // statisfy this constraint, hence the selector needs not test this
222 // constraint). Therefore, if the selector and generator have different
223 // constraints, both should use method setGraph() to initialize
224 // themselves.
225 _constraint_->setGraph(graph);
226 if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(&(_changes_generator_->constraint()))
227 != _constraint_) {
228 _changes_generator_->constraint().setGraph(graph);
229 }
230
231 _changes_generator_->setGraph(graph);
232
233
234 // save the set of parents of each node (this will speed-up the
235 // computations of the scores)
236 const std::size_t nb_nodes = graph.size();
237 {
238 const std::vector< NodeId > empty_pars;
239 _parents_.clear();
240 _parents_.resize(nb_nodes);
241 for (const auto node: graph) {
242 auto& node_parents = _parents_.insert(node, empty_pars).second;
243 const NodeSet& dag_parents = graph.parents(node);
244 if (!dag_parents.empty()) {
245 node_parents.resize(dag_parents.size());
246 std::size_t j = std::size_t(0);
247 for (const auto par: dag_parents) {
248 node_parents[j] = par;
249 ++j;
250 }
251 }
252 }
253 }
254
255 // assign a score to each node given its parents in the current graph
256 _node_current_scores_.clear();
257 _node_current_scores_.resize(nb_nodes);
258 for (const auto node: graph) {
259 _node_current_scores_.insert(node, _score_->score(node, _parents_[node]));
260 }
261
262 // compute all the possible changes
263 _changes_.clear();
264 _changes_.resize(nb_nodes);
265 for (const auto& change: *_changes_generator_) {
266 _changes_ << change;
267 }
268 _changes_generator_->notifyGetCompleted();
269
270 // determine the changes that are illegal and prepare the computation of
271 // the scores of all the legal changes
272 _illegal_changes_.clear();
273
274 // set the _change_scores_ and _change_queue_per_node_ for legal changes
275 _change_scores_.clear();
276 _change_scores_.resize(_changes_.size(),
277 std::pair< double, double >(std::numeric_limits< double >::min(),
278 std::numeric_limits< double >::min()));
279 _change_queue_per_node_.clear();
280 _change_queue_per_node_.resize(nb_nodes);
281 {
282 const PriorityQueue< std::size_t, double, std::greater< double > > empty_prio;
283 for (const auto node: graph) {
284 _change_queue_per_node_.insert(node, empty_prio);
285 }
286 }
287
288 for (std::size_t i = std::size_t(0); i < _changes_.size(); ++i) {
289 if (!_isChangeValid_(i)) {
290 _illegal_changes_.insert(i);
291 } else {
292 const GraphChange& change = _changes_[i];
293
294 switch (change.type()) {
295 case GraphChangeType::ARC_ADDITION : {
296 auto& parents = _parents_[change.node2()];
297 parents.push_back(change.node1());
298 const double delta
299 = _score_->score(change.node2(), parents) - _node_current_scores_[change.node2()];
300 parents.pop_back();
301
302 _change_scores_[i].second = delta;
303 _change_queue_per_node_[change.node2()].insert(i, delta);
304 } break;
305
306 case GraphChangeType::ARC_DELETION : {
307 auto& parents = _parents_[change.node2()];
308 for (auto& par: parents) {
309 if (par == change.node1()) {
310 par = *(parents.rbegin());
311 parents.pop_back();
312 break;
313 }
314 }
315 const double delta
316 = _score_->score(change.node2(), parents) - _node_current_scores_[change.node2()];
317 parents.push_back(change.node1());
318
319 _change_scores_[i].second = delta;
320 _change_queue_per_node_[change.node2()].insert(i, delta);
321 } break;
322
323 case GraphChangeType::ARC_REVERSAL : {
324 // remove arc ( node1 -> node2 )
325 auto& parents2 = _parents_[change.node2()];
326 for (auto& par: parents2) {
327 if (par == change.node1()) {
328 par = *(parents2.rbegin());
329 parents2.pop_back();
330 break;
331 }
332 }
333
334 const double delta2 = _score_->score(change.node2(), parents2)
335 - _node_current_scores_[change.node2()];
336 parents2.push_back(change.node1());
337
338 // add arc ( node2 -> node1 )
339 auto& parents1 = _parents_[change.node1()];
340 parents1.push_back(change.node2());
341 const double delta1 = _score_->score(change.node1(), parents1)
342 - _node_current_scores_[change.node1()];
343 parents1.pop_back();
344
345 _change_scores_[i].first = delta1;
346 _change_scores_[i].second = delta2;
347
348 const double delta = delta1 + delta2;
349 _change_queue_per_node_[change.node1()].insert(i, delta);
350 _change_queue_per_node_[change.node2()].insert(i, delta);
351
352 } break;
353
354 default : {
356 "Method setGraph of GraphChangesSelector4DiGraph "
357 << "does not handle yet graph change of type " << change.type());
358 }
359 }
360 }
361 }
362
363 // update the global queue
364 _node_queue_.clear();
365 for (const auto node: graph) {
366 _node_queue_.insert(node,
367 _change_queue_per_node_[node].empty()
368 ? std::numeric_limits< double >::min()
369 : _change_queue_per_node_[node].topPriority());
370 }
371 _queues_valid_ = true;
372 _queues_to_update_.clear();
373 }
374
376 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
377 void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::
378 _invalidateChange_(const std::size_t change_index) {
379 const GraphChange& change = _changes_[change_index];
380 if (change.type() == GraphChangeType::ARC_REVERSAL) {
381 // remove the tail change from its priority queue
382 PriorityQueue< std::size_t, double, std::greater< double > >& queue1
383 = _change_queue_per_node_[change.node1()];
384 queue1.erase(change_index);
385
386 // recompute the top priority for the changes of the head
387 const double new_priority
388 = queue1.empty() ? std::numeric_limits< double >::min() : queue1.topPriority();
389 _node_queue_.setPriority(change.node1(), new_priority);
390 }
391
392 // remove the head change from its priority queue
393 PriorityQueue< std::size_t, double, std::greater< double > >& queue2
394 = _change_queue_per_node_[change.node2()];
395 queue2.erase(change_index);
396
397 // recompute the top priority for the changes of the head
398 const double new_priority
399 = queue2.empty() ? std::numeric_limits< double >::min() : queue2.topPriority();
400 _node_queue_.setPriority(change.node2(), new_priority);
401
402 // put the change into the illegal set
403 _illegal_changes_.insert(change_index);
404 }
405
407 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
408 bool GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::empty() {
409 // put into the illegal change set all the top elements of the different
410 // queues that are not valid anymore
411 if (!_queues_valid_) {
412 for (auto& queue_pair: _change_queue_per_node_) {
413 auto& queue = queue_pair.second;
414 while (!queue.empty() && !_isChangeValid_(queue.top())) {
415 _invalidateChange_(queue.top());
416 }
417 }
418 _queues_valid_ = true;
419 }
420
421 return _node_queue_.topPriority() == std::numeric_limits< double >::min();
422 }
423
426 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
427 bool GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::empty(
428 const NodeId node) {
429 // put into the illegal change set all the top elements of the different
430 // queues that are not valid anymore
431 if (!_queues_valid_) {
432 for (auto& queue_pair: _change_queue_per_node_) {
433 auto& queue = queue_pair.second;
434 while (!queue.empty() && !_isChangeValid_(queue.top())) {
435 _invalidateChange_(queue.top());
436 }
437 }
438 _queues_valid_ = true;
439 }
440
441 return _change_queue_per_node_[node].empty();
442 }
443
445 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
446 INLINE const GraphChange&
447 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
448 GRAPH_CHANGES_GENERATOR >::bestChange() {
449 if (!empty()) return _changes_[_change_queue_per_node_[_node_queue_.top()].top()];
450 else GUM_ERROR(NotFound, "there exists no graph change applicable")
451 }
452
454 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
455 INLINE const GraphChange&
456 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::bestChange(
457 const NodeId node) {
458 if (!empty(node)) return _changes_[_change_queue_per_node_[node].top()];
459 else GUM_ERROR(NotFound, "there exists no graph change applicable")
460 }
461
463 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
464 INLINE double GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
465 GRAPH_CHANGES_GENERATOR >::bestScore() {
466 if (!empty()) return _change_queue_per_node_[_node_queue_.top()].topPriority();
467 else GUM_ERROR(NotFound, "there exists no graph change applicable")
468 }
469
471 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
472 INLINE double
473 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::bestScore(
474 const NodeId node) {
475 if (!empty(node)) return _change_queue_per_node_[node].topPriority();
476 else GUM_ERROR(NotFound, "there exists no graph change applicable")
477 }
478
480 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
481 void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::
482 _illegal2LegalChanges_(Set< std::size_t >& changes_to_recompute) {
483 for (auto iter = _illegal_changes_.beginSafe(); iter != _illegal_changes_.endSafe(); ++iter) {
484 if (_isChangeValid_(*iter)) {
485 const GraphChange& change = _changes_[*iter];
486 if (change.type() == GraphChangeType::ARC_REVERSAL) {
487 _change_queue_per_node_[change.node1()].insert(*iter,
488 std::numeric_limits< double >::min());
489 }
490 _change_queue_per_node_[change.node2()].insert(*iter,
491 std::numeric_limits< double >::min());
492
493 changes_to_recompute.insert(*iter);
494 _illegal_changes_.erase(iter);
495 }
496 }
497 }
498
500 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
501 void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::
502 _findLegalChangesNeedingUpdate_(Set< std::size_t >& changes_to_recompute,
503 const NodeId target_node) {
504 const HashTable< std::size_t, Size >& changes
505 = _change_queue_per_node_[target_node].allValues();
506 for (auto iter = changes.cbeginSafe(); iter != changes.cendSafe(); ++iter) {
507 if (!changes_to_recompute.exists(iter.key())) {
508 if (_isChangeValid_(iter.key())) {
509 changes_to_recompute.insert(iter.key());
510 } else {
511 _invalidateChange_(iter.key());
512 }
513 }
514 }
515 }
516
518 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
519 void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::
520 _updateScores_(const Set< std::size_t >& changes_to_recompute) {
521 Set< NodeId > modified_nodes(changes_to_recompute.size());
522
523 for (const auto change_index: changes_to_recompute) {
524 const GraphChange& change = _changes_[change_index];
525
526 switch (change.type()) {
527 case GraphChangeType::ARC_ADDITION : {
528 // add the arc
529 auto& parents = _parents_[change.node2()];
530 parents.push_back(change.node1());
531 const double delta
532 = _score_->score(change.node2(), parents) - _node_current_scores_[change.node2()];
533 parents.pop_back();
534
535 // update the score
536 _change_scores_[change_index].second = delta;
537
538 // update the head queue
539 _change_queue_per_node_[change.node2()].setPriority(change_index, delta);
540 // indicate which queue was modified
541 modified_nodes.insert(change.node2());
542 } break;
543
544 case GraphChangeType::ARC_DELETION : {
545 // remove the arc
546 auto& parents = _parents_[change.node2()];
547 for (auto& par: parents) {
548 if (par == change.node1()) {
549 par = *(parents.rbegin());
550 parents.pop_back();
551 break;
552 }
553 }
554 const double delta
555 = _score_->score(change.node2(), parents) - _node_current_scores_[change.node2()];
556 parents.push_back(change.node1());
557
558 // update the score
559 _change_scores_[change_index].second = delta;
560
561 // update the head queue
562 _change_queue_per_node_[change.node2()].setPriority(change_index, delta);
563 // indicate which queue was modified
564 modified_nodes.insert(change.node2());
565 } break;
566
567 case GraphChangeType::ARC_REVERSAL : {
568 // remove arc ( node1 -> node2 )
569 auto& parents2 = _parents_[change.node2()];
570 for (auto& par: parents2) {
571 if (par == change.node1()) {
572 par = *(parents2.rbegin());
573 parents2.pop_back();
574 break;
575 }
576 }
577
578 const double delta2
579 = _score_->score(change.node2(), parents2) - _node_current_scores_[change.node2()];
580 parents2.push_back(change.node1());
581
582 // add arc ( node2 -> node1 )
583 auto& parents1 = _parents_[change.node1()];
584 parents1.push_back(change.node2());
585 const double delta1
586 = _score_->score(change.node1(), parents1) - _node_current_scores_[change.node1()];
587 parents1.pop_back();
588
589 // update the scores
590 _change_scores_[change_index].first = delta1;
591 _change_scores_[change_index].second = delta2;
592
593 // update the queues
594 const double delta = delta1 + delta2;
595 _change_queue_per_node_[change.node1()].setPriority(change_index, delta);
596 _change_queue_per_node_[change.node2()].setPriority(change_index, delta);
597
598 // indicate which queues were modified
599 modified_nodes.insert(change.node1());
600 modified_nodes.insert(change.node2());
601 } break;
602
603 default : {
605 "Method _updateScores_ of GraphChangesSelector4DiGraph "
606 << "does not handle yet graph change of type " << change.type());
607 }
608 }
609 }
610
611 // update the node queue
612 for (const auto node: modified_nodes) {
613 _node_queue_.setPriority(node,
614 _change_queue_per_node_[node].empty()
615 ? std::numeric_limits< double >::min()
616 : _change_queue_per_node_[node].topPriority());
617 }
618 }
619
621 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
622 void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
623 GRAPH_CHANGES_GENERATOR >::_getNewChanges_() {
624 // ask the graph change generator for all its available changes
625 for (const auto& change: *_changes_generator_) {
626 // check that the change does not already exist
627 if (!_changes_.exists(change)) {
628 // add the new change. To make the addition simple, we put the new
629 // change into the illegal changes set. Afterwards, the applyChange
630 // function will put the legal changes again into the queues
631 _illegal_changes_.insert(_changes_.size());
632 _changes_ << change;
633 _change_scores_.push_back(
634 std::pair< double, double >(std::numeric_limits< double >::min(),
635 std::numeric_limits< double >::min()));
636 }
637 }
638
639 // indicate to the generator that we have finished retrieving its changes
640 _changes_generator_->notifyGetCompleted();
641 }
642
644 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
645 void
646 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::applyChange(
647 const GraphChange& change) {
648 // first, we get the index of the change
649 const std::size_t change_index = _changes_.pos(change);
650
651 // perform the change
652 Set< std::size_t > changes_to_recompute;
653 switch (change.type()) {
654 case GraphChangeType::ARC_ADDITION : {
655 // update the current score
656 _node_current_scores_[change.node2()] += _change_scores_[change_index].second;
657 _parents_[change.node2()].push_back(change.node1());
658
659 // inform the constraint that the graph has been modified
660 _constraint_->modifyGraph(static_cast< const ArcAddition& >(change));
661 if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(&(_changes_generator_->constraint()))
662 != _constraint_) {
663 _changes_generator_->constraint().modifyGraph(
664 static_cast< const ArcAddition& >(change));
665 }
666
667 // get new possible changes from the graph change generator
668 // warning: put the next 3 lines before calling _illegal2LegalChanges_
669 _changes_generator_->modifyGraph(static_cast< const ArcAddition& >(change));
670 _getNewChanges_();
671
672 // check whether some illegal changes can be put into the valid queues
673 _illegal2LegalChanges_(changes_to_recompute);
674 _invalidateChange_(change_index);
675 _findLegalChangesNeedingUpdate_(changes_to_recompute, change.node2());
676 _updateScores_(changes_to_recompute);
677 } break;
678
679 case GraphChangeType::ARC_DELETION : {
680 // update the current score
681 _node_current_scores_[change.node2()] += _change_scores_[change_index].second;
682 auto& parents = _parents_[change.node2()];
683 for (auto& par: parents) {
684 if (par == change.node1()) {
685 par = *(parents.rbegin());
686 parents.pop_back();
687 break;
688 }
689 }
690
691 // inform the constraint that the graph has been modified
692 _constraint_->modifyGraph(static_cast< const ArcDeletion& >(change));
693 if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(&(_changes_generator_->constraint()))
694 != _constraint_) {
695 _changes_generator_->constraint().modifyGraph(
696 static_cast< const ArcDeletion& >(change));
697 }
698
699 // get new possible changes from the graph change generator
700 // warning: put the next 3 lines before calling _illegal2LegalChanges_
701 _changes_generator_->modifyGraph(static_cast< const ArcDeletion& >(change));
702 _getNewChanges_();
703
704 // check whether some illegal changes can be put into the valid queues
705 _illegal2LegalChanges_(changes_to_recompute);
706 _invalidateChange_(change_index);
707 _findLegalChangesNeedingUpdate_(changes_to_recompute, change.node2());
708 _updateScores_(changes_to_recompute);
709 } break;
710
711 case GraphChangeType::ARC_REVERSAL : {
712 // update the current score
713 _node_current_scores_[change.node1()] += _change_scores_[change_index].first;
714 _node_current_scores_[change.node2()] += _change_scores_[change_index].second;
715 _parents_[change.node1()].push_back(change.node2());
716 auto& parents = _parents_[change.node2()];
717 for (auto& par: parents) {
718 if (par == change.node1()) {
719 par = *(parents.rbegin());
720 parents.pop_back();
721 break;
722 }
723 }
724
725 // inform the constraint that the graph has been modified
726 _constraint_->modifyGraph(static_cast< const ArcReversal& >(change));
727 if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(&(_changes_generator_->constraint()))
728 != _constraint_) {
729 _changes_generator_->constraint().modifyGraph(
730 static_cast< const ArcReversal& >(change));
731 }
732
733 // get new possible changes from the graph change generator
734 // warning: put the next 3 lines before calling _illegal2LegalChanges_
735 _changes_generator_->modifyGraph(static_cast< const ArcReversal& >(change));
736 _getNewChanges_();
737
738 // check whether some illegal changes can be put into the valid queues
739 _illegal2LegalChanges_(changes_to_recompute);
740 _invalidateChange_(change_index);
741 _findLegalChangesNeedingUpdate_(changes_to_recompute, change.node1());
742 _findLegalChangesNeedingUpdate_(changes_to_recompute, change.node2());
743 _updateScores_(changes_to_recompute);
744 } break;
745
746 default :
748 "Method applyChange of GraphChangesSelector4DiGraph "
749 << "does not handle yet graph change of type " << change.type());
750 }
751
752 _queues_valid_ = false;
753 }
754
756 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
757 void GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT, GRAPH_CHANGES_GENERATOR >::
758 applyChangeWithoutScoreUpdate(const GraphChange& change) {
759 // first, we get the index of the change
760 const std::size_t change_index = _changes_.pos(change);
761
762 // perform the change
763 switch (change.type()) {
764 case GraphChangeType::ARC_ADDITION : {
765 // update the current score
766 _node_current_scores_[change.node2()] += _change_scores_[change_index].second;
767 _parents_[change.node2()].push_back(change.node1());
768
769 // inform the constraint that the graph has been modified
770 _constraint_->modifyGraph(static_cast< const ArcAddition& >(change));
771 if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(&(_changes_generator_->constraint()))
772 != _constraint_) {
773 _changes_generator_->constraint().modifyGraph(
774 static_cast< const ArcAddition& >(change));
775 }
776
777 // get new possible changes from the graph change generator
778 // warning: put the next 3 lines before calling _illegal2LegalChanges_
779 _changes_generator_->modifyGraph(static_cast< const ArcAddition& >(change));
780 _getNewChanges_();
781
782 // indicate that we have just applied the change
783 _invalidateChange_(change_index);
784
785 // indicate that the queue to which the change belongs needs be
786 // updated
787 _queues_to_update_.insert(change.node2());
788 } break;
789
790 case GraphChangeType::ARC_DELETION : {
791 // update the current score
792 _node_current_scores_[change.node2()] += _change_scores_[change_index].second;
793 auto& parents = _parents_[change.node2()];
794 for (auto& par: parents) {
795 if (par == change.node1()) {
796 par = *(parents.rbegin());
797 parents.pop_back();
798 break;
799 }
800 }
801
802 // inform the constraint that the graph has been modified
803 _constraint_->modifyGraph(static_cast< const ArcDeletion& >(change));
804 if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(&(_changes_generator_->constraint()))
805 != _constraint_) {
806 _changes_generator_->constraint().modifyGraph(
807 static_cast< const ArcDeletion& >(change));
808 }
809
810 // get new possible changes from the graph change generator
811 // warning: put the next 3 lines before calling _illegal2LegalChanges_
812 _changes_generator_->modifyGraph(static_cast< const ArcDeletion& >(change));
813 _getNewChanges_();
814
815 // indicate that we have just applied the change
816 _invalidateChange_(change_index);
817
818 // indicate that the queue to which the change belongs needs be
819 // updated
820 _queues_to_update_.insert(change.node2());
821 } break;
822
823 case GraphChangeType::ARC_REVERSAL : {
824 // update the current score
825 _node_current_scores_[change.node1()] += _change_scores_[change_index].first;
826 _node_current_scores_[change.node2()] += _change_scores_[change_index].second;
827 _parents_[change.node1()].push_back(change.node2());
828 auto& parents = _parents_[change.node2()];
829 for (auto& par: parents) {
830 if (par == change.node1()) {
831 par = *(parents.rbegin());
832 parents.pop_back();
833 break;
834 }
835 }
836
837 // inform the constraint that the graph has been modified
838 _constraint_->modifyGraph(static_cast< const ArcReversal& >(change));
839 if (reinterpret_cast< STRUCTURAL_CONSTRAINT* >(&(_changes_generator_->constraint()))
840 != _constraint_) {
841 _changes_generator_->constraint().modifyGraph(
842 static_cast< const ArcReversal& >(change));
843 }
844
845 // get new possible changes from the graph change generator
846 // warning: put the next 3 lines before calling _illegal2LegalChanges_
847 _changes_generator_->modifyGraph(static_cast< const ArcReversal& >(change));
848 _getNewChanges_();
849
850 // indicate that we have just applied the change
851 _invalidateChange_(change_index);
852
853 // indicate that the queue to which the change belongs needs be
854 // updated
855 _queues_to_update_.insert(change.node1());
856 _queues_to_update_.insert(change.node2());
857 } break;
858
859 default :
861 "Method applyChangeWithoutScoreUpdate of "
862 << "GraphChangesSelector4DiGraph "
863 << "does not handle yet graph change of type " << change.type());
864 }
865 }
866
868 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
869 void
870 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
871 GRAPH_CHANGES_GENERATOR >::updateScoresAfterAppliedChanges() {
872 // determine which changes in the illegal set are now legal
873 Set< std::size_t > new_legal_changes;
874 for (auto iter = _illegal_changes_.beginSafe(); iter != _illegal_changes_.endSafe(); ++iter) {
875 if (_isChangeValid_(*iter)) {
876 new_legal_changes.insert(*iter);
877 _illegal_changes_.erase(iter);
878 }
879 }
880
881 // update the scores that need be updated
882 Set< std::size_t > changes_to_recompute;
883 for (const auto& node: _queues_to_update_) {
884 _findLegalChangesNeedingUpdate_(changes_to_recompute, node);
885 }
886 _queues_to_update_.clear();
887
888 // put the previously illegal changes that are now legal into their queues
889 for (const auto change_index: new_legal_changes) {
890 const GraphChange& change = _changes_[change_index];
891 if (change.type() == GraphChangeType::ARC_REVERSAL) {
892 _change_queue_per_node_[change.node1()].insert(change_index,
893 std::numeric_limits< double >::min());
894 }
895 _change_queue_per_node_[change.node2()].insert(change_index,
896 std::numeric_limits< double >::min());
897
898 changes_to_recompute.insert(change_index);
899 }
900
901 // compute the scores that we need
902 _updateScores_(changes_to_recompute);
903
904 _queues_valid_ = false;
905 }
906
908 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
909 std::vector< std::pair< NodeId, double > >
910 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
911 GRAPH_CHANGES_GENERATOR >::nodesSortedByBestScore() const {
912 std::vector< std::pair< NodeId, double > > result(_node_queue_.size());
913 for (std::size_t i = std::size_t(0); i < _node_queue_.size(); ++i) {
914 result[i].first = _node_queue_[i];
915 result[i].second = _node_queue_.priorityByPos(i);
916 }
917
918 std::sort(result.begin(),
919 result.end(),
920 [](const std::pair< NodeId, double >& a,
921 const std::pair< NodeId, double >& b) -> bool { return a.second > b.second; });
922
923 return result;
924 }
925
927 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
928 std::vector< std::pair< NodeId, double > >
929 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
930 GRAPH_CHANGES_GENERATOR >::nodesUnsortedWithScore() const {
931 std::vector< std::pair< NodeId, double > > result(_node_queue_.size());
932 for (std::size_t i = std::size_t(0); i < _node_queue_.size(); ++i) {
933 result[i].first = _node_queue_[i];
934 result[i].second = _node_queue_.priorityByPos(i);
935 }
936
937 return result;
938 }
939
941 template < typename STRUCTURAL_CONSTRAINT, typename GRAPH_CHANGES_GENERATOR >
942 INLINE typename GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
943 GRAPH_CHANGES_GENERATOR >::GeneratorType&
944 GraphChangesSelector4DiGraph< STRUCTURAL_CONSTRAINT,
945 GRAPH_CHANGES_GENERATOR >::graphChangeGenerator()
946 const noexcept {
947 return *_changes_generator_;
948 }
949
950
951 } /* namespace learning */
952
953} /* namespace gum */
954
955#endif /* DOXYGEN_SHOULD_SKIP_THIS */
Exception : the element we looked for cannot be found.
Exception : there is something wrong with an implementation.
GraphChangesSelector4DiGraph(Score &score, STRUCTURAL_CONSTRAINT &constraint, GRAPH_CHANGES_GENERATOR &changes_generator)
default constructor
The base class for all the scores used for learning (BIC, BDeu, etc).
Definition score.h:68
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
Set< NodeId > NodeSet
Some typdefs and define for shortcuts ...
include the inlined functions if necessary
Definition CSVParser.h:54
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
STL namespace.