package edu.umass.cs.mallet.base.types;

import edu.umass.cs.mallet.base.util.MalletLogger;
import edu.umass.cs.mallet.base.util.Maths;
import java.awt.geom.Point2D;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Comparator;
import java.util.Hashtable;
import java.util.logging.Logger;

/* loaded from: input_file:org/bibsonomy/scraper/ie/training/mallet.jar:edu/umass/cs/mallet/base/types/GainRatio.class */
public class GainRatio extends RankedFeatureVector {
    private static final Logger logger;
    private static final long serialVersionUID = 1;
    public static final double log2;
    double[] m_splitPoints;
    double m_baseEntropy;
    LabelVector m_baseLabelDistribution;
    int m_numSplitPointsForBestFeature;
    int m_minNumInsts;
    static final /* synthetic */ boolean $assertionsDisabled;

    static {
        $assertionsDisabled = !GainRatio.class.desiredAssertionStatus();
        logger = MalletLogger.getLogger(GainRatio.class.getName());
        log2 = Math.log(2.0d);
    }

    protected static Object[] calcGainRatios(InstanceList instanceList, int[] iArr, int i) {
        int length = iArr.length;
        Alphabet dataAlphabet = instanceList.getDataAlphabet();
        LabelAlphabet labelAlphabet = (LabelAlphabet) instanceList.getTargetAlphabet();
        double[] dArr = new double[labelAlphabet.size()];
        for (int i2 : iArr) {
            Labeling labeling = instanceList.getInstance(i2).getLabeling();
            double d = 0.0d;
            for (int i3 = 0; i3 < labeling.numLocations(); i3++) {
                int indexAtLocation = labeling.indexAtLocation(i3);
                double valueAtLocation = labeling.valueAtLocation(i3);
                d += valueAtLocation;
                dArr[indexAtLocation] = dArr[indexAtLocation] + valueAtLocation;
            }
            if (!$assertionsDisabled && !Maths.almostEquals(d, 1.0d)) {
                throw new AssertionError();
            }
        }
        double[] dArr2 = new double[labelAlphabet.size()];
        double d2 = 0.0d;
        for (int i4 = 0; i4 < labelAlphabet.size(); i4++) {
            double d3 = dArr[i4] / length;
            dArr2[i4] = d3;
            if (d3 > 0.0d) {
                d2 -= (d3 * Math.log(d3)) / log2;
            }
        }
        LabelVector labelVector = new LabelVector(labelAlphabet, dArr2);
        double d4 = 0.0d;
        int i5 = 0;
        double[] dArr3 = new double[labelAlphabet.size()];
        Hashtable[] hashtableArr = new Hashtable[dataAlphabet.size()];
        for (int i6 = 0; i6 < dataAlphabet.size(); i6++) {
            if ((i6 + 1) % 1000 == 0) {
                logger.info("at feature " + (i6 + 1) + " / " + dataAlphabet.size());
            }
            hashtableArr[i6] = new Hashtable();
            Arrays.fill(dArr3, 0.0d);
            iArr = sortInstances(instanceList, iArr, i6);
            for (int i7 = 0; i7 < length - 1; i7++) {
                Instance instanceList2 = instanceList.getInstance(iArr[i7]);
                Instance instanceList3 = instanceList.getInstance(iArr[i7 + 1]);
                FeatureVector featureVector = (FeatureVector) instanceList2.getData();
                FeatureVector featureVector2 = (FeatureVector) instanceList3.getData();
                double value = featureVector.value(i6);
                double value2 = featureVector2.value(i6);
                Labeling labeling2 = instanceList2.getLabeling();
                for (int i8 = 0; i8 < labeling2.numLocations(); i8++) {
                    int indexAtLocation2 = labeling2.indexAtLocation(i8);
                    dArr3[indexAtLocation2] = dArr3[indexAtLocation2] + labeling2.valueAtLocation(i8);
                }
                if (!Maths.almostEquals(value, value2) && !instanceList2.getLabeling().toString().equals(instanceList3.getLabeling().toString())) {
                    i5++;
                    double d5 = (value + value2) / 2.0d;
                    double d6 = i7 + 1;
                    double d7 = length - d6;
                    if (d6 >= i && d7 >= i) {
                        double d8 = d6 / length;
                        if (!Maths.almostEquals(d8, 0.0d) && !Maths.almostEquals(d8, 1.0d)) {
                            double d9 = 0.0d;
                            double d10 = 0.0d;
                            for (int i9 = 0; i9 < labelAlphabet.size(); i9++) {
                                if (d6 > 0.0d) {
                                    double d11 = dArr3[i9] / d6;
                                    if (d11 > 0.0d) {
                                        d9 -= (d11 * Math.log(d11)) / log2;
                                    }
                                }
                                if (d7 > 0.0d) {
                                    double d12 = (dArr[i9] - dArr3[i9]) / d7;
                                    if (d12 > 0.0d) {
                                        d10 -= (d12 * Math.log(d12)) / log2;
                                    }
                                }
                            }
                            double d13 = (d2 - (d8 * d9)) - ((1.0d - d8) * d10);
                            d4 += d13;
                            hashtableArr[i6].put(new Double(d5), new Point2D.Double(d13, d13 / ((((-d8) * Math.log(d8)) / log2) - (((1.0d - d8) * Math.log(1.0d - d8)) / log2))));
                        }
                    }
                }
            }
        }
        double[] dArr4 = new double[dataAlphabet.size()];
        double[] dArr5 = new double[dataAlphabet.size()];
        int i10 = 0;
        if (i5 == 0 || Maths.almostEquals(d4, 0.0d)) {
            return new Object[]{dArr4, dArr5, new Double(d2), labelVector, new Integer(0)};
        }
        double d14 = d4 / i5;
        double d15 = 0.0d;
        double d16 = 0.0d;
        int i11 = 0;
        for (int i12 = 0; i12 < dataAlphabet.size(); i12++) {
            double d17 = 0.0d;
            double d18 = 0.0d;
            double d19 = Double.NaN;
            for (Object obj : hashtableArr[i12].keySet()) {
                Point2D.Double r0 = (Point2D.Double) hashtableArr[i12].get(obj);
                double doubleValue = ((Double) obj).doubleValue();
                double x = r0.getX();
                double y = r0.getY();
                if (x < d14) {
                    i11++;
                } else if (y > d17 || (y == d17 && x > d18)) {
                    d17 = y;
                    d18 = x;
                    d19 = doubleValue;
                }
            }
            if (!$assertionsDisabled && d19 == Double.NaN) {
                throw new AssertionError();
            }
            dArr4[i12] = d17;
            dArr5[i12] = d19;
            if (d17 > d15 || (d17 == d15 && d18 > d16)) {
                d15 = d17;
                d16 = d18;
                i10 = hashtableArr[i12].size();
            }
        }
        logger.info("label distrib:\n" + labelVector);
        logger.info("base entropy=" + d2 + ", info gain sum=" + d4 + ", total num split points=" + i5 + ", avg info gain=" + d14 + ", num splits with < avg gain=" + i11);
        return new Object[]{dArr4, dArr5, new Double(d2), labelVector, new Integer(i10)};
    }

    public static int[] sortInstances(InstanceList instanceList, int[] iArr, int i) {
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < iArr.length; i2++) {
            arrayList.add(new Point2D.Double(iArr[i2], ((FeatureVector) instanceList.getInstance(iArr[i2]).getData()).value(i)));
        }
        Collections.sort(arrayList, new Comparator() { // from class: edu.umass.cs.mallet.base.types.GainRatio.1
            @Override // java.util.Comparator
            public int compare(Object obj, Object obj2) {
                Point2D.Double r0 = (Point2D.Double) obj;
                Point2D.Double r02 = (Point2D.Double) obj2;
                if (r0.y != r02.y) {
                    return r0.y > r02.y ? 1 : -1;
                }
                if (GainRatio.$assertionsDisabled || r0.x != r02.x) {
                    return r0.x > r02.x ? 1 : -1;
                }
                throw new AssertionError();
            }
        });
        int[] iArr2 = new int[iArr.length];
        for (int i3 = 0; i3 < arrayList.size(); i3++) {
            iArr2[i3] = (int) ((Point2D.Double) arrayList.get(i3)).getX();
        }
        return iArr2;
    }

    public static GainRatio createGainRatio(InstanceList instanceList) {
        int[] iArr = new int[instanceList.size()];
        for (int i = 0; i < iArr.length; i++) {
            iArr[i] = i;
        }
        return createGainRatio(instanceList, iArr, 2);
    }

    public static GainRatio createGainRatio(InstanceList instanceList, int[] iArr, int i) {
        Object[] calcGainRatios = calcGainRatios(instanceList, iArr, i);
        return new GainRatio(instanceList.getDataAlphabet(), (double[]) calcGainRatios[0], (double[]) calcGainRatios[1], ((Double) calcGainRatios[2]).doubleValue(), (LabelVector) calcGainRatios[3], ((Integer) calcGainRatios[4]).intValue(), i);
    }

    protected GainRatio(Alphabet alphabet, double[] dArr, double[] dArr2, double d, LabelVector labelVector, int i, int i2) {
        super(alphabet, dArr);
        this.m_splitPoints = dArr2;
        this.m_baseEntropy = d;
        this.m_baseLabelDistribution = labelVector;
        this.m_numSplitPointsForBestFeature = i;
        this.m_minNumInsts = i2;
    }

    public double getMaxValuedThreshold() {
        return getThresholdAtRank(0);
    }

    public double getThresholdAtRank(int i) {
        return this.m_splitPoints[getIndexAtRank(i)];
    }

    public double getBaseEntropy() {
        return this.m_baseEntropy;
    }

    public LabelVector getBaseLabelDistribution() {
        return this.m_baseLabelDistribution;
    }

    public int getNumSplitPointsForBestFeature() {
        return this.m_numSplitPointsForBestFeature;
    }
}
