/*
 * Decompiled with CFR 0.152.
 */
package qupath.opencv.ml.objects;

import com.google.common.collect.Lists;
import java.nio.FloatBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.classifiers.object.AbstractObjectClassifier;
import qupath.lib.classifiers.object.ObjectClassifier;
import qupath.lib.common.GeneralTools;
import qupath.lib.images.ImageData;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.PathObjectFilter;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.objects.classes.PathClassTools;
import qupath.lib.objects.classes.Reclassifier;
import qupath.opencv.ml.OpenCVClassifiers;
import qupath.opencv.ml.objects.features.FeatureExtractor;

public class OpenCVMLClassifier<T>
extends AbstractObjectClassifier<T> {
    private static final Logger logger = LoggerFactory.getLogger(OpenCVMLClassifier.class);
    private FeatureExtractor<T> featureExtractor;
    private OpenCVClassifiers.OpenCVStatModel classifier;
    private List<PathClass> pathClasses;
    private boolean requestProbabilityEstimate = false;

    OpenCVMLClassifier(OpenCVClassifiers.OpenCVStatModel classifier, PathObjectFilter filter, FeatureExtractor<T> extractor, List<PathClass> pathClasses) {
        super(filter);
        this.classifier = classifier;
        this.featureExtractor = extractor;
        this.pathClasses = new ArrayList<PathClass>(pathClasses);
    }

    public static <T> ObjectClassifier<T> create(OpenCVClassifiers.OpenCVStatModel model, PathObjectFilter filter, FeatureExtractor<T> extractor, List<PathClass> pathClasses) {
        return new OpenCVMLClassifier<T>(model, filter, extractor, pathClasses);
    }

    public Collection<PathClass> getPathClasses() {
        return this.pathClasses == null ? Collections.emptyList() : Collections.unmodifiableList(this.pathClasses);
    }

    public int classifyObjects(ImageData<T> imageData, Collection<? extends PathObject> pathObjects, boolean resetExistingClass) {
        return OpenCVMLClassifier.classifyObjects(this.featureExtractor, this.classifier, this.pathClasses, imageData, pathObjects, resetExistingClass, this.requestProbabilityEstimate);
    }

    static <T> int classifyObjects(FeatureExtractor<T> featureExtractor, OpenCVClassifiers.OpenCVStatModel classifier, List<PathClass> pathClasses, ImageData<T> imageData, Collection<? extends PathObject> pathObjects, boolean resetExistingClass, boolean requestProbabilityEstimate) {
        long startTime;
        if (featureExtractor == null) {
            logger.warn("No feature extractor! Cannot classify {} objects", (Object)pathObjects.size());
            return 0;
        }
        int counter = 0;
        ArrayList<Reclassifier> reclassifiers = new ArrayList<Reclassifier>();
        int subListSize = Math.max(1, Math.min(pathObjects.size(), 0xA00000 / featureExtractor.nFeatures()));
        Mat samples = new Mat();
        Mat results = new Mat();
        Mat probabilities = requestProbabilityEstimate ? new Mat() : null;
        long lastTime = startTime = System.currentTimeMillis();
        int nComplete = 0;
        for (List tempObjectList : Lists.partition(new ArrayList<PathObject>(pathObjects), (int)subListSize)) {
            if (Thread.interrupted()) {
                logger.warn("Classification interrupted - will not be applied");
                return 0;
            }
            samples.create(tempObjectList.size(), featureExtractor.nFeatures(), opencv_core.CV_32FC1);
            FloatBuffer buffer = (FloatBuffer)samples.createBuffer();
            featureExtractor.extractFeatures(imageData, tempObjectList, buffer);
            nComplete += tempObjectList.size();
            long intermediateTime = System.currentTimeMillis();
            if (intermediateTime - lastTime > 1000L) {
                logger.debug("Calculated features for {}/{} objects in {} ms ({} ms per object, {}% complete)", new Object[]{nComplete, pathObjects.size(), intermediateTime - startTime, GeneralTools.formatNumber((double)((double)(intermediateTime - startTime) / (double)nComplete), (int)2), GeneralTools.formatNumber((double)((double)nComplete * 100.0 / (double)pathObjects.size()), (int)1)});
                lastTime = startTime;
            }
            boolean doMulticlass = classifier.supportsMulticlass();
            double threshold = 0.5;
            try {
                classifier.predict(samples, results, probabilities);
                IntIndexer idxResults = (IntIndexer)results.createIndexer();
                FloatIndexer idxProbabilities = null;
                if (probabilities != null && !probabilities.empty()) {
                    idxProbabilities = (FloatIndexer)probabilities.createIndexer();
                }
                if (doMulticlass && idxProbabilities != null) {
                    row = 0L;
                    int nCols = (int)idxProbabilities.size(2);
                    ArrayList<String> classifications = new ArrayList<String>();
                    for (PathObject pathObject : tempObjectList) {
                        classifications.clear();
                        for (int col = 0; col < nCols; ++col) {
                            PathClass pathClass;
                            double prob = idxProbabilities.get(row, (long)col);
                            if (!(prob >= threshold)) continue;
                            PathClass pathClass2 = pathClass = col >= pathClasses.size() ? null : pathClasses.get(col);
                            if (pathClass == null) continue;
                            classifications.add(pathClass.getName());
                        }
                        PathClass pathClass = PathClass.fromCollection(classifications);
                        if (PathClassTools.isIgnoredClass((PathClass)pathClass)) {
                            pathClass = null;
                        }
                        if (!resetExistingClass) {
                            pathClass = PathClassTools.mergeClasses((PathClass)pathObject.getPathClass(), (PathClass)pathClass);
                        }
                        reclassifiers.add(new Reclassifier(pathObject, pathClass, false));
                        ++row;
                    }
                } else {
                    row = 0L;
                    for (PathObject pathObject : tempObjectList) {
                        double probability;
                        int prediction = idxResults.get(row);
                        PathClass pathClass = pathClasses.get(prediction);
                        double d = probability = idxProbabilities == null ? Double.NaN : (double)idxProbabilities.get(row, (long)prediction);
                        if (PathClassTools.isIgnoredClass((PathClass)pathClass)) {
                            pathClass = null;
                            probability = Double.NaN;
                        }
                        if (!resetExistingClass) {
                            pathClass = PathClassTools.mergeClasses((PathClass)pathObject.getPathClass(), (PathClass)pathClass);
                            probability = Double.NaN;
                        }
                        reclassifiers.add(new Reclassifier(pathObject, pathClass, true, probability));
                        ++row;
                    }
                }
                idxResults.release();
                if (idxProbabilities != null) {
                    idxProbabilities.release();
                }
            }
            catch (Exception e) {
                logger.warn("Error with samples: {}", (Object)samples);
                logger.error(e.getLocalizedMessage(), (Throwable)e);
            }
            counter += tempObjectList.size();
        }
        long predictTime = System.currentTimeMillis() - startTime;
        logger.info("Prediction time: {} ms for {} objects ({} ns per object)", new Object[]{predictTime, pathObjects.size(), GeneralTools.formatNumber((double)((double)predictTime / (double)pathObjects.size() * 1000.0), (int)2)});
        samples.close();
        results.close();
        if (probabilities != null) {
            probabilities.close();
        }
        reclassifiers.stream().forEach(p -> p.apply());
        return counter;
    }

    public String toString() {
        return String.format("OpenCV object classifier (%s, %d classes)", this.classifier.getName(), this.getPathClasses().size());
    }

    public Map<String, Integer> getMissingFeatures(ImageData<T> imageData, Collection<? extends PathObject> pathObjects) {
        if (pathObjects == null) {
            pathObjects = this.getCompatibleObjects(imageData);
        }
        LinkedHashMap<String, Integer> missing = new LinkedHashMap<String, Integer>();
        Integer zero = 0;
        for (PathObject pathObject : pathObjects) {
            for (String name : this.featureExtractor.getMissingFeatures(imageData, pathObject)) {
                missing.put(name, missing.getOrDefault(name, zero) + 1);
            }
        }
        return missing;
    }
}

