/*
 * Decompiled with CFR 0.152.
 */
package qupath.lib.images.writers.ome.zarr;

import com.bc.zarr.ArrayParams;
import com.bc.zarr.Compressor;
import com.bc.zarr.CompressorFactory;
import com.bc.zarr.DataType;
import com.bc.zarr.DimensionSeparator;
import com.bc.zarr.ZarrArray;
import com.bc.zarr.ZarrGroup;
import java.awt.image.BufferedImage;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStream;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;
import loci.formats.gui.AWTImageTools;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.common.ThreadTools;
import qupath.lib.images.servers.ImageServer;
import qupath.lib.images.servers.ImageServers;
import qupath.lib.images.servers.PixelType;
import qupath.lib.images.servers.TileRequest;
import qupath.lib.images.servers.TransformedServerBuilder;
import qupath.lib.images.writers.ome.zarr.OMEXMLCreator;
import qupath.lib.images.writers.ome.zarr.OMEZarrAttributesCreator;
import qupath.lib.regions.ImageRegion;

public class OMEZarrWriter
implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(OMEZarrWriter.class);
    private final ImageServer<BufferedImage> server;
    private final Map<Integer, ZarrArray> levelArrays;
    private final ExecutorService executorService;
    private final Consumer<TileRequest> onTileWritten;

    private OMEZarrWriter(Builder builder, String path) throws IOException {
        int tileWidth = OMEZarrWriter.getChunkSize(builder.tileWidth > 0 ? builder.tileWidth : builder.server.getMetadata().getPreferredTileWidth(), builder.maxNumberOfChunks, builder.server.getWidth());
        int tileHeight = OMEZarrWriter.getChunkSize(builder.tileHeight > 0 ? builder.tileHeight : builder.server.getMetadata().getPreferredTileHeight(), builder.maxNumberOfChunks, builder.server.getHeight());
        double[] downsamples = builder.downsamples.length > 0 ? builder.downsamples : builder.server.getPreferredDownsamples();
        boolean tileSizeAndDownsamplesUnchanged = tileWidth == builder.server.getMetadata().getPreferredTileWidth() && tileHeight == builder.server.getMetadata().getPreferredTileHeight() && Arrays.equals(downsamples, builder.server.getPreferredDownsamples());
        TransformedServerBuilder transformedServerBuilder = new TransformedServerBuilder(tileSizeAndDownsamplesUnchanged ? builder.server : ImageServers.pyramidalizeTiled(builder.server, (int)tileWidth, (int)tileHeight, (double[])downsamples));
        if (builder.zStart != 0 || builder.zEnd != builder.server.nZSlices() || builder.tStart != 0 || builder.tEnd != builder.server.nTimepoints()) {
            transformedServerBuilder.slice(builder.zStart, builder.zEnd, builder.tStart, builder.tEnd);
        }
        if (builder.boundingBox != null) {
            transformedServerBuilder.crop(builder.boundingBox);
        }
        this.server = transformedServerBuilder.build();
        OMEZarrAttributesCreator attributes = new OMEZarrAttributesCreator(this.server.getMetadata());
        ZarrGroup root = ZarrGroup.create((String)path, attributes.getGroupAttributes());
        OMEXMLCreator.create(this.server.getMetadata()).ifPresent(omeXML -> OMEZarrWriter.createOmeSubGroup(root, path, omeXML));
        this.levelArrays = OMEZarrWriter.createLevelArrays(this.server, root, attributes.getLevelAttributes(), builder.compressor);
        this.executorService = Executors.newFixedThreadPool(builder.numberOfThreads, ThreadTools.createThreadFactory((String)"zarr_writer_", (boolean)false));
        this.onTileWritten = builder.onTileWritten;
    }

    @Override
    public void close() throws InterruptedException {
        this.executorService.shutdown();
        try {
            this.executorService.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
        }
        catch (InterruptedException e) {
            logger.debug("Waiting interrupted. Stopping tasks", (Throwable)e);
            this.executorService.shutdownNow();
            throw e;
        }
    }

    public void writeImage() {
        for (TileRequest tileRequest : this.server.getTileRequestManager().getAllTileRequests()) {
            this.writeTile(tileRequest);
        }
    }

    public void writeTile(TileRequest tileRequest) {
        this.executorService.execute(() -> {
            try {
                this.levelArrays.get(tileRequest.getLevel()).write(this.getData((BufferedImage)this.server.readRegion(tileRequest.getRegionRequest())), this.getDimensionsOfTile(tileRequest), this.getOffsetsOfTile(tileRequest));
            }
            catch (Exception e) {
                logger.error("Error when writing tile", (Throwable)e);
            }
            if (this.onTileWritten != null) {
                this.onTileWritten.accept(tileRequest);
            }
        });
    }

    public ImageServer<BufferedImage> getReaderServer() {
        return this.server;
    }

    private static int getChunkSize(int tileSize, int maxNumberOfChunks, int imageSize) {
        return Math.min(imageSize, maxNumberOfChunks > 0 ? Math.max(tileSize, imageSize / maxNumberOfChunks) : tileSize);
    }

    private static void createOmeSubGroup(ZarrGroup mainGroup, String imagePath, String omeXMLContent) {
        String fileName = "OME";
        try {
            mainGroup.createSubGroup(fileName);
            try (FileOutputStream outputStream = new FileOutputStream(Files.createFile(Paths.get(imagePath, fileName, "METADATA.ome.xml"), new FileAttribute[0]).toString());){
                ((OutputStream)outputStream).write(omeXMLContent.getBytes());
            }
        }
        catch (IOException e) {
            logger.error("Error while creating OME group or metadata XML file", (Throwable)e);
        }
    }

    private static Map<Integer, ZarrArray> createLevelArrays(ImageServer<BufferedImage> server, ZarrGroup root, Map<String, Object> levelAttributes, Compressor compressor) throws IOException {
        HashMap<Integer, ZarrArray> levelArrays = new HashMap<Integer, ZarrArray>();
        for (int level = 0; level < server.getMetadata().nLevels(); ++level) {
            Integer n = level;
            String string = "s" + level;
            ArrayParams arrayParams = new ArrayParams().shape(OMEZarrWriter.getDimensionsOfImage(server, level)).chunks(OMEZarrWriter.getChunksOfImage(server)).compressor(compressor);
            levelArrays.put(n, root.createArray(string, arrayParams.dataType(switch (server.getPixelType()) {
                default -> throw new MatchException(null, null);
                case PixelType.UINT8 -> DataType.u1;
                case PixelType.INT8 -> DataType.i1;
                case PixelType.UINT16 -> DataType.u2;
                case PixelType.INT16 -> DataType.i2;
                case PixelType.UINT32 -> DataType.u4;
                case PixelType.INT32 -> DataType.i4;
                case PixelType.FLOAT32 -> DataType.f4;
                case PixelType.FLOAT64 -> DataType.f8;
            }).dimensionSeparator(DimensionSeparator.SLASH), levelAttributes));
        }
        return levelArrays;
    }

    private static int[] getDimensionsOfImage(ImageServer<BufferedImage> server, int level) {
        ArrayList<Integer> dimensions = new ArrayList<Integer>();
        if (server.nTimepoints() > 1) {
            dimensions.add(server.nTimepoints());
        }
        if (server.nChannels() > 1) {
            dimensions.add(server.nChannels());
        }
        if (server.nZSlices() > 1) {
            dimensions.add(server.nZSlices());
        }
        dimensions.add((int)((double)server.getHeight() / server.getDownsampleForResolution(level)));
        dimensions.add((int)((double)server.getWidth() / server.getDownsampleForResolution(level)));
        return dimensions.stream().mapToInt(i -> i).toArray();
    }

    private static int[] getChunksOfImage(ImageServer<BufferedImage> server) {
        ArrayList<Integer> chunks = new ArrayList<Integer>();
        if (server.nTimepoints() > 1) {
            chunks.add(1);
        }
        if (server.nChannels() > 1) {
            chunks.add(1);
        }
        if (server.nZSlices() > 1) {
            chunks.add(1);
        }
        chunks.add(server.getMetadata().getPreferredTileHeight());
        chunks.add(server.getMetadata().getPreferredTileWidth());
        return chunks.stream().mapToInt(i -> i).toArray();
    }

    private Object getData(BufferedImage image) {
        Object pixels = AWTImageTools.getPixels((BufferedImage)image);
        if (this.server.isRGB()) {
            int[][] data = (int[][])pixels;
            int[] output = new int[this.server.nChannels() * image.getWidth() * image.getHeight()];
            int i = 0;
            for (int c = 0; c < this.server.nChannels(); ++c) {
                for (int y = 0; y < image.getHeight(); ++y) {
                    for (int x = 0; x < image.getWidth(); ++x) {
                        output[i] = data[c][x + image.getWidth() * y];
                        ++i;
                    }
                }
            }
            return output;
        }
        return switch (this.server.getPixelType()) {
            default -> throw new MatchException(null, null);
            case PixelType.UINT8, PixelType.INT8 -> {
                byte[][] data = (byte[][])pixels;
                byte[] output = new byte[this.server.nChannels() * image.getWidth() * image.getHeight()];
                int i = 0;
                for (int c = 0; c < this.server.nChannels(); ++c) {
                    for (int y = 0; y < image.getHeight(); ++y) {
                        for (int x = 0; x < image.getWidth(); ++x) {
                            output[i] = data[c][x + image.getWidth() * y];
                            ++i;
                        }
                    }
                }
                yield output;
            }
            case PixelType.UINT16, PixelType.INT16 -> {
                short[][] data = (short[][])pixels;
                short[] output = new short[this.server.nChannels() * image.getWidth() * image.getHeight()];
                int i = 0;
                for (int c = 0; c < this.server.nChannels(); ++c) {
                    for (int y = 0; y < image.getHeight(); ++y) {
                        for (int x = 0; x < image.getWidth(); ++x) {
                            output[i] = data[c][x + image.getWidth() * y];
                            ++i;
                        }
                    }
                }
                yield (Object[])output;
            }
            case PixelType.UINT32, PixelType.INT32 -> {
                int[][] data = (int[][])pixels;
                int[] output = new int[this.server.nChannels() * image.getWidth() * image.getHeight()];
                int i = 0;
                for (int c = 0; c < this.server.nChannels(); ++c) {
                    for (int y = 0; y < image.getHeight(); ++y) {
                        for (int x = 0; x < image.getWidth(); ++x) {
                            output[i] = data[c][x + image.getWidth() * y];
                            ++i;
                        }
                    }
                }
                yield (Object[])output;
            }
            case PixelType.FLOAT32 -> {
                float[][] data = (float[][])pixels;
                float[] output = new float[this.server.nChannels() * image.getWidth() * image.getHeight()];
                int i = 0;
                for (int c = 0; c < this.server.nChannels(); ++c) {
                    for (int y = 0; y < image.getHeight(); ++y) {
                        for (int x = 0; x < image.getWidth(); ++x) {
                            output[i] = data[c][x + image.getWidth() * y];
                            ++i;
                        }
                    }
                }
                yield (Object[])output;
            }
            case PixelType.FLOAT64 -> {
                double[][] data = (double[][])pixels;
                double[] output = new double[this.server.nChannels() * image.getWidth() * image.getHeight()];
                int i = 0;
                for (int c = 0; c < this.server.nChannels(); ++c) {
                    for (int y = 0; y < image.getHeight(); ++y) {
                        for (int x = 0; x < image.getWidth(); ++x) {
                            output[i] = data[c][x + image.getWidth() * y];
                            ++i;
                        }
                    }
                }
                yield (Object[])output;
            }
        };
    }

    private int[] getDimensionsOfTile(TileRequest tileRequest) {
        ArrayList<Integer> dimensions = new ArrayList<Integer>();
        if (this.server.nTimepoints() > 1) {
            dimensions.add(1);
        }
        if (this.server.nChannels() > 1) {
            dimensions.add(this.server.nChannels());
        }
        if (this.server.nZSlices() > 1) {
            dimensions.add(1);
        }
        dimensions.add(tileRequest.getTileHeight());
        dimensions.add(tileRequest.getTileWidth());
        return dimensions.stream().mapToInt(i -> i).toArray();
    }

    private int[] getOffsetsOfTile(TileRequest tileRequest) {
        ArrayList<Integer> offset = new ArrayList<Integer>();
        if (this.server.nTimepoints() > 1) {
            offset.add(tileRequest.getT());
        }
        if (this.server.nChannels() > 1) {
            offset.add(0);
        }
        if (this.server.nZSlices() > 1) {
            offset.add(tileRequest.getZ());
        }
        offset.add(tileRequest.getTileY());
        offset.add(tileRequest.getTileX());
        return offset.stream().mapToInt(i -> i).toArray();
    }

    public static class Builder {
        private static final String FILE_EXTENSION = ".ome.zarr";
        private final ImageServer<BufferedImage> server;
        private Compressor compressor = CompressorFactory.createDefaultCompressor();
        private int numberOfThreads = ThreadTools.getParallelism();
        private double[] downsamples = new double[0];
        private int maxNumberOfChunks = -1;
        private int tileWidth = -1;
        private int tileHeight = -1;
        private ImageRegion boundingBox = null;
        private int zStart = 0;
        private int zEnd;
        private int tStart = 0;
        private int tEnd;
        private Consumer<TileRequest> onTileWritten = null;

        public Builder(ImageServer<BufferedImage> server) {
            this.server = server;
            this.zEnd = this.server.nZSlices();
            this.tEnd = this.server.nTimepoints();
        }

        public Builder compression(Compressor compressor) {
            this.compressor = compressor;
            return this;
        }

        public Builder parallelize(int numberOfThreads) {
            this.numberOfThreads = numberOfThreads;
            return this;
        }

        public Builder downsamples(double ... downsamples) {
            this.downsamples = downsamples;
            return this;
        }

        public Builder setMaxNumberOfChunksOnEachSpatialDimension(int maxNumberOfChunks) {
            this.maxNumberOfChunks = maxNumberOfChunks;
            return this;
        }

        public Builder tileSize(int tileSize) {
            return this.tileSize(tileSize, tileSize);
        }

        public Builder tileSize(int tileWidth, int tileHeight) {
            this.tileWidth = tileWidth;
            this.tileHeight = tileHeight;
            return this;
        }

        public Builder region(ImageRegion boundingBox) {
            this.boundingBox = boundingBox;
            return this;
        }

        public Builder zSlices(int zStart, int zEnd) {
            this.zStart = zStart;
            this.zEnd = zEnd;
            return this;
        }

        public Builder timePoints(int tStart, int tEnd) {
            this.tStart = tStart;
            this.tEnd = tEnd;
            return this;
        }

        public Builder onTileWritten(Consumer<TileRequest> onTileWritten) {
            this.onTileWritten = onTileWritten;
            return this;
        }

        public OMEZarrWriter build(String path) throws IOException {
            if (!path.endsWith(FILE_EXTENSION)) {
                throw new IllegalArgumentException(String.format("The provided path (%s) does not have the OME-Zarr extension (%s)", path, FILE_EXTENSION));
            }
            if (Files.exists(Paths.get(path, new String[0]), new LinkOption[0])) {
                throw new IllegalArgumentException(String.format("The provided path (%s) already exists", path));
            }
            return new OMEZarrWriter(this, path);
        }
    }
}

