/*
 * Decompiled with CFR 0.152.
 */
package qupath.opencv.dnn;

import java.awt.image.BufferedImage;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.IntFunction;
import java.util.stream.Collectors;
import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.javacpp.IntPointer;
import org.bytedeco.javacpp.PointerScope;
import org.bytedeco.javacpp.SizeTPointer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.global.opencv_dnn;
import org.bytedeco.opencv.opencv_core.IntIntPairVector;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.MatVector;
import org.bytedeco.opencv.opencv_core.Point2f;
import org.bytedeco.opencv.opencv_core.Point2fVector;
import org.bytedeco.opencv.opencv_core.Rect;
import org.bytedeco.opencv.opencv_core.RectVector;
import org.bytedeco.opencv.opencv_core.Scalar;
import org.bytedeco.opencv.opencv_core.Size;
import org.bytedeco.opencv.opencv_core.StringVector;
import org.bytedeco.opencv.opencv_dnn.ClassificationModel;
import org.bytedeco.opencv.opencv_dnn.DetectionModel;
import org.bytedeco.opencv.opencv_dnn.IntFloatPair;
import org.bytedeco.opencv.opencv_dnn.KeypointsModel;
import org.bytedeco.opencv.opencv_dnn.MatShapeVector;
import org.bytedeco.opencv.opencv_dnn.Net;
import org.bytedeco.opencv.opencv_dnn.SegmentationModel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.common.GeneralTools;
import qupath.lib.common.LogTools;
import qupath.lib.geom.Point2;
import qupath.lib.images.servers.ImageServer;
import qupath.lib.measurements.MeasurementList;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.PathObjectTools;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.regions.ImagePlane;
import qupath.lib.regions.RegionRequest;
import qupath.lib.roi.ROIs;
import qupath.lib.roi.interfaces.ROI;
import qupath.opencv.dnn.DnnModel;
import qupath.opencv.dnn.DnnModels;
import qupath.opencv.dnn.DnnShape;
import qupath.opencv.dnn.OpenCVDnn;
import qupath.opencv.tools.OpenCVTools;

public class DnnTools {
    private static final Logger logger = LoggerFactory.getLogger(DnnTools.class);
    private static boolean cudaAvailable = false;
    private static boolean useCuda = false;

    @Deprecated
    public static <T extends DnnModel> void registerDnnModel(Class<T> subtype, String name) {
        LogTools.warnOnce((Logger)logger, (String)"DnnTools.registerDnnModel is deprecated - use DnnModels.registerDnnModel instead");
        DnnModels.registerDnnModel(subtype, name);
    }

    public static OpenCVDnn.Builder builder(String modelPath) {
        return OpenCVDnn.builder(modelPath);
    }

    static void logAvailableBackends() {
        IntIntPairVector backends = opencv_dnn.getAvailableBackends();
        long n = backends.size();
        int i = 0;
        while ((long)i < n) {
            int bkend = backends.first((long)i);
            int target = backends.second((long)i);
            logger.info("Available backend {}, target {}", (Object)bkend, (Object)target);
            ++i;
        }
    }

    public static boolean isCudaAvailable() {
        return cudaAvailable;
    }

    public static void setUseCuda(boolean requestUseCuda) {
        if (requestUseCuda && !cudaAvailable) {
            logger.warn("CUDA is not available - request will be ignored");
            return;
        }
        useCuda = requestUseCuda;
    }

    public static boolean useCuda() {
        return useCuda;
    }

    public static List<String> getOutputLayerNames(Net net) {
        ArrayList<String> names = new ArrayList<String>();
        for (BytePointer bp : net.getUnconnectedOutLayersNames().get()) {
            names.add(bp.getString());
        }
        return names;
    }

    public static Map<String, DnnShape> getOutputLayers(Net net, DnnShape ... inputShape) {
        LinkedHashMap<String, DnnShape> output = new LinkedHashMap<String, DnnShape>();
        IntPointer layerIds = net.getUnconnectedOutLayers();
        int[] ids = new int[(int)layerIds.limit()];
        layerIds.get(ids);
        List<String> names = DnnTools.getOutputLayerNames(net);
        if (inputShape.length == 0) {
            return names.stream().collect(Collectors.toMap(n -> n, n -> DnnShape.UNKNOWN_SHAPE));
        }
        MatShapeVector inputShapes = new MatShapeVector((IntPointer[])Arrays.stream(inputShape).map(s -> DnnTools.toIntPointer(s)).toArray(IntPointer[]::new));
        MatShapeVector inLayerShapes = new MatShapeVector();
        MatShapeVector outLayerShapes = new MatShapeVector();
        for (String name : names) {
            int id = net.getLayerId(name);
            net.getLayerShapes(inputShapes, id, inLayerShapes, outLayerShapes);
            List<DnnShape> shapes = DnnTools.parseShape(outLayerShapes);
            if (shapes.size() > 1) {
                logger.warn("Multiple output shapes for layer {}, will use the first only", (Object)name);
            }
            output.put(name, shapes.get(0));
        }
        inputShapes.close();
        inLayerShapes.close();
        outLayerShapes.close();
        return output;
    }

    private static IntPointer toIntPointer(DnnShape shape) {
        return new IntPointer(Arrays.stream(shape.getShape()).mapToInt(l -> (int)l).toArray());
    }

    public static List<DNNLayer> parseLayers(Net net, int width, int height, int channels, int batchSize) {
        MatShapeVector netInputShape = DnnTools.getShapeVector(width, height, channels, batchSize);
        return DnnTools.parseLayers(net, netInputShape);
    }

    private static MatShapeVector getShapeVector(int width, int height, int channels, int batchSize) {
        int[] shapeInput = new int[]{batchSize, channels, height, width};
        return new MatShapeVector(new IntPointer(shapeInput));
    }

    private static List<DNNLayer> parseLayers(Net net, MatShapeVector netInputShape) {
        ArrayList<DNNLayer> list = new ArrayList<DNNLayer>();
        try (PointerScope scope = new PointerScope();){
            StringVector names = net.getLayerNames();
            MatShapeVector inputShape = new MatShapeVector();
            MatShapeVector outputShape = new MatShapeVector();
            for (BytePointer nameBytes : names.get()) {
                String name = nameBytes.getString();
                int id = net.getLayerId(name);
                net.getLayerShapes(netInputShape, id, inputShape, outputShape);
                list.add(new DNNLayer(name, id, DnnTools.parseShape(inputShape), DnnTools.parseShape(outputShape)));
            }
        }
        return list;
    }

    public static List<String> parseStrings(StringVector vector) {
        ArrayList<String> list = new ArrayList<String>();
        int n = (int)vector.size();
        for (int i = 0; i < n; ++i) {
            list.add(vector.get((long)i).getString());
        }
        return list;
    }

    public static List<DnnShape> parseShape(MatShapeVector vector) {
        ArrayList<DnnShape> shapes = new ArrayList<DnnShape>();
        for (IntPointer pointer : vector.get()) {
            long[] shape = new long[(int)pointer.limit()];
            for (int i = 0; i < shape.length; ++i) {
                shape[i] = pointer.get((long)i);
            }
            shapes.add(DnnShape.of(shape));
        }
        return shapes;
    }

    public static String summarize(Net net, int width, int height, int nChannels) throws IOException {
        StringBuilder sb = new StringBuilder();
        MatShapeVector netInputShape = DnnTools.getShapeVector(width, height, nChannels, 1);
        StringVector types = new StringVector();
        net.getLayerTypes(types);
        sb.append("Layer types:");
        for (String type : DnnTools.parseStrings(types)) {
            sb.append("\n\t").append(type);
        }
        sb.append("\nLayers:");
        for (DNNLayer layer : DnnTools.parseLayers(net, netInputShape)) {
            sb.append("\n\t").append(layer.toString());
        }
        long flops = net.getFLOPS(netInputShape);
        sb.append("\nFLOPS: ").append(flops);
        SizeTPointer weights = new SizeTPointer(1L);
        SizeTPointer blobs = new SizeTPointer(1L);
        net.getMemoryConsumption(netInputShape, weights, blobs);
        sb.append("\nMemory (weights): ").append(weights.get());
        sb.append("\nMemory (blobs): ").append(blobs.get());
        return sb.toString();
    }

    private static Mat readMat(ImageServer<BufferedImage> server, RegionRequest request) throws IOException {
        BufferedImage img = (BufferedImage)server.readRegion(request);
        return OpenCVTools.imageToMat(img);
    }

    public static Mat readPatch(ImageServer<BufferedImage> server, ROI roi, double downsample, int width, int height) throws IOException {
        return DnnTools.readPatch(server, roi, downsample, width, height, 0);
    }

    public static Mat readPatch(ImageServer<BufferedImage> server, ROI roi, double downsample, int width, int height, int borderPadding) throws IOException {
        Mat input;
        if (width < 0 && height < 0) {
            RegionRequest request = RegionRequest.createInstance((String)server.getPath(), (double)downsample, (ROI)roi);
            input = DnnTools.readMat(server, request);
        } else {
            if (width <= 0 || height <= 0) {
                throw new IllegalArgumentException("Width and height must both be > 0, or < 0 if the full ROI is used");
            }
            double scaledWidth = (double)width * downsample;
            double scaledHeight = (double)height * downsample;
            int xi = (int)Math.round(roi.getCentroidX() - scaledWidth / 2.0);
            int yi = (int)Math.round(roi.getCentroidY() - scaledHeight / 2.0);
            int xi2 = (int)Math.round((double)xi + scaledWidth);
            int yi2 = (int)Math.round((double)yi + scaledHeight);
            int x = GeneralTools.clipValue((int)xi, (int)0, (int)server.getWidth());
            int x2 = GeneralTools.clipValue((int)xi2, (int)0, (int)server.getWidth());
            int y = GeneralTools.clipValue((int)yi, (int)0, (int)server.getHeight());
            int y2 = GeneralTools.clipValue((int)yi2, (int)0, (int)server.getHeight());
            RegionRequest request = RegionRequest.createInstance((String)server.getPath(), (double)downsample, (int)x, (int)y, (int)(x2 - x), (int)(y2 - y), (int)roi.getZ(), (int)roi.getT());
            input = DnnTools.readMat(server, request);
            int matWidth = input.cols();
            int matHeight = input.rows();
            if (matWidth != width || matHeight != height) {
                if (matWidth > width) {
                    input.put(input.colRange(0, width));
                    matWidth = width;
                }
                if (matHeight > height) {
                    input.put(input.rowRange(0, height));
                    matHeight = height;
                }
                if (height > matHeight || width > matWidth) {
                    double xProp = DnnTools.calculateFirstPadProportion(xi, xi2, 0.0, server.getWidth());
                    double yProp = DnnTools.calculateFirstPadProportion(yi, yi2, 0.0, server.getHeight());
                    int padX = (int)Math.round((double)(width - matWidth) * xProp);
                    int padY = (int)Math.round((double)(height - matHeight) * yProp);
                    opencv_core.copyMakeBorder((Mat)input, (Mat)input, (int)padY, (int)(height - matHeight - padY), (int)padX, (int)(width - matWidth - padX), (int)borderPadding);
                }
            }
        }
        OpenCVTools.ensureContinuous(input, true);
        return input;
    }

    private static double calculateFirstPadProportion(double v1, double v2, double minVal, double maxVal) {
        if (v1 >= minVal) {
            return 0.0;
        }
        if (v2 <= maxVal) {
            return 1.0;
        }
        double d1 = minVal - v1;
        double d2 = v2 - maxVal;
        return d1 / (d1 + d2);
    }

    public static boolean classify(ClassificationModel model, PathObject pathObject, ImageServer<BufferedImage> server, double downsample, IntFunction<PathClass> classifier, String predictionMeasurement) throws IOException {
        ROI roi = pathObject.getROI();
        if (roi == null) {
            logger.warn("Cannot classify an object without a ROI!");
            return false;
        }
        RegionRequest request = RegionRequest.createInstance(server, (double)downsample);
        Mat input = DnnTools.readMat(server, request);
        boolean changes = DnnTools.classify(model, pathObject, input, classifier, predictionMeasurement);
        input.close();
        return changes;
    }

    public static boolean classify(ClassificationModel model, PathObject pathObject, ImageServer<BufferedImage> server, double downsample, int width, int height, IntFunction<PathClass> classifier, String predictionMeasurement) throws IOException, IllegalArgumentException {
        ROI roi = pathObject.getROI();
        if (roi == null) {
            logger.warn("Cannot classify an object without a ROI!");
            return false;
        }
        boolean preferNucleus = true;
        Mat input = DnnTools.readPatch(server, PathObjectTools.getROI((PathObject)pathObject, (boolean)preferNucleus), downsample, width, height);
        return DnnTools.classify(model, pathObject, input, classifier, predictionMeasurement);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static boolean classify(ClassificationModel model, PathObject pathObject, Mat input, IntFunction<PathClass> classifier, String predictionMeasurement) {
        IntFloatPair result;
        ClassificationModel classificationModel = model;
        synchronized (classificationModel) {
            result = model.classify(input);
        }
        int ind = result.first();
        PathClass pathClass = classifier == null ? null : classifier.apply(ind);
        boolean changed = pathClass != pathObject.getPathClass();
        pathObject.setPathClass(pathClass);
        if (predictionMeasurement != null) {
            pathObject.getMeasurementList().put(predictionMeasurement, (double)result.second());
            pathObject.getMeasurementList().close();
        }
        result.close();
        result.deallocate();
        return changed;
    }

    public static Mat segment(SegmentationModel model, ImageServer<BufferedImage> server, RegionRequest request) throws IOException {
        Mat output = new Mat();
        try (PointerScope scope = new PointerScope();){
            Mat input = DnnTools.readMat(server, request);
            DnnTools.segment(model, input, output);
            Mat mat = output;
            return mat;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static Mat segment(SegmentationModel model, Mat input, Mat output) {
        if (output == null) {
            output = new Mat();
        }
        try (PointerScope scope = new PointerScope();){
            SegmentationModel segmentationModel = model;
            synchronized (segmentationModel) {
                model.segment(input, output);
            }
            segmentationModel = output;
            return segmentationModel;
        }
    }

    public static List<PathObject> detect(DetectionModel model, ImageServer<BufferedImage> server, RegionRequest request, IntFunction<PathClass> classifier, Function<ROI, PathObject> creator) throws IOException {
        try (PointerScope scope = new PointerScope();){
            Mat mat = DnnTools.readMat(server, request);
            List<PathObject> list = DnnTools.detect(model, mat, request, classifier, creator);
            return list;
        }
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static List<PathObject> detect(DetectionModel model, Mat mat, RegionRequest request, IntFunction<PathClass> classifier, Function<ROI, PathObject> creator) {
        try (PointerScope scope = new PointerScope();){
            IntPointer ids = new IntPointer();
            FloatPointer preds = new FloatPointer();
            RectVector rects = new RectVector();
            DetectionModel detectionModel = model;
            synchronized (detectionModel) {
                model.detect(mat, ids, preds, rects);
            }
            double downsample = request == null ? 1.0 : request.getDownsample();
            ImagePlane plane = request == null ? ImagePlane.getDefaultPlane() : request.getImagePlane();
            double xOrigin = request == null ? 0.0 : (double)request.getX();
            double yOrigin = request == null ? 0.0 : (double)request.getY();
            long n = rects.size();
            ArrayList<PathObject> pathObjects = new ArrayList<PathObject>();
            for (long i = 0L; i < n; ++i) {
                Rect rect = rects.get(i);
                ROI roi = ROIs.createRectangleROI((double)(xOrigin + (double)rect.x() * downsample), (double)(yOrigin + (double)rect.y() * downsample), (double)((double)rect.width() * downsample), (double)((double)rect.height() * downsample), (ImagePlane)plane);
                PathClass pathClass = classifier == null ? null : classifier.apply(ids.get(i));
                double pred = preds.get(i);
                PathObject pathObject = creator.apply(roi);
                pathObject.setPathClass(pathClass);
                try (MeasurementList ml = pathObject.getMeasurementList();){
                    ml.put("Probability", pred);
                }
                pathObjects.add(pathObject);
            }
            ArrayList<PathObject> arrayList = pathObjects;
            return arrayList;
        }
    }

    static PathObject detectKeypoints(KeypointsModel model, Mat mat, RegionRequest request, ROI mask, double threshold, Function<ROI, PathObject> creator) {
        ROI roi = DnnTools.detectKeypointsROI(model, mat, request, mask, threshold);
        return creator.apply(roi);
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    static ROI detectKeypointsROI(KeypointsModel model, Mat mat, RegionRequest request, ROI mask, double threshold) {
        Point2fVector output;
        float thresh = (float)threshold;
        KeypointsModel keypointsModel = model;
        synchronized (keypointsModel) {
            output = model.estimate(mat, thresh);
        }
        double downsample = request.getDownsample();
        double xOrigin = request.getX();
        double yOrigin = request.getY();
        Point2f[] pointsArray = output.get();
        ArrayList<Point2> points = new ArrayList<Point2>();
        for (Point2f p : pointsArray) {
            double x = xOrigin + (double)p.x() * downsample;
            double y = yOrigin + (double)p.y() * downsample;
            if (mask != null && !mask.contains(x, y)) continue;
            points.add(new Point2(x, y));
        }
        return ROIs.createPointsROI(points, (ImagePlane)request.getImagePlane());
    }

    public static Mat blobFromImages(Mat ... mats) {
        if (mats.length == 0) {
            return new Mat();
        }
        if (mats.length == 1) {
            return DnnTools.blobFromImage(mats[0]);
        }
        int nChannels = mats[0].channels();
        if (nChannels == 1 || nChannels == 3 || nChannels == 4) {
            MatVector matvec = new MatVector(mats);
            return opencv_dnn.blobFromImages((MatVector)matvec);
        }
        throw new UnsupportedOperationException("Converting multiple images to a blob is only supported for 1, 3, or 4 channels, sorry!");
    }

    public static Mat blobFromImage(Mat mat) {
        int nChannels = mat.channels();
        if (nChannels == 1 || nChannels == 3 || nChannels == 4) {
            return opencv_dnn.blobFromImage((Mat)mat);
        }
        return DnnTools.blobFromImages(Collections.singletonList(mat), 1.0, new Size(), new Scalar(), false, false);
    }

    public static Mat blobFromImages(Mat mat, double scaleFactor, Size size, Scalar mean, boolean swapRB, boolean crop) {
        int nChannels = mat.channels();
        if (nChannels == 1 || nChannels == 3 || nChannels == 4) {
            return opencv_dnn.blobFromImage((Mat)mat, (double)scaleFactor, (Size)size, (Scalar)mean, (boolean)swapRB, (boolean)crop, (int)5);
        }
        return DnnTools.blobFromImages(Collections.singletonList(mat), scaleFactor, size, mean, swapRB, crop);
    }

    public static Mat blobFromImages(Collection<Mat> mats, double scaleFactor, Size size, Scalar mean, boolean swapRB, boolean crop) {
        Mat blob = null;
        Mat first = mats.iterator().next();
        int nChannels = first.channels();
        if (nChannels == 1 || nChannels == 3 || nChannels == 4) {
            blob = mats.size() == 1 ? opencv_dnn.blobFromImage((Mat)first, (double)scaleFactor, (Size)size, (Scalar)mean, (boolean)swapRB, (boolean)crop, (int)5) : opencv_dnn.blobFromImages((MatVector)new MatVector((Mat[])mats.toArray(Mat[]::new)), (double)scaleFactor, (Size)size, (Scalar)mean, (boolean)swapRB, (boolean)crop, (int)5);
        } else {
            logger.warn("Attempting to reshape an image with " + nChannels + " channels - this may not work! Only 1, 3 and 4 full supported, preprocessing will be ignored.");
            int[] shape = new int[4];
            Arrays.fill(shape, 1);
            int nRows = first.size(0);
            int nCols = first.size(1);
            shape[0] = mats.size();
            shape[1] = nChannels;
            shape[2] = nRows;
            shape[3] = nCols;
            blob = new Mat(shape, 5);
            Indexer idxBlob = blob.createIndexer();
            long[] indsBlob = new long[4];
            int n = 0;
            for (Mat mat : mats) {
                indsBlob[0] = n++;
                long[] indsMat = new long[4];
                Indexer idxMat = mat.createIndexer();
                for (int r = 0; r < nRows; ++r) {
                    indsMat[0] = r;
                    indsBlob[2] = r;
                    for (int c = 0; c < nCols; ++c) {
                        indsMat[1] = c;
                        indsBlob[3] = c;
                        for (int channel = 0; channel < nChannels; ++channel) {
                            indsMat[2] = channel;
                            indsBlob[1] = channel;
                            double val = idxMat.getDouble(indsMat);
                            idxBlob.putDouble(indsBlob, val);
                        }
                    }
                }
                idxMat.close();
            }
            idxBlob.close();
        }
        return blob;
    }

    public static List<Mat> imagesFromBlob(Mat blob) {
        MatVector vec = new MatVector();
        opencv_dnn.imagesFromBlob((Mat)blob, (MatVector)vec);
        return Arrays.asList(vec.get());
    }

    static {
        int cudaDeviceCount = opencv_core.getCudaEnabledDeviceCount();
        if (cudaDeviceCount > 0) {
            IntIntPairVector backends = opencv_dnn.getAvailableBackends();
            long n = backends.size();
            int i = 0;
            while ((long)i < n) {
                int bkend = backends.first((long)i);
                int target = backends.second((long)i);
                logger.trace("Available backend {}, target {}", (Object)bkend, (Object)target);
                if (bkend == 5 && target == 6) {
                    logger.info("CUDA detected and will be used if possible. Use DnnTools.setUseCuda(false) to turn this off.");
                    cudaAvailable = true;
                    useCuda = true;
                }
                ++i;
            }
            if (!cudaAvailable) {
                logger.warn("CUDA is not available - no compatible backend found with OpenCV DNN");
            }
        } else if (cudaDeviceCount < 0) {
            logger.warn("CUDA is not available - device count returns {}, which may mean a driver is missing or incompatible", (Object)cudaDeviceCount);
        } else {
            logger.debug("CUDA is not available (OpenCV not compiled with CUDA support)");
        }
    }

    public static class DNNLayer {
        private String name;
        private int id;
        private List<DnnShape> inputShapes;
        private List<DnnShape> outputShapes;

        private DNNLayer(String name, int id, List<DnnShape> inputShapes, List<DnnShape> outputShapes) {
            this.name = name;
            this.id = id;
            this.inputShapes = inputShapes;
            this.outputShapes = outputShapes;
        }

        public String getName() {
            return this.name;
        }

        public int getID() {
            return this.id;
        }

        public List<DnnShape> getInputShapes() {
            return Collections.unmodifiableList(this.inputShapes);
        }

        public List<DnnShape> getOutputShapes() {
            return Collections.unmodifiableList(this.outputShapes);
        }

        public String toString() {
            return String.format("%s \t%s -> %s", this.name, this.inputShapes, this.outputShapes);
        }
    }
}

