tidymodels-lasso2

tidymodels
statlearning
lasso
lm
string
Published

May 17, 2023

Aufgabe

Schreiben Sie eine minimale Analyse für ein Vorhersagemodell mit dem Lasso.

Hinweise:

  • Verzichten Sie auf Tuning der Penalisierung; setzen Sie den Wert auf 0.1
  • Verzichten Sie auf die Unterteilung von Train- und Test-Set.
  • Verzichten Sie auf Kreuzvalidierung.
  • Verwenden Sie Standardwerte, wo nicht anders angegeben.
  • Fixieren Sie Zufallszahlen auf den Startwert 42.
  • Verwenden Sie den Datensatz penguins.
  • Modellformel: body_mass_g ~ .











Lösung

# 2023-05-14

# Setup:
library(tidymodels)
library(tidyverse)
library(tictoc)  # Zeitmessung


# Data:
d_path <- "https://vincentarelbundock.github.io/Rdatasets/csv/palmerpenguins/penguins.csv"
d <- read_csv(d_path)

# drop rows with NA in outcome variable:
d <-
  d %>% 
  drop_na(body_mass_g)

set.seed(42)
d_split <- initial_split(d)
# d_train <- training(d_split)
# d_test <- testing(d_split)


# model:
mod_lasso <-
  linear_reg(mode = "regression",
             penalty = 0.1,
             mixture = 1,
             engine = "glmnet")

# cv:
# set.seed(42)
# rsmpl <- vfold_cv(d_train)


# recipe:
rec1_plain <- 
  recipe(body_mass_g ~  ., data = d) %>% 
  update_role("rownames", new_role = "id") %>% 
  step_normalize(all_numeric_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_impute_bag(all_predictors())


# check:
d_train_baked <- 
  prep(rec1_plain) %>% bake(new_data = NULL)

na_n <- sum(is.na(d_train_baked))


# workflow:
wf1 <-
  workflow() %>% 
  add_model(mod_lasso) %>% 
  add_recipe(rec1_plain)


# tuning:
tic()
wf1_fit <-
  wf1 %>% 
  fit(data = d)
toc()
1.223 sec elapsed
# best candidate:
wf1_fit
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_normalize()
• step_dummy()
• step_impute_bag()

── Model ───────────────────────────────────────────────────────────────────────

Call:  glmnet::glmnet(x = maybe_matrix(x), y = y, family = "gaussian",      alpha = ~1) 

   Df  %Dev Lambda
1   0  0.00 697.60
2   1 12.89 635.70
3   1 23.58 579.20
4   1 32.47 527.70
5   1 39.84 480.90
6   1 45.96 438.10
7   1 51.05 399.20
8   2 55.36 363.80
9   2 59.11 331.40
10  2 62.22 302.00
11  2 64.81 275.20
12  2 66.95 250.70
13  3 69.54 228.40
14  3 72.37 208.20
15  3 74.73 189.70
16  3 76.68 172.80
17  3 78.30 157.50
18  3 79.65 143.50
19  3 80.77 130.70
20  3 81.70 119.10
21  3 82.47 108.50
22  3 83.11  98.89
23  3 83.64  90.10
24  3 84.08  82.10
25  3 84.45  74.81
26  3 84.75  68.16
27  3 85.00  62.11
28  3 85.21  56.59
29  3 85.39  51.56
30  4 85.54  46.98
31  5 85.69  42.81
32  5 85.80  39.00
33  5 85.90  35.54
34  6 86.01  32.38
35  7 86.17  29.50
36  7 86.31  26.88
37  7 86.43  24.50
38  7 86.53  22.32
39  7 86.62  20.34
40  7 86.68  18.53
41  7 86.74  16.88
42  7 86.79  15.38
43  8 86.83  14.02
44  8 86.92  12.77
45  8 86.99  11.64
46  8 87.05  10.60

...
and 24 more lines.
# Modellgüte:

predict(wf1_fit, new_data = d) %>% 
  bind_cols(d %>% select(body_mass_g)) %>% 
  rmse(truth = body_mass_g,
       estimate = .pred)
# A tibble: 1 × 3
  .metric .estimator .estimate
  <chr>   <chr>          <dbl>
1 rmse    standard        285.

Man beachte: Für regulierte Modelle sind Zentrierung und Skalierung nötig.


Categories:

  • tidymodels
  • statlearning
  • lasso
  • lm
  • simple
  • string
  • template