/*
 * Decompiled with CFR 0.152.
 */
package qupath.opencv.tools;

import java.io.BufferedInputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.DoubleBuffer;
import java.nio.FloatBuffer;
import java.nio.IntBuffer;
import java.nio.ShortBuffer;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Arrays;
import java.util.Collections;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.zip.ZipEntry;
import java.util.zip.ZipFile;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.Scalar;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.common.GeneralTools;

public class NumpyTools {
    private static final Logger logger = LoggerFactory.getLogger(NumpyTools.class);
    private static final Pattern PATTERN_DTYPE = Pattern.compile("'descr': *'([^']+)'");
    private static final Pattern PATTERN_ORDER = Pattern.compile("'fortran_order': *([^,]+)");
    private static final Pattern PATTERN_SHAPE = Pattern.compile("'shape': *\\(([^']+)\\)");

    public static Mat readMat(String path) throws IOException {
        return NumpyTools.readMat(Paths.get(path, new String[0]));
    }

    public static Mat readMat(String path, boolean squeezeDimensions) throws IOException {
        return NumpyTools.readMat(Paths.get(path, new String[0]), squeezeDimensions);
    }

    public static Map<String, Mat> readAllMats(String path) throws IOException {
        return NumpyTools.readAllMats(Paths.get(path, new String[0]));
    }

    public static Map<String, Mat> readAllMats(String path, boolean squeezeDimensions) throws IOException {
        return NumpyTools.readAllMats(Paths.get(path, new String[0]), squeezeDimensions);
    }

    public static Map<String, Mat> readAllMats(Path path) throws IOException {
        return NumpyTools.readAllMats(path, false);
    }

    public static Map<String, Mat> readAllMats(Path path, boolean squeezeDimensions) throws IOException {
        FileType type;
        try (InputStream stream = Files.newInputStream(path, new OpenOption[0]);){
            type = NumpyTools.checkType(stream);
        }
        switch (type.ordinal()) {
            case 0: {
                return Collections.singletonMap(GeneralTools.getNameWithoutExtension((File)path.toFile()), NumpyTools.readMat(path, squeezeDimensions));
            }
            case 1: {
                return NumpyTools.readZipped(path, squeezeDimensions, false);
            }
        }
        throw new IllegalArgumentException(String.valueOf(path) + " is not a valid npy or npz file!");
    }

    private static Map<String, Mat> readZipped(Path path, boolean squeezeDimensions, boolean firstOnly) throws IOException {
        LinkedHashMap<String, Mat> map = new LinkedHashMap<String, Mat>();
        try (ZipFile zipFile = new ZipFile(path.toFile());){
            Iterator<? extends ZipEntry> iter = zipFile.entries().asIterator();
            while (iter.hasNext()) {
                ZipEntry entry = iter.next();
                try (BufferedInputStream stream = new BufferedInputStream(zipFile.getInputStream(entry));){
                    Mat mat = NumpyTools.readMat(stream, squeezeDimensions);
                    map.put(GeneralTools.stripExtension((String)entry.getName()), mat);
                    if (!firstOnly) continue;
                    break;
                }
            }
        }
        return map;
    }

    public static Mat readMat(Path path) throws IOException {
        return NumpyTools.readMat(path, false);
    }

    public static Mat readMat(Path path, boolean squeezeDimensions) throws IOException {
        FileType type;
        try (InputStream stream = Files.newInputStream(path, new OpenOption[0]);){
            type = NumpyTools.checkType(stream);
        }
        switch (type.ordinal()) {
            case 0: {
                stream = new BufferedInputStream(Files.newInputStream(path, new OpenOption[0]));
                try {
                    Mat mat = NumpyTools.readMat(stream, squeezeDimensions);
                    return mat;
                }
                finally {
                    ((BufferedInputStream)stream).close();
                }
            }
            case 1: {
                Map<String, Mat> map = NumpyTools.readZipped(path, squeezeDimensions, true);
                if (map.isEmpty()) {
                    throw new IllegalArgumentException(String.valueOf(path) + " does not contain any arrays!");
                }
                return map.values().iterator().next();
            }
        }
        throw new IllegalArgumentException(String.valueOf(path) + " is not a valid npy or npz file!");
    }

    private static FileType checkType(InputStream stream) throws IOException {
        int second;
        int firstByte = stream.read();
        if (firstByte == 147) {
            byte[] magic = stream.readNBytes(5);
            if (Arrays.equals("NUMPY".getBytes(StandardCharsets.US_ASCII), magic)) {
                return FileType.NPY;
            }
            return FileType.UNKNOWN;
        }
        if (firstByte == 80 && (second = stream.read()) == 75) {
            return FileType.ZIP;
        }
        return FileType.UNKNOWN;
    }

    public static Mat readMat(InputStream stream, boolean squeezeDimensions) throws IOException {
        int headerLength;
        if (stream.read() != 147) {
            throw new IOException("File is not in npy format!");
        }
        byte[] magic = stream.readNBytes(5);
        if (!Arrays.equals("NUMPY".getBytes(StandardCharsets.US_ASCII), magic)) {
            throw new IOException("File is not in npy format!");
        }
        int majorVersion = stream.read();
        int minorVersion = stream.read();
        if (majorVersion >= 2) {
            headerLength = ByteBuffer.wrap(stream.readNBytes(2)).order(ByteOrder.LITTLE_ENDIAN).getInt();
        } else {
            int b1 = stream.read();
            int b2 = stream.read();
            headerLength = (b2 & 0xFF) << 8 | b1 & 0xFF;
        }
        if (headerLength <= 0) {
            throw new IOException("Unsupported header length " + headerLength);
        }
        byte[] headerBytes = stream.readNBytes(headerLength);
        String dict = majorVersion >= 3 ? new String(headerBytes, StandardCharsets.UTF_8).strip() : new String(headerBytes, StandardCharsets.ISO_8859_1).strip();
        logger.debug("Version: {}.{}, Dict: {}", new Object[]{majorVersion, minorVersion, dict});
        String dtypeString = NumpyTools.getMatch(PATTERN_DTYPE, dict);
        if (dtypeString == null) {
            throw new IOException("Unable to find dtype in file");
        }
        String orderString = NumpyTools.getMatch(PATTERN_ORDER, dict);
        if (orderString == null) {
            throw new IOException("Unable to find fortran_order in file");
        }
        boolean fortranOrder = Boolean.valueOf(orderString);
        if (fortranOrder) {
            throw new IOException("Fortran order is not supported, sorry");
        }
        String shapeString = NumpyTools.getMatch(PATTERN_SHAPE, dict);
        if (shapeString == null) {
            throw new IOException("Unable to find array shape in file");
        }
        int[] shape = Arrays.stream(shapeString.split(",")).map(s -> s.strip()).filter(s -> !s.isEmpty()).mapToInt(s -> Integer.parseInt(s)).toArray();
        ByteOrder byteOrder = ByteOrder.nativeOrder();
        if (dtypeString.startsWith("=")) {
            logger.debug("Byte order specified as native order");
            byteOrder = ByteOrder.nativeOrder();
            dtypeString = dtypeString.substring(1);
        } else if (dtypeString.startsWith(">")) {
            logger.debug("Byte order specified as big endian");
            byteOrder = ByteOrder.BIG_ENDIAN;
            dtypeString = dtypeString.substring(1);
        } else if (dtypeString.startsWith("<")) {
            logger.debug("Byte order specified as little endian");
            byteOrder = ByteOrder.LITTLE_ENDIAN;
            dtypeString = dtypeString.substring(1);
        } else if (dtypeString.startsWith("|")) {
            dtypeString = dtypeString.substring(1);
        } else {
            logger.warn("Byte order not specified - will use " + String.valueOf(byteOrder));
        }
        int nPixels = shape[0];
        for (int i2 = 1; i2 < shape.length; ++i2) {
            nPixels *= shape[i2];
        }
        if (squeezeDimensions) {
            shape = Arrays.stream(shape).filter(i -> i != 1).toArray();
        }
        switch (dtypeString) {
            case "float32": 
            case "f4": 
            case "f": {
                Mat mat = NumpyTools.createMat(shape, 5);
                FloatBuffer buff32 = (FloatBuffer)mat.createBuffer();
                buff32.put(ByteBuffer.wrap(stream.readNBytes(nPixels * 4)).order(byteOrder).asFloatBuffer());
                return mat;
            }
            case "float64": 
            case "f8": 
            case "d": {
                Mat mat = NumpyTools.createMat(shape, 6);
                DoubleBuffer buff64 = (DoubleBuffer)mat.createBuffer();
                buff64.put(ByteBuffer.wrap(stream.readNBytes(nPixels * 8)).order(byteOrder).asDoubleBuffer());
                return mat;
            }
            case "uint8": 
            case "u1": 
            case "b": {
                Mat mat = NumpyTools.createMat(shape, 0);
                ByteBuffer bufu8 = (ByteBuffer)mat.createBuffer();
                bufu8.put(ByteBuffer.wrap(stream.readNBytes(nPixels)).order(byteOrder));
                return mat;
            }
            case "int8": 
            case "i1": 
            case "B": {
                Mat mat = NumpyTools.createMat(shape, 1);
                ByteBuffer bufi8 = (ByteBuffer)mat.createBuffer();
                bufi8.put(ByteBuffer.wrap(stream.readNBytes(nPixels)).order(byteOrder));
                return mat;
            }
            case "int16": 
            case "i2": 
            case "h": {
                Mat mat = NumpyTools.createMat(shape, 3);
                ShortBuffer bufs16 = (ShortBuffer)mat.createBuffer();
                bufs16.put(ByteBuffer.wrap(stream.readNBytes(nPixels * 2)).order(byteOrder).asShortBuffer());
                return mat;
            }
            case "uint16": 
            case "u2": 
            case "H": {
                Mat mat = NumpyTools.createMat(shape, 2);
                ShortBuffer bufu16 = (ShortBuffer)mat.createBuffer();
                bufu16.put(ByteBuffer.wrap(stream.readNBytes(nPixels * 2)).order(byteOrder).asShortBuffer());
                return mat;
            }
            case "int32": 
            case "i4": 
            case "i": {
                Mat mat = NumpyTools.createMat(shape, 4);
                IntBuffer bufs32 = (IntBuffer)mat.createBuffer();
                bufs32.put(ByteBuffer.wrap(stream.readNBytes(nPixels * 4)).order(byteOrder).asIntBuffer());
                return mat;
            }
        }
        throw new IOException("Unsupported data type " + orderString);
    }

    private static Mat createMat(int[] shape, int depth) {
        if (shape.length == 3 && shape[2] < 512) {
            return new Mat(shape[0], shape[1], opencv_core.CV_MAKETYPE((int)depth, (int)shape[2]), Scalar.ZERO);
        }
        return new Mat(shape, depth, Scalar.ZERO);
    }

    private static String getMatch(Pattern pattern, String string) {
        Matcher matcher = pattern.matcher(string);
        if (matcher.find()) {
            return matcher.group(1).strip();
        }
        return null;
    }

    private static enum FileType {
        NPY,
        ZIP,
        UNKNOWN;

    }
}

