/*
 * Decompiled with CFR 0.152.
 */
package qupath.ext.djl;

import ai.djl.inference.Predictor;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.translate.NoBatchifyTranslator;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorContext;
import ai.djl.util.PairList;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.ext.djl.DjlTools;
import qupath.lib.io.UriResource;
import qupath.opencv.dnn.DnnModel;
import qupath.opencv.dnn.DnnShape;

class DjlDnnModel
implements DnnModel,
AutoCloseable,
UriResource {
    private static final Logger logger = LoggerFactory.getLogger(DjlDnnModel.class);
    private List<URI> uris;
    private String engine;
    private String ndLayout;
    private Map<String, DnnShape> inputs;
    private Map<String, DnnShape> outputs;
    private boolean lazyInitialize;
    private transient boolean failed;
    private transient ZooModel<Mat[], Mat[]> model;
    private transient Predictor<Mat[], Mat[]> predictor;
    private static final String DEFAULT_MAT_LAYOUT = DjlDnnModel.getLayout(LayoutType.HEIGHT, LayoutType.WIDTH, LayoutType.CHANNEL);

    DjlDnnModel(String engine, Collection<URI> uris, String ndLayout, Map<String, DnnShape> inputs, Map<String, DnnShape> outputs, boolean lazyInitialize) {
        this.engine = engine;
        this.uris = new ArrayList<URI>(uris);
        if (ndLayout == null) {
            logger.warn("ndLayout not specified - I'll need to try to guess");
        } else {
            this.ndLayout = ndLayout.toUpperCase();
        }
        this.inputs = inputs;
        this.outputs = outputs;
        this.lazyInitialize = lazyInitialize;
        if (!lazyInitialize) {
            this.ensureInitialized();
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private void ensureInitialized() {
        if (this.model != null) {
            return;
        }
        if (!this.failed && this.model == null) {
            DjlDnnModel djlDnnModel = this;
            synchronized (djlDnnModel) {
                block13: {
                    if (!this.failed && this.model == null) {
                        try {
                            PairList description;
                            logger.debug("Initializing DjlDnnModel");
                            this.model = DjlTools.loadModel(this.engine, Mat[].class, Mat[].class, new ModelMatTranslator(), (URI[])this.uris.toArray(URI[]::new));
                            this.predictor = this.model.newPredictor();
                            if (this.inputs == null || this.inputs.isEmpty()) {
                                description = this.model.describeInput();
                                this.inputs = description != null && !description.isEmpty() ? description.stream().collect(Collectors.toMap(p -> (String)p.getKey(), p -> DjlTools.convertShape((Shape)p.getValue()))) : Map.of("input", DnnShape.UNKNOWN_SHAPE);
                            }
                            if (this.outputs != null && !this.outputs.isEmpty()) break block13;
                            try {
                                description = this.model.describeOutput();
                                if (description != null && !description.isEmpty()) {
                                    this.outputs = description.stream().collect(Collectors.toMap(p -> (String)p.getKey(), p -> DjlTools.convertShape((Shape)p.getValue())));
                                }
                            }
                            catch (Exception e) {
                                logger.debug(e.getMessage(), (Throwable)e);
                            }
                            if (this.outputs == null || this.outputs.isEmpty()) {
                                this.outputs = Map.of("output", DnnShape.UNKNOWN_SHAPE);
                            }
                        }
                        catch (Exception e) {
                            this.failed = true;
                            logger.debug("Failed to create DjlDnnModel");
                            throw new RuntimeException(e);
                        }
                    }
                }
            }
        }
    }

    public Map<String, Mat> predict(Map<String, Mat> blobs) {
        this.ensureInitialized();
        Predictor<Mat[], Mat[]> predictor = this.predictor;
        synchronized (predictor) {
            try {
                Mat[] result = (Mat[])this.predictor.predict((Object)((Mat[])blobs.values().stream().toArray(Mat[]::new)));
                if (result.length == 1) {
                    return Map.of("output", result[0]);
                }
                if (result.length == 0) {
                    return Map.of();
                }
                LinkedHashMap<String, Mat> output = new LinkedHashMap<String, Mat>();
                for (int i = 0; i < result.length; ++i) {
                    output.put("output" + i, result[i]);
                }
                return output;
            }
            catch (TranslateException e) {
                throw new RuntimeException(e);
            }
        }
    }

    public Mat predict(Mat mat) {
        return super.predict(mat);
    }

    public List<Mat> batchPredict(List<? extends Mat> mats) {
        return super.batchPredict(mats);
    }

    @Override
    public synchronized void close() throws Exception {
        if (this.model != null) {
            this.model.close();
            this.model = null;
            logger.debug("Closed DjlDnnModel");
        }
    }

    private static String getLayout(LayoutType ... layouts) {
        return LayoutType.toString((LayoutType[])layouts);
    }

    private static String estimateInputLayout(Mat mat) {
        if (mat.channels() >= 1) {
            return DEFAULT_MAT_LAYOUT;
        }
        long[] sizes = mat.createIndexer().sizes();
        switch (sizes.length) {
            case 1: {
                return DjlDnnModel.getLayout(LayoutType.HEIGHT);
            }
            case 2: {
                return DjlDnnModel.getLayout(LayoutType.HEIGHT, LayoutType.WIDTH);
            }
            case 3: {
                return DjlDnnModel.getLayout(LayoutType.HEIGHT, LayoutType.WIDTH, LayoutType.CHANNEL);
            }
        }
        throw new IllegalArgumentException("Unknown layout for input Mat " + String.valueOf(mat));
    }

    private static String estimateOutputLayout(NDArray array) {
        Shape shape = array.getShape();
        if (shape.isLayoutKnown()) {
            return shape.toLayoutString();
        }
        switch (shape.dimension()) {
            case 1: {
                return DjlDnnModel.getLayout(LayoutType.HEIGHT);
            }
            case 2: {
                return DjlDnnModel.getLayout(LayoutType.HEIGHT, LayoutType.WIDTH);
            }
            case 3: {
                if (shape.get(2) < shape.get(0)) {
                    return DjlDnnModel.getLayout(LayoutType.HEIGHT, LayoutType.WIDTH, LayoutType.CHANNEL);
                }
                return DjlDnnModel.getLayout(LayoutType.CHANNEL, LayoutType.HEIGHT, LayoutType.WIDTH);
            }
        }
        throw new IllegalArgumentException("Unknown layout for output shape " + String.valueOf(shape));
    }

    public Collection<URI> getURIs() throws IOException {
        return new ArrayList<URI>(this.uris);
    }

    public boolean updateURIs(Map<URI, URI> replacements) throws IOException {
        List newUris = this.uris.stream().map(u -> replacements.getOrDefault(u, (URI)u)).collect(Collectors.toList());
        if (Objects.equals(newUris, this.uris)) {
            return false;
        }
        this.uris = newUris;
        return true;
    }

    private class ModelMatTranslator
    implements NoBatchifyTranslator<Mat[], Mat[]> {
        private ModelMatTranslator() {
        }

        public Mat[] processOutput(TranslatorContext ctx, NDList list) throws Exception {
            String layout = (DjlDnnModel.this.ndLayout == null || DjlDnnModel.this.ndLayout.length() != ((NDArray)list.get(0)).getShape().dimension()) && !list.isEmpty() ? DjlDnnModel.estimateOutputLayout((NDArray)list.get(0)) : DjlDnnModel.this.ndLayout;
            Mat[] output = (Mat[])list.stream().map(b -> DjlTools.ndArrayToMat(b, layout)).toArray(Mat[]::new);
            return output;
        }

        public NDList processInput(TranslatorContext ctx, Mat ... input) throws Exception {
            NDList list = new NDList();
            String layout = DjlDnnModel.this.ndLayout;
            for (Mat mat : input) {
                if (layout == null) {
                    layout = DjlDnnModel.estimateInputLayout(mat);
                }
                list.add((Object)DjlTools.matToNDArray(ctx.getNDManager(), mat, layout));
            }
            return list;
        }
    }
}

