aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
greedyHillClimbing_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
48
51
52namespace gum {
53
54 namespace learning {
55
57 template < typename GRAPH_CHANGES_SELECTOR >
58 DAG GreedyHillClimbing::learnStructure(GRAPH_CHANGES_SELECTOR& selector, DAG dag) {
59 selector.setGraph(dag);
60
61 unsigned int nb_changes_applied = 1;
62 double delta_score;
63
65
66 // a vector that indicates which queues have valid scores, i.e., scores
67 // that were not invalidated by previously applied changes
68 std::vector< bool > impacted_queues(dag.size(), false);
69
70 do {
71 nb_changes_applied = 0;
72 delta_score = 0;
73
74 std::vector< std::pair< NodeId, double > > ordered_queues
75 = selector.nodesSortedByBestScore();
76
77 for (Idx j = 0; j < dag.size(); ++j) {
78 Idx i = ordered_queues[j].first;
79
80 if (!(selector.empty(i)) && (selector.bestScore(i) > 0)) {
81 // pick up the best change
82 const GraphChange& change = selector.bestChange(i);
83
84 // perform the change
85 switch (change.type()) {
87 if (!impacted_queues[change.node2()] && selector.isChangeValid(change)) {
88 delta_score += selector.bestScore(i);
89 dag.addArc(change.node1(), change.node2());
90 impacted_queues[change.node2()] = true;
91 selector.applyChangeWithoutScoreUpdate(change);
92 ++nb_changes_applied;
93 }
94
95 break;
96
98 if (!impacted_queues[change.node2()] && selector.isChangeValid(change)) {
99 delta_score += selector.bestScore(i);
100 dag.eraseArc(Arc(change.node1(), change.node2()));
101 impacted_queues[change.node2()] = true;
102 selector.applyChangeWithoutScoreUpdate(change);
103 ++nb_changes_applied;
104 }
105
106 break;
107
109 if ((!impacted_queues[change.node1()]) && (!impacted_queues[change.node2()])
110 && selector.isChangeValid(change)) {
111 delta_score += selector.bestScore(i);
112 dag.eraseArc(Arc(change.node1(), change.node2()));
113 dag.addArc(change.node2(), change.node1());
114 impacted_queues[change.node1()] = true;
115 impacted_queues[change.node2()] = true;
116 selector.applyChangeWithoutScoreUpdate(change);
117 ++nb_changes_applied;
118 }
119
120 break;
121
122 default :
124 "edge modifications are not supported by local search")
125 }
126 }
127 }
128
129 selector.updateScoresAfterAppliedChanges();
130
131 // reset the impacted queue and applied changes structures
132 for (auto iter = impacted_queues.begin(); iter != impacted_queues.end(); ++iter) {
133 *iter = false;
134 }
135
136 updateApproximationScheme(nb_changes_applied);
137
138 } while (nb_changes_applied && continueApproximationScheme(delta_score));
139
140 stopApproximationScheme(); // just to be sure of the approximationScheme
141 // has
142 // been notified of the end of looop
143
144 return dag;
145 }
146
148 template < typename GUM_SCALAR, typename GRAPH_CHANGES_SELECTOR, typename PARAM_ESTIMATOR >
149 BayesNet< GUM_SCALAR > GreedyHillClimbing::learnBN(GRAPH_CHANGES_SELECTOR& selector,
150 PARAM_ESTIMATOR& estimator,
151 DAG initial_dag) {
153 learnStructure(selector, initial_dag));
154 }
155
156 } /* namespace learning */
157
158} /* namespace gum */
A class that, given a structure and a parameter estimator returns a full Bayes net.
void updateApproximationScheme(unsigned int incr=1)
Update the scheme w.r.t the new error and increment steps.
void initApproximationScheme()
Initialise the scheme.
void stopApproximationScheme()
Stop the approximation scheme.
bool continueApproximationScheme(double error)
Update the scheme w.r.t the new error.
virtual void eraseArc(const Arc &arc)
removes an arc from the ArcGraphPart
The base class for all directed edges.
Base class for dag.
Definition DAG.h:121
void addArc(NodeId tail, NodeId head) final
insert a new arc into the directed graph
Definition DAG_inl.h:63
Size size() const
alias for sizeNodes
Exception : operation not allowed.
static BayesNet< GUM_SCALAR > createBN(ParamEstimator &estimator, const DAG &dag)
create a BN from a DAG using a one pass generator (typically ML)
NodeId node1() const noexcept
returns the first node involved in the modification
GraphChangeType type() const noexcept
returns the type of the operation
NodeId node2() const noexcept
returns the second node involved in the modification
DAG learnStructure(GRAPH_CHANGES_SELECTOR &selector, DAG initial_dag=DAG())
learns the structure of a Bayes net
BayesNet< GUM_SCALAR > learnBN(GRAPH_CHANGES_SELECTOR &selector, PARAM_ESTIMATOR &estimator, DAG initial_dag=DAG())
learns the structure and the parameters of a BN
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
the classes to account for structure changes in a graph
Size Idx
Type for indexes.
Definition types.h:79
include the inlined functions if necessary
Definition CSVParser.h:54
gum is the global namespace for all aGrUM entities
Definition agrum.h:46