1 R Setup and Required Packages

In this project, the open-source R programming language is used to help understand the variability in classification metrics with class imbalance, sample size, and different number of classes. R is maintained by an international team of developers who make the language available at The Comprehensive R Archive Network. Readers interested in reusing our code and reproducing our results should have R installed locally on their machines. R can be installed on a number of different operating systems (see Windows, Mac, and Linux for the installation instructions for these systems). We also recommend using the RStudio interface for R. The reader can download RStudio for free by following the instructions at the link. For non-R users, we recommend the Hands-on Programming with R for a brief overview of the software’s functionality. Hereafter, we assume that the reader has an introductory understanding of the R programming language.

In the code chunk below, we load the packages used to support our analysis. Note that the code of this and any of the code chunks can be hidden by clicking on the ‘Hide’ button to facilitate the navigation. The reader can hide all code and/or download the Rmd file associated with this document by clicking on the Code button on the top right corner of this document. Our input and output files can also be accessed/ downloaded from fmegahed/metric_interpretation.

# installing the pacman (PACkage MANager) package if not installed 
if(require(pacman)==FALSE) install.packages("pacman")

# load (and if not installed, install) the needed packages
pacman::p_load(
  tidyverse, # for data manipulation
  caret, # for predictive performance metrics
  plotly, # for interactive plots
  DT # for nicely formatted tables
) 

2 Preparing for the Simulation

2.1 Simulation Parameters

sim_num = 1:10^4 # setting the number of simulations per exp run to 10,000

num_classes = c(2, 3, 5, 10) # number of classes

num_observations = c(100, 10^3, 10^4, 10^5) # number of observations

first_class_perc = c(0.01, 0.05, 0.1, 0.2, (1/3), 0.5) # imbalance

pred_approach = c('uniform', 'proportional', 'most_frequent')

quant_of_interest = 0.99 # quantile of interest

2.2 Full Factorial Setup

In the chunk below, we create a nested objected titled sim_setup, which is a tibble containing three variables:

  1. exp_num a unique id for each combination of num_classes, num_observations, first_class_perc and pred_approach.
  2. sim_num, which has values from 1 to 10,000, repeating for each ID (i.e., each value forexp_num has sim_num from 1 to 10,000).
  3. data, which is list column, where each cell/row contains a tibble of 1 row and the four variables num_classes, num_observations, first_class_perc and pred_approach.

Note that the data column is the input used in our custom sim_fun(), which is shown/created in a subsequent subsection.

# the full factorial setup
sim_setup = 
  # creating a tibble of all combinations of the five experimental conditions
  expand_grid( sim_num, num_classes, num_observations, first_class_perc,
               pred_approach) |>
  # using group_by to create ids which are NOT dependent on the sim_num
  group_by(num_classes, num_observations, first_class_perc, pred_approach) |>
  # using the dplyr::cur_group_id() to create the IDs
  mutate( exp_num = cur_group_id() ) |>
  # ungrouping since it was only performed to create unique ids
  ungroup() |>
  # moving the exp_num to the beginning
  relocate( exp_num ) |>
  # creating list columns for all the columns with except of ID-type columns
  nest(data = -c(exp_num, sim_num) ) |>
  # sort the experimental data by exp_num
  arrange(exp_num)


# saving the results as an rds file under the results folder
write_rds(
  x = sim_setup, file = '../results/sim_setup.rds'
)

2.3 Custom Functions

2.3.1 g-mean Calculation

The g_mean() function below is used to compute the geomteric mean of any two variables.

g_mean = function(x, y){
  
  result = sqrt(x*y)
  
  return(result)
}

2.3.2 Big Simulation Function for Vectorized/Functional Programming

The sim_function() is a custom vectorized function that generates:

  1. the simulated observations, which are used to mimic true (aka reference) values used to measure the classification performance
  2. simulated values for the predictions

The two vectors of observations and predictions are based on the values inputted in the Simulation Paramaters section (i.e., num_classes, num_observations, first_class_perc and pred_approach).

The function returns seven classification metrics of interest:

  1. overall accuracy,
  2. sensitivity for the reference class class_1,
  3. specificity for the reference class class_1,
  4. the geometric mean of sensitivity and specificity,
  5. precision for the reference class class_1,
  6. recall for the reference class class_1, and
  7. f_measure for the reference class class_1.
sim_function = function(data = NULL, base_class_name = 'class_'){
  
  # extracting the relevant features from the data 
  number_classes = data$num_classes
  number_observations = data$num_observations
  first_class_percentage = data$first_class_perc
  prediction_approach = data$pred_approach


  # generating the reference data
  x = c(paste0(base_class_name, 1:number_classes))
  ref = c(rep(x[1], number_observations*first_class_percentage),
          rep(x[-1], each = number_observations*(1-first_class_percentage)/(number_classes-1)))
  
  if (length(ref)!=number_classes) ref = c(ref, rep(x[number_classes], (number_observations-length(ref))))
  
  ref = sample(ref, number_observations, replace=F) |> factor(levels = paste0(base_class_name, 1:number_classes) )
  
  # generating the predictions
  if(prediction_approach == 'uniform'){
    predictions = sample(
      x = c(paste0(base_class_name, 1:number_classes)),
      size = number_observations,
      replace = T,
      prob = rep(1/number_classes, number_classes) 
    ) |> factor(levels = paste0(base_class_name, 1:number_classes))
  }else if(prediction_approach == 'proportional'){
    predictions = sample(
      x = c(paste0(base_class_name, 1:number_classes)),
      size = number_observations,
      replace = T,
      prob = (table(ref)/length(ref))
    ) |> factor(levels = paste0(base_class_name, 1:number_classes))
  } else{
    prob = (table(ref)/length(ref))
    predictions = rep(
      paste0(base_class_name, which.max(prob)),
      times = number_observations
    ) |> factor(levels = paste0(base_class_name, 1:number_classes) )
  }
  
  conf_matrix = caret::confusionMatrix(data = predictions, reference = ref)
  
  # computing the acc metrics of interest
  if (number_classes == 2){
    acc_metrics = tibble(
      accuracy = conf_matrix$overall['Accuracy'],
      sensitivity = conf_matrix$byClass["Sensitivity"],
      specificity = conf_matrix$byClass["Specificity"],
      g_mean = g_mean(x = sensitivity, y = specificity),
      precision = conf_matrix$byClass["Precision"],
      recall = conf_matrix$byClass["Recall"],
      f_measure = conf_matrix$byClass["F1"]
    )
  }else{
  acc_metrics = tibble(
    accuracy = conf_matrix$overall['Accuracy'],
    sensitivity = conf_matrix$byClass[paste0("Class: ", base_class_name, 1), "Sensitivity"],
    specificity = conf_matrix$byClass[paste0("Class: ", base_class_name, 1), "Specificity"],
    g_mean = g_mean(x = sensitivity, y = specificity),
    precision = conf_matrix$byClass[paste0("Class: ", base_class_name, 1), "Precision"],
    recall = conf_matrix$byClass[paste0("Class: ", base_class_name, 1), "Recall"],
    f_measure = conf_matrix$byClass[paste0("Class: ", base_class_name, 1), "F1"]
  )
  }
  return(acc_metrics)
}

3 Running the Experiment

With our sim_fun, we can use a functional programming approach to compute the five classification metrics of interest for each simulation run. The results will be stored in an object titled experimental_results, which contains the 3 original columns from sim_setup and the new list column titled acc_metrics.

experimental_results = 
  sim_setup |># setup from a previous chunk
  # creating a new list column titled acc_metrics, where the five classification
  # metrics of choice are stored for each simulation run
  
  # The map function from purrr (part of tidyverse) allows us to perform row-wise 
  # computations (i.e., for each combination of simulation parameters)
  # on the `data` list column, where we use the variables within it to
  # generate the random values and compute the classification metrics.
  mutate(
    acc_metrics = map(.x = data, .f = sim_function, base_class_name = 'class_')
  )

# saving the experimental results
write_rds(
  x = experimental_results, file = '../results/experimental_results.rds'
)

4 Summarizing the Experimental Results

Below, we compute the means, standard deviations and quantiles for each of the metrics, grouped by the ID (i.e., exp_num). We save the summary results into RDS and CSV formats (to facilitate their consumption by both R and non-R users). In addition, we print a table of the obtained results below.

summary_results = 
  experimental_results |>
  # unnesting the acc_metrics (i.e., generating columns for each of the metrics)
  unnest( acc_metrics ) |>
  # grouping by exp_num 
  # (and data which is redundant but to keep it in the summary table)
  group_by(exp_num, data) |>
  # summaries of means, sds and quantile of interest for each metric by exp_num
  summarise(
    across(accuracy:f_measure,  
           list(mean = ~ mean(.x, na.rm = TRUE),
                sd = ~ sd(.x, na.rm = TRUE),
                quantile = ~ quantile(.x, probs = quant_of_interest, na.rm = TRUE) ), 
           .names = "{.fn}_{.col}" )
  ) |>
  # expanding the data into the different experimental conditions for interpretation
  unnest(data)

# storing the results as rds
write_rds(
  x = summary_results, file = '../results/summary_results.rds'
)

# we also, store the results as CSV
write_csv(
  x = summary_results, file = '../results/summary_results.csv'
)

# printing the results
DT::datatable(
  summary_results |>mutate(across(where(is.numeric), round, 3)) |>
    # moving and renaming columns purely for output presentation purposes
    relocate(exp_num, num_classes, first_class_perc, pred_approach, num_observations) |>
    rename(n_obs = num_observations, 
           ref_prob = first_class_perc,
           pred = pred_approach, 
           mean_acc = mean_accuracy, sd_acc = sd_accuracy,
           mean_sens = mean_sensitivity, sd_sens = sd_sensitivity, 
           mean_spec = mean_specificity, sd_spec = sd_specificity,
           mean_prec = mean_precision, sd_prec = sd_precision,
           mean_f_meas = mean_f_measure, sd_f_meas = sd_f_measure), 
  filter = 'top',
  extensions = c('Buttons', 'FixedColumns'), 
  rownames = FALSE,
  options = list(
  initComplete = JS("function(settings, json) {$(this.api().table().container()).css({'font-size' : '10pt'});}"), 
    dom = 'Bfrtip',
    buttons = c('copy', 'csv', 'excel', 'pdf', 'print'),
    pageLength = 20,
    scrollX = TRUE,
    fixedColumns = list(leftColumns = 4)
  ) 
)

The results can be downloaded from summary_results.csv.


5 App for Citizen Data Scientists

We have created a flexdashboard-based web application, which allows users to input the simulation parameters of interest and output the corresponding simulation summary and histograms of our five classification metrics of interest. The application can be accessed here and the code used to create the application can be downloaded from here.


Appendix

In this appendix, we print all the R packages used in our analysis and their versions to assist with reproducing our results/analysis.

pacman::p_load(pander)
pander::pander(sessionInfo(), compact = TRUE) # printing the session info

R version 4.2.2 (2022-10-31 ucrt)

Platform: x86_64-w64-mingw32/x64 (64-bit)

locale: LC_COLLATE=English_United States.utf8, LC_CTYPE=English_United States.utf8, LC_MONETARY=English_United States.utf8, LC_NUMERIC=C and LC_TIME=English_United States.utf8

attached base packages: stats, graphics, grDevices, utils, datasets, methods and base

other attached packages: pander(v.0.6.5), DT(v.0.27), plotly(v.4.10.1), caret(v.6.0-93), lattice(v.0.20-45), lubridate(v.1.9.2), forcats(v.1.0.0), stringr(v.1.5.0), dplyr(v.1.1.0), purrr(v.1.0.1), readr(v.2.1.4), tidyr(v.1.3.0), tibble(v.3.2.0), tidyverse(v.2.0.0), pacman(v.0.5.1), RColorBrewer(v.1.1-3) and ggplot2(v.3.4.1)

loaded via a namespace (and not attached): nlme(v.3.1-162), bit64(v.4.0.5), httr(v.1.4.5), tools(v.4.2.2), bslib(v.0.4.2), utf8(v.1.2.3), R6(v.2.5.1), rpart(v.4.1.19), lazyeval(v.0.2.2), colorspace(v.2.1-0), nnet(v.7.3-18), withr(v.2.5.0), tidyselect(v.1.2.0), bit(v.4.0.5), compiler(v.4.2.2), cli(v.3.6.0), sass(v.0.4.5), scales(v.1.2.1), proxy(v.0.4-27), digest(v.0.6.31), rmarkdown(v.2.20), pkgconfig(v.2.0.3), htmltools(v.0.5.4), parallelly(v.1.34.0), fastmap(v.1.1.1), htmlwidgets(v.1.6.1), rlang(v.1.0.6), rstudioapi(v.0.14), jquerylib(v.0.1.4), generics(v.0.1.3), jsonlite(v.1.8.4), crosstalk(v.1.2.0), vroom(v.1.6.1), ModelMetrics(v.1.2.2.2), magrittr(v.2.0.3), Matrix(v.1.5-3), Rcpp(v.1.0.10), munsell(v.0.5.0), fansi(v.1.0.4), lifecycle(v.1.0.3), stringi(v.1.7.12), pROC(v.1.18.0), yaml(v.2.3.7), MASS(v.7.3-58.3), plyr(v.1.8.8), recipes(v.1.0.5), grid(v.4.2.2), parallel(v.4.2.2), listenv(v.0.9.0), crayon(v.1.5.2), splines(v.4.2.2), hms(v.1.1.2), knitr(v.1.42), pillar(v.1.8.1), future.apply(v.1.10.0), reshape2(v.1.4.4), codetools(v.0.2-19), stats4(v.4.2.2), glue(v.1.6.2), evaluate(v.0.20), data.table(v.1.14.8), vctrs(v.0.5.2), tzdb(v.0.3.0), foreach(v.1.5.2), gtable(v.0.3.1), future(v.1.32.0), cachem(v.1.0.7), xfun(v.0.37), gower(v.1.0.1), prodlim(v.2019.11.13), e1071(v.1.7-13), class(v.7.3-21), survival(v.3.5-3), viridisLite(v.0.4.1), timeDate(v.4022.108), iterators(v.1.0.14), hardhat(v.1.2.0), lava(v.1.7.2.1), timechange(v.0.2.0), globals(v.0.16.2), ellipsis(v.0.3.2) and ipred(v.0.9-14)


  1. Email: | Phone: +1-513-529-4185 | Website: Miami University Official↩︎

  2. Email: | Phone: +1-937-229-2405 | Website: University of Dayton Official↩︎

  3. Email: | Phone: +1-513-529-4823 | Website: Miami University Official↩︎

  4. Email: | Website: Saint Louis University Official↩︎

