当前位置:网站首页>[5 machine learning] the most understandable decision tree in the whole network (with source code attached)

[5 machine learning] the most understandable decision tree in the whole network (with source code attached)

2022-06-09 09:39:00 Ofter Data Science

Most machine learning algorithms can be classified as supervised learning (supervised learning) Or unsupervised learning (unsupervised learning). For supervised learning , Each data instance in the dataset must contain the target attribute value . therefore , Before using the supervised learning algorithm to train the model , It takes a lot of time and effort to create datasets with target attribute values .

Linear regression 、 neural network 、 Decision trees are members of supervised learning .ofter Linear regression and neural networks have been introduced in detail in the previous two articles , Those who are interested can have a look . Linear regression and neural network are suitable for dealing with numerical input , For example, the input attributes in the data set are mainly nominal , Then it would be more appropriate to use the decision tree model .

1、 Introduction to decision tree

 

1.1 function

Decision tree constructs classification or regression model in the form of tree structure . In fact, through a series of judgments , The leaves finally show the classification ( Accept or not accept work ? Don't accept ) Or return ( How much can second-hand computers be sold ?10 element )

1.2 Construct algorithms and metrics

ID3( Information gain )、C4.5( Information gain rate )、C5.0(C4.5 Improved version )、CART( The gini coefficient ).

1.3 Construction method

No matter which algorithm and index is used , The idea of constructing tree nodes is the same . such as , We have the data set shown in the figure below

chart 1-1 Loan qualification

Through the calculation of one or more indicators , Get which attribute ( age group / There's work / Have your own house / Credit situation / Whether to give a loan or not ) Which tree node should be displayed ? Of course, the index calculation score is too low , Then we can prune , That is, the attribute is not displayed . We use C4.5 Algorithm , Take a look at the calculated decision tree ?

chart 1-2 C4.5 Decision tree

2、 Construction algorithm

2.1 ID3

𝑰𝒏𝒇𝒐𝒓𝒎𝒂𝒕𝒊𝒐𝒏 𝑮𝒂𝒊𝒏 = 𝑬𝒏𝒕𝒓𝒐𝒑𝒚(before) − 𝑬𝒏𝒕𝒓𝒐𝒑𝒚(after)

ID3 The algorithm uses information gain to select attributes . According to the picture 1-1 Data set of , use ID3 Algorithm constructs decision tree .

chart 2-1 ID3 Decision tree

Let's look at the calculation process , The calculation process of the first tree node is shown in the following figure :

chart 2-2 ID3 First best index

Obviously , We see the first 2 Features ( Have your own house ) The information gain is optimal .

chart 2-3 Dataset properties

However , There is a problem with information gain , It prefers to select attributes with more values in the dataset . therefore , With C4.5 Algorithm .

2.2 C4.5

𝑮𝒂𝒊𝒏𝑹𝒂𝒕𝒊𝒐 𝑨 =𝑮𝒂𝒊𝒏(𝑨) / 𝑺𝒑𝒍𝒊𝒕𝑰𝒏𝒇𝒐 (𝑨)

C4.5 The algorithm uses the information gain rate to select attributes . According to the picture 1-1 Data set of , use C4.5 Algorithm constructs decision tree .

chart 2-4 C4.5 Decision tree

Let's look at the calculation process , The calculation process of the first tree node is shown in the following figure :

chart 2-5 C4.5 First best index

Obviously , We see the first 2 Features ( Have your own house ) The information gain rate is optimal .

2.3 CART

CART The algorithm uses Gini coefficients to select attributes . According to the picture 1-1 Data set of , use CART Algorithm constructs decision tree .

chart 2-6 CART Decision tree

Let's look at the calculation process , The calculation process of the first tree node is shown in the following figure :

chart 2-7 CART First best index

It needs to be explained here : The maximum Gini coefficient is “1”, The minimum is equal to “0”. The closer the Gini coefficient is to 0 It shows that the distribution tends to be more equal . let me put it another way , If completely classified , The Gini index will be zero . We need to choose features with low Gini coefficient .

3、 Application of decision tree

The real challenge in machine learning applications is to find the algorithm that best matches a particular data set in the learning bias . therefore , We need to understand each model / Application scenario of the algorithm .

Case study 1:【 classification 】 The banking system reviews the lender's qualification

Case study 2:【 Return to / probability 】 Whether the employee will leave

Case study 3:【 Return to / value 】 Forecast the value of second-hand goods

4、 Source code

The complete source code used in this case :

tree.py

from math import log
import operator
import treePlotter
from collections import Counter
pre_pruning = True
post_pruning = True
def read_dataset(filename):
    """
     age group :0 On behalf of youth ,1 For middle age ,2 For old age ;
     There's work :0 It means No ,1 Representative is ;
     Have your own house :0 It means No ,1 Representative is ;
     Credit situation :0 On behalf of the general ,1 Good for ,2 The representative is very good ;
     Category ( Whether to give a loan or not ):0 It means No ,1 Representative is 
    """
    fr = open(filename, 'r')
    all_lines = fr.readlines()  # list form , Every act 1 individual str
    # print all_lines
    labels = [' age group ', ' There's work ', ' Have your own house ', ' Credit situation ']
    # featname=all_lines[0].strip().split(',')  #list form 
    # featname=featname[:-1]
    labelCounts = {}
    dataset = []
    for line in all_lines[0:]:
        line = line.strip().split(',')  #  Split the list with a comma as the separator 
        dataset.append(line)
    return dataset, labels


def read_testset(testfile):
    """
     age group :0 On behalf of youth ,1 For middle age ,2 For old age ;
     There's work :0 It means No ,1 Representative is ;
     Have your own house :0 It means No ,1 Representative is ;
     Credit situation :0 On behalf of the general ,1 Good for ,2 The representative is very good ;
     Category ( Whether to give a loan or not ):0 It means No ,1 Representative is 
    """
    fr = open(testfile, 'r')
    all_lines = fr.readlines()
    testset = []
    for line in all_lines[0:]:
        line = line.strip().split(',')  #  Split the list with a comma as the separator 
        testset.append(line)
    return testset


#  Calculate entropy of information 
def cal_entropy(dataset):
    numEntries = len(dataset)
    labelCounts = {}
    #  Create a dictionary for all possible categories 
    for featVec in dataset:
        currentlabel = featVec[-1]
        if currentlabel not in labelCounts.keys():
            labelCounts[currentlabel] = 0
        labelCounts[currentlabel] += 1
    Ent = 0.0
    for key in labelCounts:
        p = float(labelCounts[key]) / numEntries
        Ent = Ent - p * log(p, 2)  #  With 2 Find the logarithm of the base 
    return Ent


#  Divide the data set 
def splitdataset(dataset, axis, value):
    retdataset = []  #  Create a list of returned datasets 
    for featVec in dataset:  #  Extract the value that conforms to the partition characteristics 
        if featVec[axis] == value:
            reducedfeatVec = featVec[:axis]  #  Get rid of axis features 
            reducedfeatVec.extend(featVec[axis + 1:])  #  Add the eligible features to the returned dataset list 
            retdataset.append(reducedfeatVec)
    return retdataset


'''
 Choose the best way to partition data sets 
ID3 Algorithm : Select partition attributes based on information gain 
C4.5 Algorithm : Use “ Gain rate ” To select partition properties 
'''


# ID3 Algorithm 
def ID3_chooseBestFeatureToSplit(dataset):
    numFeatures = len(dataset[0]) - 1
    baseEnt = cal_entropy(dataset)
    bestInfoGain = 0.0
    bestFeature = -1
    for i in range(numFeatures):  #  Traverse all features 
        # for example in dataset:
        # featList=example[i]
        featList = [example[i] for example in dataset]
        uniqueVals = set(featList)  #  Create a feature list as set aggregate , Element is not repeatable . Create a unique list of categories 
        newEnt = 0.0
        for value in uniqueVals:  #  Calculate the information entropy of each partition 
            subdataset = splitdataset(dataset, i, value)
            p = len(subdataset) / float(len(dataset))
            newEnt += p * cal_entropy(subdataset)
        infoGain = baseEnt - newEnt
        print(u"ID3 pass the civil examinations %d The information gain of each feature is :%.3f" % (i, infoGain))
        if (infoGain > bestInfoGain):
            bestInfoGain = infoGain  #  Calculate the best information gain 
            bestFeature = i
    return bestFeature


# C4.5 Algorithm 
def C45_chooseBestFeatureToSplit(dataset):
    numFeatures = len(dataset[0]) - 1
    baseEnt = cal_entropy(dataset)
    bestInfoGain_ratio = 0.0
    bestFeature = -1
    for i in range(numFeatures):  #  Traverse all features 
        featList = [example[i] for example in dataset]
        uniqueVals = set(featList)  #  Create a feature list as set aggregate , Element is not repeatable . Create a unique list of categories 
        newEnt = 0.0
        IV = 0.0
        for value in uniqueVals:  #  Calculate the information entropy of each partition 
            subdataset = splitdataset(dataset, i, value)
            p = len(subdataset) / float(len(dataset))
            newEnt += p * cal_entropy(subdataset)
            IV = IV - p * log(p, 2)
        infoGain = baseEnt - newEnt
        if (IV == 0):  # fix the overflow bug
            continue
        infoGain_ratio = infoGain / IV  #  This feature Of infoGain_ratio
        print(u"C4.5 pass the civil examinations %d The information gain rate of each feature is :%.3f" % (i, infoGain_ratio))
        if (infoGain_ratio > bestInfoGain_ratio):  #  Choose the biggest gain ratio
            bestInfoGain_ratio = infoGain_ratio
            bestFeature = i  #  Choose the biggest gain ratio Corresponding feature
    return bestFeature


# CART Algorithm 
def CART_chooseBestFeatureToSplit(dataset):
    numFeatures = len(dataset[0]) - 1
    bestGini = 999999.0
    bestFeature = -1
    for i in range(numFeatures):
        featList = [example[i] for example in dataset]
        uniqueVals = set(featList)
        gini = 0.0
        for value in uniqueVals:
            subdataset = splitdataset(dataset, i, value)
            p = len(subdataset) / float(len(dataset))
            subp = len(splitdataset(subdataset, -1, '0')) / float(len(subdataset))
        gini += p * (1.0 - pow(subp, 2) - pow(1 - subp, 2))
        print(u"CART pass the civil examinations %d The Gini value of each feature is :%.3f" % (i, gini))
        if (gini < bestGini):
            bestGini = gini
            bestFeature = i
    return bestFeature


def majorityCnt(classList):
    '''
     The dataset has processed all the attributes , But class tags are not unique ,
     At this point we need to decide how to define the leaf node , under these circumstances , We usually use majority voting to determine the classification of the leaf nodes 
    '''
    classCont = {}
    for vote in classList:
        if vote not in classCont.keys():
            classCont[vote] = 0
        classCont[vote] += 1
    sortedClassCont = sorted(classCont.items(), key=operator.itemgetter(1), reverse=True)
    return sortedClassCont[0][0]


#  utilize ID3 Algorithm to create decision tree 
def ID3_createTree(dataset, labels, test_dataset):
    classList = [example[-1] for example in dataset]
    if classList.count(classList[0]) == len(classList):
        #  The categories are exactly the same , Stop dividing 
        return classList[0]
    if len(dataset[0]) == 1:
        #  Returns the most frequent occurrence when all features are traversed 
        return majorityCnt(classList)
    bestFeat = ID3_chooseBestFeatureToSplit(dataset)
    bestFeatLabel = labels[bestFeat]
    print(u" At this time, the optimal index is :" + (bestFeatLabel))


    ID3Tree = {bestFeatLabel: {}}
    del (labels[bestFeat])
    #  Get the list, including all attribute values of the node 
    featValues = [example[bestFeat] for example in dataset]
    uniqueVals = set(featValues)
    if pre_pruning:
        ans = []
        for index in range(len(test_dataset)):
            ans.append(test_dataset[index][-1])
        result_counter = Counter()
        for vec in dataset:
            result_counter[vec[-1]] += 1
        leaf_output = result_counter.most_common(1)[0][0]
        root_acc = cal_acc(test_output=[leaf_output] * len(test_dataset), label=ans)
        outputs = []
        ans = []
        for value in uniqueVals:
            cut_testset = splitdataset(test_dataset, bestFeat, value)
            cut_dataset = splitdataset(dataset, bestFeat, value)
            for vec in cut_testset:
                ans.append(vec[-1])
            result_counter = Counter()
            for vec in cut_dataset:
                result_counter[vec[-1]] += 1
            leaf_output = result_counter.most_common(1)[0][0]
            outputs += [leaf_output] * len(cut_testset)
        cut_acc = cal_acc(test_output=outputs, label=ans)


        if cut_acc <= root_acc:
            return leaf_output
    for value in uniqueVals:
        subLabels = labels[:]
        ID3Tree[bestFeatLabel][value] = ID3_createTree(
            splitdataset(dataset, bestFeat, value),
            subLabels,
            splitdataset(test_dataset, bestFeat, value))
    if post_pruning:
        tree_output = classifytest(ID3Tree,
                                   featLabels=[' age group ', ' There's work ', ' Have your own house ', ' Credit situation '],
                                   testDataSet=test_dataset)
        ans = []
        for vec in test_dataset:
            ans.append(vec[-1])
        root_acc = cal_acc(tree_output, ans)
        result_counter = Counter()
        for vec in dataset:
            result_counter[vec[-1]] += 1
        leaf_output = result_counter.most_common(1)[0][0]
        cut_acc = cal_acc([leaf_output] * len(test_dataset), ans)


        if cut_acc >= root_acc:
            return leaf_output


    return ID3Tree


def C45_createTree(dataset, labels, test_dataset):
    classList = [example[-1] for example in dataset]
    if classList.count(classList[0]) == len(classList):
        #  The categories are exactly the same , Stop dividing 
        return classList[0]
    if len(dataset[0]) == 1:
        #  Returns the most frequent occurrence when all features are traversed 
        return majorityCnt(classList)
    bestFeat = C45_chooseBestFeatureToSplit(dataset)
    bestFeatLabel = labels[bestFeat]
    print(u" At this time, the optimal index is :" + (bestFeatLabel))
    C45Tree = {bestFeatLabel: {}}
    del (labels[bestFeat])
    #  Get the list, including all attribute values of the node 
    featValues = [example[bestFeat] for example in dataset]
    uniqueVals = set(featValues)


    if pre_pruning:
        ans = []
        for index in range(len(test_dataset)):
            ans.append(test_dataset[index][-1])
        result_counter = Counter()
        for vec in dataset:
            result_counter[vec[-1]] += 1
        leaf_output = result_counter.most_common(1)[0][0]
        root_acc = cal_acc(test_output=[leaf_output] * len(test_dataset), label=ans)
        outputs = []
        ans = []
        for value in uniqueVals:
            cut_testset = splitdataset(test_dataset, bestFeat, value)
            cut_dataset = splitdataset(dataset, bestFeat, value)
            for vec in cut_testset:
                ans.append(vec[-1])
            result_counter = Counter()
            for vec in cut_dataset:
                result_counter[vec[-1]] += 1
            leaf_output = result_counter.most_common(1)[0][0]
            outputs += [leaf_output] * len(cut_testset)
        cut_acc = cal_acc(test_output=outputs, label=ans)


        if cut_acc <= root_acc:
            return leaf_output


    for value in uniqueVals:
        subLabels = labels[:]
        C45Tree[bestFeatLabel][value] = C45_createTree(
            splitdataset(dataset, bestFeat, value),
            subLabels,
            splitdataset(test_dataset, bestFeat, value))


    if post_pruning:
        tree_output = classifytest(C45Tree,
                                   featLabels=[' age group ', ' There's work ', ' Have your own house ', ' Credit situation '],
                                   testDataSet=test_dataset)
        ans = []
        for vec in test_dataset:
            ans.append(vec[-1])
        root_acc = cal_acc(tree_output, ans)
        result_counter = Counter()
        for vec in dataset:
            result_counter[vec[-1]] += 1
        leaf_output = result_counter.most_common(1)[0][0]
        cut_acc = cal_acc([leaf_output] * len(test_dataset), ans)


        if cut_acc >= root_acc:
            return leaf_output


    return C45Tree




def CART_createTree(dataset, labels, test_dataset):
    classList = [example[-1] for example in dataset]
    if classList.count(classList[0]) == len(classList):
        #  The categories are exactly the same , Stop dividing 
        return classList[0]
    if len(dataset[0]) == 1:
        #  Returns the most frequent occurrence when all features are traversed 
        return majorityCnt(classList)
    bestFeat = CART_chooseBestFeatureToSplit(dataset)
    # print(u" At this time, the optimal index is :"+str(bestFeat))
    bestFeatLabel = labels[bestFeat]
    print(u" At this time, the optimal index is :" + (bestFeatLabel))
    CARTTree = {bestFeatLabel: {}}
    del (labels[bestFeat])
    #  Get the list, including all attribute values of the node 
    featValues = [example[bestFeat] for example in dataset]
    uniqueVals = set(featValues)


    if pre_pruning:
        ans = []
        for index in range(len(test_dataset)):
            ans.append(test_dataset[index][-1])
        result_counter = Counter()
        for vec in dataset:
            result_counter[vec[-1]] += 1
        leaf_output = result_counter.most_common(1)[0][0]
        root_acc = cal_acc(test_output=[leaf_output] * len(test_dataset), label=ans)
        outputs = []
        ans = []
        for value in uniqueVals:
            cut_testset = splitdataset(test_dataset, bestFeat, value)
            cut_dataset = splitdataset(dataset, bestFeat, value)
            for vec in cut_testset:
                ans.append(vec[-1])
            result_counter = Counter()
            for vec in cut_dataset:
                result_counter[vec[-1]] += 1
            leaf_output = result_counter.most_common(1)[0][0]
            outputs += [leaf_output] * len(cut_testset)
        cut_acc = cal_acc(test_output=outputs, label=ans)


        if cut_acc <= root_acc:
            return leaf_output


    for value in uniqueVals:
        subLabels = labels[:]
        CARTTree[bestFeatLabel][value] = CART_createTree(
            splitdataset(dataset, bestFeat, value),
            subLabels,
            splitdataset(test_dataset, bestFeat, value))


        if post_pruning:
            tree_output = classifytest(CARTTree,
                                       featLabels=[' age group ', ' There's work ', ' Have your own house ', ' Credit situation '],
                                       testDataSet=test_dataset)
            ans = []
            for vec in test_dataset:
                ans.append(vec[-1])
            root_acc = cal_acc(tree_output, ans)
            result_counter = Counter()
            for vec in dataset:
                result_counter[vec[-1]] += 1
            leaf_output = result_counter.most_common(1)[0][0]
            cut_acc = cal_acc([leaf_output] * len(test_dataset), ans)


            if cut_acc >= root_acc:
                return leaf_output


    return CARTTree




def classify(inputTree, featLabels, testVec):
    """
     Input : Decision tree , Category labels , Test data 
     Output : The result of the decision 
     describe : Run the decision tree 
    """
    firstStr = list(inputTree.keys())[0]
    secondDict = inputTree[firstStr]
    featIndex = featLabels.index(firstStr)
    classLabel = '0'
    for key in secondDict.keys():
        if testVec[featIndex] == key:
            if type(secondDict[key]).__name__ == 'dict':
                classLabel = classify(secondDict[key], featLabels, testVec)
            else:
                classLabel = secondDict[key]
    return classLabel




def classifytest(inputTree, featLabels, testDataSet):
    """
     Input : Decision tree , Category labels , Test data set 
     Output : The result of the decision 
     describe : Run the decision tree 
    """
    classLabelAll = []
    for testVec in testDataSet:
        classLabelAll.append(classify(inputTree, featLabels, testVec))
    return classLabelAll




def cal_acc(test_output, label):
    """
    :param test_output: the output of testset
    :param label: the answer
    :return: the acc of
    """
    assert len(test_output) == len(label)
    count = 0
    for index in range(len(test_output)):
        if test_output[index] == label[index]:
            count += 1


    return float(count / len(test_output))


if __name__ == '__main__':
    filename = 'dataset.txt'
    testfile = 'testset.txt'
    dataset, labels = read_dataset(filename)
    # dataset,features=createDataSet()
    print('dataset', dataset)
    print("---------------------------------------------")
    print(u" Data set length ", len(dataset))
    print("Ent(D):", cal_entropy(dataset))
    print("---------------------------------------------")
    print(u" The following is the first time to find the optimal index :\n")
    print(u"ID3 The optimal feature index of the algorithm is :" + str(ID3_chooseBestFeatureToSplit(dataset)))
    print("--------------------------------------------------")
    print(u"C4.5 The optimal feature index of the algorithm is :" + str(C45_chooseBestFeatureToSplit(dataset)))
    print("--------------------------------------------------")
    print(u"CART The optimal feature index of the algorithm is :" + str(CART_chooseBestFeatureToSplit(dataset)))
    print(u" The first search for the optimal index ends !")
    print("---------------------------------------------")
    print(u" Let's start to create the corresponding decision tree -------")


    while True:
        dec_tree = '3'
        # ID3 Decision tree 
        if dec_tree == '1':
            labels_tmp = labels[:]  #  Copy ,createTree Will change labels
            ID3desicionTree = ID3_createTree(dataset, labels_tmp, test_dataset=read_testset(testfile))
            print('ID3desicionTree:\n', ID3desicionTree)
            # treePlotter.createPlot(ID3desicionTree)
            treePlotter.ID3_Tree(ID3desicionTree)
            testSet = read_testset(testfile)
            print(" The following is the result of the test data set :")
            print('ID3_TestSet_classifyResult:\n', classifytest(ID3desicionTree, labels, testSet))
            print("---------------------------------------------")


        # C4.5 Decision tree 
        if dec_tree == '2':
            labels_tmp = labels[:]  #  Copy ,createTree Will change labels
            C45desicionTree = C45_createTree(dataset, labels_tmp, test_dataset=read_testset(testfile))
            print('C45desicionTree:\n', C45desicionTree)
            treePlotter.C45_Tree(C45desicionTree)
            testSet = read_testset(testfile)
            print(" The following is the result of the test data set :")
            print('C4.5_TestSet_classifyResult:\n', classifytest(C45desicionTree, labels, testSet))
            print("---------------------------------------------")


        # CART Decision tree 
        if dec_tree == '3':
            labels_tmp = labels[:]  #  Copy ,createTree Will change labels
            CARTdesicionTree = CART_createTree(dataset, labels_tmp, test_dataset=read_testset(testfile))
            print('CARTdesicionTree:\n', CARTdesicionTree)
            treePlotter.CART_Tree(CARTdesicionTree)
            testSet = read_testset(testfile)
            print(" The following is the result of the test data set :")
            print('CART_TestSet_classifyResult:\n', classifytest(CARTdesicionTree, labels, testSet))
        break

dataset.txt Data set in :

0,0,0,0,0
0,0,0,1,0
0,1,0,1,1
0,1,1,0,1
0,0,0,0,0
1,0,0,0,0
1,0,0,1,0
1,1,1,1,1
1,0,1,2,1
1,0,1,2,1
2,0,1,2,1
2,0,1,1,1
2,1,0,1,1
2,1,0,2,1
2,0,0,0,0
2,0,0,2,0

treePlotter.py

import matplotlib.pyplot as plt
from pylab import mpl
mpl.rcParams['font.sans-serif'] = ['SimHei']
decisionNode = dict(boxstyle="sawtooth", fc="0.8")
leafNode = dict(boxstyle="round4", fc="0.8")
arrow_args = dict(arrowstyle="<-")


def plotNode(nodeTxt, centerPt, parentPt, nodeType):
    createPlot.ax1.annotate(nodeTxt, xy=parentPt, xycoords='axes fraction', \
                            xytext=centerPt, textcoords='axes fraction', \
                            va="center", ha="center", bbox=nodeType, arrowprops=arrow_args)


def getNumLeafs(myTree):
    numLeafs = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            numLeafs += getNumLeafs(secondDict[key])
        else:
            numLeafs += 1
    return numLeafs


def getTreeDepth(myTree):
    maxDepth = 0
    firstStr = list(myTree.keys())[0]
    secondDict = myTree[firstStr]
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            thisDepth = getTreeDepth(secondDict[key]) + 1
        else:
            thisDepth = 1
        if thisDepth > maxDepth:
            maxDepth = thisDepth
    return maxDepth


def plotMidText(cntrPt, parentPt, txtString):
    xMid = (parentPt[0] - cntrPt[0]) / 2.0 + cntrPt[0]
    yMid = (parentPt[1] - cntrPt[1]) / 2.0 + cntrPt[1]
    createPlot.ax1.text(xMid, yMid, txtString)


def plotTree(myTree, parentPt, nodeTxt):
    numLeafs = getNumLeafs(myTree)
    depth = getTreeDepth(myTree)
    firstStr = list(myTree.keys())[0]
    cntrPt = (plotTree.xOff + (1.0 + float(numLeafs)) / 2.0 / plotTree.totalw, plotTree.yOff)
    plotMidText(cntrPt, parentPt, nodeTxt)
    plotNode(firstStr, cntrPt, parentPt, decisionNode)
    secondDict = myTree[firstStr]
    plotTree.yOff = plotTree.yOff - 1.0 / plotTree.totalD
    for key in secondDict.keys():
        if type(secondDict[key]).__name__ == 'dict':
            plotTree(secondDict[key], cntrPt, str(key))
        else:
            plotTree.xOff = plotTree.xOff + 1.0 / plotTree.totalw
            plotNode(secondDict[key], (plotTree.xOff, plotTree.yOff), cntrPt, leafNode)
            plotMidText((plotTree.xOff, plotTree.yOff), cntrPt, str(key))
    plotTree.yOff = plotTree.yOff + 1.0 / plotTree.totalD


def createPlot(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalw = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalw
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    #plt.show()
#ID3 Decision tree 
def ID3_Tree(inTree):
    fig = plt.figure(1, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalw = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalw
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.title("ID3 Decision tree ",fontsize=12,color='red')
    plt.show()


#C4.5 Decision tree 
def C45_Tree(inTree):
    fig = plt.figure(2, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalw = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalw
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.title("C4.5 Decision tree ",fontsize=12,color='red')
    plt.show()


#CART Decision tree 
def CART_Tree(inTree):
    fig = plt.figure(3, facecolor='white')
    fig.clf()
    axprops = dict(xticks=[], yticks=[])
    createPlot.ax1 = plt.subplot(111, frameon=False, **axprops)
    plotTree.totalw = float(getNumLeafs(inTree))
    plotTree.totalD = float(getTreeDepth(inTree))
    plotTree.xOff = -0.5 / plotTree.totalw
    plotTree.yOff = 1.0
    plotTree(inTree, (0.5, 1.0), '')
    plt.title("CART Decision tree ",fontsize=12,color='red')
    plt.show()
原网站

版权声明
本文为[Ofter Data Science]所创,转载请带上原文链接,感谢
https://yzsam.com/2022/160/202206090911219729.html