regr-tree02

statlearning
trees
tidymodels
string
Published

May 17, 2023

library(tidymodels)
── Attaching packages ────────────────────────────────────── tidymodels 1.1.1 ──
✔ broom        1.0.5     ✔ recipes      1.0.8
✔ dials        1.2.0     ✔ rsample      1.2.0
✔ dplyr        1.1.3     ✔ tibble       3.2.1
✔ ggplot2      3.4.4     ✔ tidyr        1.3.0
✔ infer        1.0.5     ✔ tune         1.1.2
✔ modeldata    1.2.0     ✔ workflows    1.1.3
✔ parsnip      1.1.1     ✔ workflowsets 1.0.1
✔ purrr        1.0.2     ✔ yardstick    1.2.0
── Conflicts ───────────────────────────────────────── tidymodels_conflicts() ──
✖ purrr::discard() masks scales::discard()
✖ dplyr::filter()  masks stats::filter()
✖ dplyr::lag()     masks stats::lag()
✖ recipes::step()  masks stats::step()
• Use tidymodels_prefer() to resolve common conflicts.

Aufgabe

Berechnen Sie einfaches Prognosemodell auf Basis eines Entscheidungsbaums!

Modellformel: am ~ . (Datensatz mtcars)

Berichten Sie die Modellgüte (ROC-AUC).

Hinweise:

  • Tunen Sie alle Parameter (die der Engine anbietet).
  • Führen Sie eine \(v=2\)-fache Kreuzvalidierung durch (weil die Stichprobe so klein ist).
  • Beachten Sie die üblichen Hinweise.











Lösung

Setup

library(tidymodels)
data(mtcars)
library(tictoc)  # Zeitmessung

Für Klassifikation verlangt Tidymodels eine nominale AV, keine numerische:

mtcars <-
  mtcars %>% 
  mutate(am = factor(am))

Daten teilen

d_split <- initial_split(mtcars)
d_train <- training(d_split)
d_test <- testing(d_split)

Modell(e)

mod_tree <-
  decision_tree(mode = "classification",
                cost_complexity = tune(),
                tree_depth = tune(),
                min_n = tune())

Rezept(e)

rec1 <- 
  recipe(am ~ ., data = d_train)

Resampling

rsmpl <- vfold_cv(d_train, v = 2)

Workflow

wf1 <-
  workflow() %>%  
  add_recipe(rec1) %>% 
  add_model(mod_tree)

Tuning/Fitting

fit1 <-
  tune_grid(object = wf1,
            metrics = metric_set(roc_auc),
            resamples = rsmpl)
→ A | warning: 30 samples were requested but there were 12 rows in the data. 12 will be used.
There were issues with some computations   A: x1
→ B | warning: 18 samples were requested but there were 12 rows in the data. 12 will be used.
There were issues with some computations   A: x1
→ C | warning: 27 samples were requested but there were 12 rows in the data. 12 will be used.
There were issues with some computations   A: x1
→ D | warning: 17 samples were requested but there were 12 rows in the data. 12 will be used.
There were issues with some computations   A: x1
→ E | warning: 33 samples were requested but there were 12 rows in the data. 12 will be used.
There were issues with some computations   A: x1
There were issues with some computations   A: x1   B: x1   C: x1   D: x1   E: x1
→ F | warning: 22 samples were requested but there were 12 rows in the data. 12 will be used.
There were issues with some computations   A: x1   B: x1   C: x1   D: x1   E: x1
→ G | warning: 37 samples were requested but there were 12 rows in the data. 12 will be used.
There were issues with some computations   A: x1   B: x1   C: x1   D: x1   E: x1
There were issues with some computations   A: x2   B: x2   C: x2   D: x2   E: x…

Bester Kandidat

autoplot(fit1)

show_best(fit1)
# A tibble: 5 × 9
  cost_complexity tree_depth min_n .metric .estimator  mean     n std_err
            <dbl>      <int> <int> <chr>   <chr>      <dbl> <int>   <dbl>
1        5.46e- 2          8     3 roc_auc binary     0.879     2  0.121 
2        4.23e- 5         13    30 roc_auc binary     0.816     2  0.0589
3        1.06e- 7          2    18 roc_auc binary     0.816     2  0.0589
4        2.41e- 5         15     8 roc_auc binary     0.816     2  0.0589
5        9.18e-10         10    11 roc_auc binary     0.816     2  0.0589
# ℹ 1 more variable: .config <chr>

Finalisieren

wf1_finalized <-
  wf1 %>% 
  finalize_workflow(select_best(fit1))

Last Fit

final_fit <- 
  last_fit(object = wf1_finalized, d_split)

collect_metrics(final_fit)
# A tibble: 2 × 4
  .metric  .estimator .estimate .config             
  <chr>    <chr>          <dbl> <chr>               
1 accuracy binary             1 Preprocessor1_Model1
2 roc_auc  binary             1 Preprocessor1_Model1

Categories:

  • statlearning
  • trees
  • tidymodels
  • string