57 template <
typename GUM_SCALAR >
66 template <
typename GUM_SCALAR >
74 template <
typename GUM_SCALAR >
84 template <
typename GUM_SCALAR >
89 template <
typename GUM_SCALAR >
92 if (
this == &src) {
return *
this; }
105 template <
typename GUM_SCALAR >
108 if (
this == &src) {
return *
this; }
109 _prm_ = std::move(src._prm_);
116 _dag_ = std::move(src._dag_);
121 template <
typename GUM_SCALAR >
132 for (
auto& i: c->interfaces()) {
133 if (
_solver_->resolveInterface(i)) { implements.insert(i.label()); }
137 if (
_solver_->resolveClass(c->superLabel())) {
138 factory.
startClass(c->name().label(), c->superLabel().label(), &implements,
true);
145 template <
typename GUM_SCALAR >
147 auto topo_order =
_dag_.topologicalOrder();
149 for (
auto id = topo_order.rbegin();
id != topo_order.rend(); --
id) {
154 template <
typename GUM_SCALAR >
159 template <
typename GUM_SCALAR >
161 for (
auto& c:
_o3_prm_->classes()) {
162 auto id =
_dag_.addNode();
166 _classMap_.insert(c->name().label(), c.get());
170 O3PRM_CLASS_DUPLICATE(c->name(), *
_errors_);
178 template <
typename GUM_SCALAR >
180 for (
auto& c:
_o3_prm_->classes()) {
181 if (c->superLabel().label() !=
"") {
182 if (!
_solver_->resolveClass(c->superLabel())) {
return false; }
184 auto head =
_nameMap_[c->superLabel().label()];
185 auto tail =
_nameMap_[c->name().label()];
188 _dag_.addArc(tail, head);
191 O3PRM_CLASS_CYLIC_INHERITANCE(c->name(), c->superLabel(), *
_errors_);
200 template <
typename GUM_SCALAR >
202 for (
auto& c:
_o3_prm_->classes()) {
204 _prm_->getClass(c->name().label()).initializeInheritance();
213 template <
typename GUM_SCALAR >
218 attr_map.insert(a->name().label(), a.get());
224 agg_map.insert(agg.name().label(), &agg);
229 ref_map.insert(ref.name().label(), &ref);
234 if (
_solver_->resolveInterface(i)) {
242 template <
typename GUM_SCALAR >
248 const auto& real_i = _prm_->getInterface(i.
label());
250 auto counter = (
Size)0;
251 for (
const auto& a: real_i.attributes()) {
252 if (attr_map.
exists(a->name())) {
255 if (!_checkImplementation_(attr_map[a->name()]->type(), a->type())) {
256 O3PRM_CLASS_ATTR_IMPLEMENTATION(c.
name(), i, attr_map[a->name()]->name(), *_errors_);
261 if (agg_map.
exists(a->name())) {
264 if (!_checkImplementation_(agg_map[a->name()]->variableType(), a->type())) {
265 O3PRM_CLASS_AGG_IMPLEMENTATION(c.
name(), i, agg_map[a->name()]->name(), *_errors_);
271 if (counter != real_i.attributes().size()) {
272 O3PRM_CLASS_MISSING_ATTRIBUTES(c.
name(), i, *_errors_);
277 for (
const auto& r: real_i.referenceSlots()) {
278 if (ref_map.
exists(r->name())) {
281 if (!_checkImplementation_(ref_map[r->name()]->type(), r->slotType())) {
282 O3PRM_CLASS_REF_IMPLEMENTATION(c.
name(), i, ref_map[r->name()]->name(), *_errors_);
290 template <
typename GUM_SCALAR >
293 if (!
_solver_->resolveType(o3_type)) {
return false; }
295 return _prm_->type(o3_type.
label()).isSubTypeOf(type);
298 template <
typename GUM_SCALAR >
302 if (!
_solver_->resolveSlotType(o3_type)) {
return false; }
305 return _prm_->getInterface(o3_type.
label()).isSubTypeOf(type);
307 return _prm_->getClass(o3_type.
label()).isSubTypeOf(type);
311 template <
typename GUM_SCALAR >
326 template <
typename GUM_SCALAR >
332 factory.
addParameter(
"int", p.name().label(), p.value().value());
337 factory.
addParameter(
"real", p.name().label(), p.value().value());
348 template <
typename GUM_SCALAR >
357 template <
typename GUM_SCALAR >
366 factory.
addReferenceSlot(ref.type().label(), ref.name().label(), ref.isArray());
373 template <
typename GUM_SCALAR >
376 if (!
_solver_->resolveSlotType(ref.
type())) {
return false; }
382 const auto& elt = real_c.get(ref.
name().
label());
396 if (slot_type->name() == real_ref->slotType().name()) {
400 }
else if (!slot_type->isSubTypeOf(real_ref->slotType())) {
416 if ((&ref_type) == (&real_c)) {
422 if (ref_type.isSubTypeOf(real_c)) {
431 template <
typename GUM_SCALAR >
440 template <
typename GUM_SCALAR >
449 template <
typename GUM_SCALAR >
456 factory.
startAttribute(attr->type().label(), attr->name().label());
464 template <
typename GUM_SCALAR >
468 if (!
_solver_->resolveType(attr.
type())) {
return false; }
474 if (!super.exists(attr.
name().
label())) {
return true; }
476 const auto& super_type = super.get(attr.
name().
label()).type();
479 if (!type.isSubTypeOf(super_type)) {
487 template <
typename GUM_SCALAR >
502 for (
auto a: super.attributes()) {
503 to_complete.insert(a->safeName());
506 for (
auto a: super.aggregates()) {
507 to_complete.insert(a->safeName());
512 _prm_->getClass(c->
name().
label()).get(a->name().label()).safeName());
517 _prm_->getClass(c->
name().
label()).get(a.name().label()).safeName());
520 for (
auto a: to_complete) {
529 template <
typename GUM_SCALAR >
543 template <
typename GUM_SCALAR >
552 for (
const auto& parent: agg.parents()) {
561 template <
typename GUM_SCALAR >
566 if (t ==
nullptr) {
return false; }
574 template <
typename GUM_SCALAR >
583 for (
const auto& parent: attr->parents()) {
587 auto raw =
dynamic_cast< const O3RawCPT*
>(attr.get());
590 auto values = std::vector< std::string >();
591 for (
const auto& val: raw->values()) {
592 values.push_back(val.formula().formula());
597 auto rule_cpt =
dynamic_cast< const O3RuleCPT*
>(attr.get());
599 for (
const auto& rule: rule_cpt->rules()) {
600 auto labels = std::vector< std::string >();
601 auto values = std::vector< std::string >();
603 for (
const auto& lbl: rule.first) {
604 labels.push_back(lbl.label());
607 for (
const auto& form: rule.second) {
608 values.push_back(form.formula().formula());
620 template <
typename GUM_SCALAR >
625 for (
auto& prnt: attr.
parents()) {
630 auto raw =
dynamic_cast< O3RawCPT*
>(&attr);
633 auto rule =
dynamic_cast< O3RuleCPT*
>(&attr);
639 template <
typename GUM_SCALAR >
642 if (prnt.
label().find(
'.') == std::string::npos) {
650 template <
typename GUM_SCALAR >
654 O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *
_errors_);
658 const auto& elt = c.
get(prnt.
label());
662 O3PRM_CLASS_ILLEGAL_PARENT(prnt, *
_errors_);
669 template <
typename GUM_SCALAR >
677 template <
typename GUM_SCALAR >
681 if (rule.first.size() != attr.
parents().size()) {
682 O3PRM_CLASS_ILLEGAL_RULE_SIZE(rule, rule.first.size(), attr.
parents().size(), *
_errors_);
688 template <
typename GUM_SCALAR >
693 for (std::size_t i = 0; i < attr.
parents().size(); ++i) {
694 auto label = rule.first[i];
699 if (label.label() !=
"*"
700 && std::find(real_labels.begin(), real_labels.end(), label.label())
701 == real_labels.end()) {
702 O3PRM_CLASS_ILLEGAL_RULE_LABEL(rule, label, prnt, *
_errors_);
709 return errors ==
false;
712 template <
typename GUM_SCALAR >
717 for (
auto& f: rule.second) {
718 f.formula().variables().clear();
719 for (
const auto& values: scope) {
720 f.formula().variables().insert(values.first, values.second->value());
725 template <
typename GUM_SCALAR >
732 GUM_SCALAR sum = 0.0;
733 for (
const auto& f: rule.second) {
735 auto value = GUM_SCALAR(f.formula().result());
737 if (value < 0.0 || 1.0 < value) {
748 if (std::abs(sum - 1.0) > 1e-3) {
749 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(c.
name(), attr.
name(),
float(sum), *
_errors_);
751 }
else if (std::abs(sum - 1.0f) > 1e-6) {
752 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(c.
name(), attr.
name(),
float(sum), *
_errors_);
754 return errors ==
false;
757 template <
typename GUM_SCALAR >
760 const auto& scope = c.
scope();
762 for (
auto& rule: attr.
rules()) {
774 return errors ==
false;
777 template <
typename GUM_SCALAR >
782 auto domainSize = type->domainSize();
783 for (
auto& prnt: attr.
parents()) {
785 domainSize *= c.
get(prnt.label()).type()->domainSize();
794 if (domainSize != attr.
values().size()) {
795 O3PRM_CLASS_ILLEGAL_CPT_SIZE(c.
name(),
804 const auto& scope = c.
scope();
805 for (
auto& f: attr.
values()) {
806 f.formula().variables().clear();
808 for (
const auto& values: scope) {
809 f.formula().variables().insert(values.first, values.second->value());
814 Size parent_size = domainSize / type->domainSize();
815 auto values = std::vector< GUM_SCALAR >(parent_size, 0.0f);
817 for (std::size_t i = 0; i < attr.
values().size(); ++i) {
819 auto idx = i % parent_size;
820 auto val = (GUM_SCALAR)attr.
values()[i].formula().result();
823 if (val < 0.0 || 1.0 < val) {
833 for (
auto f: values) {
834 if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-3) {
835 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(c.
name(), attr.
name(),
float(f), *
_errors_);
837 }
else if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-6) {
838 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(c.
name(), attr.
name(),
float(f), *
_errors_);
844 template <
typename GUM_SCALAR >
848 auto s = chain.
label();
850 std::vector< std::string > v;
854 for (
size_t i = 0; i < v.size(); ++i) {
859 auto elt = &(current->get(link));
861 if (i == v.size() - 1) {
870 current = &(ref->slotType());
882 template <
typename GUM_SCALAR >
886 const std::string& s) {
888 O3PRM_CLASS_LINK_NOT_FOUND(chain, s, *
_errors_);
894 template <
typename GUM_SCALAR >
901 auto params = std::vector< std::string >();
902 for (
auto& p: agg.parameters()) {
903 params.push_back(p.label());
907 agg.aggregateType().label(),
908 agg.variableType().label(),
917 template <
typename GUM_SCALAR >
928 template <
typename GUM_SCALAR >
932 auto t = (
const PRMType*)
nullptr;
934 for (
const auto& prnt: agg.
parents()) {
937 if (elt ==
nullptr) {
938 O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *
_errors_);
947 O3PRM_CLASS_WRONG_PARENT(prnt, *
_errors_);
951 }
else if ((*t) != elt->type()) {
953 O3PRM_CLASS_WRONG_PARENT_TYPE(prnt, t->name(), elt->type().name(), *
_errors_);
961 template <
typename GUM_SCALAR >
969 && !agg_type.isSubTypeOf(super.get(agg.
name().
label()).type())) {
978 template <
typename GUM_SCALAR >
1008 if (!ok) {
return false; }
1026 template <
typename GUM_SCALAR >
1036 template <
typename GUM_SCALAR >
1039 const auto& param = agg.
parameters().front();
1041 for (
Size idx = 0; idx < t.
variable().domainSize(); ++idx) {
1049 O3PRM_CLASS_AGG_PARAMETER_NOT_FOUND(agg.
name(), param, *
_errors_);
Headers for the O3ClassFactory class.
virtual std::string label(Idx i) const =0
get the indice-th label. This method is pure virtual.
Exception : a similar element already exists.
This class is used contain and manipulate gum::ParseError.
Base class for all aGrUM's exceptions.
Exception : fatal (unknown ?) error.
The class for generic Hash Tables.
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
Exception : existence of a directed cycle in a graph.
Exception : the element we looked for cannot be found.
Exception : operation not allowed.
static AggregateType str2enum(const std::string &str)
Static method which returns the AggregateType given its string representation.
<agrum/PRM/classElementContainer.h>
virtual bool exists(const std::string &name) const
Returns true if a member with the given name exists in this PRMClassElementContainer or in the PRMCla...
Abstract class representing an element of PRM class.
static INLINE bool isSlotChain(const PRMClassElement< GUM_SCALAR > &elt)
Return true if obj is of type PRMSlotChain.
static INLINE bool isAggregate(const PRMClassElement< GUM_SCALAR > &elt)
Return true if obj is of type PRMAggregate.
static INLINE bool isReferenceSlot(const PRMClassElement< GUM_SCALAR > &elt)
Returns true if obj_ptr is of type PRMReferenceSlot.
static INLINE bool isAttribute(const PRMClassElement< GUM_SCALAR > &elt)
Returns true if obj_ptr is of type PRMAttribute.
A PRMClass is an object of a PRM representing a fragment of a Bayesian network which can be instantia...
PRMClassElement< GUM_SCALAR > & get(NodeId id)
See gum::prm::PRMClassElementContainer<GUM_SCALAR>::get(NodeId).
HashTable< std::string, const PRMParameter< GUM_SCALAR > * > scope() const
Returns all the parameters in the scope of this class.
Factory which builds a PRM<GUM_SCALAR>.
virtual void startAttribute(const std::string &type, const std::string &name, bool scalar_atttr=false) override
Tells the factory that we start an attribute declaration.
void endAggregator()
Finishes an aggregate declaration.
void addParameter(const std::string &type, const std::string &name, double value) override
Add a parameter to the current class with a default value.
virtual void addParent(const std::string &name) override
Tells the factory that we add a parent to the current declared attribute.
void startAggregator(const std::string &name, const std::string &agg_type, const std::string &rv_type, const std::vector< std::string > ¶ms)
Start an aggregator declaration.
virtual void setCPFByRule(const std::vector< std::string > &labels, const std::vector< GUM_SCALAR > &values)
Fills the CPF using a rule.
virtual void addReferenceSlot(const std::string &type, const std::string &name, bool isArray) override
Tells the factory that we started declaring a slot.
virtual void endAttribute() override
Tells the factory that we finished declaring an attribute.
virtual void startClass(const std::string &c, const std::string &ext="", const Set< std::string > *implements=nullptr, bool delayInheritance=false) override
Tells the factory that we start a class declaration.
void setRawCPFByColumns(const std::vector< GUM_SCALAR > &array)
Gives the factory the CPF in its raw form.
virtual void continueClass(const std::string &c) override
Continue the declaration of a class.
virtual void endClass(bool checkImplementations=true) override
Tells the factory that we finished a class declaration.
void continueAggregator(const std::string &name)
Conitnues an aggregator declaration.
virtual void continueAttribute(const std::string &name) override
Continues the declaration of an attribute.
const std::string & name() const
Returns the name of this object.
PRMParameter is a member of a Class in a PRM.
A PRMReferenceSlot represent a relation between two PRMClassElementContainer.
This is a decoration of the DiscreteVariable class.
DiscreteVariable & variable()
Return a reference on the DiscreteVariable contained in this.
This class represents a Probabilistic Relational PRMSystem<GUM_SCALAR>.
The O3Aggregate is part of the AST of the O3PRM language.
O3LabelList & parameters()
O3Label & aggregateType()
The O3Attribute is part of the AST of the O3PRM language.
virtual O3LabelList & parents()
Builds gum::prm::Class from gum::prm::o3prm::O3Class.
bool _checkAndAddNodesToDag_()
bool _checkRemoteParent_(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &prnt)
O3NameSolver< GUM_SCALAR > * _solver_
bool _checkLabelsNumber_(const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
bool _checkAggTypeLegality_(O3Class &o3class, O3Aggregate &agg)
void _addReferenceSlots_(O3Class &c)
bool _checkRawCPT_(const PRMClass< GUM_SCALAR > &c, O3RawCPT &attr)
void buildReferenceSlots()
void _declareAggregates_(O3Class &c)
const PRMClassElement< GUM_SCALAR > * _resolveSlotChain_(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain)
void buildImplementations()
HashTable< std::string, gum::NodeId > _nameMap_
O3ClassFactory(PRM< GUM_SCALAR > &prm, O3PRM &o3_prm, O3NameSolver< GUM_SCALAR > &solver, ErrorsContainer &errors)
bool _checkAggParameters_(O3Class &o3class, O3Aggregate &agg, const PRMType *t)
bool _checkLocalParent_(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
void _addParamsToForms_(const HashTable< std::string, const PRMParameter< GUM_SCALAR > * > &scope, O3RuleCPT::O3Rule &rule)
bool _checkAggregateForDeclaration_(O3Class &o3class, O3Aggregate &agg)
bool _checkRuleCPT_(const PRMClass< GUM_SCALAR > &c, O3RuleCPT &attr)
std::vector< O3Class * > _o3Classes_
void _addParameters_(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
bool _checkAttributeForDeclaration_(O3Class &o3_c, O3Attribute &attr)
bool _checkParameterValue_(O3Aggregate &agg, const gum::prm::PRMType &t)
const PRMType * _checkAggParents_(O3Class &o3class, O3Aggregate &agg)
void _completeAttribute_(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
bool _checkSlotChainLink_(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain, const std::string &s)
void completeAttributes()
bool _checkAndAddArcsToDag_()
bool _checkLabelsValues_(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
void _declareAttribute_(O3Class &c)
ErrorsContainer * _errors_
void _completeAggregates_(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
bool _checkImplementation_(O3Class &c)
void completeAggregates()
bool _checkAttributeForCompletion_(const O3Class &o3_c, O3Attribute &attr)
bool _checkReferenceSlot_(O3Class &c, O3ReferenceSlot &ref)
bool _checkParent_(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
bool _checkParametersNumber_(O3Aggregate &agg, Size n)
PRM< GUM_SCALAR > * _prm_
void _setO3ClassCreationOrder_()
bool _checkAggregateForCompletion_(O3Class &o3class, O3Aggregate &agg)
O3ClassFactory< GUM_SCALAR > & operator=(const O3ClassFactory< GUM_SCALAR > &src)
HashTable< std::string, O3Class * > _classMap_
HashTable< NodeId, O3Class * > _nodeMap_
bool _checkRuleCPTSumsTo1_(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
The O3Class is part of the AST of the O3PRM language.
O3ParameterList & parameters()
O3AggregateList & aggregates()
O3LabelList & interfaces()
O3ReferenceSlotList & referenceSlots()
O3AttributeList & attributes()
The O3Label is part of the AST of the O3PRM language.
Resolves names for the different O3PRM factories.
The O3PRM is part of the AST of the O3PRM language.
The O3RawCPT is part of the AST of the O3PRM language.
virtual O3FormulaList & values()
The O3ReferenceSlot is part of the AST of the O3PRM language.
The O3RuleCPT is part of the AST of the O3PRM language.
std::pair< O3LabelList, O3FormulaList > O3Rule
virtual O3RuleList & rules()
#define GUM_ERROR(type, msg)
std::size_t Size
In aGrUM, hashed values are unsigned long int.
HashTable< std::string, O3Aggregate * > AggMap
HashTable< std::string, O3ReferenceSlot * > RefMap
HashTable< std::string, O3Attribute * > AttrMap
namespace for all probabilistic relational models entities
void decomposePath(const std::string &path, std::vector< std::string > &v)
Decompose a string in a vector of strings using "." as separators.
gum is the global namespace for all aGrUM entities