/*
 * Decompiled with CFR 0.152.
 */
package qupath.lib.gui.viewer.overlays;

import java.awt.AlphaComposite;
import java.awt.Color;
import java.awt.Composite;
import java.awt.Graphics2D;
import java.awt.Image;
import java.awt.RenderingHints;
import java.awt.Shape;
import java.awt.image.BufferedImage;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.WeakHashMap;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.function.Function;
import javafx.application.Platform;
import javafx.beans.property.ObjectProperty;
import javafx.beans.property.SimpleObjectProperty;
import javafx.beans.value.ObservableBooleanValue;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.awt.common.AwtTools;
import qupath.lib.classifiers.pixel.PixelClassificationImageServer;
import qupath.lib.classifiers.pixel.PixelClassifier;
import qupath.lib.color.ColorToolsAwt;
import qupath.lib.common.GeneralTools;
import qupath.lib.common.ThreadTools;
import qupath.lib.geom.Point2;
import qupath.lib.gui.QuPathGUI;
import qupath.lib.gui.images.stores.ImageRenderer;
import qupath.lib.gui.prefs.PathPrefs;
import qupath.lib.gui.viewer.OverlayOptions;
import qupath.lib.gui.viewer.RegionFilter;
import qupath.lib.gui.viewer.overlays.AbstractImageOverlay;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.ImageChannel;
import qupath.lib.images.servers.ImageServer;
import qupath.lib.images.servers.ImageServerMetadata;
import qupath.lib.images.servers.PixelCalibration;
import qupath.lib.images.servers.ServerTools;
import qupath.lib.images.servers.TileRequest;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.objects.hierarchy.PathObjectHierarchy;
import qupath.lib.regions.ImageRegion;
import qupath.lib.regions.RegionRequest;
import qupath.opencv.ops.ImageDataOp;
import qupath.opencv.ops.ImageDataServer;
import qupath.opencv.ops.ImageOps;

public class PixelClassificationOverlay
extends AbstractImageOverlay {
    private static final Logger logger = LoggerFactory.getLogger(PixelClassificationOverlay.class);
    private final ObjectProperty<ImageRenderer> renderer = new SimpleObjectProperty();
    private long rendererLastTimestamp = 0L;
    private final Map<BufferedImage, BufferedImage> cacheRGB = Collections.synchronizedMap(new WeakHashMap());
    private final Set<TileRequest> pendingRequests = Collections.synchronizedSet(new HashSet());
    private final Set<TileRequest> currentRequests = Collections.synchronizedSet(new HashSet());
    private int maxThreads = ThreadTools.getParallelism();
    private final ThreadPoolExecutor pool;
    private final Function<ImageData<BufferedImage>, ImageServer<BufferedImage>> fun;
    private boolean livePrediction = false;
    private Map<ImageData<BufferedImage>, ImageServer<BufferedImage>> cachedServers = new WeakHashMap<ImageData<BufferedImage>, ImageServer<BufferedImage>>();
    private final ObservableBooleanValue showOverlay;
    private RegionFilter lastRegionFilter;
    private final Map<RegionRequest, Boolean> acceptedTiles = new HashMap<RegionRequest, Boolean>();
    private long lastHierarchyEventCount = -1L;

    private PixelClassificationOverlay(OverlayOptions options, int nThreads, Function<ImageData<BufferedImage>, ImageServer<BufferedImage>> fun) {
        super(options);
        this.showOverlay = options.showPixelClassificationProperty();
        ThreadFactory threadFactory = ThreadTools.createThreadFactory((String)"classifier-overlay", (boolean)true, (int)3);
        if (nThreads > 0) {
            this.maxThreads = nThreads;
        }
        this.pool = (ThreadPoolExecutor)Executors.newFixedThreadPool(this.maxThreads, threadFactory);
        this.renderer.addListener((v, o, n) -> this.cacheRGB.clear());
        this.fun = fun;
    }

    public static PixelClassificationOverlay create(OverlayOptions options, PixelClassifier classifier) {
        int nThreads = Math.max(1, PathPrefs.numCommandThreadsProperty().get());
        return PixelClassificationOverlay.create(options, classifier, nThreads);
    }

    public static PixelClassificationOverlay create(OverlayOptions options, PixelClassifier classifier, int nThreads) {
        return new PixelClassificationOverlay(options, Math.max(1, nThreads), new ClassifierServerFunction(classifier));
    }

    public static PixelClassificationOverlay create(OverlayOptions options, Function<ImageData<BufferedImage>, ImageServer<BufferedImage>> fun, ImageRenderer renderer) {
        PixelClassificationOverlay overlay = new PixelClassificationOverlay(options, 1, fun);
        overlay.setRenderer(renderer);
        return overlay;
    }

    public static PixelClassificationOverlay create(OverlayOptions options, Map<ImageData<BufferedImage>, ImageServer<BufferedImage>> map, ImageRenderer renderer) {
        PixelClassificationOverlay overlay = new PixelClassificationOverlay(options, 1, data -> null);
        overlay.cachedServers = map;
        overlay.setRenderer(renderer);
        return overlay;
    }

    @Deprecated
    public static PixelClassificationOverlay createFeatureDisplayOverlay(OverlayOptions options, Function<ImageData<BufferedImage>, ImageServer<BufferedImage>> fun, ImageRenderer renderer) {
        PixelClassificationOverlay overlay = new PixelClassificationOverlay(options, 1, fun);
        overlay.setRenderer(renderer);
        return overlay;
    }

    public ObjectProperty<ImageRenderer> rendererProperty() {
        return this.renderer;
    }

    public ImageRenderer getRenderer() {
        return (ImageRenderer)this.renderer.get();
    }

    public void setMaxThreads(int nThreads) {
        this.maxThreads = Math.max(1, nThreads);
        if (this.maxThreads < this.pool.getCorePoolSize()) {
            this.pool.setCorePoolSize(this.maxThreads);
        }
        this.pool.setMaximumPoolSize(this.maxThreads);
        this.pool.setCorePoolSize(this.maxThreads);
        logger.debug("Number of parallel threads set to {}", (Object)nThreads);
    }

    public int getMaxThreads() {
        return this.maxThreads;
    }

    public void setRenderer(ImageRenderer renderer) {
        this.renderer.set((Object)renderer);
    }

    public boolean getLivePrediction() {
        return this.livePrediction;
    }

    public void setLivePrediction(boolean livePrediction) {
        this.livePrediction = livePrediction;
    }

    public ImageServer<BufferedImage> getPixelClassificationServer(ImageData<BufferedImage> imageData) {
        return imageData == null ? null : this.cachedServers.computeIfAbsent(imageData, this::createPixelClassificationServer);
    }

    @Override
    public void paintOverlay(Graphics2D g2d, ImageRegion imageRegion, double downsampleFactor, ImageData<BufferedImage> imageData, boolean paintCompletely) {
        ImageRenderer renderer;
        if (!this.showOverlay.get()) {
            return;
        }
        if (imageData == null) {
            return;
        }
        ImageServer<BufferedImage> server = this.getPixelClassificationServer(imageData);
        if (server == null) {
            return;
        }
        Color colorComplete = imageData.getImageType() == ImageData.ImageType.FLUORESCENCE ? ColorToolsAwt.getCachedColor((int)255, (int)255, (int)255, (int)1) : ColorToolsAwt.getCachedColor((int)0, (int)0, (int)0, (int)1);
        Shape shapeRegion = g2d.getClip();
        RegionRequest fullRequest = shapeRegion == null ? RegionRequest.createInstance((String)server.getPath(), (double)downsampleFactor, (ImageRegion)imageRegion) : RegionRequest.createInstance((String)server.getPath(), (double)downsampleFactor, (ImageRegion)AwtTools.getImageRegion((Shape)shapeRegion, (int)imageRegion.getZ(), (int)imageRegion.getT()));
        RegionFilter filter = this.getOverlayOptions().getPixelClassificationRegionFilter();
        if (!Objects.equals(filter, this.lastRegionFilter)) {
            this.resetAcceptedTiles();
            this.lastRegionFilter = filter;
        }
        if ((renderer = (ImageRenderer)this.renderer.get()) != null && this.rendererLastTimestamp != renderer.getLastChangeTimestamp()) {
            this.clearCache();
            this.rendererLastTimestamp = renderer.getLastChangeTimestamp();
        }
        double requestedDownsample = ServerTools.getPreferredDownsampleFactor(server, (double)downsampleFactor);
        Graphics2D gCopy = (Graphics2D)g2d.create();
        if (requestedDownsample > server.getDownsampleForResolution(0)) {
            gCopy.setRenderingHint(RenderingHints.KEY_INTERPOLATION, RenderingHints.VALUE_INTERPOLATION_BILINEAR);
        } else {
            this.setInterpolation(gCopy);
        }
        AlphaComposite comp = this.getAlphaComposite();
        Composite previousComposite = gCopy.getComposite();
        if (comp != null) {
            if (previousComposite instanceof AlphaComposite) {
                gCopy.setComposite(comp.derive(((AlphaComposite)previousComposite).getAlpha() * comp.getAlpha()));
            } else {
                gCopy.setComposite(comp);
            }
        }
        ArrayList tiles = server.getTileRequestManager().getTileRequests(fullRequest);
        if (fullRequest != null) {
            double x = (double)(Math.max(0, fullRequest.getMinX()) + Math.min(server.getWidth(), fullRequest.getMaxX())) / 2.0;
            double y = (double)(Math.max(0, fullRequest.getMinY()) + Math.min(server.getHeight(), fullRequest.getMaxY())) / 2.0;
            Point2 p = new Point2(x, y);
            tiles = new ArrayList(tiles);
            ((List)tiles).sort(Comparator.comparingDouble(t -> p.distanceSq((double)t.getImageX() + (double)t.getImageWidth() / 2.0, (double)t.getImageY() + (double)t.getImageHeight() / 2.0)));
        }
        this.pendingRequests.clear();
        for (TileRequest tile : tiles) {
            BufferedImage imgRGB;
            RegionRequest request = tile.getRegionRequest();
            if (filter != null) {
                if (this.lastHierarchyEventCount != imageData.getHierarchy().getEventCount()) {
                    this.resetAcceptedTiles();
                    this.lastHierarchyEventCount = imageData.getHierarchy().getEventCount();
                }
                if (!this.acceptedTiles.computeIfAbsent(request, r -> filter.test(imageData, r)).booleanValue()) continue;
            }
            if ((imgRGB = this.getCachedTileRGB(tile, server)) != null) {
                gCopy.setColor(colorComplete);
                gCopy.fillRect(request.getX(), request.getY(), request.getWidth(), request.getHeight());
                gCopy.drawImage(imgRGB, request.getX(), request.getY(), request.getWidth(), request.getHeight(), null);
                continue;
            }
            if (!this.livePrediction) continue;
            this.requestTile(tile, imageData, server);
        }
        gCopy.dispose();
    }

    private void resetAcceptedTiles() {
        this.acceptedTiles.clear();
        this.lastHierarchyEventCount = -1L;
    }

    BufferedImage getCachedTileRGB(TileRequest request, ImageServer<BufferedImage> server) {
        if (server == null) {
            return null;
        }
        BufferedImage img = (BufferedImage)server.getCachedTile(request);
        if (img != null) {
            if (img.getType() == 2 || img.getType() == 1 || img.getType() == 13 || img.getType() == 10) {
                return img;
            }
            return this.cacheRGB.computeIfAbsent(img, this::convertToRGB);
        }
        return null;
    }

    private BufferedImage convertToRGB(BufferedImage img) {
        ImageRenderer renderer = (ImageRenderer)this.renderer.get();
        if (renderer == null) {
            BufferedImage imgRGB = new BufferedImage(img.getWidth(), img.getHeight(), 2);
            Graphics2D g = imgRGB.createGraphics();
            g.drawImage((Image)img, 0, 0, null);
            g.dispose();
            return imgRGB;
        }
        return renderer.applyTransforms(img, null);
    }

    public void clearCache() {
        this.cacheRGB.clear();
    }

    public void stop() {
        List<Runnable> pending = this.pool.shutdownNow();
        this.clearCache();
        logger.debug("Stopped classification overlay, dropped {} requests", (Object)pending.size());
    }

    synchronized ImageServer<BufferedImage> createPixelClassificationServer(ImageData<BufferedImage> imageData) {
        return this.fun.apply(imageData);
    }

    void requestTile(TileRequest tile, ImageData<BufferedImage> imageData, ImageServer<BufferedImage> classifierServer) {
        if (!this.pool.isShutdown() && this.pendingRequests.add(tile)) {
            this.pool.submit(() -> {
                if (this.pool.isShutdown()) {
                    return;
                }
                if (!this.pendingRequests.contains(tile) || !this.currentRequests.add(tile)) {
                    return;
                }
                ArrayList<PathObject> changed = new ArrayList<PathObject>();
                PathObjectHierarchy hierarchy = imageData == null ? null : imageData.getHierarchy();
                try {
                    classifierServer.readRegion(tile.getRegionRequest());
                    this.repaintAllViewers();
                    ImageServerMetadata.ChannelType channelType = classifierServer.getMetadata().getChannelType();
                    if ((channelType == ImageServerMetadata.ChannelType.CLASSIFICATION || channelType == ImageServerMetadata.ChannelType.PROBABILITY || channelType == ImageServerMetadata.ChannelType.MULTICLASS_PROBABILITY) && hierarchy != null) {
                        changed.add(hierarchy.getRootObject());
                        hierarchy.getAnnotationsForRegion((ImageRegion)tile.getRegionRequest(), changed);
                    }
                }
                catch (Exception e) {
                    logger.error("Error requesting tile classification", (Throwable)e);
                }
                finally {
                    this.currentRequests.remove(tile);
                    this.pendingRequests.remove(tile);
                    if (hierarchy != null && !changed.isEmpty()) {
                        Platform.runLater(() -> hierarchy.fireObjectMeasurementsChangedEvent((Object)this, (Collection)changed, true));
                    }
                }
            });
        }
    }

    private void repaintAllViewers() {
        QuPathGUI qupath = QuPathGUI.getInstance();
        if (qupath != null) {
            qupath.getViewerManager().repaintAllViewers();
        }
    }

    @Override
    public String getLocationString(ImageData<BufferedImage> imageData, double x, double y, int z, int t) {
        if (this.getLocationStringFunction() == null) {
            ImageServer<BufferedImage> classifierServer;
            ImageServer<BufferedImage> imageServer = classifierServer = imageData == null ? null : this.getPixelClassificationServer(imageData);
            if (classifierServer == null) {
                return null;
            }
            return PixelClassificationOverlay.getDefaultLocationString(classifierServer, x, y, z, t);
        }
        return super.getLocationString(imageData, x, y, z, t);
    }

    public static String getDefaultLocationString(ImageServer<BufferedImage> server, double x, double y, int z, int t) {
        int level = 0;
        TileRequest tile = server.getTileRequestManager().getTileRequest(level, (int)Math.round(x), (int)Math.round(y), z, t);
        if (tile == null) {
            return null;
        }
        BufferedImage img = (BufferedImage)server.getCachedTile(tile);
        if (img == null) {
            return null;
        }
        int xx = (int)Math.floor((x - (double)tile.getImageX()) / tile.getDownsample());
        int yy = (int)Math.floor((y - (double)tile.getImageY()) / tile.getDownsample());
        if (xx < 0 || yy < 0 || xx >= img.getWidth() || yy >= img.getHeight()) {
            return null;
        }
        ImageServerMetadata.ChannelType channelType = server.getMetadata().getChannelType();
        double scale = 1.0;
        double probabilityScale = 1.0;
        if (img.getRaster().getDataBuffer().getDataType() == 0) {
            probabilityScale = 0.00392156862745098;
        }
        String defaultName = "";
        switch (channelType) {
            case CLASSIFICATION: {
                String name;
                Map classificationLabels = server.getMetadata().getClassificationLabels();
                int sample = img.getRaster().getSample(xx, yy, 0);
                PathClass pathClass = (PathClass)classificationLabels.get(sample);
                String string = name = pathClass == null ? null : pathClass.toString();
                if (name == null) {
                    return null;
                }
                return String.format("Classification: %s", name);
            }
            case MULTICLASS_PROBABILITY: 
            case PROBABILITY: {
                defaultName = "Prediction: ";
                scale = probabilityScale;
                break;
            }
            case DENSITY: {
                defaultName = "Density: ";
            }
        }
        List channels = server.getMetadata().getChannels();
        StringBuilder sb = new StringBuilder(defaultName);
        StringBuilder sbWithNames = new StringBuilder(defaultName);
        for (int c = 0; c < channels.size(); ++c) {
            double sampleDouble = img.getRaster().getSampleDouble(xx, yy, c) * scale;
            String num = GeneralTools.formatNumber((double)sampleDouble, (int)2);
            if (c != 0) {
                sb.append(", ");
                sbWithNames.append(", ");
            }
            sb.append(num);
            sbWithNames.append(((ImageChannel)channels.get(c)).getName()).append(": ").append(num);
        }
        if (sbWithNames.length() < 100) {
            return sbWithNames.toString();
        }
        return sb.toString();
    }

    static class ClassifierServerFunction
    implements Function<ImageData<BufferedImage>, ImageServer<BufferedImage>> {
        private PixelClassificationImageServer server;
        private PixelClassifier classifier;

        private ClassifierServerFunction(PixelClassifier classifier) {
            this.classifier = classifier;
        }

        @Override
        public ImageServer<BufferedImage> apply(ImageData<BufferedImage> imageData) {
            if (imageData == null) {
                this.server = null;
                return null;
            }
            if (this.server != null && this.server.getImageData() != imageData) {
                this.server = null;
            }
            if (this.server == null && this.classifier.supportsImage(imageData)) {
                this.server = new PixelClassificationImageServer(imageData, this.classifier);
            }
            return this.server;
        }
    }

    static class FeatureCalculatorServerFunction
    implements Function<ImageData<BufferedImage>, ImageServer<BufferedImage>> {
        private ImageServer<BufferedImage> server;
        private ImageDataOp calculator;
        private PixelCalibration resolution;

        private FeatureCalculatorServerFunction(ImageDataOp calculator, PixelCalibration resolution) {
            this.calculator = calculator;
            this.resolution = resolution;
        }

        private FeatureCalculatorServerFunction(ImageServer<BufferedImage> server) {
            this.server = server;
        }

        @Override
        public ImageServer<BufferedImage> apply(ImageData<BufferedImage> imageData) {
            if (imageData == null) {
                this.server = null;
                return null;
            }
            if (this.server != null && this.server instanceof ImageDataServer && ((ImageDataServer)this.server).getImageData() != imageData) {
                this.server = null;
            }
            if (this.server == null && this.calculator != null && this.calculator.supportsImage(imageData)) {
                this.server = ImageOps.buildServer(imageData, (ImageDataOp)this.calculator, (PixelCalibration)this.resolution);
            }
            return this.server;
        }
    }
}

