/*
 * Decompiled with CFR 0.152.
 */
package dr.evomodel.continuous;

import dr.evolution.tree.BranchRates;
import dr.evolution.tree.NodeRef;
import dr.evolution.tree.Tree;
import dr.evolution.tree.TreeTrait;
import dr.evomodel.continuous.ContinuousDiffusionStatistic;
import dr.evomodel.tree.TreeStatistic;
import dr.evomodel.treedatalikelihood.TreeDataLikelihood;
import dr.evomodel.treedatalikelihood.continuous.ContinuousDataLikelihoodDelegate;
import dr.xml.AbstractXMLObjectParser;
import dr.xml.AttributeRule;
import dr.xml.ElementRule;
import dr.xml.XMLObject;
import dr.xml.XMLObjectParser;
import dr.xml.XMLParseException;
import dr.xml.XMLSyntaxRule;

public class TreeDataContinuousDiffusionStatistic
extends TreeStatistic {
    public static final String CONTINUOUS_DIFFUSION_STATISTIC = "traitDataContinuousDiffusionStatistic";
    private final TreeTrait.DA trait;
    private final Tree tree;
    private final BranchRates branchRates;
    private final WeightingScheme weightingScheme;
    private final DisplacementScheme displacementScheme;
    private final ScalingScheme scalingScheme;
    private static final String WEIGHTING_SCHEME = "weightingScheme";
    private static final String BRANCH_RATE_SCHEME = "scalingScheme";
    private static final String DISPLACEMENT_SCHEME = "displacementScheme";
    public static XMLObjectParser PARSER = new AbstractXMLObjectParser(){
        private XMLSyntaxRule[] rules = new XMLSyntaxRule[]{AttributeRule.newStringRule("name", true), AttributeRule.newStringRule("traitName"), new ElementRule(TreeDataLikelihood.class), AttributeRule.newStringRule("weightingScheme", true), AttributeRule.newStringRule("displacementScheme", true)};

        @Override
        public String getParserName() {
            return TreeDataContinuousDiffusionStatistic.CONTINUOUS_DIFFUSION_STATISTIC;
        }

        @Override
        public Object parseXMLObject(XMLObject xMLObject) throws XMLParseException {
            TreeDataLikelihood treeDataLikelihood = (TreeDataLikelihood)xMLObject.getChild(TreeDataLikelihood.class);
            if (!(treeDataLikelihood.getDataLikelihoodDelegate() instanceof ContinuousDataLikelihoodDelegate)) {
                throw new XMLParseException("Must provide a continuous trait data likelihood");
            }
            String string = xMLObject.getAttribute("name", xMLObject.getId());
            String string2 = xMLObject.getStringAttribute("traitName");
            TreeTrait.DA dA = (TreeTrait.DA)treeDataLikelihood.getTreeTrait(string2);
            if (dA == null) {
                throw new XMLParseException("Not trait `" + string2 + "' in likelihood `" + treeDataLikelihood.getId() + "`");
            }
            WeightingScheme weightingScheme = this.parseWeightingScheme(xMLObject);
            DisplacementScheme displacementScheme = this.parseDisplacementScheme(xMLObject);
            ScalingScheme scalingScheme = this.parseScalingScheme(xMLObject);
            return new TreeDataContinuousDiffusionStatistic(string, dA, treeDataLikelihood, weightingScheme, displacementScheme, scalingScheme);
        }

        @Override
        public String getParserDescription() {
            return "A statistic that returns the average of the branch diffusion rates";
        }

        @Override
        public Class getReturnType() {
            return TreeStatistic.class;
        }

        @Override
        public XMLSyntaxRule[] getSyntaxRules() {
            return this.rules;
        }

        WeightingScheme parseWeightingScheme(XMLObject xMLObject) throws XMLParseException {
            String string = xMLObject.getAttribute(TreeDataContinuousDiffusionStatistic.WEIGHTING_SCHEME, WeightingScheme.WEIGHTED.getName());
            for (WeightingScheme weightingScheme : WeightingScheme.values()) {
                if (string.compareToIgnoreCase(weightingScheme.getName()) != 0) continue;
                return weightingScheme;
            }
            throw new XMLParseException("Unknown weighting scheme '" + string + "'");
        }

        DisplacementScheme parseDisplacementScheme(XMLObject xMLObject) throws XMLParseException {
            String string = xMLObject.getAttribute(TreeDataContinuousDiffusionStatistic.DISPLACEMENT_SCHEME, DisplacementScheme.QUADRATIC.getName());
            for (DisplacementScheme displacementScheme : DisplacementScheme.values()) {
                if (string.compareToIgnoreCase(displacementScheme.getName()) != 0) continue;
                return displacementScheme;
            }
            throw new XMLParseException("Unknown displacement scheme '" + string + "'");
        }

        ScalingScheme parseScalingScheme(XMLObject xMLObject) throws XMLParseException {
            String string = xMLObject.getAttribute(TreeDataContinuousDiffusionStatistic.BRANCH_RATE_SCHEME, ScalingScheme.RATE_DEPENDENT.getName());
            for (ScalingScheme scalingScheme : ScalingScheme.values()) {
                if (string.compareToIgnoreCase(scalingScheme.getName()) != 0) continue;
                return scalingScheme;
            }
            throw new XMLParseException("Unknown scaling scheme '" + string + "'");
        }
    };

    private TreeDataContinuousDiffusionStatistic(String string, TreeTrait.DA dA, TreeDataLikelihood treeDataLikelihood, WeightingScheme weightingScheme, DisplacementScheme displacementScheme, ScalingScheme scalingScheme) {
        super(string);
        this.trait = dA;
        this.tree = treeDataLikelihood.getTree();
        this.branchRates = treeDataLikelihood.getBranchRateModel();
        this.weightingScheme = weightingScheme;
        this.displacementScheme = displacementScheme;
        this.scalingScheme = scalingScheme;
    }

    @Override
    public void setTree(Tree tree) {
        throw new RuntimeException("Cannot set the tree");
    }

    @Override
    public Tree getTree() {
        return this.tree;
    }

    @Override
    public int getDimension() {
        return 1;
    }

    @Override
    public double getStatisticValue(int n) {
        Statistic statistic = new Statistic();
        for (int i = 0; i < this.tree.getNodeCount(); ++i) {
            NodeRef nodeRef = this.tree.getNode(i);
            if (nodeRef == this.tree.getRoot()) continue;
            this.addBranchStatistic(statistic, nodeRef);
        }
        return statistic.numerator / statistic.denominator;
    }

    private void addBranchStatistic(Statistic statistic, NodeRef nodeRef) {
        NodeRef nodeRef2 = this.tree.getParent(nodeRef);
        double[] dArray = (double[])this.trait.getTrait(this.tree, nodeRef2);
        double[] dArray2 = (double[])this.trait.getTrait(this.tree, nodeRef);
        double d = this.displacementScheme.displace(dArray, dArray2);
        double d2 = this.tree.getNodeHeight(nodeRef2) - this.tree.getNodeHeight(nodeRef);
        double d3 = d2 * this.scalingScheme.scale(this.branchRates, this.tree, nodeRef);
        this.weightingScheme.add(statistic, d, d3);
    }

    private static double distance(double[] dArray, double[] dArray2) {
        assert (dArray.length == dArray2.length);
        double d = 0.0;
        for (int i = 0; i < dArray.length; ++i) {
            d += (dArray[i] - dArray2[i]) * (dArray[i] - dArray2[i]);
        }
        return d;
    }

    private static enum WeightingScheme {
        WEIGHTED{

            @Override
            void add(Statistic statistic, double d, double d2) {
                statistic.numerator += d;
                statistic.denominator += d2;
            }

            @Override
            String getName() {
                return "weighted";
            }
        }
        ,
        UNWEIGHTED{

            @Override
            void add(Statistic statistic, double d, double d2) {
                statistic.numerator += d / d2;
                statistic.denominator += 1.0;
            }

            @Override
            String getName() {
                return "unweighted";
            }
        };


        abstract void add(Statistic var1, double var2, double var4);

        abstract String getName();
    }

    private static enum DisplacementScheme {
        LINEAR{

            @Override
            double displace(double[] dArray, double[] dArray2) {
                return Math.sqrt(TreeDataContinuousDiffusionStatistic.distance(dArray, dArray2));
            }

            @Override
            String getName() {
                return "linear";
            }
        }
        ,
        QUADRATIC{

            @Override
            double displace(double[] dArray, double[] dArray2) {
                return TreeDataContinuousDiffusionStatistic.distance(dArray, dArray2);
            }

            @Override
            String getName() {
                return "quadratic";
            }
        }
        ,
        GREAT_CIRCLE_DISTANCE{

            @Override
            double displace(double[] dArray, double[] dArray2) {
                if (dArray.length == 2 && dArray2.length == 2) {
                    return ContinuousDiffusionStatistic.getGreatCircleDistance(dArray, dArray2);
                }
                return LINEAR.displace(dArray, dArray2);
            }

            @Override
            String getName() {
                return "greatCircleDistance";
            }
        };


        abstract String getName();

        abstract double displace(double[] var1, double[] var2);
    }

    private static enum ScalingScheme {
        RATE_DEPENDENT{

            @Override
            double scale(BranchRates branchRates, Tree tree, NodeRef nodeRef) {
                return 1.0;
            }

            @Override
            String getName() {
                return "dependent";
            }
        }
        ,
        RATE_INDEPENDENT{

            @Override
            double scale(BranchRates branchRates, Tree tree, NodeRef nodeRef) {
                return branchRates.getBranchRate(tree, nodeRef);
            }

            @Override
            String getName() {
                return "independent";
            }
        };


        abstract double scale(BranchRates var1, Tree var2, NodeRef var3);

        abstract String getName();
    }

    private class Statistic {
        double numerator = 0.0;
        double denominator = 0.0;

        Statistic() {
        }
    }
}

