package edu.umass.cs.mallet.base.classify;

import edu.umass.cs.mallet.base.pipe.Pipe;
import edu.umass.cs.mallet.base.types.Alphabet;
import edu.umass.cs.mallet.base.types.DenseVector;
import edu.umass.cs.mallet.base.types.FeatureSelection;
import edu.umass.cs.mallet.base.types.FeatureVector;
import edu.umass.cs.mallet.base.types.Instance;
import edu.umass.cs.mallet.base.types.LabelAlphabet;
import edu.umass.cs.mallet.base.types.LabelVector;
import edu.umass.cs.mallet.base.types.MatrixOps;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;

/* loaded from: input_file:WEB-INF/lib/mallet-0.4-jaeschke.jar:edu/umass/cs/mallet/base/classify/MaxEnt.class */
public class MaxEnt extends Classifier implements Serializable {
    double[] parameters;
    int defaultFeatureIndex;
    FeatureSelection featureSelection;
    FeatureSelection[] perClassFeatureSelection;
    private static final long serialVersionUID = 1;
    private static final int CURRENT_SERIAL_VERSION = 1;
    static final int NULL_INTEGER = -1;
    static final /* synthetic */ boolean $assertionsDisabled;

    public MaxEnt(Pipe pipe, double[] dArr, FeatureSelection featureSelection, FeatureSelection[] featureSelectionArr) {
        super(pipe);
        if (!$assertionsDisabled && featureSelection != null && featureSelectionArr != null) {
            throw new AssertionError();
        }
        this.parameters = dArr;
        this.featureSelection = featureSelection;
        this.perClassFeatureSelection = featureSelectionArr;
        this.defaultFeatureIndex = pipe.getDataAlphabet().size();
    }

    public MaxEnt(Pipe pipe, double[] dArr, FeatureSelection featureSelection) {
        this(pipe, dArr, featureSelection, null);
    }

    public MaxEnt(Pipe pipe, double[] dArr, FeatureSelection[] featureSelectionArr) {
        this(pipe, dArr, null, featureSelectionArr);
    }

    public MaxEnt(Pipe pipe, double[] dArr) {
        this(pipe, dArr, null, null);
    }

    public double[] getParameters() {
        return this.parameters;
    }

    public void setParameter(int i, int i2, double d) {
        this.parameters[(i * (getAlphabet().size() + 1)) + i2] = d;
    }

    public void getUnnormalizedClassificationScores(Instance instance, double[] dArr) {
        int i = this.defaultFeatureIndex + 1;
        int size = getLabelAlphabet().size();
        if (!$assertionsDisabled && dArr.length != size) {
            throw new AssertionError();
        }
        FeatureVector featureVector = (FeatureVector) instance.getData(this.instancePipe);
        if (!$assertionsDisabled && featureVector.getAlphabet() != this.instancePipe.getDataAlphabet()) {
            throw new AssertionError();
        }
        for (int i2 = 0; i2 < size; i2++) {
            dArr[i2] = this.parameters[(i2 * i) + this.defaultFeatureIndex] + MatrixOps.rowDotProduct(this.parameters, i, i2, featureVector, this.defaultFeatureIndex, this.perClassFeatureSelection == null ? this.featureSelection : this.perClassFeatureSelection[i2]);
        }
    }

    public void getClassificationScores(Instance instance, double[] dArr) {
        int size = getLabelAlphabet().size();
        if (!$assertionsDisabled && dArr.length != size) {
            throw new AssertionError();
        }
        FeatureVector featureVector = (FeatureVector) instance.getData(this.instancePipe);
        if (!$assertionsDisabled && this.instancePipe != null && featureVector.getAlphabet() != this.instancePipe.getDataAlphabet()) {
            throw new AssertionError();
        }
        int i = this.defaultFeatureIndex + 1;
        for (int i2 = 0; i2 < size; i2++) {
            dArr[i2] = this.parameters[(i2 * i) + this.defaultFeatureIndex] + MatrixOps.rowDotProduct(this.parameters, i, i2, featureVector, this.defaultFeatureIndex, this.perClassFeatureSelection == null ? this.featureSelection : this.perClassFeatureSelection[i2]);
        }
        double max = DenseVector.max(dArr);
        double d = 0.0d;
        for (int i3 = 0; i3 < size; i3++) {
            double exp = Math.exp(dArr[i3] - max);
            dArr[i3] = exp;
            d += exp;
        }
        for (int i4 = 0; i4 < size; i4++) {
            int i5 = i4;
            dArr[i5] = dArr[i5] / d;
        }
    }

    @Override // edu.umass.cs.mallet.base.classify.Classifier
    public Classification classify(Instance instance) {
        double[] dArr = new double[getLabelAlphabet().size()];
        getClassificationScores(instance, dArr);
        return new Classification(instance, this, new LabelVector(getLabelAlphabet(), dArr));
    }

    @Override // edu.umass.cs.mallet.base.classify.Classifier
    public void print() {
        Alphabet alphabet = getAlphabet();
        LabelAlphabet labelAlphabet = getLabelAlphabet();
        int size = alphabet.size() + 1;
        int size2 = labelAlphabet.size();
        for (int i = 0; i < size2; i++) {
            System.out.println("FEATURES FOR CLASS " + labelAlphabet.lookupObject(i));
            System.out.println(" <default> " + this.parameters[(i * size) + this.defaultFeatureIndex]);
            for (int i2 = 0; i2 < this.defaultFeatureIndex; i2++) {
                System.out.println(" " + alphabet.lookupObject(i2) + " " + this.parameters[(i * size) + i2]);
            }
        }
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        objectOutputStream.writeInt(1);
        objectOutputStream.writeObject(getInstancePipe());
        int length = this.parameters.length;
        objectOutputStream.writeInt(length);
        for (int i = 0; i < length; i++) {
            objectOutputStream.writeDouble(this.parameters[i]);
        }
        objectOutputStream.writeInt(this.defaultFeatureIndex);
        if (this.featureSelection == null) {
            objectOutputStream.writeInt(-1);
        } else {
            objectOutputStream.writeInt(1);
            objectOutputStream.writeObject(this.featureSelection);
        }
        if (this.perClassFeatureSelection == null) {
            objectOutputStream.writeInt(-1);
            return;
        }
        objectOutputStream.writeInt(this.perClassFeatureSelection.length);
        for (int i2 = 0; i2 < this.perClassFeatureSelection.length; i2++) {
            if (this.perClassFeatureSelection[i2] == null) {
                objectOutputStream.writeInt(-1);
            } else {
                objectOutputStream.writeInt(1);
                objectOutputStream.writeObject(this.perClassFeatureSelection[i2]);
            }
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        int readInt = objectInputStream.readInt();
        if (readInt != 1) {
            throw new ClassNotFoundException("Mismatched MaxEnt versions: wanted 1, got " + readInt);
        }
        this.instancePipe = (Pipe) objectInputStream.readObject();
        int readInt2 = objectInputStream.readInt();
        this.parameters = new double[readInt2];
        for (int i = 0; i < readInt2; i++) {
            this.parameters[i] = objectInputStream.readDouble();
        }
        this.defaultFeatureIndex = objectInputStream.readInt();
        if (objectInputStream.readInt() == 1) {
            this.featureSelection = (FeatureSelection) objectInputStream.readObject();
        }
        int readInt3 = objectInputStream.readInt();
        if (readInt3 >= 0) {
            this.perClassFeatureSelection = new FeatureSelection[readInt3];
            for (int i2 = 0; i2 < readInt3; i2++) {
                if (objectInputStream.readInt() == 1) {
                    this.perClassFeatureSelection[i2] = (FeatureSelection) objectInputStream.readObject();
                }
            }
        }
    }

    static {
        $assertionsDisabled = !MaxEnt.class.desiredAssertionStatus();
    }
}
