package edu.umass.cs.mallet.base.classify.tui;

import bsh.EvalError;
import com.fasterxml.jackson.core.util.MinimalPrettyPrinter;
import edu.umass.cs.mallet.base.classify.Classifier;
import edu.umass.cs.mallet.base.classify.ClassifierTrainer;
import edu.umass.cs.mallet.base.classify.NaiveBayesTrainer;
import edu.umass.cs.mallet.base.classify.Trial;
import edu.umass.cs.mallet.base.classify.evaluate.ConfusionMatrix;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.Labeling;
import edu.umass.cs.mallet.base.types.MatrixOps;
import edu.umass.cs.mallet.base.util.CommandOption;
import edu.umass.cs.mallet.base.util.MalletLogger;
import edu.umass.cs.mallet.base.util.MalletProgressMessageLogger;
import edu.umass.cs.mallet.base.util.ProgressMessageLogFormatter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Random;
import java.util.logging.ConsoleHandler;
import java.util.logging.Handler;
import java.util.logging.Logger;
import org.apache.xalan.templates.Constants;

/* loaded from: input_file:WEB-INF/lib/mallet-0.4-jaeschke.jar:edu/umass/cs/mallet/base/classify/tui/Vectors2Classify.class */
public abstract class Vectors2Classify {
    private static Logger logger = MalletLogger.getLogger(Vectors2Classify.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(Vectors2Classify.class.getName() + "-pl");
    private static ArrayList classifierTrainers = new ArrayList();
    private static boolean[][] ReportOptions = new boolean[3][4];
    private static String[][] ReportOptionArgs = new String[3][4];
    static CommandOption.SpacedStrings report = new CommandOption.SpacedStrings(Vectors2Classify.class, "report", "[train|test|validation]:[accuracy|f1|confusion|raw]", true, new String[]{"test:accuracy", "test:confusion", "train:accuracy"}, "", null) { // from class: edu.umass.cs.mallet.base.classify.tui.Vectors2Classify.1
        @Override // edu.umass.cs.mallet.base.util.CommandOption
        public void postParsing(CommandOption.List list) {
            for (int i = 0; i < this.value.length; i++) {
                String[] split = this.value[i].split("[:=]");
                String str = split[0];
                String str2 = split[1];
                String str3 = split.length >= 3 ? split[2] : null;
                boolean z = false;
                int i2 = 0;
                while (true) {
                    if (i2 >= ReportOption.dataOptions.length) {
                        break;
                    }
                    if (str.equals(ReportOption.dataOptions[i2])) {
                        z = true;
                        break;
                    }
                    i2++;
                }
                if (!z) {
                    throw new IllegalArgumentException("Unknown argument = " + str + " in --report " + this.value[i]);
                }
                boolean z2 = false;
                int i3 = 0;
                while (true) {
                    if (i3 >= ReportOption.reportOptions.length) {
                        break;
                    }
                    if (str2.equals(ReportOption.reportOptions[i3])) {
                        z2 = true;
                        break;
                    }
                    i3++;
                }
                if (!z2) {
                    throw new IllegalArgumentException("Unknown argument = " + str2 + " in --report " + this.value[i]);
                }
                Vectors2Classify.ReportOptions[i2][i3] = true;
                if (i3 == 1) {
                    if (str3 == null) {
                        throw new IllegalArgumentException("F1 must have label argument in --report " + this.value[i]);
                    }
                    Vectors2Classify.ReportOptionArgs[i2][i3] = str3;
                } else if (str3 != null) {
                    throw new IllegalArgumentException("No arguments after = allowed in --report " + this.value[i]);
                }
            }
        }
    };
    static CommandOption.Object trainerConstructor = new CommandOption.Object(Vectors2Classify.class, "trainer", "ClassifierTrainer constructor", true, new NaiveBayesTrainer(), "Java code for the constructor used to create a ClassifierTrainer.  If no '(' appears, then \"new \" will be prepended and \"Trainer()\" will be appended.You may use this option mutiple times to compare multiple classifiers.", null) { // from class: edu.umass.cs.mallet.base.classify.tui.Vectors2Classify.2
        static final /* synthetic */ boolean $assertionsDisabled;

        @Override // edu.umass.cs.mallet.base.util.CommandOption.Object, edu.umass.cs.mallet.base.util.CommandOption
        public void parseArg(String str) {
            String[] split = str.split(",");
            String str2 = split[0];
            if (str2.indexOf(40) != -1) {
                super.parseArg(str);
            } else if (str2.endsWith("Trainer")) {
                super.parseArg("new " + str2 + "()");
            } else {
                super.parseArg("new " + str2 + "Trainer()");
            }
            Method[] methods = this.value.getClass().getMethods();
            for (int i = 1; i < split.length; i++) {
                String[] split2 = split[i].split("=");
                String str3 = split2[0];
                try {
                    Object eval = getInterpreter().eval(split2[1]);
                    boolean z = false;
                    int i2 = 0;
                    while (true) {
                        if (i2 >= methods.length) {
                            break;
                        }
                        if (("set" + Character.toUpperCase(str3.charAt(0)) + str3.substring(1)).equals(methods[i2].getName()) && methods[i2].getParameterTypes().length == 1) {
                            try {
                                methods[i2].invoke(this.value, eval);
                                z = true;
                                break;
                            } catch (IllegalAccessException e) {
                                System.out.println("IllegalAccessException " + e);
                                throw new IllegalArgumentException("Java access error calling setter\n" + e);
                            } catch (InvocationTargetException e2) {
                                System.out.println("IllegalTargetException " + e2);
                                throw new IllegalArgumentException("Java target error calling setter\n" + e2);
                            }
                        }
                        i2++;
                    }
                    if (!z) {
                        System.out.println("Parameter " + str3 + " not found on trainer " + str2);
                        System.out.println("Available parameters for " + str2);
                        for (int i3 = 0; i3 < methods.length; i3++) {
                            if (methods[i3].getName().startsWith("set") && methods[i3].getParameterTypes().length == 1) {
                                System.out.println(Character.toLowerCase(methods[i3].getName().charAt(3)) + methods[i3].getName().substring(4));
                            }
                        }
                        throw new IllegalArgumentException("no setter found for parameter " + str3);
                    }
                } catch (EvalError e3) {
                    throw new IllegalArgumentException("Java interpreter eval error on parameter " + str3 + "\n" + e3);
                }
            }
        }

        @Override // edu.umass.cs.mallet.base.util.CommandOption
        public void postParsing(CommandOption.List list) {
            if (!$assertionsDisabled && !(this.value instanceof ClassifierTrainer)) {
                throw new AssertionError();
            }
            Vectors2Classify.classifierTrainers.add(this.value);
        }

        static {
            $assertionsDisabled = !Vectors2Classify.class.desiredAssertionStatus();
        }
    };
    static CommandOption.String outputFile = new CommandOption.String(Vectors2Classify.class, "output-classifier", "FILENAME", true, "classifier.mallet", "The filename in which to write the classifier after it has been trained.", null);
    static CommandOption.String inputFile = new CommandOption.String(Vectors2Classify.class, "input", "FILENAME", true, "text.vectors", "The filename from which to read the list of training instances.  Use - for stdin.", null);
    static CommandOption.String trainingFile = new CommandOption.String(Vectors2Classify.class, "training-file", "FILENAME", true, "text.vectors", "Read the training set instance list from this file. If this is specified, the input file parameter is ignored", null);
    static CommandOption.String testFile = new CommandOption.String(Vectors2Classify.class, "testing-file", "FILENAME", true, "text.vectors", "Read the test set instance list to this file. If this option is specified, the training-file parameter must be specified and  the input-file parameter is ignored", null);
    static CommandOption.String validationFile = new CommandOption.String(Vectors2Classify.class, "validation-file", "FILENAME", true, "text.vectors", "Read the validation set instance list to this file.If this option is specified, the training-file parameter must be specified and the input-file parameter is ignored", null);
    static CommandOption.Double trainingProportionOption = new CommandOption.Double(Vectors2Classify.class, "training-portion", "DECIMAL", true, 1.0d, "The fraction of the instances that should be used for training.", null);
    static CommandOption.Double validationProportionOption = new CommandOption.Double(Vectors2Classify.class, "validation-portion", "DECIMAL", true, 0.0d, "The fraction of the instances that should be used for validation.", null);
    static CommandOption.Integer randomSeedOption = new CommandOption.Integer(Vectors2Classify.class, "random-seed", "INTEGER", true, 0, "The random seed for randomly selecting a proportion of the instance list for training", null);
    static CommandOption.Integer numTrialsOption = new CommandOption.Integer(Vectors2Classify.class, "num-trials", "INTEGER", true, 1, "The number of random train/test splits to perform", null);
    static CommandOption.Object classifierEvaluatorOption = new CommandOption.Object(Vectors2Classify.class, "classifier-evaluator", "CONSTRUCTOR", true, null, "Java code for constructing a ClassifierEvaluating object", null);
    static CommandOption.Integer verbosityOption = new CommandOption.Integer(Vectors2Classify.class, "verbosity", "INTEGER", true, -1, "The level of messages to print: 0 is silent, 8 is most verbose. Levels 0-8 correspond to the java.logger predefined levels off, severe, warning, info, config, fine, finer, finest, all. The default value is taken from the mallet logging.properties file, which currently defaults to INFO level (3)", null);
    static CommandOption.Boolean noOverwriteProgressMessagesOption = new CommandOption.Boolean(Vectors2Classify.class, "noOverwriteProgressMessages", "true|false", false, false, "Suppress writing-in-place on terminal for progess messages - repetitive messages of which only the latest is generally of interest", null);

    /* loaded from: input_file:WEB-INF/lib/mallet-0.4-jaeschke.jar:edu/umass/cs/mallet/base/classify/tui/Vectors2Classify$ReportOption.class */
    private static class ReportOption {
        static final String[] dataOptions = {"train", Constants.ATTRNAME_TEST, "validation"};
        static final String[] reportOptions = {"accuracy", "f1", "confusion", "raw"};
        static final int train = 0;
        static final int test = 1;
        static final int validation = 2;
        static final int accuracy = 0;
        static final int f1 = 1;
        static final int confusion = 2;
        static final int raw = 3;

        private ReportOption() {
        }
    }

    public static void main(String[] strArr) throws EvalError, IOException {
        CommandOption.setSummary(Vectors2Classify.class, "A tool for training, saving and printing diagnostics from a classifier on vectors.");
        CommandOption.process(Vectors2Classify.class, strArr);
        if (!trainerConstructor.wasInvoked()) {
            classifierTrainers.add(new NaiveBayesTrainer());
        }
        if (!report.wasInvoked()) {
            report.postParsing(null);
        }
        int i = verbosityOption.value;
        Logger rootLogger = ((MalletLogger) progressLogger).getRootLogger();
        if (verbosityOption.wasInvoked()) {
            rootLogger.setLevel(MalletLogger.LoggingLevels[i]);
        }
        if (!noOverwriteProgressMessagesOption.value) {
            Handler[] handlers = rootLogger.getHandlers();
            for (int i2 = 0; i2 < handlers.length; i2++) {
                if (handlers[i2] instanceof ConsoleHandler) {
                    handlers[i2].setFormatter(new ProgressMessageLogFormatter());
                }
            }
        }
        boolean z = testFile.wasInvoked() || trainingFile.wasInvoked() || validationFile.wasInvoked();
        InstanceList instanceList = null;
        InstanceList instanceList2 = null;
        InstanceList instanceList3 = null;
        if (z) {
            instanceList3 = InstanceList.load(new File(trainingFile.value));
            logger.info("Training vectors loaded from " + trainingFile.value);
            if (testFile.wasInvoked()) {
                instanceList2 = InstanceList.load(new File(testFile.value));
                logger.info("Testing vectors loaded from " + testFile.value);
            }
            if (validationFile.wasInvoked()) {
                InstanceList.load(new File(validationFile.value));
                logger.info("validation vectors loaded from " + validationFile.value);
            }
        } else {
            instanceList = InstanceList.load(new File(inputFile.value));
        }
        int i3 = numTrialsOption.value;
        Random random = randomSeedOption.wasInvoked() ? new Random(randomSeedOption.value) : new Random();
        ClassifierTrainer[] classifierTrainerArr = new ClassifierTrainer[classifierTrainers.size()];
        for (int i4 = 0; i4 < classifierTrainers.size(); i4++) {
            classifierTrainerArr[i4] = (ClassifierTrainer) classifierTrainers.get(i4);
            logger.fine("Trainer specified = " + classifierTrainerArr[i4].toString());
        }
        double[][] dArr = new double[classifierTrainerArr.length][i3];
        double[][] dArr2 = new double[classifierTrainerArr.length][i3];
        double[][] dArr3 = new double[classifierTrainerArr.length][i3];
        String[][] strArr2 = new String[classifierTrainerArr.length][i3];
        String[][] strArr3 = new String[classifierTrainerArr.length][i3];
        String[][] strArr4 = new String[classifierTrainerArr.length][i3];
        double d = trainingProportionOption.value;
        double d2 = validationProportionOption.value;
        if (!z) {
            logger.info("Training portion = " + d);
            logger.info("Validation portion = " + d2);
            logger.info("Testing portion = " + ((1.0d - d2) - d));
        }
        for (int i5 = 0; i5 < i3; i5++) {
            System.out.println("\n-------------------- Trial " + i5 + "  --------------------\n");
            InstanceList[] split = !z ? instanceList.split(random, new double[]{d, (1.0d - d) - d2, d2}) : new InstanceList[]{instanceList3, instanceList2, instanceList2};
            long[] jArr = new long[classifierTrainerArr.length];
            for (int i6 = 0; i6 < classifierTrainerArr.length; i6++) {
                jArr[i6] = System.currentTimeMillis();
                System.out.println("Trial " + i5 + " Training " + classifierTrainerArr[i6].toString() + " with " + split[0].size() + " instances");
                Classifier train = classifierTrainerArr[i6].train(split[0]);
                System.out.println("Trial " + i5 + " Training " + classifierTrainerArr[i6].toString() + " finished");
                jArr[i6] = System.currentTimeMillis() - jArr[i6];
                Trial trial = new Trial(train, split[0]);
                Trial trial2 = new Trial(train, split[1]);
                Trial trial3 = new Trial(train, split[2]);
                if (split[0].size() > 0) {
                    strArr2[i6][i5] = new ConfusionMatrix(trial).toString();
                }
                if (split[1].size() > 0) {
                    strArr3[i6][i5] = new ConfusionMatrix(trial2).toString();
                }
                if (split[2].size() > 0) {
                    strArr4[i6][i5] = new ConfusionMatrix(trial3).toString();
                }
                dArr[i6][i5] = trial.accuracy();
                dArr2[i6][i5] = trial2.accuracy();
                dArr3[i6][i5] = trial3.accuracy();
                if (outputFile.wasInvoked()) {
                    String str = outputFile.value;
                    if (classifierTrainerArr.length > 1) {
                        str = str + classifierTrainerArr[i6].toString();
                    }
                    if (i3 > 1) {
                        str = str + ".trial" + i5;
                    }
                    try {
                        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(str));
                        objectOutputStream.writeObject(train);
                        objectOutputStream.close();
                    } catch (Exception e) {
                        e.printStackTrace();
                        throw new IllegalArgumentException("Couldn't write classifier to filename " + str);
                    }
                }
                if (ReportOptions[0][3]) {
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString());
                    System.out.println(" Raw Training Data");
                    printTrialClassification(trial);
                }
                if (ReportOptions[1][3]) {
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString());
                    System.out.println(" Raw Testing Data");
                    printTrialClassification(trial2);
                }
                if (ReportOptions[2][3]) {
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString());
                    System.out.println(" Raw Validation Data");
                    printTrialClassification(trial3);
                }
                if (ReportOptions[0][2]) {
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString() + " Training Data Confusion Matrix");
                    if (split[0].size() > 0) {
                        System.out.println(strArr2[i6][i5]);
                    }
                }
                if (ReportOptions[0][0]) {
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString() + " training data accuracy= " + dArr[i6][i5]);
                }
                if (ReportOptions[0][1]) {
                    String str2 = ReportOptionArgs[0][1];
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString() + " training data F1(" + str2 + ") = " + trial.labelF1(str2));
                }
                if (ReportOptions[2][2]) {
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString() + " Validation Data Confusion Matrix");
                    if (split[2].size() > 0) {
                        System.out.println(strArr4[i6][i5]);
                    }
                }
                if (ReportOptions[2][0]) {
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString() + " validation data accuracy= " + dArr3[i6][i5]);
                }
                if (ReportOptions[2][1]) {
                    String str3 = ReportOptionArgs[2][1];
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString() + " validation data F1(" + str3 + ") = " + trial3.labelF1(str3));
                }
                if (ReportOptions[1][2]) {
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString() + " Test Data Confusion Matrix");
                    if (split[1].size() > 0) {
                        System.out.println(strArr3[i6][i5]);
                    }
                }
                if (ReportOptions[1][0]) {
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString() + " test data accuracy= " + dArr2[i6][i5]);
                }
                if (ReportOptions[1][1]) {
                    String str4 = ReportOptionArgs[1][1];
                    System.out.println("Trial " + i5 + " Trainer " + classifierTrainerArr[i6].toString() + " test data F1(" + str4 + ") = " + trial2.labelF1(str4));
                }
            }
        }
        for (int i7 = 0; i7 < classifierTrainerArr.length; i7++) {
            System.out.println("\n" + classifierTrainerArr[i7].toString());
            if (ReportOptions[0][0]) {
                System.out.println("Summary. train accuracy mean = " + MatrixOps.mean(dArr[i7]) + " stddev = " + MatrixOps.stddev(dArr[i7]) + " stderr = " + MatrixOps.stderr(dArr[i7]));
            }
            if (ReportOptions[2][0]) {
                System.out.println("Summary. validation accuracy mean = " + MatrixOps.mean(dArr3[i7]) + " stddev = " + MatrixOps.stddev(dArr3[i7]) + " stderr = " + MatrixOps.stderr(dArr3[i7]));
            }
            if (ReportOptions[1][0]) {
                System.out.println("Summary. test accuracy mean = " + MatrixOps.mean(dArr2[i7]) + " stddev = " + MatrixOps.stddev(dArr2[i7]) + " stderr = " + MatrixOps.stderr(dArr2[i7]));
            }
        }
    }

    private static void printTrialClassification(Trial trial) {
        ArrayList arrayList = trial.toArrayList();
        for (int i = 0; i < arrayList.size(); i++) {
            Instance classification = trial.getClassification(i).getInstance();
            System.out.print(classification.getName() + MinimalPrettyPrinter.DEFAULT_ROOT_VALUE_SEPARATOR + classification.getTarget() + MinimalPrettyPrinter.DEFAULT_ROOT_VALUE_SEPARATOR);
            Labeling labeling = trial.getClassification(i).getLabeling();
            for (int i2 = 0; i2 < labeling.numLocations(); i2++) {
                System.out.print(labeling.getLabelAtRank(i2).toString() + ":" + labeling.getValueAtRank(i2) + MinimalPrettyPrinter.DEFAULT_ROOT_VALUE_SEPARATOR);
            }
            System.out.println();
        }
    }
}
