Live code

Live code
Class Imbalance
Published

April 11, 2023

library(tidyverse)
library(pROC)

Introduction

The haberman dataset contains cases from a study that was conducted between 1958 and 1970 at the University of Chicago’s Billings Hospital on the survival of patients who had undergone surgery for breast cancer.

library(imbalance)
data(haberman)
haberman %>% 
  slice(1:5)
  Age Year Positive    Class
1  38   59        2 negative
2  39   63        4 negative
3  49   62        1 negative
4  53   60        2 negative
5  47   68        4 negative
  • Variables:
    • Age: age of patient at time of operation
    • Year: patient’s year of operation
    • Positive: number of positive axillary nodes detected
    • Class: Two possible survival status: “positive” (survival rate of less than 5 years), “negative” (survival rate or more than 5 years)
  • We may be interested in predicting the probability of survival

Imbalanced data

haberman %>%
  count(Class) %>%
  mutate(prop = n / sum(n))
     Class   n      prop
1 negative 225 0.7352941
2 positive  81 0.2647059
  • What do you notice?

    • “Class imbalance”: one (or more, if \(J \geq 2\)) of the possible labels is several underrepresented in the data
  • Discuss: I claim that class imbalance is an issue for predictive models! Why do you think that is?

Logistic regression

set.seed(2)
haberman <- haberman %>%
  mutate(Class = ifelse(Class == "positive", 1, 0))
n <- nrow(haberman)
train_ids <- sample(1:n, 0.8*n)
train_dat <- haberman[train_ids,]
test_dat <- haberman[-train_ids,]
log_mod <- glm(Class ~ Positive, data = train_dat, family = "binomial")
pred_probs <- predict(log_mod, newdata = test_dat, type = "response")
pred_class <- as.numeric(pred_probs >= 0.5)
table(preds = pred_class, true = test_dat$Class)
     true
preds  0  1
    0 44 12
    1  3  3
roc(test_dat$Class,  pred_class)
Setting levels: control = 0, case = 1
Setting direction: controls < cases

Call:
roc.default(response = test_dat$Class, predictor = pred_class)

Data: pred_class in 47 controls (test_dat$Class 0) < 15 cases (test_dat$Class 1).
Area under the curve: 0.5681

Oversampling

# Randomly duplicating examples from the minority class and adding them to the training dataset.
set.seed(3)
train_minority <- which(train_dat$Class == 1)
train_majority <- which(train_dat$Class == 0)
n_min <- length(train_minority)
n_maj <- length(train_majority)
over_ids <- sample(train_minority, size = 40, replace = T)

train_dat_oversample <- rbind(train_dat, train_dat[over_ids,])
mod_oversample <-  glm(Class ~ Positive, data = train_dat_oversample, family = "binomial")
pred_probs <- predict(mod_oversample, newdata = test_dat, type = "response")
pred_class <- as.numeric(pred_probs >= 0.5) 
table(preds = pred_class, true = test_dat$Class)
     true
preds  0  1
    0 41 10
    1  6  5
# the random oversampling may increase the likelihood of overfitting occurring, since it makes exact copies of the minority class examples
roc(test_dat$Class,  pred_class)
Setting levels: control = 0, case = 1
Setting direction: controls < cases

Call:
roc.default(response = test_dat$Class, predictor = pred_class)

Data: pred_class in 47 controls (test_dat$Class 0) < 15 cases (test_dat$Class 1).
Area under the curve: 0.6028

Undersampling

# Randomly remove examples from the majority class in the training dataset.
set.seed(3)
remove_ids <- sample(train_majority, n_maj - n_min, replace = F)

mod_undersample <-  glm(Class ~ Positive, data = train_dat[-remove_ids,], family = "binomial")
pred_probs <- predict(mod_undersample, newdata = test_dat, type = "response")
pred_class <-as.numeric(pred_probs >= 0.5) 
table(preds = pred_class, true = test_dat$Class)
     true
preds  0  1
    0 34  8
    1 13  7
# the random oversampling may increase the likelihood of overfitting occurring, since it makes exact copies of the minority class examples
roc(test_dat$Class,  pred_class)
Setting levels: control = 0, case = 1
Setting direction: controls < cases

Call:
roc.default(response = test_dat$Class, predictor = pred_class)

Data: pred_class in 47 controls (test_dat$Class 0) < 15 cases (test_dat$Class 1).
Area under the curve: 0.595