/*
 * Decompiled with CFR 0.152.
 */
package qupath.ext.djl;

import ai.djl.Application;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.modality.Classifications;
import ai.djl.modality.cv.BufferedImageFactory;
import ai.djl.modality.cv.Image;
import ai.djl.modality.cv.ImageFactory;
import ai.djl.modality.cv.output.BoundingBox;
import ai.djl.modality.cv.output.CategoryMask;
import ai.djl.modality.cv.output.DetectedObjects;
import ai.djl.modality.cv.output.Joints;
import ai.djl.modality.cv.output.Landmark;
import ai.djl.modality.cv.output.Mask;
import ai.djl.modality.cv.output.Point;
import ai.djl.modality.cv.output.Rectangle;
import ai.djl.modality.cv.translator.BigGANTranslator;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.repository.Artifact;
import ai.djl.repository.MRL;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ModelZoo;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import ai.djl.translate.TranslatorFactory;
import ai.djl.util.ClassLoaderUtils;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Progress;
import java.awt.image.BandedSampleModel;
import java.awt.image.BufferedImage;
import java.awt.image.DataBufferFloat;
import java.awt.image.Raster;
import java.awt.image.WritableRaster;
import java.io.IOException;
import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.locationtech.jts.geom.Geometry;
import org.locationtech.jts.geom.util.AffineTransformation;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.analysis.images.ContourTracing;
import qupath.lib.analysis.images.SimpleImage;
import qupath.lib.geom.Point2;
import qupath.lib.images.ImageData;
import qupath.lib.images.servers.AbstractTileableImageServer;
import qupath.lib.images.servers.GeneratingImageServer;
import qupath.lib.images.servers.ImageServer;
import qupath.lib.images.servers.ImageServerBuilder;
import qupath.lib.images.servers.ImageServerMetadata;
import qupath.lib.images.servers.TileRequest;
import qupath.lib.measurements.MeasurementList;
import qupath.lib.objects.PathObject;
import qupath.lib.objects.PathObjectTools;
import qupath.lib.objects.PathObjects;
import qupath.lib.objects.classes.PathClass;
import qupath.lib.objects.hierarchy.PathObjectHierarchy;
import qupath.lib.regions.ImagePlane;
import qupath.lib.regions.ImageRegion;
import qupath.lib.regions.RegionRequest;
import qupath.lib.roi.GeometryTools;
import qupath.lib.roi.ROIs;
import qupath.lib.roi.RoiTools;
import qupath.lib.roi.interfaces.ROI;

public class DjlZoo {
    private static final Logger logger = LoggerFactory.getLogger(DjlZoo.class);
    private static final List<Class<?>> preferredInputs = Arrays.asList(Image.class, NDList.class);
    private static final List<Class<?>> preferredOutputs = Arrays.asList(CategoryMask.class, DetectedObjects.class, Joints.class, Classifications.class, Image.class, NDList.class);

    public static void logAvailableModels() {
        for (Map.Entry entry : ModelZoo.listModels().entrySet()) {
            logger.info("Application: {}", entry.getKey());
            for (MRL artifact : (List)entry.getValue()) {
                logger.info("  {}", (Object)artifact.toString());
            }
        }
    }

    public static List<MRL> listModels(Criteria<?, ?> criteria) throws ModelNotFoundException, IOException {
        ArrayList<MRL> list = new ArrayList<MRL>();
        for (Map.Entry entry : ModelZoo.listModels(criteria).entrySet()) {
            list.addAll((Collection)entry.getValue());
        }
        return list;
    }

    public static List<MRL> listModels(Application application) throws ModelNotFoundException, IOException {
        Criteria criteria = Criteria.builder().optApplication(application).build();
        return DjlZoo.listModels(criteria);
    }

    public static List<MRL> listModels() throws ModelNotFoundException, IOException {
        return DjlZoo.listModels(Criteria.builder().build());
    }

    public static List<MRL> listImageClassificationModels() throws ModelNotFoundException, IOException {
        return DjlZoo.listModels(Application.CV.IMAGE_CLASSIFICATION);
    }

    public static List<MRL> listSemanticSegmentationModels() throws ModelNotFoundException, IOException {
        return DjlZoo.listModels(Application.CV.SEMANTIC_SEGMENTATION);
    }

    public static List<MRL> listObjectDetectionModels() throws ModelNotFoundException, IOException {
        return DjlZoo.listModels(Application.CV.OBJECT_DETECTION);
    }

    public static List<MRL> listInstanceSegmentationModels() throws ModelNotFoundException, IOException {
        return DjlZoo.listModels(Application.CV.INSTANCE_SEGMENTATION);
    }

    public static ZooModel<?, ?> loadModel(Artifact artifact, boolean allowDownload) throws ModelNotFoundException, MalformedModelException, IOException {
        Criteria<?, ?> criteria = DjlZoo.buildCriteria(artifact, allowDownload);
        return criteria.loadModel();
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    public static Criteria<?, ?> buildCriteria(Artifact artifact, boolean allowDownload) throws ModelNotFoundException, MalformedModelException, IOException {
        String before = System.getProperty("ai.djl.offline");
        try {
            if (allowDownload) {
                System.setProperty("ai.djl.offline", "false");
            }
            Application application = artifact.getMetadata().getApplication();
            Criteria.Builder builder = Criteria.builder().optApplication(application).optArtifactId(artifact.getMetadata().getArtifactId()).optProgress((Progress)new ProgressBar()).optArguments(artifact.getArguments()).optGroupId(artifact.getMetadata().getGroupId()).optFilters(artifact.getProperties());
            Object factoryClass = artifact.getArguments().getOrDefault("translatorFactory", null);
            if (factoryClass instanceof String) {
                TranslatorFactory factory = DjlZoo.getTranslatorFactory(factoryClass);
                Set supportedTypes = factory.getSupportedTypes();
                Pair preferredTypes = supportedTypes.stream().filter(p -> preferredInputs.contains(p.getKey()) && preferredOutputs.contains(p.getValue())).sorted(Comparator.comparingInt(p -> preferredInputs.indexOf(p.getKey())).thenComparingInt(p -> preferredOutputs.indexOf(p.getValue()))).findFirst().orElse(null);
                if (preferredTypes == null) {
                    if (supportedTypes.size() == 1) {
                        preferredTypes = (Pair)supportedTypes.iterator().next();
                    }
                    logger.warn("No supported types found in " + String.valueOf(factoryClass) + " -\nPlease call .builder().setTypes(inputClass, outputClass).build() to specify these directly");
                }
                if (preferredTypes != null) {
                    builder = builder.setTypes((Class)preferredTypes.getKey(), (Class)preferredTypes.getValue());
                }
            } else {
                logger.warn("No translatorFactory specified - will try to choose suitable input/output class based on the application.\nIf this fails, please call .builder().setTypes(inputClass, outputClass).build() to specify these directly");
                builder = application == Application.CV.IMAGE_CLASSIFICATION ? builder.setTypes(Image.class, Classifications.class) : (application == Application.CV.SEMANTIC_SEGMENTATION ? builder.setTypes(Image.class, CategoryMask.class) : (application == Application.CV.IMAGE_GENERATION ? builder.setTypes(Image.class, Image.class) : (application == Application.CV.OBJECT_DETECTION ? builder.setTypes(Image.class, DetectedObjects.class) : (application == Application.CV.INSTANCE_SEGMENTATION ? builder.setTypes(Image.class, DetectedObjects.class) : (application == Application.CV.WORD_RECOGNITION ? builder.setTypes(Image.class, DetectedObjects.class) : (application == Application.CV.POSE_ESTIMATION ? builder.setTypes(Image.class, Joints.class) : builder.setTypes(NDList.class, NDList.class)))))));
            }
            Criteria criteria = builder.build();
            return criteria;
        }
        finally {
            System.setProperty("ai.djl.offline", before);
        }
    }

    private static TranslatorFactory getTranslatorFactory(String factoryClass) {
        ClassLoader cl = ClassLoaderUtils.getContextClassLoader();
        return (TranslatorFactory)ClassLoaderUtils.initClass((ClassLoader)cl, TranslatorFactory.class, (String)factoryClass);
    }

    public static ROI createROI(DetectedObjects.DetectedObject obj, ImageRegion region) {
        BoundingBox box = obj.getBoundingBox();
        if (box instanceof Mask) {
            return DjlZoo.createROI((Mask)box, region, 0.5);
        }
        if (box instanceof Landmark) {
            return DjlZoo.createROI((Landmark)box, region);
        }
        return DjlZoo.createROI(box, region);
    }

    public static ROI createROI(BoundingBox box, ImageRegion region) {
        Rectangle bounds = box.getBounds();
        double xo = 0.0;
        double yo = 0.0;
        ImagePlane plane = ImagePlane.getDefaultPlane();
        if (region != null) {
            plane = region.getImagePlane();
            xo = region.getMinX();
            yo = region.getMinY();
        }
        return ROIs.createRectangleROI((double)(xo + bounds.getX() * (double)region.getWidth()), (double)(yo + bounds.getY() * (double)region.getHeight()), (double)(bounds.getWidth() * (double)region.getWidth()), (double)(bounds.getHeight() * (double)region.getHeight()), (ImagePlane)plane);
    }

    public static ROI createROI(Mask mask, ImageRegion region, double threshold) {
        float[][] probs = mask.getProbDist();
        int w = probs.length;
        int h = probs[0].length;
        DataBufferFloat buffer = new DataBufferFloat(w * h, 1);
        BandedSampleModel sampleModel = new BandedSampleModel(buffer.getDataType(), w, h, 1);
        WritableRaster raster = WritableRaster.createWritableRaster(sampleModel, buffer, null);
        for (int x = 0; x < w; ++x) {
            float[] col = probs[x];
            for (int y = 0; y < h; ++y) {
                raster.setSample(x, y, 0, col[y]);
            }
        }
        if (region == null) {
            region = ImageRegion.createInstance((int)0, (int)0, (int)w, (int)h, (int)0, (int)0);
        }
        Geometry geometry = ContourTracing.createTracedGeometry((Raster)raster, (double)threshold, (double)Double.POSITIVE_INFINITY, (int)0, null);
        Rectangle bounds = mask.getBounds();
        AffineTransformation transform = new AffineTransformation();
        transform.scale(1.0 / (double)raster.getWidth(), 1.0 / (double)raster.getHeight());
        transform.scale(bounds.getWidth(), bounds.getHeight());
        transform.translate(bounds.getX(), bounds.getY());
        transform.scale((double)region.getWidth(), (double)region.getHeight());
        transform.translate((double)region.getX(), (double)region.getY());
        if (!transform.isIdentity()) {
            geometry = transform.transform(geometry);
        }
        return GeometryTools.geometryToROI((Geometry)geometry, (ImagePlane)region.getImagePlane());
    }

    public static ROI createROI(Landmark landmark, ImageRegion region) {
        double xo = 0.0;
        double yo = 0.0;
        ImagePlane plane = ImagePlane.getDefaultPlane();
        if (region != null) {
            plane = region.getImagePlane();
            xo = region.getMinX();
            yo = region.getMinY();
        }
        ArrayList<Point2> points = new ArrayList<Point2>();
        for (Point p : landmark.getPath()) {
            points.add(new Point2(xo + p.getX() * (double)region.getWidth(), yo + p.getY() * (double)region.getHeight()));
        }
        return ROIs.createPointsROI(points, (ImagePlane)plane);
    }

    public static Optional<List<PathObject>> detect(ZooModel<Image, DetectedObjects> model, ImageData<BufferedImage> imageData) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
        return DjlZoo.detect(model, imageData, Collections.singleton(imageData.getHierarchy().getRootObject()));
    }

    public static Optional<List<PathObject>> detect(ZooModel<Image, DetectedObjects> model, ImageData<BufferedImage> imageData, Collection<? extends PathObject> parentObjects) throws ModelNotFoundException, MalformedModelException, IOException, TranslateException {
        if (parentObjects == null) {
            parentObjects = Collections.singleton(imageData.getHierarchy().getRootObject());
        }
        Shape inputHeightWidth = DjlZoo.getInputHeightWidth(model);
        long inputWidth = inputHeightWidth.get(0);
        long inputHeight = inputHeightWidth.get(1);
        double defaultThreshold = 0.5;
        double threshold = DjlZoo.tryToParseDoubleProperty(model, "threshold", defaultThreshold);
        if (threshold != defaultThreshold) {
            logger.debug("Setting threshold to {} from model properties", (Object)threshold);
        }
        ImageServer server = imageData.getServer();
        double downsampleBase = server.getDownsampleForResolution(0);
        ConcurrentHashMap<PathObject, List<PathObject>> map = new ConcurrentHashMap<PathObject, List<PathObject>>();
        ArrayList<PathObject> list = new ArrayList<PathObject>();
        try (NDManager manager = model.getNDManager();){
            Predictor predictor = model.newPredictor();
            for (PathObject pathObject : parentObjects) {
                pathObject.clearChildObjects();
                ROI roi = pathObject.getROI();
                List<RegionRequest> requests = roi == null ? DjlZoo.getAllRequests(RegionRequest.createInstance((ImageServer)server, (double)downsampleBase), server.nZSlices(), server.nTimepoints()) : Collections.singletonList(RegionRequest.createInstance((String)server.getPath(), (double)downsampleBase, (ROI)pathObject.getROI()));
                List childObjects = map.computeIfAbsent(pathObject, p -> new ArrayList());
                for (RegionRequest request : requests) {
                    if (Thread.currentThread().isInterrupted()) {
                        logger.warn("Detection interrupted! Discarding {} detection(s)", (Object)list.size());
                        Optional<List<PathObject>> optional = Optional.empty();
                        return optional;
                    }
                    if (inputWidth > 0L || inputHeight > 0L) {
                        request = DjlZoo.updateDownsampleForInput(request, inputWidth, inputHeight);
                    }
                    BufferedImage img = (BufferedImage)server.readRegion(request);
                    DetectedObjects detections = DjlZoo.detect((Predictor<Image, DetectedObjects>)predictor, img);
                    for (Classifications.Classification item : detections.items()) {
                        DetectedObjects.DetectedObject detected = (DetectedObjects.DetectedObject)item;
                        if (detected.getProbability() < threshold) continue;
                        ROI detectedROI = DjlZoo.createROI(detected, (ImageRegion)request);
                        if (roi != null && (detected.getBoundingBox() instanceof Mask || detectedROI.isPoint())) {
                            detectedROI = RoiTools.intersection((ROI[])new ROI[]{detectedROI, roi});
                        }
                        if (detectedROI.isEmpty()) {
                            logger.debug("ROI detected, but empty (class={}, prob={})", (Object)detected.getProbability(), (Object)item.getClassName());
                            continue;
                        }
                        PathObject newObject = PathObjects.createAnnotationObject((ROI)detectedROI, (PathClass)PathClass.fromString((String)item.getClassName()));
                        try (MeasurementList ml = newObject.getMeasurementList();){
                            ml.put("Class probability", item.getProbability());
                        }
                        list.add(newObject);
                        childObjects.add(newObject);
                    }
                }
            }
        }
        DjlZoo.updateObjectsAndHierarchy(imageData.getHierarchy(), map, model);
        imageData.getHierarchy().fireHierarchyChangedEvent(DjlZoo.class);
        return Optional.of(list);
    }

    private static List<RegionRequest> getAllRequests(RegionRequest request, int nZSlices, int nTimepoints) {
        ArrayList<RegionRequest> list = new ArrayList<RegionRequest>();
        for (int t = 0; t < nTimepoints; ++t) {
            request = request.updateT(t);
            for (int z = 0; z < nZSlices; ++z) {
                request = request.updateZ(z);
                list.add(request);
            }
        }
        return list;
    }

    public static Optional<List<PathObject>> segmentAnnotations(ZooModel<Image, CategoryMask> model, ImageData<BufferedImage> imageData) throws TranslateException, IOException {
        return DjlZoo.segmentAnnotations(model, imageData, Collections.singletonList(imageData.getHierarchy().getRootObject()));
    }

    public static Optional<List<PathObject>> segmentAnnotations(ZooModel<Image, CategoryMask> model, ImageData<BufferedImage> imageData, Collection<? extends PathObject> parentObjects) throws TranslateException, IOException {
        return DjlZoo.segmentObjects(model, imageData, parentObjects, r -> PathObjects.createAnnotationObject((ROI)r), true);
    }

    public static Optional<List<PathObject>> segmentDetections(ZooModel<Image, CategoryMask> model, ImageData<BufferedImage> imageData) throws TranslateException, IOException {
        return DjlZoo.segmentDetections(model, imageData, Collections.singletonList(imageData.getHierarchy().getRootObject()));
    }

    public static Optional<List<PathObject>> segmentDetections(ZooModel<Image, CategoryMask> model, ImageData<BufferedImage> imageData, Collection<? extends PathObject> parentObjects) throws TranslateException, IOException {
        return DjlZoo.segmentObjects(model, imageData, parentObjects, r -> PathObjects.createDetectionObject((ROI)r), true);
    }

    private static Shape getInputHeightWidth(Model model) {
        long inputWidth = -1L;
        long inputHeight = -1L;
        Block block = model.getBlock();
        if (block.isInitialized()) {
            Shape inputShape = DjlZoo.getSingleInputShape((PairList<String, Shape>)block.describeInput());
            inputWidth = DjlZoo.getDim(inputShape, LayoutType.WIDTH);
            inputHeight = DjlZoo.getDim(inputShape, LayoutType.HEIGHT);
        } else {
            long w = (long)DjlZoo.tryToParseDoubleProperty(model, "width", inputWidth);
            long h = (long)DjlZoo.tryToParseDoubleProperty(model, "height", inputHeight);
            if (w != inputWidth || h != inputHeight) {
                logger.debug("Setting input size to {} x {}", (Object)inputWidth, (Object)inputHeight);
                inputWidth = w;
                inputHeight = h;
            }
        }
        return new Shape(new long[]{inputHeight, inputWidth});
    }

    public static Optional<List<PathObject>> segmentObjects(ZooModel<Image, CategoryMask> model, ImageData<BufferedImage> imageData, Collection<? extends PathObject> parentObjects, Function<ROI, PathObject> creator, boolean skipBackground) throws TranslateException, IOException {
        if (parentObjects == null) {
            parentObjects = Collections.singleton(imageData.getHierarchy().getRootObject());
        }
        Shape inputHeightWidth = DjlZoo.getInputHeightWidth(model);
        long inputWidth = inputHeightWidth.get(0);
        long inputHeight = inputHeightWidth.get(1);
        ConcurrentHashMap<PathObject, List<PathObject>> map = new ConcurrentHashMap<PathObject, List<PathObject>>();
        ArrayList<PathObject> list = new ArrayList<PathObject>();
        ImageServer server = imageData.getServer();
        try (Predictor predictor = model.newPredictor();){
            for (PathObject pathObject : parentObjects) {
                List<RegionRequest> requests;
                if (Thread.interrupted()) {
                    logger.warn("Processing interrupted - {} object(s) will be discarded", (Object)list.size());
                    Optional<List<PathObject>> optional = Optional.empty();
                    return optional;
                }
                ROI roi = pathObject.getROI();
                if (roi != null) {
                    request = RegionRequest.createInstance((String)imageData.getServer().getPath(), (double)server.getDownsampleForResolution(0), (ROI)roi);
                    request = DjlZoo.updateDownsampleForInput(request, inputWidth, inputHeight);
                    requests = Collections.singletonList(request);
                } else {
                    request = RegionRequest.createInstance((ImageServer)imageData.getServer());
                    request = DjlZoo.updateDownsampleForInput(request, inputWidth, inputHeight);
                    requests = DjlZoo.getAllRequests(request, server.nZSlices(), server.nTimepoints());
                }
                List childList = map.computeIfAbsent(pathObject, p -> new ArrayList());
                for (RegionRequest request : requests) {
                    BufferedImage img = (BufferedImage)imageData.getServer().readRegion(request);
                    List<PathObject> segmented = DjlZoo.segmentObjects((Predictor<Image, CategoryMask>)predictor, img, request, roi, creator, skipBackground);
                    if (segmented.isEmpty()) continue;
                    childList.addAll(segmented);
                    pathObject.addChildObjects(segmented);
                    list.addAll(segmented);
                }
            }
        }
        DjlZoo.updateObjectsAndHierarchy(imageData.getHierarchy(), map, model);
        return Optional.of(list);
    }

    private static double tryToParseDoubleProperty(Model model, String key, double defaultValue) {
        String value = model.getProperty(key);
        if (value == null || value.isBlank()) {
            return defaultValue;
        }
        try {
            return Double.parseDouble(value);
        }
        catch (NumberFormatException e) {
            logger.warn("Unable to parse property {} as double: {}", (Object)key, (Object)e.getLocalizedMessage());
            return defaultValue;
        }
    }

    private static void updateObjectsAndHierarchy(PathObjectHierarchy hierarchy, Map<PathObject, List<PathObject>> map, Object changeSource) {
        boolean changes = false;
        for (Map.Entry<PathObject, List<PathObject>> entry : map.entrySet()) {
            PathObject parent = entry.getKey();
            List<PathObject> childObjects = entry.getValue();
            parent.clearChildObjects();
            parent.addChildObjects(childObjects);
            if (!childObjects.isEmpty()) {
                parent.setLocked(true);
            }
            changes = true;
        }
        if (changes) {
            hierarchy.fireHierarchyChangedEvent(changeSource);
            DjlZoo.deselectDeletedObjects(hierarchy);
        }
    }

    private static void deselectDeletedObjects(PathObjectHierarchy hierarchy) {
        List toDeselect = hierarchy.getSelectionModel().getSelectedObjects().stream().filter(p -> !PathObjectTools.hierarchyContainsObject((PathObjectHierarchy)hierarchy, (PathObject)p)).collect(Collectors.toList());
        if (!toDeselect.isEmpty()) {
            hierarchy.getSelectionModel().deselectObjects(toDeselect);
        }
    }

    private static RegionRequest updateDownsampleForInput(RegionRequest request, long inputWidth, long inputHeight) {
        double targetDownsampleHeight;
        if (inputWidth <= 0L && inputHeight <= 0L) {
            return request;
        }
        double targetDownsampleWidth = inputWidth <= 0L ? request.getDownsample() : (double)Math.round((double)request.getWidth() / (double)inputWidth);
        double targetDownsample = Math.min(targetDownsampleWidth, targetDownsampleHeight = inputHeight <= 0L ? request.getDownsample() : (double)Math.round((double)request.getHeight() / (double)inputHeight));
        if (targetDownsample > request.getDownsample()) {
            return request.updateDownsample(targetDownsample);
        }
        return request;
    }

    private static long getDim(Shape shape, LayoutType layoutType) {
        if (shape != null) {
            for (int i = 0; i < shape.dimension(); ++i) {
                if (!layoutType.equals((Object)shape.getLayoutType(i))) continue;
                return shape.get(i);
            }
        }
        return -1L;
    }

    private static Shape getSingleInputShape(PairList<String, Shape> input) {
        if (input == null) {
            return null;
        }
        if (input.size() != 1) {
            throw new IllegalArgumentException("Only single inputs are supported! Model requires " + input.size() + " inputs");
        }
        return (Shape)input.get(0).getValue();
    }

    public static List<PathObject> segmentDetections(Predictor<Image, CategoryMask> predictor, BufferedImage img, RegionRequest request) throws TranslateException {
        return DjlZoo.segmentObjects(predictor, img, request, null, r -> PathObjects.createDetectionObject((ROI)r), true);
    }

    public static List<PathObject> segmentAnnotations(Predictor<Image, CategoryMask> predictor, BufferedImage img, RegionRequest request) throws TranslateException {
        return DjlZoo.segmentObjects(predictor, img, request, null, r -> PathObjects.createAnnotationObject((ROI)r), true);
    }

    public static List<PathObject> segmentObjects(Predictor<Image, CategoryMask> predictor, BufferedImage img, RegionRequest request, ROI roiMask, Function<ROI, PathObject> creator, boolean skipBackground) throws TranslateException {
        Map<String, ROI> map = DjlZoo.segmentROIs(predictor, img, request, roiMask, skipBackground);
        return map.entrySet().stream().map(e -> DjlZoo.createPathObject(creator, (ROI)e.getValue(), (String)e.getKey())).collect(Collectors.toList());
    }

    private static PathObject createPathObject(Function<ROI, PathObject> creator, ROI roi, String classification) {
        PathObject pathObject = creator.apply(roi);
        if (classification != null) {
            pathObject.setPathClass(PathClass.getInstance((String)classification));
        }
        return pathObject;
    }

    private static Map<String, ROI> segmentROIs(Predictor<Image, CategoryMask> predictor, BufferedImage img, RegionRequest request, ROI roiMask, boolean skipBackground) throws TranslateException {
        Image input = BufferedImageFactory.getInstance().fromImage((Object)img);
        CategoryMask output = (CategoryMask)predictor.predict((Object)input);
        List classes = output.getClasses();
        int[][] maskOrig = output.getMask();
        SimpleMaskImage mask = new SimpleMaskImage(maskOrig);
        int nClasses = classes.size();
        int[] hist = new int[nClasses];
        int[][] nArray = maskOrig;
        int n = nArray.length;
        for (int i = 0; i < n; ++i) {
            int[] row;
            for (int val : row = nArray[i]) {
                if (val < 0 || val >= nClasses) continue;
                int n2 = val;
                hist[n2] = hist[n2] + 1;
            }
        }
        int startInd = skipBackground ? 1 : 0;
        LinkedHashMap<String, ROI> map = new LinkedHashMap<String, ROI>();
        for (int i = startInd; i < nClasses; ++i) {
            if (hist[i] == 0) continue;
            String classification = (String)classes.get(i);
            ROI roi = DjlZoo.createROI(mask, request, i);
            if (roi == null || roi.isEmpty()) continue;
            double scaleX = (double)img.getWidth() / (double)mask.getWidth();
            double scaleY = (double)img.getHeight() / (double)mask.getHeight();
            if (scaleX != 1.0 || scaleY != 0.0) {
                roi = request == null ? roi.scale(scaleX, scaleY) : roi.scale(scaleX, scaleY, (double)request.getX(), (double)request.getY());
            }
            if (roiMask != null) {
                roi = RoiTools.intersection((ROI[])new ROI[]{roi, roiMask});
            }
            map.put(classification, roi);
        }
        return map;
    }

    private static ROI createROI(SimpleImage mask, RegionRequest request, int val) {
        return ContourTracing.createTracedROI((SimpleImage)mask, (double)val, (double)val, (RegionRequest)request);
    }

    public static DetectedObjects detect(Predictor<Image, DetectedObjects> predictor, BufferedImage img) throws TranslateException {
        Image image = ImageFactory.getInstance().fromImage((Object)img);
        return (DetectedObjects)predictor.predict((Object)image);
    }

    public static Classifications classify(Predictor<Image, Classifications> predictor, BufferedImage img) throws TranslateException {
        Image image = ImageFactory.getInstance().fromImage((Object)img);
        return (Classifications)predictor.predict((Object)image);
    }

    static List<BufferedImage> bigGanGenerate(ZooModel<int[], Image[]> model, int ... indices) throws TranslateException {
        if (!(model.getTranslator() instanceof BigGANTranslator)) {
            logger.warn("Model translater is not an instance of BigGANTranslator");
        }
        try (Predictor predictor = model.newPredictor();){
            Image[] output = (Image[])predictor.predict((Object)indices);
            List<BufferedImage> list = Arrays.stream(output).map(i -> DjlZoo.toBufferedImage(i)).collect(Collectors.toList());
            return list;
        }
    }

    public static BufferedImage imageToImage(Predictor<Image, Image> predictor, BufferedImage img) throws TranslateException {
        Image image = BufferedImageFactory.getInstance().fromImage((Object)img);
        Image output = (Image)predictor.predict((Object)image);
        return DjlZoo.toBufferedImage(output);
    }

    public static BufferedImage toBufferedImage(Image image) throws IllegalArgumentException {
        Object wrapped = image.getWrappedImage();
        if (wrapped instanceof BufferedImage) {
            return (BufferedImage)wrapped;
        }
        throw new IllegalArgumentException("Need a java.awt.image.BufferedImage, but found " + String.valueOf(wrapped));
    }

    static ImageServer<BufferedImage> wrapImageToImage(ZooModel<Image, Image> model, ImageServer<BufferedImage> server) {
        return new DjlPredictionImageServer(server, model);
    }

    private static class SimpleMaskImage
    implements SimpleImage {
        private int[][] values;
        private int width;
        private int height;

        private SimpleMaskImage(int[][] values) {
            this.values = values;
            this.width = values[0].length;
            this.height = values.length;
        }

        public float getValue(int x, int y) {
            return this.values[y][x];
        }

        public int getWidth() {
            return this.width;
        }

        public int getHeight() {
            return this.height;
        }
    }

    static class DjlPredictionImageServer
    extends AbstractTileableImageServer
    implements GeneratingImageServer<BufferedImage> {
        private ImageServer<BufferedImage> server;
        private ZooModel<Image, Image> model;
        private ThreadLocal<Predictor<Image, Image>> predictors;

        DjlPredictionImageServer(ImageServer<BufferedImage> server, ZooModel<Image, Image> model) {
            this.server = server;
            this.model = model;
            this.predictors = ThreadLocal.withInitial(() -> model.newPredictor());
            Shape imageHeightWidth = DjlZoo.getInputHeightWidth(model);
            long tileWidth = imageHeightWidth.size(new int[]{1}) <= 0L ? 512L : imageHeightWidth.size(new int[]{1});
            long tileHeight = imageHeightWidth.size(new int[]{0}) <= 0L ? tileWidth : imageHeightWidth.size(new int[]{0});
            this.setMetadata(new ImageServerMetadata.Builder(server.getMetadata()).preferredTileSize((int)tileWidth, (int)tileHeight).build());
        }

        public Collection<URI> getURIs() {
            return this.server.getURIs();
        }

        public String getServerType() {
            return "Deep Java Library prediction server";
        }

        public ImageServerMetadata getOriginalMetadata() {
            return this.server.getOriginalMetadata();
        }

        protected BufferedImage readTile(TileRequest tileRequest) throws IOException {
            if (this.server.isEmptyRegion(tileRequest.getRegionRequest())) {
                return this.getEmptyTile(tileRequest.getTileWidth(), tileRequest.getTileHeight());
            }
            BufferedImage img = (BufferedImage)this.server.readRegion(tileRequest.getRegionRequest());
            try {
                return DjlZoo.imageToImage(this.predictors.get(), img);
            }
            catch (TranslateException e) {
                throw new IOException(e);
            }
        }

        protected ImageServerBuilder.ServerBuilder<BufferedImage> createServerBuilder() {
            throw new UnsupportedOperationException("DjlPredictionImageServer cannot currently be serialized");
        }

        protected String createID() {
            return UUID.randomUUID().toString();
        }

        public void close() throws Exception {
            super.close();
            this.model.close();
        }
    }
}

