/*
 * Decompiled with CFR 0.152.
 */
package qupath.opencv.ml;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.classifiers.pixel.PixelClassifier;
import qupath.lib.classifiers.pixel.PixelClassifierMetadata;
import qupath.lib.images.servers.ColorTransforms;
import qupath.lib.images.servers.ImageServerMetadata;
import qupath.lib.images.servers.PixelCalibration;
import qupath.lib.images.servers.PixelType;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.regions.Padding;
import qupath.opencv.dnn.DnnModel;
import qupath.opencv.ml.pixel.PixelClassifiers;
import qupath.opencv.ops.ImageDataOp;
import qupath.opencv.ops.ImageOp;
import qupath.opencv.ops.ImageOps;

public class PatchClassifierParams {
    private static final Logger logger = LoggerFactory.getLogger(PatchClassifierParams.class);
    private static final int DEFAULT_PATCH_SIZE = 256;
    private List<ColorTransforms.ColorTransform> inputChannels;
    private int patchWidth = 256;
    private int patchHeight = 256;
    private Padding halo = null;
    private PixelCalibration inputResolution = null;
    private ImageServerMetadata.ChannelType outputChannelType;
    private List<ImageOp> preprocessingOps;
    private ImageOp predictionOp;
    private List<ImageOp> postprocessingOps;
    private Map<Integer, PathClass> outputClasses = new LinkedHashMap<Integer, PathClass>();

    private PatchClassifierParams() {
    }

    private PatchClassifierParams(PatchClassifierParams params) {
        this.patchWidth = params.patchWidth;
        this.patchHeight = params.patchHeight;
        this.halo = params.halo;
        this.inputResolution = params.inputResolution;
        this.inputChannels = params.inputChannels == null ? null : new ArrayList<ColorTransforms.ColorTransform>(params.inputChannels);
        this.outputChannelType = params.outputChannelType;
        this.preprocessingOps = params.preprocessingOps == null ? null : new ArrayList<ImageOp>(params.preprocessingOps);
        this.postprocessingOps = params.postprocessingOps == null ? null : new ArrayList<ImageOp>(params.postprocessingOps);
        this.predictionOp = params.predictionOp;
        this.outputClasses = params.outputClasses == null ? null : new LinkedHashMap<Integer, PathClass>(params.outputClasses);
    }

    public List<ColorTransforms.ColorTransform> getInputChannels() {
        return this.inputChannels == null ? Collections.emptyList() : new ArrayList<ColorTransforms.ColorTransform>(this.inputChannels);
    }

    public int getPatchWidth() {
        return this.patchWidth;
    }

    public int getPatchHeight() {
        return this.patchHeight;
    }

    public Padding getHalo() {
        return this.halo == null ? Padding.empty() : this.halo;
    }

    public PixelCalibration getInputResolution() {
        return this.inputResolution;
    }

    public ImageServerMetadata.ChannelType getOutputChannelType() {
        return this.outputChannelType;
    }

    public Map<Integer, PathClass> getOutputClasses() {
        return new LinkedHashMap<Integer, PathClass>(this.outputClasses);
    }

    public List<ImageOp> getPreprocessing() {
        return this.preprocessingOps == null ? Collections.emptyList() : new ArrayList<ImageOp>(this.preprocessingOps);
    }

    public ImageOp getPredictionOp() {
        return this.predictionOp;
    }

    public List<ImageOp> getPostprocessing() {
        return this.postprocessingOps == null ? Collections.emptyList() : new ArrayList<ImageOp>(this.postprocessingOps);
    }

    public static PixelClassifier buildPixelClassifier(PatchClassifierParams params) {
        int pad = 0;
        Padding padding = params.halo;
        if (padding != null && !padding.isEmpty()) {
            if (padding.isSymmetric()) {
                pad = padding.getX1();
            } else {
                logger.warn("Only symmetric padding is supported - {} will be ignored", (Object)padding);
            }
        }
        ImageDataOp dataOp = ImageOps.buildImageDataOp(params.getInputChannels());
        ArrayList<ImageOp> ops = new ArrayList<ImageOp>();
        if (params.getPreprocessing() != null) {
            ops.addAll(params.getPreprocessing());
        }
        if (params.getPredictionOp() != null) {
            ops.add(params.getPredictionOp());
        }
        if (params.getPostprocessing() != null) {
            ops.addAll(params.getPostprocessing());
        }
        dataOp = dataOp.appendOps((ImageOp[])ops.toArray(ImageOp[]::new));
        PixelClassifierMetadata metadata = new PixelClassifierMetadata.Builder().inputShape(params.getPatchWidth(), params.getPatchHeight()).classificationLabels(params.getOutputClasses()).setChannelType(params.getOutputChannelType()).inputResolution(params.getInputResolution()).inputPadding(pad).outputPixelType(dataOp.getOutputType(PixelType.FLOAT32)).build();
        return PixelClassifiers.createClassifier(dataOp, metadata);
    }

    public static Builder builder() {
        return new Builder();
    }

    public static Builder builder(PatchClassifierParams params) {
        return new Builder(params);
    }

    public static class Builder {
        private final PatchClassifierParams params;
        private Padding dnnPadding;
        private DnnModel dnnModel;
        private String[] dnnOutputNames;

        private Builder() {
            this(null);
        }

        private Builder(PatchClassifierParams params) {
            this.params = params == null ? new PatchClassifierParams() : new PatchClassifierParams(params);
        }

        public Builder inputChannels(String ... channels) {
            return this.inputChannels(Arrays.stream(channels).map(ColorTransforms::createChannelExtractor).toList());
        }

        public Builder inputChannels(int ... channels) {
            return this.inputChannels(Arrays.stream(channels).mapToObj(ColorTransforms::createChannelExtractor).toList());
        }

        public Builder inputChannels(Collection<? extends ColorTransforms.ColorTransform> channels) {
            this.params.inputChannels = new ArrayList<ColorTransforms.ColorTransform>(channels);
            return this;
        }

        public Builder inputResolution(PixelCalibration cal, double downsample) {
            if (cal == null) {
                cal = PixelCalibration.getDefaultInstance();
            }
            return this.inputResolution(cal.createScaledInstance(downsample, downsample));
        }

        public Builder inputResolution(PixelCalibration cal) {
            this.params.inputResolution = cal;
            return this;
        }

        public Builder halo(int padding) {
            return this.halo(Padding.symmetric((int)padding));
        }

        public Builder halo(Padding halo) {
            this.params.halo = halo;
            return this;
        }

        public Builder patchSize(int patchSize) {
            return this.patchSize(patchSize, patchSize);
        }

        public Builder patchSize(int patchWidth, int patchHeight) {
            this.params.patchWidth = patchWidth;
            this.params.patchHeight = patchHeight;
            return this;
        }

        public Builder preprocessing(ImageOp ... preprocessingOps) {
            this.params.preprocessingOps = Arrays.asList(preprocessingOps);
            return this;
        }

        public Builder preprocessing(Collection<? extends ImageOp> preprocessingOps) {
            this.params.preprocessingOps = new ArrayList<ImageOp>(preprocessingOps);
            return this;
        }

        public Builder prediction(ImageOp predictionOp) {
            this.params.predictionOp = predictionOp;
            return this;
        }

        public Builder prediction(DnnModel model, Padding padding, String ... outputNames) {
            this.dnnModel = model;
            this.dnnPadding = padding;
            this.dnnOutputNames = outputNames;
            return this;
        }

        public Builder postprocessing(ImageOp ... postprocessingOps) {
            this.params.postprocessingOps = Arrays.asList(postprocessingOps);
            return this;
        }

        public Builder postprocessing(Collection<? extends ImageOp> postprocessingOps) {
            this.params.postprocessingOps = new ArrayList<ImageOp>(postprocessingOps);
            return this;
        }

        public Builder outputChannelType(ImageServerMetadata.ChannelType channelType) {
            this.params.outputChannelType = channelType;
            return this;
        }

        public Builder outputClasses(Map<Integer, PathClass> outputClasses) {
            this.params.outputClasses = new LinkedHashMap<Integer, PathClass>(outputClasses);
            return this;
        }

        public Builder outputClasses(PathClass ... outputClasses) {
            LinkedHashMap<Integer, PathClass> map = new LinkedHashMap<Integer, PathClass>();
            for (int i = 0; i < outputClasses.length; ++i) {
                map.put(i, outputClasses[i]);
            }
            return this.outputClasses(map);
        }

        public Builder outputClassNames(String ... outputClasses) {
            return this.outputClasses((PathClass[])Arrays.stream(outputClasses).map(PathClass::fromString).toArray(PathClass[]::new));
        }

        public Builder outputClassNames(Map<Integer, String> outputClasses) {
            return this.outputClasses(outputClasses.entrySet().stream().collect(Collectors.toMap(e -> (Integer)e.getKey(), e -> PathClass.fromString((String)((String)e.getValue())))));
        }

        public PatchClassifierParams build() {
            PatchClassifierParams params2 = new PatchClassifierParams(this.params);
            if (this.dnnModel != null) {
                if (params2.predictionOp == null) {
                    params2.predictionOp = ImageOps.ML.dnn(this.dnnModel, params2.patchWidth, params2.patchHeight, this.dnnPadding, this.dnnOutputNames);
                } else {
                    throw new IllegalArgumentException("Both DnnModel and prediction op were provided - only one is allowed");
                }
            }
            return params2;
        }
    }
}

