/*
 * Decompiled with CFR 0.152.
 */
package qupath.lib.color;

import java.awt.image.BufferedImage;
import java.awt.image.Raster;
import java.awt.image.WritableRaster;
import java.util.Arrays;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.awt.common.BufferedImageTools;
import qupath.lib.color.ColorDeconvolutionStains;
import qupath.lib.color.ColorTransformer;
import qupath.lib.color.StainVector;
import qupath.lib.common.ColorTools;

public class ColorDeconvolutionHelper {
    private static final Logger logger = LoggerFactory.getLogger(ColorDeconvolutionHelper.class);

    public static double makeOD(double val, double max) {
        return Math.max(0.0, -Math.log10(Math.max(val, 1.0) / max));
    }

    public static double makeODByLUT(int val, double[] OD_LUT) {
        if (val >= 0 && val < OD_LUT.length) {
            return OD_LUT[val];
        }
        return Double.NaN;
    }

    public static double makeODByLUT(float val, double[] OD_LUT) {
        return ColorDeconvolutionHelper.makeODByLUT(Math.round(val), OD_LUT);
    }

    public static double[] makeODLUT(double maxValue) {
        return ColorDeconvolutionHelper.makeODLUT(maxValue, 256);
    }

    public static double[] makeODLUT(double maxValue, int nValues) {
        double[] OD_LUT = new double[nValues];
        for (int i = 0; i < nValues; ++i) {
            OD_LUT[i] = ColorDeconvolutionHelper.makeOD(i, maxValue);
        }
        return OD_LUT;
    }

    public static void convertPixelsToOpticalDensities(float[] px, double maxValue, boolean use8BitLUT) {
        if (use8BitLUT) {
            double[] od_lut = ColorDeconvolutionHelper.makeODLUT(maxValue, 256);
            for (int i = 0; i < px.length; ++i) {
                px[i] = (float)ColorDeconvolutionHelper.makeODByLUT(px[i], od_lut);
            }
        } else {
            for (int i = 0; i < px.length; ++i) {
                px[i] = (float)ColorDeconvolutionHelper.makeOD(px[i], maxValue);
            }
        }
    }

    public static float[] getRedOpticalDensities(int[] rgb, double maxValue, float[] px) {
        if (px == null) {
            px = new float[rgb.length];
        }
        double[] od_lut = ColorDeconvolutionHelper.makeODLUT(maxValue, 256);
        for (int i = 0; i < px.length; ++i) {
            px[i] = (float)ColorDeconvolutionHelper.makeODByLUT(ColorTools.red(rgb[i]), od_lut);
        }
        return px;
    }

    public static float[] getOpticalDensities(Raster raster, int band, double maxValue, float[] px) {
        px = ColorDeconvolutionHelper.getPixels(raster, band, px);
        ColorDeconvolutionHelper.convertToOpticalDensity(px, maxValue);
        return px;
    }

    public static float[] colorDeconvolve(BufferedImage img, ColorDeconvolutionStains stains, int stainNumber, float[] pixels) {
        if (BufferedImageTools.is8bitColorType(img.getType())) {
            ColorTransformer.ColorTransformMethod method;
            int[] rgb = img.getRGB(0, 0, img.getWidth(), img.getHeight(), null, 0, img.getWidth());
            if (stainNumber == 0) {
                method = ColorTransformer.ColorTransformMethod.Stain_1;
            } else if (stainNumber == 1) {
                method = ColorTransformer.ColorTransformMethod.Stain_2;
            } else if (stainNumber == 2) {
                method = ColorTransformer.ColorTransformMethod.Stain_3;
            } else {
                throw new IllegalArgumentException("Unsupported stain number: " + stainNumber);
            }
            return ColorTransformer.getTransformedPixels(rgb, method, pixels, stains);
        }
        WritableRaster raster = img.getRaster();
        float[] r = ColorDeconvolutionHelper.getChannelOpticalDensity(raster, 0, stains.getMaxRed());
        float[] g = ColorDeconvolutionHelper.getChannelOpticalDensity(raster, 1, stains.getMaxGreen());
        float[] b = ColorDeconvolutionHelper.getChannelOpticalDensity(raster, 2, stains.getMaxBlue());
        double[][] invMat = stains.getMatrixInverse();
        double rScale = invMat[0][stainNumber];
        double gScale = invMat[1][stainNumber];
        double bScale = invMat[2][stainNumber];
        if (pixels == null || pixels.length < r.length) {
            pixels = new float[r.length];
        }
        for (int i = 0; i < pixels.length; ++i) {
            pixels[i] = (float)((double)r[i] * rScale + (double)g[i] * gScale + (double)b[i] * bScale);
        }
        return pixels;
    }

    public static void colorDeconvolve(float[] red, float[] green, float[] blue, ColorDeconvolutionStains stains) {
        ColorDeconvolutionHelper.convertToOpticalDensity(red, stains.getMaxRed());
        ColorDeconvolutionHelper.convertToOpticalDensity(green, stains.getMaxGreen());
        ColorDeconvolutionHelper.convertToOpticalDensity(blue, stains.getMaxBlue());
        double[][] invMat = stains.getMatrixInverse();
        int n = red.length;
        for (int i = 0; i < n; ++i) {
            double r = red[i];
            double g = green[i];
            double b = blue[i];
            red[i] = (float)(r * invMat[0][0] + g * invMat[1][0] + b * invMat[2][0]);
            green[i] = (float)(r * invMat[0][1] + g * invMat[1][1] + b * invMat[2][1]);
            blue[i] = (float)(r * invMat[0][2] + g * invMat[1][2] + b * invMat[2][2]);
        }
    }

    private static float[] getChannelOpticalDensity(WritableRaster raster, int channel, double max) {
        float[] arr = raster.getSamples(0, 0, raster.getWidth(), raster.getHeight(), channel, (float[])null);
        ColorDeconvolutionHelper.convertToOpticalDensity(arr, max);
        return arr;
    }

    public static void convertToOpticalDensity(float[] pixels, double max) {
        double minValue = 0.00392156862745098;
        for (int i = 0; i < pixels.length; ++i) {
            pixels[i] = (float)(-Math.log10(Math.max((double)pixels[i] / max, minValue)));
        }
    }

    public static float[] getPixels(Raster raster, int band, float[] px) {
        return raster.getSamples(0, 0, raster.getWidth(), raster.getHeight(), band, px);
    }

    public static float[] getPixels(Raster raster, int band) {
        return ColorDeconvolutionHelper.getPixels(raster, band, null);
    }

    public static float[] getGreenOpticalDensities(int[] rgb, double maxValue, float[] px) {
        if (px == null) {
            px = new float[rgb.length];
        }
        double[] od_lut = ColorDeconvolutionHelper.makeODLUT(maxValue, 256);
        for (int i = 0; i < px.length; ++i) {
            px[i] = (float)ColorDeconvolutionHelper.makeODByLUT(ColorTools.green(rgb[i]), od_lut);
        }
        return px;
    }

    public static float[] getBlueOpticalDensities(int[] rgb, double maxValue, float[] px) {
        if (px == null) {
            px = new float[rgb.length];
        }
        double[] od_lut = ColorDeconvolutionHelper.makeODLUT(maxValue, 256);
        for (int i = 0; i < px.length; ++i) {
            px[i] = (float)ColorDeconvolutionHelper.makeODByLUT(ColorTools.blue(rgb[i]), od_lut);
        }
        return px;
    }

    public static StainVector generateMedianStainVectorFromPixels(String name, BufferedImage img, double redMax, double greenMax, double blueMax) {
        if (BufferedImageTools.is8bitColorType(img.getType())) {
            int[] rgb = img.getRGB(0, 0, img.getWidth(), img.getHeight(), null, 0, img.getWidth());
            return ColorDeconvolutionHelper.generateMedianStainVectorFromPixels(name, rgb, redMax, greenMax, blueMax);
        }
        float[] red = ColorDeconvolutionHelper.getPixels(img.getRaster(), 0);
        float[] green = ColorDeconvolutionHelper.getPixels(img.getRaster(), 1);
        float[] blue = ColorDeconvolutionHelper.getPixels(img.getRaster(), 2);
        ColorDeconvolutionHelper.convertToOpticalDensity(red, redMax);
        ColorDeconvolutionHelper.convertToOpticalDensity(green, greenMax);
        ColorDeconvolutionHelper.convertToOpticalDensity(blue, blueMax);
        return ColorDeconvolutionHelper.generateMedianStainVector(name, red, green, blue);
    }

    public static StainVector generateMedianStainVectorFromPixels(String name, int[] rgb, double redMax, double greenMax, double blueMax) {
        int n = rgb.length;
        float[] red = ColorTransformer.getSimpleTransformedPixels(rgb, ColorTransformer.ColorTransformMethod.Red, null);
        float[] green = ColorTransformer.getSimpleTransformedPixels(rgb, ColorTransformer.ColorTransformMethod.Green, null);
        float[] blue = ColorTransformer.getSimpleTransformedPixels(rgb, ColorTransformer.ColorTransformMethod.Blue, null);
        ColorDeconvolutionHelper.convertPixelsToOpticalDensities(red, redMax, n > 500);
        ColorDeconvolutionHelper.convertPixelsToOpticalDensities(green, greenMax, n > 500);
        ColorDeconvolutionHelper.convertPixelsToOpticalDensities(blue, blueMax, n > 500);
        return ColorDeconvolutionHelper.generateMedianStainVector(name, red, green, blue);
    }

    private static StainVector generateMedianStainVector(String name, float[] red, float[] green, float[] blue) {
        double b;
        double g;
        double r;
        int n = red.length;
        for (int i = 0; i < n; ++i) {
            r = red[i];
            g = green[i];
            b = blue[i];
            double denominator = Math.sqrt(r * r + g * g + b * b);
            red[i] = (float)(r / denominator);
            green[i] = (float)(g / denominator);
            blue[i] = (float)(b / denominator);
        }
        Arrays.sort(red);
        Arrays.sort(green);
        Arrays.sort(blue);
        int medianInd = n / 2;
        r = red[medianInd];
        g = green[medianInd];
        b = blue[medianInd];
        return StainVector.createStainVector(name, r, g, b);
    }

    public static int getMedianRGB(int[] rgb) {
        int n = rgb.length;
        int[] temp = new int[n];
        for (int i = 0; i < rgb.length; ++i) {
            temp[i] = ColorTools.red(rgb[i]);
        }
        int rMedian = ColorDeconvolutionHelper.getMedian(temp);
        for (int i = 0; i < rgb.length; ++i) {
            temp[i] = ColorTools.green(rgb[i]);
        }
        int gMedian = ColorDeconvolutionHelper.getMedian(temp);
        for (int i = 0; i < rgb.length; ++i) {
            temp[i] = ColorTools.blue(rgb[i]);
        }
        int bMedian = ColorDeconvolutionHelper.getMedian(temp);
        return ColorTools.packRGB(rMedian, gMedian, bMedian);
    }

    private static int getMedian(int[] array) {
        Arrays.sort(array);
        if (array.length % 2 == 0) {
            return array[array.length / 2 - 1] + array[array.length / 2];
        }
        return array[array.length / 2];
    }

    public static float getMedian(float[] array) {
        Arrays.sort(array);
        if (array.length % 2 == 0) {
            return (array[array.length / 2 - 1] + array[array.length / 2]) / 2.0f;
        }
        return array[array.length / 2];
    }

    public static ColorDeconvolutionStains refineColorDeconvolutionStains(int[] rgb, ColorDeconvolutionStains stains, double minStain, double percentageClipped) {
        logger.warn("WARNING!  Stain vector refinement is only for testing - treat the results with caution!");
        int n = rgb.length;
        double[] whiteValues = ColorDeconvolutionHelper.estimateWhiteValues(rgb);
        float[] redOD = ColorDeconvolutionHelper.getRedOpticalDensities(rgb, whiteValues[0], null);
        float[] greenOD = ColorDeconvolutionHelper.getGreenOpticalDensities(rgb, whiteValues[1], null);
        float[] blueOD = ColorDeconvolutionHelper.getBlueOpticalDensities(rgb, whiteValues[2], null);
        boolean[] mask = ColorDeconvolutionHelper.createStainMask(redOD, greenOD, blueOD, minStain, stains.isH_DAB() || stains.isH_E(), true, null);
        int nnz = 0;
        for (boolean m : mask) {
            if (!m) continue;
            ++nnz;
        }
        double[] stain1 = stains.getStain(1).getArray();
        double[] stain2 = stains.getStain(2).getArray();
        float[] stain1Proj = new float[nnz];
        float[] stain2Proj = new float[nnz];
        int[] indices = new int[nnz];
        int ind = 0;
        for (int i = 0; i < n; ++i) {
            if (!mask[i]) continue;
            double r = redOD[i];
            double g = greenOD[i];
            double b = blueOD[i];
            double norm = Math.sqrt(r * r + g * g + b * b);
            stain1Proj[ind] = (float)((r * stain1[0] + g * stain1[1] + b * stain1[2]) / norm);
            stain2Proj[ind] = (float)((r * stain2[0] + g * stain2[1] + b * stain2[2]) / norm);
            indices[ind] = i;
            ++ind;
        }
        float[] temp = Arrays.copyOf(stain1Proj, nnz);
        Arrays.sort(temp);
        float stain1Threshold = temp[(int)((double)nnz * 0.98)];
        System.arraycopy(stain2Proj, 0, temp, 0, nnz);
        Arrays.sort(temp);
        float stain2Threshold = temp[(int)((double)nnz * 0.98)];
        double r1Sum = 0.0;
        double g1Sum = 0.0;
        double b1Sum = 0.0;
        int n1 = 0;
        double r2Sum = 0.0;
        double g2Sum = 0.0;
        double b2Sum = 0.0;
        int n2 = 0;
        for (int i = 0; i < indices.length; ++i) {
            ind = indices[i];
            if (stain1Proj[i] >= stain1Threshold) {
                r1Sum += (double)redOD[ind];
                g1Sum += (double)greenOD[ind];
                b1Sum += (double)blueOD[ind];
                ++n1;
            }
            if (!(stain2Proj[i] >= stain2Threshold)) continue;
            r2Sum += (double)redOD[ind];
            g2Sum += (double)greenOD[ind];
            b2Sum += (double)blueOD[ind];
            ++n2;
        }
        StainVector stainBase1 = StainVector.createStainVector("Basis 1", r1Sum / (double)n1, g1Sum / (double)n1, b1Sum / (double)n1);
        StainVector stainBase2 = StainVector.createStainVector("Basis 2", r2Sum / (double)n2, g2Sum / (double)n2, b2Sum / (double)n2);
        stainBase1 = StainVector.createStainVector("Basis 1", (stainBase1.getRed() + stainBase2.getRed()) / 2.0, (stainBase1.getGreen() + stainBase2.getGreen()) / 2.0, (stainBase1.getBlue() + stainBase2.getBlue()) / 2.0);
        StainVector stainNorm = StainVector.makeResidualStainVector(stainBase1, stainBase2);
        stainBase2 = StainVector.makeOrthogonalStainVector("Basis 2", stainBase1, stainNorm, false);
        double[] base1 = stainBase1.getArray();
        double[] base2 = stainBase2.getArray();
        double[] angles = new double[nnz];
        int nAngles = 0;
        for (int i = 0; i < n; ++i) {
            if (!mask[i]) continue;
            double r = redOD[i];
            double g = greenOD[i];
            double b = blueOD[i];
            double proj1 = r * base1[0] + g * base1[1] + b * base1[2];
            double proj2 = r * base2[0] + g * base2[1] + b * base2[2];
            angles[nAngles] = Math.atan2(proj2, proj1);
            if (++nAngles == nnz) break;
        }
        Arrays.sort(angles);
        double alpha = Math.min(Math.max(0.1, percentageClipped), 25.0) / 100.0;
        double minAngle = angles[(int)((double)nAngles * alpha)];
        double maxAngle = angles[nAngles - (int)((double)nAngles * alpha) - 1];
        double cos = Math.cos(minAngle);
        double sin = Math.sin(minAngle);
        StainVector stain2Refined = StainVector.createStainVector(stains.getStain(2).getName(), base1[0] * cos + base2[0] * sin, base1[1] * cos + base2[1] * sin, base1[2] * cos + base2[2] * sin);
        cos = Math.cos(maxAngle);
        sin = Math.sin(maxAngle);
        StainVector stain1Refined = StainVector.createStainVector(stains.getStain(1).getName(), base1[0] * cos + base2[0] * sin, base1[1] * cos + base2[1] * sin, base1[2] * cos + base2[2] * sin);
        return new ColorDeconvolutionStains(stains.getName(), stain1Refined, stain2Refined, whiteValues[0], whiteValues[1], whiteValues[2]);
    }

    private static boolean[] createStainMask(float[] redOD, float[] greenOD, float[] blueOD, double stainThreshold, boolean excludeGray, boolean excludeUncommonColors, boolean[] mask) {
        if (mask == null) {
            mask = new boolean[redOD.length];
            Arrays.fill(mask, true);
        }
        double gray = 1.0 / Math.sqrt(3.0);
        double grayThreshold = Math.cos(0.2);
        for (int i = 0; i < mask.length; ++i) {
            double r = redOD[i];
            double g = greenOD[i];
            double b = blueOD[i];
            if (r < stainThreshold || g < stainThreshold || b < stainThreshold) {
                mask[i] = false;
                continue;
            }
            if (excludeGray) {
                double norm = Math.sqrt(r * r + g * g + b * b);
                if (!(r * gray + g * gray + b * gray >= grayThreshold * norm)) continue;
                mask[i] = false;
                continue;
            }
            if (!excludeUncommonColors || !(g < r && r <= b) && (!(g <= b) || !(b < r))) continue;
            mask[i] = false;
        }
        return mask;
    }

    public static double[] estimateWhiteValues(int[] rgb) {
        int[] countsRed = new int[256];
        int[] countsGreen = new int[256];
        int[] countsBlue = new int[256];
        double meanRed = 0.0;
        double meanGreen = 0.0;
        double meanBlue = 0.0;
        double scale = 1.0 / (double)rgb.length;
        for (int val : rgb) {
            int red = ColorTools.red(val);
            int green = ColorTools.green(val);
            int blue = ColorTools.blue(val);
            int n = red;
            countsRed[n] = countsRed[n] + 1;
            int n2 = green;
            countsGreen[n2] = countsGreen[n2] + 1;
            int n3 = blue;
            countsBlue[n3] = countsBlue[n3] + 1;
            meanRed += (double)red * scale;
            meanGreen += (double)green * scale;
            meanBlue += (double)blue * scale;
        }
        int modeRedCount = Integer.MIN_VALUE;
        int modeGreenCount = Integer.MIN_VALUE;
        int modeBlueCount = Integer.MIN_VALUE;
        int modeRed = 0;
        int modeGreen = 0;
        int modeBlue = 0;
        for (int i = 0; i < 256; ++i) {
            if ((double)i > meanRed && countsRed[i] >= modeRedCount) {
                modeRedCount = countsRed[i];
                modeRed = i;
            }
            if ((double)i > meanGreen && countsGreen[i] >= modeGreenCount) {
                modeGreenCount = countsGreen[i];
                modeGreen = i;
            }
            if (!((double)i > meanBlue) || countsBlue[i] < modeBlueCount) continue;
            modeBlueCount = countsBlue[i];
            modeBlue = i;
        }
        return new double[]{modeRed, modeGreen, modeBlue};
    }
}

