/*
 * Decompiled with CFR 0.152.
 */
package qupath.ext.djl;

import ai.djl.engine.Engine;
import ai.djl.ndarray.types.LayoutType;
import java.net.URI;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.util.List;
import qupath.ext.djl.DjlDnnModel;
import qupath.ext.djl.DjlTools;
import qupath.lib.common.GeneralTools;
import qupath.opencv.dnn.DnnModel;
import qupath.opencv.dnn.DnnModelBuilder;
import qupath.opencv.dnn.DnnModelParams;

public class DjlDnnModelBuilder
implements DnnModelBuilder {
    private static String getEngineName(String framework) {
        if (DjlTools.ALL_ENGINES.contains(framework)) {
            return framework;
        }
        switch (framework) {
            case "TensorFlow": {
                return DjlTools.ENGINE_TENSORFLOW;
            }
            case "TFLite": {
                return DjlTools.ENGINE_TFLITE;
            }
            case "OnnxRuntime": {
                return DjlTools.ENGINE_ONNX_RUNTIME;
            }
            case "PyTorch": {
                return DjlTools.ENGINE_PYTORCH;
            }
            case "MxNet": {
                return DjlTools.ENGINE_MXNET;
            }
        }
        return null;
    }

    private static String estimateEngine(URI uri) {
        String urlString = uri.toString().toLowerCase();
        if (urlString.endsWith(".onnx") && Engine.hasEngine((String)DjlTools.ENGINE_ONNX_RUNTIME)) {
            return DjlTools.ENGINE_ONNX_RUNTIME;
        }
        if ((urlString.endsWith("pytorch") || urlString.endsWith(".pt")) && Engine.hasEngine((String)DjlTools.ENGINE_PYTORCH)) {
            return DjlTools.ENGINE_PYTORCH;
        }
        if (urlString.endsWith(".tflite") && Engine.hasEngine((String)DjlTools.ENGINE_TFLITE)) {
            return DjlTools.ENGINE_TFLITE;
        }
        Path path = GeneralTools.toPath((URI)uri);
        if (path != null && ("saved_model.pb".equals(path.getFileName().toString()) || Files.isDirectory(path, new LinkOption[0]) && Files.exists(path.resolve("saved_model.pb"), new LinkOption[0])) && Engine.hasEngine((String)DjlTools.ENGINE_TENSORFLOW)) {
            return DjlTools.ENGINE_TENSORFLOW;
        }
        return null;
    }

    private static LayoutType getLayout(char c) {
        switch (c) {
            case 'b': {
                return LayoutType.BATCH;
            }
            case 't': {
                return LayoutType.TIME;
            }
            case 'c': {
                return LayoutType.CHANNEL;
            }
            case 'z': {
                return LayoutType.DEPTH;
            }
            case 'y': {
                return LayoutType.HEIGHT;
            }
            case 'x': {
                return LayoutType.WIDTH;
            }
            case 'i': {
                return LayoutType.UNKNOWN;
            }
        }
        return LayoutType.UNKNOWN;
    }

    private static String axesToLayout(String axes) {
        if (axes == null) {
            return null;
        }
        axes = axes.strip().toLowerCase();
        StringBuilder sb = new StringBuilder(axes.length());
        for (char c : axes.toCharArray()) {
            sb.append(DjlDnnModelBuilder.getLayout(c).getValue());
        }
        return sb.toString();
    }

    public DnnModel buildModel(DnnModelParams params) {
        String framework = params.getFramework();
        String engineName = null;
        if (framework == null) {
            List uris = params.getURIs();
            if (uris.isEmpty()) {
                return null;
            }
            engineName = DjlDnnModelBuilder.estimateEngine((URI)uris.get(0));
        } else {
            engineName = DjlDnnModelBuilder.getEngineName(framework);
        }
        if (engineName == null || !Engine.hasEngine((String)engineName)) {
            return null;
        }
        String layout = params.getLayout();
        if (layout != null) {
            layout = DjlDnnModelBuilder.axesToLayout(layout);
        }
        return new DjlDnnModel(engineName, params.getURIs(), layout, params.getInputs(), params.getOutputs(), params.requestLazyInitialize());
    }
}

