/*
 * Decompiled with CFR 0.152.
 */
package qupath.ext.djl;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.engine.Engine;
import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.Translator;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Progress;
import java.io.IOException;
import java.net.URI;
import java.nio.Buffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.javacpp.indexer.BooleanIndexer;
import org.bytedeco.javacpp.indexer.ByteIndexer;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.HalfIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.javacpp.indexer.LongIndexer;
import org.bytedeco.javacpp.indexer.UByteIndexer;
import org.bytedeco.javacpp.indexer.UShortIndexer;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.global.opencv_dnn;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.ext.djl.DjlDnnModel;
import qupath.opencv.dnn.DnnModel;
import qupath.opencv.dnn.DnnShape;
import qupath.opencv.tools.OpenCVTools;

public class DjlTools {
    private static final Logger logger = LoggerFactory.getLogger(DjlTools.class);
    public static String ENGINE_PYTORCH = "PyTorch";
    public static String ENGINE_TENSORFLOW = "TensorFlow";
    public static String ENGINE_MXNET = "MXNet";
    public static String ENGINE_TFLITE = "TFLite";
    public static String ENGINE_ONNX_RUNTIME = "OnnxRuntime";
    public static String ENGINE_XGBOOST = "XGBoost";
    public static String ENGINE_LIGHTGBM = "LightGBM";
    public static String ENGINE_DLR = "DLR";
    public static String ENGINE_TENSORRT = "TensorRT";
    public static String ENGINE_PADDLEPADDLE = "PaddlePaddle";
    public static Set<String> loadedEngines = new HashSet<String>();
    private static Map<String, Device> defaultDevices = new HashMap<String, Device>();
    static Set<String> ALL_ENGINES = Set.of(ENGINE_DLR, ENGINE_LIGHTGBM, ENGINE_MXNET, ENGINE_ONNX_RUNTIME, ENGINE_PADDLEPADDLE, ENGINE_PYTORCH, ENGINE_TENSORFLOW, ENGINE_TENSORRT, ENGINE_TFLITE, ENGINE_XGBOOST);
    private static final Object lock = new Object();

    public static DnnModel createDnnModel(URI uri, String ndLayout, int[] inputShape) {
        DnnShape shape = null;
        if (inputShape != null) {
            shape = DnnShape.of((long[])Arrays.stream(inputShape).mapToLong(i -> i).toArray());
        }
        return DjlTools.createDnnModel(null, uri, ndLayout, Map.of("input", shape), null);
    }

    public static DnnModel createDnnModel(String engine, URI uri, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs) {
        return DjlTools.createDnnModel(engine, Collections.singletonList(uri), ndLayout, inputs, outputs);
    }

    private static DnnModel createDnnModel(String engine, Collection<URI> uris, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs) {
        return new DjlDnnModel(engine, uris, ndLayout, inputs, outputs, false);
    }

    public static boolean hasEngine(String name) {
        return Engine.hasEngine((String)name);
    }

    public static boolean isEngineAvailable(String name) {
        if (loadedEngines.contains(name)) {
            return true;
        }
        if (!DjlTools.hasEngine(name)) {
            return false;
        }
        logger.debug("Need to try to get engine to test availability");
        return DjlTools.getEngine(name, false) != null;
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static Engine getEngine(String name, boolean downloadIfNeeded) throws IllegalArgumentException {
        if (!DjlTools.hasEngine(name)) {
            throw new IllegalArgumentException("Requested engine " + name + " is not available!");
        }
        Object object = lock;
        synchronized (object) {
            Engine engine;
            String offlineStatus = System.getProperty("ai.djl.offline");
            try {
                if (downloadIfNeeded) {
                    System.setProperty("ai.djl.offline", "false");
                } else {
                    System.setProperty("ai.djl.offline", "true");
                }
                Engine engine2 = Engine.getEngine((String)name);
                if (engine2 != null) {
                    loadedEngines.add(name);
                }
                engine = engine2;
            }
            catch (Exception e) {
                Engine engine3;
                try {
                    if (downloadIfNeeded) {
                        logger.error("Unable to get engine " + name + ": " + e.getMessage(), (Throwable)e);
                    } else {
                        String msg = e.getLocalizedMessage();
                        if (msg == null) {
                            logger.warn("Unable to get engine {}", (Object)name);
                        } else {
                            logger.warn("Unable to get engine {} ({})", (Object)name, (Object)e.getMessage());
                        }
                    }
                    engine3 = null;
                }
                catch (Throwable throwable) {
                    System.setProperty("ai.djl.offline", offlineStatus);
                    throw throwable;
                }
                System.setProperty("ai.djl.offline", offlineStatus);
                return engine3;
            }
            System.setProperty("ai.djl.offline", offlineStatus);
            return engine;
        }
    }

    static DnnShape convertShape(Shape shape) {
        return DnnShape.of((long[])shape.getShape());
    }

    static ZooModel<NDList, NDList> loadModel(String engineName, URI ... uris) throws ModelNotFoundException, MalformedModelException, IOException {
        return DjlTools.loadModel(engineName, NDList.class, NDList.class, null, uris);
    }

    static <P, Q> ZooModel<P, Q> loadModel(String engineName, Class<P> inputClass, Class<Q> outputClass, Translator<P, Q> translator, URI ... uris) throws ModelNotFoundException, MalformedModelException, IOException {
        StringBuilder sb = new StringBuilder();
        boolean isFirst = true;
        for (URI uri : uris) {
            if (isFirst) {
                isFirst = false;
            } else {
                sb.append(",");
            }
            String s = uri.toString();
            if (s.toLowerCase().startsWith("jar:file:") || s.toLowerCase().endsWith(".zip")) {
                logger.warn("Model URI is zipped - please unzip the model and recreate it");
            }
            sb.append(uri.toString());
        }
        return DjlTools.loadModel(engineName, inputClass, outputClass, translator, sb.toString());
    }

    private static <P, Q> ZooModel<P, Q> loadModel(String engineName, Class<P> inputClass, Class<Q> outputClass, Translator<P, Q> translator, String urls) throws ModelNotFoundException, MalformedModelException, IOException {
        Criteria.Builder builder = Criteria.builder().setTypes(inputClass, outputClass).optModelUrls(urls).optTranslator(translator).optProgress((Progress)new ProgressBar());
        String selectedEngine = null;
        if (engineName != null && Engine.getAllEngines().contains(engineName)) {
            selectedEngine = engineName;
        }
        if (selectedEngine == null) {
            String urlString = urls.toString().toLowerCase();
            if (urlString.endsWith(".onnx") && Engine.hasEngine((String)"OnnxRuntime")) {
                selectedEngine = "OnnxRuntime";
            } else if ((urlString.endsWith("pytorch") || urlString.endsWith(".pt")) && Engine.hasEngine((String)"PyTorch")) {
                selectedEngine = "PyTorch";
            } else if (urlString.endsWith(".tflite") && Engine.hasEngine((String)"TFLite")) {
                selectedEngine = "TFLite";
            } else if ((urlString.endsWith(".pb") || urlString.endsWith("tf_savedmodel.zip") || urlString.endsWith("tf_savedmodel")) && Engine.hasEngine((String)"TensorFlow")) {
                selectedEngine = "TensorFlow";
            }
        }
        if (selectedEngine != null) {
            builder.optEngine(selectedEngine);
            Device device = defaultDevices.getOrDefault(selectedEngine, null);
            if (device != null) {
                builder.optDevice(device);
                builder.optOption("mapLocation", "true");
            }
        }
        Criteria criteria = builder.build();
        return ModelZoo.loadModel((Criteria)criteria);
    }

    public static void setOverrideDevice(String engineName, Device device) {
        if (device == null) {
            defaultDevices.remove(engineName);
        } else {
            defaultDevices.put(engineName, device);
        }
    }

    public static Device getOverrideDevice(String engineName) {
        return defaultDevices.getOrDefault(engineName, null);
    }

    static Mat predict(Model model, Mat mat) throws TranslateException {
        try (Predictor predictor = model.newPredictor((Translator)new MatTranslator("CHW", "CHW"));){
            Mat mat2 = (Mat)predictor.batchPredict(Collections.singletonList(mat)).get(0);
            return mat2;
        }
    }

    public static NDArray matToNDArray(NDManager manager, Mat mat, String ndLayout) {
        DataType dataType = DjlTools.getDataType(mat);
        if (dataType == DataType.UNKNOWN) {
            throw new IllegalArgumentException("Unsupported data type for " + String.valueOf(mat));
        }
        Shape shape = DjlTools.getShape(mat, ndLayout);
        int indC = ndLayout.indexOf("C");
        int indHW = ndLayout.indexOf("HW");
        if (indHW < 0) {
            throw new IllegalArgumentException("Expected layout contains HW, but provided layout is " + ndLayout);
        }
        long nChannels = shape.get(indC);
        NDArray array = null;
        if (indC > indHW || shape.get(indC) == 1L) {
            Buffer buffer = mat.createBuffer();
            array = manager.create(buffer, shape, dataType);
        } else if (("NCHW".equals(ndLayout) || "CHW".equals(ndLayout)) && (nChannels == 3L || nChannels == 4L)) {
            array = manager.create(opencv_dnn.blobFromImage((Mat)mat).createBuffer(), shape, dataType);
        } else {
            long[] shapeDims = (long[])shape.getShape().clone();
            shapeDims[indC] = 1L;
            Shape shapeChannel = new Shape(shapeDims, shape.getLayout());
            for (Mat mat2 : OpenCVTools.splitChannels((Mat)mat)) {
                Buffer buffer = mat2.createBuffer();
                NDArray arrayTemp = manager.create(buffer, shapeChannel, dataType);
                if (array == null) {
                    array = arrayTemp;
                    continue;
                }
                NDArray arrayTemp2 = array.concat(arrayTemp, indC);
                array.close();
                arrayTemp.close();
                array = arrayTemp2;
            }
        }
        return array;
    }

    public static Mat ndArrayToMat(NDArray array, String ndLayout) {
        return DjlTools.ndArrayToMat(array, ndLayout, true);
    }

    public static Mat ndArrayToMat(NDArray array, String ndLayout, boolean doSqueeze) {
        Mat mat;
        block35: {
            int nChannels;
            DataType dataType = array.getDataType();
            Shape shape = array.getShape();
            if (ndLayout == null) {
                if (shape.isLayoutKnown()) {
                    ndLayout = LayoutType.toString((LayoutType[])shape.getLayout());
                } else {
                    throw new IllegalArgumentException("Can't convert ndArray to Mat - layout is unknown");
                }
            }
            int nDim = shape.dimension();
            int nLeading = doSqueeze ? shape.getLeadingOnes() : 0;
            int nTrailing = doSqueeze ? shape.getTrailingOnes() : 0;
            int[] dims = new int[nDim - nLeading - nTrailing];
            for (int i = 0; i < dims.length; ++i) {
                dims[i] = (int)shape.get(i + nLeading);
            }
            if (doSqueeze) {
                array = array.squeeze();
                ndLayout = ndLayout.substring(nLeading, ndLayout.length() - nTrailing);
            }
            int indH = ndLayout.indexOf("H");
            int indW = ndLayout.indexOf("W");
            int indC = ndLayout.indexOf("C");
            int height = indH >= 0 && indH < dims.length ? dims[indH] : 1;
            int width = indW >= 0 && indW < dims.length ? dims[indW] : 1;
            int n = nChannels = indC >= 0 && indC < dims.length ? dims[indC] : 1;
            if (nChannels > 1 && indC >= 0 && indC < indH) {
                Mat mat2 = new Mat();
                try (PointerScope scope = new PointerScope();){
                    ArrayList<Mat> channels = new ArrayList<Mat>();
                    try (NDList list = array.split((long)nChannels, indC);){
                        for (NDArray ndChannel : list) {
                            channels.add(DjlTools.ndArrayToMat(ndChannel, ndLayout, false));
                            ndChannel.close();
                        }
                    }
                    OpenCVTools.mergeChannels(channels, (Mat)mat2);
                }
                return mat2;
            }
            int cvDepth = DjlTools.getMatDepth(dataType);
            mat = dims.length <= 3 && (long)(width * height * nChannels) == array.size() ? new Mat(height, width, opencv_core.CV_MAKETYPE((int)cvDepth, (int)nChannels)) : new Mat(dims, cvDepth);
            try (Indexer indexer = mat.createIndexer();){
                if (indexer instanceof ByteIndexer) {
                    ((ByteIndexer)indexer).put(0L, array.toByteArray());
                    break block35;
                }
                if (indexer instanceof UByteIndexer) {
                    ((UByteIndexer)indexer).put(0L, DjlTools.getInts(array));
                    break block35;
                }
                if (indexer instanceof UShortIndexer) {
                    ((UShortIndexer)indexer).put(0L, DjlTools.getInts(array));
                    break block35;
                }
                if (indexer instanceof IntIndexer) {
                    ((IntIndexer)indexer).put(0L, DjlTools.getInts(array));
                    break block35;
                }
                if (indexer instanceof FloatIndexer) {
                    ((FloatIndexer)indexer).put(0L, DjlTools.getFloats(array));
                    break block35;
                }
                if (indexer instanceof HalfIndexer) {
                    ((HalfIndexer)indexer).put(0L, DjlTools.getFloats(array));
                    break block35;
                }
                if (indexer instanceof DoubleIndexer) {
                    ((DoubleIndexer)indexer).put(0L, DjlTools.getDoubles(array));
                    break block35;
                }
                if (indexer instanceof LongIndexer) {
                    ((LongIndexer)indexer).put(0L, DjlTools.getLongs(array));
                    break block35;
                }
                if (indexer instanceof BooleanIndexer) {
                    ((BooleanIndexer)indexer).put(0L, DjlTools.getBooleans(array));
                    break block35;
                }
                throw new IllegalArgumentException("Unable to convert array " + String.valueOf(array) + " to Mat");
            }
        }
        return mat;
    }

    public static long[] getLongs(NDArray array) {
        if (array.getDataType() == DataType.INT64) {
            try {
                return array.toLongArray();
            }
            catch (Exception e) {
                logger.error("Exception requesting longs from NDArray");
            }
        }
        return array.toType(DataType.INT64, true).toLongArray();
    }

    private static boolean[] getBooleans(NDArray array) {
        if (array.getDataType() == DataType.BOOLEAN) {
            try {
                return array.toBooleanArray();
            }
            catch (Exception e) {
                logger.error("Exception requesting ints from NDArray");
            }
        }
        return array.toType(DataType.BOOLEAN, true).toBooleanArray();
    }

    private static int[] getInts(NDArray array) {
        if (array.getDataType() == DataType.INT32) {
            try {
                return array.toIntArray();
            }
            catch (Exception e) {
                logger.error("Exception requesting ints from NDArray");
            }
        } else if (array.getDataType() == DataType.UINT8) {
            try {
                return array.toUint8Array();
            }
            catch (Exception e) {
                logger.error("Exception requesting ints from NDArray");
            }
        }
        return array.toType(DataType.INT32, true).toIntArray();
    }

    private static double[] getDoubles(NDArray array) {
        if (array.getDataType() == DataType.FLOAT64) {
            try {
                return array.toDoubleArray();
            }
            catch (Exception e) {
                logger.error("Exception requesting doubles from NDArray");
            }
        } else if (array.getDataType() == DataType.INT64) {
            try {
                return Arrays.stream(array.toLongArray()).mapToDouble(i -> i).toArray();
            }
            catch (Exception e) {
                logger.error("Exception requesting doubles from NDArray (from longs)");
            }
        }
        return array.toDevice(Device.cpu(), false).toType(DataType.FLOAT64, true).toDoubleArray();
    }

    private static float[] getFloats(NDArray array) {
        if (array.getDataType() == DataType.FLOAT32 || array.getDataType() == DataType.FLOAT16) {
            try {
                return array.toFloatArray();
            }
            catch (Exception e) {
                logger.error("Exception requesting floats from NDArray", (Throwable)e);
            }
        }
        return array.toType(DataType.FLOAT32, true).toFloatArray();
    }

    static Shape getShape(Mat mat, String ndLayout) {
        ArrayList<Pair> pairs = new ArrayList<Pair>();
        block5: for (LayoutType layout : LayoutType.fromValue((String)ndLayout)) {
            switch (layout) {
                case CHANNEL: {
                    pairs.add(new Pair((Object)mat.arrayChannels(), (Object)layout));
                    continue block5;
                }
                case HEIGHT: {
                    pairs.add(new Pair((Object)mat.arrayHeight(), (Object)layout));
                    continue block5;
                }
                case WIDTH: {
                    pairs.add(new Pair((Object)mat.arrayWidth(), (Object)layout));
                    continue block5;
                }
                default: {
                    pairs.add(new Pair((Object)1L, (Object)layout));
                }
            }
        }
        Shape shape = new Shape(new PairList(pairs));
        return shape;
    }

    static DataType getDataType(Mat mat) {
        switch (mat.depth()) {
            case 0: {
                return DataType.UINT8;
            }
            case 1: {
                return DataType.INT8;
            }
            case 4: {
                return DataType.INT32;
            }
            case 5: {
                return DataType.FLOAT32;
            }
            case 6: {
                return DataType.FLOAT64;
            }
            case 7: {
                return DataType.FLOAT16;
            }
        }
        return DataType.UNKNOWN;
    }

    static int getMatDepth(DataType dt) {
        switch (dt) {
            case BOOLEAN: {
                return 0;
            }
            case FLOAT16: {
                return 7;
            }
            case FLOAT32: {
                return 5;
            }
            case FLOAT64: {
                return 6;
            }
            case INT32: {
                return 4;
            }
            case INT64: {
                return 6;
            }
            case INT8: {
                return 1;
            }
            case UINT8: {
                return 0;
            }
        }
        throw new UnsupportedOperationException("Cannot convert data type " + String.valueOf(dt) + " to Mat");
    }

    static class MatTranslator
    implements Translator<Mat, Mat> {
        private String inputLayoutNd;
        private String outputLayoutNd;

        MatTranslator(String inputLayoutNd, String outputLayoutNd) {
            this.inputLayoutNd = inputLayoutNd;
            this.outputLayoutNd = outputLayoutNd;
        }

        public NDList processInput(TranslatorContext ctx, Mat input) throws Exception {
            NDArray ndarray = DjlTools.matToNDArray(ctx.getNDManager(), input, this.inputLayoutNd);
            return new NDList(new NDArray[]{ndarray});
        }

        public Mat processOutput(TranslatorContext ctx, NDList list) throws Exception {
            NDArray array = (NDArray)list.get(0);
            return DjlTools.ndArrayToMat(array, this.outputLayoutNd);
        }
    }
}

