You are on page 1of 3

import math

import operator
import matplotlib.pyplot as plt
import TreePlotter

def createDataset():

dataSet = [ #数据集
['青绿', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['乌黑', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
['乌黑', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['青绿', '蜷缩', '沉闷', '清晰', '凹陷', '硬滑', '好瓜'],
['浅白', '蜷缩', '浊响', '清晰', '凹陷', '硬滑', '好瓜'],
['青绿', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '好瓜'],
['乌黑', '稍蜷', '浊响', '稍糊', '稍凹', '软粘', '好瓜'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '硬滑', '好瓜'],
['乌黑', '稍蜷', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜'],
['青绿', '硬挺', '清脆', '清晰', '平坦', '软粘', '坏瓜'],
['浅白', '硬挺', '清脆', '模糊', '平坦', '硬滑', '坏瓜'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '软粘', '坏瓜'],
['青绿', '稍蜷', '浊响', '稍糊', '凹陷', '硬滑', '坏瓜'],
['浅白', '稍蜷', '沉闷', '稍糊', '凹陷', '硬滑', '坏瓜'],
['乌黑', '稍蜷', '浊响', '清晰', '稍凹', '软粘', '坏瓜'],
['浅白', '蜷缩', '浊响', '模糊', '平坦', '硬滑', '坏瓜'],
['青绿', '蜷缩', '沉闷', '稍糊', '稍凹', '硬滑', '坏瓜']
]
labels = ['色泽', '根蒂', '敲击', '纹理', '脐部', '触感'] #西瓜所有属性的列表

return dataSet,labels

def calcShannonEnt(dataSet): #计算信息熵


numEntries=len(dataSet)
labelCounts={}
for featVec in dataSet:
currentLable=featVec[-1]

if currentLable not in labelCounts.keys():


labelCounts[currentLable]=0
labelCounts[currentLable]+=1
#print(labelCounts)

shannonEnt=0 #设置信息熵的初始值为零
for key in labelCounts:
prob=float(labelCounts[key])/numEntries
shannonEnt-=prob*math.log(prob,2)

return shannonEnt

def splitDataSet(dataSet,axis,value): #划分数据集


retDataSet=[]

for featVec in dataSet:

if featVec[axis]==value:
reducedFeatVec=featVec[:axis]
reducedFeatVec.extend(featVec[axis+1:])
retDataSet.append(reducedFeatVec)

return retDataSet

def chooseBestFeatureToSplit(dataSet): #选择最好的特征划分数据集


numFeatures=len(dataSet[0])-1
baseEntropy=calcShannonEnt(dataSet) #计算数据的信息熵
bestInfoGain=0.0
bestFeature=-1

for i in range(numFeatures):
featList=[example[i] for example in dataSet]
uniqueVals=set(featList)
newEntropy=0

for value in uniqueVals:


subDataSet=splitDataSet(dataSet,i,value)
prob=len(subDataSet)/float(len(dataSet))
newEntropy+=prob*calcShannonEnt(subDataSet)
infoGain=baseEntropy-newEntropy

if infoGain>bestInfoGain:
bestInfoGain=infoGain
bestFeature=i
return bestFeature

def majorityCnt(classList):
classCount={}
for vote in classList:
if vote not in classCount.keys():
classCount[vote]=0
classCount[vote]+=1

sortedClassCount=sorted(classCount.items(),key=operator.itemgetter(1),reverse=True)
#排序
print(type(sortedClassCount))
print(sortedClassCount)
return sortedClassCount[0][0] #返回最多的类别的名称

def createTree(dataSet,labels):
classList=[example[-1] for example in dataSet] #形成一个数据库中最后一列的列表
if classList.count(classList[0])==len(dataSet): #好瓜的数量
return classList[0]
if len(dataSet[0])==1:
return majorityCnt(classList) #返回个数最多的类别
bestFeat=chooseBestFeatureToSplit(dataSet) #返回最好的属性
bestFeatLabel=labels[bestFeat]
myTree={bestFeatLabel:{}}
del(labels[bestFeat])
featValues=[example[bestFeat] for example in dataSet] #属性最好的取值
uniqueVals=set(featValues) #去重
for value in uniqueVals: #对每个分支进行以上类似操作
subLabels=labels[:]
myTree[bestFeatLabel]
[value]=createTree(splitDataSet(dataSet,bestFeat,value),subLabels)
return myTree
dataSet,labels=createDataset()
myTree=createTree(dataSet,labels)
TreePlotter.createPlot(myTree)
print(myTree)

You might also like