/*
 * Decompiled with CFR 0.152.
 */
package qupath.opencv.ml;

import com.google.gson.TypeAdapter;
import com.google.gson.annotations.JsonAdapter;
import com.google.gson.stream.JsonReader;
import com.google.gson.stream.JsonWriter;
import java.io.IOException;
import java.nio.IntBuffer;
import java.util.Arrays;
import java.util.Locale;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import org.bytedeco.javacpp.indexer.DoubleIndexer;
import org.bytedeco.javacpp.indexer.FloatIndexer;
import org.bytedeco.javacpp.indexer.Indexer;
import org.bytedeco.javacpp.indexer.IntIndexer;
import org.bytedeco.opencv.global.opencv_core;
import org.bytedeco.opencv.opencv_core.Mat;
import org.bytedeco.opencv.opencv_core.Scalar;
import org.bytedeco.opencv.opencv_core.TermCriteria;
import org.bytedeco.opencv.opencv_core.UMat;
import org.bytedeco.opencv.opencv_ml.ANN_MLP;
import org.bytedeco.opencv.opencv_ml.Boost;
import org.bytedeco.opencv.opencv_ml.DTrees;
import org.bytedeco.opencv.opencv_ml.EM;
import org.bytedeco.opencv.opencv_ml.KNearest;
import org.bytedeco.opencv.opencv_ml.LogisticRegression;
import org.bytedeco.opencv.opencv_ml.NormalBayesClassifier;
import org.bytedeco.opencv.opencv_ml.RTrees;
import org.bytedeco.opencv.opencv_ml.SVM;
import org.bytedeco.opencv.opencv_ml.SVMSGD;
import org.bytedeco.opencv.opencv_ml.StatModel;
import org.bytedeco.opencv.opencv_ml.TrainData;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import qupath.lib.common.GeneralTools;
import qupath.lib.plugins.parameters.ParameterList;
import qupath.opencv.io.OpenCVTypeAdapters;
import qupath.opencv.tools.OpenCVTools;

public class OpenCVClassifiers {
    private static final Logger logger = LoggerFactory.getLogger(OpenCVClassifiers.class);

    public static OpenCVStatModel createStatModel(Class<? extends StatModel> cls) {
        if (RTrees.class.equals(cls)) {
            return new RTreesClassifier();
        }
        if (Boost.class.equals(cls)) {
            return new BoostClassifier();
        }
        if (DTrees.class.equals(cls)) {
            return new DTreesClassifier();
        }
        if (KNearest.class.equals(cls)) {
            return new KNearestClassifierCV();
        }
        if (ANN_MLP.class.equals(cls)) {
            return new ANNClassifierCV();
        }
        if (LogisticRegression.class.equals(cls)) {
            return new LogisticRegressionClassifier();
        }
        if (EM.class.equals(cls)) {
            return new EMClusterer();
        }
        if (NormalBayesClassifier.class.equals(cls)) {
            return new NormalBayesClassifierCV();
        }
        if (SVM.class.equals(cls)) {
            return new SVMClassifierCV();
        }
        if (SVMSGD.class.equals(cls)) {
            return new SVMSGDClassifierCV();
        }
        throw new IllegalArgumentException("Unknown StatModel class " + String.valueOf(cls));
    }

    public static OpenCVStatModel wrapStatModel(StatModel statModel) {
        Class<?> cls = statModel.getClass();
        if (RTrees.class.equals(cls)) {
            return new RTreesClassifier((RTrees)statModel);
        }
        if (Boost.class.equals(cls)) {
            return new BoostClassifier((Boost)statModel);
        }
        if (DTrees.class.equals(cls)) {
            return new DTreesClassifier((DTrees)statModel);
        }
        if (KNearest.class.equals(cls)) {
            return new KNearestClassifierCV((KNearest)statModel);
        }
        if (ANN_MLP.class.equals(cls)) {
            return new ANNClassifierCV((ANN_MLP)statModel);
        }
        if (LogisticRegression.class.equals(cls)) {
            return new LogisticRegressionClassifier((LogisticRegression)statModel);
        }
        if (EM.class.equals(cls)) {
            return new EMClusterer((EM)statModel);
        }
        if (NormalBayesClassifier.class.equals(cls)) {
            return new NormalBayesClassifierCV((NormalBayesClassifier)statModel);
        }
        if (SVM.class.equals(cls)) {
            return new SVMClassifierCV((SVM)statModel);
        }
        if (SVMSGD.class.equals(cls)) {
            return new SVMSGDClassifierCV((SVMSGD)statModel);
        }
        throw new IllegalArgumentException("Unknown StatModel class " + String.valueOf(cls));
    }

    static void addTerminationCriteriaParameters(ParameterList params, TermCriteria defaultCriteria) {
        params.addTitleParameter("Termination criteria");
        params.addIntParameter("termIterations", "Max iterations", defaultCriteria.maxCount(), null, "Maximum number of iterations for training");
        params.addDoubleParameter("termEpsilon", "Epsilon", defaultCriteria.epsilon(), null, "Desired accuracy for training");
    }

    static TermCriteria updateTermCriteria(ParameterList params, TermCriteria termCriteria) {
        int count = params.getIntParameterValue("termIterations");
        double epsilon = params.getDoubleParameterValue("termEpsilon");
        if (termCriteria != null && termCriteria.maxCount() == count && termCriteria.epsilon() == epsilon) {
            return termCriteria;
        }
        int type = 0;
        int termIterations = params.getIntParameterValue("termIterations");
        double termEpsilon = params.getDoubleParameterValue("termEpsilon");
        if (termIterations >= 1) {
            ++type;
        }
        if (termIterations > 0) {
            type += 2;
        }
        return new TermCriteria(type, termIterations, termEpsilon);
    }

    public static class RTreesClassifier
    extends AbstractTreeClassifier<RTrees> {
        private double[] featureImportance;

        RTreesClassifier() {
        }

        RTreesClassifier(RTrees model) {
            super(model);
        }

        @Override
        RTrees createStatModel() {
            RTrees model = RTrees.create();
            model.setMaxDepth(0);
            model.setTermCriteria(new TermCriteria(1, 50, 0.0));
            return model;
        }

        @Override
        ParameterList createParameterList(RTrees model) {
            ParameterList params = super.createParameterList(model);
            int activeVarCount = model.getActiveVarCount();
            TermCriteria termCrit = model.getTermCriteria();
            int maxTrees = termCrit.maxCount();
            double epsilon = termCrit.epsilon();
            boolean calcImportance = model.getCalculateVarImportance();
            params.addIntParameter("activeVarCount", "Active variable count", activeVarCount, null, "Number of features per tree node (if <=0, will use square root of number of features)");
            params.addIntParameter("maxTrees", "Maximum number of trees", maxTrees, null, "Maximum possible number of trees - but viewer may be used if 'Termination epsilon' is high");
            params.addDoubleParameter("epsilon", "Termination epsilon", epsilon, null, "Termination criterion - if this is high, viewer trees may be used for classification");
            params.addBooleanParameter("calcImportance", "Calculate variable importance", calcImportance, "Calculate estimate of each variable's importance (this impacts the results of the classifier!)");
            return params;
        }

        @Override
        public void train(TrainData trainData) {
            super.train(trainData);
            RTrees trees = (RTrees)this.getStatModel();
            if (trees.getCalculateVarImportance()) {
                Mat importance = trees.getVarImportance();
                Indexer indexer = importance.createIndexer();
                int nFeatures = (int)indexer.size(0);
                this.featureImportance = new double[nFeatures];
                for (int r = 0; r < nFeatures; ++r) {
                    this.featureImportance[r] = indexer.getDouble(new long[]{r});
                }
                indexer.release();
            } else {
                this.featureImportance = null;
            }
        }

        public synchronized boolean hasFeatureImportance() {
            return this.featureImportance != null;
        }

        public double[] getFeatureImportance() {
            return this.featureImportance == null ? null : (double[])this.featureImportance.clone();
        }

        @Override
        void updateModel(RTrees model, ParameterList params, TrainData trainData) {
            super.updateModel(model, params, trainData);
            int activeVarCount = params.getIntParameterValue("activeVarCount");
            int maxTrees = params.getIntParameterValue("maxTrees");
            double epsilon = params.getDoubleParameterValue("epsilon");
            boolean calcImportance = params.getBooleanParameterValue("calcImportance");
            int type = 0;
            if (maxTrees >= 1) {
                ++type;
            }
            if (epsilon > 0.0) {
                type += 2;
            }
            TermCriteria termCrit = new TermCriteria(type, maxTrees, epsilon);
            model.setActiveVarCount(activeVarCount);
            model.setUseSurrogates(false);
            model.setTermCriteria(termCrit);
            model.setCalculateVarImportance(calcImportance);
        }

        @Override
        Class<? extends StatModel> getStatModelClass() {
            return RTrees.class;
        }

        @Override
        public void predictWithLock(Mat samples, Mat results, Mat probabilities) {
            RTrees model = (RTrees)this.getStatModel();
            if (probabilities == null) {
                model.predict(samples, results, 0);
                results.convertTo(results, opencv_core.CV_32SC1);
                return;
            }
            Mat votes = new Mat();
            model.getVotes(samples, votes, 0);
            int nClasses = votes.cols();
            int nSamples = samples.rows();
            IntIndexer indexer = (IntIndexer)votes.createIndexer();
            probabilities.create(nSamples, nClasses, opencv_core.CV_32FC1);
            FloatIndexer idxProbabilities = (FloatIndexer)probabilities.createIndexer();
            results.create(nSamples, 1, opencv_core.CV_32SC1);
            IntIndexer idxResults = (IntIndexer)results.createIndexer();
            int[] orderedClasses = new int[nClasses];
            for (int c = 0; c < nClasses; ++c) {
                orderedClasses[c] = indexer.get(0L, (long)c);
            }
            long row = 1L;
            for (int i = 0; i < nSamples; ++i) {
                double sum = 0.0;
                int maxCount = -1;
                int maxInd = -1;
                for (long c = 0L; c < (long)nClasses; ++c) {
                    int count = indexer.get(row, c);
                    if (count > maxCount) {
                        maxCount = count;
                        maxInd = (int)c;
                    }
                    sum += (double)count;
                }
                for (int c = 0; c < nClasses; ++c) {
                    int count = indexer.get(row, (long)c);
                    idxProbabilities.put((long)i, (long)orderedClasses[c], (float)((double)count / sum));
                }
                int prediction = orderedClasses[maxInd];
                idxResults.put((long)i, prediction);
                ++row;
            }
            indexer.release();
            idxProbabilities.release();
            idxResults.release();
            votes.close();
        }
    }

    public static class BoostClassifier
    extends AbstractTreeClassifier<Boost> {
        BoostClassifier() {
        }

        BoostClassifier(Boost model) {
            super(model);
        }

        @Override
        Boost createStatModel() {
            return Boost.create();
        }

        @Override
        Class<? extends StatModel> getStatModelClass() {
            return Boost.class;
        }

        @Override
        ParameterList createParameterList(Boost model) {
            ParameterList params = super.createParameterList(model);
            int weakCount = model.getWeakCount();
            double weightTrimRate = model.getWeightTrimRate();
            params.addIntParameter("weakCount", "Number of weak classifiers", weakCount, null, "Number of weak classifiers to train");
            params.addDoubleParameter("weightTrimRate", "Weight trim rate", weightTrimRate, null, 0.0, 1.0, "Threshold used to save computational time");
            return params;
        }

        @Override
        void updateModel(Boost model, ParameterList params, TrainData trainData) {
            super.updateModel(model, params, trainData);
            int weakCount = params.getIntParameterValue("weakCount");
            double weightTrimRate = params.getDoubleParameterValue("weightTrimRate");
            model.setWeakCount(weakCount);
            model.setWeightTrimRate(weightTrimRate);
        }
    }

    public static class DTreesClassifier
    extends AbstractTreeClassifier<DTrees> {
        DTreesClassifier() {
        }

        DTreesClassifier(DTrees model) {
            super(model);
        }

        @Override
        DTrees createStatModel() {
            return DTrees.create();
        }

        @Override
        Class<? extends StatModel> getStatModelClass() {
            return DTrees.class;
        }
    }

    static class KNearestClassifierCV
    extends AbstractOpenCVClassifierML<KNearest> {
        KNearestClassifierCV() {
        }

        KNearestClassifierCV(KNearest model) {
            super(model);
        }

        @Override
        ParameterList createParameterList(KNearest model) {
            ParameterList params = new ParameterList();
            int defaultK = model.getDefaultK();
            params.addIntParameter("defaultK", "Default K", defaultK, null, "Number of nearest neighbors");
            return params;
        }

        @Override
        KNearest createStatModel() {
            return KNearest.create();
        }

        @Override
        Class<? extends StatModel> getStatModelClass() {
            return KNearest.class;
        }

        @Override
        void updateModel(KNearest model, ParameterList params, TrainData trainData) {
            int defaultK = params.getIntParameterValue("defaultK");
            model.setDefaultK(defaultK);
            model.setIsClassifier(true);
        }
    }

    static class ANNClassifierCV
    extends AbstractOpenCVClassifierML<ANN_MLP> {
        private static Logger logger = LoggerFactory.getLogger(ANNClassifierCV.class);
        private int MAX_HIDDEN_LAYERS = 5;

        ANNClassifierCV() {
        }

        ANNClassifierCV(ANN_MLP model) {
            super(model);
        }

        @Override
        ParameterList createParameterList(ANN_MLP model) {
            int[] layerSizes;
            Mat sizes = model.getLayerSizes();
            if (!sizes.empty()) {
                Indexer idx = sizes.createIndexer();
                int n = (int)sizes.total();
                layerSizes = new int[n];
                for (int i = 0; i < n; ++i) {
                    layerSizes[i] = (int)idx.getDouble(new long[]{i});
                }
                idx.release();
                this.MAX_HIDDEN_LAYERS = n;
            } else {
                layerSizes = new int[this.MAX_HIDDEN_LAYERS];
            }
            ParameterList params = new ParameterList();
            params.addTitleParameter("Hidden layers");
            for (int i = 1; i <= layerSizes.length; ++i) {
                params.addIntParameter("hidden" + i, "Layer " + i, layerSizes[i - 1], "Nodes", "Size of first hidden layer (0 to omit layer)");
            }
            OpenCVClassifiers.addTerminationCriteriaParameters(params, model.getTermCriteria());
            return params;
        }

        @Override
        protected int getTrainFlags() {
            return 4;
        }

        @Override
        ANN_MLP createStatModel() {
            return ANN_MLP.create();
        }

        @Override
        Class<? extends StatModel> getStatModelClass() {
            return ANN_MLP.class;
        }

        @Override
        public TrainData createTrainData(Mat samples, Mat targets, Mat weights, boolean doMulticlass) {
            if (doMulticlass) {
                Indexer indexer = targets.createIndexer();
                Mat targets2 = new Mat(targets.rows(), targets.cols(), opencv_core.CV_32FC1, Scalar.all((double)-1.0));
                FloatIndexer idxTargets = (FloatIndexer)targets2.createIndexer();
                int nRows = targets.rows();
                int nCols = targets.cols();
                long[] inds = new long[2];
                for (int r = 0; r < nRows; ++r) {
                    for (int c = 0; c < nCols; ++c) {
                        inds[0] = r;
                        inds[1] = c;
                        double val = indexer.getDouble(inds);
                        if (!(val > 0.0)) continue;
                        idxTargets.put(inds, 1.0f);
                    }
                }
                targets.put(targets2);
                targets2.close();
            } else {
                IntBuffer buffer = (IntBuffer)OpenCVTools.ensureContinuous(targets, false).createBuffer();
                int[] vals = new int[targets.rows()];
                buffer.get(vals);
                int max = Arrays.stream(vals).max().orElse(0) + 1;
                Mat targets2 = new Mat(targets.rows(), max, opencv_core.CV_32FC1, Scalar.all((double)-1.0));
                FloatIndexer idxTargets = (FloatIndexer)targets2.createIndexer();
                int row = 0;
                for (int v : vals) {
                    idxTargets.put((long)row, (long)v, 1.0f);
                    ++row;
                }
                targets.put(targets2);
                targets2.close();
            }
            return super.createTrainData(samples, targets, weights, doMulticlass);
        }

        @Override
        public void predictWithLock(Mat samples, Mat results, Mat probabilities) {
            boolean isSigmoidSym = true;
            double beta = 1.0;
            if (probabilities == null) {
                probabilities = new Mat();
            }
            super.predictWithLock(samples, results, probabilities);
            if (isSigmoidSym) {
                Indexer indexer = probabilities.createIndexer();
                long[] inds = new long[2];
                long rows = indexer.size(0);
                long cols = indexer.size(1);
                double scale = 0.5 / beta;
                double offset = 0.5;
                for (long r = 0L; r < rows; ++r) {
                    inds[0] = r;
                    long c = 0L;
                    while (c < cols) {
                        inds[1] = c++;
                        double val = indexer.getDouble(inds) * scale + offset;
                        indexer.putDouble(inds, val);
                    }
                }
                indexer.release();
            }
        }

        @Override
        void updateModel(ANN_MLP model, ParameterList params, TrainData trainData) {
            int nMeasurements = trainData.getNVars();
            int nClasses = trainData.getResponses().cols();
            double[] layers = new double[this.MAX_HIDDEN_LAYERS + 2];
            layers[0] = nMeasurements;
            int n = 1;
            for (int i = 1; i <= this.MAX_HIDDEN_LAYERS; ++i) {
                int size;
                String name = "hidden" + i;
                if (!params.containsKey((Object)name) || (size = params.getIntParameterValue(name).intValue()) <= 1) continue;
                layers[n] = size;
                ++n;
            }
            layers[n] = nClasses;
            if (++n < layers.length) {
                layers = Arrays.copyOf(layers, n);
            }
            Mat mat = new Mat(n, 1, 6, Scalar.ZERO);
            DoubleIndexer idx = (DoubleIndexer)mat.createIndexer();
            for (int i = 0; i < n; ++i) {
                idx.put((long)i, layers[i]);
            }
            idx.release();
            model.setLayerSizes(mat);
            model.setActivationFunction(1, 1.0, 1.0);
            model.setTermCriteria(OpenCVClassifiers.updateTermCriteria(params, model.getTermCriteria()));
            logger.debug("Initializing ANN with layer sizes: " + GeneralTools.arrayToString((Locale)Locale.getDefault(Locale.Category.FORMAT), (double[])layers, (int)0));
        }

        static enum TrainingMethod {
            BACKPROP,
            RPROP,
            ANNEAL;


            public int getTrainingMethod() {
                switch (this.ordinal()) {
                    case 0: {
                        return 0;
                    }
                    case 1: {
                        return 1;
                    }
                    case 2: {
                        return 2;
                    }
                }
                return 0;
            }
        }

        static enum ActivationFunction {
            IDENTITY,
            SIGMOID_SYM,
            GAUSSIAN,
            RELU,
            LEAKY_RELU;


            public int getActivationFunction() {
                switch (this.ordinal()) {
                    case 2: {
                        return 2;
                    }
                    case 0: {
                        return 0;
                    }
                    case 1: {
                        return 1;
                    }
                    case 3: {
                        return 3;
                    }
                    case 4: {
                        return 4;
                    }
                }
                return 1;
            }
        }
    }

    public static class LogisticRegressionClassifier
    extends AbstractOpenCVClassifierML<LogisticRegression> {
        LogisticRegressionClassifier() {
        }

        LogisticRegressionClassifier(LogisticRegression model) {
            super(model);
        }

        @Override
        ParameterList createParameterList(LogisticRegression model) {
            ParameterList params = new ParameterList();
            double learningRate = model.getLearningRate();
            int nIterations = model.getIterations();
            int reg = model.getRegularization();
            Regularization defaultReg = Regularization.DISABLE;
            for (Regularization temp : Regularization.values()) {
                if (reg != temp.getRegularization()) continue;
                defaultReg = temp;
                break;
            }
            params.addTitleParameter("Logistic regression options");
            params.addDoubleParameter("learningRate", "Learning rate", learningRate);
            params.addIntParameter("nIterations", "Number of iterations", nIterations);
            params.addChoiceParameter("regularization", "Regularization", (Object)defaultReg, Arrays.asList(Regularization.values()));
            OpenCVClassifiers.addTerminationCriteriaParameters(params, model.getTermCriteria());
            return params;
        }

        @Override
        public TrainData createTrainData(Mat samples, Mat targets, Mat weights, boolean doMulticlass) {
            targets.convertTo(targets, 5);
            return super.createTrainData(samples, targets, weights, doMulticlass);
        }

        @Override
        LogisticRegression createStatModel() {
            return LogisticRegression.create();
        }

        @Override
        Class<? extends StatModel> getStatModelClass() {
            return LogisticRegression.class;
        }

        @Override
        void updateModel(LogisticRegression model, ParameterList params, TrainData trainData) {
            double learningRate = params.getDoubleParameterValue("learningRate");
            int nIterations = params.getIntParameterValue("nIterations");
            Regularization regularization = (Regularization)((Object)params.getChoiceParameterValue("regularization"));
            model.setRegularization(regularization.getRegularization());
            model.setLearningRate(learningRate);
            model.setIterations(nIterations);
            model.setTermCriteria(OpenCVClassifiers.updateTermCriteria(params, model.getTermCriteria()));
        }

        static enum Regularization {
            DISABLE,
            L1,
            L2;


            public int getRegularization() {
                switch (this.ordinal()) {
                    case 1: {
                        return 0;
                    }
                    case 2: {
                        return 1;
                    }
                }
                return -1;
            }

            public String toString() {
                switch (this.ordinal()) {
                    case 1: {
                        return "L1";
                    }
                    case 2: {
                        return "L2";
                    }
                }
                return "None";
            }
        }
    }

    public static class EMClusterer
    extends AbstractOpenCVClassifierML<EM> {
        EMClusterer() {
        }

        EMClusterer(EM model) {
            super(model);
        }

        @Override
        ParameterList createParameterList(EM model) {
            ParameterList params = new ParameterList();
            int nClusters = model.getClustersNumber();
            params.addIntParameter("nClusters", "Number of clusters", nClusters);
            return params;
        }

        @Override
        Class<? extends StatModel> getStatModelClass() {
            return EM.class;
        }

        @Override
        EM createStatModel() {
            return EM.create();
        }

        @Override
        void updateModel(EM model, ParameterList params, TrainData trainData) {
            model.setClustersNumber(params.getIntParameterValue("nClusters").intValue());
        }
    }

    public static class NormalBayesClassifierCV
    extends AbstractOpenCVClassifierML<NormalBayesClassifier> {
        NormalBayesClassifierCV() {
        }

        NormalBayesClassifierCV(NormalBayesClassifier model) {
            super(model);
        }

        @Override
        ParameterList createParameterList(NormalBayesClassifier model) {
            ParameterList params = new ParameterList();
            params.addTitleParameter("No parameters to adjust!");
            return params;
        }

        @Override
        NormalBayesClassifier createStatModel() {
            return NormalBayesClassifier.create();
        }

        @Override
        Class<? extends StatModel> getStatModelClass() {
            return NormalBayesClassifier.class;
        }

        @Override
        void updateModel(NormalBayesClassifier model, ParameterList params, TrainData trainData) {
        }

        @Override
        public void predictWithLock(Mat samples, Mat results, Mat probabilities) {
            NormalBayesClassifier model = (NormalBayesClassifier)this.getStatModel();
            if (probabilities == null) {
                probabilities = new Mat();
            }
            model.predictProb(samples, results, probabilities, 0);
        }
    }

    public static class SVMClassifierCV
    extends AbstractOpenCVClassifierML<SVM> {
        SVMClassifierCV() {
        }

        SVMClassifierCV(SVM model) {
            super(model);
        }

        @Override
        ParameterList createParameterList(SVM model) {
            ParameterList params = new ParameterList();
            return params;
        }

        @Override
        SVM createStatModel() {
            return SVM.create();
        }

        @Override
        public boolean supportsAutoUpdate() {
            return false;
        }

        @Override
        Class<? extends StatModel> getStatModelClass() {
            return SVM.class;
        }

        @Override
        void updateModel(SVM model, ParameterList params, TrainData trainData) {
        }
    }

    public static class SVMSGDClassifierCV
    extends AbstractOpenCVClassifierML<SVMSGD> {
        SVMSGDClassifierCV() {
        }

        SVMSGDClassifierCV(SVMSGD model) {
            super(model);
        }

        @Override
        ParameterList createParameterList(SVMSGD model) {
            ParameterList params = new ParameterList();
            return params;
        }

        @Override
        SVMSGD createStatModel() {
            return SVMSGD.create();
        }

        @Override
        Class<? extends StatModel> getStatModelClass() {
            return SVMSGD.class;
        }

        @Override
        public boolean supportsAutoUpdate() {
            return false;
        }

        @Override
        void updateModel(SVMSGD model, ParameterList params, TrainData trainData) {
        }
    }

    static class MulticlassANNClassifierCV
    extends ANNClassifierCV {
        MulticlassANNClassifierCV() {
        }

        @Override
        public boolean supportsMulticlass() {
            return true;
        }

        @Override
        public String getName() {
            return "ANN MLP (Multiclass)";
        }
    }

    static abstract class AbstractTreeClassifier<T extends DTrees>
    extends AbstractOpenCVClassifierML<T> {
        AbstractTreeClassifier() {
        }

        AbstractTreeClassifier(T model) {
            super(model);
        }

        @Override
        ParameterList createParameterList(T model) {
            int maxDepth = Math.min(model.getMaxDepth(), 1000);
            int minSampleCount = model.getMinSampleCount();
            boolean use1SERule = model.getUse1SERule();
            ParameterList params = new ParameterList().addIntParameter("maxDepth", "Maximum tree depth", maxDepth, null, "Maximum possible tree depth").addIntParameter("minSampleCount", "Minimum sample count", minSampleCount, null, "Minimum number of samples per node").addBooleanParameter("use1SERule", "Use 1SE rule", use1SERule, "Harsher pruning, more compact tree");
            return params;
        }

        @Override
        void updateModel(T model, ParameterList params, TrainData trainData) {
            int maxDepth = params.getIntParameterValue("maxDepth");
            int minSampleCount = params.getIntParameterValue("minSampleCount");
            boolean use1SERule = params.getBooleanParameterValue("use1SERule");
            model.setCVFolds(0);
            model.setMaxDepth(maxDepth <= 0 ? Integer.MAX_VALUE : maxDepth);
            model.setMinSampleCount(minSampleCount < 1 ? 1 : minSampleCount);
            model.setUse1SERule(use1SERule);
        }
    }

    static class DefaultOpenCVStatModel<T extends StatModel>
    extends AbstractOpenCVClassifierML<T> {
        DefaultOpenCVStatModel(T model) {
            super(model);
        }

        @Override
        ParameterList createParameterList(T model) {
            return new ParameterList();
        }

        @Override
        T createStatModel() {
            return this.getStatModel();
        }

        @Override
        void updateModel(T model, ParameterList params, TrainData trainData) {
        }

        @Override
        Class<? extends StatModel> getStatModelClass() {
            return this.getStatModel().getClass();
        }
    }

    static abstract class AbstractOpenCVClassifierML<T extends StatModel>
    extends OpenCVStatModel {
        @JsonAdapter(value=OpenCVTypeAdapters.OpenCVTypeAdaptorFactory.class)
        private T model;
        private transient ParameterList params;
        transient ReentrantReadWriteLock lock = new ReentrantReadWriteLock();

        abstract ParameterList createParameterList(T var1);

        abstract T createStatModel();

        abstract void updateModel(T var1, ParameterList var2, TrainData var3);

        AbstractOpenCVClassifierML() {
        }

        AbstractOpenCVClassifierML(T model) {
            this.model = model;
            this.params = this.createParameterList(model);
        }

        @Override
        public boolean supportsMulticlass() {
            return false;
        }

        @Override
        public boolean supportsAutoUpdate() {
            return true;
        }

        @Override
        public boolean supportsProbabilities() {
            T model = this.getStatModel();
            return model instanceof RTrees || model instanceof ANN_MLP || model instanceof NormalBayesClassifier;
        }

        T getStatModel() {
            if (this.model == null) {
                this.model = this.createStatModel();
            }
            return this.model;
        }

        @Override
        public boolean isTrained() {
            return this.getStatModel().isTrained();
        }

        @Override
        public ParameterList getParameterList() {
            if (this.params == null) {
                this.params = this.createParameterList(this.getStatModel());
            }
            return this.params;
        }

        @Override
        public String toString() {
            return this.getName();
        }

        @Override
        public TrainData createTrainData(Mat samples, Mat targets, Mat weights, boolean doMulticlass) {
            if (doMulticlass && !this.supportsMulticlass()) {
                logger.warn("Multiclass classification requested, but not supported");
            }
            if (this.useUMat()) {
                UMat uSamples = samples.getUMat(0x1000000);
                UMat uTargets = targets.getUMat(0x1000000);
                if (weights == null || weights.empty()) {
                    return TrainData.create((UMat)uSamples, (int)0, (UMat)uTargets);
                }
                UMat uWeights = weights.getUMat(0x1000000);
                return TrainData.create((UMat)uSamples, (int)0, (UMat)uTargets, null, null, (UMat)uWeights, null);
            }
            if (weights == null || weights.empty()) {
                return TrainData.create((Mat)samples, (int)0, (Mat)targets);
            }
            return TrainData.create((Mat)samples, (int)0, (Mat)targets, null, null, (Mat)weights, null);
        }

        boolean useUMat() {
            return false;
        }

        @Override
        public void train(TrainData trainData) {
            this.lock.writeLock().lock();
            try {
                this.trainWithLock(trainData);
            }
            finally {
                this.lock.writeLock().unlock();
            }
        }

        public void trainWithLock(TrainData trainData) {
            T statModel = this.getStatModel();
            opencv_core.setRNGSeed((int)1012);
            this.updateModel(statModel, this.getParameterList(), trainData);
            statModel.train(trainData, this.getTrainFlags());
        }

        protected int getTrainFlags() {
            return 0;
        }

        abstract Class<? extends StatModel> getStatModelClass();

        @Override
        public String getName() {
            Class<StatModel> cls = this.getStatModelClass();
            if (ANN_MLP.class.equals(cls)) {
                return "Artificial neural network (ANN_MLP)";
            }
            if (RTrees.class.equals(cls)) {
                return "Random trees (RTrees)";
            }
            if (Boost.class.equals(cls)) {
                return "Boosted trees (Boost)";
            }
            if (DTrees.class.equals(cls)) {
                return "Decision tree (DTrees)";
            }
            if (EM.class.equals(cls)) {
                return "Expectation maximization";
            }
            if (KNearest.class.equals(cls)) {
                return "K nearest neighbor";
            }
            if (LogisticRegression.class.equals(cls)) {
                return "Logistic regression";
            }
            if (NormalBayesClassifier.class.equals(cls)) {
                return "Normal Bayes classifier";
            }
            return this.getStatModel().getClass().getSimpleName();
        }

        /*
         * WARNING - Removed try catching itself - possible behaviour change.
         */
        @Override
        public void predict(Mat samples, Mat results, Mat probabilities) {
            this.lock.readLock().lock();
            try {
                this.predictWithLock(samples, results, probabilities);
            }
            finally {
                this.lock.readLock().unlock();
            }
        }

        protected void predictWithLock(Mat samples, Mat results, Mat probabilities) {
            T statModel = this.getStatModel();
            statModel.predict(samples, results, 0);
            int nSamples = results.rows();
            if (results.cols() > 1) {
                Indexer indexer = results.createIndexer();
                int nClasses = results.cols();
                Mat matResultsNew = new Mat(nSamples, 1, opencv_core.CV_32SC1);
                IntIndexer idxResults = (IntIndexer)matResultsNew.createIndexer();
                if (probabilities != null) {
                    probabilities.create(nSamples, nClasses, opencv_core.CV_32FC1);
                    probabilities.put(results);
                }
                long[] inds = new long[2];
                for (int row = 0; row < nSamples; ++row) {
                    double maxValue = Double.NEGATIVE_INFINITY;
                    int maxInd = -1;
                    inds[0] = row;
                    for (long c = 0L; c < (long)nClasses; ++c) {
                        inds[1] = c;
                        double val = indexer.getDouble(inds);
                        if (!(val > maxValue)) continue;
                        maxValue = val;
                        maxInd = (int)c;
                    }
                    idxResults.put((long)row, maxInd);
                }
                indexer.release();
                idxResults.release();
                results.put(matResultsNew);
                matResultsNew.close();
            } else {
                results.convertTo(results, opencv_core.CV_32SC1);
                if (probabilities != null) {
                    probabilities.create(0, 0, opencv_core.CV_32FC1);
                }
            }
        }

        @Override
        public boolean supportsMissingValues() {
            return this.getStatModel() instanceof DTrees;
        }
    }

    static class OpenCVClassifierTypeAdapter
    extends TypeAdapter<OpenCVStatModel> {
        OpenCVClassifierTypeAdapter() {
        }

        public void write(JsonWriter out, OpenCVStatModel value) throws IOException {
            OpenCVTypeAdapters.getTypeAdaptor(StatModel.class).write(out, (Object)value.getStatModel());
        }

        public OpenCVStatModel read(JsonReader in) throws IOException {
            StatModel statModel = (StatModel)OpenCVTypeAdapters.getTypeAdaptor(StatModel.class).read(in);
            return OpenCVClassifiers.wrapStatModel(statModel);
        }
    }

    @JsonAdapter(value=OpenCVClassifierTypeAdapter.class)
    public static abstract class OpenCVStatModel {
        public abstract boolean supportsMissingValues();

        public abstract String getName();

        public abstract boolean isTrained();

        public abstract boolean supportsMulticlass();

        public abstract boolean supportsAutoUpdate();

        public abstract boolean supportsProbabilities();

        public abstract ParameterList getParameterList();

        public abstract TrainData createTrainData(Mat var1, Mat var2, Mat var3, boolean var4);

        public abstract void train(TrainData var1);

        public abstract void predict(Mat var1, Mat var2, Mat var3);

        abstract StatModel getStatModel();

        public String toString() {
            return String.format("OpenCV ", this.getStatModel().getClass().getSimpleName());
        }
    }
}

