tidymodels1

ds1
tidymodels
prediction
yacsda
statlearning
dyn
schoice
Published

May 17, 2023

Aufgabe

Prof. Salzig übt sich im statistischen Lernen. Dazu will er das Überleben im Titanic-Unglück Vorhersagen; es handelt sich um eine klassische Aufgabe im statistischen Lernen. Betrachten Sie dazu den folgenden R-Code sowie die Kommentare dazu. Wählen Sie die am besten passende Aussage.

Zuerst lädt er die nötigen R-Pakete:

library(tidyverse)  # data wrangling
library(tidymodels)  # modelling
library(broom)  # tidy model output
library(parallel)  # multiple cores -- *nix only, d.h. Mac und Linux
library(finetune)  # tune race anova

Dann initialisiert er die Anzahl der Prozessoren auf seinem Computer:

cores <- parallel::detectCores(logical = FALSE)
cores
[1] 4

Daten importieren:

data_path <- "https://raw.githubusercontent.com/sebastiansauer/Lehre"
traindata_path_url  <- "/main/data/titanic/titanic_train.csv"
testdata_path_url <- "/main/data/titanic/titanic_test.csv"

traindata_url <- paste0(data_path, traindata_path_url)
testdata_url <- paste0(data_path, testdata_path_url)


# import the data:
train_raw <- read_csv(traindata_url)
test <- read_csv(testdata_url)

Und aufbereiten:

# drop unused variables:
train <-
  train_raw %>% 
  select(-c(Name, Cabin, Ticket))

# convert string to factors:
train2 <- 
  train %>% 
  mutate(across(where(is.character), as.factor))
  
# convert numeric outcome to nominal, to indicate classification:
train2 <- 
  train2 %>% 
  mutate(Survived = as.factor(Survived))

Gibt es fehlende Werte in der AV?

sum(is.na(train2$Survived))
[1] 0

Vorverarbeitung des Datensatzes macht er via ein recipe aus tidymodels:

titanic_recipe <- 
  
  # define model formula:
  recipe(Survived ~ ., data = train2) %>%
  
  # Use "ID" etc as ID, not as predictor:
  update_role(PassengerId, new_role = "ID") %>% 
  
   # impute missing values:
  step_impute_bag(all_predictors()) %>% 
  
  # convert to dummy variables:
  step_dummy(all_nominal_predictors())

Check no missings:

titanic_train_baked <- titanic_recipe %>% prep() %>% bake(new_data = NULL)

sum(is.na(titanic_train_baked))
[1] 0

Dann definiert ein ein Modell:

rf_mod2 <- 
  rand_forest(mtry = tune(), # tune mtry
              min_n = tune(), # tune minimal n per node
              trees = 1000) %>%  # set number of trees to 1000
  set_engine("ranger", 
             num.threads = cores) %>% 
  set_mode("classification")

… und ein Kreuzvalidierungsschema:

train_cv <- vfold_cv(train2, 
                     v = 10,
                     repeats = 1, 
                     strata = "Survived")

Aus der Hilfe zu vfold_cv:


V-Fold Cross-Validation

Description

V-fold cross-validation randomly splits the data into V groups of roughly equal size (called “folds”). A resample of the analysis data consisted of V-1 of the folds while the assessment set contains the final fold. In basic V-fold cross-validation (i.e. no repeats), the number of resamples is equal to V.

Usage

vfold_cv(data, v = 10, repeats = 1, strata = NULL, breaks = 4, ...)

Arguments

data A data frame.

v The number of partitions of the data set.

repeats The number of times to repeat the V-fold partitioning.

strata A variable that is used to conduct stratified sampling to create the folds. This could be a single character value or a variable name that corresponds to a variable that exists in the data frame.

breaks A single number giving the number of bins desired to stratify a numeric stratification variable.

... Not currently used.

Details

The strata argument causes the random sampling to be conducted within the stratification variable. This can help ensure that the number of data points in the analysis data is equivalent to the proportions in the original data set. (Strata below 10% of the total are pooled together.) When more than one repeat is requested, the basic V-fold cross-validation is conducted each time. For example, if three repeats are used with v = 10, there are a total of 30 splits which as three groups of 10 that are generated separately.


So entsteht dieser Workflow:

titanic_rf_wf2 <-
  workflow() %>% 
  add_model(rf_mod2) %>% 
  add_recipe(titanic_recipe)

Jetzt: Fit the grid!

set.seed(42)

n_candidates <- 2

rf_res2 <- 
  titanic_rf_wf2 %>% 
  tune_race_anova(
    resamples = train_cv,
    grid = n_candidates,  # test 25 different tuning parameter values
    #control = control_grid(save_pred = TRUE),
    metrics = metric_set(roc_auc))

Mit dem Parameter grid kann man die Anzahl der zu berechnenden Kandidaten-Modelle festlegen.

Für gute Vorhersagen bieten sich hohe Werte an; das kostet aber Rechenzeit.

Aus den Resampling-Kandidaten wählt er nun das beste aus:

rf_best2 <- 
  rf_res2 %>% 
  select_best(metric = "roc_auc")
rf_best2
# A tibble: 1 × 3
   mtry min_n .config             
  <int> <int> <chr>               
1     3    12 Preprocessor1_Model1

Das beste Kandidatenmodell nutzt er nun, um den ganzen Train-Datensatz zu “fitten”:

# write best parameter values to the workflow:
rf_final_wf2 <- 
  titanic_rf_wf2 %>% 
  finalize_workflow(rf_best2)

# fit the model:
rf_final_model2 <- 
rf_final_wf2 %>% 
  fit(train2)

Zum Abschluss speichert er die Vorhersagen, die er dann bei Kaggle einreichen will:

rf2_preds <- 
  predict(rf_final_model2, new_data = test)  # compute prediction on test set

Ein letzter Blick auf die Verteilung der vorhergesagten Werte:

count(rf2_preds, .pred_class)
# A tibble: 2 × 2
  .pred_class     n
  <fct>       <int>
1 0             284
2 1             134

Auf Basis dieser Analyse: Wählen Sie am besten passende Aussage!

Answerlist

  • Es wurden 2 Kandidaten von Tuningparameterwerten in die Analyse einbezogen.
  • Es wurde kein Parameter-Tuning durchgeführt.
  • Die Metrik \(AUC\) sollte nicht für Klassifikationsmodelle verwendet werden.
  • Es wurde eine 10-fache Kreuzvalidierung (ohne Wiederholungen) verwendet.
  • Die Anzahl der Bäume im Random Forest wurde hier nicht ins Parametertuning einbezogen; allerdings wäre es sinnvoll (und üblich), dies zu tun.
  • der Parameter mtry wurde hier nicht ins Parametertuning einbezogen.











Lösung

Answerlist

  • Wahr
  • Falsch
  • Falsch
  • Falsch
  • Falsch
  • Falsch

Categories:

  • ds1
  • tidymodels
  • prediction
  • yacsda
  • statlearning
  • dyn
  • schoice