package edu.umass.cs.mallet.base.classify;

import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.maximize.LimitedMemoryBFGS;
import edu.umass.cs.mallet.base.maximize.Maximizable;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.ExpGain;
import edu.umass.cs.mallet.base.types.FeatureInducer;
import edu.umass.cs.mallet.base.types.FeatureSelection;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.GradientGain;
import edu.umass.cs.mallet.base.types.InfoGain;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.Label;
import edu.umass.cs.mallet.base.types.LabelAlphabet;
import edu.umass.cs.mallet.base.types.LabelVector;
import edu.umass.cs.mallet.base.types.Labeling;
import edu.umass.cs.mallet.base.types.MatrixOps;
import edu.umass.cs.mallet.base.types.RankedFeatureVector;
import edu.umass.cs.mallet.base.util.CommandOption;
import edu.umass.cs.mallet.base.util.MalletLogger;
import edu.umass.cs.mallet.base.util.MalletProgressMessageLogger;
import edu.umass.cs.mallet.base.util.Maths;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.logging.Logger;
import org.apache.commons.io.IOUtils;

/* loaded from: input_file:WEB-INF/lib/mallet-0.4-jaeschke.jar:edu/umass/cs/mallet/base/classify/MaxEntTrainer.class */
public class MaxEntTrainer extends ClassifierTrainer implements Boostable, Serializable {
    int numGetValueCalls;
    int numGetValueGradientCalls;
    int numIterations;
    public static final String EXP_GAIN = "exp";
    public static final String GRADIENT_GAIN = "grad";
    public static final String INFORMATION_GAIN = "info";
    static final double DEFAULT_GAUSSIAN_PRIOR_VARIANCE = 1.0d;
    boolean usingMultiConditionalTraining;
    boolean usingHyperbolicPrior;
    double gaussianPriorVariance;
    double hyperbolicPriorSlope;
    double hyperbolicPriorSharpness;
    Class maximizerClass;
    private static Logger logger = MalletLogger.getLogger(MaxEntTrainer.class.getName());
    private static Logger progressLogger = MalletProgressMessageLogger.getLogger(MaxEntTrainer.class.getName() + "-pl");
    static final Class DEFAULT_MAXIMIZER_CLASS = LimitedMemoryBFGS.class;
    static CommandOption.Boolean usingMultiConditionalTrainingOption = new CommandOption.Boolean(MaxEntTrainer.class, "useMCTraining", "true|false", false, false, "Use MultiConditional Training", null);
    static CommandOption.Boolean usingHyperbolicPriorOption = new CommandOption.Boolean(MaxEntTrainer.class, "useHyperbolicPrior", "true|false", false, false, "Use hyperbolic (close to L1 penalty) prior over parameters", null);
    static final double DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS = 10.0d;
    static CommandOption.Double gaussianPriorVarianceOption = new CommandOption.Double(MaxEntTrainer.class, "gaussianPriorVariance", "FLOAT", true, DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS, "Variance of the gaussian prior over parameters", null);
    static final double DEFAULT_HYPERBOLIC_PRIOR_SLOPE = 0.2d;
    static CommandOption.Double hyperbolicPriorSlopeOption = new CommandOption.Double(MaxEntTrainer.class, "hyperbolicPriorSlope", "FLOAT", true, DEFAULT_HYPERBOLIC_PRIOR_SLOPE, "Slope of the (L1 penalty) hyperbolic prior over parameters", null);
    static CommandOption.Double hyperbolicPriorSharpnessOption = new CommandOption.Double(MaxEntTrainer.class, "hyperbolicPriorSharpness", "FLOAT", true, DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS, "Sharpness of the (L1 penalty) hyperbolic prior over parameters", null);
    static final CommandOption.List commandOptions = new CommandOption.List("Maximum Entropy Classifier", new CommandOption[]{usingHyperbolicPriorOption, gaussianPriorVarianceOption, hyperbolicPriorSlopeOption, hyperbolicPriorSharpnessOption, usingMultiConditionalTrainingOption});

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:WEB-INF/lib/mallet-0.4-jaeschke.jar:edu/umass/cs/mallet/base/classify/MaxEntTrainer$MaximizableTrainer.class */
    public class MaximizableTrainer implements Maximizable.ByGradient {
        double[] parameters;
        double[] constraints;
        double[] cachedGradient;
        MaxEnt theClassifier;
        InstanceList trainingList;
        double cachedValue;
        boolean cachedValueStale;
        boolean cachedGradientStale;
        int numLabels;
        int numFeatures;
        int defaultFeatureIndex;
        FeatureSelection featureSelection;
        FeatureSelection[] perLabelFeatureSelection;
        static final /* synthetic */ boolean $assertionsDisabled;

        public MaximizableTrainer() {
        }

        public MaximizableTrainer(InstanceList instanceList, MaxEnt maxEnt) {
            this.trainingList = instanceList;
            Alphabet dataAlphabet = instanceList.getDataAlphabet();
            LabelAlphabet labelAlphabet = (LabelAlphabet) instanceList.getTargetAlphabet();
            labelAlphabet.stopGrowth();
            this.numLabels = labelAlphabet.size();
            this.numFeatures = dataAlphabet.size() + 1;
            this.defaultFeatureIndex = this.numFeatures - 1;
            this.parameters = new double[this.numLabels * this.numFeatures];
            this.constraints = new double[this.numLabels * this.numFeatures];
            this.cachedGradient = new double[this.numLabels * this.numFeatures];
            Arrays.fill(this.parameters, Transducer.ZERO_COST);
            Arrays.fill(this.constraints, Transducer.ZERO_COST);
            Arrays.fill(this.cachedGradient, Transducer.ZERO_COST);
            this.featureSelection = instanceList.getFeatureSelection();
            this.perLabelFeatureSelection = instanceList.getPerLabelFeatureSelection();
            if (this.featureSelection != null) {
                this.featureSelection.add(this.defaultFeatureIndex);
            }
            if (this.perLabelFeatureSelection != null) {
                for (int i = 0; i < this.perLabelFeatureSelection.length; i++) {
                    this.perLabelFeatureSelection[i].add(this.defaultFeatureIndex);
                }
            }
            if (!$assertionsDisabled && this.featureSelection != null && this.perLabelFeatureSelection != null) {
                throw new AssertionError();
            }
            if (maxEnt != null) {
                this.theClassifier = maxEnt;
                this.parameters = this.theClassifier.parameters;
                this.featureSelection = this.theClassifier.featureSelection;
                this.perLabelFeatureSelection = this.theClassifier.perClassFeatureSelection;
                this.defaultFeatureIndex = this.theClassifier.defaultFeatureIndex;
                if (!$assertionsDisabled && maxEnt.getInstancePipe() != instanceList.getPipe()) {
                    throw new AssertionError();
                }
            } else if (this.theClassifier == null) {
                this.theClassifier = new MaxEnt(instanceList.getPipe(), this.parameters, this.featureSelection, this.perLabelFeatureSelection);
            }
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            InstanceList.Iterator it2 = this.trainingList.iterator();
            MaxEntTrainer.logger.fine("Number of instances in training list = " + this.trainingList.size());
            while (it2.hasNext()) {
                double instanceWeight = it2.getInstanceWeight();
                Instance nextInstance = it2.nextInstance();
                Labeling labeling = nextInstance.getLabeling();
                FeatureVector featureVector = (FeatureVector) nextInstance.getData();
                Alphabet alphabet = featureVector.getAlphabet();
                if (!$assertionsDisabled && featureVector.getAlphabet() != dataAlphabet) {
                    throw new AssertionError();
                }
                int bestIndex = labeling.getBestIndex();
                MatrixOps.rowPlusEquals(this.constraints, this.numFeatures, bestIndex, featureVector, instanceWeight);
                if (!$assertionsDisabled && Double.isNaN(instanceWeight)) {
                    throw new AssertionError("instanceWeight is NaN");
                }
                if (!$assertionsDisabled && Double.isNaN(bestIndex)) {
                    throw new AssertionError("bestIndex is NaN");
                }
                boolean z = false;
                for (int i2 = 0; i2 < featureVector.numLocations(); i2++) {
                    if (Double.isNaN(featureVector.valueAtLocation(i2))) {
                        MaxEntTrainer.logger.info("NaN for feature " + alphabet.lookupObject(featureVector.indexAtLocation(i2)).toString());
                        z = true;
                    }
                }
                if (z) {
                    MaxEntTrainer.logger.info("NaN in instance: " + nextInstance.getName());
                }
                double[] dArr = this.constraints;
                int i3 = (bestIndex * this.numFeatures) + this.defaultFeatureIndex;
                dArr[i3] = dArr[i3] + (MaxEntTrainer.DEFAULT_GAUSSIAN_PRIOR_VARIANCE * instanceWeight);
            }
        }

        public MaxEnt getClassifier() {
            return this.theClassifier;
        }

        @Override // edu.umass.cs.mallet.base.maximize.Maximizable
        public double getParameter(int i) {
            return this.parameters[i];
        }

        @Override // edu.umass.cs.mallet.base.maximize.Maximizable
        public void setParameter(int i, double d) {
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            this.parameters[i] = d;
        }

        @Override // edu.umass.cs.mallet.base.maximize.Maximizable
        public int getNumParameters() {
            return this.parameters.length;
        }

        @Override // edu.umass.cs.mallet.base.maximize.Maximizable
        public void getParameters(double[] dArr) {
            if (dArr == null || dArr.length != this.parameters.length) {
                dArr = new double[this.parameters.length];
            }
            System.arraycopy(this.parameters, 0, dArr, 0, this.parameters.length);
        }

        @Override // edu.umass.cs.mallet.base.maximize.Maximizable
        public void setParameters(double[] dArr) {
            if (!$assertionsDisabled && dArr == null) {
                throw new AssertionError();
            }
            this.cachedValueStale = true;
            this.cachedGradientStale = true;
            if (dArr.length != this.parameters.length) {
                this.parameters = new double[dArr.length];
            }
            System.arraycopy(dArr, 0, this.parameters, 0, dArr.length);
        }

        @Override // edu.umass.cs.mallet.base.maximize.Maximizable.ByGradient
        public double getValue() {
            if (this.cachedValueStale) {
                MaxEntTrainer.this.numGetValueCalls++;
                this.cachedValue = Transducer.ZERO_COST;
                this.cachedGradientStale = true;
                MatrixOps.setAll(this.cachedGradient, Transducer.ZERO_COST);
                double[] dArr = new double[this.trainingList.getTargetAlphabet().size()];
                InstanceList.Iterator it2 = this.trainingList.iterator();
                int i = 0;
                while (it2.hasNext()) {
                    i++;
                    double instanceWeight = it2.getInstanceWeight();
                    Instance nextInstance = it2.nextInstance();
                    Labeling labeling = nextInstance.getLabeling();
                    this.theClassifier.getClassificationScores(nextInstance, dArr);
                    FeatureVector featureVector = (FeatureVector) nextInstance.getData();
                    int bestIndex = labeling.getBestIndex();
                    double d = -(instanceWeight * Math.log(dArr[bestIndex]));
                    if (Double.isNaN(d)) {
                        MaxEntTrainer.logger.fine("MaxEntTrainer: Instance " + nextInstance.getName() + "has NaN value. log(scores)= " + Math.log(dArr[bestIndex]) + " scores = " + dArr[bestIndex] + " has instance weight = " + instanceWeight);
                    }
                    if (Double.isInfinite(d)) {
                        MaxEntTrainer.logger.warning("Instance " + nextInstance.getSource() + " has infinite value; skipping value and gradient");
                        this.cachedValue -= d;
                        this.cachedValueStale = false;
                        return -d;
                    }
                    this.cachedValue += d;
                    for (int i2 = 0; i2 < dArr.length; i2++) {
                        if (dArr[i2] != Transducer.ZERO_COST) {
                            if (!$assertionsDisabled && Double.isInfinite(dArr[i2])) {
                                throw new AssertionError();
                            }
                            MatrixOps.rowPlusEquals(this.cachedGradient, this.numFeatures, i2, featureVector, (-instanceWeight) * dArr[i2]);
                            double[] dArr2 = this.cachedGradient;
                            int i3 = (this.numFeatures * i2) + this.defaultFeatureIndex;
                            dArr2[i3] = dArr2[i3] + ((-instanceWeight) * dArr[i2]);
                        }
                    }
                    if (MaxEntTrainer.this.usingMultiConditionalTraining) {
                        for (int i4 = 0; i4 < this.numFeatures; i4++) {
                            double[] dArr3 = this.cachedGradient;
                            int i5 = (this.numFeatures * bestIndex) + i4;
                            dArr3[i5] = dArr3[i5] + ((-instanceWeight) * Math.exp(this.parameters[(this.numFeatures * bestIndex) + i4]));
                        }
                    }
                }
                if (MaxEntTrainer.this.usingHyperbolicPrior) {
                    for (int i6 = 0; i6 < this.numLabels; i6++) {
                        for (int i7 = 0; i7 < this.numFeatures; i7++) {
                            this.cachedValue += (MaxEntTrainer.this.hyperbolicPriorSlope / MaxEntTrainer.this.hyperbolicPriorSharpness) * Math.log(Maths.cosh(MaxEntTrainer.this.hyperbolicPriorSharpness * this.parameters[(i6 * this.numFeatures) + i7]));
                        }
                    }
                } else {
                    for (int i8 = 0; i8 < this.numLabels; i8++) {
                        for (int i9 = 0; i9 < this.numFeatures; i9++) {
                            double d2 = this.parameters[(i8 * this.numFeatures) + i9];
                            this.cachedValue += (d2 * d2) / (2.0d * MaxEntTrainer.this.gaussianPriorVariance);
                        }
                    }
                }
                this.cachedValue *= -1.0d;
                this.cachedValueStale = false;
                MaxEntTrainer.progressLogger.info("Value (loglikelihood) = " + this.cachedValue);
            }
            return this.cachedValue;
        }

        @Override // edu.umass.cs.mallet.base.maximize.Maximizable.ByGradient
        public void getValueGradient(double[] dArr) {
            if (this.cachedGradientStale) {
                MaxEntTrainer.this.numGetValueGradientCalls++;
                if (this.cachedValueStale) {
                    getValue();
                }
                MatrixOps.plusEquals(this.cachedGradient, this.constraints);
                if (MaxEntTrainer.this.usingMultiConditionalTraining) {
                    MatrixOps.plusEquals(this.cachedGradient, this.constraints);
                }
                if (MaxEntTrainer.this.usingHyperbolicPrior) {
                    throw new UnsupportedOperationException("Hyperbolic prior not yet implemented.");
                }
                MatrixOps.plusEquals(this.cachedGradient, this.parameters, (-1.0d) / MaxEntTrainer.this.gaussianPriorVariance);
                MatrixOps.substitute(this.cachedGradient, Double.NEGATIVE_INFINITY, Transducer.ZERO_COST);
                if (this.perLabelFeatureSelection == null) {
                    for (int i = 0; i < this.numLabels; i++) {
                        MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, i, Transducer.ZERO_COST, this.featureSelection, false);
                    }
                } else {
                    for (int i2 = 0; i2 < this.numLabels; i2++) {
                        MatrixOps.rowSetAll(this.cachedGradient, this.numFeatures, i2, Transducer.ZERO_COST, this.perLabelFeatureSelection[i2], false);
                    }
                }
                this.cachedGradientStale = false;
            }
            if (!$assertionsDisabled && (dArr == null || dArr.length != this.parameters.length)) {
                throw new AssertionError();
            }
            System.arraycopy(this.cachedGradient, 0, dArr, 0, this.cachedGradient.length);
        }

        static {
            $assertionsDisabled = !MaxEntTrainer.class.desiredAssertionStatus();
        }
    }

    public static CommandOption.List getCommandOptionList() {
        return commandOptions;
    }

    public MaxEntTrainer(CommandOption.List list) {
        this.numGetValueCalls = 0;
        this.numGetValueGradientCalls = 0;
        this.numIterations = 10;
        this.usingMultiConditionalTraining = false;
        this.usingHyperbolicPrior = false;
        this.gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
        this.hyperbolicPriorSlope = DEFAULT_HYPERBOLIC_PRIOR_SLOPE;
        this.hyperbolicPriorSharpness = DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS;
        this.maximizerClass = DEFAULT_MAXIMIZER_CLASS;
        this.usingHyperbolicPrior = usingHyperbolicPriorOption.value;
        this.gaussianPriorVariance = gaussianPriorVarianceOption.value;
        this.hyperbolicPriorSlope = hyperbolicPriorSlopeOption.value;
        this.hyperbolicPriorSharpness = hyperbolicPriorSharpnessOption.value;
        this.usingMultiConditionalTraining = usingMultiConditionalTrainingOption.value;
    }

    public MaxEntTrainer() {
        this(false);
    }

    public MaxEntTrainer(boolean z) {
        this.numGetValueCalls = 0;
        this.numGetValueGradientCalls = 0;
        this.numIterations = 10;
        this.usingMultiConditionalTraining = false;
        this.usingHyperbolicPrior = false;
        this.gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
        this.hyperbolicPriorSlope = DEFAULT_HYPERBOLIC_PRIOR_SLOPE;
        this.hyperbolicPriorSharpness = DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS;
        this.maximizerClass = DEFAULT_MAXIMIZER_CLASS;
        this.usingHyperbolicPrior = z;
    }

    public MaxEntTrainer(double d) {
        this.numGetValueCalls = 0;
        this.numGetValueGradientCalls = 0;
        this.numIterations = 10;
        this.usingMultiConditionalTraining = false;
        this.usingHyperbolicPrior = false;
        this.gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
        this.hyperbolicPriorSlope = DEFAULT_HYPERBOLIC_PRIOR_SLOPE;
        this.hyperbolicPriorSharpness = DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS;
        this.maximizerClass = DEFAULT_MAXIMIZER_CLASS;
        this.usingHyperbolicPrior = false;
        this.gaussianPriorVariance = d;
    }

    public MaxEntTrainer(double d, boolean z) {
        this.numGetValueCalls = 0;
        this.numGetValueGradientCalls = 0;
        this.numIterations = 10;
        this.usingMultiConditionalTraining = false;
        this.usingHyperbolicPrior = false;
        this.gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
        this.hyperbolicPriorSlope = DEFAULT_HYPERBOLIC_PRIOR_SLOPE;
        this.hyperbolicPriorSharpness = DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS;
        this.maximizerClass = DEFAULT_MAXIMIZER_CLASS;
        this.usingHyperbolicPrior = false;
        this.usingMultiConditionalTraining = z;
        this.gaussianPriorVariance = d;
    }

    public MaxEntTrainer(double d, double d2) {
        this.numGetValueCalls = 0;
        this.numGetValueGradientCalls = 0;
        this.numIterations = 10;
        this.usingMultiConditionalTraining = false;
        this.usingHyperbolicPrior = false;
        this.gaussianPriorVariance = DEFAULT_GAUSSIAN_PRIOR_VARIANCE;
        this.hyperbolicPriorSlope = DEFAULT_HYPERBOLIC_PRIOR_SLOPE;
        this.hyperbolicPriorSharpness = DEFAULT_HYPERBOLIC_PRIOR_SHARPNESS;
        this.maximizerClass = DEFAULT_MAXIMIZER_CLASS;
        this.usingHyperbolicPrior = true;
        this.hyperbolicPriorSlope = d;
        this.hyperbolicPriorSharpness = d2;
    }

    public Maximizable.ByGradient getMaximizableTrainer(InstanceList instanceList) {
        return instanceList == null ? new MaximizableTrainer() : new MaximizableTrainer(instanceList, null);
    }

    public MaxEntTrainer setNumIterations(int i) {
        this.numIterations = i;
        return this;
    }

    public MaxEntTrainer setUseHyperbolicPrior(boolean z) {
        this.usingHyperbolicPrior = z;
        return this;
    }

    public MaxEntTrainer setGaussianPriorVariance(double d) {
        this.usingHyperbolicPrior = false;
        this.gaussianPriorVariance = d;
        return this;
    }

    public MaxEntTrainer setHyperbolicPriorSlope(double d) {
        this.usingHyperbolicPrior = true;
        this.hyperbolicPriorSlope = d;
        return this;
    }

    public MaxEntTrainer setHyperbolicPriorSharpness(double d) {
        this.usingHyperbolicPrior = true;
        this.hyperbolicPriorSharpness = d;
        return this;
    }

    @Override // edu.umass.cs.mallet.base.classify.ClassifierTrainer
    public Classifier train(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ClassifierEvaluating classifierEvaluating, Classifier classifier) {
        logger.fine("trainingSet.size() = " + instanceList.size());
        MaximizableTrainer maximizableTrainer = new MaximizableTrainer(instanceList, (MaxEnt) classifier);
        new LimitedMemoryBFGS().maximize(maximizableTrainer);
        logger.info("MaxEnt ngetValueCalls:" + getValueCalls() + "\nMaxEnt ngetValueGradientCalls:" + getValueGradientCalls());
        progressLogger.info(IOUtils.LINE_SEPARATOR_UNIX);
        return maximizableTrainer.getClassifier();
    }

    public Classifier trainWithFeatureInduction(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ClassifierEvaluating classifierEvaluating, int i, int i2, int i3, int i4) {
        return trainWithFeatureInduction(instanceList, instanceList2, instanceList3, classifierEvaluating, null, i, i2, i3, i4, "exp");
    }

    public Classifier trainWithFeatureInduction(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ClassifierEvaluating classifierEvaluating, MaxEnt maxEnt, int i, int i2, int i3, int i4, String str) {
        RankedFeatureVector.Factory factory;
        Alphabet dataAlphabet = instanceList.getDataAlphabet();
        Alphabet targetAlphabet = instanceList.getTargetAlphabet();
        if (maxEnt == null) {
            maxEnt = new MaxEnt(instanceList.getPipe(), new double[(1 + dataAlphabet.size()) * targetAlphabet.size()]);
        }
        int i5 = 0;
        int size = targetAlphabet.size();
        FeatureSelection featureSelection = instanceList.getFeatureSelection();
        if (featureSelection == null) {
            featureSelection = new FeatureSelection(instanceList.getDataAlphabet());
            instanceList.setFeatureSelection(featureSelection);
        }
        if (instanceList2 != null) {
            instanceList2.setFeatureSelection(featureSelection);
        }
        if (instanceList3 != null) {
            instanceList3.setFeatureSelection(featureSelection);
        }
        MaxEnt maxEnt2 = new MaxEnt(maxEnt.getInstancePipe(), maxEnt.getParameters(), featureSelection);
        for (int i6 = 0; i6 < i3; i6++) {
            logger.info("Feature induction iteration " + i6);
            if (i6 != 0) {
                setNumIterations(i2);
                maxEnt2 = (MaxEnt) train(instanceList, instanceList2, instanceList3, classifierEvaluating, maxEnt2);
            }
            i5 += i2;
            logger.info("Starting feature induction with " + (1 + dataAlphabet.size()) + " features over " + size + " labels.");
            InstanceList instanceList4 = new InstanceList(instanceList.getDataAlphabet(), instanceList.getTargetAlphabet());
            instanceList4.setFeatureSelection(featureSelection);
            ArrayList arrayList = new ArrayList();
            for (int i7 = 0; i7 < instanceList.size(); i7++) {
                Instance instanceList5 = instanceList.getInstance(i7);
                FeatureVector featureVector = (FeatureVector) instanceList5.getData();
                Label label = (Label) instanceList5.getTarget();
                Classification classify = maxEnt2.classify(instanceList5);
                if (!classify.bestLabelIsCorrect()) {
                    instanceList4.add(featureVector, label, null, null);
                    arrayList.add(classify.getLabelVector());
                }
            }
            logger.info("Error instance list size = " + instanceList4.size());
            int size2 = arrayList.size();
            LabelVector[] labelVectorArr = new LabelVector[size2];
            for (int i8 = 0; i8 < size2; i8++) {
                labelVectorArr[i8] = (LabelVector) arrayList.get(i8);
            }
            if (str.equals("exp")) {
                factory = new ExpGain.Factory(labelVectorArr, this.gaussianPriorVariance);
            } else if (str.equals("grad")) {
                factory = new GradientGain.Factory(labelVectorArr);
            } else {
                if (!str.equals("info")) {
                    throw new IllegalArgumentException("Unsupported gain name: " + str);
                }
                factory = new InfoGain.Factory();
            }
            FeatureInducer featureInducer = new FeatureInducer(factory, instanceList4, i4, 2 * i4, 2 * i4);
            featureInducer.induceFeaturesFor(instanceList, false, false);
            if (instanceList3 != null) {
                featureInducer.induceFeaturesFor(instanceList3, false, false);
            }
            logger.info("MaxEnt FeatureSelection now includes " + featureSelection.cardinality() + " features");
            double[] dArr = new double[(1 + dataAlphabet.size()) * targetAlphabet.size()];
            if (0 != 0) {
                int length = maxEnt2.parameters.length / targetAlphabet.size();
                int size3 = 1 + dataAlphabet.size();
                for (int i9 = 0; i9 < targetAlphabet.size(); i9++) {
                    System.arraycopy(maxEnt2.parameters, i9 * length, dArr, i9 * size3, length);
                }
                for (int i10 = 0; i10 < length; i10++) {
                    if (maxEnt2.parameters[i10] != dArr[i10]) {
                        System.out.println(maxEnt2.parameters[i10] + " " + dArr[i10]);
                        System.exit(0);
                    }
                }
            }
            maxEnt2.parameters = dArr;
            maxEnt2.defaultFeatureIndex = dataAlphabet.size();
        }
        logger.info("Ended with " + featureSelection.cardinality() + " features.");
        setNumIterations(i - i5);
        return train(instanceList, instanceList2, instanceList3, classifierEvaluating, maxEnt2);
    }

    public int getValueGradientCalls() {
        return this.numGetValueGradientCalls;
    }

    public int getValueCalls() {
        return this.numGetValueCalls;
    }

    @Override // edu.umass.cs.mallet.base.classify.ClassifierTrainer
    public String toString() {
        return "MaxEntTrainer,numIterations=" + this.numIterations + (this.usingHyperbolicPrior ? ",hyperbolicPriorSlope=" + this.hyperbolicPriorSlope + ",hyperbolicPriorSharpness=" + this.hyperbolicPriorSharpness : ",gaussianPriorVariance=" + this.gaussianPriorVariance);
    }
}
