You are on page 1of 4

# package for trees library(rpart) # package including data from Elements of Statistical Learning library(ElemStatLearn) data(spam) # make response

a 0-1 outcome #spam$spam = ifelse(spam$spam=="spam",1,0) spam.sub = c(1:nrow(spam))[spam$spam == 'spam'] nospam.sub = c(1:nrow(spam))[spam$spam == 'email'] # use 2/3 for training, 1/3 for test train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) train = c(train.spam,train.email) train.set = spam[train,] test.set = spam[-train,] rpart.spam = rpart(spam ~ ., data=train.set, method="class", parms=list(split="gini")) # take a look at the decision rule print(summary(rpart.spam)) png("spam_tree.png", height=600, width=900) # visualize it (gets difficult for bigger trees) post(rpart.spam, filename='') dev.off() # predict the labels for the test set predict.spam = predict(rpart.spam, test.set) plabels.spam = colnames(predict.spam)[apply(predict.spam, 1, which.max)] # compute the various measures of accuracy classification.summary = function(plabels, tlabels) { # true positives: things we labelled spam that are spam

TP = sum((plabels.spam == 'spam') * (tlabels == 'spam')) # false positives: things we labelled spam that are email FP = sum((plabels.spam == 'spam') * (tlabels == 'email')) # true negatives: things we labelled email that are email TN = sum((plabels.spam == 'email') * (tlabels == 'email')) # false negatives: things we labelled email that are spam FN = sum((plabels.spam == 'email') * (tlabels == 'spam')) # accuracy A = (TP+TN) / (TP+TN+FP+FN) # sensitivity sens = TP / (TP+FN) # specificity spec = TN / (TN+FP) # precision prec = TP / (TP+FN) # confusion matrix C = matrix(c(TP,FP,FN,TN),2,2) colnames(C) = c('predicted spam', 'predicted email') rownames(C) = c('truly spam', 'truly email') return(list(A=A,TP=TP,FP=FP,TN=TN,FN=FN,C=C,sens=sens,spec=spec)) } s = classification.summary(plabels.spam, test.set$spam) print(s)

png("spam_cptree.png", height=1200, width=800) # you can control some aspects of the tree building process # with rpart.control rpart.spam.deeper = rpart(spam ~ ., data=train.set, method="class", parms=list(split="gini"), control=rpart.control(cp=0.00001, xval=20)) post(rpart.spam, filename='')

dev.off() # let's look at the stability of the tree png("spam_repeat0.png", height=600, width=600)

train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) train = c(train.spam,train.email) train.set = spam[train,] test.set = spam[-train,] rpart.spam = rpart(spam ~ ., data=train.set, method="class", parms=list(split="gini")) post(rpart.spam, filename='') dev.off() png("spam_repeat1.png", height=600, width=600)

train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) train = c(train.spam,train.email) train.set = spam[train,] test.set = spam[-train,] rpart.spam = rpart(spam ~ ., data=train.set, method="class", parms=list(split="gini")) post(rpart.spam, filename='') dev.off() png("spam_repeat2.png", height=600, width=600)

train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) train = c(train.spam,train.email) train.set = spam[train,] test.set = spam[-train,] rpart.spam = rpart(spam ~ ., data=train.set, method="class", parms=list(split="gini")) post(rpart.spam, filename='') dev.off() png("spam_repeat3.png", height=600, width=600)

train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) train = c(train.spam,train.email) train.set = spam[train,] test.set = spam[-train,] rpart.spam = rpart(spam ~ ., data=train.set, method="class", parms=list(split="gini")) post(rpart.spam, filename='') dev.off() png("spam_repeat4.png", height=600, width=600)

train.spam = sample(spam.sub,floor(length(spam.sub)*2/3)) train.email = sample(nospam.sub,floor(length(nospam.sub)*2/3)) train = c(train.spam,train.email) train.set = spam[train,] test.set = spam[-train,] rpart.spam = rpart(spam ~ ., data=train.set, method="class", parms=list(split="gini")) post(rpart.spam, filename='') dev.off() png("spamROC.png", height=600, width=600) predict.spam = predict(rpart.spam, test.set) l = sort(unique(predict.spam[,'spam'])) sens = c() spec = c() for (ll in l) { plabels.spam = rep('email', nrow(predict.spam)) plabels.spam[(predict.spam[,'spam'] >= ll)] = 'spam' s = classification.summary(plabels.spam, test.set$spam) sens = c(sens, s$sens) spec = c(spec, s$spec) } sens = c(1,sens,0) spec = c(0,spec,1) plot(1-spec, sens, type='l', col='red', lwd=2) abline(0,1,lwd=2, lty=2, col='blue') dev.off()

You might also like