package edu.umass.cs.mallet.base.fst;

import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.pipe.iterator.LineGroupIterator;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.FeatureVectorSequence;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.LabelAlphabet;
import edu.umass.cs.mallet.base.types.LabelSequence;
import edu.umass.cs.mallet.base.types.Sequence;
import edu.umass.cs.mallet.base.util.CommandOption;
import edu.umass.cs.mallet.base.util.MalletLogger;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.FileReader;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Random;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import org.apache.commons.io.IOUtils;

/* loaded from: input_file:WEB-INF/lib/mallet-0.4-jaeschke.jar:edu/umass/cs/mallet/base/fst/SimpleTagger.class */
public class SimpleTagger {
    private static Logger logger = MalletLogger.getLogger(SimpleTagger.class.getName());
    private static final CommandOption.Double gaussianVarianceOption = new CommandOption.Double(SimpleTagger.class, "gaussian-variance", "DECIMAL", true, 10.0d, "The gaussian prior variance used for training.", null);
    private static final CommandOption.Boolean trainOption = new CommandOption.Boolean(SimpleTagger.class, "train", "true|false", true, false, "Whether to train", null);
    private static final CommandOption.String testOption = new CommandOption.String(SimpleTagger.class, "test", "lab or seg=start-1.continue-1,...,start-n.continue-n", true, null, "Test measuring labeling or segmentation (start-i, continue-i) accuracy", null);
    private static final CommandOption.File modelOption = new CommandOption.File(SimpleTagger.class, "model-file", "FILENAME", true, null, "The filename for reading (train/run) or saving (train) the model.", null);
    private static final CommandOption.Double trainingFractionOption = new CommandOption.Double(SimpleTagger.class, "training-proportion", "DECIMAL", true, 0.5d, "Fraction of data to use for training in a random split.", null);
    private static final CommandOption.Integer randomSeedOption = new CommandOption.Integer(SimpleTagger.class, "random-seed", "INTEGER", true, 0, "The random seed for randomly selecting a proportion of the instance list for training", null);
    private static final CommandOption.IntegerArray ordersOption = new CommandOption.IntegerArray(SimpleTagger.class, "orders", "COMMA-SEP-DECIMALS", true, new int[]{1}, "List of label Markov orders (main and backoff) ", null);
    private static final CommandOption.String forbiddenOption = new CommandOption.String(SimpleTagger.class, "forbidden", "REGEXP", true, "\\s", "label1,label2 transition forbidden if it matches this", null);
    private static final CommandOption.String allowedOption = new CommandOption.String(SimpleTagger.class, "allowed", "REGEXP", true, ".*", "label1,label2 transition allowed only if it matches this", null);
    private static final CommandOption.String defaultOption = new CommandOption.String(SimpleTagger.class, "default-label", "STRING", true, "O", "Label for initial context and uninteresting tokens", null);
    private static final CommandOption.Integer iterationsOption = new CommandOption.Integer(SimpleTagger.class, "iterations", "INTEGER", true, 500, "Number of training iterations", null);
    private static final CommandOption.Boolean viterbiOutputOption = new CommandOption.Boolean(SimpleTagger.class, "viterbi-output", "true|false", true, false, "Print Viterbi periodically during training", null);
    private static final CommandOption.Boolean connectedOption = new CommandOption.Boolean(SimpleTagger.class, "fully-connected", "true|false", true, true, "Include all allowed transitions, even those not in training data", null);
    private static final CommandOption.Boolean continueTrainingOption = new CommandOption.Boolean(SimpleTagger.class, "continue-training", "true|false", false, false, "Continue training from model specified by --model-file", null);
    private static final CommandOption.Integer nBestOption = new CommandOption.Integer(SimpleTagger.class, "n-best", "INTEGER", true, 1, "How many answers to output", null);
    private static final CommandOption.Integer cacheSizeOption = new CommandOption.Integer(SimpleTagger.class, "cache-size", "INTEGER", true, 100000, "How much state information to memoize in n-best decoding", null);
    private static final CommandOption.Boolean includeInputOption = new CommandOption.Boolean(SimpleTagger.class, "include-input", "true|false", true, false, "Whether to include the input features when printing decoding output", null);
    private static final CommandOption.List commandOptions = new CommandOption.List("Training, testing and running a generic tagger.", new CommandOption[]{gaussianVarianceOption, trainOption, iterationsOption, testOption, trainingFractionOption, modelOption, randomSeedOption, ordersOption, forbiddenOption, allowedOption, defaultOption, viterbiOutputOption, connectedOption, continueTrainingOption, nBestOption, cacheSizeOption, includeInputOption});

    /* loaded from: input_file:WEB-INF/lib/mallet-0.4-jaeschke.jar:edu/umass/cs/mallet/base/fst/SimpleTagger$SimpleTaggerSentence2FeatureVectorSequence.class */
    public static class SimpleTaggerSentence2FeatureVectorSequence extends Pipe {
        public SimpleTaggerSentence2FeatureVectorSequence() {
            super(Alphabet.class, LabelAlphabet.class);
        }

        /* JADX WARN: Type inference failed for: r0v4, types: [java.lang.String[], java.lang.String[][]] */
        private String[][] parseSentence(String str) {
            String[] split = str.split(IOUtils.LINE_SEPARATOR_UNIX);
            ?? r0 = new String[split.length];
            for (int i = 0; i < split.length; i++) {
                r0[i] = split[i].split(" ");
            }
            return r0;
        }

        @Override // edu.umass.cs.mallet.base.pipe.Pipe
        public Instance pipe(Instance instance) {
            String[][] strArr;
            int length;
            Object data = instance.getData();
            Alphabet dataAlphabet = getDataAlphabet();
            if (data instanceof String) {
                strArr = parseSentence((String) data);
            } else {
                if (!(data instanceof String[][])) {
                    throw new IllegalArgumentException("Not a String or String[][]; got " + data);
                }
                strArr = (String[][]) data;
            }
            FeatureVector[] featureVectorArr = new FeatureVector[strArr.length];
            LabelSequence labelSequence = isTargetProcessing() ? new LabelSequence((LabelAlphabet) getTargetAlphabet(), strArr.length) : null;
            for (int i = 0; i < strArr.length; i++) {
                if (!isTargetProcessing()) {
                    length = strArr[i].length;
                } else {
                    if (strArr[i].length < 1) {
                        throw new IllegalStateException("Missing label at line " + i + " instance " + instance.getName());
                    }
                    length = strArr[i].length - 1;
                    labelSequence.add(strArr[i][length]);
                }
                int[] iArr = new int[length];
                for (int i2 = 0; i2 < length; i2++) {
                    iArr[i2] = dataAlphabet.lookupIndex(strArr[i][i2]);
                }
                featureVectorArr[i] = new FeatureVector(dataAlphabet, iArr);
            }
            instance.setData(new FeatureVectorSequence(featureVectorArr));
            if (isTargetProcessing()) {
                instance.setTarget(labelSequence);
            }
            return instance;
        }
    }

    private SimpleTagger() {
    }

    public static CRF4 train(InstanceList instanceList, InstanceList instanceList2, TransducerEvaluator transducerEvaluator, int[] iArr, String str, String str2, String str3, boolean z, int i, double d, CRF4 crf4) {
        Pattern compile = Pattern.compile(str2);
        Pattern compile2 = Pattern.compile(str3);
        if (crf4 == null) {
            crf4 = new CRF4(instanceList.getPipe(), (Pipe) null);
            String addOrderNStates = crf4.addOrderNStates(instanceList, iArr, null, str, compile, compile2, z);
            crf4.setGaussianPriorVariance(d);
            for (int i2 = 0; i2 < crf4.numStates(); i2++) {
                crf4.getState(i2).setInitialCost(Double.POSITIVE_INFINITY);
            }
            crf4.getState(addOrderNStates).setInitialCost(Transducer.ZERO_COST);
        }
        logger.info("Training on " + instanceList.size() + " instances");
        if (instanceList2 != null) {
            logger.info("Testing on " + instanceList2.size() + " instances");
        }
        crf4.train(instanceList, null, instanceList2, transducerEvaluator, i);
        return crf4;
    }

    public static void test(Transducer transducer, TransducerEvaluator transducerEvaluator, InstanceList instanceList) {
        transducerEvaluator.test(transducer, instanceList, "Testing", null);
    }

    public static Sequence[] apply(Transducer transducer, Sequence sequence, int i) {
        return i == 1 ? new Sequence[]{transducer.transduce(sequence)} : transducer.getViterbiLattice(sequence, null, cacheSizeOption.value()).outputNBest(i);
    }

    public static void main(String[] strArr) throws Exception {
        Pipe inputPipe;
        FileReader fileReader = null;
        FileReader fileReader2 = null;
        InstanceList instanceList = null;
        InstanceList instanceList2 = null;
        int processOptions = commandOptions.processOptions(strArr);
        if (processOptions == strArr.length) {
            commandOptions.printUsage(true);
            throw new IllegalArgumentException("Missing data file(s)");
        }
        if (trainOption.value) {
            fileReader = new FileReader(new File(strArr[processOptions]));
            if (testOption.value != null && processOptions < strArr.length - 1) {
                fileReader2 = new FileReader(new File(strArr[processOptions + 1]));
            }
        } else {
            fileReader2 = new FileReader(new File(strArr[processOptions]));
        }
        CRF4 crf4 = null;
        TransducerEvaluator transducerEvaluator = null;
        if (!continueTrainingOption.value && trainOption.value) {
            inputPipe = new SimpleTaggerSentence2FeatureVectorSequence();
            inputPipe.getTargetAlphabet().lookupIndex(defaultOption.value);
        } else {
            if (modelOption.value == null) {
                commandOptions.printUsage(true);
                throw new IllegalArgumentException("Missing model file option");
            }
            ObjectInputStream objectInputStream = new ObjectInputStream(new FileInputStream(modelOption.value));
            crf4 = (CRF4) objectInputStream.readObject();
            objectInputStream.close();
            inputPipe = crf4.getInputPipe();
        }
        if (testOption.value != null) {
            if (testOption.value.startsWith("lab")) {
                transducerEvaluator = new TokenAccuracyEvaluator(viterbiOutputOption.value);
            } else {
                if (!testOption.value.startsWith("seg=")) {
                    commandOptions.printUsage(true);
                    throw new IllegalArgumentException("Invalid test option: " + testOption.value);
                }
                String[] split = testOption.value.substring(4).split(",");
                if (split.length < 1) {
                    commandOptions.printUsage(true);
                    throw new IllegalArgumentException("Missing segment start/continue labels: " + testOption.value);
                }
                String[] strArr2 = new String[split.length];
                String[] strArr3 = new String[split.length];
                for (int i = 0; i < split.length; i++) {
                    String[] split2 = split[i].split("\\.");
                    if (split2.length != 2) {
                        commandOptions.printUsage(true);
                        throw new IllegalArgumentException("Incorrectly-specified segment start and end labels: " + split[i]);
                    }
                    strArr2[i] = split2[0];
                    strArr3[i] = split2[1];
                }
                transducerEvaluator = new MultiSegmentationEvaluator(strArr2, strArr3, viterbiOutputOption.value);
            }
        }
        if (trainOption.value) {
            inputPipe.setTargetProcessing(true);
            instanceList = new InstanceList(inputPipe);
            instanceList.add(new LineGroupIterator(fileReader, Pattern.compile("^\\s*$"), true));
            logger.info("Number of features in training data: " + inputPipe.getDataAlphabet().size());
            if (testOption.value != null) {
                if (fileReader2 != null) {
                    instanceList2 = new InstanceList(inputPipe);
                    instanceList2.add(new LineGroupIterator(fileReader2, Pattern.compile("^\\s*$"), true));
                } else {
                    InstanceList[] split3 = instanceList.split(new Random(randomSeedOption.value), new double[]{trainingFractionOption.value, 1.0d - trainingFractionOption.value});
                    instanceList = split3[0];
                    instanceList2 = split3[1];
                }
            }
        } else if (testOption.value != null) {
            inputPipe.setTargetProcessing(true);
            instanceList2 = new InstanceList(inputPipe);
            instanceList2.add(new LineGroupIterator(fileReader2, Pattern.compile("^\\s*$"), true));
        } else {
            inputPipe.setTargetProcessing(false);
            instanceList2 = new InstanceList(inputPipe);
            instanceList2.add(new LineGroupIterator(fileReader2, Pattern.compile("^\\s*$"), true));
        }
        logger.info("Number of predicates: " + inputPipe.getDataAlphabet().size());
        if (inputPipe.isTargetProcessing()) {
            Alphabet targetAlphabet = inputPipe.getTargetAlphabet();
            StringBuffer stringBuffer = new StringBuffer("Labels:");
            for (int i2 = 0; i2 < targetAlphabet.size(); i2++) {
                stringBuffer.append(" ").append(targetAlphabet.lookupObject(i2).toString());
            }
            logger.info(stringBuffer.toString());
        }
        if (trainOption.value) {
            CRF4 train = train(instanceList, instanceList2, transducerEvaluator, ordersOption.value, defaultOption.value, forbiddenOption.value, allowedOption.value, connectedOption.value, iterationsOption.value, gaussianVarianceOption.value, crf4);
            if (modelOption.value != null) {
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(new FileOutputStream(modelOption.value));
                objectOutputStream.writeObject(train);
                objectOutputStream.close();
                return;
            }
            return;
        }
        if (crf4 == null) {
            if (modelOption.value == null) {
                commandOptions.printUsage(true);
                throw new IllegalArgumentException("Missing model file option");
            }
            ObjectInputStream objectInputStream2 = new ObjectInputStream(new FileInputStream(modelOption.value));
            crf4 = (CRF4) objectInputStream2.readObject();
            objectInputStream2.close();
        }
        if (transducerEvaluator != null) {
            test(crf4, transducerEvaluator, instanceList2);
            return;
        }
        boolean value = includeInputOption.value();
        for (int i3 = 0; i3 < instanceList2.size(); i3++) {
            Sequence sequence = (Sequence) instanceList2.getInstance(i3).getData();
            Sequence[] apply = apply(crf4, sequence, nBestOption.value);
            int length = apply.length;
            boolean z = false;
            for (int i4 = 0; i4 < length; i4++) {
                if (apply[i4].size() != sequence.size()) {
                    System.err.println("Failed to decode input sequence " + i3 + ", answer " + i4);
                    z = true;
                }
            }
            if (!z) {
                for (int i5 = 0; i5 < sequence.size(); i5++) {
                    StringBuffer stringBuffer2 = new StringBuffer();
                    for (Sequence sequence2 : apply) {
                        stringBuffer2.append(sequence2.get(i5).toString()).append(" ");
                    }
                    if (value) {
                        stringBuffer2.append(((FeatureVector) sequence.get(i5)).toString(true));
                    }
                    System.out.println(stringBuffer2.toString());
                }
                System.out.println();
            }
        }
    }
}
