/****************************************************************************
 *   This file is part of the aGrUM/pyAgrum library.                        *
 *                                                                          *
 *   Copyright (c) 2005-2025 by                                             *
 *       - Pierre-Henri WUILLEMIN(_at_LIP6)                                 *
 *       - Christophe GONZALES(_at_AMU)                                     *
 *                                                                          *
 *   The aGrUM/pyAgrum library is free software; you can redistribute it    *
 *   and/or modify it under the terms of either :                           *
 *                                                                          *
 *    - the GNU Lesser General Public License as published by               *
 *      the Free Software Foundation, either version 3 of the License,      *
 *      or (at your option) any later version,                              *
 *    - the MIT license (MIT),                                              *
 *    - or both in dual license, as here.                                   *
 *                                                                          *
 *   (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html)    *
 *                                                                          *
 *   This aGrUM/pyAgrum library is distributed in the hope that it will be  *
 *   useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED,          *
 *   INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS *
 *   FOR A PARTICULAR PURPOSE  AND NONINFRINGEMENT. IN NO EVENT SHALL THE   *
 *   AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER *
 *   LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE,        *
 *   ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR  *
 *   OTHER DEALINGS IN THE SOFTWARE.                                        *
 *                                                                          *
 *   See LICENCES for more details.                                         *
 *                                                                          *
 *   SPDX-FileCopyrightText: Copyright 2005-2025                            *
 *       - Pierre-Henri WUILLEMIN(_at_LIP6)                                 *
 *       - Christophe GONZALES(_at_AMU)                                     *
 *   SPDX-License-Identifier: LGPL-3.0-or-later OR MIT                      *
 *                                                                          *
 *   Contact  : info_at_agrum_dot_org                                       *
 *   homepage : http://agrum.gitlab.io                                      *
 *   gitlab   : https://gitlab.com/agrumery/agrum                           *
 *                                                                          *
 ****************************************************************************/


/**
 * @file
 * @brief Implementation of the generic class for the computation of
 * (possibly incrementally) marginal posteriors
 */
#include <iterator>

namespace gum {


  // Default Constructor
  template < typename GUM_SCALAR >
  MarginalTargetedInference< GUM_SCALAR >::MarginalTargetedInference(
      const IBayesNet< GUM_SCALAR >* bn) : BayesNetInference< GUM_SCALAR >(bn) {
    // assign a BN if this has not been done before (due to virtual inheritance)
    if (this->hasNoModel_()) {
      BayesNetInference< GUM_SCALAR >::_setBayesNetDuringConstruction_(bn);
    }

    // sets all the nodes as targets
    if (bn != nullptr) {
      _targeted_mode_ = false;
      _targets_       = bn->dag().asNodeSet();
    }

    GUM_CONSTRUCTOR(MarginalTargetedInference);
  }

  // Destructor
  template < typename GUM_SCALAR >
  MarginalTargetedInference< GUM_SCALAR >::~MarginalTargetedInference() {
    GUM_DESTRUCTOR(MarginalTargetedInference);
  }

  // fired when a new BN is assigned to the inference engine
  template < typename GUM_SCALAR >
  void MarginalTargetedInference< GUM_SCALAR >::onModelChanged_(const GraphicalModel* bn) {
    _targeted_mode_ = true;
    _setAllMarginalTargets_();
  }

  // ##############################################################################
  // Targets
  // ##############################################################################

  // return true if variable is a target
  template < typename GUM_SCALAR >
  INLINE bool MarginalTargetedInference< GUM_SCALAR >::isTarget(NodeId node) const {
    // check that the variable belongs to the bn
    if (this->hasNoModel_())
      GUM_ERROR(NullElement,
                "No Bayes net has been assigned to the "
                "inference algorithm");
    if (!this->BN().dag().exists(node)) {
      GUM_ERROR(UndefinedElement, node << " is not a NodeId in the bn")
    }

    return _targets_.contains(node);
  }

  // Add a single target to the list of targets
  template < typename GUM_SCALAR >
  INLINE bool MarginalTargetedInference< GUM_SCALAR >::isTarget(const std::string& nodeName) const {
    return isTarget(this->BN().idFromName(nodeName));
  }

  // Clear all previously defined targets (single targets and sets of targets)
  template < typename GUM_SCALAR >
  INLINE void MarginalTargetedInference< GUM_SCALAR >::eraseAllTargets() {
    onAllMarginalTargetsErased_();

    _targets_.clear();
    setTargetedMode_();   // does nothing if already in targeted mode

    this->setState_(GraphicalModelInference< GUM_SCALAR >::StateOfInference::OutdatedStructure);
  }

  // Add a single target to the list of targets
  template < typename GUM_SCALAR >
  void MarginalTargetedInference< GUM_SCALAR >::addTarget(NodeId target) {
    // check if the node belongs to the Bayesian network
    if (this->hasNoModel_())
      GUM_ERROR(NullElement,
                "No Bayes net has been assigned to the "
                "inference algorithm");

    if (!this->BN().dag().exists(target)) {
      GUM_ERROR(UndefinedElement, target << " is not a NodeId in the bn")
    }

    setTargetedMode_();   // does nothing if already in targeted mode
    // add the new target
    if (!_targets_.contains(target)) {
      _targets_.insert(target);
      onMarginalTargetAdded_(target);
      this->setState_(GraphicalModelInference< GUM_SCALAR >::StateOfInference::OutdatedStructure);
    }
  }

  // Add all nodes as targets
  template < typename GUM_SCALAR >
  void MarginalTargetedInference< GUM_SCALAR >::addAllTargets() {
    // check if the node belongs to the Bayesian network
    if (this->hasNoModel_())
      GUM_ERROR(NullElement,
                "No Bayes net has been assigned to the "
                "inference algorithm");


    setTargetedMode_();   // does nothing if already in targeted mode
    for (const auto target: this->BN().dag()) {
      if (!_targets_.contains(target)) {
        _targets_.insert(target);
        onMarginalTargetAdded_(target);
        this->setState_(GraphicalModelInference< GUM_SCALAR >::StateOfInference::OutdatedStructure);
      }
    }
  }

  // Add a single target to the list of targets
  template < typename GUM_SCALAR >
  void MarginalTargetedInference< GUM_SCALAR >::addTarget(const std::string& nodeName) {
    // check if the node belongs to the Bayesian network
    if (this->hasNoModel_())
      GUM_ERROR(NullElement,
                "No Bayes net has been assigned to the "
                "inference algorithm");

    addTarget(this->BN().idFromName(nodeName));
  }

  // removes an existing target
  template < typename GUM_SCALAR >
  void MarginalTargetedInference< GUM_SCALAR >::eraseTarget(NodeId target) {
    // check if the node belongs to the Bayesian network
    if (this->hasNoModel_())
      GUM_ERROR(NullElement,
                "No Bayes net has been assigned to the "
                "inference algorithm");

    if (!this->BN().dag().exists(target)) {
      GUM_ERROR(UndefinedElement, target << " is not a NodeId in the bn")
    }


    if (_targets_.contains(target)) {
      _targeted_mode_ = true;   // we do not use setTargetedMode_ because we do not
                                // want to clear the targets
      onMarginalTargetErased_(target);
      _targets_.erase(target);
      this->setState_(GraphicalModelInference< GUM_SCALAR >::StateOfInference::OutdatedStructure);
    }
  }

  // Add a single target to the list of targets
  template < typename GUM_SCALAR >
  void MarginalTargetedInference< GUM_SCALAR >::eraseTarget(const std::string& nodeName) {
    // check if the node belongs to the Bayesian network
    if (this->hasNoModel_())
      GUM_ERROR(NullElement,
                "No Bayes net has been assigned to the "
                "inference algorithm");

    eraseTarget(this->BN().idFromName(nodeName));
  }

  // returns the list of single targets
  template < typename GUM_SCALAR >
  INLINE const NodeSet& MarginalTargetedInference< GUM_SCALAR >::targets() const noexcept {
    return _targets_;
  }

  // returns the list of single targets
  template < typename GUM_SCALAR >
  INLINE Size MarginalTargetedInference< GUM_SCALAR >::nbrTargets() const noexcept {
    return _targets_.size();
  }

  // indicates whether the inference is in a target mode
  template < typename GUM_SCALAR >
  INLINE bool MarginalTargetedInference< GUM_SCALAR >::isInTargetMode() const noexcept {
    return _targeted_mode_;
  }

  /// sets all the nodes of the Bayes net as targets
  template < typename GUM_SCALAR >
  void MarginalTargetedInference< GUM_SCALAR >::_setAllMarginalTargets_() {
    _targets_.clear();
    if (!this->hasNoModel_()) {
      _targets_ = this->BN().dag().asNodeSet();
      onAllMarginalTargetsAdded_();
    }
  }

  // ##############################################################################
  // Inference
  // ##############################################################################

  // Compute the posterior of a node.
  template < typename GUM_SCALAR >
  const Tensor< GUM_SCALAR >& MarginalTargetedInference< GUM_SCALAR >::posterior(NodeId node) {
    if (this->hardEvidenceNodes().contains(node)) { return *(this->evidence()[node]); }

    if (!isTarget(node)) {
      // throws UndefinedElement if var is not a target
      GUM_ERROR(UndefinedElement, node << " is not a target node")
    }

    if (!this->isInferenceDone()) { this->makeInference(); }

    return posterior_(node);
  }

  // Compute the posterior of a node.
  template < typename GUM_SCALAR >
  const Tensor< GUM_SCALAR >&
      MarginalTargetedInference< GUM_SCALAR >::posterior(const std::string& nodeName) {
    return posterior(this->BN().idFromName(nodeName));
  }

  /* Entropy
   * Compute Shanon's entropy of a node given the observation
   */
  template < typename GUM_SCALAR >
  INLINE GUM_SCALAR MarginalTargetedInference< GUM_SCALAR >::H(NodeId X) {
    return posterior(X).entropy();
  }

  /* Entropy
   * Compute Shanon's entropy of a node given the observation
   */
  template < typename GUM_SCALAR >
  INLINE GUM_SCALAR MarginalTargetedInference< GUM_SCALAR >::H(const std::string& nodeName) {
    return H(this->BN().idFromName(nodeName));
  }

  template < typename GUM_SCALAR >
  Tensor< GUM_SCALAR > MarginalTargetedInference< GUM_SCALAR >::evidenceImpact(NodeId target,
                                                                               const NodeSet& evs) {
    const auto& vtarget = this->BN().variable(target);

    if (evs.contains(target)) {
      GUM_ERROR(InvalidArgument,
                "Target <" << vtarget.name() << "> (" << target << ") can not be in evs (" << evs
                           << ").");
    }
    auto condset = this->BN().minimalCondSet(target, evs);

    Tensor< GUM_SCALAR > res;
    this->eraseAllTargets();
    this->eraseAllEvidence();
    res.add(this->BN().variable(target));
    this->addTarget(target);
    for (const auto& n: condset) {
      res.add(this->BN().variable(n));
      this->addEvidence(n, 0);
    }

    Instantiation inst(res);
    for (inst.setFirst(); !inst.end(); inst.incNotVar(vtarget)) {
      // inferring
      for (const auto& n: condset)
        this->chgEvidence(n, inst.val(this->BN().variable(n)));
      this->makeInference();
      // populate res
      const auto& pot = this->posterior(target);
      for (inst.setFirstVar(vtarget); !inst.end(); inst.incVar(vtarget)) {
        res.set(inst, pot[inst]);
      }
      inst.setFirstVar(vtarget);   // remove inst.end() flag
    }

    return res;
  }

  template < typename GUM_SCALAR >
  Tensor< GUM_SCALAR > MarginalTargetedInference< GUM_SCALAR >::evidenceImpact(
      const std::string&                target,
      const std::vector< std::string >& evs) {
    const auto& bn = this->BN();
    return evidenceImpact(bn.idFromName(target), bn.nodeset(evs));
  }

  template < typename GUM_SCALAR >
  INLINE bool MarginalTargetedInference< GUM_SCALAR >::isTargetedMode_() const {
    return _targeted_mode_;
  }

  template < typename GUM_SCALAR >
  INLINE void MarginalTargetedInference< GUM_SCALAR >::setTargetedMode_() {
    if (!_targeted_mode_) {
      _targets_.clear();
      _targeted_mode_ = true;
    }
  }
} /* namespace gum */
