package edu.umass.cs.mallet.base.types;

import edu.umass.cs.mallet.base.classify.Classification;
import edu.umass.cs.mallet.base.fst.Transducer;
import edu.umass.cs.mallet.base.util.MalletLogger;
import java.util.logging.Logger;

/* loaded from: input_file:WEB-INF/lib/mallet-0.4-jaeschke.jar:edu/umass/cs/mallet/base/types/KLGain.class */
public class KLGain extends RankedFeatureVector {
    private static Logger logger;
    static final /* synthetic */ boolean $assertionsDisabled;

    private static double[] calcKLGains(InstanceList instanceList, LabelVector[] labelVectorArr) {
        int size = instanceList.size();
        int size2 = instanceList.getTargetAlphabet().size();
        int size3 = instanceList.getDataAlphabet().size();
        if (!$assertionsDisabled && instanceList.size() <= 0) {
            throw new AssertionError();
        }
        double[][] dArr = new double[size2][size3];
        double[][] dArr2 = new double[size2][size3];
        double[][] dArr3 = new double[size2][size3];
        logger.info("Starting klgains, #instances=" + size);
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = size + 1.0d;
        for (int i = 0; i < size2; i++) {
            for (int i2 = 0; i2 < size3; i2++) {
                double d4 = 1.0d / ((d3 * size3) * size2);
                dArr2[i][i2] = d4;
                dArr[i][i2] = d4;
                d += dArr[i][i2];
                d2 += dArr2[i][i2];
            }
        }
        for (int i3 = 0; i3 < size; i3++) {
            if (!$assertionsDisabled && labelVectorArr[i3].getLabelAlphabet() != instanceList.getTargetAlphabet()) {
                throw new AssertionError();
            }
            Instance instanceList2 = instanceList.getInstance(i3);
            Labeling labeling = instanceList2.getLabeling();
            FeatureVector featureVector = (FeatureVector) instanceList2.getData();
            for (int i4 = 0; i4 < size2; i4++) {
                double value = labeling.value(i4) / d3;
                double value2 = labelVectorArr[i3].value(i4) / d3;
                d += value;
                d2 += value2;
                if (value != Transducer.ZERO_COST || value2 != Transducer.ZERO_COST) {
                    for (int i5 = 0; i5 < featureVector.numLocations(); i5++) {
                        int indexAtLocation = featureVector.indexAtLocation(i5);
                        if (!$assertionsDisabled && featureVector.valueAtLocation(i5) != 1.0d) {
                            throw new AssertionError();
                        }
                        double[] dArr4 = dArr[i4];
                        dArr4[indexAtLocation] = dArr4[indexAtLocation] + value;
                        double[] dArr5 = dArr2[i4];
                        dArr5[indexAtLocation] = dArr5[indexAtLocation] + value2;
                    }
                }
            }
        }
        if (!$assertionsDisabled && Math.abs(d - 1.0d) >= 0.001d) {
            throw new AssertionError("trueLabelWeightSum should be 1.0, it was " + d);
        }
        if (!$assertionsDisabled && Math.abs(d2 - 1.0d) >= 0.001d) {
            throw new AssertionError("modelLabelWeightSum should be 1.0, it was " + d2);
        }
        for (int i6 = 0; i6 < size2; i6++) {
            for (int i7 = 0; i7 < size3; i7++) {
                dArr3[i6][i7] = Math.log((dArr[i6][i7] * (1.0d - dArr2[i6][i7])) / (dArr2[i6][i7] * (1.0d - dArr[i6][i7])));
            }
        }
        double[][] dArr6 = new double[size2][size3];
        double d5 = 0.0d;
        for (int i8 = 0; i8 < instanceList.size(); i8++) {
            if (!$assertionsDisabled && labelVectorArr[i8].getLabelAlphabet() != instanceList.getTargetAlphabet()) {
                throw new AssertionError();
            }
            Instance instanceList3 = instanceList.getInstance(i8);
            instanceList3.getLabeling();
            FeatureVector featureVector2 = (FeatureVector) instanceList3.getData();
            int numLocations = featureVector2.numLocations() - 1;
            for (int i9 = 0; i9 < size2; i9++) {
                double value3 = labelVectorArr[i8].value(i9) / size;
                d5 += value3;
                for (int i10 = 0; i10 < featureVector2.numLocations(); i10++) {
                    int indexAtLocation2 = featureVector2.indexAtLocation(i10);
                    double[] dArr7 = dArr6[i9];
                    dArr7[indexAtLocation2] = dArr7[indexAtLocation2] + ((value3 * Math.exp(dArr3[i9][indexAtLocation2])) - value3);
                }
            }
        }
        for (int i11 = 0; i11 < size2; i11++) {
            for (int i12 = 0; i12 < size3; i12++) {
                double[] dArr8 = dArr6[i11];
                int i13 = i12;
                dArr8[i13] = dArr8[i13] + d5;
            }
        }
        double[] dArr9 = new double[size3];
        for (int i14 = 0; i14 < size2; i14++) {
            for (int i15 = 0; i15 < size3; i15++) {
                if (dArr3[i14][i15] > Transducer.ZERO_COST && !Double.isInfinite(dArr3[i14][i15])) {
                    int i16 = i15;
                    dArr9[i16] = dArr9[i16] + ((dArr3[i14][i15] * dArr[i14][i15]) - Math.log(dArr6[i14][i15]));
                }
            }
        }
        logger.info("klgains.length=" + dArr9.length);
        for (int i17 = 0; i17 < size3; i17++) {
            if (i17 % (size3 / 100) == 0) {
                for (int i18 = 0; i18 < size2; i18++) {
                    logger.info("c=" + i18 + " p[" + instanceList.getDataAlphabet().lookupObject(i17) + "] = " + dArr[i18][i17]);
                    logger.info("c=" + i18 + " q[" + instanceList.getDataAlphabet().lookupObject(i17) + "] = " + dArr2[i18][i17]);
                    logger.info("c=" + i18 + " alphas[" + instanceList.getDataAlphabet().lookupObject(i17) + "] = " + dArr3[i18][i17]);
                    logger.info("c=" + i18 + " qeag[" + instanceList.getDataAlphabet().lookupObject(i17) + "] = " + dArr6[i18][i17]);
                }
                logger.info("klgains[" + instanceList.getDataAlphabet().lookupObject(i17) + "] = " + dArr9[i17]);
            }
        }
        return dArr9;
    }

    public KLGain(InstanceList instanceList, LabelVector[] labelVectorArr) {
        super(instanceList.getDataAlphabet(), calcKLGains(instanceList, labelVectorArr));
    }

    private static LabelVector[] getLabelVectorsFromClassifications(Classification[] classificationArr) {
        LabelVector[] labelVectorArr = new LabelVector[classificationArr.length];
        for (int i = 0; i < classificationArr.length; i++) {
            labelVectorArr[i] = classificationArr[i].getLabelVector();
        }
        return labelVectorArr;
    }

    public KLGain(InstanceList instanceList, Classification[] classificationArr) {
        super(instanceList.getDataAlphabet(), calcKLGains(instanceList, getLabelVectorsFromClassifications(classificationArr)));
    }

    static {
        $assertionsDisabled = !KLGain.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(KLGain.class.getName());
    }
}
