package weka.classifiers;

import java.io.Reader;
import java.io.StreamTokenizer;
import java.util.Random;
import weka.core.Instances;
import weka.core.Matrix;
import weka.core.Utils;

/* loaded from: input_file:org/bibsonomy/scraper/ie/training/mallet.jar:weka/classifiers/CostMatrix.class */
public class CostMatrix extends Matrix {
    public static String FILE_EXTENSION = ".cost";

    public CostMatrix(CostMatrix costMatrix) {
        super(costMatrix.size(), costMatrix.size());
        for (int i = 0; i < costMatrix.size(); i++) {
            for (int i2 = 0; i2 < costMatrix.size(); i2++) {
                setElement(i, i2, costMatrix.getElement(i, i2));
            }
        }
    }

    public CostMatrix(int i) {
        super(i, i);
    }

    public CostMatrix(Reader reader) throws Exception {
        super(reader);
        if (numRows() != numColumns()) {
            throw new Exception("Trying to create a non-square cost matrix");
        }
    }

    public void initialize() {
        int i = 0;
        while (i < size()) {
            int i2 = 0;
            while (i2 < size()) {
                setElement(i, i2, i == i2 ? 0.0d : 1.0d);
                i2++;
            }
            i++;
        }
    }

    public int size() {
        return numColumns();
    }

    public Instances applyCostMatrix(Instances instances, Random random) throws Exception {
        double d = 0.0d;
        if (instances.classIndex() < 0) {
            throw new Exception("Class index is not set!");
        }
        if (size() != instances.numClasses()) {
            throw new Exception("Misclassification cost matrix has wrong format!");
        }
        double[] dArr = new double[instances.numClasses()];
        double[] dArr2 = new double[instances.numClasses()];
        for (int i = 0; i < instances.numInstances(); i++) {
            int classValue = (int) instances.instance(i).classValue();
            dArr2[classValue] = dArr2[classValue] + instances.instance(i).weight();
        }
        double sum = Utils.sum(dArr2);
        for (int i2 = 0; i2 < size(); i2++) {
            if (!Utils.eq(getElement(i2, i2), 0.0d)) {
                CostMatrix costMatrix = new CostMatrix(this);
                costMatrix.normalize();
                return costMatrix.applyCostMatrix(instances, random);
            }
        }
        for (int i3 = 0; i3 < instances.numClasses(); i3++) {
            double d2 = 0.0d;
            for (int i4 = 0; i4 < instances.numClasses(); i4++) {
                if (Utils.sm(getElement(i3, i4), 0.0d)) {
                    throw new Exception("Neg. weights in misclassification cost matrix!");
                }
                d2 += getElement(i3, i4);
            }
            dArr[i3] = d2 * sum;
            d += d2 * dArr2[i3];
        }
        for (int i5 = 0; i5 < instances.numClasses(); i5++) {
            int i6 = i5;
            dArr[i6] = dArr[i6] / d;
        }
        double[] dArr3 = new double[instances.numInstances()];
        for (int i7 = 0; i7 < instances.numInstances(); i7++) {
            dArr3[i7] = instances.instance(i7).weight() * dArr[(int) instances.instance(i7).classValue()];
        }
        if (random != null) {
            return instances.resampleWithWeights(random, dArr3);
        }
        Instances instances2 = new Instances(instances);
        for (int i8 = 0; i8 < instances.numInstances(); i8++) {
            instances2.instance(i8).setWeight(dArr3[i8]);
        }
        return instances2;
    }

    public double[] expectedCosts(double[] dArr) throws Exception {
        if (dArr.length != size()) {
            throw new Exception("Length of probability estimates don't match cost matrix");
        }
        double[] dArr2 = new double[size()];
        for (int i = 0; i < size(); i++) {
            for (int i2 = 0; i2 < size(); i2++) {
                int i3 = i;
                dArr2[i3] = dArr2[i3] + (dArr[i2] * getElement(i, i2));
            }
        }
        return dArr2;
    }

    public double getMaxCost(int i) {
        double d = Double.NEGATIVE_INFINITY;
        for (int i2 = 0; i2 < size(); i2++) {
            double element = getElement(i, i2);
            if (element > d) {
                d = element;
            }
        }
        return d;
    }

    public void normalize() {
        for (int i = 0; i < size(); i++) {
            double element = getElement(i, i);
            for (int i2 = 0; i2 < size(); i2++) {
                setElement(i2, i, getElement(i2, i) - element);
            }
        }
    }

    public void readOldFormat(Reader reader) throws Exception {
        StreamTokenizer streamTokenizer = new StreamTokenizer(reader);
        initialize();
        streamTokenizer.commentChar(37);
        streamTokenizer.eolIsSignificant(true);
        while (true) {
            int nextToken = streamTokenizer.nextToken();
            if (-1 == nextToken) {
                return;
            }
            if (nextToken != 10) {
                if (nextToken != -2) {
                    throw new Exception("Only numbers and comments allowed in cost file!");
                }
                double d = streamTokenizer.nval;
                if (!Utils.eq((int) d, d)) {
                    throw new Exception("First number in line has to be index of a class!");
                }
                if (((int) d) >= size()) {
                    throw new Exception("Class index out of range!");
                }
                int nextToken2 = streamTokenizer.nextToken();
                if (-1 == nextToken2) {
                    throw new Exception("Premature end of file!");
                }
                if (nextToken2 == 10) {
                    throw new Exception("Premature end of line!");
                }
                if (nextToken2 != -2) {
                    throw new Exception("Only numbers and comments allowed in cost file!");
                }
                double d2 = streamTokenizer.nval;
                if (!Utils.eq((int) d2, d2)) {
                    throw new Exception("Second number in line has to be index of a class!");
                }
                if (((int) d2) >= size()) {
                    throw new Exception("Class index out of range!");
                }
                if (((int) d2) == ((int) d)) {
                    throw new Exception("Diagonal of cost matrix non-zero!");
                }
                int nextToken3 = streamTokenizer.nextToken();
                if (-1 == nextToken3) {
                    throw new Exception("Premature end of file!");
                }
                if (nextToken3 == 10) {
                    throw new Exception("Premature end of line!");
                }
                if (nextToken3 != -2) {
                    throw new Exception("Only numbers and comments allowed in cost file!");
                }
                double d3 = streamTokenizer.nval;
                if (!Utils.gr(d3, 0.0d)) {
                    throw new Exception("Only positive weights allowed!");
                }
                setElement((int) d, (int) d2, d3);
            }
        }
    }
}
