/*
 * Decompiled with CFR 0.152.
 */
package dr.inference.hmc;

import dr.inference.hmc.GradientWrtParameterProvider;
import dr.inference.hmc.HessianWrtParameterProvider;
import dr.inference.model.Likelihood;
import dr.inference.model.Parameter;
import dr.inference.operators.PathDependent;
import java.util.Arrays;

public class PathGradient
implements HessianWrtParameterProvider,
PathDependent {
    private final int dimension;
    private final Likelihood likelihood;
    private final Parameter parameter;
    private final GradientWrtParameterProvider source;
    private final GradientWrtParameterProvider destination;
    private double beta = 1.0;

    public PathGradient(final GradientWrtParameterProvider gradientWrtParameterProvider, final GradientWrtParameterProvider gradientWrtParameterProvider2) {
        this.source = gradientWrtParameterProvider;
        this.destination = gradientWrtParameterProvider2;
        this.dimension = gradientWrtParameterProvider.getDimension();
        this.parameter = gradientWrtParameterProvider.getParameter();
        if (gradientWrtParameterProvider2.getDimension() != this.dimension) {
            throw new RuntimeException("Unequal parameter dimensions");
        }
        if (!Arrays.equals(gradientWrtParameterProvider2.getParameter().getParameterValues(), this.parameter.getParameterValues())) {
            throw new RuntimeException("Unequal parameter values");
        }
        this.likelihood = new Likelihood.Abstract(gradientWrtParameterProvider.getLikelihood().getModel()){

            @Override
            protected double calculateLogLikelihood() {
                double d = gradientWrtParameterProvider.getLikelihood().getLogLikelihood();
                if (PathGradient.this.beta != 1.0) {
                    d = PathGradient.blend(d, gradientWrtParameterProvider2.getLikelihood().getLogLikelihood(), PathGradient.this.beta);
                }
                return d;
            }
        };
    }

    @Override
    public void setPathParameter(double d) {
        this.beta = d;
    }

    @Override
    public Likelihood getLikelihood() {
        return this.likelihood;
    }

    @Override
    public Parameter getParameter() {
        return this.parameter;
    }

    @Override
    public int getDimension() {
        return this.dimension;
    }

    @Override
    public double[] getGradientLogDensity() {
        double[] dArray = this.source.getGradientLogDensity();
        if (this.beta != 1.0) {
            double[] dArray2 = this.destination.getGradientLogDensity();
            for (int i = 0; i < dArray.length; ++i) {
                dArray[i] = PathGradient.blend(dArray[i], dArray2[i], this.beta);
            }
        }
        return dArray;
    }

    private static double blend(double d, double d2, double d3) {
        return d3 * d + (1.0 - d3) * d2;
    }

    @Override
    public double[] getDiagonalHessianLogDensity() {
        if (!(this.source instanceof HessianWrtParameterProvider) || !(this.destination instanceof HessianWrtParameterProvider)) {
            throw new RuntimeException("Must use Hessian providers");
        }
        double[] dArray = ((HessianWrtParameterProvider)this.source).getDiagonalHessianLogDensity();
        if (this.beta != 1.0) {
            double[] dArray2 = ((HessianWrtParameterProvider)this.destination).getDiagonalHessianLogDensity();
            for (int i = 0; i < dArray.length; ++i) {
                dArray[i] = PathGradient.blend(dArray[i], dArray2[i], this.beta);
            }
        }
        return dArray;
    }

    @Override
    public double[][] getHessianLogDensity() {
        throw new RuntimeException("Not yet implemented");
    }
}

