aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
localSearchWithTabuList_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
49
52
53namespace gum {
54
55 namespace learning {
56
58 template < typename GRAPH_CHANGES_SELECTOR >
59 DAG LocalSearchWithTabuList::learnStructure(GRAPH_CHANGES_SELECTOR& selector, DAG dag) {
60 selector.setGraph(dag);
61
62 unsigned int nb_changes_applied = 0;
63 Idx applied_change_with_positive_score = 0;
64 Idx current_N = 0;
65
67
68 // a vector that indicates which queues have valid scores, i.e., scores
69 // that were not invalidated by previously applied changes
70 std::vector< bool > impacted_queues(dag.size(), false);
71
72 // the best dag found so far with its score
73 DAG best_dag = dag;
74 double best_score = 0;
75 double current_score = 0;
76 double delta_score = 0;
77
78 do {
79 applied_change_with_positive_score = 0;
80 delta_score = 0;
81
82 std::vector< std::pair< NodeId, double > > ordered_queues
83 = selector.nodesSortedByBestScore();
84
85 for (Idx j = 0; j < dag.size(); ++j) {
86 NodeId i = ordered_queues[j].first;
87
88 if (!selector.empty(i) && (!nb_changes_applied || (selector.bestScore(i) > 0))) {
89 // pick up the best change
90 const GraphChange& change = selector.bestChange(i);
91
92 // perform the change
93 switch (change.type()) {
95 if (!impacted_queues[change.node2()] && selector.isChangeValid(change)) {
96 if (selector.bestScore(i) > 0) {
97 ++applied_change_with_positive_score;
98 } else if (current_score > best_score) {
99 best_score = current_score;
100 best_dag = dag;
101 }
102
103 // std::cout << "apply arc addition " << change.node1()
104 // << " -> " << change.node2()
105 // << " delta = " << selector.bestScore( i )
106 // << std::endl;
107
108 delta_score += selector.bestScore(i);
109 current_score += selector.bestScore(i);
110 dag.addArc(change.node1(), change.node2());
111 impacted_queues[change.node2()] = true;
112 selector.applyChangeWithoutScoreUpdate(change);
113 ++nb_changes_applied;
114 }
115
116 break;
117
119 if (!impacted_queues[change.node2()] && selector.isChangeValid(change)) {
120 if (selector.bestScore(i) > 0) {
121 ++applied_change_with_positive_score;
122 } else if (current_score > best_score) {
123 best_score = current_score;
124 best_dag = dag;
125 }
126
127 // std::cout << "apply arc deletion " << change.node1()
128 // << " -> " << change.node2()
129 // << " delta = " << selector.bestScore( i )
130 // << std::endl;
131
132 delta_score += selector.bestScore(i);
133 current_score += selector.bestScore(i);
134 dag.eraseArc(Arc(change.node1(), change.node2()));
135 impacted_queues[change.node2()] = true;
136 selector.applyChangeWithoutScoreUpdate(change);
137 ++nb_changes_applied;
138 }
139
140 break;
141
143 if ((!impacted_queues[change.node1()]) && (!impacted_queues[change.node2()])
144 && selector.isChangeValid(change)) {
145 if (selector.bestScore(i) > 0) {
146 ++applied_change_with_positive_score;
147 } else if (current_score > best_score) {
148 best_score = current_score;
149 best_dag = dag;
150 }
151
152 // std::cout << "apply arc reversal " << change.node1()
153 // << " -> " << change.node2()
154 // << " delta = " << selector.bestScore( i )
155 // << std::endl;
156
157 delta_score += selector.bestScore(i);
158 current_score += selector.bestScore(i);
159 dag.eraseArc(Arc(change.node1(), change.node2()));
160 dag.addArc(change.node2(), change.node1());
161 impacted_queues[change.node1()] = true;
162 impacted_queues[change.node2()] = true;
163 selector.applyChangeWithoutScoreUpdate(change);
164 ++nb_changes_applied;
165 }
166
167 break;
168
169 default :
171 "edge modifications are not "
172 "supported by local search");
173 }
174
175 break;
176 }
177 }
178
179 selector.updateScoresAfterAppliedChanges();
180
181 // reset the impacted queue and applied changes structures
182 for (auto iter = impacted_queues.begin(); iter != impacted_queues.end(); ++iter) {
183 *iter = false;
184 }
185
186 updateApproximationScheme(nb_changes_applied);
187
188 // update current_N
189 if (applied_change_with_positive_score) {
190 current_N = 0;
191 nb_changes_applied = 0;
192 } else {
193 ++current_N;
194 }
195
196 // std::cout << "current N = " << current_N << std::endl;
197 } while ((current_N <= _MaxNbDecreasing_) && continueApproximationScheme(delta_score));
198
199 stopApproximationScheme(); // just to be sure of the
200 // approximationScheme has
201 // been notified of the end of looop
202
203 if (current_score > best_score) {
204 return dag;
205 } else {
206 return best_dag;
207 }
208 }
209
211 template < typename GUM_SCALAR, typename GRAPH_CHANGES_SELECTOR, typename PARAM_ESTIMATOR >
212 BayesNet< GUM_SCALAR > LocalSearchWithTabuList::learnBN(GRAPH_CHANGES_SELECTOR& selector,
213 PARAM_ESTIMATOR& estimator,
214 DAG initial_dag) {
216 learnStructure(selector, initial_dag));
217 }
218
219 } /* namespace learning */
220
221} /* namespace gum */
A class that, given a structure and a parameter estimator returns a full Bayes net.
void updateApproximationScheme(unsigned int incr=1)
Update the scheme w.r.t the new error and increment steps.
void initApproximationScheme()
Initialise the scheme.
void stopApproximationScheme()
Stop the approximation scheme.
bool continueApproximationScheme(double error)
Update the scheme w.r.t the new error.
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
The base class for all directed edges.
Base class for dag.
Definition DAG.h:121
void addArc(NodeId tail, NodeId head) final
insert a new arc into the directed graph
Definition DAG_inl.h:63
Size size() const
alias for sizeNodes
Exception : operation not allowed.
static BayesNet< GUM_SCALAR > createBN(ParamEstimator &estimator, const DAG &dag)
create a BN from a DAG using a one pass generator (typically ML)
NodeId node1() const noexcept
returns the first node involved in the modification
GraphChangeType type() const noexcept
returns the type of the operation
NodeId node2() const noexcept
returns the second node involved in the modification
Size _MaxNbDecreasing_
the max number of changes decreasing the score that we allow to apply
BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR &selector, PARAM_ESTIMATOR &estimator, DAG initial_dag=DAG())
learns the structure and the parameters of a BN
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, DAG initial_dag=DAG())
learns the structure of a Bayes net
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
the classes to account for structure changes in a graph
Size Idx
Type for indexes.
Definition types.h:79
Size NodeId
Type for node ids.
include the inlined functions if necessary
Definition CSVParser.h:54
gum is the global namespace for all aGrUM entities
Definition agrum.h:46