/*
 * Decompiled with CFR 0.152.
 */
package qupath.opencv.ml.pixel;

import com.google.gson.Gson;
import com.google.gson.TypeAdapter;
import com.google.gson.TypeAdapterFactory;
import com.google.gson.reflect.TypeToken;
import java.awt.image.BufferedImage;
import java.awt.image.IndexColorModel;
import java.awt.image.WritableRaster;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.io.Reader;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Map;
import qupath.lib.classifiers.pixel.PixelClassifier;
import qupath.lib.classifiers.pixel.PixelClassifierMetadata;
import qupath.lib.color.ColorModelFactory;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.ImageServerMetadata;
import qupath.lib.images.servers.PixelCalibration;
import qupath.lib.images.servers.PixelType;
import qupath.lib.io.GsonTools;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.regions.RegionRequest;
import qupath.opencv.ml.OpenCVClassifiers;
import qupath.opencv.ml.pixel.OpenCVPixelClassifier;
import qupath.opencv.ops.ImageDataOp;
import qupath.opencv.ops.ImageOp;
import qupath.opencv.ops.ImageOps;

public class PixelClassifiers {
    private static final TypeAdapterFactory factory = new PixelClassifierTypeAdapterFactory();

    public static TypeAdapterFactory getTypeAdapterFactory() {
        return factory;
    }

    public static PixelClassifier readClassifier(Path path) throws IOException {
        try (BufferedReader reader = Files.newBufferedReader(path, StandardCharsets.UTF_8);){
            PixelClassifier pixelClassifier = (PixelClassifier)GsonTools.getInstance().fromJson((Reader)reader, PixelClassifier.class);
            return pixelClassifier;
        }
    }

    public static void writeClassifier(PixelClassifier classifier, Path path) throws IOException {
        try (BufferedWriter writer = Files.newBufferedWriter(path, StandardCharsets.UTF_8, new OpenOption[0]);){
            GsonTools.getInstance((boolean)true).toJson((Object)classifier, PixelClassifier.class, (Appendable)writer);
        }
    }

    public static PixelClassifier createClassifier(ImageDataOp op, PixelCalibration inputResolution, Map<Integer, PathClass> classifications) {
        PixelClassifierMetadata metadata = new PixelClassifierMetadata.Builder().inputResolution(inputResolution).classificationLabels(classifications).inputShape(512, 512).setChannelType(ImageServerMetadata.ChannelType.CLASSIFICATION).build();
        return new OpenCVPixelClassifier(op, metadata);
    }

    public static PixelClassifier createClassifier(ImageDataOp op, PixelClassifierMetadata metadata) {
        return new OpenCVPixelClassifier(op, metadata);
    }

    public static PixelClassifier createClassifier(OpenCVClassifiers.OpenCVStatModel statModel, ImageDataOp calculator, PixelClassifierMetadata metadata, boolean do8Bit) {
        ArrayList<ImageOp> ops = new ArrayList<ImageOp>();
        boolean outputProbability = metadata.getOutputType() == ImageServerMetadata.ChannelType.PROBABILITY || metadata.getOutputType() == ImageServerMetadata.ChannelType.MULTICLASS_PROBABILITY;
        ops.add(ImageOps.ML.statModel(statModel, outputProbability));
        if (metadata.getOutputType() == ImageServerMetadata.ChannelType.PROBABILITY) {
            if (do8Bit) {
                ops.add(ImageOps.Normalize.channelSum(255.0));
            } else {
                ops.add(ImageOps.Normalize.channelSum(1.0));
            }
        }
        if (do8Bit) {
            ops.add(ImageOps.Core.ensureType(PixelType.UINT8));
        }
        ImageDataOp op = calculator.appendOps((ImageOp[])ops.toArray(ImageOp[]::new));
        return new OpenCVPixelClassifier(op, metadata);
    }

    public static PixelClassifier createThresholdClassifier(PixelCalibration inputResolution, int channel, double threshold, PathClass below, PathClass aboveEquals) {
        ClassifierFunction fun = PixelClassifiers.createThresholdFunction(channel, threshold);
        Map<Integer, PathClass> labels = Map.of(0, below, 1, aboveEquals);
        return PixelClassifiers.createThresholdClassifier(inputResolution, labels, fun);
    }

    public static PixelClassifier createThresholdClassifier(PixelCalibration inputResolution, Map<Integer, ? extends Number> thresholds, PathClass below, PathClass aboveEquals) {
        ClassifierFunction fun = PixelClassifiers.createThresholdFunction(thresholds);
        Map<Integer, PathClass> labels = Map.of(0, below, 1, aboveEquals);
        return PixelClassifiers.createThresholdClassifier(inputResolution, labels, fun);
    }

    static ClassifierFunction createThresholdFunction(Map<Integer, ? extends Number> thresholds) {
        if (thresholds.size() == 1) {
            Map.Entry<Integer, ? extends Number> entry = thresholds.entrySet().iterator().next();
            return PixelClassifiers.createThresholdFunction(entry.getKey(), entry.getValue().doubleValue());
        }
        return new MultiThresholdClassifierFunction(thresholds);
    }

    static ClassifierFunction createThresholdFunction(int channel, double threshold) {
        return new ClassifierGreaterEquals(channel, (float)threshold);
    }

    static PixelClassifier createThresholdClassifier(PixelCalibration inputResolution, Map<Integer, PathClass> labels, ClassifierFunction fun) {
        PixelClassifierMetadata metadata = new PixelClassifierMetadata.Builder().classificationLabels(labels).setChannelType(ImageServerMetadata.ChannelType.CLASSIFICATION).inputResolution(inputResolution).build();
        return new ThresholdPixelClassifier(metadata, fun);
    }

    static interface ClassifierFunction {
        public int predict(float[] var1);
    }

    static class MultiThresholdClassifierFunction
    implements ClassifierFunction {
        private int n;
        private int[] bands;
        protected float[] thresholds;

        MultiThresholdClassifierFunction(Map<Integer, ? extends Number> thresholds) {
            this.n = thresholds.size();
            this.bands = new int[this.n];
            this.thresholds = new float[this.n];
            int i = 0;
            for (Map.Entry<Integer, ? extends Number> entry : thresholds.entrySet()) {
                this.bands[i] = entry.getKey();
                this.thresholds[i] = entry.getValue().floatValue();
                ++i;
            }
        }

        MultiThresholdClassifierFunction(int[] bands, float[] thresholds) {
            this.n = bands.length;
            this.bands = (int[])bands.clone();
            this.thresholds = (float[])thresholds.clone();
        }

        @Override
        public int predict(float[] input) {
            for (int i = 0; i < this.n; ++i) {
                float val = input[this.bands[i]];
                if (!Float.isNaN(val) && !(val < this.thresholds[i])) continue;
                return 0;
            }
            return 1;
        }
    }

    static class ClassifierGreaterEquals
    extends SingleThresholdClassifierFunction {
        ClassifierGreaterEquals(int band, float threshold) {
            super(band, threshold);
        }

        @Override
        protected int predict(float input) {
            if (Float.isNaN(input)) {
                return 0;
            }
            return input >= this.threshold ? 1 : 0;
        }
    }

    static class ThresholdPixelClassifier
    implements PixelClassifier {
        private PixelClassifierMetadata metadata;
        private ClassifierFunction fun;
        private transient IndexColorModel cm;

        ThresholdPixelClassifier(PixelClassifierMetadata metadata, ClassifierFunction fun) {
            this.metadata = metadata;
            this.fun = fun;
        }

        public boolean supportsImage(ImageData<BufferedImage> imageData) {
            return this.metadata.getInputNumChannels() == imageData.getServer().nChannels();
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        protected IndexColorModel getColorModel() {
            if (this.cm != null) {
                return this.cm;
            }
            ThresholdPixelClassifier thresholdPixelClassifier = this;
            synchronized (thresholdPixelClassifier) {
                if (this.cm == null) {
                    this.cm = ColorModelFactory.getIndexedClassificationColorModel((Map)this.metadata.getClassificationLabels());
                }
            }
            return this.cm;
        }

        protected BufferedImage createOutputImage(BufferedImage imgInput) {
            return new BufferedImage(imgInput.getWidth(), imgInput.getHeight(), 13, this.getColorModel());
        }

        public BufferedImage applyClassification(ImageData<BufferedImage> imageData, RegionRequest request) throws IOException {
            BufferedImage img = (BufferedImage)imageData.getServer().readRegion(request);
            WritableRaster raster = img.getRaster();
            int w = img.getWidth();
            int h = img.getHeight();
            BufferedImage imgOutput = this.createOutputImage(img);
            WritableRaster rasterOutput = imgOutput.getRaster();
            float[] px = new float[raster.getNumBands()];
            for (int y = 0; y < h; ++y) {
                for (int x = 0; x < w; ++x) {
                    px = raster.getPixel(x, y, px);
                    int output = this.fun.predict(px);
                    rasterOutput.setSample(x, y, 0, output);
                }
            }
            return imgOutput;
        }

        public PixelClassifierMetadata getMetadata() {
            return this.metadata;
        }
    }

    private static class PixelClassifierTypeAdapterFactory
    implements TypeAdapterFactory {
        private static String typeName = "pixel_classifier_type";
        private static final GsonTools.SubTypeAdapterFactory<PixelClassifier> pixelClassifierTypeAdapter = GsonTools.createSubTypeAdapterFactory(PixelClassifier.class, (String)typeName).registerSubtype(OpenCVPixelClassifier.class).registerSubtype(ThresholdPixelClassifier.class);
        private static final TypeAdapterFactory classifierFunctionTypeAdapter = GsonTools.createSubTypeAdapterFactory(ClassifierFunction.class, (String)"classifier_function").registerSubtype(ClassifierGreaterEquals.class).registerSubtype(ClassifierGreater.class).registerSubtype(MultiThresholdClassifierFunction.class);

        PixelClassifierTypeAdapterFactory() {
        }

        public static void registerSubtype(Class<? extends PixelClassifier> cls) {
            pixelClassifierTypeAdapter.registerSubtype(cls);
        }

        public <T> TypeAdapter<T> create(Gson gson, TypeToken<T> type) {
            TypeAdapter adapter = pixelClassifierTypeAdapter.create(gson, type);
            if (adapter == null) {
                return classifierFunctionTypeAdapter.create(gson, type);
            }
            return adapter;
        }
    }

    static class ClassifierGreater
    extends SingleThresholdClassifierFunction {
        ClassifierGreater(int band, float threshold) {
            super(band, threshold);
        }

        @Override
        protected int predict(float input) {
            if (Float.isNaN(input)) {
                return 0;
            }
            return input > this.threshold ? 1 : 0;
        }
    }

    static abstract class SingleThresholdClassifierFunction
    implements ClassifierFunction {
        private int band;
        protected float threshold;

        SingleThresholdClassifierFunction(int band, float threshold) {
            this.band = band;
            this.threshold = threshold;
        }

        @Override
        public int predict(float[] input) {
            return this.predict(input[this.band]);
        }

        protected abstract int predict(float var1);
    }

    static class ThresholdClassifierBuilder {
        private PixelCalibration inputResolution;
        private ClassifierFunction fun;
        private Map<Integer, PathClass> labels;

        ThresholdClassifierBuilder() {
        }

        public PixelClassifier build() {
            PixelClassifierMetadata metadata = new PixelClassifierMetadata.Builder().classificationLabels(this.labels).setChannelType(ImageServerMetadata.ChannelType.CLASSIFICATION).inputResolution(this.inputResolution).build();
            return new ThresholdPixelClassifier(metadata, this.fun);
        }
    }
}

