Live code:

Live code
Classification trees
Published

April 20, 2023

Data

library(tidyverse)
obesity <- read.csv("https://raw.githubusercontent.com/math218-spring2023/class-data/main/obesity.csv")

obesity <- obesity %>%
  mutate_if(is.character, factor)
obesity %>%
  count(class)
          class   n
1  Insufficient 272
2        Normal 287
3       Obese_I 351
4      Obese_II 297
5     Obese_III 324
6  Overweight_I 290
7 Overweight_II 290

Train/test split

set.seed(1)
n <- nrow(obesity)
train_ids <- sample(1:n, 0.8*n)
train_dat <- obesity[train_ids,]
test_dat <- obesity[-train_ids,]

Fit classification tree + prune

library(tree)
my_tree <- tree(class ~ ., data = train_dat,
                       control = tree.control(nobs = nrow(train_dat), minsize = 2))
cv_tree <- cv.tree(my_tree, FUN = prune.misclass)
best_size <- min(cv_tree$size[which(cv_tree$dev == min(cv_tree$dev))])
## equivalent to: prune.tree(my_tree, best = best_size, method - "misclass")
prune_tree <- prune.misclass(my_tree, best = best_size)
plot(prune_tree)
text(prune_tree, pretty = 0, cex = 0.6)

Obtain predictions

head(predict(prune_tree, newdata = test_dat))
   Insufficient     Normal   Obese_I   Obese_II Obese_III Overweight_I
3    0.00000000 0.09395973 0.0000000 0.00000000         0    0.5302013
4    0.00000000 0.09395973 0.0000000 0.00000000         0    0.5302013
6    0.01190476 0.89285714 0.0000000 0.00000000         0    0.0952381
10   0.00000000 0.11382114 0.0000000 0.00000000         0    0.8536585
18   0.00000000 0.00000000 1.0000000 0.00000000         0    0.0000000
24   0.00000000 0.00000000 0.9519231 0.02884615         0    0.0000000
   Overweight_II
3     0.37583893
4     0.37583893
6     0.00000000
10    0.03252033
18    0.00000000
24    0.01923077
head(predict(prune_tree, newdata = test_dat, type = "class"))
[1] Overweight_I Overweight_I Normal       Overweight_I Obese_I     
[6] Obese_I     
7 Levels: Insufficient Normal Obese_I Obese_II Obese_III ... Overweight_II
tree_preds <- predict(prune_tree, newdata = test_dat, type = "class")

mean(tree_preds != test_dat$class)
[1] 0.1583924

Bagged classification trees

library(randomForest)
randomForest 4.7-1.1
Type rfNews() to see new features/changes/bug fixes.

Attaching package: 'randomForest'
The following object is masked from 'package:dplyr':

    combine
The following object is masked from 'package:ggplot2':

    margin
my_bag <- randomForest(class ~ ., data = obesity, ntree = 100,
             mtry = ncol(obesity) - 1)
varImpPlot(my_bag)

bag_preds <- my_bag$predicted[-train_ids]
mean(bag_preds != test_dat$class)
[1] 0.02600473

Random forest classification

m <- round(sqrt(ncol(obesity)-1))
my_rf <- randomForest(class ~ ., data = obesity, ntree = 100,
                           mtry = m)
varImpPlot(my_rf)

rf_preds <- my_rf$predicted[-train_ids]
mean(rf_preds != test_dat$class)
[1] 0.04964539