/*
 * Decompiled with CFR 0.152.
 */
package org.tribuo.regression.evaluation;

import java.util.function.BiFunction;
import java.util.function.ToDoubleBiFunction;
import org.tribuo.evaluation.metrics.EvaluationMetric;
import org.tribuo.evaluation.metrics.MetricTarget;
import org.tribuo.regression.Regressor;
import org.tribuo.regression.evaluation.RegressionMetric;
import org.tribuo.regression.evaluation.RegressionSufficientStatistics;
import org.tribuo.util.Util;

public enum RegressionMetrics {
    R2((target, context) -> RegressionMetrics.r2((MetricTarget<Regressor>)target, context.getMemo())),
    RMSE((target, context) -> RegressionMetrics.rmse((MetricTarget<Regressor>)target, context.getMemo())),
    MAE((target, context) -> RegressionMetrics.mae((MetricTarget<Regressor>)target, context.getMemo())),
    EV((target, context) -> RegressionMetrics.explainedVariance((MetricTarget<Regressor>)target, context.getMemo()));

    private final ToDoubleBiFunction<MetricTarget<Regressor>, RegressionMetric.Context> impl;

    private RegressionMetrics(ToDoubleBiFunction<MetricTarget<Regressor>, RegressionMetric.Context> impl) {
        this.impl = impl;
    }

    RegressionMetric forTarget(MetricTarget<Regressor> target) {
        return new RegressionMetric(target, this.name(), this.impl);
    }

    public static double r2(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) {
        return RegressionMetrics.compute(target, sufficientStats, RegressionMetrics::r2);
    }

    public static double r2(Regressor variable, RegressionSufficientStatistics sufficientStats) {
        String varname = variable.getNames()[0];
        double[] trueArray = sufficientStats.trueValues.get(varname);
        double numerator = sufficientStats.sumSquaredError.get(varname).doubleValue();
        double meanTruth = Util.weightedMean((double[])trueArray, (float[])sufficientStats.weights, (int)sufficientStats.n);
        double denominator = 0.0;
        for (int i = 0; i < sufficientStats.n; ++i) {
            double difference = trueArray[i] - meanTruth;
            float currWeight = sufficientStats.weights[i];
            denominator += (double)currWeight * difference * difference;
        }
        return 1.0 - numerator / denominator;
    }

    public static double rmse(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) {
        return RegressionMetrics.compute(target, sufficientStats, RegressionMetrics::rmse);
    }

    public static double rmse(Regressor variable, RegressionSufficientStatistics sufficientStats) {
        String varname = variable.getNames()[0];
        double sumSqErr = sufficientStats.sumSquaredError.get(varname).doubleValue();
        return Math.sqrt(sumSqErr / (double)sufficientStats.weightSum);
    }

    public static double mae(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) {
        return RegressionMetrics.compute(target, sufficientStats, RegressionMetrics::mae);
    }

    public static double mae(Regressor variable, RegressionSufficientStatistics sufficientStats) {
        String varname = variable.getNames()[0];
        double sumAbsErr = sufficientStats.sumAbsoluteError.get(varname).doubleValue();
        return sumAbsErr / (double)sufficientStats.weightSum;
    }

    public static double explainedVariance(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats) {
        return RegressionMetrics.compute(target, sufficientStats, RegressionMetrics::explainedVariance);
    }

    public static double explainedVariance(Regressor variable, RegressionSufficientStatistics sufficientStats) {
        String varname = variable.getNames()[0];
        double[] trueArray = sufficientStats.trueValues.get(varname);
        double[] predictedArray = sufficientStats.predictedValues.get(varname);
        double meanDifference = 0.0;
        for (int i = 0; i < sufficientStats.n; ++i) {
            meanDifference += (double)sufficientStats.weights[i] * (trueArray[i] - predictedArray[i]);
        }
        meanDifference /= (double)sufficientStats.weightSum;
        double meanTruth = Util.weightedMean((double[])trueArray, (float[])sufficientStats.weights, (int)sufficientStats.n);
        double numerator = 0.0;
        double denominator = 0.0;
        for (int i = 0; i < sufficientStats.n; ++i) {
            float weight = sufficientStats.weights[i];
            double variance = trueArray[i] - predictedArray[i] - meanDifference;
            numerator += (double)weight * variance * variance;
            double difference = trueArray[i] - meanTruth;
            denominator += (double)weight * difference * difference;
        }
        return 1.0 - numerator / denominator;
    }

    private static double compute(MetricTarget<Regressor> target, RegressionSufficientStatistics sufficientStats, BiFunction<Regressor, RegressionSufficientStatistics, Double> impl) {
        if (target.getOutputTarget().isPresent()) {
            return impl.apply((Regressor)target.getOutputTarget().get(), sufficientStats);
        }
        if (target.getAverageTarget().isPresent()) {
            EvaluationMetric.Average averageType = (EvaluationMetric.Average)target.getAverageTarget().get();
            switch (averageType) {
                case MACRO: {
                    double accumulator = 0.0;
                    for (Regressor r : sufficientStats.domain.getDomain()) {
                        accumulator += impl.apply(r, sufficientStats).doubleValue();
                    }
                    return accumulator / (double)sufficientStats.domain.size();
                }
                case MICRO: {
                    throw new IllegalStateException("Micro averages are not supported for regression metrics.");
                }
            }
            throw new IllegalStateException("Unexpected average type " + averageType);
        }
        throw new IllegalStateException("MetricTarget without target.");
    }
}

