/*
 * Decompiled with CFR 0.152.
 */
package weka.classifiers.bayes.net.search.global;

import java.util.Enumeration;
import java.util.Vector;
import weka.classifiers.bayes.BayesNet;
import weka.classifiers.bayes.net.ParentSet;
import weka.classifiers.bayes.net.search.SearchAlgorithm;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.Option;
import weka.core.RevisionUtils;
import weka.core.SelectedTag;
import weka.core.Tag;
import weka.core.Utils;

public class GlobalScoreSearchAlgorithm
extends SearchAlgorithm {
    static final long serialVersionUID = 7341389867906199781L;
    BayesNet m_BayesNet;
    boolean m_bUseProb = true;
    int m_nNrOfFolds = 10;
    static final int LOOCV = 0;
    static final int KFOLDCV = 1;
    static final int CUMCV = 2;
    public static final Tag[] TAGS_CV_TYPE = new Tag[]{new Tag(0, "LOO-CV"), new Tag(1, "k-Fold-CV"), new Tag(2, "Cumulative-CV")};
    int m_nCVType = 0;

    public double calcScore(BayesNet bayesNet) throws Exception {
        switch (this.m_nCVType) {
            case 0: {
                return this.leaveOneOutCV(bayesNet);
            }
            case 2: {
                return this.cumulativeCV(bayesNet);
            }
            case 1: {
                return this.kFoldCV(bayesNet, this.m_nNrOfFolds);
            }
        }
        throw new Exception("Unrecognized cross validation type encountered: " + this.m_nCVType);
    }

    public double calcScoreWithExtraParent(int nNode, int nCandidateParent) throws Exception {
        ParentSet oParentSet = this.m_BayesNet.getParentSet(nNode);
        Instances instances = this.m_BayesNet.m_Instances;
        int iParent = 0;
        while (iParent < oParentSet.getNrOfParents()) {
            if (oParentSet.getParent(iParent) == nCandidateParent) {
                return -1.0E100;
            }
            ++iParent;
        }
        oParentSet.addParent(nCandidateParent, instances);
        double fAccuracy = this.calcScore(this.m_BayesNet);
        oParentSet.deleteLastParent(instances);
        return fAccuracy;
    }

    public double calcScoreWithMissingParent(int nNode, int nCandidateParent) throws Exception {
        ParentSet oParentSet = this.m_BayesNet.getParentSet(nNode);
        Instances instances = this.m_BayesNet.m_Instances;
        if (!oParentSet.contains(nCandidateParent)) {
            return -1.0E100;
        }
        int iParent = oParentSet.deleteParent(nCandidateParent, instances);
        double fAccuracy = this.calcScore(this.m_BayesNet);
        oParentSet.addParent(nCandidateParent, iParent, instances);
        return fAccuracy;
    }

    public double calcScoreWithReversedParent(int nNode, int nCandidateParent) throws Exception {
        ParentSet oParentSet = this.m_BayesNet.getParentSet(nNode);
        ParentSet oParentSet2 = this.m_BayesNet.getParentSet(nCandidateParent);
        Instances instances = this.m_BayesNet.m_Instances;
        if (!oParentSet.contains(nCandidateParent)) {
            return -1.0E100;
        }
        int iParent = oParentSet.deleteParent(nCandidateParent, instances);
        oParentSet2.addParent(nNode, instances);
        double fAccuracy = this.calcScore(this.m_BayesNet);
        oParentSet2.deleteLastParent(instances);
        oParentSet.addParent(nCandidateParent, iParent, instances);
        return fAccuracy;
    }

    public double leaveOneOutCV(BayesNet bayesNet) throws Exception {
        this.m_BayesNet = bayesNet;
        double fAccuracy = 0.0;
        double fWeight = 0.0;
        Instances instances = bayesNet.m_Instances;
        bayesNet.estimateCPTs();
        int iInstance = 0;
        while (iInstance < instances.numInstances()) {
            Instance instance = instances.instance(iInstance);
            instance.setWeight(-instance.weight());
            bayesNet.updateClassifier(instance);
            fAccuracy += this.accuracyIncrease(instance);
            fWeight += instance.weight();
            instance.setWeight(-instance.weight());
            bayesNet.updateClassifier(instance);
            ++iInstance;
        }
        return fAccuracy / fWeight;
    }

    public double cumulativeCV(BayesNet bayesNet) throws Exception {
        this.m_BayesNet = bayesNet;
        double fAccuracy = 0.0;
        double fWeight = 0.0;
        Instances instances = bayesNet.m_Instances;
        bayesNet.initCPTs();
        int iInstance = 0;
        while (iInstance < instances.numInstances()) {
            Instance instance = instances.instance(iInstance);
            fAccuracy += this.accuracyIncrease(instance);
            bayesNet.updateClassifier(instance);
            fWeight += instance.weight();
            ++iInstance;
        }
        return fAccuracy / fWeight;
    }

    public double kFoldCV(BayesNet bayesNet, int nNrOfFolds) throws Exception {
        this.m_BayesNet = bayesNet;
        double fAccuracy = 0.0;
        double fWeight = 0.0;
        Instances instances = bayesNet.m_Instances;
        bayesNet.estimateCPTs();
        int nFoldStart = 0;
        int nFoldEnd = instances.numInstances() / nNrOfFolds;
        int iFold = 1;
        while (nFoldStart < instances.numInstances()) {
            Instance instance;
            int iInstance = nFoldStart;
            while (iInstance < nFoldEnd) {
                instance = instances.instance(iInstance);
                instance.setWeight(-instance.weight());
                bayesNet.updateClassifier(instance);
                ++iInstance;
            }
            iInstance = nFoldStart;
            while (iInstance < nFoldEnd) {
                instance = instances.instance(iInstance);
                instance.setWeight(-instance.weight());
                fAccuracy += this.accuracyIncrease(instance);
                instance.setWeight(-instance.weight());
                fWeight += instance.weight();
                ++iInstance;
            }
            iInstance = nFoldStart;
            while (iInstance < nFoldEnd) {
                instance = instances.instance(iInstance);
                instance.setWeight(-instance.weight());
                bayesNet.updateClassifier(instance);
                ++iInstance;
            }
            nFoldStart = nFoldEnd;
            nFoldEnd = ++iFold * instances.numInstances() / nNrOfFolds;
        }
        return fAccuracy / fWeight;
    }

    double accuracyIncrease(Instance instance) throws Exception {
        if (this.m_bUseProb) {
            double[] fProb = this.m_BayesNet.distributionForInstance(instance);
            return fProb[(int)instance.classValue()] * instance.weight();
        }
        if (this.m_BayesNet.classifyInstance(instance) == instance.classValue()) {
            return instance.weight();
        }
        return 0.0;
    }

    public boolean getUseProb() {
        return this.m_bUseProb;
    }

    public void setUseProb(boolean useProb) {
        this.m_bUseProb = useProb;
    }

    public void setCVType(SelectedTag newCVType) {
        if (newCVType.getTags() == TAGS_CV_TYPE) {
            this.m_nCVType = newCVType.getSelectedTag().getID();
        }
    }

    public SelectedTag getCVType() {
        return new SelectedTag(this.m_nCVType, TAGS_CV_TYPE);
    }

    @Override
    public void setMarkovBlanketClassifier(boolean bMarkovBlanketClassifier) {
        super.setMarkovBlanketClassifier(bMarkovBlanketClassifier);
    }

    @Override
    public boolean getMarkovBlanketClassifier() {
        return super.getMarkovBlanketClassifier();
    }

    @Override
    public Enumeration listOptions() {
        Vector<Option> newVector = new Vector<Option>();
        newVector.addElement(new Option("\tApplies a Markov Blanket correction to the network structure, \n\tafter a network structure is learned. This ensures that all \n\tnodes in the network are part of the Markov blanket of the \n\tclassifier node.", "mbc", 0, "-mbc"));
        newVector.addElement(new Option("\tScore type (LOO-CV,k-Fold-CV,Cumulative-CV)", "S", 1, "-S [LOO-CV|k-Fold-CV|Cumulative-CV]"));
        newVector.addElement(new Option("\tUse probabilistic or 0/1 scoring.\n\t(default probabilistic scoring)", "Q", 0, "-Q"));
        Enumeration enu = super.listOptions();
        while (enu.hasMoreElements()) {
            newVector.addElement((Option)enu.nextElement());
        }
        return newVector.elements();
    }

    @Override
    public void setOptions(String[] options) throws Exception {
        this.setMarkovBlanketClassifier(Utils.getFlag("mbc", options));
        String sScore = Utils.getOption('S', options);
        if (sScore.compareTo("LOO-CV") == 0) {
            this.setCVType(new SelectedTag(0, TAGS_CV_TYPE));
        }
        if (sScore.compareTo("k-Fold-CV") == 0) {
            this.setCVType(new SelectedTag(1, TAGS_CV_TYPE));
        }
        if (sScore.compareTo("Cumulative-CV") == 0) {
            this.setCVType(new SelectedTag(2, TAGS_CV_TYPE));
        }
        this.setUseProb(!Utils.getFlag('Q', options));
        super.setOptions(options);
    }

    @Override
    public String[] getOptions() {
        String[] superOptions = super.getOptions();
        String[] options = new String[4 + superOptions.length];
        int current = 0;
        if (this.getMarkovBlanketClassifier()) {
            options[current++] = "-mbc";
        }
        options[current++] = "-S";
        switch (this.m_nCVType) {
            case 0: {
                options[current++] = "LOO-CV";
                break;
            }
            case 1: {
                options[current++] = "k-Fold-CV";
                break;
            }
            case 2: {
                options[current++] = "Cumulative-CV";
            }
        }
        if (!this.getUseProb()) {
            options[current++] = "-Q";
        }
        int iOption = 0;
        while (iOption < superOptions.length) {
            options[current++] = superOptions[iOption];
            ++iOption;
        }
        while (current < options.length) {
            options[current++] = "";
        }
        return options;
    }

    public String CVTypeTipText() {
        return "Select cross validation strategy to be used in searching for networks.LOO-CV = Leave one out cross validation\nk-Fold-CV = k fold cross validation\nCumulative-CV = cumulative cross validation.";
    }

    public String useProbTipText() {
        return "If set to true, the probability of the class if returned in the estimate of the accuracy. If set to false, the accuracy estimate is only increased if the classifier returns exactly the correct class.";
    }

    public String globalInfo() {
        return "This Bayes Network learning algorithm uses cross validation to estimate classification accuracy.";
    }

    @Override
    public String markovBlanketClassifierTipText() {
        return super.markovBlanketClassifierTipText();
    }

    @Override
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 1.10 $");
    }
}

