tmdb02

ds1
tidymodels
statlearning
tmdb
trees
num
Published

May 17, 2023

Aufgabe

Wir bearbeiten hier die Fallstudie TMDB Box Office Prediction - Can you predict a movie’s worldwide box office revenue?, ein Kaggle-Prognosewettbewerb.

Ziel ist es, genaue Vorhersagen zu machen, in diesem Fall für Filme.

Die Daten können Sie von der Kaggle-Projektseite beziehen oder so:

d_train_path <- "https://raw.githubusercontent.com/sebastiansauer/Lehre/main/data/tmdb-box-office-prediction/train.csv"
d_test_path <- "https://raw.githubusercontent.com/sebastiansauer/Lehre/main/data/tmdb-box-office-prediction/test.csv"
Aufgabe

Reichen Sie bei Kaggle eine Submission für die Fallstudie ein! Berichten Sie den Kaggle-Score

Hinweise:

  • Sie müssen sich bei Kaggle ein Konto anlegen (kostenlos und anonym möglich); alternativ können Sie sich mit einem Google-Konto anmelden.
  • Berechnen Sie einen Entscheidungsbaum und einen Random-Forest.
  • Tunen Sie nach Bedarf; verwenden Sie aber Default-Werte.
  • Verwenden Sie Tidymodels.











Lösung

Vorbereitung

library(tidyverse)
library(tidymodels)
library(tictoc)
library(doParallel)  # mehrere CPUs nutzen
library(finetune)  # Tune Anova
d_train <- read_csv(d_train_path)
d_test <- read_csv(d_test_path)

glimpse(d_train)
Rows: 3,000
Columns: 23
$ id                    <dbl> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1…
$ belongs_to_collection <chr> "[{'id': 313576, 'name': 'Hot Tub Time Machine C…
$ budget                <dbl> 1.40e+07, 4.00e+07, 3.30e+06, 1.20e+06, 0.00e+00…
$ genres                <chr> "[{'id': 35, 'name': 'Comedy'}]", "[{'id': 35, '…
$ homepage              <chr> NA, NA, "http://sonyclassics.com/whiplash/", "ht…
$ imdb_id               <chr> "tt2637294", "tt0368933", "tt2582802", "tt182148…
$ original_language     <chr> "en", "en", "en", "hi", "ko", "en", "en", "en", …
$ original_title        <chr> "Hot Tub Time Machine 2", "The Princess Diaries …
$ overview              <chr> "When Lou, who has become the \"father of the In…
$ popularity            <dbl> 6.575393, 8.248895, 64.299990, 3.174936, 1.14807…
$ poster_path           <chr> "/tQtWuwvMf0hCc2QR2tkolwl7c3c.jpg", "/w9Z7A0GHEh…
$ production_companies  <chr> "[{'name': 'Paramount Pictures', 'id': 4}, {'nam…
$ production_countries  <chr> "[{'iso_3166_1': 'US', 'name': 'United States of…
$ release_date          <chr> "2/20/15", "8/6/04", "10/10/14", "3/9/12", "2/5/…
$ runtime               <dbl> 93, 113, 105, 122, 118, 83, 92, 84, 100, 91, 119…
$ spoken_languages      <chr> "[{'iso_639_1': 'en', 'name': 'English'}]", "[{'…
$ status                <chr> "Released", "Released", "Released", "Released", …
$ tagline               <chr> "The Laws of Space and Time are About to be Viol…
$ title                 <chr> "Hot Tub Time Machine 2", "The Princess Diaries …
$ Keywords              <chr> "[{'id': 4379, 'name': 'time travel'}, {'id': 96…
$ cast                  <chr> "[{'cast_id': 4, 'character': 'Lou', 'credit_id'…
$ crew                  <chr> "[{'credit_id': '59ac067c92514107af02c8c8', 'dep…
$ revenue               <dbl> 12314651, 95149435, 13092000, 16000000, 3923970,…
glimpse(d_test)
Rows: 4,398
Columns: 22
$ id                    <dbl> 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, …
$ belongs_to_collection <chr> "[{'id': 34055, 'name': 'Pokémon Collection', 'p…
$ budget                <dbl> 0.00e+00, 8.80e+04, 0.00e+00, 6.80e+06, 2.00e+06…
$ genres                <chr> "[{'id': 12, 'name': 'Adventure'}, {'id': 16, 'n…
$ homepage              <chr> "http://www.pokemon.com/us/movies/movie-pokemon-…
$ imdb_id               <chr> "tt1226251", "tt0051380", "tt0118556", "tt125595…
$ original_language     <chr> "ja", "en", "en", "fr", "en", "en", "de", "en", …
$ original_title        <chr> "ディアルガVSパルキアVSダークライ", "Attack of t…
$ overview              <chr> "Ash and friends (this time accompanied by newco…
$ popularity            <dbl> 3.851534, 3.559789, 8.085194, 8.596012, 3.217680…
$ poster_path           <chr> "/tnftmLMemPLduW6MRyZE0ZUD19z.jpg", "/9MgBNBqlH1…
$ production_companies  <chr> NA, "[{'name': 'Woolner Brothers Pictures Inc.',…
$ production_countries  <chr> "[{'iso_3166_1': 'JP', 'name': 'Japan'}, {'iso_3…
$ release_date          <chr> "7/14/07", "5/19/58", "5/23/97", "9/4/10", "2/11…
$ runtime               <dbl> 90, 65, 100, 130, 92, 121, 119, 77, 120, 92, 88,…
$ spoken_languages      <chr> "[{'iso_639_1': 'en', 'name': 'English'}, {'iso_…
$ status                <chr> "Released", "Released", "Released", "Released", …
$ tagline               <chr> "Somewhere Between Time & Space... A Legend Is B…
$ title                 <chr> "Pokémon: The Rise of Darkrai", "Attack of the 5…
$ Keywords              <chr> "[{'id': 11451, 'name': 'pok√©mon'}, {'id': 1155…
$ cast                  <chr> "[{'cast_id': 3, 'character': 'Tonio', 'credit_i…
$ crew                  <chr> "[{'credit_id': '52fe44e7c3a368484e03d683', 'dep…

Rezept

Rezept definieren

rec1 <-
  recipe(revenue ~ ., data = d_train) %>% 
  update_role(all_predictors(), new_role = "id") %>% 
  update_role(popularity, runtime, revenue, budget) %>% 
  update_role(revenue, new_role = "outcome") %>% 
  step_mutate(budget = ifelse(budget < 10, 10, budget)) %>% 
  step_log(budget) %>% 
  step_impute_knn(all_predictors())

rec1

Check das Rezept

rec1_prepped <-
  prep(rec1, verbose = TRUE)
oper 1 step mutate [training] 
oper 2 step log [training] 
oper 3 step impute knn [training] 
The retained training set is ~ 28.71 Mb  in memory.
rec1_prepped
d_train_baked <-
  rec1_prepped %>% 
  bake(new_data = NULL) 

head(d_train_baked)
# A tibble: 6 × 23
     id belongs_to_collection   budget genres homepage imdb_id original_language
  <dbl> <fct>                    <dbl> <fct>  <fct>    <fct>   <fct>            
1     1 [{'id': 313576, 'name'…  16.5  [{'id… <NA>     tt2637… en               
2     2 [{'id': 107674, 'name'…  17.5  [{'id… <NA>     tt0368… en               
3     3 <NA>                     15.0  [{'id… http://… tt2582… en               
4     4 <NA>                     14.0  [{'id… http://… tt1821… hi               
5     5 <NA>                      2.30 [{'id… <NA>     tt1380… ko               
6     6 <NA>                     15.9  [{'id… <NA>     tt0093… en               
# ℹ 16 more variables: original_title <fct>, overview <fct>, popularity <dbl>,
#   poster_path <fct>, production_companies <fct>, production_countries <fct>,
#   release_date <fct>, runtime <dbl>, spoken_languages <fct>, status <fct>,
#   tagline <fct>, title <fct>, Keywords <fct>, cast <fct>, crew <fct>,
#   revenue <dbl>

Die AV-Spalte sollte leer sein:

bake(rec1_prepped, new_data = head(d_test), all_outcomes())
# A tibble: 6 × 0
d_train_baked %>% 
  map_df(~ sum(is.na(.)))
# A tibble: 1 × 23
     id belongs_to_collection budget genres homepage imdb_id original_language
  <int>                 <int>  <int>  <int>    <int>   <int>             <int>
1     0                  2396      0      7     2054       0                 0
# ℹ 16 more variables: original_title <int>, overview <int>, popularity <int>,
#   poster_path <int>, production_companies <int>, production_countries <int>,
#   release_date <int>, runtime <int>, spoken_languages <int>, status <int>,
#   tagline <int>, title <int>, Keywords <int>, cast <int>, crew <int>,
#   revenue <int>

Keine fehlenden Werte mehr in den Prädiktoren.

Nach fehlenden Werten könnte man z.B. auch so suchen:

datawizard::describe_distribution(d_train_baked)
Variable   |     Mean |       SD |      IQR |              Range | Skewness | Kurtosis |    n | n_Missing
---------------------------------------------------------------------------------------------------------
id         |  1500.50 |   866.17 |  1500.50 |    [1.00, 3000.00] |     0.00 |    -1.20 | 3000 |         0
budget     |    12.51 |     6.44 |    14.88 |      [2.30, 19.76] |    -0.87 |    -1.09 | 3000 |         0
popularity |     8.46 |    12.10 |     6.88 | [1.00e-06, 294.34] |    14.38 |   280.10 | 3000 |         0
runtime    |   107.85 |    22.08 |    24.00 |     [0.00, 338.00] |     1.02 |     8.20 | 3000 |         0
revenue    | 6.67e+07 | 1.38e+08 | 6.66e+07 |   [1.00, 1.52e+09] |     4.54 |    27.78 | 3000 |         0

So bekommt man gleich noch ein paar Infos über die Verteilung der Variablen. Praktische Sache.

Das Test-Sample backen wir auch mal:

d_test_baked <-
  bake(rec1_prepped, new_data = d_test)

d_test_baked %>% 
  head()
# A tibble: 6 × 22
     id belongs_to_collection   budget genres homepage imdb_id original_language
  <dbl> <fct>                    <dbl> <fct>  <fct>    <fct>   <fct>            
1  3001 [{'id': 34055, 'name':…   2.30 [{'id… <NA>     <NA>    ja               
2  3002 <NA>                     11.4  [{'id… <NA>     <NA>    en               
3  3003 <NA>                      2.30 [{'id… <NA>     <NA>    en               
4  3004 <NA>                     15.7  <NA>   <NA>     <NA>    fr               
5  3005 <NA>                     14.5  [{'id… <NA>     <NA>    en               
6  3006 <NA>                      2.30 [{'id… <NA>     <NA>    en               
# ℹ 15 more variables: original_title <fct>, overview <fct>, popularity <dbl>,
#   poster_path <fct>, production_companies <fct>, production_countries <fct>,
#   release_date <fct>, runtime <dbl>, spoken_languages <fct>, status <fct>,
#   tagline <fct>, title <fct>, Keywords <fct>, cast <fct>, crew <fct>

Kreuzvalidierung

cv_scheme <- vfold_cv(d_train,
                      v = 5, 
                      repeats = 1)

Modelle

Baum

mod_tree <-
  decision_tree(cost_complexity = tune(),
                tree_depth = tune(),
                mode = "regression")

Random Forest

mod_rf <-
  rand_forest(mtry = tune(),
              min_n = tune(),
              trees = 1000,
              mode = "regression") %>% 
  set_engine("ranger", num.threads = 4)

Workflows

wf_tree <-
  workflow() %>% 
  add_model(mod_tree) %>% 
  add_recipe(rec1)

wf_rf <-
  workflow() %>% 
  add_model(mod_rf) %>% 
  add_recipe(rec1)

Fitten und tunen

Um Rechenzeit zu sparen, kann man den Parameter grid bei tune_grid() auf einen kleinen Wert setzen. Der Default ist 10. Um gute Vorhersagen zu erzielen, sollte man den Wert tendenziell noch über 10 erhöhen.

Tree

Parallele Verarbeitung starten:

cl <- makePSOCKcluster(4)  # Create 4 clusters
registerDoParallel(cl)
tic()
tree_fit <-
  wf_tree %>% 
  tune_race_anova(
    resamples = cv_scheme,
    #grid = 2
  )
toc()
37.736 sec elapsed

Hilfe zu tune_grid() bekommt man hier.

tree_fit
# Tuning results
# 5-fold cross-validation 
# A tibble: 5 × 5
  splits             id    .order .metrics          .notes          
  <list>             <chr>  <int> <list>            <list>          
1 <split [2400/600]> Fold1      3 <tibble [20 × 6]> <tibble [0 × 3]>
2 <split [2400/600]> Fold2      1 <tibble [20 × 6]> <tibble [0 × 3]>
3 <split [2400/600]> Fold3      2 <tibble [20 × 6]> <tibble [0 × 3]>
4 <split [2400/600]> Fold5      4 <tibble [16 × 6]> <tibble [0 × 3]>
5 <split [2400/600]> Fold4      5 <tibble [14 × 6]> <tibble [0 × 3]>

Steht was in den .notes?

tree_fit[[".notes"]][[2]]
# A tibble: 0 × 3
# ℹ 3 variables: location <chr>, type <chr>, note <chr>

Nein.

collect_metrics(tree_fit)
# A tibble: 14 × 8
   cost_complexity tree_depth .metric .estimator      mean     n std_err .config
             <dbl>      <int> <chr>   <chr>          <dbl> <int>   <dbl> <chr>  
 1        1.56e- 5         14 rmse    standard     8.95e+7     5 4.65e+6 Prepro…
 2        1.56e- 5         14 rsq     standard     5.82e-1     5 3.16e-2 Prepro…
 3        9.32e- 5         10 rmse    standard     8.91e+7     5 4.66e+6 Prepro…
 4        9.32e- 5         10 rsq     standard     5.85e-1     5 3.11e-2 Prepro…
 5        2.36e-10          5 rmse    standard     8.80e+7     5 4.57e+6 Prepro…
 6        2.36e-10          5 rsq     standard     5.92e-1     5 3.20e-2 Prepro…
 7        2.29e- 8         11 rmse    standard     8.93e+7     5 4.67e+6 Prepro…
 8        2.29e- 8         11 rsq     standard     5.83e-1     5 3.10e-2 Prepro…
 9        9.60e- 4          9 rmse    standard     8.84e+7     5 5.00e+6 Prepro…
10        9.60e- 4          9 rsq     standard     5.90e-1     5 3.22e-2 Prepro…
11        1.94e- 9         12 rmse    standard     8.95e+7     5 4.64e+6 Prepro…
12        1.94e- 9         12 rsq     standard     5.82e-1     5 3.10e-2 Prepro…
13        5.72e- 7          7 rmse    standard     8.83e+7     5 4.73e+6 Prepro…
14        5.72e- 7          7 rsq     standard     5.91e-1     5 3.38e-2 Prepro…
show_best(tree_fit)
Warning: No value of `metric` was given; metric 'rmse' will be used.
# A tibble: 5 × 8
  cost_complexity tree_depth .metric .estimator      mean     n  std_err .config
            <dbl>      <int> <chr>   <chr>          <dbl> <int>    <dbl> <chr>  
1        2.36e-10          5 rmse    standard   88038619.     5 4572618. Prepro…
2        5.72e- 7          7 rmse    standard   88262344.     5 4734314. Prepro…
3        9.60e- 4          9 rmse    standard   88397994.     5 5003102. Prepro…
4        9.32e- 5         10 rmse    standard   89140111.     5 4663576. Prepro…
5        2.29e- 8         11 rmse    standard   89330466.     5 4668641. Prepro…

Finalisieren

best_tree_wf <-
  wf_tree %>% 
  finalize_workflow(select_best(tree_fit))
Warning: No value of `metric` was given; metric 'rmse' will be used.
best_tree_wf
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_mutate()
• step_log()
• step_impute_knn()

── Model ───────────────────────────────────────────────────────────────────────
Decision Tree Model Specification (regression)

Main Arguments:
  cost_complexity = 2.36005153743282e-10
  tree_depth = 5

Computational engine: rpart 
tree_last_fit <-
  fit(best_tree_wf, data = d_train)

tree_last_fit
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: decision_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
3 Recipe Steps

• step_mutate()
• step_log()
• step_impute_knn()

── Model ───────────────────────────────────────────────────────────────────────
n= 3000 

node), split, n, deviance, yval
      * denotes terminal node

 1) root 3000 5.672651e+19   66725850  
   2) budget< 18.32631 2845 1.958584e+19   46935270  
     4) budget< 17.19976 2252 5.443953e+18   25901120  
       8) popularity< 9.734966 1745 1.665118e+18   17076460  
        16) popularity< 5.761331 1019 3.184962e+17    8793730  
          32) budget< 15.44456 782 1.408243e+17    6074563 *
          33) budget>=15.44456 237 1.528117e+17   17765830 *
        17) popularity>=5.761331 726 1.178595e+18   28701940  
          34) budget< 16.15249 484 6.504138e+17   21093220 *
          35) budget>=16.15249 242 4.441208e+17   43919380 *
       9) popularity>=9.734966 507 3.175231e+18   56273980  
        18) budget< 15.36217 186 3.092335e+17   24880850  
          36) popularity< 14.04031 151 1.743659e+17   20728170 *
          37) popularity>=14.04031 35 1.210294e+17   42796710 *
        19) budget>=15.36217 321 2.576473e+18   74464390  
          38) popularity< 19.64394 300 2.025184e+18   68010500 *
          39) popularity>=19.64394 21 3.602808e+17  166662900 *
     5) budget>=17.19976 593 9.361685e+18  126815400  
      10) popularity< 19.63372 570 6.590372e+18  117422100  
        20) budget< 17.86726 374 2.692151e+18   94469490  
          40) popularity< 8.444193 149 6.363495e+17   68256660 *
          41) popularity>=8.444193 225 1.885623e+18  111828200 *
        21) budget>=17.86726 196 3.325222e+18  161219400  
          42) popularity< 11.60513 126 1.693483e+18  136587100 *
          43) popularity>=11.60513 70 1.417677e+18  205557600 *
      11) popularity>=19.63372 23 1.474624e+18  359605200  
        22) runtime>=109.5 16 9.882757e+17  299077200 *
        23) runtime< 109.5 7 2.937458e+17  497955000 *
   3) budget>=18.32631 155 1.557371e+19  429978800  
     6) popularity< 17.26579 101 4.711450e+18  299997300  
      12) budget< 18.73897 67 1.671489e+18  230290900  
        24) popularity< 12.66146 40 5.426991e+17  174328700  
          48) budget< 18.44536 18 1.099070e+17  134734600 *
          49) budget>=18.44536 22 3.814856e+17  206724000 *
        25) popularity>=12.66146 27 8.179336e+17  313197700  
          50) budget< 18.52944 13 1.273606e+17  234797100 *
          51) budget>=18.52944 14 5.364675e+17  385998300 *
      13) budget>=18.73897 34 2.072879e+18  437360100  
        26) runtime< 132.5 26 1.123840e+18  391271100  
          52) popularity< 11.34182 9 9.729505e+16  248614500 *
          53) popularity>=11.34182 17 7.464210e+17  466795200 *
        27) runtime>=132.5 8 7.143147e+17  587149400 *
     7) popularity>=17.26579 54 5.964228e+18  673092200  
      14) budget< 18.99438 33 2.082469e+18  534404700  
        28) popularity< 25.35778 19 5.425201e+17  416871200 *

...
and 4 more lines.

Vorhersage Test-Sample

predict(tree_last_fit, new_data = d_test)
# A tibble: 4,398 × 1
        .pred
        <dbl>
 1   6074563.
 2   6074563.
 3  21093221.
 4  21093221.
 5   6074563.
 6  21093221.
 7   6074563.
 8  68256659.
 9  43919378.
10 205557624.
# ℹ 4,388 more rows

RF

Fitten und Tunen

Um Rechenzeit zu sparen, kann man das Objekt, wenn einmal berechnet, abspeichern unter result_obj_path auf der Festplatte und beim nächsten Mal importieren, das geht schneller als neu berechnen.

Das könnte dann z.B. so aussehen:

if (file.exists(result_obj_path)) {
  rf_fit <- read_rds(result_obj_path)
} else {
  tic()
  rf_fit <-
    wf_rf %>% 
    tune_grid(
      resamples = cv_scheme)
  toc()
}

Achtung Ein Ergebnisobjekt von der Festplatte zu laden ist gefährlich. Wenn Sie Ihr Modell verändern, aber vergessen, das Objekt auf der Festplatte zu aktualisieren, werden Ihre Ergebnisse falsch sein (da auf dem veralteten Objekt beruhend), ohne dass Sie durch eine Fehlermeldung von R gewarnt würden!

So kann man das Ergebnisobjekt auf die Festplatte schreiben:

#write_rds(rf_fit, file = "objects/tmbd_rf_fit1.rds")

Aber wir berechnen lieber neu:

tic()
rf_fit <-
  wf_rf %>% 
  tune_grid(
    resamples = cv_scheme
    #grid = 2
    )
toc()
34.282 sec elapsed
collect_metrics(rf_fit)
# A tibble: 20 × 8
    mtry min_n .metric .estimator         mean     n      std_err .config       
   <int> <int> <chr>   <chr>             <dbl> <int>        <dbl> <chr>         
 1     3    26 rmse    standard   81496992.        5 4420334.     Preprocessor1…
 2     3    26 rsq     standard          0.647     5       0.0319 Preprocessor1…
 3     1     8 rmse    standard   81104914.        5 4249148.     Preprocessor1…
 4     1     8 rsq     standard          0.651     5       0.0270 Preprocessor1…
 5     3    13 rmse    standard   82253761.        5 4204371.     Preprocessor1…
 6     3    13 rsq     standard          0.639     5       0.0316 Preprocessor1…
 7     2    16 rmse    standard   81466291.        5 4103501.     Preprocessor1…
 8     2    16 rsq     standard          0.646     5       0.0298 Preprocessor1…
 9     2    36 rmse    standard   81355080.        5 4051776.     Preprocessor1…
10     2    36 rsq     standard          0.649     5       0.0281 Preprocessor1…
11     3     5 rmse    standard   84125788.        5 4113181.     Preprocessor1…
12     3     5 rsq     standard          0.623     5       0.0347 Preprocessor1…
13     1    32 rmse    standard   82381636.        5 4069505.     Preprocessor1…
14     1    32 rsq     standard          0.645     5       0.0230 Preprocessor1…
15     1    33 rmse    standard   82130106.        5 3978566.     Preprocessor1…
16     1    33 rsq     standard          0.647     5       0.0231 Preprocessor1…
17     2    20 rmse    standard   81547269.        5 4189669.     Preprocessor1…
18     2    20 rsq     standard          0.647     5       0.0294 Preprocessor1…
19     2    23 rmse    standard   81351141.        5 4073682.     Preprocessor1…
20     2    23 rsq     standard          0.648     5       0.0285 Preprocessor1…
select_best(rf_fit)
Warning: No value of `metric` was given; metric 'rmse' will be used.
# A tibble: 1 × 3
   mtry min_n .config              
  <int> <int> <chr>                
1     1     8 Preprocessor1_Model02

Finalisieren

final_wf <-
  wf_rf %>% 
  finalize_workflow(select_best(rf_fit))
Warning: No value of `metric` was given; metric 'rmse' will be used.
final_fit <-
  fit(final_wf, data = d_train)
final_preds <- 
  final_fit %>% 
  predict(new_data = d_test) %>% 
  bind_cols(d_test)
submission <-
  final_preds %>% 
  select(id, revenue = .pred)

Abspeichern und einreichen:

write_csv(submission, file = "submission.csv")

Kaggle Score

Diese Submission erzielte einen Score von 2.7664 (RMSLE).

sol <- 2.7664

Categories:

  • ds1
  • tidymodels
  • statlearning
  • tmdb
  • trees
  • num