aGrUM 2.3.2
a C++ library for (probabilistic) graphical models
O3ClassFactory_tpl.h
Go to the documentation of this file.
1/****************************************************************************
2 * This file is part of the aGrUM/pyAgrum library. *
3 * *
4 * Copyright (c) 2005-2025 by *
5 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
6 * - Christophe GONZALES(_at_AMU) *
7 * *
8 * The aGrUM/pyAgrum library is free software; you can redistribute it *
9 * and/or modify it under the terms of either : *
10 * *
11 * - the GNU Lesser General Public License as published by *
12 * the Free Software Foundation, either version 3 of the License, *
13 * or (at your option) any later version, *
14 * - the MIT license (MIT), *
15 * - or both in dual license, as here. *
16 * *
17 * (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) *
18 * *
19 * This aGrUM/pyAgrum library is distributed in the hope that it will be *
20 * useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, *
21 * INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
22 * FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE *
23 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
24 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, *
25 * ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR *
26 * OTHER DEALINGS IN THE SOFTWARE. *
27 * *
28 * See LICENCES for more details. *
29 * *
30 * SPDX-FileCopyrightText: Copyright 2005-2025 *
31 * - Pierre-Henri WUILLEMIN(_at_LIP6) *
32 * - Christophe GONZALES(_at_AMU) *
33 * SPDX-License-Identifier: LGPL-3.0-or-later OR MIT *
34 * *
35 * Contact : info_at_agrum_dot_org *
36 * homepage : http://agrum.gitlab.io *
37 * gitlab : https://gitlab.com/agrumery/agrum *
38 * *
39 ****************************************************************************/
40#pragma once
41
42
50
52
53namespace gum {
54 namespace prm {
55 namespace o3prm {
56
57 template < typename GUM_SCALAR >
59 O3PRM& o3_prm,
61 ErrorsContainer& errors) :
62 _prm_(&prm), _o3_prm_(&o3_prm), _solver_(&solver), _errors_(&errors) {
63 GUM_CONSTRUCTOR(O3ClassFactory);
64 }
65
66 template < typename GUM_SCALAR >
73
74 template < typename GUM_SCALAR >
76 _prm_(std::move(src._prm_)), _o3_prm_(std::move(src._o3_prm_)),
77 _solver_(std::move(src._solver_)), _errors_(std::move(src._errors_)),
78 _nameMap_(std::move(src._nameMap_)), _classMap_(std::move(src._classMap_)),
79 _nodeMap_(std::move(src._nodeMap_)), _dag_(std::move(src._dag_)),
80 _o3Classes_(std::move(src._o3Classes_)) {
81 GUM_CONS_MOV(O3ClassFactory);
82 }
83
84 template < typename GUM_SCALAR >
88
89 template < typename GUM_SCALAR >
92 if (this == &src) { return *this; }
93 _prm_ = src._prm_;
94 _o3_prm_ = src._o3_prm_;
95 _solver_ = src._solver_;
96 _errors_ = src._errors_;
97 _nameMap_ = src._nameMap_;
99 _nodeMap_ = src._nodeMap_;
100 _dag_ = src._dag_;
102 return *this;
103 }
104
105 template < typename GUM_SCALAR >
108 if (this == &src) { return *this; }
109 _prm_ = std::move(src._prm_);
110 _o3_prm_ = std::move(src._o3_prm_);
111 _solver_ = std::move(src._solver_);
112 _errors_ = std::move(src._errors_);
113 _nameMap_ = std::move(src._nameMap_);
114 _classMap_ = std::move(src._classMap_);
115 _nodeMap_ = std::move(src._nodeMap_);
116 _dag_ = std::move(src._dag_);
117 _o3Classes_ = std::move(src._o3Classes_);
118 return *this;
119 }
120
121 template < typename GUM_SCALAR >
124
125 // Class with a super class must be declared after
126 if (_checkO3Classes_()) {
128
129 for (auto c: _o3Classes_) {
130 // Soving interfaces
131 auto implements = Set< std::string >();
132 for (auto& i: c->interfaces()) {
133 if (_solver_->resolveInterface(i)) { implements.insert(i.label()); }
134 }
135
136 // Adding the class
137 if (_solver_->resolveClass(c->superLabel())) {
138 factory.startClass(c->name().label(), c->superLabel().label(), &implements, true);
139 factory.endClass(false);
140 }
141 }
142 }
143 }
144
145 template < typename GUM_SCALAR >
147 auto topo_order = _dag_.topologicalOrder();
148
149 for (auto id = topo_order.rbegin(); id != topo_order.rend(); --id) {
150 _o3Classes_.push_back(_nodeMap_[*id]);
151 }
152 }
153
154 template < typename GUM_SCALAR >
158
159 template < typename GUM_SCALAR >
161 for (auto& c: _o3_prm_->classes()) {
162 auto id = _dag_.addNode();
163
164 try {
165 _nameMap_.insert(c->name().label(), id);
166 _classMap_.insert(c->name().label(), c.get());
167 _nodeMap_.insert(id, c.get());
168
169 } catch (DuplicateElement const&) {
170 O3PRM_CLASS_DUPLICATE(c->name(), *_errors_);
171 return false;
172 }
173 }
174
175 return true;
176 }
177
178 template < typename GUM_SCALAR >
180 for (auto& c: _o3_prm_->classes()) {
181 if (c->superLabel().label() != "") {
182 if (!_solver_->resolveClass(c->superLabel())) { return false; }
183
184 auto head = _nameMap_[c->superLabel().label()];
185 auto tail = _nameMap_[c->name().label()];
186
187 try {
188 _dag_.addArc(tail, head);
189 } catch (InvalidDirectedCycle const&) {
190 // Cyclic inheritance
191 O3PRM_CLASS_CYLIC_INHERITANCE(c->name(), c->superLabel(), *_errors_);
192 return false;
193 }
194 }
195 }
196
197 return true;
198 }
199
200 template < typename GUM_SCALAR >
202 for (auto& c: _o3_prm_->classes()) {
203 if (_checkImplementation_(*c)) {
204 _prm_->getClass(c->name().label()).initializeInheritance();
205 }
206 }
207 }
208
212
213 template < typename GUM_SCALAR >
215 // Saving attributes names for fast lookup
216 auto attr_map = AttrMap();
217 for (auto& a: c.attributes()) {
218 attr_map.insert(a->name().label(), a.get());
219 }
220
221 // Saving aggregates names for fast lookup
222 auto agg_map = AggMap();
223 for (auto& agg: c.aggregates()) {
224 agg_map.insert(agg.name().label(), &agg);
225 }
226
227 auto ref_map = RefMap();
228 for (auto& ref: c.referenceSlots()) {
229 ref_map.insert(ref.name().label(), &ref);
230 }
231
232 // Cheking interface implementation
233 for (auto& i: c.interfaces()) {
234 if (_solver_->resolveInterface(i)) {
235 if (!_checkImplementation_(c, i, attr_map, agg_map, ref_map)) { return false; }
236 }
237 }
238
239 return true;
240 }
241
242 template < typename GUM_SCALAR >
244 O3Label& i,
245 AttrMap& attr_map,
246 AggMap& agg_map,
247 RefMap& ref_map) {
248 const auto& real_i = _prm_->getInterface(i.label());
249
250 auto counter = (Size)0;
251 for (const auto& a: real_i.attributes()) {
252 if (attr_map.exists(a->name())) {
253 ++counter;
254
255 if (!_checkImplementation_(attr_map[a->name()]->type(), a->type())) {
256 O3PRM_CLASS_ATTR_IMPLEMENTATION(c.name(), i, attr_map[a->name()]->name(), *_errors_);
257 return false;
258 }
259 }
260
261 if (agg_map.exists(a->name())) {
262 ++counter;
263
264 if (!_checkImplementation_(agg_map[a->name()]->variableType(), a->type())) {
265 O3PRM_CLASS_AGG_IMPLEMENTATION(c.name(), i, agg_map[a->name()]->name(), *_errors_);
266 return false;
267 }
268 }
269 }
270
271 if (counter != real_i.attributes().size()) {
272 O3PRM_CLASS_MISSING_ATTRIBUTES(c.name(), i, *_errors_);
273 return false;
274 }
275
276 counter = 0;
277 for (const auto& r: real_i.referenceSlots()) {
278 if (ref_map.exists(r->name())) {
279 ++counter;
280
281 if (!_checkImplementation_(ref_map[r->name()]->type(), r->slotType())) {
282 O3PRM_CLASS_REF_IMPLEMENTATION(c.name(), i, ref_map[r->name()]->name(), *_errors_);
283 return false;
284 }
285 }
286 }
287 return true;
288 }
289
290 template < typename GUM_SCALAR >
292 const PRMType& type) {
293 if (!_solver_->resolveType(o3_type)) { return false; }
294
295 return _prm_->type(o3_type.label()).isSubTypeOf(type);
296 }
297
298 template < typename GUM_SCALAR >
300 O3Label& o3_type,
302 if (!_solver_->resolveSlotType(o3_type)) { return false; }
303
304 if (_prm_->isInterface(o3_type.label())) {
305 return _prm_->getInterface(o3_type.label()).isSubTypeOf(type);
306 } else {
307 return _prm_->getClass(o3_type.label()).isSubTypeOf(type);
308 }
309 }
310
311 template < typename GUM_SCALAR >
314 // Class with a super class must be declared after
315 for (auto c: _o3Classes_) {
316 _prm_->getClass(c->name().label()).inheritParameters();
317
318 factory.continueClass(c->name().label());
319
320 _addParameters_(factory, *c);
321
322 factory.endClass(false);
323 }
324 }
325
326 template < typename GUM_SCALAR >
328 O3Class& c) {
329 for (auto& p: c.parameters()) {
330 switch (p.type()) {
332 factory.addParameter("int", p.name().label(), p.value().value());
333 break;
334 }
335
337 factory.addParameter("real", p.name().label(), p.value().value());
338 break;
339 }
340
341 default : {
342 GUM_ERROR(FatalError, "unknown O3Parameter type")
343 }
344 }
345 }
346 }
347
348 template < typename GUM_SCALAR >
350 // Class with a super class must be declared after
351 for (auto c: _o3Classes_) {
352 _prm_->getClass(c->name().label()).inheritReferenceSlots();
354 }
355 }
356
357 template < typename GUM_SCALAR >
360
361 factory.continueClass(c.name().label());
362
363 // References
364 for (auto& ref: c.referenceSlots()) {
365 if (_checkReferenceSlot_(c, ref)) {
366 factory.addReferenceSlot(ref.type().label(), ref.name().label(), ref.isArray());
367 }
368 }
369
370 factory.endClass(false);
371 }
372
373 template < typename GUM_SCALAR >
375 O3ReferenceSlot& ref) {
376 if (!_solver_->resolveSlotType(ref.type())) { return false; }
377
378 const auto& real_c = _prm_->getClass(c.name().label());
379
380 // Check for dupplicates
381 if (real_c.exists(ref.name().label())) {
382 const auto& elt = real_c.get(ref.name().label());
383
385 auto slot_type = (PRMClassElementContainer< GUM_SCALAR >*)nullptr;
386
387 if (_prm_->isInterface(ref.type().label())) {
388 slot_type = &(_prm_->getInterface(ref.type().label()));
389
390 } else {
391 slot_type = &(_prm_->getClass(ref.type().label()));
392 }
393
394 auto real_ref = static_cast< const PRMReferenceSlot< GUM_SCALAR >* >(&elt);
395
396 if (slot_type->name() == real_ref->slotType().name()) {
397 O3PRM_CLASS_DUPLICATE_REFERENCE(ref.name(), *_errors_);
398 return false;
399
400 } else if (!slot_type->isSubTypeOf(real_ref->slotType())) {
401 O3PRM_CLASS_ILLEGAL_OVERLOAD(ref.name(), c.name(), *_errors_);
402 return false;
403 }
404
405 } else {
406 O3PRM_CLASS_DUPLICATE_REFERENCE(ref.name(), *_errors_);
407 return false;
408 }
409 }
410
411 // If class we need to check for illegal references
412 if (_prm_->isClass(ref.type().label())) {
413 const auto& ref_type = _prm_->getClass(ref.type().label());
414
415 // No recursive reference
416 if ((&ref_type) == (&real_c)) {
417 O3PRM_CLASS_SELF_REFERENCE(c.name(), ref.name(), *_errors_);
418 return false;
419 }
420
421 // No reference to subclasses
422 if (ref_type.isSubTypeOf(real_c)) {
423 O3PRM_CLASS_ILLEGAL_SUB_REFERENCE(c.name(), ref.type(), *_errors_);
424 return false;
425 }
426 }
427
428 return true;
429 }
430
431 template < typename GUM_SCALAR >
433 // Class with a super class must be declared after
434 for (auto c: _o3Classes_) {
435 _prm_->getClass(c->name().label()).inheritAttributes();
437 }
438 }
439
440 template < typename GUM_SCALAR >
442 // Class with a super class must be declared after
443 for (auto c: _o3Classes_) {
444 _prm_->getClass(c->name().label()).inheritAggregates();
446 }
447 }
448
449 template < typename GUM_SCALAR >
452 factory.continueClass(c.name().label());
453
454 for (auto& attr: c.attributes()) {
455 if (_checkAttributeForDeclaration_(c, *attr)) {
456 factory.startAttribute(attr->type().label(), attr->name().label());
457 factory.endAttribute();
458 }
459 }
460
461 factory.endClass(false);
462 }
463
464 template < typename GUM_SCALAR >
466 O3Attribute& attr) {
467 // Check type
468 if (!_solver_->resolveType(attr.type())) { return false; }
469
470 // Checking type legality if overload
471 if (c.superLabel().label() != "") {
472 const auto& super = _prm_->getClass(c.superLabel().label());
473
474 if (!super.exists(attr.name().label())) { return true; }
475
476 const auto& super_type = super.get(attr.name().label()).type();
477 const auto& type = _prm_->type(attr.type().label());
478
479 if (!type.isSubTypeOf(super_type)) {
480 O3PRM_CLASS_ILLEGAL_OVERLOAD(attr.name(), c.superLabel(), *_errors_);
481 return false;
482 }
483 }
484 return true;
485 }
486
487 template < typename GUM_SCALAR >
490
491 // Class with a super class must be declared in order
492 for (auto c: _o3Classes_) {
493 _prm_->getClass(c->name().label()).inheritSlotChains();
494 factory.continueClass(c->name().label());
495
496 _completeAttribute_(factory, *c);
497
498 if (c->superLabel().label() != "") {
499 auto& super = _prm_->getClass(c->superLabel().label());
500 auto to_complete = Set< std::string >();
501
502 for (auto a: super.attributes()) {
503 to_complete.insert(a->safeName());
504 }
505
506 for (auto a: super.aggregates()) {
507 to_complete.insert(a->safeName());
508 }
509
510 for (auto& a: c->attributes()) {
511 to_complete.erase(
512 _prm_->getClass(c->name().label()).get(a->name().label()).safeName());
513 }
514
515 for (auto& a: c->aggregates()) {
516 to_complete.erase(
517 _prm_->getClass(c->name().label()).get(a.name().label()).safeName());
518 }
519
520 for (auto a: to_complete) {
521 _prm_->getClass(c->name().label()).completeInheritance(a);
522 }
523 }
524
525 factory.endClass(true);
526 }
527 }
528
529 template < typename GUM_SCALAR >
532
533 // Class with a super class must be declared in order
534 for (auto c: _o3Classes_) {
535 factory.continueClass(c->name().label());
536
537 _completeAggregates_(factory, *c);
538
539 factory.endClass(false);
540 }
541 }
542
543 template < typename GUM_SCALAR >
544 INLINE void
546 O3Class& c) {
547 // Attributes
548 for (auto& agg: c.aggregates()) {
549 if (_checkAggregateForCompletion_(c, agg)) {
550 factory.continueAggregator(agg.name().label());
551
552 for (const auto& parent: agg.parents()) {
553 factory.addParent(parent.label());
554 }
555
556 factory.endAggregator();
557 }
558 }
559 }
560
561 template < typename GUM_SCALAR >
563 O3Aggregate& agg) {
564 // Checking parents
565 auto t = _checkAggParents_(c, agg);
566 if (t == nullptr) { return false; }
567
568 // Checking parameters numbers
569 if (!_checkAggParameters_(c, agg, t)) { return false; }
570
571 return true;
572 }
573
574 template < typename GUM_SCALAR >
575 INLINE void
577 O3Class& c) {
578 // Attributes
579 for (auto& attr: c.attributes()) {
580 if (_checkAttributeForCompletion_(c, *attr)) {
581 factory.continueAttribute(attr->name().label());
582
583 for (const auto& parent: attr->parents()) {
584 factory.addParent(parent.label());
585 }
586
587 auto raw = dynamic_cast< const O3RawCPT* >(attr.get());
588
589 if (raw) {
590 auto values = std::vector< std::string >();
591 for (const auto& val: raw->values()) {
592 values.push_back(val.formula().formula());
593 }
594 factory.setRawCPFByColumns(values);
595 }
596
597 auto rule_cpt = dynamic_cast< const O3RuleCPT* >(attr.get());
598 if (rule_cpt) {
599 for (const auto& rule: rule_cpt->rules()) {
600 auto labels = std::vector< std::string >();
601 auto values = std::vector< std::string >();
602
603 for (const auto& lbl: rule.first) {
604 labels.push_back(lbl.label());
605 }
606
607 for (const auto& form: rule.second) {
608 values.push_back(form.formula().formula());
609 }
610
611 factory.setCPFByRule(labels, values);
612 }
613 }
614
615 factory.endAttribute();
616 }
617 }
618 }
619
620 template < typename GUM_SCALAR >
622 O3Attribute& attr) {
623 // Check for parents existence
624 const auto& c = _prm_->getClass(o3_c.name().label());
625 for (auto& prnt: attr.parents()) {
626 if (!_checkParent_(c, prnt)) { return false; }
627 }
628
629 // Check that CPT sums to 1
630 auto raw = dynamic_cast< O3RawCPT* >(&attr);
631 if (raw) { return _checkRawCPT_(c, *raw); }
632
633 auto rule = dynamic_cast< O3RuleCPT* >(&attr);
634 if (rule) { return _checkRuleCPT_(c, *rule); }
635
636 return true;
637 }
638
639 template < typename GUM_SCALAR >
641 const O3Label& prnt) {
642 if (prnt.label().find('.') == std::string::npos) {
643 return _checkLocalParent_(c, prnt);
644
645 } else {
646 return _checkRemoteParent_(c, prnt);
647 }
648 }
649
650 template < typename GUM_SCALAR >
652 const O3Label& prnt) {
653 if (!c.exists(prnt.label())) {
654 O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *_errors_);
655 return false;
656 }
657
658 const auto& elt = c.get(prnt.label());
662 O3PRM_CLASS_ILLEGAL_PARENT(prnt, *_errors_);
663 return false;
664 }
665
666 return true;
667 }
668
669 template < typename GUM_SCALAR >
672 const O3Label& prnt) {
673 if (_resolveSlotChain_(c, prnt) == nullptr) { return false; }
674 return true;
675 }
676
677 template < typename GUM_SCALAR >
679 const O3RuleCPT::O3Rule& rule) {
680 // Check that the number of labels is correct
681 if (rule.first.size() != attr.parents().size()) {
682 O3PRM_CLASS_ILLEGAL_RULE_SIZE(rule, rule.first.size(), attr.parents().size(), *_errors_);
683 return false;
684 }
685 return true;
686 }
687
688 template < typename GUM_SCALAR >
690 const O3RuleCPT& attr,
691 const O3RuleCPT::O3Rule& rule) {
692 bool errors = false;
693 for (std::size_t i = 0; i < attr.parents().size(); ++i) {
694 auto label = rule.first[i];
695 auto prnt = attr.parents()[i];
696 try {
697 auto real_labels = _resolveSlotChain_(c, prnt)->type()->labels();
698 // c.get(prnt.label()).type()->labels();
699 if (label.label() != "*"
700 && std::find(real_labels.begin(), real_labels.end(), label.label())
701 == real_labels.end()) {
702 O3PRM_CLASS_ILLEGAL_RULE_LABEL(rule, label, prnt, *_errors_);
703 errors = true;
704 }
705 } catch (Exception const&) {
706 // parent does not exists and is already reported
707 }
708 }
709 return errors == false;
710 }
711
712 template < typename GUM_SCALAR >
714 const HashTable< std::string, const PRMParameter< GUM_SCALAR >* >& scope,
715 O3RuleCPT::O3Rule& rule) {
716 // Add parameters to formulas
717 for (auto& f: rule.second) {
718 f.formula().variables().clear();
719 for (const auto& values: scope) {
720 f.formula().variables().insert(values.first, values.second->value());
721 }
722 }
723 }
724
725 template < typename GUM_SCALAR >
726 INLINE bool
728 const O3RuleCPT& attr,
729 const O3RuleCPT::O3Rule& rule) {
730 bool errors = false;
731 // Check that formulas are valid and sums to 1
732 GUM_SCALAR sum = 0.0;
733 for (const auto& f: rule.second) {
734 try {
735 auto value = GUM_SCALAR(f.formula().result());
736 sum += value;
737 if (value < 0.0 || 1.0 < value) {
738 O3PRM_CLASS_ILLEGAL_CPT_VALUE(c.name(), attr.name(), f, *_errors_);
739 errors = true;
740 }
741 } catch (OperationNotAllowed const&) {
742 O3PRM_CLASS_ILLEGAL_CPT_VALUE(c.name(), attr.name(), f, *_errors_);
743 errors = true;
744 }
745 }
746
747 // Check that CPT sums to 1
748 if (std::abs(sum - 1.0) > 1e-3) {
749 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(c.name(), attr.name(), float(sum), *_errors_);
750 errors = true;
751 } else if (std::abs(sum - 1.0f) > 1e-6) {
752 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(c.name(), attr.name(), float(sum), *_errors_);
753 }
754 return errors == false;
755 }
756
757 template < typename GUM_SCALAR >
759 O3RuleCPT& attr) {
760 const auto& scope = c.scope();
761 bool errors = false;
762 for (auto& rule: attr.rules()) {
763 try {
764 if (!_checkLabelsNumber_(attr, rule)) { errors = true; }
765 if (!_checkLabelsValues_(c, attr, rule)) { errors = true; }
766 _addParamsToForms_(scope, rule);
767 if (!_checkRuleCPTSumsTo1_(c, attr, rule)) { errors = true; }
768 } catch (Exception& e) {
769 GUM_SHOWERROR(e);
770 errors = true;
771 }
772 }
773
774 return errors == false;
775 }
776
777 template < typename GUM_SCALAR >
779 O3RawCPT& attr) {
780 const auto& type = _prm_->type(attr.type().label());
781
782 auto domainSize = type->domainSize();
783 for (auto& prnt: attr.parents()) {
784 try {
785 domainSize *= c.get(prnt.label()).type()->domainSize();
786 } catch (NotFound const&) {
787 // If we are here, all parents have been check so _resolveSlotChain_
788 // will not raise an error and not return a nullptr
789 domainSize *= _resolveSlotChain_(c, prnt)->type()->domainSize();
790 }
791 }
792
793 // Check for CPT size
794 if (domainSize != attr.values().size()) {
795 O3PRM_CLASS_ILLEGAL_CPT_SIZE(c.name(),
796 attr.name(),
797 Size(attr.values().size()),
798 domainSize,
799 *_errors_);
800 return false;
801 }
802
803 // Add parameters to formulas
804 const auto& scope = c.scope();
805 for (auto& f: attr.values()) {
806 f.formula().variables().clear();
807
808 for (const auto& values: scope) {
809 f.formula().variables().insert(values.first, values.second->value());
810 }
811 }
812
813 // Check that CPT sums to 1
814 Size parent_size = domainSize / type->domainSize();
815 auto values = std::vector< GUM_SCALAR >(parent_size, 0.0f);
816
817 for (std::size_t i = 0; i < attr.values().size(); ++i) {
818 try {
819 auto idx = i % parent_size;
820 auto val = (GUM_SCALAR)attr.values()[i].formula().result();
821 values[idx] += val;
822
823 if (val < 0.0 || 1.0 < val) {
824 O3PRM_CLASS_ILLEGAL_CPT_VALUE(c.name(), attr.name(), attr.values()[i], *_errors_);
825 return false;
826 }
827 } catch (Exception const&) {
828 O3PRM_CLASS_ILLEGAL_CPT_VALUE(c.name(), attr.name(), attr.values()[i], *_errors_);
829 return false;
830 }
831 }
832
833 for (auto f: values) {
834 if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-3) {
835 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1(c.name(), attr.name(), float(f), *_errors_);
836 return false;
837 } else if (std::abs(f - GUM_SCALAR(1.0)) > 1.0e-6) {
838 O3PRM_CLASS_CPT_DOES_NOT_SUM_TO_1_WARNING(c.name(), attr.name(), float(f), *_errors_);
839 }
840 }
841 return true;
842 }
843
844 template < typename GUM_SCALAR >
847 const O3Label& chain) {
848 auto s = chain.label();
849 auto current = &c;
850 std::vector< std::string > v;
851
852 decomposePath(chain.label(), v);
853
854 for (size_t i = 0; i < v.size(); ++i) {
855 auto link = v[i];
856
857 if (!_checkSlotChainLink_(*current, chain, link)) { return nullptr; }
858
859 auto elt = &(current->get(link));
860
861 if (i == v.size() - 1) {
862 // last link, should be an attribute or aggregate
863 return elt;
864
865 } else {
866 // should be a reference slot
867
868 auto ref = dynamic_cast< const PRMReferenceSlot< GUM_SCALAR >* >(elt);
869 if (ref) {
870 current = &(ref->slotType());
871 } else {
872 return nullptr; // failsafe to prevent infinite loop
873 }
874 }
875 }
876
877 // Encountered only reference slots
878
879 return nullptr;
880 }
881
882 template < typename GUM_SCALAR >
885 const O3Label& chain,
886 const std::string& s) {
887 if (!c.exists(s)) {
888 O3PRM_CLASS_LINK_NOT_FOUND(chain, s, *_errors_);
889 return false;
890 }
891 return true;
892 }
893
894 template < typename GUM_SCALAR >
897 factory.continueClass(c.name().label());
898
899 for (auto& agg: c.aggregates()) {
900 if (_checkAggregateForDeclaration_(c, agg)) {
901 auto params = std::vector< std::string >();
902 for (auto& p: agg.parameters()) {
903 params.push_back(p.label());
904 }
905
906 factory.startAggregator(agg.name().label(),
907 agg.aggregateType().label(),
908 agg.variableType().label(),
909 params);
910 factory.endAggregator();
911 }
912 }
913
914 factory.endClass(false);
915 }
916
917 template < typename GUM_SCALAR >
919 O3Aggregate& agg) {
920 if (!_solver_->resolveType(agg.variableType())) { return false; }
921
922 // Checking type legality if overload
923 if (!_checkAggTypeLegality_(o3class, agg)) { return false; }
924
925 return true;
926 }
927
928 template < typename GUM_SCALAR >
930 O3Aggregate& agg) {
931 const auto& c = _prm_->getClass(o3class.name().label());
932 auto t = (const PRMType*)nullptr;
933
934 for (const auto& prnt: agg.parents()) {
935 auto elt = _resolveSlotChain_(c, prnt);
936
937 if (elt == nullptr) {
938 O3PRM_CLASS_PARENT_NOT_FOUND(prnt, *_errors_);
939 return nullptr;
940
941 } else {
942 if (t == nullptr) {
943 try {
944 t = &(elt->type());
945
946 } catch (OperationNotAllowed const&) {
947 O3PRM_CLASS_WRONG_PARENT(prnt, *_errors_);
948 return nullptr;
949 }
950
951 } else if ((*t) != elt->type()) {
952 // Wront type in chain
953 O3PRM_CLASS_WRONG_PARENT_TYPE(prnt, t->name(), elt->type().name(), *_errors_);
954 return nullptr;
955 }
956 }
957 }
958 return t;
959 }
960
961 template < typename GUM_SCALAR >
963 O3Aggregate& agg) {
964 if (_prm_->isClass(o3class.superLabel().label())) {
965 const auto& super = _prm_->getClass(o3class.superLabel().label());
966 const auto& agg_type = _prm_->type(agg.variableType().label());
967
968 if (super.exists(agg.name().label())
969 && !agg_type.isSubTypeOf(super.get(agg.name().label()).type())) {
970 O3PRM_CLASS_ILLEGAL_OVERLOAD(agg.name(), o3class.superLabel(), *_errors_);
971 return false;
972 }
973 }
974
975 return true;
976 }
977
978 template < typename GUM_SCALAR >
980 O3Aggregate& agg,
981 const PRMType* t) {
982 bool ok = false;
983
992 ok = _checkParametersNumber_(agg, 0);
993 break;
994 }
995
999 ok = _checkParametersNumber_(agg, 1);
1000 break;
1001 }
1002
1003 default : {
1004 GUM_ERROR(FatalError, "unknown aggregate type")
1005 }
1006 }
1007
1008 if (!ok) { return false; }
1009
1010 // Checking parameters type
1015 ok = _checkParameterValue_(agg, *t);
1016 break;
1017 }
1018
1019 default : { /* Nothing to do */
1020 }
1021 }
1022
1023 return ok;
1024 }
1025
1026 template < typename GUM_SCALAR >
1028 if (agg.parameters().size() != n) {
1029 O3PRM_CLASS_AGG_PARAMETERS(agg.name(), Size(n), Size(agg.parameters().size()), *_errors_);
1030 return false;
1031 }
1032
1033 return true;
1034 }
1035
1036 template < typename GUM_SCALAR >
1038 const gum::prm::PRMType& t) {
1039 const auto& param = agg.parameters().front();
1040 bool found = false;
1041 for (Size idx = 0; idx < t.variable().domainSize(); ++idx) {
1042 if (t.variable().label(idx) == param.label()) {
1043 found = true;
1044 break;
1045 }
1046 }
1047
1048 if (!found) {
1049 O3PRM_CLASS_AGG_PARAMETER_NOT_FOUND(agg.name(), param, *_errors_);
1050 return false;
1051 }
1052
1053 return true;
1054 }
1055
1056 } // namespace o3prm
1057 } // namespace prm
1058} // namespace gum
Headers for the O3ClassFactory class.
virtual std::string label(Idx i) const =0
get the indice-th label. This method is pure virtual.
Exception : a similar element already exists.
This class is used contain and manipulate gum::ParseError.
Base class for all aGrUM's exceptions.
Definition exceptions.h:118
Exception : fatal (unknown ?) error.
The class for generic Hash Tables.
Definition hashTable.h:637
bool exists(const Key &key) const
Checks whether there exists an element with a given key in the hashtable.
Exception : existence of a directed cycle in a graph.
Exception : the element we looked for cannot be found.
Exception : operation not allowed.
static AggregateType str2enum(const std::string &str)
Static method which returns the AggregateType given its string representation.
<agrum/PRM/classElementContainer.h>
virtual bool exists(const std::string &name) const
Returns true if a member with the given name exists in this PRMClassElementContainer or in the PRMCla...
Abstract class representing an element of PRM class.
static INLINE bool isSlotChain(const PRMClassElement< GUM_SCALAR > &elt)
Return true if obj is of type PRMSlotChain.
static INLINE bool isAggregate(const PRMClassElement< GUM_SCALAR > &elt)
Return true if obj is of type PRMAggregate.
static INLINE bool isReferenceSlot(const PRMClassElement< GUM_SCALAR > &elt)
Returns true if obj_ptr is of type PRMReferenceSlot.
static INLINE bool isAttribute(const PRMClassElement< GUM_SCALAR > &elt)
Returns true if obj_ptr is of type PRMAttribute.
A PRMClass is an object of a PRM representing a fragment of a Bayesian network which can be instantia...
Definition PRMClass.h:75
PRMClassElement< GUM_SCALAR > & get(NodeId id)
See gum::prm::PRMClassElementContainer<GUM_SCALAR>::get(NodeId).
HashTable< std::string, const PRMParameter< GUM_SCALAR > * > scope() const
Returns all the parameters in the scope of this class.
Factory which builds a PRM<GUM_SCALAR>.
Definition PRMFactory.h:88
virtual void startAttribute(const std::string &type, const std::string &name, bool scalar_atttr=false) override
Tells the factory that we start an attribute declaration.
void endAggregator()
Finishes an aggregate declaration.
void addParameter(const std::string &type, const std::string &name, double value) override
Add a parameter to the current class with a default value.
virtual void addParent(const std::string &name) override
Tells the factory that we add a parent to the current declared attribute.
void startAggregator(const std::string &name, const std::string &agg_type, const std::string &rv_type, const std::vector< std::string > &params)
Start an aggregator declaration.
virtual void setCPFByRule(const std::vector< std::string > &labels, const std::vector< GUM_SCALAR > &values)
Fills the CPF using a rule.
virtual void addReferenceSlot(const std::string &type, const std::string &name, bool isArray) override
Tells the factory that we started declaring a slot.
virtual void endAttribute() override
Tells the factory that we finished declaring an attribute.
virtual void startClass(const std::string &c, const std::string &ext="", const Set< std::string > *implements=nullptr, bool delayInheritance=false) override
Tells the factory that we start a class declaration.
void setRawCPFByColumns(const std::vector< GUM_SCALAR > &array)
Gives the factory the CPF in its raw form.
virtual void continueClass(const std::string &c) override
Continue the declaration of a class.
virtual void endClass(bool checkImplementations=true) override
Tells the factory that we finished a class declaration.
void continueAggregator(const std::string &name)
Conitnues an aggregator declaration.
virtual void continueAttribute(const std::string &name) override
Continues the declaration of an attribute.
const std::string & name() const
Returns the name of this object.
PRMParameter is a member of a Class in a PRM.
A PRMReferenceSlot represent a relation between two PRMClassElementContainer.
This is a decoration of the DiscreteVariable class.
Definition PRMType.h:78
DiscreteVariable & variable()
Return a reference on the DiscreteVariable contained in this.
Definition PRMType_inl.h:64
This class represents a Probabilistic Relational PRMSystem<GUM_SCALAR>.
Definition PRM.h:74
The O3Aggregate is part of the AST of the O3PRM language.
Definition O3prm.h:598
O3LabelList & parameters()
Definition O3prm.cpp:1163
O3LabelList & parents()
Definition O3prm.cpp:1159
The O3Attribute is part of the AST of the O3PRM language.
Definition O3prm.h:486
virtual O3Label & type()
Definition O3prm.cpp:744
virtual O3Label & name()
Definition O3prm.cpp:748
virtual O3LabelList & parents()
Definition O3prm.cpp:752
Builds gum::prm::Class from gum::prm::o3prm::O3Class.
bool _checkRemoteParent_(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &prnt)
O3NameSolver< GUM_SCALAR > * _solver_
bool _checkLabelsNumber_(const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
bool _checkAggTypeLegality_(O3Class &o3class, O3Aggregate &agg)
bool _checkRawCPT_(const PRMClass< GUM_SCALAR > &c, O3RawCPT &attr)
const PRMClassElement< GUM_SCALAR > * _resolveSlotChain_(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain)
HashTable< std::string, gum::NodeId > _nameMap_
O3ClassFactory(PRM< GUM_SCALAR > &prm, O3PRM &o3_prm, O3NameSolver< GUM_SCALAR > &solver, ErrorsContainer &errors)
bool _checkAggParameters_(O3Class &o3class, O3Aggregate &agg, const PRMType *t)
bool _checkLocalParent_(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
void _addParamsToForms_(const HashTable< std::string, const PRMParameter< GUM_SCALAR > * > &scope, O3RuleCPT::O3Rule &rule)
bool _checkAggregateForDeclaration_(O3Class &o3class, O3Aggregate &agg)
bool _checkRuleCPT_(const PRMClass< GUM_SCALAR > &c, O3RuleCPT &attr)
std::vector< O3Class * > _o3Classes_
void _addParameters_(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
bool _checkAttributeForDeclaration_(O3Class &o3_c, O3Attribute &attr)
bool _checkParameterValue_(O3Aggregate &agg, const gum::prm::PRMType &t)
const PRMType * _checkAggParents_(O3Class &o3class, O3Aggregate &agg)
void _completeAttribute_(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
bool _checkSlotChainLink_(const PRMClassElementContainer< GUM_SCALAR > &c, const O3Label &chain, const std::string &s)
bool _checkLabelsValues_(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
void _completeAggregates_(PRMFactory< GUM_SCALAR > &factory, O3Class &c)
bool _checkAttributeForCompletion_(const O3Class &o3_c, O3Attribute &attr)
bool _checkReferenceSlot_(O3Class &c, O3ReferenceSlot &ref)
bool _checkParent_(const PRMClass< GUM_SCALAR > &c, const O3Label &prnt)
bool _checkParametersNumber_(O3Aggregate &agg, Size n)
bool _checkAggregateForCompletion_(O3Class &o3class, O3Aggregate &agg)
O3ClassFactory< GUM_SCALAR > & operator=(const O3ClassFactory< GUM_SCALAR > &src)
HashTable< std::string, O3Class * > _classMap_
HashTable< NodeId, O3Class * > _nodeMap_
bool _checkRuleCPTSumsTo1_(const PRMClass< GUM_SCALAR > &c, const O3RuleCPT &attr, const O3RuleCPT::O3Rule &rule)
The O3Class is part of the AST of the O3PRM language.
Definition O3prm.h:640
O3ParameterList & parameters()
Definition O3prm.cpp:971
O3AggregateList & aggregates()
Definition O3prm.cpp:981
O3LabelList & interfaces()
Definition O3prm.cpp:967
O3ReferenceSlotList & referenceSlots()
Definition O3prm.cpp:975
O3Label & superLabel()
Definition O3prm.cpp:963
O3AttributeList & attributes()
Definition O3prm.cpp:977
The O3Label is part of the AST of the O3PRM language.
Definition O3prm.h:192
std::string & label()
Definition O3prm.cpp:287
Resolves names for the different O3PRM factories.
The O3PRM is part of the AST of the O3PRM language.
Definition O3prm.h:913
The O3RawCPT is part of the AST of the O3PRM language.
Definition O3prm.h:523
virtual O3FormulaList & values()
Definition O3prm.cpp:799
The O3ReferenceSlot is part of the AST of the O3PRM language.
Definition O3prm.h:453
The O3RuleCPT is part of the AST of the O3PRM language.
Definition O3prm.h:559
std::pair< O3LabelList, O3FormulaList > O3Rule
Definition O3prm.h:563
virtual O3RuleList & rules()
Definition O3prm.cpp:851
#define GUM_ERROR(type, msg)
Definition exceptions.h:72
#define GUM_SHOWERROR(e)
Definition exceptions.h:85
std::size_t Size
In aGrUM, hashed values are unsigned long int.
Definition types.h:74
HashTable< std::string, O3Aggregate * > AggMap
HashTable< std::string, O3ReferenceSlot * > RefMap
HashTable< std::string, O3Attribute * > AttrMap
namespace for all probabilistic relational models entities
Definition agrum.h:68
void decomposePath(const std::string &path, std::vector< std::string > &v)
Decompose a string in a vector of strings using "." as separators.
Definition utils_prm.cpp:48
gum is the global namespace for all aGrUM entities
Definition agrum.h:46
STL namespace.