/*
 * Decompiled with CFR 0.152.
 */
package qupath.lib.analysis.stats.survival;

import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.analysis.stats.survival.KaplanMeierData;
import qupath.lib.common.GeneralTools;

public class LogRankTest {
    private static final Logger logger = LoggerFactory.getLogger(LogRankTest.class);
    private static ChiSquaredDistribution chi2 = new ChiSquaredDistribution(1.0);

    public static LogRankResult computeLogRankTest(KaplanMeierData km1, KaplanMeierData km2) {
        if (km1.isEmpty() || km2.isEmpty()) {
            return new LogRankResult();
        }
        KaplanMeierData kmMerged = new KaplanMeierData("Temp", km1.getEvents()).addEvents(km2.getEvents());
        double[] timesEventsMerged = kmMerged.getAllTimes();
        double d1 = 0.0;
        double d2 = 0.0;
        double e1 = 0.0;
        double e2 = 0.0;
        for (int j = 0; j < timesEventsMerged.length; ++j) {
            double t = timesEventsMerged[j];
            double n1j = km1.getAtRisk(t);
            double n2j = km2.getAtRisk(t);
            double d1j = km1.getEventsAtTime(t);
            double d2j = km2.getEventsAtTime(t);
            d1 += d1j;
            d2 += d2j;
            double pd = (d1j + d2j) / (n1j + n2j);
            e1 += pd * n1j;
            e2 += pd * n2j;
        }
        double stat = (d1 - e1) * (d1 - e1) / e1 + (d2 - e2) * (d2 - e2) / e2;
        double logRankPValue = stat < 0.0 ? 1.0 - chi2.cumulativeProbability(-stat) : 1.0 - chi2.cumulativeProbability(stat);
        double hazardRatio = d1 / e1 / (d2 / e2);
        double hazardSE = Math.sqrt(1.0 / e1 + 1.0 / e2);
        double hazardLog = Math.log(hazardRatio);
        double hazardLower = Math.exp(hazardLog - 1.96 * hazardSE);
        double hazardUpper = Math.exp(hazardLog + 1.96 * hazardSE);
        logger.trace(String.format("Log rank: %.4f\tHAZARD: %.3f (%.3f - %.3f)", logRankPValue, hazardRatio, hazardLower, hazardUpper));
        return new LogRankResult(logRankPValue, hazardRatio, hazardLower, hazardUpper);
    }

    public static class LogRankResult {
        final double pValue;
        final double hazardRatio;
        final double hazardRatioLowerConfidence;
        final double hazardRatioUpperConfidence;

        LogRankResult() {
            this.pValue = Double.NaN;
            this.hazardRatio = Double.NaN;
            this.hazardRatioLowerConfidence = Double.NaN;
            this.hazardRatioUpperConfidence = Double.NaN;
        }

        LogRankResult(double pValue, double hazardRatio, double hazardRatioLowerConfidence, double hazardRatioUpperConfidence) {
            this.pValue = pValue;
            this.hazardRatio = hazardRatio;
            this.hazardRatioLowerConfidence = hazardRatioLowerConfidence;
            this.hazardRatioUpperConfidence = hazardRatioUpperConfidence;
        }

        public String getResultString() {
            if (Double.isNaN(this.pValue)) {
                return "-";
            }
            int maxDecimalPlaces = this.pValue > 0.001 ? 4 : (this.pValue > 1.0E-4 ? 5 : (this.pValue > 1.0E-5 ? 6 : (this.pValue > 1.0E-6 ? 7 : 8)));
            String pValueString = GeneralTools.formatNumber(this.pValue, maxDecimalPlaces);
            return String.format("%s (%.2f; %.2f-%.2f)", pValueString, this.hazardRatio, this.hazardRatioLowerConfidence, this.hazardRatioUpperConfidence);
        }

        public boolean isValid() {
            return !Double.isNaN(this.pValue);
        }

        public double getPValue() {
            return this.pValue;
        }

        public double getHazardRatio() {
            return this.hazardRatio;
        }

        public double getHazardRatioLowerConfidence() {
            return this.hazardRatioLowerConfidence;
        }

        public double getHazardRatioUpperConfidence() {
            return this.hazardRatioUpperConfidence;
        }
    }
}

