package weka.classifiers.bayes.net.search.global;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;

/* JADX WARN: Classes with same name are omitted:
  input_file:WEB-INF/lib/bibsonomy-scraper-2.0.1.jar:org/bibsonomy/scraper/ie/training/mallet.jar:weka/classifiers/bayes/net/search/global/GlobalScoreSearchAlgorithm.class
 */
/* loaded from: input_file:WEB-INF/lib/mallet-0.4-steuber.jar:weka/classifiers/bayes/net/search/global/GlobalScoreSearchAlgorithm.class */
public class GlobalScoreSearchAlgorithm extends SearchAlgorithm {
    BayesNet m_BayesNet;
    static final int LOOCV = 0;
    static final int KFOLDCV = 1;
    static final int CUMCV = 2;
    public static final Tag[] TAGS_CV_TYPE = {new Tag(0, "LOO-CV"), new Tag(1, "k-Fold-CV"), new Tag(2, "Cumulative-CV")};
    boolean m_bUseProb = true;
    int m_nNrOfFolds = 10;
    int m_nCVType = 0;

    public double calcScore(BayesNet bayesNet) throws Exception {
        switch (this.m_nCVType) {
            case 0:
                return leaveOneOutCV(bayesNet);
            case 1:
                return kFoldCV(bayesNet, this.m_nNrOfFolds);
            case 2:
                return cumulativeCV(bayesNet);
            default:
                throw new Exception(new StringBuffer().append("Unrecognized cross validation type encountered: ").append(this.m_nCVType).toString());
        }
    }

    public double calcScoreWithExtraParent(int i, int i2) throws Exception {
        ParentSet parentSet = this.m_BayesNet.getParentSet(i);
        Instances instances = this.m_BayesNet.m_Instances;
        for (int i3 = 0; i3 < parentSet.getNrOfParents(); i3++) {
            if (parentSet.getParent(i3) == i2) {
                return -1.0E100d;
            }
        }
        int[][] iArr = new int[parentSet.getCardinalityOfParents() * instances.attribute(i2).numValues()][instances.attribute(i).numValues()];
        parentSet.addParent(i2, instances);
        double calcScore = calcScore(this.m_BayesNet);
        parentSet.deleteLastParent(instances);
        return calcScore;
    }

    public double calcScoreWithMissingParent(int i, int i2) throws Exception {
        ParentSet parentSet = this.m_BayesNet.getParentSet(i);
        Instances instances = this.m_BayesNet.m_Instances;
        if (!parentSet.contains(i2)) {
            return -1.0E100d;
        }
        int deleteParent = parentSet.deleteParent(i2, instances);
        int[][] iArr = new int[parentSet.getCardinalityOfParents()][instances.attribute(i).numValues()];
        double calcScore = calcScore(this.m_BayesNet);
        parentSet.addParent(i2, deleteParent, instances);
        return calcScore;
    }

    public double calcScoreWithReversedParent(int i, int i2) throws Exception {
        ParentSet parentSet = this.m_BayesNet.getParentSet(i);
        ParentSet parentSet2 = this.m_BayesNet.getParentSet(i2);
        Instances instances = this.m_BayesNet.m_Instances;
        if (!parentSet.contains(i2)) {
            return -1.0E100d;
        }
        int deleteParent = parentSet.deleteParent(i2, instances);
        parentSet2.addParent(i, instances);
        int[][] iArr = new int[parentSet.getCardinalityOfParents()][instances.attribute(i).numValues()];
        double calcScore = calcScore(this.m_BayesNet);
        parentSet2.deleteLastParent(instances);
        parentSet.addParent(i2, deleteParent, instances);
        return calcScore;
    }

    public double leaveOneOutCV(BayesNet bayesNet) throws Exception {
        this.m_BayesNet = bayesNet;
        double d = 0.0d;
        double d2 = 0.0d;
        Instances instances = bayesNet.m_Instances;
        bayesNet.estimateCPTs();
        for (int i = 0; i < instances.numInstances(); i++) {
            Instance instance = instances.instance(i);
            instance.setWeight(-instance.weight());
            bayesNet.updateClassifier(instance);
            d += accuracyIncrease(instance);
            d2 += instance.weight();
            instance.setWeight(-instance.weight());
            bayesNet.updateClassifier(instance);
        }
        return d / d2;
    }

    public double cumulativeCV(BayesNet bayesNet) throws Exception {
        this.m_BayesNet = bayesNet;
        double d = 0.0d;
        double d2 = 0.0d;
        Instances instances = bayesNet.m_Instances;
        bayesNet.initCPTs();
        for (int i = 0; i < instances.numInstances(); i++) {
            Instance instance = instances.instance(i);
            d += accuracyIncrease(instance);
            bayesNet.updateClassifier(instance);
            d2 += instance.weight();
        }
        return d / d2;
    }

    public double kFoldCV(BayesNet bayesNet, int i) throws Exception {
        this.m_BayesNet = bayesNet;
        double d = 0.0d;
        double d2 = 0.0d;
        Instances instances = bayesNet.m_Instances;
        bayesNet.estimateCPTs();
        int i2 = 0;
        int numInstances = instances.numInstances() / i;
        int i3 = 1;
        while (i2 < instances.numInstances()) {
            for (int i4 = i2; i4 < numInstances; i4++) {
                Instance instance = instances.instance(i4);
                instance.setWeight(-instance.weight());
                bayesNet.updateClassifier(instance);
            }
            for (int i5 = i2; i5 < numInstances; i5++) {
                Instance instance2 = instances.instance(i5);
                instance2.setWeight(-instance2.weight());
                d += accuracyIncrease(instance2);
                instance2.setWeight(-instance2.weight());
                d2 += instance2.weight();
            }
            for (int i6 = i2; i6 < numInstances; i6++) {
                Instance instance3 = instances.instance(i6);
                instance3.setWeight(-instance3.weight());
                bayesNet.updateClassifier(instance3);
            }
            i2 = numInstances;
            i3++;
            numInstances = (i3 * instances.numInstances()) / i;
        }
        return d / d2;
    }

    double accuracyIncrease(Instance instance) throws Exception {
        if (this.m_bUseProb) {
            return this.m_BayesNet.distributionForInstance(instance)[(int) instance.classValue()] * instance.weight();
        }
        if (this.m_BayesNet.classifyInstance(instance) == instance.classValue()) {
            return instance.weight();
        }
        return 0.0d;
    }

    public boolean getUseProb() {
        return this.m_bUseProb;
    }

    public void setUseProb(boolean z) {
        this.m_bUseProb = z;
    }

    public void setCVType(SelectedTag selectedTag) {
        if (selectedTag.getTags() == TAGS_CV_TYPE) {
            this.m_nCVType = selectedTag.getSelectedTag().getID();
        }
    }

    public SelectedTag getCVType() {
        return new SelectedTag(this.m_nCVType, TAGS_CV_TYPE);
    }

    @Override // weka.classifiers.bayes.net.search.SearchAlgorithm, weka.core.OptionHandler
    public Enumeration listOptions() {
        Vector vector = new Vector(2);
        vector.addElement(new Option("\tScore type (LOO-CV,k-Fold-CV,Cumulative-CV)\n", "S", 1, "-S [LOO-CV|k-Fold-CV|Cumulative-CV]"));
        vector.addElement(new Option("\tUse probabilistic or 0/1 scoring.\n\t(default probabilistic scoring)", "Q", 0, "-Q"));
        Enumeration listOptions = super.listOptions();
        while (listOptions.hasMoreElements()) {
            vector.addElement(listOptions.nextElement());
        }
        return vector.elements();
    }

    @Override // weka.classifiers.bayes.net.search.SearchAlgorithm, weka.core.OptionHandler
    public void setOptions(String[] strArr) throws Exception {
        String option = Utils.getOption('S', strArr);
        if (option.compareTo("LOO-CV") == 0) {
            setCVType(new SelectedTag(0, TAGS_CV_TYPE));
        }
        if (option.compareTo("k-Fold-CV") == 0) {
            setCVType(new SelectedTag(1, TAGS_CV_TYPE));
        }
        if (option.compareTo("Cumulative-CV") == 0) {
            setCVType(new SelectedTag(2, TAGS_CV_TYPE));
        }
        setUseProb(!Utils.getFlag('Q', strArr));
        super.setOptions(strArr);
    }

    @Override // weka.classifiers.bayes.net.search.SearchAlgorithm, weka.core.OptionHandler
    public String[] getOptions() {
        String[] strArr = new String[3 + super.getOptions().length];
        int i = 0 + 1;
        strArr[0] = "-S";
        switch (this.m_nCVType) {
            case 0:
                i++;
                strArr[i] = "LOO-CV";
                break;
            case 1:
                i++;
                strArr[i] = "k-Fold-CV";
                break;
            case 2:
                i++;
                strArr[i] = "Cumulative-CV";
                break;
        }
        if (getUseProb()) {
            int i2 = i;
            i++;
            strArr[i2] = "-Q";
        }
        while (i < strArr.length) {
            int i3 = i;
            i++;
            strArr[i3] = "";
        }
        return strArr;
    }

    public String CVTypeTipText() {
        return "Select cross validation strategy to be used in searching for networks.LOO-CV = Leave one out cross validation\nk-Fold-CV = k fold cross validation\nCumulative-CV = cumulative cross validation.";
    }

    public String useProbTipText() {
        return "If set to true, the probability of the class if returned in the estimate of the accuracy. If set to false, the accuracy estimate is only increased if the classifier returns exactly the correct class.";
    }

    public String globalInfo() {
        return "This Bayes Network learning algorithm uses cross validation to estimate classification accuracy.";
    }
}
