/*
 * Decompiled with CFR 0.152.
 */
package qupath.opencv.ml;

import java.net.URI;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.bioimageio.spec.Model;
import qupath.bioimageio.spec.Weights;
import qupath.bioimageio.spec.tensor.InputTensor;
import qupath.bioimageio.spec.tensor.OutputTensor;
import qupath.bioimageio.spec.tensor.Processing;
import qupath.bioimageio.spec.tensor.axes.Axes;
import qupath.bioimageio.spec.tensor.axes.Axis;
import qupath.lib.common.GeneralTools;
import qupath.lib.images.servers.ImageServerMetadata;
import qupath.lib.images.servers.PixelType;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.regions.Padding;
import qupath.opencv.dnn.DnnModel;
import qupath.opencv.dnn.DnnModelParams;
import qupath.opencv.dnn.DnnModels;
import qupath.opencv.dnn.DnnShape;
import qupath.opencv.ml.PatchClassifierParams;
import qupath.opencv.ops.ImageOp;
import qupath.opencv.ops.ImageOps;

public class BioimageIoTools {
    private static final Logger logger = LoggerFactory.getLogger(BioimageIoTools.class);

    public static DnnModel buildDnnModel(Model spec) {
        DnnModel dnn = null;
        for (Weights.WeightsEntry key : Arrays.asList(Weights.WeightsEntry.TORCHSCRIPT, Weights.WeightsEntry.TENSORFLOW_SAVED_MODEL_BUNDLE, Weights.WeightsEntry.ONNX, Weights.WeightsEntry.TENSORFLOW_JS, Weights.WeightsEntry.PYTORCH_STATE_DICT, Weights.WeightsEntry.KERAS_HDF5)) {
            try {
                Weights.ModelWeights weights = spec.getWeights(key);
                if (weights == null) continue;
                URI baseUri = spec.getBaseURI();
                Path basePath = GeneralTools.toPath((URI)baseUri);
                String relativeSource = weights.getSource();
                if (relativeSource.toLowerCase().startsWith("http:") || relativeSource.toLowerCase().startsWith("https:")) {
                    logger.debug("Don't support source {}", (Object)relativeSource);
                    continue;
                }
                URI source = null;
                if (basePath == null) {
                    source = BioimageIoTools.resolveUri(baseUri, relativeSource);
                } else {
                    Path pathWeights = null;
                    if (relativeSource.toLowerCase().endsWith(".zip")) {
                        pathWeights = basePath.resolve(relativeSource.substring(0, relativeSource.length() - 4));
                        if (!Files.exists(pathWeights, new LinkOption[0])) {
                            logger.warn("Please unzip the model weights to {}", (Object)pathWeights);
                            continue;
                        }
                    } else {
                        pathWeights = basePath.resolve(relativeSource);
                        if (!Files.exists(pathWeights, new LinkOption[0])) {
                            logger.warn("Can't find model weights at {}", (Object)pathWeights);
                            continue;
                        }
                    }
                    source = pathWeights.toUri();
                }
                String frameworkName = null;
                switch (key) {
                    case ONNX: {
                        frameworkName = "OnnxRuntime";
                        break;
                    }
                    case TENSORFLOW_SAVED_MODEL_BUNDLE: {
                        frameworkName = "TensorFlow";
                        break;
                    }
                    case TORCHSCRIPT: {
                        frameworkName = "PyTorch";
                        break;
                    }
                    case PYTORCH_STATE_DICT: 
                    case TENSORFLOW_JS: 
                    case KERAS_HDF5: {
                        break;
                    }
                }
                String axes = Axes.getAxesString((Axis[])((InputTensor)spec.getInputs().getFirst()).getAxes());
                Map<String, DnnShape> inputShapeMap = spec.getInputs().stream().collect(Collectors.toMap(i -> i.getName(), i -> BioimageIoTools.getMinShape(i)));
                DnnShape inputShape = inputShapeMap.size() == 1 ? inputShapeMap.values().iterator().next() : null;
                Map<String, DnnShape> outputShapeMap = spec.getOutputs().stream().collect(Collectors.toMap(o -> o.getName(), o -> BioimageIoTools.getOutputShapeFromInput(o, inputShape)));
                DnnModelParams params = DnnModelParams.builder().framework(frameworkName).URIs(source).layout(axes).inputs(inputShapeMap).outputs(outputShapeMap).build();
                dnn = DnnModels.buildModel(params);
                if (dnn != null) {
                    logger.info("Loaded model {}", (Object)dnn);
                } else {
                    logger.warn("Unable to build model for weights {} (source={}, framework={})", new Object[]{key, source, frameworkName});
                }
            }
            catch (Exception e) {
                logger.warn("Unsupported weights: {}", (Object)key);
                logger.error(e.getLocalizedMessage(), (Throwable)e);
            }
            if (dnn == null) continue;
            break;
        }
        return dnn;
    }

    public static PatchClassifierParams buildPatchClassifierParams(Model model, ImageOp ... inputOps) {
        return BioimageIoTools.buildPatchClassifierParams(model, -1, -1, inputOps);
    }

    public static PatchClassifierParams buildPatchClassifierParams(Model modelSpec, int preferredTileWidth, int preferredTileHeight, ImageOp ... inputOps) {
        DnnModel dnn;
        List inputs = modelSpec.getInputs();
        if (inputs.size() != 1) {
            throw new UnsupportedOperationException("Only single inputs currently supported! Model requires " + inputs.size());
        }
        List outputs = modelSpec.getOutputs();
        if (outputs.size() != 1) {
            throw new UnsupportedOperationException("Only single outputs currently supported! Model requires " + outputs.size());
        }
        InputTensor input = (InputTensor)inputs.getFirst();
        OutputTensor output = (OutputTensor)outputs.getFirst();
        String axes = Axes.getAxesString((Axis[])input.getAxes());
        int indChannels = axes.indexOf("c");
        int indX = axes.indexOf("x");
        int indY = axes.indexOf("y");
        int[] shapeMin = input.getShape().getShapeMin();
        int[] shapeStep = input.getShape().getShapeMin();
        int width = shapeMin[indX];
        int height = shapeMin[indY];
        int nChannelsIn = shapeMin[indChannels];
        int widthStep = shapeStep[indX];
        int heightStep = shapeStep[indY];
        long[] inputShape = Arrays.stream(shapeMin).mapToLong(i -> i).toArray();
        if (preferredTileWidth <= 0) {
            preferredTileWidth = 512;
        }
        if (preferredTileHeight <= 0) {
            preferredTileHeight = 512;
        }
        width = BioimageIoTools.updateLength(width, widthStep, preferredTileWidth);
        height = BioimageIoTools.updateLength(height, heightStep, preferredTileHeight);
        inputShape[indX] = width;
        inputShape[indY] = height;
        int[] outputShape = output.getShape().getShape();
        if (outputShape == null || outputShape.length == 0) {
            double[] outputShapeScale = output.getShape().getScale();
            double[] outputShapeOffset = output.getShape().getOffset();
            outputShape = new int[outputShapeScale.length];
            for (int i2 = 0; i2 < outputShape.length; ++i2) {
                outputShape[i2] = (int)Math.round((double)inputShape[i2] * outputShapeScale[i2] + outputShapeOffset[i2]);
            }
        }
        int nChannelsOut = outputShape[indChannels];
        Padding padding = Padding.empty();
        int[] halo = output.getHalo();
        if (halo != null && halo.length != 0) {
            padding = Padding.getPadding((int)halo[indX], (int)halo[indY]);
        }
        ArrayList<Object> preprocessing = new ArrayList<Object>();
        Collections.addAll(preprocessing, inputOps);
        preprocessing.add(ImageOps.Core.ensureType(PixelType.FLOAT32));
        if (input.getPreprocessing() != null) {
            for (Processing transform : input.getPreprocessing()) {
                ImageOp op = BioimageIoTools.transformToOp(transform);
                if (op == null) {
                    logger.warn("Unsupported preprocessing transform: {}", (Object)transform);
                    continue;
                }
                preprocessing.add(op);
            }
        }
        if ((dnn = BioimageIoTools.buildDnnModel(modelSpec)) == null) {
            throw new UnsupportedOperationException("Unable to create a DnnModel for " + modelSpec.getName() + ".\nCheck 'View > Show log' for more details.");
        }
        ArrayList<ImageOp> postprocessing = new ArrayList<ImageOp>();
        if (output.getPostprocessing() != null) {
            for (Processing transform : output.getPostprocessing()) {
                ImageOp op = BioimageIoTools.transformToOp(transform);
                if (op == null) {
                    logger.warn("Unsupported postprocessing transform: {}", (Object)transform);
                    continue;
                }
                postprocessing.add(op);
            }
        }
        LinkedHashMap<Integer, PathClass> labels = new LinkedHashMap<Integer, PathClass>();
        for (int c = 0; c < nChannelsOut; ++c) {
            labels.put(c, PathClass.getInstance((String)("Class " + c)));
        }
        return PatchClassifierParams.builder().inputChannels(IntStream.range(0, nChannelsIn).toArray()).patchSize(width, height).halo(padding).preprocessing(preprocessing).prediction(dnn, padding, new String[0]).postprocessing(postprocessing).outputClasses(labels).outputChannelType(ImageServerMetadata.ChannelType.PROBABILITY).build();
    }

    public static ImageOp transformsToOp(Collection<? extends Processing> transforms) {
        ArrayList<ImageOp> ops = new ArrayList<ImageOp>();
        ops.add(ImageOps.Core.ensureType(PixelType.FLOAT32));
        for (Processing processing : transforms) {
            ImageOp op = BioimageIoTools.transformToOp(processing);
            if (op == null) continue;
            ops.add(op);
        }
        return ops.size() == 1 ? (ImageOp)ops.getFirst() : ImageOps.Core.sequential(ops);
    }

    public static ImageOp transformToOp(Processing transform) {
        if (transform instanceof Processing.Binarize) {
            Processing.Binarize binarize = (Processing.Binarize)transform;
            return ImageOps.Threshold.threshold(binarize.getThreshold());
        }
        if (transform instanceof Processing.Clip) {
            Processing.Clip clip = (Processing.Clip)transform;
            return ImageOps.Core.clip(clip.getMin(), clip.getMax());
        }
        if (transform instanceof Processing.ScaleLinear) {
            Processing.ScaleLinear scale = (Processing.ScaleLinear)transform;
            return ImageOps.Core.sequential(ImageOps.Core.multiply(scale.getGain()), ImageOps.Core.add(scale.getOffset()));
        }
        if (transform instanceof Processing.ScaleMeanVariance) {
            Processing.ScaleMeanVariance scale = (Processing.ScaleMeanVariance)transform;
            logger.warn("Unsupported transform {} - cannot access reference tensor {}", (Object)transform, (Object)scale.getReferenceTensor());
            return null;
        }
        if (transform instanceof Processing.ScaleRange) {
            Processing.ScaleRange scale = (Processing.ScaleRange)transform;
            Processing.ProcessingMode mode = BioimageIoTools.warnIfUnsupportedMode(transform.getName(), scale.getMode(), List.of(Processing.ProcessingMode.PER_SAMPLE));
            assert (mode == Processing.ProcessingMode.PER_SAMPLE);
            String axes = Axes.getAxesString((Axis[])scale.getAxes());
            boolean perChannel = false;
            if (axes != null) {
                perChannel = !axes.contains("c");
            } else {
                logger.warn("Axes not specified for {} - channels will be normalized jointly", (Object)transform);
            }
            return ImageOps.Normalize.percentile(scale.getMinPercentile(), scale.getMaxPercentile(), perChannel, scale.getEps());
        }
        if (transform instanceof Processing.Sigmoid) {
            return ImageOps.Normalize.sigmoid();
        }
        if (transform instanceof Processing.ZeroMeanUnitVariance) {
            Processing.ZeroMeanUnitVariance zeroMeanUnitVariance = (Processing.ZeroMeanUnitVariance)transform;
            Processing.ProcessingMode mode = BioimageIoTools.warnIfUnsupportedMode(transform.getName(), zeroMeanUnitVariance.getMode(), List.of(Processing.ProcessingMode.PER_SAMPLE, Processing.ProcessingMode.FIXED));
            if (mode == Processing.ProcessingMode.PER_SAMPLE) {
                String axes = Axes.getAxesString((Axis[])zeroMeanUnitVariance.getAxes());
                boolean perChannel = false;
                if (axes != null) {
                    boolean bl = perChannel = !axes.contains("c");
                    if (!BioimageIoTools.sameAxes(axes, "xy") && !BioimageIoTools.sameAxes(axes, "xyc")) {
                        logger.warn("Unsupported axes {} for {} - I will use {} instead", new Object[]{axes, transform.getName(), perChannel ? "xy" : "xyc"});
                    }
                } else {
                    logger.warn("Axes not specified for {} - channels will be normalized jointly", (Object)transform);
                }
                return ImageOps.Normalize.zeroMeanUnitVariance(perChannel, zeroMeanUnitVariance.getEps());
            }
            assert (mode == Processing.ProcessingMode.FIXED);
            double[] std = zeroMeanUnitVariance.getStd();
            int i = 0;
            while (i < std.length) {
                int n = i++;
                std[n] = std[n] + zeroMeanUnitVariance.getEps();
            }
            return ImageOps.Core.sequential(ImageOps.Core.subtract(zeroMeanUnitVariance.getMean()), ImageOps.Core.divide(std));
        }
        logger.warn("Unknown transform {} - cannot convert to ImageOp", (Object)transform);
        return null;
    }

    private static DnnShape getMinShape(InputTensor spec) {
        return DnnShape.of(Arrays.stream(spec.getShape().getShapeMin()).mapToLong(i -> i).toArray());
    }

    private static DnnShape getOutputShapeFromInput(OutputTensor outputSpec, DnnShape inputShape) {
        if (inputShape == null) {
            if (Arrays.stream(outputSpec.getShape().getScale()).anyMatch(s -> s != 0.0)) {
                logger.warn("Attempting to infer scaled output shape, but input shape is not available");
            }
            return DnnShape.of(Arrays.stream(outputSpec.getShape().getOffset()).mapToLong(i -> (long)i).toArray());
        }
        int n = inputShape.numDimensions();
        int[] inputArr = new int[n];
        long[] shape = new long[n];
        for (int i2 = 0; i2 < n; ++i2) {
            inputArr[i2] = (int)inputShape.get(i2);
        }
        int[] outputShape = outputSpec.getShape().getTargetShape(inputArr);
        for (int i3 = 0; i3 < n; ++i3) {
            shape[i3] = outputShape[i3];
        }
        return DnnShape.of(shape);
    }

    private static URI resolveUri(URI base, String relative) {
        if (base.toString().startsWith("jar:file:")) {
            return URI.create("jar:" + String.valueOf(URI.create(base.toString().substring(4)).resolve(relative)));
        }
        return base.resolve(relative);
    }

    private static int updateLength(int minLength, int step, int targetLength) {
        if (targetLength <= minLength || step <= 0) {
            return minLength;
        }
        return minLength + (int)((double)(targetLength - minLength) / (double)step) * step;
    }

    private static boolean sameAxes(String input, String target) {
        if (Objects.equals(input, target)) {
            return true;
        }
        if (input == null || target == null || input.length() != target.length()) {
            return false;
        }
        char[] inputArray = input.toLowerCase().toCharArray();
        char[] targetArray = target.toLowerCase().toCharArray();
        Arrays.sort(inputArray);
        Arrays.sort(targetArray);
        return Arrays.equals(inputArray, targetArray);
    }

    private static Processing.ProcessingMode warnIfUnsupportedMode(String transformName, Processing.ProcessingMode mode, List<Processing.ProcessingMode> allowed) {
        if (mode == null || mode == Processing.ProcessingMode.PER_DATASET) {
            logger.warn("Unsupported mode {} for {}, will be switched to {}", new Object[]{mode, transformName, allowed.getFirst()});
            return allowed.getFirst();
        }
        return mode;
    }
}

