/*
 * Decompiled with CFR 0.152.
 */
package qupath.opencv.ml.pixel;

import java.awt.image.BufferedImage;
import java.awt.image.ColorModel;
import java.awt.image.WritableRaster;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.locationtech.jts.geom.Geometry;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.analysis.images.ContourTracing;
import qupath.lib.classifiers.pixel.PixelClassificationImageServer;
import qupath.lib.classifiers.pixel.PixelClassifier;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.ImageServer;
import qupath.lib.images.servers.ImageServerMetadata;
import qupath.lib.images.servers.PixelCalibration;
import qupath.lib.images.servers.TileRequest;
import qupath.lib.objects.DefaultPathObjectComparator;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.PathObjectTools;
import qupath.lib.objects.PathObjects;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.objects.classes.PathClassTools;
import qupath.lib.objects.classes.Reclassifier;
import qupath.lib.objects.hierarchy.PathObjectHierarchy;
import qupath.lib.regions.ImagePlane;
import qupath.lib.regions.RegionRequest;
import qupath.lib.roi.GeometryTools;
import qupath.lib.roi.interfaces.ROI;
import qupath.opencv.ml.pixel.PixelClassificationMeasurementManager;
import qupath.opencv.ml.pixel.PixelClassifiers;

public class PixelClassifierTools {
    private static final Logger logger = LoggerFactory.getLogger(PixelClassifierTools.class);

    public static boolean createDetectionsFromPixelClassifier(PathObjectHierarchy hierarchy, ImageServer<BufferedImage> classifierServer, double minArea, double minHoleArea, CreateObjectOptions ... options) throws IOException {
        Set<PathObject> selected = hierarchy.getSelectionModel().getSelectedObjects();
        if (selected.isEmpty()) {
            selected = Collections.singleton(hierarchy.getRootObject());
        }
        return PixelClassifierTools.createObjectsFromPredictions(classifierServer, hierarchy, selected, roi -> PathObjects.createDetectionObject((ROI)roi), minArea, minHoleArea, options);
    }

    public static boolean createDetectionsFromPixelClassifier(ImageData<BufferedImage> imageData, PixelClassifier classifier, double minArea, double minHoleArea, CreateObjectOptions ... options) throws IOException {
        return PixelClassifierTools.createDetectionsFromPixelClassifier(imageData.getHierarchy(), PixelClassifierTools.createPixelClassificationServer(imageData, classifier), minArea, minHoleArea, options);
    }

    public static boolean createObjectsFromPredictions(ImageServer<BufferedImage> server, PathObjectHierarchy hierarchy, Collection<PathObject> selectedObjects, Function<ROI, ? extends PathObject> creator, double minArea, double minHoleArea, CreateObjectOptions ... options) throws IOException {
        Collection<PathObject> children;
        if (selectedObjects.isEmpty()) {
            return false;
        }
        HashSet<CreateObjectOptions> optionSet = new HashSet<CreateObjectOptions>(Arrays.asList(options));
        boolean doSplit = optionSet.contains((Object)CreateObjectOptions.SPLIT);
        boolean includeIgnored = optionSet.contains((Object)CreateObjectOptions.INCLUDE_IGNORED);
        boolean clearExisting = optionSet.contains((Object)CreateObjectOptions.DELETE_EXISTING);
        Set<Object> toSelect = optionSet.contains((Object)CreateObjectOptions.SELECT_NEW) ? new HashSet() : null;
        LinkedHashMap<PathObject, Collection<PathObject>> map = new LinkedHashMap<PathObject, Collection<PathObject>>();
        boolean firstWarning = true;
        List<PathObject> parentObjects = new ArrayList<PathObject>(selectedObjects);
        parentObjects = parentObjects.stream().filter(p -> p.isRootObject() || p.hasROI() && p.getROI().isArea()).sorted(Comparator.comparing(PathObject::getROI, Comparator.nullsFirst(Comparator.comparingDouble(ROI::getArea).reversed()))).toList();
        ArrayList<PathObject> completed = new ArrayList<PathObject>();
        ArrayList<PathObject> toDeselect = new ArrayList<PathObject>();
        for (PathObject pathObject : parentObjects) {
            if (clearExisting) {
                boolean willRemove = false;
                for (PathObject possibleAncestor : completed) {
                    if (!PathObjectTools.isAncestor((PathObject)pathObject, (PathObject)possibleAncestor)) continue;
                    willRemove = true;
                    break;
                }
                if (willRemove) {
                    toDeselect.add(pathObject);
                    logger.warn("Skipping {} during object creation (is descendant of an object that is already being processed)", (Object)pathObject);
                    continue;
                }
            }
            Map<Integer, PathClass> labels = PixelClassifierTools.parseClassificationLabels(server.getMetadata().getClassificationLabels(), includeIgnored);
            children = PixelClassifierTools.createObjectsFromPixelClassifier(server, labels, pathObject.getROI(), creator, minArea, minHoleArea, doSplit);
            if (pathObject.isDetection() && children.stream().anyMatch(p -> !p.isDetection())) {
                if (firstWarning) {
                    logger.warn("Cannot add non-detection objects to detections! Objects will be skipped...");
                    firstWarning = false;
                }
            } else {
                if (toSelect != null) {
                    toSelect.addAll(children);
                }
                map.put(pathObject, children);
            }
            completed.add(pathObject);
            if (!Thread.currentThread().isInterrupted()) continue;
            return false;
        }
        for (Map.Entry entry : map.entrySet()) {
            PathObject parent = (PathObject)entry.getKey();
            children = (Collection<PathObject>)entry.getValue();
            if (clearExisting && parent.hasChildObjects()) {
                parent.clearChildObjects();
            }
            parent.addChildObjects(children);
            if (parent.isRootObject()) continue;
            parent.setLocked(true);
        }
        if (map.size() == 1) {
            hierarchy.fireHierarchyChangedEvent(null, (PathObject)map.keySet().iterator().next());
        } else if (map.size() > 1) {
            hierarchy.fireHierarchyChangedEvent(null);
        }
        if (toSelect != null) {
            toSelect = toSelect.stream().filter(p -> PathObjectTools.hierarchyContainsObject((PathObjectHierarchy)hierarchy, (PathObject)p)).collect(Collectors.toSet());
            hierarchy.getSelectionModel().setSelectedObjects(toSelect, null);
        } else if (!toDeselect.isEmpty()) {
            hierarchy.getSelectionModel().deselectObjects(toDeselect);
        }
        return true;
    }

    public static boolean createAnnotationsFromPixelClassifier(ImageData<BufferedImage> imageData, PixelClassifier classifier, double minArea, double minHoleArea, CreateObjectOptions ... options) throws IOException {
        return PixelClassifierTools.createAnnotationsFromPixelClassifier(imageData.getHierarchy(), PixelClassifierTools.createPixelClassificationServer(imageData, classifier), minArea, minHoleArea, options);
    }

    public static boolean createAnnotationsFromPixelClassifier(PathObjectHierarchy hierarchy, ImageServer<BufferedImage> classifierServer, double minArea, double minHoleArea, CreateObjectOptions ... options) throws IOException {
        Set<PathObject> selected = hierarchy.getSelectionModel().getSelectedObjects();
        if (selected.isEmpty()) {
            selected = Collections.singleton(hierarchy.getRootObject());
        }
        return PixelClassifierTools.createObjectsFromPredictions(classifierServer, hierarchy, selected, roi -> {
            PathObject annotation = PathObjects.createAnnotationObject((ROI)roi);
            annotation.setLocked(true);
            return annotation;
        }, minArea, minHoleArea, options);
    }

    private static Map<Integer, PathClass> parseClassificationLabels(Map<Integer, PathClass> labelsOrig, boolean includeIgnored) {
        LinkedHashMap<Integer, PathClass> labels = new LinkedHashMap<Integer, PathClass>();
        for (Map.Entry<Integer, PathClass> entry : labelsOrig.entrySet()) {
            PathClass pathClass = entry.getValue();
            if (pathClass == null || pathClass == PathClass.NULL_CLASS || !includeIgnored && PathClassTools.isIgnoredClass((PathClass)pathClass)) continue;
            labels.put(entry.getKey(), pathClass);
        }
        return labels;
    }

    public static Collection<PathObject> createObjectsFromPixelClassifier(ImageServer<BufferedImage> server, Map<Integer, PathClass> labels, ROI roi, Function<ROI, ? extends PathObject> creator, double minArea, double minHoleArea, boolean doSplit) throws IOException {
        List<RegionRequest> regionRequests;
        Geometry clipArea;
        ContourTracing.ChannelThreshold[] thresholds;
        if (labels == null) {
            labels = PixelClassifierTools.parseClassificationLabels(server.getMetadata().getClassificationLabels(), false);
        }
        if (labels.isEmpty()) {
            throw new IllegalArgumentException("Cannot create objects for server - no classification labels are available!");
        }
        int nChannels = server.nChannels();
        ImageServerMetadata.ChannelType channelType = server.getMetadata().getChannelType();
        if (channelType == ImageServerMetadata.ChannelType.MULTICLASS_PROBABILITY || channelType == ImageServerMetadata.ChannelType.PROBABILITY && nChannels == 1) {
            double probabilityThreshold;
            switch (server.getPixelType()) {
                case INT16: 
                case INT32: 
                case INT8: 
                case UINT16: 
                case UINT32: {
                    logger.warn("Probability threshold for int image will be set to half the maximum value for the pixel type");
                }
                case UINT8: {
                    probabilityThreshold = server.getPixelType().getUpperBound().doubleValue() / 2.0;
                    break;
                }
                default: {
                    probabilityThreshold = 0.5;
                }
            }
            thresholds = (ContourTracing.ChannelThreshold[])labels.keySet().stream().map(pathClass -> ContourTracing.ChannelThreshold.createAbove((int)pathClass, (double)probabilityThreshold)).toArray(ContourTracing.ChannelThreshold[]::new);
        } else {
            thresholds = (ContourTracing.ChannelThreshold[])labels.keySet().stream().map(ContourTracing.ChannelThreshold::create).toArray(ContourTracing.ChannelThreshold[]::new);
        }
        if (roi != null && !roi.isArea()) {
            logger.warn("Cannot create objects for non-area ROIs");
            return Collections.emptyList();
        }
        Geometry geometry = clipArea = roi == null ? null : roi.getGeometry();
        if (roi != null) {
            RegionRequest request = RegionRequest.createInstance((String)server.getPath(), (double)server.getDownsampleForResolution(0), (ROI)roi);
            regionRequests = Collections.singletonList(request);
        } else {
            regionRequests = RegionRequest.createAllRequests(server, (double)server.getDownsampleForResolution(0));
        }
        double pixelArea = server.getPixelCalibration().getPixelWidth().doubleValue() * server.getPixelCalibration().getPixelHeight().doubleValue();
        double minAreaPixels = minArea / pixelArea;
        double minHoleAreaPixels = minHoleArea / pixelArea;
        ArrayList<PathObject> pathObjects = new ArrayList<PathObject>();
        for (RegionRequest regionRequest : regionRequests) {
            Map geometryMap = ContourTracing.traceGeometries(server, (RegionRequest)regionRequest, (Geometry)clipArea, (ContourTracing.ChannelThreshold[])thresholds);
            Map<Integer, PathClass> labelMap = labels;
            pathObjects.addAll(geometryMap.entrySet().parallelStream().flatMap(e -> PixelClassifierTools.geometryToObjects((Geometry)e.getValue(), creator, (PathClass)labelMap.get(e.getKey()), minAreaPixels, minHoleAreaPixels, doSplit, regionRequest.getImagePlane()).stream()).toList());
        }
        pathObjects.sort(DefaultPathObjectComparator.getInstance());
        return pathObjects;
    }

    private static List<PathObject> geometryToObjects(Geometry geometry, Function<ROI, ? extends PathObject> creator, PathClass pathClass, double minAreaPixels, double minHoleAreaPixels, boolean doSplit, ImagePlane plane) {
        if ((geometry = GeometryTools.refineAreas((Geometry)geometry, (double)minAreaPixels, (double)minHoleAreaPixels)) == null || geometry.isEmpty()) {
            return Collections.emptyList();
        }
        if (doSplit) {
            ArrayList<PathObject> pathObjects = new ArrayList<PathObject>();
            for (int i = 0; i < geometry.getNumGeometries(); ++i) {
                Geometry geom = geometry.getGeometryN(i);
                ROI r = GeometryTools.geometryToROI((Geometry)geom, (ImagePlane)plane);
                PathObject newObject = creator.apply(r);
                newObject.setPathClass(pathClass);
                pathObjects.add(newObject);
            }
            return pathObjects;
        }
        ROI r = GeometryTools.geometryToROI((Geometry)geometry, (ImagePlane)plane);
        PathObject newObject = creator.apply(r);
        newObject.setPathClass(pathClass);
        return Collections.singletonList(newObject);
    }

    public static ImageServer<BufferedImage> createPixelClassificationServer(ImageData<BufferedImage> imageData, PixelClassifier classifier) {
        return PixelClassifierTools.createPixelClassificationServer(imageData, classifier, null, null, false);
    }

    public static ImageServer<BufferedImage> createPixelClassificationServer(ImageData<BufferedImage> imageData, PixelClassifier classifier, String id, ColorModel colorModel, boolean cacheAllTiles) {
        PixelClassificationImageServer server = new PixelClassificationImageServer(imageData, classifier, id, colorModel);
        if (cacheAllTiles) {
            logger.debug("Caching all tiles for {}", (Object)server);
            server.readAllTiles();
        }
        return server;
    }

    public static ImageServer<BufferedImage> createThresholdServer(ImageServer<BufferedImage> server, Map<Integer, ? extends Number> thresholds, PathClass below, PathClass aboveEquals) {
        PixelClassifiers.ClassifierFunction fun = PixelClassifiers.createThresholdFunction(thresholds);
        Map<Integer, PathClass> labels = Map.of(0, below, 1, aboveEquals);
        return PixelClassifierTools.createThresholdServer(server, labels, fun);
    }

    public static ImageServer<BufferedImage> createThresholdServer(ImageServer<BufferedImage> server, int channel, double threshold, PathClass below, PathClass aboveEquals) {
        PixelClassifiers.ClassifierFunction fun = PixelClassifiers.createThresholdFunction(channel, threshold);
        Map<Integer, PathClass> labels = Map.of(0, below, 1, aboveEquals);
        return PixelClassifierTools.createThresholdServer(server, labels, fun);
    }

    private static ImageServer<BufferedImage> createThresholdServer(ImageServer<BufferedImage> server, Map<Integer, PathClass> labels, PixelClassifiers.ClassifierFunction fun) {
        PixelCalibration inputResolution = server.getPixelCalibration();
        double scale = server.getDownsampleForResolution(0);
        if (scale > 1.0) {
            inputResolution = inputResolution.createScaledInstance(scale, scale);
        }
        PixelClassifier classifier = PixelClassifiers.createThresholdClassifier(inputResolution, labels, fun);
        return PixelClassifierTools.createPixelClassificationServer((ImageData<BufferedImage>)new ImageData(server), classifier);
    }

    public static PixelClassificationMeasurementManager createMeasurementManager(ImageData<BufferedImage> imageData, PixelClassifier classifier) {
        return PixelClassifierTools.createMeasurementManager(PixelClassifierTools.createPixelClassificationServer(imageData, classifier));
    }

    public static PixelClassificationMeasurementManager createMeasurementManager(ImageServer<BufferedImage> classifierServer) {
        return new PixelClassificationMeasurementManager(classifierServer);
    }

    public static boolean addMeasurementsToSelectedObjects(ImageData<BufferedImage> imageData, PixelClassifier classifier, String measurementID) {
        PixelClassificationMeasurementManager manager = PixelClassifierTools.createMeasurementManager(imageData, classifier);
        PathObjectHierarchy hierarchy = imageData.getHierarchy();
        Set<PathObject> objectsToMeasure = hierarchy.getSelectionModel().getSelectedObjects();
        if (objectsToMeasure.isEmpty()) {
            objectsToMeasure = Collections.singleton(hierarchy.getRootObject());
        }
        PixelClassifierTools.addMeasurements(objectsToMeasure, manager, measurementID);
        hierarchy.fireObjectMeasurementsChangedEvent((Object)manager, objectsToMeasure);
        return true;
    }

    public static boolean addMeasurements(Collection<? extends PathObject> objectsToMeasure, PixelClassificationMeasurementManager manager, String measurementID) {
        return manager.addMeasurements(objectsToMeasure, measurementID);
    }

    public static void classifyObjectsByCentroid(ImageServer<BufferedImage> classifierServer, Collection<PathObject> pathObjects, boolean preferNucleusROI) {
        Map labels = classifierServer.getMetadata().getClassificationLabels();
        List<Reclassifier> reclassifiers = pathObjects.parallelStream().map(p -> {
            try {
                ROI roi = PathObjectTools.getROI((PathObject)p, (boolean)preferNucleusROI);
                int x = (int)roi.getCentroidX();
                int y = (int)roi.getCentroidY();
                int ind = PixelClassifierTools.getClassification(classifierServer, x, y, roi.getZ(), roi.getT());
                return new Reclassifier(p, (PathClass)labels.getOrDefault(ind, null), false);
            }
            catch (Exception e) {
                return new Reclassifier(p, null, false);
            }
        }).toList();
        reclassifiers.parallelStream().forEach(r -> r.apply());
    }

    public static int getClassification(ImageServer<BufferedImage> server, int x, int y, int z, int t) throws IOException {
        int nBands;
        ImageServerMetadata.ChannelType type = server.getMetadata().getChannelType();
        if (type != ImageServerMetadata.ChannelType.CLASSIFICATION && type != ImageServerMetadata.ChannelType.PROBABILITY) {
            return -1;
        }
        TileRequest tile = server.getTileRequestManager().getTileRequest(0, x, y, z, t);
        if (tile == null) {
            return -1;
        }
        int xx = (int)Math.floor((double)x / tile.getDownsample() - (double)tile.getTileX());
        int yy = (int)Math.floor((double)y / tile.getDownsample() - (double)tile.getTileY());
        BufferedImage img = (BufferedImage)server.readRegion(tile.getRegionRequest());
        if (xx >= img.getWidth()) {
            xx = img.getWidth() - 1;
        }
        if (xx < 0) {
            xx = 0;
        }
        if (yy >= img.getHeight()) {
            yy = img.getHeight() - 1;
        }
        if (yy < 0) {
            yy = 0;
        }
        if ((nBands = img.getRaster().getNumBands()) == 1 && type == ImageServerMetadata.ChannelType.CLASSIFICATION) {
            try {
                return img.getRaster().getSample(xx, yy, 0);
            }
            catch (Exception e) {
                logger.error("Error requesting classification", (Throwable)e);
                return -1;
            }
        }
        if (type == ImageServerMetadata.ChannelType.PROBABILITY) {
            int maxInd = -1;
            double maxVal = Double.NEGATIVE_INFINITY;
            WritableRaster raster = img.getRaster();
            double threshold = raster.getTransferType() == 0 ? 127.5 : 0.5;
            for (int b = 0; b < nBands; ++b) {
                double temp = raster.getSampleDouble(xx, yy, b);
                if (!(temp > maxVal)) continue;
                maxInd = b;
                maxVal = temp;
            }
            if (nBands == 1 && maxVal < threshold) {
                return -1;
            }
            return maxInd;
        }
        return -1;
    }

    public static void classifyCellsByCentroid(ImageData<BufferedImage> imageData, PixelClassifier classifier, boolean preferNucleusROI) {
        PixelClassifierTools.classifyObjectsByCentroid(imageData, classifier, imageData.getHierarchy().getCellObjects(), preferNucleusROI);
    }

    public static void classifyDetectionsByCentroid(ImageData<BufferedImage> imageData, PixelClassifier classifier) {
        PixelClassifierTools.classifyObjectsByCentroid(imageData, classifier, imageData.getHierarchy().getDetectionObjects(), true);
    }

    public static void classifyObjectsByCentroid(ImageData<BufferedImage> imageData, PixelClassifier classifier, Collection<PathObject> pathObjects, boolean preferNucleusROI) {
        PixelClassifierTools.classifyObjectsByCentroid(PixelClassifierTools.createPixelClassificationServer(imageData, classifier), pathObjects, preferNucleusROI);
        imageData.getHierarchy().fireObjectClassificationsChangedEvent((Object)classifier, pathObjects);
    }

    public static enum CreateObjectOptions {
        DELETE_EXISTING,
        SPLIT,
        INCLUDE_IGNORED,
        SELECT_NEW;


        public String toString() {
            switch (this.ordinal()) {
                case 0: {
                    return "Delete existing";
                }
                case 2: {
                    return "Include ignored";
                }
                case 1: {
                    return "Split";
                }
                case 3: {
                    return "Select new";
                }
            }
            throw new IllegalArgumentException("Unknown option " + String.valueOf((Object)this));
        }
    }
}

