/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.operators.hmc;

import cern.colt.matrix.DoubleFactory2D;
import cern.colt.matrix.DoubleMatrix1D;
import cern.colt.matrix.DoubleMatrix2D;
import cern.colt.matrix.impl.DenseDoubleMatrix2D;
import cern.colt.matrix.linalg.Algebra;
import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.AbstractModel;
import dr.inference.model.Model;
import dr.inference.model.Parameter;
import dr.inference.model.PriorPreconditioningProvider;
import dr.inference.model.Variable;
import dr.inference.operators.hmc.MassPreconditioningOptions;
import dr.inference.operators.hmc.SecantHessian;
import dr.math.AdaptableCovariance;
import dr.math.AdaptableVector;
import dr.math.MachineAccuracy;
import dr.math.MathUtils;
import dr.math.MultivariateFunction;
import dr.math.distributions.MultivariateNormalDistribution;
import dr.math.matrixAlgebra.ReadableVector;
import dr.math.matrixAlgebra.RobustEigenDecomposition;
import dr.math.matrixAlgebra.WrappedMatrix;
import dr.math.matrixAlgebra.WrappedVector;
import dr.util.Transform;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

public interface MassPreconditioner {
    public WrappedVector drawInitialMomentum();

    public double getVelocity(int var1, ReadableVector var2);

    public void storeSecant(ReadableVector var1, ReadableVector var2);

    public void updateMass();

    public WrappedVector getMass();

    public void updateVariance(WrappedVector var1);

    public ReadableVector doCollision(int[] var1, ReadableVector var2);

    public int getDimension();

    default public double[] getVelocity(ReadableVector readableVector) {
        double[] dArray = new double[readableVector.getDim()];
        for (int i = 0; i < dArray.length; ++i) {
            dArray[i] = this.getVelocity(i, readableVector);
        }
        return dArray;
    }

    public static class AdaptiveFullHessianPreconditioning
    extends FullHessianPreconditioning {
        private final AdaptableCovariance adaptableCovariance;
        private final GradientWrtParameterProvider gradientProvider;
        private final AdaptableVector averageCovariance;
        private final double[] inverseMassBuffer;
        private final int minimumUpdates;
        protected MultivariateFunction numeric1 = new MultivariateFunction(){

            @Override
            public double evaluate(double[] dArray) {
                for (int i = 0; i < dArray.length; ++i) {
                    gradientProvider.getParameter().setParameterValue(i, dArray[i]);
                }
                return gradientProvider.getLikelihood().getLogLikelihood();
            }

            @Override
            public int getNumArguments() {
                return gradientProvider.getParameter().getDimension();
            }

            @Override
            public double getLowerBound(int n) {
                return 0.0;
            }

            @Override
            public double getUpperBound(int n) {
                return Double.POSITIVE_INFINITY;
            }
        };

        AdaptiveFullHessianPreconditioning(GradientWrtParameterProvider gradientWrtParameterProvider, AdaptableCovariance adaptableCovariance, Transform transform, int n, int n2) {
            super(null, transform, n);
            this.adaptableCovariance = adaptableCovariance;
            this.gradientProvider = gradientWrtParameterProvider;
            this.averageCovariance = new AdaptableVector.Default(n * n);
            this.inverseMassBuffer = new double[n * n];
            this.minimumUpdates = n2;
        }

        @Override
        protected void computeInverseMass() {
            if (this.adaptableCovariance.getUpdateCount() > this.minimumUpdates) {
                WrappedMatrix.ArrayOfArray arrayOfArray = (WrappedMatrix.ArrayOfArray)this.adaptableCovariance.getCovariance();
                double[] dArray = new double[this.dim * this.dim];
                for (int i = 0; i < this.dim; ++i) {
                    System.arraycopy(arrayOfArray.getArrays()[i], 0, dArray, i * this.dim, this.dim);
                }
                this.averageCovariance.update(new WrappedVector.Raw(dArray));
                this.cacheAverageCovariance(this.normalizeCovariance((WrappedVector.Raw)this.averageCovariance.getMean()));
                this.setInverseMassFromArray(this.inverseMassBuffer);
            }
        }

        private ReadableVector normalizeCovariance(WrappedVector wrappedVector) {
            double d = 0.0;
            for (int i = 0; i < this.dim; ++i) {
                d += wrappedVector.get(i * this.dim + i);
            }
            double d2 = (double)this.dim / d;
            for (int i = 0; i < this.dim * this.dim; ++i) {
                wrappedVector.set(i, wrappedVector.get(i) * d2);
            }
            return wrappedVector;
        }

        private void cacheAverageCovariance(ReadableVector readableVector) {
            int n;
            double[][] dArray = new double[this.dim][this.dim];
            for (int i = 0; i < this.dim; ++i) {
                for (n = 0; n < this.dim; ++n) {
                    dArray[i][n] = -readableVector.get(i * this.dim + n);
                }
            }
            double[] dArray2 = FullHessianPreconditioning.PDTransformMatrix.Default.transformMatrix(dArray, this.dim);
            for (n = 0; n < this.dim; ++n) {
                for (int i = 0; i < this.dim; ++i) {
                    this.inverseMassBuffer[n * this.dim + i] = dArray2[n * this.dim + i];
                }
            }
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
            this.adaptableCovariance.update(readableVector2);
        }
    }

    public static class Secant
    extends FullHessianPreconditioning {
        private final SecantHessian secantHessian;

        Secant(SecantHessian secantHessian, Transform transform) {
            super(secantHessian, transform);
            this.secantHessian = secantHessian;
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
            this.secantHessian.storeSecant(readableVector, readableVector2);
        }
    }

    public static class FullHessianPreconditioning
    extends HessianBased {
        FullHessianPreconditioning(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform) {
            super(hessianWrtParameterProvider, transform);
        }

        FullHessianPreconditioning(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform, int n) {
            super(hessianWrtParameterProvider, transform, n);
            this.inverseMass = new Parameter.Default("InverseMass", n * n);
            this.addVariable(this.inverseMass);
        }

        @Override
        protected void initializeMass() {
            double[] dArray = new double[this.dim * this.dim];
            for (int i = 0; i < this.dim; ++i) {
                dArray[i * this.dim + i] = 1.0;
            }
        }

        private double[] computeInverseMass(WrappedMatrix.ArrayOfArray arrayOfArray, GradientWrtParameterProvider gradientWrtParameterProvider, PDTransformMatrix pDTransformMatrix) {
            double[][] dArray = arrayOfArray.getArrays();
            if (this.transform != null) {
                dArray = this.transform.updateHessianLogDensity(dArray, new double[this.dim][this.dim], gradientWrtParameterProvider.getGradientLogDensity(), gradientWrtParameterProvider.getParameter().getParameterValues(), 0, this.dim);
            }
            return pDTransformMatrix.transformMatrix(dArray, this.dim);
        }

        @Override
        protected void computeInverseMass() {
            WrappedMatrix.ArrayOfArray arrayOfArray = new WrappedMatrix.ArrayOfArray(this.hessian.getHessianLogDensity());
            this.setInverseMassFromArray(this.computeInverseMass(arrayOfArray, this.hessian, PDTransformMatrix.Invert));
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
        }

        @Override
        public void updateVariance(WrappedVector wrappedVector) {
        }

        @Override
        public WrappedVector getMass() {
            throw new RuntimeException("Not yet implemented!");
        }

        @Override
        public WrappedVector drawInitialMomentum() {
            MultivariateNormalDistribution multivariateNormalDistribution = new MultivariateNormalDistribution(new double[this.dim], FullHessianPreconditioning.toArray(this.inverseMass.getParameterValues(), this.dim, this.dim));
            return new WrappedVector.Raw(multivariateNormalDistribution.nextMultivariateNormal());
        }

        @Override
        public double getVelocity(int n, ReadableVector readableVector) {
            double d = 0.0;
            for (int i = 0; i < this.dim; ++i) {
                d += this.inverseMass.getParameterValue(n * this.dim + i) * readableVector.get(i);
            }
            return d;
        }

        private static double[][] toArray(double[] dArray, int n, int n2) {
            double[][] dArrayArray = new double[n][];
            for (int i = 0; i < n; ++i) {
                dArrayArray[i] = new double[n2];
                System.arraycopy(dArray, n2 * i, dArrayArray[i], 0, n2);
            }
            return dArrayArray;
        }

        static enum PDTransformMatrix {
            Invert("Transform inverse matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.inverseNegateEigenvalues(doubleMatrix1D);
                }
            }
            ,
            Default("Transform matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.negateEigenvalues(doubleMatrix1D);
                }
            }
            ,
            Negate("Transform negative matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.negateEigenvalues(doubleMatrix1D);
                }

                @Override
                protected void normalizeEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.negateEigenvalues(doubleMatrix1D);
                    this.boundEigenvalues(doubleMatrix1D);
                    this.scaleEigenvalues(doubleMatrix1D);
                }
            }
            ,
            NegateInvert("Transform negative inverse matrix into a PD matrix"){

                @Override
                protected void transformEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.inverseNegateEigenvalues(doubleMatrix1D);
                }

                @Override
                protected void normalizeEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                    this.negateEigenvalues(doubleMatrix1D);
                    this.boundEigenvalues(doubleMatrix1D);
                    this.scaleEigenvalues(doubleMatrix1D);
                }
            };

            String desc;
            private static final double MIN_EIGENVALUE = -20.0;
            private static final double MAX_EIGENVALUE = -0.5;

            private PDTransformMatrix(String string2) {
                this.desc = string2;
            }

            public String toString() {
                return this.desc;
            }

            protected void boundEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                for (int i = 0; i < doubleMatrix1D.cardinality(); ++i) {
                    if (doubleMatrix1D.get(i) > -0.5) {
                        doubleMatrix1D.set(i, -0.5);
                        continue;
                    }
                    if (!(doubleMatrix1D.get(i) < -20.0)) continue;
                    doubleMatrix1D.set(i, -20.0);
                }
            }

            protected void scaleEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                double d = 0.0;
                for (int i = 0; i < doubleMatrix1D.cardinality(); ++i) {
                    d += doubleMatrix1D.get(i);
                }
                double d2 = -d / (double)doubleMatrix1D.cardinality();
                for (int i = 0; i < doubleMatrix1D.cardinality(); ++i) {
                    doubleMatrix1D.set(i, doubleMatrix1D.get(i) / d2);
                }
            }

            protected void normalizeEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                this.boundEigenvalues(doubleMatrix1D);
                this.scaleEigenvalues(doubleMatrix1D);
            }

            protected void inverseNegateEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                for (int i = 0; i < doubleMatrix1D.cardinality(); ++i) {
                    doubleMatrix1D.set(i, -1.0 / doubleMatrix1D.get(i));
                }
            }

            protected void negateEigenvalues(DoubleMatrix1D doubleMatrix1D) {
                for (int i = 0; i < doubleMatrix1D.cardinality(); ++i) {
                    doubleMatrix1D.set(i, -doubleMatrix1D.get(i));
                }
            }

            public double[] transformMatrix(double[][] dArray, int n) {
                Algebra algebra = new Algebra();
                DenseDoubleMatrix2D denseDoubleMatrix2D = new DenseDoubleMatrix2D(dArray);
                RobustEigenDecomposition robustEigenDecomposition = new RobustEigenDecomposition(denseDoubleMatrix2D);
                DoubleMatrix1D doubleMatrix1D = robustEigenDecomposition.getRealEigenvalues();
                this.normalizeEigenvalues(doubleMatrix1D);
                DoubleMatrix2D doubleMatrix2D = robustEigenDecomposition.getV();
                this.transformEigenvalues(doubleMatrix1D);
                double[][] dArray2 = algebra.mult(algebra.mult(doubleMatrix2D, DoubleFactory2D.dense.diagonal(doubleMatrix1D)), algebra.inverse(doubleMatrix2D)).toArray();
                double[] dArray3 = new double[n * n];
                for (int i = 0; i < n; ++i) {
                    System.arraycopy(dArray2[i], 0, dArray3, i * n, n);
                }
                return dArray3;
            }

            protected abstract void transformEigenvalues(DoubleMatrix1D var1);
        }
    }

    public static class AdaptiveDiagonalPreconditioning
    extends DiagonalPreconditioning {
        private AdaptableVector.AdaptableVariance variance;
        private final int minimumUpdates;
        private final GradientWrtParameterProvider gradient;
        private final MassPreconditioningOptions options;

        AdaptiveDiagonalPreconditioning(int n, GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, MassPreconditioningOptions massPreconditioningOptions) {
            this(n, gradientWrtParameterProvider, transform, massPreconditioningOptions, false);
        }

        AdaptiveDiagonalPreconditioning(int n, GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, MassPreconditioningOptions massPreconditioningOptions, boolean bl) {
            super(n, transform);
            this.variance = new AdaptableVector.AdaptableVariance(n);
            this.options = massPreconditioningOptions;
            this.minimumUpdates = massPreconditioningOptions.preconditioningDelay();
            if (this.minimumUpdates < 2) {
                throw new RuntimeException("Need at least two samples to calculate empirical variance.  Set HMC's option preconditioningDelay > 2 please!");
            }
            this.gradient = gradientWrtParameterProvider;
            if (bl) {
                this.setInitialMass();
            } else {
                super.initializeMass();
            }
        }

        @Override
        protected void initializeMass() {
        }

        private void setInitialMass() {
            int n;
            double[] dArray = this.gradient.getParameter().getParameterValues();
            double[] dArray2 = (double[])dArray.clone();
            for (int i = 0; i < this.dim; ++i) {
                this.gradient.getParameter().setParameterValueQuietly(i, dArray[i] + MachineAccuracy.SQRT_SQRT_EPSILON);
            }
            this.gradient.getParameter().fireParameterChangedEvent();
            double[] dArray3 = this.gradient.getGradientLogDensity();
            for (int i = 0; i < this.dim; ++i) {
                this.gradient.getParameter().setParameterValueQuietly(i, dArray[i] - MachineAccuracy.SQRT_SQRT_EPSILON);
            }
            this.gradient.getParameter().fireParameterChangedEvent();
            double[] dArray4 = this.gradient.getGradientLogDensity();
            for (n = 0; n < this.dim; ++n) {
                this.gradient.getParameter().setParameterValueQuietly(n, dArray[n]);
            }
            this.gradient.getParameter().fireParameterChangedEvent();
            for (n = 0; n < this.dim; ++n) {
                dArray[n] = Math.abs((dArray3[n] - dArray4[n]) / (2.0 * MachineAccuracy.SQRT_SQRT_EPSILON));
                this.gradient.getParameter().setParameterValueQuietly(n, dArray2[n]);
            }
            this.gradient.getParameter().fireParameterChangedEvent();
            this.fillZeros(dArray);
            this.setInverseMassFromArray(this.normalizeVector(new WrappedVector.Raw(dArray), this.dim));
        }

        private void fillZeros(double[] dArray) {
            int n;
            double d = 0.0;
            double d2 = Double.POSITIVE_INFINITY;
            for (n = 0; n < dArray.length; ++n) {
                d += dArray[n];
                if (!(d2 > dArray[n]) || !(dArray[n] > 0.0)) continue;
                d2 = dArray[n];
            }
            if (d == 0.0) {
                Arrays.fill(dArray, 1.0);
            } else {
                for (n = 0; n < dArray.length; ++n) {
                    if (dArray[n] != 0.0) continue;
                    dArray[n] = d2;
                }
            }
        }

        @Override
        protected void computeInverseMass() {
            if (this.variance.getUpdateCount() > this.minimumUpdates) {
                double[] dArray = this.variance.getVariance();
                this.adaptiveDiagonal.update(new WrappedVector.Raw(dArray));
                this.setInverseMassFromArray(DiagonalHessianPreconditioning.boundMassInverse(((WrappedVector)this.adaptiveDiagonal.getMean()).getBuffer(), this.options.preconditioningEigenLowerBound(), this.options.preconditioningEigenUpperBound(), this.dim, DiagonalHessianPreconditioning.VarianceConverter.VARIANCE));
            }
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
            this.variance.update(readableVector2);
        }

        @Override
        public void updateVariance(WrappedVector wrappedVector) {
            this.variance.update(wrappedVector);
        }

        @Override
        public WrappedVector getMass() {
            double[] dArray = new double[this.dim];
            for (int i = 0; i < this.dim; ++i) {
                dArray[i] = 1.0 / this.inverseMass.getParameterValue(i);
            }
            return new WrappedVector.Raw(dArray);
        }
    }

    public static class DiagonalHessianPreconditioning
    extends DiagonalPreconditioning {
        protected final HessianWrtParameterProvider hessian;
        private final Parameter lowerBound;
        private final Parameter upperBound;

        DiagonalHessianPreconditioning(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform, int n, Parameter parameter, Parameter parameter2) {
            super(hessianWrtParameterProvider.getDimension(), transform);
            this.hessian = hessianWrtParameterProvider;
            this.adaptiveDiagonal = n > 0 ? new AdaptableVector.LimitedMemory(hessianWrtParameterProvider.getDimension(), n) : new AdaptableVector.Default(hessianWrtParameterProvider.getDimension());
            this.lowerBound = parameter;
            this.upperBound = parameter2;
        }

        @Override
        protected void computeInverseMass() {
            double[] dArray;
            double[] dArray2 = this.hessian.getDiagonalHessianLogDensity();
            if (this.transform != null) {
                dArray = this.hessian.getParameter().getParameterValues();
                double[] dArray3 = this.hessian.getGradientLogDensity();
                dArray2 = this.transform.updateDiagonalHessianLogDensity(dArray2, dArray3, dArray, 0, this.dim);
            }
            this.adaptiveDiagonal.update(new WrappedVector.Raw(dArray2));
            dArray = DiagonalHessianPreconditioning.boundMassInverse(((WrappedVector)this.adaptiveDiagonal.getMean()).getBuffer(), this.lowerBound, this.upperBound, this.dim, VarianceConverter.HESSIAN);
            this.setInverseMassFromArray(dArray);
        }

        public static double[] boundMassInverse(double[] dArray, Parameter parameter, Parameter parameter2, int n, VarianceConverter varianceConverter) {
            double[] dArray2 = (double[])dArray.clone();
            DiagonalHessianPreconditioning.normalizeL1(dArray2, n);
            for (int i = 0; i < n; ++i) {
                dArray2[i] = varianceConverter.convertVariance(dArray2[i]);
                if (dArray2[i] < parameter.getParameterValue(0)) {
                    dArray2[i] = parameter.getParameterValue(0);
                    continue;
                }
                if (!(dArray2[i] > parameter2.getParameterValue(0))) continue;
                dArray2[i] = parameter2.getParameterValue(0);
            }
            DiagonalHessianPreconditioning.normalizeL1(dArray2, n);
            return dArray2;
        }

        private static void normalizeL1(double[] dArray, double d) {
            double d2 = 0.0;
            for (int i = 0; i < dArray.length; ++i) {
                d2 += Math.abs(dArray[i]);
            }
            double d3 = d / d2;
            for (int i = 0; i < dArray.length; ++i) {
                dArray[i] = dArray[i] * d3;
            }
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
        }

        @Override
        public void updateVariance(WrappedVector wrappedVector) {
        }

        @Override
        public WrappedVector getMass() {
            throw new RuntimeException("Not yet implemented!");
        }

        static enum VarianceConverter {
            HESSIAN{

                @Override
                double convertVariance(double d) {
                    return -1.0 / d;
                }
            }
            ,
            VARIANCE{

                @Override
                double convertVariance(double d) {
                    return d;
                }
            };


            abstract double convertVariance(double var1);
        }
    }

    public static class PriorPreconditioner
    extends DiagonalPreconditioning {
        PriorPreconditioningProvider priorDistribution;

        public PriorPreconditioner(PriorPreconditioningProvider priorPreconditioningProvider, Transform transform) {
            super(priorPreconditioningProvider.getDimension(), transform);
            this.priorDistribution = priorPreconditioningProvider;
            this.computeInverseMass();
        }

        public MassPreconditioner factory(PriorPreconditioningProvider priorPreconditioningProvider, Transform transform) {
            return new PriorPreconditioner(priorPreconditioningProvider, transform);
        }

        @Override
        protected void computeInverseMass() {
            for (int i = 0; i < this.priorDistribution.getDimension(); ++i) {
                double d = this.priorDistribution.getStandardDeviation(i);
                this.inverseMass.setParameterValue(i, d * d);
            }
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
        }

        @Override
        public void updateVariance(WrappedVector wrappedVector) {
        }

        @Override
        public WrappedVector getMass() {
            throw new RuntimeException("Not yet implemented!");
        }
    }

    public static abstract class DiagonalPreconditioning
    extends AbstractMassPreconditioning {
        protected AdaptableVector adaptiveDiagonal;

        protected DiagonalPreconditioning(int n, Transform transform) {
            super(n, transform);
            this.adaptiveDiagonal = new AdaptableVector.Default(n);
            this.inverseMass = new Parameter.Default("InverseMass", n);
            this.initializeMass();
            this.addVariable(this.inverseMass);
        }

        @Override
        protected void initializeMass() {
            double[] dArray = new double[this.dim];
            Arrays.fill(dArray, 1.0);
            double[] dArray2 = this.normalizeVector(new WrappedVector.Raw(dArray), this.dim);
            this.setInverseMassFromArray(dArray2);
        }

        protected double[] normalizeVector(ReadableVector readableVector, double d) {
            double d2 = 0.0;
            for (int i = 0; i < readableVector.getDim(); ++i) {
                d2 += readableVector.get(i);
            }
            double d3 = d / d2;
            double[] dArray = new double[readableVector.getDim()];
            for (int i = 0; i < readableVector.getDim(); ++i) {
                dArray[i] = readableVector.get(i) * d3;
            }
            return dArray;
        }

        @Override
        public WrappedVector drawInitialMomentum() {
            double[] dArray = new double[this.dim];
            for (int i = 0; i < this.dim; ++i) {
                dArray[i] = MathUtils.nextGaussian() * Math.sqrt(1.0 / this.inverseMass.getParameterValue(i));
            }
            return new WrappedVector.Raw(dArray);
        }

        @Override
        public double getVelocity(int n, ReadableVector readableVector) {
            return readableVector.get(n) * this.inverseMass.getParameterValue(n);
        }

        @Override
        public ReadableVector doCollision(int[] nArray, ReadableVector readableVector) {
            int n;
            if (nArray.length != 2) {
                throw new RuntimeException("Not implemented for more than two dimensions yet.");
            }
            WrappedVector.Raw raw = new WrappedVector.Raw(new double[readableVector.getDim()]);
            for (n = 0; n < readableVector.getDim(); ++n) {
                raw.set(n, readableVector.get(n));
            }
            n = nArray[0];
            int n2 = nArray[1];
            double d = ((this.inverseMass.getParameterValue(n2) - this.inverseMass.getParameterValue(n)) * readableVector.get(n) + 2.0 * this.inverseMass.getParameterValue(n2) * readableVector.get(n2)) / (this.inverseMass.getParameterValue(n) + this.inverseMass.getParameterValue(n2));
            double d2 = ((this.inverseMass.getParameterValue(n) - this.inverseMass.getParameterValue(n2)) * readableVector.get(n2) + 2.0 * this.inverseMass.getParameterValue(n) * readableVector.get(n)) / (this.inverseMass.getParameterValue(n) + this.inverseMass.getParameterValue(n2));
            raw.set(n, d);
            raw.set(n2, d2);
            return raw;
        }
    }

    public static abstract class HessianBased
    extends AbstractMassPreconditioning {
        protected final HessianWrtParameterProvider hessian;

        HessianBased(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform) {
            this(hessianWrtParameterProvider, transform, hessianWrtParameterProvider.getDimension());
        }

        HessianBased(HessianWrtParameterProvider hessianWrtParameterProvider, Transform transform, int n) {
            super(n, transform);
            this.hessian = hessianWrtParameterProvider;
            this.initializeMass();
        }

        @Override
        public ReadableVector doCollision(int[] nArray, ReadableVector readableVector) {
            throw new RuntimeException("Not yet implemented.");
        }
    }

    public static abstract class AbstractMassPreconditioning
    extends AbstractModel
    implements MassPreconditioner {
        protected final int dim;
        protected final Transform transform;
        protected Parameter inverseMass;
        private static final String PRECONDITIONING = "MassPreconditioning";
        protected static final String MASSNAME = "InverseMass";

        protected AbstractMassPreconditioning(int n, Transform transform) {
            super(PRECONDITIONING);
            this.dim = n;
            this.transform = transform;
        }

        protected abstract void initializeMass();

        protected abstract void computeInverseMass();

        @Override
        public void updateMass() {
            this.computeInverseMass();
        }

        @Override
        public int getDimension() {
            return this.dim;
        }

        @Override
        public abstract void storeSecant(ReadableVector var1, ReadableVector var2);

        protected void setInverseMassFromArray(double[] dArray) {
            for (int i = 0; i < dArray.length; ++i) {
                this.inverseMass.setParameterValue(i, dArray[i]);
            }
        }

        @Override
        protected void handleModelChangedEvent(Model model, Object object, int n) {
        }

        @Override
        protected void handleVariableChangedEvent(Variable variable, int n, Variable.ChangeType changeType) {
        }

        @Override
        protected void storeState() {
        }

        @Override
        protected void restoreState() {
        }

        @Override
        protected void acceptState() {
        }
    }

    public static class NoPreconditioning
    implements MassPreconditioner {
        final int dim;

        NoPreconditioning(int n) {
            this.dim = n;
        }

        @Override
        public WrappedVector drawInitialMomentum() {
            double[] dArray = new double[this.dim];
            for (int i = 0; i < this.dim; ++i) {
                dArray[i] = MathUtils.nextGaussian();
            }
            return new WrappedVector.Raw(dArray);
        }

        @Override
        public double getVelocity(int n, ReadableVector readableVector) {
            return readableVector.get(n);
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
        }

        @Override
        public void updateMass() {
        }

        @Override
        public ReadableVector doCollision(int[] nArray, ReadableVector readableVector) {
            if (nArray.length != 2) {
                throw new RuntimeException("Not implemented for more than two dimensions yet.");
            }
            WrappedVector.Raw raw = new WrappedVector.Raw(new double[readableVector.getDim()]);
            for (int i = 0; i < readableVector.getDim(); ++i) {
                raw.set(i, readableVector.get(i));
            }
            raw.set(nArray[0], readableVector.get(nArray[1]));
            raw.set(nArray[1], readableVector.get(nArray[0]));
            return raw;
        }

        @Override
        public WrappedVector getMass() {
            double[] dArray = new double[this.dim];
            Arrays.fill(dArray, 1.0);
            return new WrappedVector.Raw(dArray);
        }

        @Override
        public void updateVariance(WrappedVector wrappedVector) {
        }

        @Override
        public int getDimension() {
            return this.dim;
        }
    }

    public static class CompoundPreconditioning
    implements MassPreconditioner {
        final int dim;
        final List<MassPreconditioner> preconditionerList;
        boolean velocityKnown = false;
        double[] velocity;

        CompoundPreconditioning(List<MassPreconditioner> list) {
            int n = 0;
            for (MassPreconditioner massPreconditioner : list) {
                n += massPreconditioner.getDimension();
            }
            this.dim = n;
            this.preconditionerList = list;
            this.velocity = new double[this.dim];
        }

        @Override
        public WrappedVector drawInitialMomentum() {
            WrappedVector.Raw raw = new WrappedVector.Raw(new double[this.dim]);
            int n = 0;
            for (MassPreconditioner massPreconditioner : this.preconditionerList) {
                WrappedVector wrappedVector = massPreconditioner.drawInitialMomentum();
                for (int i = 0; i < massPreconditioner.getDimension(); ++i) {
                    raw.set(n + i, wrappedVector.get(i));
                }
                n += massPreconditioner.getDimension();
            }
            return raw;
        }

        @Override
        public double getVelocity(int n, ReadableVector readableVector) {
            this.getVelocityVector(readableVector);
            return this.velocity[n];
        }

        private void getVelocityVector(ReadableVector readableVector) {
            if (!this.velocityKnown) {
                int n = 0;
                List<ReadableVector> list = this.separateVectors(readableVector);
                for (int i = 0; i < this.preconditionerList.size(); ++i) {
                    MassPreconditioner massPreconditioner = this.preconditionerList.get(i);
                    ReadableVector readableVector2 = list.get(i);
                    for (int j = 0; j < massPreconditioner.getDimension(); ++j) {
                        this.velocity[n + j] = massPreconditioner.getVelocity(j, readableVector2);
                    }
                    n += massPreconditioner.getDimension();
                }
                this.velocityKnown = true;
            }
        }

        private List<ReadableVector> separateVectors(ReadableVector readableVector) {
            ArrayList<ReadableVector> arrayList = new ArrayList<ReadableVector>();
            int n = 0;
            for (MassPreconditioner massPreconditioner : this.preconditionerList) {
                WrappedVector.Raw raw = new WrappedVector.Raw(new double[massPreconditioner.getDimension()]);
                for (int i = 0; i < massPreconditioner.getDimension(); ++i) {
                    raw.set(i, readableVector.get(n + i));
                }
                arrayList.add(raw);
                n += massPreconditioner.getDimension();
            }
            return arrayList;
        }

        private ReadableVector combineVectors(List<ReadableVector> list) {
            WrappedVector.Raw raw = new WrappedVector.Raw(new double[this.dim]);
            int n = 0;
            for (ReadableVector readableVector : list) {
                for (int i = 0; i < readableVector.getDim(); ++i) {
                    raw.set(n + i, readableVector.get(i));
                }
            }
            return raw;
        }

        @Override
        public void storeSecant(ReadableVector readableVector, ReadableVector readableVector2) {
            List<ReadableVector> list = this.separateVectors(readableVector);
            List<ReadableVector> list2 = this.separateVectors(readableVector2);
            for (int i = 0; i < this.preconditionerList.size(); ++i) {
                this.preconditionerList.get(i).storeSecant(list.get(i), list2.get(i));
            }
        }

        @Override
        public void updateMass() {
            for (MassPreconditioner massPreconditioner : this.preconditionerList) {
                massPreconditioner.updateMass();
            }
            this.velocityKnown = false;
        }

        @Override
        public ReadableVector doCollision(int[] nArray, ReadableVector readableVector) {
            throw new RuntimeException("Not yet implemented!");
        }

        @Override
        public WrappedVector getMass() {
            throw new RuntimeException("Not yet implemented!");
        }

        @Override
        public void updateVariance(WrappedVector wrappedVector) {
        }

        @Override
        public int getDimension() {
            return this.dim;
        }
    }

    public static enum Type {
        NONE("none"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, MassPreconditioningOptions massPreconditioningOptions) {
                Parameter parameter = gradientWrtParameterProvider.getParameter();
                int n = parameter.getDimension();
                if (transform != null && transform instanceof Transform.MultivariableTransform) {
                    n = ((Transform.MultivariableTransform)transform).getDimension();
                }
                return new NoPreconditioning(n);
            }
        }
        ,
        DIAGONAL("diagonal"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, MassPreconditioningOptions massPreconditioningOptions) {
                return new DiagonalHessianPreconditioning((HessianWrtParameterProvider)gradientWrtParameterProvider, transform, massPreconditioningOptions.preconditioningMemory(), massPreconditioningOptions.preconditioningEigenLowerBound(), massPreconditioningOptions.preconditioningEigenUpperBound());
            }
        }
        ,
        ADAPTIVE_DIAGONAL("adaptiveDiagonal"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, MassPreconditioningOptions massPreconditioningOptions) {
                int n = transform instanceof Transform.MultivariableTransform ? ((Transform.MultivariableTransform)transform).getDimension() : gradientWrtParameterProvider.getDimension();
                return new AdaptiveDiagonalPreconditioning(n, gradientWrtParameterProvider, transform, massPreconditioningOptions);
            }
        }
        ,
        PRIOR_DIAGONAL("priorDiagonal"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, MassPreconditioningOptions massPreconditioningOptions) {
                if (!(gradientWrtParameterProvider instanceof PriorPreconditioningProvider)) {
                    throw new RuntimeException("Gradient must be a PriorPreconditioningProvider for prior preconditioning!");
                }
                return new PriorPreconditioner((PriorPreconditioningProvider)((Object)gradientWrtParameterProvider), transform);
            }
        }
        ,
        FULL("full"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, MassPreconditioningOptions massPreconditioningOptions) {
                return new FullHessianPreconditioning((HessianWrtParameterProvider)gradientWrtParameterProvider, transform);
            }
        }
        ,
        SECANT("secant"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, MassPreconditioningOptions massPreconditioningOptions) {
                SecantHessian secantHessian = new SecantHessian(gradientWrtParameterProvider, massPreconditioningOptions.preconditioningMemory());
                return new Secant(secantHessian, transform);
            }
        }
        ,
        ADAPTIVE("adaptive"){

            @Override
            public MassPreconditioner factory(GradientWrtParameterProvider gradientWrtParameterProvider, Transform transform, MassPreconditioningOptions massPreconditioningOptions) {
                AdaptableCovariance adaptableCovariance = new AdaptableCovariance(gradientWrtParameterProvider.getDimension());
                return new AdaptiveFullHessianPreconditioning(gradientWrtParameterProvider, adaptableCovariance, transform, gradientWrtParameterProvider.getDimension(), massPreconditioningOptions.preconditioningDelay());
            }
        };

        private final String name;

        private Type(String string2) {
            this.name = string2;
        }

        public abstract MassPreconditioner factory(GradientWrtParameterProvider var1, Transform var2, MassPreconditioningOptions var3);

        public String getName() {
            return this.name;
        }

        public static Type parseFromString(String string) {
            for (Type type : Type.values()) {
                if (type.name.toLowerCase().compareToIgnoreCase(string) != 0) continue;
                return type;
            }
            return NONE;
        }
    }
}

