/*
 * Decompiled with CFR 0.152.
 */
package qupath.opencv.ml.pixel;

import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.Rectangle;
import java.awt.RenderingHints;
import java.awt.Shape;
import java.awt.Stroke;
import java.awt.geom.AffineTransform;
import java.awt.geom.Point2D;
import java.awt.image.BufferedImage;
import java.awt.image.WritableRaster;
import java.io.IOException;
import java.lang.invoke.CallSite;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.WeakHashMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.stream.Collectors;
import java.util.stream.LongStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.awt.common.BufferedImageTools;
import qupath.lib.common.GeneralTools;
import qupath.lib.common.ThreadTools;
import qupath.lib.geom.Point2;
import qupath.lib.images.servers.ImageServer;
import qupath.lib.images.servers.ImageServerMetadata;
import qupath.lib.images.servers.PixelCalibration;
import qupath.lib.images.servers.TileRequest;
import qupath.lib.measurements.MeasurementList;
import qupath.lib.measurements.MeasurementListFactory;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.objects.classes.PathClassTools;
import qupath.lib.regions.ImagePlane;
import qupath.lib.regions.RegionRequest;
import qupath.lib.roi.ROIs;
import qupath.lib.roi.RoiTools;
import qupath.lib.roi.interfaces.ROI;

public class PixelClassificationMeasurementManager {
    private static final Logger logger = LoggerFactory.getLogger(PixelClassificationMeasurementManager.class);
    private static final Map<ImageServer<BufferedImage>, Map<ROI, MeasurementList>> measuredROIs = Collections.synchronizedMap(new WeakHashMap());
    private final ImageServer<BufferedImage> classifierServer;
    private List<String> measurementNames = null;
    private ROI rootROI = null;
    private ThreadLocal<BufferedImage> imgTileMask = new ThreadLocal();
    private boolean isMulticlass = false;
    private double requestedDownsample;
    private double pixelArea;
    private String pixelAreaUnits;

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public PixelClassificationMeasurementManager(ImageServer<BufferedImage> classifierServer) {
        ImageServerMetadata.ChannelType type;
        this.classifierServer = classifierServer;
        Map<ImageServer<BufferedImage>, Map<ROI, MeasurementList>> map = measuredROIs;
        synchronized (map) {
            if (!measuredROIs.containsKey(classifierServer)) {
                measuredROIs.put(classifierServer, new HashMap());
            }
        }
        this.requestedDownsample = classifierServer.getDownsampleForResolution(0);
        PixelCalibration cal = classifierServer.getPixelCalibration();
        if (cal.unitsMatch2D()) {
            this.pixelArea = cal.getPixelWidth().doubleValue() * this.requestedDownsample * (cal.getPixelHeight().doubleValue() * this.requestedDownsample);
            this.pixelAreaUnits = cal.getPixelWidthUnit() + "^2";
        } else {
            this.pixelArea = this.requestedDownsample * this.requestedDownsample;
            this.pixelAreaUnits = "px^2";
        }
        if (classifierServer.nZSlices() == 1 || classifierServer.nTimepoints() == 1) {
            this.rootROI = ROIs.createRectangleROI((double)0.0, (double)0.0, (double)classifierServer.getWidth(), (double)classifierServer.getHeight(), (ImagePlane)ImagePlane.getDefaultPlane());
        }
        if ((type = classifierServer.getMetadata().getChannelType()) == ImageServerMetadata.ChannelType.MULTICLASS_PROBABILITY || type == ImageServerMetadata.ChannelType.PROBABILITY && classifierServer.nChannels() == 1) {
            this.isMulticlass = true;
        }
        this.updateMeasurements(classifierServer.getMetadata().getClassificationLabels(), null, this.pixelArea, this.pixelAreaUnits);
    }

    @Deprecated
    public Number getMeasurementValue(PathObject pathObject, String name, boolean cachedOnly) {
        if (cachedOnly) {
            return this.getCachedMeasurementValue(pathObject, name);
        }
        return this.getMeasurementValue(pathObject, name);
    }

    private ROI getROI(PathObject pathObject) {
        ROI roi = pathObject.getROI();
        if (roi == null || pathObject.isRootObject()) {
            return this.rootROI;
        }
        return roi;
    }

    public Number getCachedMeasurementValue(PathObject pathObject, String name) {
        return this.getCachedMeasurementValue(this.getROI(pathObject), name);
    }

    public Number getMeasurementValue(PathObject pathObject, String name) {
        return this.getMeasurementValue(this.getROI(pathObject), name);
    }

    @Deprecated
    public Number getMeasurementValue(ROI roi, String name, boolean cachedOnly) {
        if (cachedOnly) {
            return this.getCachedMeasurementValue(roi, name);
        }
        return this.getMeasurementValue(roi, name);
    }

    public Number getCachedMeasurementValue(ROI roi, String name) {
        MeasurementList ml = this.getMeasurementList(roi, null);
        if (ml == null) {
            return null;
        }
        return ml.get(name);
    }

    public Number getMeasurementValue(ROI roi, String name) {
        MeasurementList ml = this.getMeasurementList(roi, this.getDefaultPool());
        if (ml == null) {
            return null;
        }
        return ml.get(name);
    }

    public boolean addMeasurements(Collection<? extends PathObject> objectsToMeasure, String measurementID) {
        if (objectsToMeasure.isEmpty()) {
            return false;
        }
        measurementID = measurementID == null || ((String)measurementID).isBlank() ? "" : (((String)(measurementID = ((String)measurementID).strip())).endsWith(":") ? (String)measurementID + " " : (String)measurementID + ": ");
        int maxParallelism = this.calculatePreferredParallelism();
        int nObjectThreads = 1;
        int nTileThreads = maxParallelism;
        if (objectsToMeasure.size() > 1 && maxParallelism > 2) {
            if (objectsToMeasure.size() > maxParallelism) {
                nObjectThreads = maxParallelism - 1;
                nTileThreads = 2;
            } else {
                nObjectThreads = 2;
                nTileThreads = maxParallelism - 1;
            }
        }
        logger.debug("Measuring {} objects (object threads={}, tile threads={})", new Object[]{objectsToMeasure.size(), nObjectThreads, nTileThreads});
        ThreadFactory factoryObjects = ThreadTools.createThreadFactory((String)"pixel-classification-objects", (boolean)true, (int)6);
        ThreadFactory factoryTiles = ThreadTools.createThreadFactory((String)"pixel-classification-tiles", (boolean)true, (int)5);
        ExecutorService poolObjects = Executors.newFixedThreadPool(nObjectThreads, factoryObjects);
        ExecutorService poolTiles = Executors.newFixedThreadPool(nTileThreads, factoryTiles);
        Object measurementIdFinal = measurementID;
        ArrayList tasks = new ArrayList();
        for (PathObject pathObject : objectsToMeasure) {
            tasks.add(poolObjects.submit(() -> this.lambda$addMeasurements$0(pathObject, (String)measurementIdFinal, poolTiles)));
        }
        poolObjects.shutdown();
        try {
            for (Future future : tasks) {
                future.get();
            }
        }
        catch (Exception e) {
            throw new RuntimeException(e);
        }
        finally {
            poolTiles.shutdown();
        }
        return true;
    }

    private void measureObject(PathObject pathObject, String measurementID, ExecutorService pool) {
        try (MeasurementList ml = pathObject.getMeasurementList();){
            Map<String, Number> map = this.getMeasurementListAsMap(pathObject.getROI(), pool);
            List<String> measurementNames = this.getMeasurementNames();
            if (map.isEmpty() || measurementNames.isEmpty()) {
                logger.warn("Map or measurements names are empty!");
            }
            for (String name : measurementNames) {
                Number value = map.getOrDefault(name, null);
                if (value == null) {
                    ml.put(measurementID + name, Double.NaN);
                    continue;
                }
                ml.put(measurementID + name, value.doubleValue());
            }
        }
        if (!pathObject.isRootObject()) {
            pathObject.setLocked(true);
        }
    }

    private ExecutorService getDefaultPool() {
        return ForkJoinPool.commonPool();
    }

    private Map<String, Number> getMeasurementListAsMap(ROI roi, ExecutorService pool) {
        MeasurementList ml = this.getMeasurementList(roi, pool);
        return ml == null ? Collections.emptyMap() : Collections.unmodifiableMap(ml.asMap());
    }

    private MeasurementList getMeasurementList(ROI roi, ExecutorService pool) {
        if (roi == null) {
            return null;
        }
        Map map = measuredROIs.computeIfAbsent(this.classifierServer, s -> new ConcurrentHashMap());
        MeasurementList ml = map.getOrDefault(roi, null);
        if (ml == null && (ml = this.calculateMeasurements(roi, pool)) != null) {
            map.put(roi, ml);
        }
        return ml;
    }

    public List<String> getMeasurementNames() {
        return this.measurementNames == null ? Collections.emptyList() : this.measurementNames;
    }

    private static boolean completelyContainsTile(Shape shape, TileRequest tile, double padding) {
        return shape.contains((double)tile.getImageX() - padding, (double)tile.getImageY() - padding, (double)tile.getImageWidth() + padding * 2.0, (double)tile.getImageHeight() + padding * 2.0);
    }

    private static boolean mayIntersectTile(Shape shape, TileRequest tile, double padding) {
        return shape.intersects((double)tile.getImageX() - padding, (double)tile.getImageY() - padding, (double)tile.getImageWidth() + padding * 2.0, (double)tile.getImageHeight() + padding * 2.0);
    }

    /*
     * WARNING - void declaration
     */
    private MeasurementList calculateMeasurements(ROI roi, ExecutorService pool) {
        Collection<Object> requests;
        boolean cachedOnly = pool == null;
        Map classificationLabels = this.classifierServer.getMetadata().getClassificationLabels();
        long[] counts = null;
        ImageServer<BufferedImage> server = this.classifierServer;
        ImageServerMetadata.ChannelType type = this.classifierServer.getMetadata().getChannelType();
        if (type == ImageServerMetadata.ChannelType.FEATURE) {
            return null;
        }
        Shape shape = null;
        if (!roi.isPoint()) {
            shape = RoiTools.getShape((ROI)roi);
        }
        if (roi == this.rootROI) {
            requests = server.getTileRequestManager().getAllTileRequests();
        } else if (!roi.isEmpty()) {
            RegionRequest regionRequest = RegionRequest.createInstance((String)server.getPath(), (double)this.requestedDownsample, (ROI)roi);
            requests = server.getTileRequestManager().getTileRequests(regionRequest);
            if (shape != null) {
                Shape shapeTemp = shape;
                requests = requests.stream().filter(r -> PixelClassificationMeasurementManager.mayIntersectTile(shapeTemp, r, r.getDownsample())).toList();
            }
        } else {
            requests = Collections.emptyList();
        }
        if (requests.isEmpty()) {
            logger.debug("Request empty for {}", (Object)roi);
            return null;
        }
        HashMap<TileRequest, Object> localCache = new HashMap<TileRequest, Object>();
        ArrayList<TileRequest> tilesToRequest = new ArrayList<TileRequest>();
        ArrayList<TileRequest> missingTiles = new ArrayList<TileRequest>();
        for (TileRequest tileRequest : requests) {
            BufferedImage tile = (BufferedImage)this.classifierServer.getCachedTile(tileRequest);
            if (cachedOnly && tile == null) {
                return null;
            }
            tilesToRequest.add(tileRequest);
            if (tile != null) {
                localCache.put(tileRequest, tile);
                continue;
            }
            if (cachedOnly) {
                logger.trace("No cached tile for {} - returning now", (Object)tile);
                return null;
            }
            missingTiles.add(tileRequest);
        }
        HashMap<TileRequest, Future<BufferedImage>> requestMap = new HashMap<TileRequest, Future<BufferedImage>>();
        if (!missingTiles.isEmpty()) {
            boolean bl = missingTiles.size() > 1;
            for (TileRequest request : missingTiles) {
                if (bl) {
                    requestMap.put(request, pool.submit(() -> (BufferedImage)this.classifierServer.readRegion(request.getRegionRequest())));
                    continue;
                }
                try {
                    localCache.put(request, (BufferedImage)this.classifierServer.readRegion(request.getRegionRequest()));
                }
                catch (IOException e) {
                    logger.error("Error reading tile " + String.valueOf(request), (Throwable)e);
                    return null;
                }
            }
        }
        Object var14_16 = null;
        byte[] mask = null;
        BufferedImage imgMask = this.imgTileMask.get();
        Rectangle bounds = new Rectangle();
        Point2D.Double p1 = new Point2D.Double();
        Point2D.Double p2 = new Point2D.Double();
        long startTime = System.currentTimeMillis();
        for (TileRequest region : tilesToRequest) {
            BufferedImage tile = (BufferedImage)localCache.remove(region);
            if (!cachedOnly && tile == null) {
                try {
                    tile = (BufferedImage)((Future)requestMap.get(region)).get();
                }
                catch (Exception e) {
                    logger.error("Error requesting tile " + String.valueOf(region), (Throwable)e);
                }
            }
            if (tile == null) {
                return null;
            }
            if (imgMask == null || imgMask.getWidth() < tile.getWidth() || imgMask.getHeight() < tile.getHeight() || imgMask.getType() != 10) {
                imgMask = new BufferedImage(tile.getWidth(), tile.getHeight(), 10);
                this.imgTileMask.set(imgMask);
            }
            boolean fullMask = false;
            if (shape != null && PixelClassificationMeasurementManager.completelyContainsTile(shape, region, region.getDownsample())) {
                fullMask = true;
                bounds.setRect(0.0, 0.0, tile.getWidth(), tile.getHeight());
            } else {
                bounds.setBounds(0, 0, -1, -1);
                if (roi.isLine() || roi.isArea()) {
                    Graphics2D g2d = imgMask.createGraphics();
                    g2d.setColor(Color.BLACK);
                    g2d.fillRect(0, 0, tile.getWidth(), tile.getHeight());
                    g2d.setColor(Color.WHITE);
                    g2d.scale(1.0 / region.getDownsample(), 1.0 / region.getDownsample());
                    g2d.translate((double)(-region.getTileX()) * region.getDownsample(), (double)(-region.getTileY()) * region.getDownsample());
                    g2d.setRenderingHint(RenderingHints.KEY_ANTIALIASING, RenderingHints.VALUE_ANTIALIAS_OFF);
                    g2d.setRenderingHint(RenderingHints.KEY_FRACTIONALMETRICS, RenderingHints.VALUE_FRACTIONALMETRICS_OFF);
                    if (roi.isLine()) {
                        void var14_17;
                        float fDownsample = (float)region.getDownsample();
                        if (var14_17 == null || var14_17.getLineWidth() != fDownsample) {
                            BasicStroke basicStroke = new BasicStroke(fDownsample);
                        }
                        g2d.setStroke((Stroke)var14_17);
                        g2d.draw(shape);
                    } else if (roi.isArea()) {
                        g2d.fill(shape);
                    }
                    AffineTransform transform = g2d.getTransform();
                    ((Point2D)p1).setLocation(roi.getBoundsX(), roi.getBoundsY());
                    transform.transform(p1, p1);
                    ((Point2D)p2).setLocation(roi.getBoundsX() + roi.getBoundsWidth(), roi.getBoundsY() + roi.getBoundsHeight());
                    transform.transform(p2, p2);
                    bounds.x = (int)Math.max(0.0, ((Point2D)p1).getX() - 1.0);
                    bounds.y = (int)Math.max(0.0, ((Point2D)p1).getY() - 1.0);
                    bounds.width = (int)Math.min((double)tile.getWidth(), Math.ceil(((Point2D)p2).getX() + 1.0)) - bounds.x;
                    bounds.height = (int)Math.min((double)tile.getHeight(), Math.ceil(((Point2D)p2).getY() + 1.0)) - bounds.y;
                    g2d.dispose();
                } else if (roi.isPoint()) {
                    boolean anyPoints = false;
                    for (Point2 p : roi.getAllPoints()) {
                        int x = (int)((p.getX() - (double)region.getImageX()) / region.getDownsample());
                        int y = (int)((p.getY() - (double)region.getImageY()) / region.getDownsample());
                        if (x < 0 || y < 0 || x >= tile.getWidth() || y >= tile.getHeight()) continue;
                        if (!anyPoints) {
                            Graphics2D g2d = imgMask.createGraphics();
                            g2d.setColor(Color.BLACK);
                            g2d.fillRect(0, 0, tile.getWidth(), tile.getHeight());
                            g2d.dispose();
                            anyPoints = true;
                        }
                        imgMask.getRaster().setSample(x, y, 0, 255);
                        bounds.add(x, y);
                        bounds.add(x + 1, y + 1);
                    }
                    if (!anyPoints) continue;
                }
            }
            int h = tile.getHeight();
            int w = tile.getWidth();
            if (mask == null || mask.length != h * w) {
                mask = new byte[w * h];
            }
            int nChannels = tile.getSampleModel().getNumBands();
            try {
                switch (type) {
                    case CLASSIFICATION: {
                        counts = BufferedImageTools.computeUnsignedIntHistogram((WritableRaster)tile.getRaster(), (long[])counts, (WritableRaster)(fullMask ? null : imgMask.getRaster()), (Rectangle)bounds);
                        break;
                    }
                    case PROBABILITY: {
                        if (nChannels > 1) {
                            counts = BufferedImageTools.computeArgMaxHistogram((WritableRaster)tile.getRaster(), (long[])counts, (WritableRaster)(fullMask ? null : imgMask.getRaster()), (Rectangle)bounds);
                            break;
                        }
                    }
                    case MULTICLASS_PROBABILITY: {
                        if (counts == null) {
                            counts = new long[nChannels];
                        }
                        double threshold = PixelClassificationMeasurementManager.getProbabilityThreshold(tile.getRaster());
                        for (int c = 0; c < nChannels; ++c) {
                            int n = c;
                            counts[n] = counts[n] + BufferedImageTools.computeAboveThresholdCounts((WritableRaster)tile.getRaster(), (int)c, (double)threshold, (WritableRaster)(fullMask ? null : imgMask.getRaster()), (Rectangle)bounds);
                        }
                    }
                    default: {
                        return this.updateMeasurements(classificationLabels, counts, this.pixelArea, this.pixelAreaUnits);
                    }
                }
            }
            catch (Exception e) {
                logger.error("Error calculating classification areas", (Throwable)e);
                if (nChannels <= 1 || type != ImageServerMetadata.ChannelType.CLASSIFICATION) continue;
                logger.error("There are {} channels - are you sure this is really a classification image?", (Object)nChannels);
            }
        }
        long endTime = System.currentTimeMillis();
        if (logger.isDebugEnabled()) {
            long totalCounts = LongStream.of(counts).sum();
            logger.debug("Counted {} pixels in {} ms (area {} {})", new Object[]{totalCounts, endTime - startTime, GeneralTools.formatNumber((double)((double)totalCounts * this.pixelArea), (int)2), this.pixelAreaUnits});
        }
        return this.updateMeasurements(classificationLabels, counts, this.pixelArea, this.pixelAreaUnits);
    }

    protected int calculatePreferredParallelism() {
        int poolSize = this.getPoolSizeProp();
        if (poolSize > 0) {
            return poolSize;
        }
        int minSize = 2;
        int maxSize = ThreadTools.getParallelism();
        if (maxSize <= minSize) {
            return maxSize;
        }
        Runtime runtime = Runtime.getRuntime();
        long availableMemory = runtime.maxMemory() / 2L;
        int nThreads = (int)Math.min((long)maxSize, availableMemory / 0x20000000L);
        return GeneralTools.clipValue((int)nThreads, (int)minSize, (int)maxSize);
    }

    protected int getPoolSizeProp() {
        String prop = System.getProperty("pixel.classification.pool.size");
        if (prop != null) {
            try {
                return Integer.parseInt(prop);
            }
            catch (NumberFormatException e) {
                logger.error("Error parsing pixel.classification.pool.size", (Throwable)e);
            }
        }
        return -1;
    }

    public static double getProbabilityThreshold(WritableRaster raster) {
        return switch (raster.getTransferType()) {
            case 0, 1, 2, 3 -> 127.5;
            default -> 0.5;
        };
    }

    private synchronized MeasurementList updateMeasurements(Map<Integer, PathClass> classificationLabels, long[] counts, double pixelArea, String pixelAreaUnits) {
        PathClass pathClass;
        long total = counts == null ? 0L : GeneralTools.sum((long[])counts);
        LinkedHashSet<PathClass> pathClasses = new LinkedHashSet<PathClass>(classificationLabels.values());
        boolean addNames = this.measurementNames == null;
        ArrayList<CallSite> tempList = null;
        int nMeasurements = pathClasses.size() * 2;
        if (!this.isMulticlass) {
            nMeasurements += 2;
        }
        if (addNames) {
            tempList = new ArrayList<CallSite>();
            this.measurementNames = Collections.unmodifiableList(tempList);
        } else {
            nMeasurements = this.measurementNames.size();
        }
        MeasurementList measurementList = MeasurementListFactory.createMeasurementList((int)nMeasurements, (MeasurementList.MeasurementListType)MeasurementList.MeasurementListType.DOUBLE);
        Set ignored = pathClasses.stream().filter(p -> p == null || PathClassTools.isIgnoredClass((PathClass)p)).collect(Collectors.toSet());
        LinkedHashMap<PathClass, Long> pathClassTotals = new LinkedHashMap<PathClass, Long>();
        long totalWithoutIgnored = 0L;
        if (counts != null) {
            for (Map.Entry entry : classificationLabels.entrySet()) {
                pathClass = (PathClass)entry.getValue();
                if (pathClass == null || ignored.contains(pathClass)) continue;
                int c = (Integer)entry.getKey();
                long temp = counts == null || c >= counts.length ? 0L : counts[c];
                totalWithoutIgnored += temp;
                pathClassTotals.put(pathClass, pathClassTotals.getOrDefault(pathClass, 0L) + temp);
            }
        } else {
            for (PathClass pathClass2 : pathClasses) {
                if (pathClass2 == null || ignored.contains(pathClass2)) continue;
                pathClassTotals.put(pathClass2, 0L);
            }
        }
        for (Map.Entry entry : pathClassTotals.entrySet()) {
            pathClass = (PathClass)entry.getKey();
            String name = pathClass.toString();
            String namePercentage = name + " %";
            String nameArea = name + " area " + pixelAreaUnits;
            if (tempList != null) {
                if (pathClassTotals.size() > 1) {
                    tempList.add((CallSite)((Object)namePercentage));
                }
                tempList.add((CallSite)((Object)nameArea));
            }
            if (counts == null) continue;
            long count = (Long)entry.getValue();
            if (pathClassTotals.size() > 1) {
                measurementList.put(namePercentage, (double)count / (double)totalWithoutIgnored * 100.0);
            }
            if (Double.isNaN(pixelArea)) continue;
            measurementList.put(nameArea, (double)count * pixelArea);
        }
        String nameArea = "Total annotated area " + pixelAreaUnits;
        String string = "Total quantified area " + pixelAreaUnits;
        if (counts != null && !Double.isNaN(pixelArea)) {
            if (tempList != null) {
                tempList.add((CallSite)((Object)nameArea));
                tempList.add((CallSite)((Object)string));
            }
            measurementList.put(nameArea, (double)totalWithoutIgnored * pixelArea);
            measurementList.put(string, (double)total * pixelArea);
        }
        measurementList.close();
        return measurementList;
    }

    private /* synthetic */ void lambda$addMeasurements$0(PathObject pathObject, String measurementIdFinal, ExecutorService poolTiles) {
        this.measureObject(pathObject, measurementIdFinal, poolTiles);
    }
}

