/*
 * Decompiled with CFR 0.152.
 */
package qupath.opencv.dnn;

import java.io.File;
import java.io.IOException;
import java.net.URI;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.global.opencv_dnn;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.MatVector;
import org.bytedeco.opencv.opencv_core.Scalar;
import org.bytedeco.opencv.opencv_core.Size;
import org.bytedeco.opencv.opencv_core.StringVector;
import org.bytedeco.opencv.opencv_dnn.ClassificationModel;
import org.bytedeco.opencv.opencv_dnn.DetectionModel;
import org.bytedeco.opencv.opencv_dnn.KeypointsModel;
import org.bytedeco.opencv.opencv_dnn.Model;
import org.bytedeco.opencv.opencv_dnn.Net;
import org.bytedeco.opencv.opencv_dnn.SegmentationModel;
import org.bytedeco.opencv.opencv_dnn.TextDetectionModel_DB;
import org.bytedeco.opencv.opencv_dnn.TextDetectionModel_EAST;
import org.bytedeco.opencv.opencv_dnn.TextRecognitionModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.io.UriResource;
import qupath.opencv.dnn.AbstractDnnModel;
import qupath.opencv.dnn.BlobFunction;
import qupath.opencv.dnn.DefaultBlobFunction;
import qupath.opencv.dnn.DnnShape;
import qupath.opencv.dnn.DnnTools;
import qupath.opencv.dnn.PredictionFunction;
import qupath.opencv.ops.ImageOp;
import qupath.opencv.ops.ImageOps;

public class OpenCVDnn
extends AbstractDnnModel<Mat>
implements UriResource {
    private static Logger logger = LoggerFactory.getLogger(OpenCVDnn.class);
    private String name;
    private ModelType modelType = ModelType.DEFAULT;
    private URI pathModel;
    private URI pathConfig;
    private String framework;
    private int backend = DnnTools.useCuda() ? 5 : 0;
    private int target = DnnTools.useCuda() ? 6 : 0;
    private boolean crop = false;
    private boolean swapRB = false;
    private Size size;
    private Scalar mean;
    private double scale;
    private Map<String, DnnShape> inputs;
    private Map<String, DnnShape> outputs;
    private transient boolean constructed = false;
    private transient PredictionFunction<Mat> predFun;
    private transient BlobFunction<Mat> blobFun;

    public Net buildNet() {
        String fileModel = this.pathModel == null ? null : Paths.get(this.pathModel).toFile().getAbsolutePath();
        String fileConfig = this.pathConfig == null ? null : Paths.get(this.pathConfig).toFile().getAbsolutePath();
        Net net = opencv_dnn.readNet((String)fileModel, (String)fileConfig, (String)this.framework);
        this.initializeNet(net);
        this.constructed = true;
        return net;
    }

    public <T extends Model> T buildModel(ModelType type) {
        if (type == null) {
            type = ModelType.DEFAULT;
        }
        Net net = this.buildNet();
        Model model = OpenCVDnn.buildModel(type, net);
        this.initializeModel(model);
        return (T)model;
    }

    private static Model buildModel(ModelType type, Net net) {
        switch (type.ordinal()) {
            case 3: {
                return new ClassificationModel(net);
            }
            case 1: {
                return new DetectionModel(net);
            }
            case 2: {
                return new SegmentationModel(net);
            }
            case 4: {
                return new KeypointsModel(net);
            }
            case 5: {
                return new TextRecognitionModel(net);
            }
            case 6: {
                return new TextDetectionModel_DB(net);
            }
            case 7: {
                return new TextDetectionModel_EAST(net);
            }
        }
        return new Model(net);
    }

    public <T extends Model> T buildModel() {
        return this.buildModel(this.modelType);
    }

    public void initializeModel(Model model) {
        model.setInputCrop(this.crop);
        model.setInputSwapRB(this.swapRB);
        if (this.mean != null) {
            model.setInputMean(this.mean);
        }
        if (Double.isFinite(this.scale)) {
            model.setInputScale(Scalar.all((double)this.scale));
        }
        if (this.size != null) {
            model.setInputSize(this.size);
        }
    }

    private PredictionFunction<Mat> createPredictionFunction() {
        return new OpenCVNetFunction();
    }

    private void initializeNet(Net net) {
        switch (this.target) {
            case 6: 
            case 7: {
                int count = opencv_core.getCudaEnabledDeviceCount();
                if (count < 0) {
                    logger.warn("Unable to set CUDA target - driver may be missing or unavailable (device count = {})", (Object)count);
                    break;
                }
                if (count == 0) {
                    logger.warn("Unable to set CUDA target - OpenCV not compiled with CUDA support (device count = {})", (Object)count);
                    break;
                }
                if (this.backend != 5) {
                    logger.warn("Must specify CUDA backend to use CUDA target - request will be ignored");
                    break;
                }
                logger.debug("Setting CUDA backend and target ({}:{})", (Object)this.backend, (Object)this.target);
                net.setPreferableBackend(this.backend);
                net.setPreferableTarget(this.target);
                break;
            }
            case 1: 
            case 2: {
                if (!opencv_core.haveOpenCL()) {
                    logger.warn("Cannot set OpenCL target - OpenCL is unavailable on this platform");
                    break;
                }
                if (this.backend == 5) {
                    logger.warn("Cannot set CUDA backend and OpenCL target");
                    break;
                }
                logger.debug("Setting OpenCL backend and target ({}:{})", (Object)this.backend, (Object)this.target);
                net.setPreferableBackend(this.backend);
                net.setPreferableTarget(this.target);
                break;
            }
            default: {
                net.setPreferableBackend(this.backend);
                net.setPreferableBackend(this.target);
            }
        }
    }

    public String getName() {
        return this.name;
    }

    public Double getScale() {
        return this.scale;
    }

    public ModelType getModelType() {
        return this.modelType;
    }

    public Scalar getMean() {
        return this.mean == null ? null : new Scalar(this.mean);
    }

    public URI getModelUri() {
        return this.pathModel;
    }

    public URI getConfigUri() {
        return this.pathConfig;
    }

    public String getFramework() {
        return this.framework;
    }

    public static Builder builder(String pathModel) {
        return new Builder(pathModel);
    }

    public Collection<URI> getURIs() throws IOException {
        ArrayList<URI> list = new ArrayList<URI>();
        if (this.pathModel != null) {
            list.add(this.pathModel);
        }
        if (this.pathConfig != null) {
            list.add(this.pathConfig);
        }
        return list;
    }

    public boolean updateURIs(Map<URI, URI> replacements) throws IOException {
        if (this.constructed) {
            throw new UnsupportedOperationException("URIs cannot be updated after construction!");
        }
        boolean changes = false;
        for (Map.Entry<URI, URI> entry : replacements.entrySet()) {
            if (entry.getKey() == null || Objects.equals(entry.getKey(), entry.getValue())) continue;
            if (Objects.equals(this.pathModel, entry.getKey())) {
                this.pathModel = entry.getValue();
                changes = true;
            }
            if (!Objects.equals(this.pathConfig, entry.getKey())) continue;
            this.pathConfig = entry.getValue();
            changes = true;
        }
        return changes;
    }

    private BlobFunction<Mat> createBlobFunction() {
        ArrayList<ImageOp> ops = new ArrayList<ImageOp>();
        if (this.mean != null) {
            double[] scalarArray = new double[4];
            this.mean.get(scalarArray);
            ops.add(ImageOps.Core.subtract(scalarArray));
        }
        if (this.scale != 1.0) {
            ops.add(ImageOps.Core.multiply(this.scale));
        }
        ImageOp preprocess = ops.isEmpty() ? null : (ops.size() == 1 ? (ImageOp)ops.get(0) : ImageOps.Core.sequential(ops));
        return new DefaultBlobFunction(preprocess, this.size, this.crop);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public BlobFunction<Mat> getBlobFunction() {
        if (this.blobFun == null) {
            OpenCVDnn openCVDnn = this;
            synchronized (openCVDnn) {
                if (this.blobFun == null) {
                    this.blobFun = this.createBlobFunction();
                }
            }
        }
        return this.blobFun;
    }

    @Override
    public BlobFunction<Mat> getBlobFunction(String name) {
        return this.getBlobFunction();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    @Override
    public PredictionFunction<Mat> getPredictionFunction() {
        if (this.predFun == null) {
            OpenCVDnn openCVDnn = this;
            synchronized (openCVDnn) {
                if (this.predFun == null) {
                    this.predFun = this.createPredictionFunction();
                }
            }
        }
        return this.predFun;
    }

    @Override
    public void close() throws Exception {
        if (this.predFun instanceof AutoCloseable) {
            ((AutoCloseable)((Object)this.predFun)).close();
        }
    }

    public static enum ModelType {
        DEFAULT,
        DETECTION,
        SEGMENTATION,
        CLASSIFICATION,
        KEYPOINTS,
        TEXT_RECOGNITION,
        TEXT_DETECTION_DB,
        TEXT_DETECTION_EAST;

    }

    class OpenCVNetFunction
    implements PredictionFunction<Mat>,
    AutoCloseable {
        private transient Net net;
        private transient List<String> outputLayerNames;
        private transient StringVector outputLayerNamesVector;

        OpenCVNetFunction() {
            this.ensureInitialized();
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        private void ensureInitialized() {
            if (this.net == null || this.net.isNull()) {
                OpenCVNetFunction openCVNetFunction = this;
                synchronized (openCVNetFunction) {
                    if (this.net == null || this.net.isNull()) {
                        this.net = OpenCVDnn.this.buildNet();
                        this.net.retainReference();
                        this.outputLayerNames = new ArrayList<String>();
                        if (OpenCVDnn.this.outputs != null && !OpenCVDnn.this.outputs.isEmpty()) {
                            this.outputLayerNames.addAll(OpenCVDnn.this.outputs.keySet());
                        } else {
                            StringVector names = this.net.getUnconnectedOutLayersNames();
                            for (BytePointer bp : names.get()) {
                                this.outputLayerNames.add(bp.getString());
                            }
                        }
                        this.outputLayerNamesVector = new StringVector((String[])this.outputLayerNames.toArray(String[]::new));
                        this.outputLayerNamesVector.retainReference();
                    }
                }
            }
        }

        private Net getNet() {
            this.ensureInitialized();
            return this.net;
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public Mat predict(Mat input) {
            Net net;
            Net net2 = net = this.getNet();
            synchronized (net2) {
                net.setInput(input);
                if (this.outputLayerNames.size() > 1) {
                    logger.warn("Single output requested for multi-output model - only the first will be returned");
                }
                return net.forward(this.outputLayerNames.get(0)).clone();
            }
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public Map<String, Mat> predict(Map<String, Mat> input) {
            Net net = this.getNet();
            if (input.size() == 1 && this.outputLayerNames.size() == 1) {
                Mat output = this.predict(input.values().iterator().next());
                return Map.of(this.outputLayerNames.get(0), output);
            }
            if (this.outputLayerNamesVector == null || this.outputLayerNamesVector.isNull()) {
                this.outputLayerNamesVector = new StringVector((String[])this.outputLayerNames.toArray(String[]::new));
            }
            LinkedHashMap<String, Mat> result = new LinkedHashMap<String, Mat>();
            for (String name : this.outputLayerNames) {
                result.put(name, new Mat());
            }
            try (PointerScope scope = new PointerScope();){
                MatVector output = new MatVector();
                Net net2 = net;
                synchronized (net2) {
                    boolean singleInput = input.size() == 1;
                    for (Map.Entry<String, Mat> entry : input.entrySet()) {
                        if (singleInput) {
                            net.setInput(entry.getValue());
                            continue;
                        }
                        net.setInput(entry.getValue(), entry.getKey(), 1.0, null);
                    }
                    net.forward(output, this.outputLayerNamesVector);
                    Mat[] mats = output.get();
                    int i = 0;
                    for (String name : this.outputLayerNames) {
                        ((Mat)result.get(name)).put(mats[i].clone());
                        ++i;
                    }
                }
            }
            return result;
        }

        @Override
        public Map<String, DnnShape> getInputs() {
            if (OpenCVDnn.this.inputs != null) {
                return OpenCVDnn.this.inputs;
            }
            return Collections.singletonMap("input", DnnShape.UNKNOWN_SHAPE);
        }

        @Override
        public Map<String, DnnShape> getOutputs(DnnShape ... inputShapes) {
            if (OpenCVDnn.this.outputs != null) {
                return OpenCVDnn.this.outputs;
            }
            return DnnTools.getOutputLayers(this.net, inputShapes);
        }

        @Override
        public synchronized void close() throws Exception {
            if (this.net != null) {
                logger.debug("Closing {}", (Object)this.net);
                this.net.close();
                this.net.deallocate();
            }
            if (this.outputLayerNamesVector != null) {
                this.outputLayerNamesVector.close();
                this.outputLayerNamesVector.deallocate();
            }
        }
    }

    public static class Builder {
        private String name;
        private ModelType modelType = ModelType.DEFAULT;
        private URI pathModel;
        private URI pathConfig;
        private String framework;
        private Size size = null;
        private Scalar mean = null;
        private double scale = 1.0;
        private boolean swapRB = false;
        private int backend = DnnTools.useCuda() ? 5 : 0;
        private int target = DnnTools.useCuda() ? 6 : 0;
        private Map<String, DnnShape> outputs;

        private Builder(String pathModel) {
            this(new File(pathModel).toURI());
        }

        private Builder(URI pathModel) {
            this.pathModel = pathModel;
            try {
                this.name = Paths.get(pathModel).getFileName().toString();
            }
            catch (Exception e) {
                logger.debug("Unable to set default Net name from {} ({})", (Object)pathModel, (Object)e.getLocalizedMessage());
            }
        }

        public Builder framework(String name) {
            this.framework = name;
            return this;
        }

        public Builder config(String pathConfig) {
            return this.config(new File(pathConfig).toURI());
        }

        public Builder config(URI pathConfig) {
            this.pathConfig = pathConfig;
            return this;
        }

        public Builder name(String name) {
            this.name = name;
            return this;
        }

        public Builder opencl() {
            this.backend = 3;
            this.target = 1;
            return this;
        }

        public Builder opencl16() {
            this.backend = 3;
            this.target = 2;
            return this;
        }

        public Builder cuda() {
            this.backend = 5;
            this.target = 6;
            return this;
        }

        public Builder cpu() {
            this.backend = 3;
            this.target = 0;
            return this;
        }

        public Builder cuda16() {
            this.backend = 5;
            this.target = 7;
            return this;
        }

        public Builder target(int target) {
            this.target = target;
            return this;
        }

        public Builder backend(int backend) {
            this.backend = backend;
            return this;
        }

        public Builder mean(Scalar mean) {
            this.mean = mean;
            return this;
        }

        public Builder scale(double scale) {
            this.scale = scale;
            return this;
        }

        public Builder size(int width, int height) {
            this.size = new Size(width, height);
            return this;
        }

        public Builder size(Size size) {
            return this.size(size.width(), size.height());
        }

        public Builder modelType(ModelType type) {
            this.modelType = type;
            return this;
        }

        public Builder classification() {
            return this.modelType(ModelType.CLASSIFICATION);
        }

        public Builder segmentation() {
            return this.modelType(ModelType.SEGMENTATION);
        }

        public Builder detection() {
            return this.modelType(ModelType.DETECTION);
        }

        public Builder outputs(String ... layers) {
            this.outputs = Arrays.stream(layers).collect(Collectors.toMap(n -> n, n -> DnnShape.UNKNOWN_SHAPE));
            return this;
        }

        public Builder outputs(Map<String, DnnShape> outputs) {
            this.outputs = Collections.unmodifiableMap(new LinkedHashMap<String, DnnShape>(outputs));
            return this;
        }

        public OpenCVDnn build() {
            OpenCVDnn dnn = new OpenCVDnn();
            dnn.pathModel = this.pathModel;
            dnn.pathConfig = this.pathConfig;
            dnn.framework = this.framework;
            dnn.name = this.name;
            dnn.modelType = this.modelType == null ? ModelType.DEFAULT : this.modelType;
            dnn.size = this.size == null ? null : new Size(this.size);
            dnn.backend = this.backend;
            dnn.target = this.target;
            dnn.mean = this.mean == null ? null : new Scalar(this.mean);
            dnn.scale = this.scale;
            dnn.swapRB = this.swapRB;
            dnn.outputs = this.outputs;
            return dnn;
        }
    }
}

