49#ifndef DOXYGEN_SHOULD_SKIP_THIS
56 template <
typename GUM_SCALAR >
57 void ParamEstimator::_checkParameters_(
const NodeId target_node,
58 const std::vector< NodeId >& conditioning_nodes,
59 Tensor< GUM_SCALAR >& pot) {
61 const Sequence< const DiscreteVariable* >& vars = pot.variablesSequence();
62 if (vars.size() == 0) {
GUM_ERROR(SizeError,
"the tensor contains no variable") }
65 const auto& node2cols =
counter_.nodeId2Columns();
66 if (node2cols.empty()) {
67 if (
database.domainSize(target_node) != vars[0]->domainSize()) {
69 "Variable " << vars[0]->name() <<
"of the tensor to be filled "
70 <<
"has a domain size of " << vars[0]->domainSize()
71 <<
", which is different from that of node " << target_node
72 <<
" which is equal to " <<
database.domainSize(target_node));
74 for (std::size_t i = 1; i < vars.size(); ++i) {
75 if (
database.domainSize(conditioning_nodes[i - 1]) != vars[i]->domainSize()) {
77 "Variable " << vars[i]->name() <<
"of the tensor to be filled "
78 <<
"has a domain size of " << vars[i]->domainSize()
79 <<
", which is different from that of node "
80 << conditioning_nodes[i - 1] <<
" which is equal to "
81 <<
database.domainSize(conditioning_nodes[i - 1]));
85 std::size_t col = node2cols.second(target_node);
86 if (
database.domainSize(col) != vars[0]->domainSize()) {
88 "Variable " << vars[0]->name() <<
"of the tensor to be filled "
89 <<
"has a domain size of " << vars[0]->domainSize()
90 <<
", which is different from that of node " << target_node
91 <<
" which is equal to " <<
database.domainSize(col));
93 for (std::size_t i = 1; i < vars.size(); ++i) {
94 col = node2cols.second(conditioning_nodes[i - 1]);
95 if (
database.domainSize(col) != vars[i]->domainSize()) {
97 "Variable " << vars[i]->name() <<
"of the tensor to be filled "
98 <<
"has a domain size of " << vars[i]->domainSize()
99 <<
", which is different from that of node "
100 << conditioning_nodes[i - 1] <<
" which is equal to "
108 template <
typename GUM_SCALAR >
109 INLINE
typename std::enable_if< !std::is_same< GUM_SCALAR, double >::value,
double >::type
110 ParamEstimator::_setParameters_(
const NodeId target_node,
111 const std::vector< NodeId >& conditioning_nodes,
112 Tensor< GUM_SCALAR >& pot,
113 const bool compute_log_likelihood) {
114 _checkParameters_(target_node, conditioning_nodes, pot);
116 std::vector< double > params;
117 double log_likelihood = 0.0;
118 if (compute_log_likelihood) {
120 params = std::move(xparams).first;
121 log_likelihood = xparams.second;
123 params =
parameters(target_node, conditioning_nodes);
127 const std::size_t size = params.size();
128 std::vector< GUM_SCALAR > xparams(size);
129 for (std::size_t i = std::size_t(0); i < size; ++i)
130 xparams[i] = GUM_SCALAR(params[i]);
132 pot.fillWith(xparams);
133 return log_likelihood;
137 template <
typename GUM_SCALAR >
138 INLINE
typename std::enable_if< std::is_same< GUM_SCALAR, double >::value,
double >::type
139 ParamEstimator::_setParameters_(
const NodeId target_node,
140 const std::vector< NodeId >& conditioning_nodes,
141 Tensor< GUM_SCALAR >& pot,
142 const bool compute_log_likelihood) {
143 _checkParameters_(target_node, conditioning_nodes, pot);
145 std::vector< double > params;
146 double log_likelihood = 0.0;
147 if (compute_log_likelihood) {
149 params = std::move(xparams).first;
150 log_likelihood = xparams.second;
152 params =
parameters(target_node, conditioning_nodes);
155 pot.fillWith(params);
156 return log_likelihood;
160 template <
typename GUM_SCALAR >
162 const std::vector< NodeId >& conditioning_nodes,
163 Tensor< GUM_SCALAR >& pot,
164 const bool compute_log_likelihood) {
165 return _setParameters_(target_node, conditioning_nodes, pot, compute_log_likelihood);
169 template <
typename GUM_SCALAR >
RecordCounter counter_
the record counter used to parse the database
double setParameters(const NodeId target_node, const std::vector< NodeId > &conditioning_nodes, Tensor< GUM_SCALAR > &pot, const bool compute_log_likelihood=false)
sets a CPT's parameters and, possibly, return its log-likelihhod
void setBayesNet(const BayesNet< GUM_SCALAR > &new_bn)
assign a new Bayes net to all the counter's generators depending on a BN
std::pair< std::vector< double >, double > parametersAndLogLikelihood(const NodeId target_node)
returns the parameters of a CPT as well as its log-likelihood
std::vector< double > parameters(const NodeId target_node)
returns the CPT's parameters corresponding to a given target node
const DatabaseTable & database() const
returns the database on which we perform the counts
#define GUM_ERROR(type, msg)
Size NodeId
Type for node ids.
include the inlined functions if necessary
gum is the global namespace for all aGrUM entities