/*
 * Decompiled with CFR 0.152.
 */
package qupath.bioimageio.spec;

import com.google.gson.FieldNamingPolicy;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonDeserializationContext;
import com.google.gson.JsonDeserializer;
import com.google.gson.JsonElement;
import com.google.gson.JsonObject;
import com.google.gson.JsonParseException;
import com.google.gson.JsonPrimitive;
import com.google.gson.reflect.TypeToken;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.lang.reflect.Type;
import java.net.URI;
import java.nio.file.FileSystem;
import java.nio.file.FileSystems;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.yaml.snakeyaml.LoaderOptions;
import org.yaml.snakeyaml.Yaml;
import org.yaml.snakeyaml.constructor.BaseConstructor;
import org.yaml.snakeyaml.constructor.SafeConstructor;
import qupath.bioimageio.spec.Author;
import qupath.bioimageio.spec.Dataset;
import qupath.bioimageio.spec.FileDescription;
import qupath.bioimageio.spec.ModelParent;
import qupath.bioimageio.spec.Resource;
import qupath.bioimageio.spec.Weights;
import qupath.bioimageio.spec.tensor.BaseTensor;
import qupath.bioimageio.spec.tensor.InputTensor;
import qupath.bioimageio.spec.tensor.OutputTensor;
import qupath.bioimageio.spec.tensor.Tensors;

public class Model
extends Resource {
    private static final Logger logger = LoggerFactory.getLogger(Model.class);
    private URI baseURI;
    private URI uri;
    private List<InputTensor> inputs;
    private List<String> testInputs;
    private List<String> testOutputs;
    private String timestamp;
    private Weights.WeightsMap weights;
    private Map<?, ?> config;
    private List<OutputTensor> outputs;
    private List<Author> packagedBy;
    private ModelParent parent;
    private Map<?, ?> runMode;
    private List<String> sampleInputs;
    private List<String> sampleOutputs;
    static final List<String> MODEL_NAMES = List.of("model.yaml", "model.yml", "rdf.yaml", "rdf.yml");

    public void setBaseURI(URI baseURI) {
        this.baseURI = baseURI;
    }

    public void setUri(URI uri) {
        this.uri = uri;
    }

    public static Model parse(File file) throws IOException {
        return Model.parse(file.toPath());
    }

    public static Model parse(Path path) throws IOException {
        Path pathYaml = Model.findModelRdf(path);
        if (pathYaml == null) {
            throw new IOException("Can't find rdf.yaml from " + String.valueOf(path));
        }
        try (InputStream stream = Files.newInputStream(pathYaml, new OpenOption[0]);){
            Model model = Model.parse(stream);
            if (model != null) {
                model.setBaseURI(pathYaml.getParent().toUri());
                model.setUri(pathYaml.toUri());
            }
            Model model2 = model;
            return model2;
        }
    }

    public static Model parse(InputStream stream) throws IOException {
        try {
            Yaml yaml = new Yaml((BaseConstructor)new SafeConstructor(new LoaderOptions()));
            Map map = (Map)yaml.load(stream);
            GsonBuilder builder = new GsonBuilder().serializeSpecialFloatingPointValues().setPrettyPrinting().setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES).registerTypeAdapter(Model.class, (Object)new Deserializer()).registerTypeAdapter(Resource.class, (Object)new Resource.Deserializer()).registerTypeAdapter(Dataset.class, (Object)new Dataset.Deserializer()).registerTypeAdapter(Author.class, (Object)new Author.Deserializer()).registerTypeAdapter(Weights.WeightsEntry.class, (Object)new Weights.WeightsEntryDeserializer()).registerTypeAdapter(Weights.WeightsMap.class, (Object)new Weights.WeightsMapDeserializer()).registerTypeAdapter(double[].class, (Object)new DoubleArrayDeserializer()).setDateFormat("yyyy-MM-dd'T'HH:mm:ss");
            Map<Class<?>, JsonDeserializer<?>> deserializers = Tensors.getDeserializers();
            for (Map.Entry<Class<?>, JsonDeserializer<?>> entry : deserializers.entrySet()) {
                builder.registerTypeAdapter((Type)entry.getKey(), entry.getValue());
            }
            Gson gson = builder.create();
            String json = gson.toJson((Object)map);
            return (Model)gson.fromJson(json, Model.class);
        }
        catch (Exception e) {
            throw new IOException(e);
        }
    }

    public URI getBaseURI() {
        if (this.baseURI != null) {
            return this.baseURI;
        }
        if (this.uri != null) {
            return this.uri.resolve("..");
        }
        return null;
    }

    public URI getURI() {
        return this.uri;
    }

    public Map<String, Weights.ModelWeights> getWeights() {
        return this.weights == null ? Collections.emptyMap() : this.weights.withStringKeys();
    }

    public Weights.ModelWeights getWeights(Weights.WeightsEntry key) {
        if (this.weights == null || key == null) {
            return null;
        }
        return this.weights.getMap().getOrDefault((Object)key, null);
    }

    public Weights.ModelWeights getWeights(String key) {
        return this.getWeights(Weights.WeightsEntry.fromKey(key));
    }

    public ModelParent getParent() {
        return this.parent;
    }

    public String getTimestamp() {
        return this.timestamp;
    }

    public Map<?, ?> getConfig() {
        return this.config;
    }

    public List<InputTensor> getInputs() {
        return Model.toUnmodifiableList(this.inputs);
    }

    public List<OutputTensor> getOutputs() {
        return Model.toUnmodifiableList(this.outputs);
    }

    public List<String> getTestInputs() {
        List<String> ti = this.testInputs;
        if (ti.isEmpty() && this.isFormatNewerThan("0.5")) {
            ti = this.inputs.stream().map(BaseTensor::getTestTensor).map(ofd -> ofd.orElse(FileDescription.NULL_FILE).source()).collect(Collectors.toList());
        }
        return Model.toUnmodifiableList(ti);
    }

    public List<String> getTestOutputs() {
        List<String> to = this.testOutputs;
        if (to.isEmpty() && this.isFormatNewerThan("0.5")) {
            to = this.outputs.stream().map(BaseTensor::getTestTensor).map(ofd -> ofd.orElse(FileDescription.NULL_FILE).source()).collect(Collectors.toList());
        }
        return Model.toUnmodifiableList(to);
    }

    public List<String> getSampleInputs() {
        List<String> si = this.sampleInputs;
        if (si.isEmpty() && this.isFormatNewerThan("0.5")) {
            si = this.inputs.stream().map(BaseTensor::getTestTensor).map(ofd -> ofd.orElse(FileDescription.NULL_FILE).source()).collect(Collectors.toList());
        }
        return Model.toUnmodifiableList(si);
    }

    public List<String> getSampleOutputs() {
        List<String> so = this.sampleOutputs;
        if (so.isEmpty() && this.isFormatNewerThan("0.5")) {
            so = this.outputs.stream().map(BaseTensor::getTestTensor).map(ofd -> ofd.orElse(FileDescription.NULL_FILE).source()).collect(Collectors.toList());
        }
        return Model.toUnmodifiableList(so);
    }

    public static <T> List<T> toUnmodifiableList(List<T> list) {
        return list == null || list.isEmpty() ? Collections.emptyList() : Collections.unmodifiableList(list);
    }

    public static <T> T deserializeField(JsonDeserializationContext context, JsonObject obj, String name, Type typeOfT, boolean doStrict) throws IllegalArgumentException {
        if (doStrict && !obj.has(name)) {
            throw new IllegalArgumentException("Required field " + name + " not found");
        }
        return Model.deserializeField(context, obj, name, typeOfT, null);
    }

    public static <T> T deserializeField(JsonDeserializationContext context, JsonObject obj, String name, Type typeOfT, T defaultValue) {
        if (obj.has(name)) {
            return (T)Model.ensureUnmodifiable(context.deserialize(obj.get(name), typeOfT));
        }
        return Model.ensureUnmodifiable(defaultValue);
    }

    static Path findModelRdf(Path path) throws IOException {
        return Model.findRdf(path, MODEL_NAMES);
    }

    private static Path findRdf(Path path, Collection<String> names) throws IOException {
        if (Model.isYamlPath(path)) {
            if (names.isEmpty()) {
                return path;
            }
            String name = path.getFileName().toString().toLowerCase();
            if (names.contains(name) || name.startsWith("model") || name.startsWith("rdf")) {
                return path;
            }
            return null;
        }
        if (Files.isDirectory(path, new LinkOption[0])) {
            try (Stream<Path> pathStream = Files.list(path);){
                List<Path> yamlFiles = pathStream.filter(Model::isYamlPath).toList();
                if (yamlFiles.isEmpty()) {
                    Path path2 = null;
                    return path2;
                }
                if (yamlFiles.size() == 1) {
                    Path path3 = yamlFiles.get(0);
                    return path3;
                }
                for (String name : MODEL_NAMES) {
                    Path modelFile = yamlFiles.stream().filter(p -> p.getFileName().toString().equalsIgnoreCase(name)).findFirst().orElse(null);
                    if (modelFile == null) continue;
                    Path path4 = modelFile;
                    return path4;
                }
            }
        }
        if (path.toAbsolutePath().toString().toLowerCase().endsWith(".zip")) {
            FileSystem fs = FileSystems.newFileSystem(path, (ClassLoader)null);
            for (Path dir : fs.getRootDirectories()) {
                for (String name : MODEL_NAMES) {
                    Path p2 = dir.resolve(name);
                    if (!Files.exists(p2, new LinkOption[0])) continue;
                    return p2;
                }
            }
        }
        return null;
    }

    static boolean isYamlPath(Path path) {
        if (Files.isRegularFile(path, new LinkOption[0])) {
            String name = path.getFileName().toString().toLowerCase();
            return name.endsWith(".yaml") || name.endsWith(".yml");
        }
        return false;
    }

    static Type parameterizedListType(Type typeOfList) {
        return TypeToken.getParameterized(List.class, (Type[])new Type[]{typeOfList}).getType();
    }

    private static <T> T ensureUnmodifiable(T input) {
        if (input instanceof List) {
            return (T)Collections.unmodifiableList((List)input);
        }
        if (input instanceof Map) {
            return (T)Collections.unmodifiableMap((Map)input);
        }
        if (input instanceof Set) {
            return (T)Collections.unmodifiableSet((Set)input);
        }
        return input;
    }

    private static void deserializeModelFields(Model model, JsonObject obj, JsonDeserializationContext context, boolean doStrict) {
        model.inputs = (List)Model.deserializeField(context, obj, "inputs", Model.parameterizedListType(InputTensor.class), doStrict);
        ArrayList<InputTensor> tensors = new ArrayList<InputTensor>(model.inputs);
        for (InputTensor inputTensor : model.inputs) {
            inputTensor.validate(tensors);
        }
        if (model.isFormatNewerThan("0.5.0")) {
            model.testInputs = List.of();
            model.testOutputs = List.of();
            model.timestamp = "";
        } else {
            model.testInputs = (List)Model.deserializeField(context, obj, "test_inputs", Model.parameterizedListType(String.class), doStrict);
            model.testOutputs = (List)Model.deserializeField(context, obj, "test_outputs", Model.parameterizedListType(String.class), doStrict);
            model.timestamp = (String)Model.deserializeField(context, obj, "timestamp", String.class, doStrict);
        }
        model.weights = (Weights.WeightsMap)Model.deserializeField(context, obj, "weights", Weights.WeightsMap.class, doStrict);
        model.config = Model.deserializeField(context, obj, "config", Map.class, Collections.emptyMap());
        model.outputs = Model.deserializeField(context, obj, "outputs", Model.parameterizedListType(OutputTensor.class), Collections.emptyList());
        tensors.addAll(model.outputs);
        for (OutputTensor outputTensor : model.outputs) {
            outputTensor.validate(tensors);
        }
        model.packagedBy = Model.deserializeField(context, obj, "packaged_by", Model.parameterizedListType(Author.class), Collections.emptyList());
        model.parent = Model.deserializeField(context, obj, "parent", ModelParent.class, null);
        if (obj.has("run_mode")) {
            if (obj.get("run_mode").isJsonObject()) {
                model.runMode = Model.deserializeField(context, obj, "run_mode", Map.class, Collections.emptyMap());
            } else {
                logger.warn("Can't parse run_mode (not an object)");
            }
        }
        model.sampleInputs = Model.deserializeField(context, obj, "sample_inputs", Model.parameterizedListType(String.class), Collections.emptyList());
        model.sampleOutputs = Model.deserializeField(context, obj, "sample_outputs", Model.parameterizedListType(String.class), Collections.emptyList());
    }

    static class Deserializer
    implements JsonDeserializer<Model> {
        Deserializer() {
        }

        public Model deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) throws JsonParseException {
            if (json.isJsonNull()) {
                return null;
            }
            Model model = new Model();
            JsonObject obj = json.getAsJsonObject();
            Resource.deserializeResourceFields(model, obj, context, true);
            Model.deserializeModelFields(model, obj, context, true);
            return model;
        }
    }

    static class DoubleArrayDeserializer
    implements JsonDeserializer<double[]> {
        private static final Logger logger = LoggerFactory.getLogger(DoubleArrayDeserializer.class);

        DoubleArrayDeserializer() {
        }

        public double[] deserialize(JsonElement json, Type typeOfT, JsonDeserializationContext context) throws JsonParseException {
            if (json.isJsonNull()) {
                return null;
            }
            if (json.isJsonArray()) {
                ArrayList<Double> values = new ArrayList<Double>();
                for (JsonElement jsonVal : json.getAsJsonArray()) {
                    if (jsonVal.isJsonNull()) {
                        logger.warn("Found null when expecting a double - will replace with NaN");
                        values.add(Double.NaN);
                        continue;
                    }
                    JsonPrimitive jsonPrimitive = jsonVal.getAsJsonPrimitive();
                    if (jsonPrimitive.isNumber()) {
                        values.add(jsonPrimitive.getAsDouble());
                        continue;
                    }
                    String s = jsonPrimitive.getAsString();
                    if ("inf".equalsIgnoreCase(s)) {
                        values.add(Double.POSITIVE_INFINITY);
                        continue;
                    }
                    if ("-inf".equalsIgnoreCase(s)) {
                        values.add(Double.NEGATIVE_INFINITY);
                        continue;
                    }
                    values.add(Double.parseDouble(s));
                }
                return values.stream().mapToDouble(d -> d).toArray();
            }
            throw new JsonParseException("Can't parse data range from " + String.valueOf(json));
        }
    }
}

