package edu.umass.cs.mallet.base.classify;

import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.InstanceList;
import edu.umass.cs.mallet.base.types.Labeling;

/* loaded from: input_file:WEB-INF/lib/mallet-0.4-jaeschke.jar:edu/umass/cs/mallet/base/classify/WinnowTrainer.class */
public class WinnowTrainer extends ClassifierTrainer {
    static final double DEFAULT_ALPHA = 2.0d;
    static final double DEFAULT_BETA = 2.0d;
    static final double DEFAULT_NFACTOR = 0.5d;
    double alpha;
    double beta;
    double theta;
    double nfactor;
    double[][] weights;

    public WinnowTrainer() {
        this(2.0d, 2.0d, 0.5d);
    }

    public WinnowTrainer(double d, double d2) {
        this(d, d2, 0.5d);
    }

    public WinnowTrainer(double d, double d2, double d3) {
        this.alpha = d;
        this.beta = d2;
        this.nfactor = d3;
    }

    @Override // edu.umass.cs.mallet.base.classify.ClassifierTrainer
    public Classifier train(InstanceList instanceList, InstanceList instanceList2, InstanceList instanceList3, ClassifierEvaluating classifierEvaluating, Classifier classifier) {
        if (instanceList.getFeatureSelection() != null) {
            throw new UnsupportedOperationException("FeatureSelection not yet implemented.");
        }
        instanceList.getDataAlphabet().stopGrowth();
        instanceList.getTargetAlphabet().stopGrowth();
        Pipe pipe = instanceList.getPipe();
        Alphabet dataAlphabet = instanceList.getDataAlphabet();
        int size = instanceList.getTargetAlphabet().size();
        int size2 = dataAlphabet.size();
        this.theta = size2 * this.nfactor;
        this.weights = new double[size][size2];
        for (int i = 0; i < size; i++) {
            for (int i2 = 0; i2 < size2; i2++) {
                this.weights[i][i2] = 1.0d;
            }
        }
        for (int i3 = 0; i3 < instanceList.size(); i3++) {
            Instance instance = (Instance) instanceList.get(i3);
            Labeling labeling = instance.getLabeling();
            FeatureVector featureVector = (FeatureVector) instance.getData(pipe);
            double[] dArr = new double[size];
            int numLocations = featureVector.numLocations();
            int bestIndex = labeling.getBestIndex();
            for (int i4 = 0; i4 < size; i4++) {
                dArr[i4] = 0.0d;
            }
            for (int i5 = 0; i5 < numLocations; i5++) {
                int indexAtLocation = featureVector.indexAtLocation(i5);
                for (int i6 = 0; i6 < size; i6++) {
                    int i7 = i6;
                    dArr[i7] = dArr[i7] + this.weights[i6][indexAtLocation];
                }
            }
            for (int i8 = 0; i8 < size; i8++) {
                if (dArr[i8] > this.theta) {
                    if (bestIndex != i8) {
                        demote(i8, featureVector);
                    }
                } else if (bestIndex == i8) {
                    promote(i8, featureVector);
                }
            }
        }
        return new Winnow(pipe, this.weights, this.theta, size, size2);
    }

    private void promote(int i, FeatureVector featureVector) {
        int numLocations = featureVector.numLocations();
        for (int i2 = 0; i2 < numLocations; i2++) {
            int indexAtLocation = featureVector.indexAtLocation(i2);
            double[] dArr = this.weights[i];
            dArr[indexAtLocation] = dArr[indexAtLocation] * this.alpha;
        }
    }

    private void demote(int i, FeatureVector featureVector) {
        int numLocations = featureVector.numLocations();
        for (int i2 = 0; i2 < numLocations; i2++) {
            int indexAtLocation = featureVector.indexAtLocation(i2);
            double[] dArr = this.weights[i];
            dArr[indexAtLocation] = dArr[indexAtLocation] / this.beta;
        }
    }
}
