55 template <
typename GUM_SCALAR >
59 = *(this->
tree_->data(p).iso_map.begin().val());
62 for (
const auto inst: seq) {
63 for (
const auto input: inst->type().slotChains())
64 for (
const auto inst2: inst->getInstances(input->id()))
66 && (!input_set.
exists(&(inst2->get(input->lastElt().safeName()))))) {
67 cost += std::log(input->type().variable().domainSize());
68 input_set.
insert(&(inst2->get(input->lastElt().safeName())));
71 for (
auto vec = inst->beginInvRef(); vec != inst->endInvRef(); ++vec)
72 for (
const auto& inverse: *vec.val())
73 if (!seq.
exists(inverse.first)) {
74 cost += std::log(inst->get(vec.key()).type().variable().domainSize());
82 template <
typename GUM_SCALAR >
85 Set< Tensor< GUM_SCALAR >* >& pool,
87 for (
const auto inst: match) {
88 for (
const auto& elt: *inst) {
92 data.
mod.insert(
id, elt.second->type()->domainSize());
93 data.
vars.insert(
id, &elt.second->type().variable());
94 pool.insert(
const_cast< Tensor< GUM_SCALAR >*
>(&(elt.second->cpf())));
99 for (
const auto inst: match)
100 for (
const auto& elt: *inst) {
105 for (
const auto chld: inst->type().containerDag().children(elt.second->id())) {
110 for (
const auto par: inst->type().containerDag().parents(elt.second->id())) {
111 switch (inst->type().get(par).elt_type()) {
119 for (
const auto inst2: inst->getInstances(par))
120 if (match.exists(inst2))
125 inst->type().get(par)))));
136 if (inst->hasRefAttr(elt.second->id())) {
137 const std::vector< std::pair< PRMInstance< GUM_SCALAR >*, std::string > >& ref_attr
138 = inst->getRefAttr(elt.second->id());
140 for (
auto pair = ref_attr.begin(); pair != ref_attr.end(); ++pair) {
141 if (match.exists(pair->first)) {
142 NodeId id = pair->first->type().get(pair->second).id();
144 for (
const auto child: pair->first->type().containerDag().children(
id))
159 template <
typename GUM_SCALAR >
162 Set< Tensor< GUM_SCALAR >* >& pool) {
171 Size max(0), max_count(1);
173 Tensor< GUM_SCALAR >* pot = 0;
175 for (
size_t idx = 0; idx < data.
inners.
size(); ++idx) {
177 pot->add(*(data.
vars.second(elim_order[idx])));
181 for (
const auto p: pool)
182 if (p->contains(*(data.
vars.second(elim_order[idx])))) {
183 for (
auto var = p->variablesSequence().begin(); var != p->variablesSequence().end();
193 if (pot->domainSize() > max) {
194 max = pot->domainSize();
196 }
else if (pot->domainSize() == max) {
200 for (
const auto p: toRemove)
203 pot->erase(*(data.
vars.second(elim_order[idx])));
206 for (
const auto pot: trash)
209 return std::make_pair(max, max_count);
213 template <
typename GUM_SCALAR >
218 template <
typename GUM_SCALAR >
225 template <
typename GUM_SCALAR >
230 template <
typename GUM_SCALAR >
237 template <
typename GUM_SCALAR >
245 template <
typename GUM_SCALAR >
251 template <
typename GUM_SCALAR >
258 template <
typename GUM_SCALAR >
263 template <
typename GUM_SCALAR >
270 template <
typename GUM_SCALAR >
275 template <
typename GUM_SCALAR >
283 template <
typename GUM_SCALAR >
286 return this->
tree_->frequency(*i) > this->
tree_->frequency(*j);
289 template <
typename GUM_SCALAR >
291 return (this->
tree_->graph().size(i) > this->tree_->graph().size(j));
297 template <
typename GUM_SCALAR >
303 template <
typename GUM_SCALAR >
309 template <
typename GUM_SCALAR >
314 template <
typename GUM_SCALAR >
321 template <
typename GUM_SCALAR >
326 template <
typename GUM_SCALAR >
335 template <
typename GUM_SCALAR >
341 template <
typename GUM_SCALAR >
347 template <
typename GUM_SCALAR >
350 return _map_[p].first;
353 return _map_[p].first;
357 template <
typename GUM_SCALAR >
360 return _map_[p].second;
363 return _map_[p].second;
367 template <
typename GUM_SCALAR >
374 template <
typename GUM_SCALAR >
381 template <
typename GUM_SCALAR >
388 template <
typename GUM_SCALAR >
395 _map_.insert(p, std::make_pair(inner, outer));
400 template <
typename GUM_SCALAR >
405 template <
typename GUM_SCALAR >
411 template <
typename GUM_SCALAR >
416 template <
typename GUM_SCALAR >
422 template <
typename GUM_SCALAR >
432 template <
typename GUM_SCALAR >
436 for (
const auto n: r->
nodes())
439 return tree_width >=
cost(*r);
442 template <
typename GUM_SCALAR >
447 return cost(*parent) >=
cost(*child);
450 template <
typename GUM_SCALAR >
455 template <
typename GUM_SCALAR >
void insert(const T1 &first, const T2 &second)
const T1 & first(const T2 &second) const
Exception : a similar element already exists.
Generic doubly linked lists.
Val & insert(const Val &val)
Inserts a new element at the end of the chained list (alias of pushBack).
Multidimensional matrix stored as a sparse array in memory.
virtual NodeId addNode()
insert a new node and return its id
Exception : the element we looked for cannot be found.
class for graph triangulations for which we enforce a given partial ordering on the nodes elimination...
bool exists(const Key &k) const
Check the existence of k in the sequence.
void insert(const Key &k)
Insert an element at the end of the sequence.
The generic class for storing (ordered) sequences of objects.
Size size() const noexcept
Returns the number of elements in the set.
void insert(const Key &k)
Inserts a new element into the set.
void erase(const Key &k)
Erases an element from the set.
const std::vector< NodeId > & eliminationOrder()
returns an elimination ordering compatible with the triangulated graph
void addEdge(NodeId first, NodeId second) override
insert a new edge into the undirected graph
PRMAttribute is a member of a Class in a PRM.
const std::string & safeName() const
Returns the safe name of this PRMClassElement, if any.
An PRMInstance is a Bayesian network fragment defined by a Class and used in a PRMSystem.
const std::string & name() const
Returns the name of this object.
A PRMSlotChain represents a sequence of gum::prm::PRMClassElement<GUM_SCALAR> where the n-1 first gum...
PRMClassElement< GUM_SCALAR > & lastElt()
Returns the last element of the slot chain, typically this is an gum::PRMAttribute or a gum::PRMAggre...
This class is used to define an edge growth of a pattern in this DFSTree.
This is class is an implementation of a simple serach strategy for the gspan algorithm: it accept a g...
virtual ~FrequenceSearch()
Destructor.
FrequenceSearch(Size freq)
Default constructor.
virtual bool operator()(LabelData *i, LabelData *j)
virtual bool accept_root(const Pattern *r)
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
FrequenceSearch & operator=(const FrequenceSearch &from)
Copy operator.
This contains all the information we want for a node in a DFSTree.
const NodeGraphPart & nodes() const
LabelData & label(NodeId node)
Returns the LabelData assigned to node.
This is an abstract class used to tune search strategies in the gspan algorithm.
double computeCost_(const Pattern &p)
SearchStrategy()
Default constructor.
SearchStrategy< GUM_SCALAR > & operator=(const SearchStrategy< GUM_SCALAR > &from)
Copy operator.
DFSTree< GUM_SCALAR > * tree_
void setTree(DFSTree< GUM_SCALAR > *tree)
virtual ~SearchStrategy()
Destructor.
This is class is an implementation of a strict strategy for the GSpan algorithm.
double _outer_cost_(const Pattern *p)
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
virtual ~StrictSearch()
Destructor.
StrictSearch(Size freq=2)
Default constructor.
virtual bool accept_root(const Pattern *r)
StrictSearch & operator=(const StrictSearch &from)
Copy operator.
double _inner_cost_(const Pattern *p)
HashTable< const Pattern *, std::pair< double, double > > _map_
virtual bool operator()(LabelData *i, LabelData *j)
void _compute_costs_(const Pattern *p)
void _buildPatternGraph_(typename StrictSearch< GUM_SCALAR >::PData &data, Set< Tensor< GUM_SCALAR > * > &pool, const Sequence< PRMInstance< GUM_SCALAR > * > &match)
std::pair< Size, Size > _elimination_cost_(typename StrictSearch< GUM_SCALAR >::PData &data, Set< Tensor< GUM_SCALAR > * > &pool)
std::string _str_(const PRMInstance< GUM_SCALAR > *i, const PRMAttribute< GUM_SCALAR > *a) const
A growth is accepted if and only if the new growth has a tree width less large or equal than its fath...
TreeWidthSearch()
Default constructor.
virtual bool accept_growth(const Pattern *parent, const Pattern *child, const EdgeGrowth< GUM_SCALAR > &growth)
virtual ~TreeWidthSearch()
Destructor.
HashTable< const Pattern *, double > _map_
TreeWidthSearch & operator=(const TreeWidthSearch &from)
Copy operator.
virtual bool operator()(LabelData *i, LabelData *j)
virtual bool accept_root(const Pattern *r)
double cost(const Pattern &p)
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Size NodeId
Type for node ids.
namespace for all probabilistic relational models entities
gum is the global namespace for all aGrUM entities
Headers of the SearchStrategy class and child.
Inner class to handle data about labels in this interface graph.
Size tree_width
The size in terms of tree width of the given label.
Private structure to represent data about a pattern.
Bijection< NodeId, std::string > node2attr
A bijection to easily keep track between graph and attributes, its of the form instance_name DOT attr...
NodeProperty< Size > mod
The pattern's variables modalities.
UndiGraph graph
A yet to be triangulated undigraph.
NodeSet outputs
Returns the set of outputs nodes given all the matches of pattern.
NodeSet inners
Returns the set of inner nodes.
Bijection< NodeId, const DiscreteVariable * > vars
Bijection between graph's nodes and their corresponding DiscreteVariable, for inference purpose.