/*
 * Decompiled with CFR 0.152.
 */
package qupath.process.gui.commands.ml;

import java.awt.BasicStroke;
import java.awt.Graphics2D;
import java.awt.Shape;
import java.awt.image.BufferedImage;
import java.awt.image.WritableRaster;
import java.io.IOException;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.Objects;
import java.util.TreeSet;
import java.util.WeakHashMap;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.MatVector;
import org.bytedeco.opencv.opencv_ml.TrainData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.color.ColorToolsAwt;
import qupath.lib.geom.Point2;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.PixelCalibration;
import qupath.lib.images.servers.TileRequest;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.regions.ImageRegion;
import qupath.lib.regions.RegionRequest;
import qupath.lib.roi.interfaces.ROI;
import qupath.opencv.ops.ImageDataOp;
import qupath.opencv.ops.ImageDataServer;
import qupath.opencv.ops.ImageOps;
import qupath.process.gui.commands.ml.BoundaryStrategy;

public class PixelClassifierTraining {
    private static final Logger logger = LoggerFactory.getLogger(PixelClassifierTraining.class);
    private BoundaryStrategy boundaryStrategy = BoundaryStrategy.getSkipBoundaryStrategy();
    private PixelCalibration resolution = PixelCalibration.getDefaultInstance();
    private ImageDataOp featureCalculator;
    private Mat matTraining;
    private Mat matTargets;
    private static PathClass REGION_CLASS = PathClass.StandardPathClasses.REGION;
    private static Map<RegionRequest, TileFeatures> cache = Collections.synchronizedMap(new WeakHashMap());

    public PixelClassifierTraining(ImageDataOp featureCalculator) {
        this.featureCalculator = featureCalculator;
    }

    synchronized ImageDataServer<BufferedImage> getFeatureServer(ImageData<BufferedImage> imageData) {
        if (this.featureCalculator != null && imageData != null && this.featureCalculator.supportsImage(imageData)) {
            return ImageOps.buildServer(imageData, (ImageDataOp)this.featureCalculator, (PixelCalibration)this.resolution);
        }
        return null;
    }

    public synchronized ImageDataOp getFeatureOp() {
        return this.featureCalculator;
    }

    public synchronized PixelCalibration getResolution() {
        return this.resolution;
    }

    public synchronized void setResolution(PixelCalibration cal) {
        if (Objects.equals(this.resolution, cal)) {
            return;
        }
        this.resolution = cal;
    }

    public synchronized void setFeatureOp(ImageDataOp featureOp) {
        if (Objects.equals(this.featureCalculator, featureOp)) {
            return;
        }
        this.featureCalculator = featureOp;
        this.resetTrainingData();
    }

    private synchronized ClassifierTrainingData updateTrainingData(Map<PathClass, Integer> labelMap, Collection<ImageData<BufferedImage>> imageDataCollection) throws IOException {
        if (imageDataCollection.isEmpty()) {
            this.resetTrainingData();
            return null;
        }
        LinkedHashMap<PathClass, Integer> labels = new LinkedHashMap<PathClass, Integer>();
        boolean hasLockedAnnotations = false;
        if (labelMap == null) {
            TreeSet<PathClass> pathClasses = new TreeSet<PathClass>((p1, p2) -> p1.toString().compareTo(p2.toString()));
            for (ImageData<BufferedImage> imageData : imageDataCollection) {
                Collection collection = imageData.getHierarchy().getAnnotationObjects();
                for (PathObject annotation : collection) {
                    if (PixelClassifierTraining.isTrainableAnnotation(annotation, true)) {
                        PathClass boundaryClass;
                        PathClass pathClass = annotation.getPathClass();
                        pathClasses.add(pathClass);
                        if (!annotation.getROI().isArea() || (boundaryClass = this.boundaryStrategy.getBoundaryClass(pathClass)) == null) continue;
                        pathClasses.add(boundaryClass);
                        continue;
                    }
                    if (!PixelClassifierTraining.isTrainableAnnotation(annotation, false)) continue;
                    hasLockedAnnotations = true;
                }
            }
            int lab = 0;
            for (PathClass pathClass : pathClasses) {
                Integer temp = lab;
                labels.put(pathClass, temp);
                ++lab;
            }
        } else {
            labels.putAll(labelMap);
        }
        ArrayList<Mat> allFeatures = new ArrayList<Mat>();
        ArrayList<Mat> allTargets = new ArrayList<Mat>();
        for (ImageData<BufferedImage> imageData : imageDataCollection) {
            ImageDataServer<BufferedImage> featureServer = this.getFeatureServer(imageData);
            if (featureServer != null) {
                Collection tiles = featureServer.getTileRequestManager().getAllTileRequests();
                for (TileRequest tile : tiles) {
                    TileFeatures tileFeatures = PixelClassifierTraining.getTileFeatures(tile.getRegionRequest(), featureServer, this.boundaryStrategy, labels);
                    if (tileFeatures == null) continue;
                    allFeatures.add(tileFeatures.getFeatures());
                    allTargets.add(tileFeatures.getTargets());
                }
                continue;
            }
            logger.warn("Unable to generate features for {}", imageData);
        }
        int n = labels.size();
        if (n <= 1) {
            logger.warn("Unlocked annotations for at least two classes are required to train a classifier!");
            if (hasLockedAnnotations) {
                logger.warn("Image contains annotations that *could* be used for training, except they are currently locked. Please unlock them if they should be used.");
            }
            this.resetTrainingData();
            return null;
        }
        if (this.matTraining == null) {
            this.matTraining = new Mat();
        }
        if (this.matTargets == null) {
            this.matTargets = new Mat();
        }
        opencv_core.vconcat((MatVector)new MatVector((Mat[])allFeatures.toArray(Mat[]::new)), (Mat)this.matTraining);
        opencv_core.vconcat((MatVector)new MatVector((Mat[])allTargets.toArray(Mat[]::new)), (Mat)this.matTargets);
        logger.debug("Training data: {} x {}, Target data: {} x {}", new Object[]{this.matTraining.rows(), this.matTraining.cols(), this.matTargets.rows(), this.matTargets.cols()});
        if (this.matTraining.rows() == 0) {
            logger.warn("No training data found - if you have training annotations, check the features are compatible with the current image.");
            return null;
        }
        return new ClassifierTrainingData(labels, this.matTraining, this.matTargets);
    }

    static boolean isTrainableAnnotation(PathObject pathObject, boolean checkLocked) {
        return pathObject != null && pathObject.hasROI() && !pathObject.getROI().isEmpty() && pathObject.isAnnotation() && (!pathObject.isLocked() || !checkLocked) && pathObject.getPathClass() != null && pathObject.getPathClass() != REGION_CLASS;
    }

    public void setBoundaryStrategy(BoundaryStrategy strategy) {
        if (this.boundaryStrategy == strategy) {
            return;
        }
        this.boundaryStrategy = strategy == null ? BoundaryStrategy.getSkipBoundaryStrategy() : strategy;
        this.resetTrainingData();
    }

    public BoundaryStrategy getBoundaryStrategy() {
        return this.boundaryStrategy;
    }

    private synchronized void resetTrainingData() {
        if (this.matTraining != null) {
            this.matTraining.release();
        }
        this.matTraining = null;
        if (this.matTargets != null) {
            this.matTargets.release();
        }
        this.matTargets = null;
    }

    public ClassifierTrainingData createTrainingData(ImageData<BufferedImage> imageData) throws IOException {
        return this.createTrainingDataForLabelMap(Collections.singleton(imageData), null);
    }

    public ClassifierTrainingData createTrainingData(Collection<ImageData<BufferedImage>> imageData) throws IOException {
        return this.createTrainingDataForLabelMap(imageData, null);
    }

    public ClassifierTrainingData createTrainingDataForLabelMap(Collection<ImageData<BufferedImage>> imageData, Map<PathClass, Integer> labels) throws IOException {
        return this.updateTrainingData(labels, imageData);
    }

    private static TileFeatures getTileFeatures(RegionRequest request, ImageDataServer<BufferedImage> featureServer, BoundaryStrategy strategy, Map<PathClass, Integer> labels) {
        TileFeatures features = cache.get(request);
        HashMap<ROI, PathClass> rois = null;
        Collection annotations = featureServer.getImageData().getHierarchy().getAllObjectsForRegion((ImageRegion)request, null);
        if (annotations != null && !annotations.isEmpty()) {
            rois = new HashMap<ROI, PathClass>();
            for (PathObject annotation : annotations) {
                if (!PixelClassifierTraining.isTrainableAnnotation(annotation, true)) continue;
                ROI roi = annotation.getROI();
                if (roi != null && roi.isPoint()) {
                    boolean containsPoint = false;
                    for (Point2 p : roi.getAllPoints()) {
                        if (!request.contains((int)p.getX(), (int)p.getY(), roi.getZ(), roi.getT())) continue;
                        containsPoint = true;
                        break;
                    }
                    if (!containsPoint) continue;
                }
                PathClass pathClass = annotation.getPathClass();
                if (roi == null || !labels.containsKey(pathClass)) continue;
                rois.put(roi, pathClass);
            }
        }
        if (rois == null || rois.isEmpty()) {
            if (features != null) {
                cache.remove(request);
            }
            return null;
        }
        if (features != null && features.featureServer.equals(featureServer) && features.labels.equals(labels) && features.strategy.equals(strategy) && features.rois.equals(rois) && features.request.equals((Object)request)) {
            return features;
        }
        try {
            features = new TileFeatures(request, featureServer, strategy, rois, labels);
            cache.put(request, features);
        }
        catch (IOException e) {
            cache.remove(request);
            logger.error("Error requesting features for " + String.valueOf(request), (Throwable)e);
        }
        return features;
    }

    private static class TileFeatures {
        private Map<PathClass, Integer> labels;
        private ImageDataServer<BufferedImage> featureServer;
        private RegionRequest request;
        private Map<ROI, PathClass> rois;
        private BoundaryStrategy strategy;
        private Mat matFeatures;
        private Mat matTargets;

        private TileFeatures(RegionRequest request, ImageDataServer<BufferedImage> featureServer, BoundaryStrategy strategy, Map<ROI, PathClass> rois, Map<PathClass, Integer> labels) throws IOException {
            this.request = request;
            this.strategy = strategy;
            this.featureServer = featureServer;
            this.rois = rois;
            this.labels = labels;
            this.ensureFeaturesCalculated();
        }

        private void ensureFeaturesCalculated() throws IOException {
            if (this.matFeatures != null && this.matTargets != null) {
                return;
            }
            BufferedImage features = (BufferedImage)this.featureServer.readRegion(this.request);
            double downsample = this.request.getDownsample();
            double boundaryThickness = this.strategy.getBoundaryThickness();
            BasicStroke stroke = boundaryThickness > 0.0 ? new BasicStroke((float)(downsample * boundaryThickness)) : null;
            BasicStroke singleStroke = new BasicStroke((float)downsample);
            int width = features.getWidth();
            int height = features.getHeight();
            BufferedImage imgLabels = new BufferedImage(width, height, 10);
            WritableRaster raster = imgLabels.getRaster();
            for (Map.Entry<ROI, PathClass> entry : this.rois.entrySet()) {
                ROI roi = entry.getKey();
                PathClass pathClass = entry.getValue();
                Integer label = this.labels.get(pathClass);
                if (label == null) continue;
                int lab = label + 1;
                boolean isArea = roi.isArea();
                boolean isLine = roi.isLine();
                if (roi.isPoint()) {
                    for (Point2 p : roi.getAllPoints()) {
                        int x = (int)Math.round((p.getX() - (double)this.request.getX()) / downsample);
                        int y = (int)Math.round((p.getY() - (double)this.request.getY()) / downsample);
                        if (x < 0 || y < 0 || x >= width || y >= height) continue;
                        raster.setSample(x, y, 0, lab);
                    }
                    continue;
                }
                Graphics2D g2d = imgLabels.createGraphics();
                g2d.scale(1.0 / downsample, 1.0 / downsample);
                g2d.translate(-this.request.getX(), -this.request.getY());
                g2d.setColor(ColorToolsAwt.getCachedColor((int)lab, (int)lab, (int)lab));
                Shape shape = entry.getKey().getShape();
                if (isArea) {
                    Integer boundaryLabel;
                    g2d.fill(shape);
                    PathClass boundaryClass = this.strategy.getBoundaryClass(pathClass);
                    Integer n = boundaryLabel = boundaryClass == null ? null : this.labels.get(boundaryClass);
                    if (stroke != null && boundaryLabel != null) {
                        int boundaryLab = boundaryLabel + 1;
                        g2d.setColor(ColorToolsAwt.getCachedColor((int)boundaryLab, (int)boundaryLab, (int)boundaryLab));
                        g2d.setStroke(stroke);
                        g2d.draw(shape);
                    }
                } else if (isLine) {
                    g2d.setStroke(stroke == null ? singleStroke : stroke);
                    g2d.draw(shape);
                }
                g2d.dispose();
            }
            int capacity = width * height;
            int nFeatures = features.getRaster().getNumBands();
            float[] buf = new float[nFeatures];
            FloatBuffer extracted = FloatBuffer.allocate(capacity * nFeatures);
            IntBuffer targets = IntBuffer.allocate(capacity);
            WritableRaster rasterFeatures = features.getRaster();
            for (int y = 0; y < height; ++y) {
                for (int x = 0; x < width; ++x) {
                    int label = raster.getSample(x, y, 0);
                    if (label == 0) continue;
                    buf = rasterFeatures.getPixel(x, y, buf);
                    extracted.put(buf);
                    targets.put(label - 1);
                }
            }
            int n = targets.position();
            this.matFeatures = new Mat(n, nFeatures, opencv_core.CV_32FC1);
            this.matTargets = new Mat(n, 1, opencv_core.CV_32SC1);
            if (n == 0) {
                logger.debug("I thought I'd have features but I don't! " + this.rois.size() + " - " + String.valueOf(this.request));
                return;
            }
            IntIndexer idxTargets = (IntIndexer)this.matTargets.createIndexer();
            FloatIndexer idxFeatures = (FloatIndexer)this.matFeatures.createIndexer();
            int t = 0;
            for (int i = 0; i < n; ++i) {
                for (int j = 0; j < nFeatures; ++j) {
                    idxFeatures.put((long)i, (long)j, extracted.get(t));
                    ++t;
                }
                idxTargets.put((long)i, targets.get(i));
            }
            idxTargets.release();
            idxFeatures.release();
        }

        public Mat getFeatures() {
            return this.matFeatures;
        }

        public Mat getTargets() {
            return this.matTargets;
        }
    }

    public static class ClassifierTrainingData {
        private Mat matTraining;
        private Mat matTargets;
        private Map<PathClass, Integer> pathClassesLabels;

        private ClassifierTrainingData(Map<PathClass, Integer> pathClassesLabels, Mat matTraining, Mat matTargets) {
            this.pathClassesLabels = Collections.unmodifiableMap(new LinkedHashMap<PathClass, Integer>(pathClassesLabels));
            this.matTraining = matTraining;
            this.matTargets = matTargets;
        }

        public synchronized Map<PathClass, Integer> getLabelMap() {
            return this.pathClassesLabels;
        }

        public TrainData getTrainData() {
            return TrainData.create((Mat)this.matTraining.clone(), (int)0, (Mat)this.matTargets.clone());
        }
    }
}

