/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.treedatalikelihood.continuous;

import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.evomodel.treedatalikelihood.continuous.ContinuousTraitGradientForBranch;
import dr.evomodel.treedatalikelihood.preorder.BranchConditionalDistributionDelegate;
import dr.evomodel.treedatalikelihood.preorder.BranchSufficientStatistics;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.loggers.LogColumn;
import dr.inference.loggers.Loggable;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.math.MultivariateFunction;
import dr.math.NumericalDerivative;
import dr.math.matrixAlgebra.Vector;
import dr.xml.Reportable;
import java.util.List;

public class BranchSpecificGradient
implements GradientWrtParameterProvider,
Reportable,
Loggable {
    private final TreeDataLikelihood treeDataLikelihood;
    private final TreeTrait<List<BranchSufficientStatistics>> treeTraitProvider;
    private final Tree tree;
    private final int nTraits;
    private final Parameter parameter;
    private final ContinuousTraitGradientForBranch branchProvider;
    private MultivariateFunction numeric1 = new MultivariateFunction(){

        @Override
        public double evaluate(double[] dArray) {
            for (int i = 0; i < dArray.length; ++i) {
                BranchSpecificGradient.this.parameter.setParameterValue(i, dArray[i]);
            }
            BranchSpecificGradient.this.treeDataLikelihood.makeDirty();
            return BranchSpecificGradient.this.treeDataLikelihood.getLogLikelihood();
        }

        @Override
        public int getNumArguments() {
            return BranchSpecificGradient.this.parameter.getDimension();
        }

        @Override
        public double getLowerBound(int n) {
            return 0.0;
        }

        @Override
        public double getUpperBound(int n) {
            return Double.POSITIVE_INFINITY;
        }
    };
    private static final boolean DEBUG = false;

    public BranchSpecificGradient(String string, TreeDataLikelihood treeDataLikelihood, ContinuousDataLikelihoodDelegate continuousDataLikelihoodDelegate, ContinuousTraitGradientForBranch continuousTraitGradientForBranch, Parameter parameter) {
        TreeTrait treeTrait;
        assert (treeDataLikelihood != null);
        this.treeDataLikelihood = treeDataLikelihood;
        this.tree = treeDataLikelihood.getTree();
        this.parameter = parameter;
        this.branchProvider = continuousTraitGradientForBranch;
        String string2 = BranchConditionalDistributionDelegate.getName(string);
        if (treeDataLikelihood.getTreeTrait(string2) == null) {
            continuousDataLikelihoodDelegate.addBranchConditionalDensityTrait(string);
        }
        this.treeTraitProvider = treeTrait = treeDataLikelihood.getTreeTrait(string2);
        assert (this.treeTraitProvider != null);
        this.nTraits = treeDataLikelihood.getDataLikelihoodDelegate().getTraitCount();
        if (this.nTraits != 1) {
            throw new RuntimeException("Not yet implemented for >1 traits");
        }
    }

    @Override
    public Likelihood getLikelihood() {
        return this.treeDataLikelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.getParameter().getDimension();
    }

    @Override
    public double[] getGradientLogDensity() {
        int n = this.branchProvider.getDimension();
        double[] dArray = new double[this.parameter.getDimension()];
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            List<BranchSufficientStatistics> list = this.treeTraitProvider.getTrait(this.tree, nodeRef);
            assert (list.size() == this.nTraits);
            double[] dArray2 = this.branchProvider.getGradientForBranch(list.get(0), nodeRef);
            int n2 = this.getParameterIndexFromNode(nodeRef);
            assert (n2 != -1);
            for (int j = 0; j < n; ++j) {
                int n3 = n2 * n + j;
                dArray[n3] = dArray[n3] + dArray2[j];
            }
        }
        return dArray;
    }

    private int getParameterIndexFromNode(NodeRef nodeRef) {
        return this.branchProvider.getParameterIndexFromNode(nodeRef);
    }

    public List<ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient.DerivationParameter> getDerivationParameter() {
        return ((ContinuousTraitGradientForBranch.ContinuousProcessParameterGradient)this.branchProvider).getDerivationParameter();
    }

    @Override
    public String getReport() {
        double[] dArray = this.parameter.getParameterValues();
        double[] dArray2 = NumericalDerivative.gradient(this.numeric1, this.parameter.getParameterValues());
        for (int i = 0; i < dArray.length; ++i) {
            this.parameter.setParameterValue(i, dArray[i]);
        }
        StringBuilder stringBuilder = new StringBuilder();
        stringBuilder.append("Peeling: ").append(new Vector(this.getGradientLogDensity()));
        stringBuilder.append("\n");
        stringBuilder.append("numeric: ").append(new Vector(dArray2));
        stringBuilder.append("\n");
        return stringBuilder.toString();
    }

    @Override
    public LogColumn[] getColumns() {
        LogColumn[] logColumnArray = new LogColumn[]{new LogColumn.Default("gradient report", new Object(){

            public String toString() {
                return "\n" + BranchSpecificGradient.this.getReport();
            }
        })};
        return logColumnArray;
    }
}

