/*
 * Decompiled with CFR 0.152.
 */
package qupath.opencv.dnn;

import com.google.common.collect.Lists;
import java.awt.image.BufferedImage;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.Future;
import java.util.function.IntFunction;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.classifiers.object.AbstractObjectClassifier;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.ImageServer;
import qupath.lib.io.UriResource;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.PathObjectFilter;
import qupath.lib.objects.PathObjectTools;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.roi.interfaces.ROI;
import qupath.opencv.dnn.DnnModel;
import qupath.opencv.dnn.DnnTools;

public class DnnObjectClassifier
extends AbstractObjectClassifier<BufferedImage>
implements UriResource {
    private static final Logger logger = LoggerFactory.getLogger(DnnObjectClassifier.class);
    private DnnModel model;
    private List<PathClass> pathClasses;
    private double requestedPixelSize = 1.0;
    private int width;
    private int height;
    boolean preferNucleus = true;
    private int batchSize = 4;

    public Collection<PathClass> getPathClasses() {
        return Collections.unmodifiableList(this.pathClasses);
    }

    public DnnObjectClassifier(PathObjectFilter filter, DnnModel model, List<PathClass> pathClasses, int width, int height, double requestedPixelSize) {
        super(filter);
        this.model = model;
        this.pathClasses = new ArrayList<PathClass>(pathClasses);
        this.width = width;
        this.height = height;
        this.requestedPixelSize = requestedPixelSize;
    }

    public int classifyObjects(ImageData<BufferedImage> imageData, Collection<? extends PathObject> pathObjects, boolean resetExistingClass) {
        ImageServer server = imageData.getServer();
        double ds = Double.isFinite(this.requestedPixelSize) && this.requestedPixelSize > 0.0 ? this.requestedPixelSize / server.getPixelCalibration().getAveragedPixelSize().doubleValue() : 1.0;
        double downsample = ds;
        ForkJoinPool pool = ForkJoinPool.commonPool();
        ArrayList<Future> futures = new ArrayList<Future>();
        for (List list : Lists.partition((List)Lists.newArrayList(pathObjects), (int)Math.max(1, this.batchSize))) {
            futures.add(pool.submit(() -> this.tryToClassify(list, (ImageServer<BufferedImage>)server, downsample, (int i) -> this.pathClasses.get(i))));
        }
        int reclassified = 0;
        for (ForkJoinTask forkJoinTask : futures) {
            try {
                reclassified += ((Integer)forkJoinTask.get()).intValue();
            }
            catch (Exception e) {
                throw new RuntimeException(e);
            }
        }
        return reclassified;
    }

    protected boolean tryToClassify(PathObject pathObject, ImageServer<BufferedImage> server, double downsample, IntFunction<PathClass> classifier) {
        return this.tryToClassify(Collections.singletonList(pathObject), server, downsample, classifier) != 0;
    }

    protected int tryToClassify(List<? extends PathObject> pathObjects, ImageServer<BufferedImage> server, double downsample, IntFunction<PathClass> classifier) {
        int count = 0;
        try {
            ArrayList<Mat> inputImages = new ArrayList<Mat>();
            int n = pathObjects.size();
            int i = 0;
            for (PathObject pathObject : pathObjects) {
                ROI roi = PathObjectTools.getROI((PathObject)pathObject, (boolean)this.preferNucleus);
                if (roi == null) {
                    logger.warn("Cannot classify an object without a ROI!");
                    return 0;
                }
                Mat input = DnnTools.readPatch(server, roi, downsample, this.width, this.height);
                inputImages.add(input);
                ++i;
            }
            List<Mat> output = this.model.batchPredict(inputImages);
            assert (output.size() == n);
            for (i = 0; i < n; ++i) {
                PathClass pathClassNew;
                int dim;
                Indexer indexer = output.get(i).createIndexer();
                long[] sizes = indexer.sizes();
                int nClasses = this.pathClasses.size();
                for (dim = 0; dim < sizes.length && sizes[dim] != (long)nClasses; ++dim) {
                }
                if (dim == sizes.length) {
                    if (nClasses == 1) {
                        logger.error("Unable to find classification axis in output! Sizes {} for single class", (Object)Arrays.toString(sizes));
                    } else {
                        logger.error("Unable to find classification axis in output! Sizes {} for {} classes", (Object)Arrays.toString(sizes), (Object)nClasses);
                    }
                    throw new IllegalArgumentException("Unable to find classification axis in prediction output!");
                }
                PathObject pathObject = pathObjects.get(i);
                long[] inds = sizes;
                Arrays.fill(inds, 0L);
                double maxPred = Double.NEGATIVE_INFINITY;
                int maxPredInd = -1;
                for (int d = 0; d < nClasses; ++d) {
                    inds[dim] = d;
                    double pred = indexer.getDouble(inds);
                    if (!(pred > maxPred)) continue;
                    maxPred = pred;
                    maxPredInd = d;
                }
                PathClass pathClassOld = pathObject.getPathClass();
                if (pathClassOld == (pathClassNew = this.pathClasses.get(maxPredInd))) continue;
                pathObject.setPathClass(pathClassNew);
                ++count;
            }
        }
        catch (IOException e) {
            logger.warn("Error classifying object: " + e.getLocalizedMessage(), (Throwable)e);
        }
        return count;
    }

    public Map<String, Integer> getMissingFeatures(ImageData<BufferedImage> imageData, Collection<? extends PathObject> pathObjects) {
        return Collections.emptyMap();
    }

    public Collection<URI> getURIs() throws IOException {
        if (this.model instanceof UriResource) {
            return ((UriResource)this.model).getURIs();
        }
        return Collections.emptyList();
    }

    public boolean updateURIs(Map<URI, URI> replacements) throws IOException {
        if (this.model instanceof UriResource) {
            return ((UriResource)this.model).updateURIs(replacements);
        }
        return false;
    }
}

