tmdb07

ds1
tidymodels
statlearning
tmdb
num
Published

May 17, 2023

Aufgabe

Melden Sie sich an für die Kaggle Competition TMDB Box Office Prediction - Can you predict a movie’s worldwide box office revenue?.

Sie benötigen dazu ein Konto; es ist auch möglich, sich mit seinem Google-Konto anzumelden.

Bei diesem Prognosewettbewerb geht es darum, vorherzusagen, wieviel Umsatz wohl einige Filme machen werden. Als Prädiktoren stehen einige Infos wie Budget, Genre, Titel etc. zur Verfügung. Eine klassische “predictive Competition” also :-) Allerdings können immer ein paar Schwierigkeiten auftreten ;-)

Aufgabe

Erstellen Sie ein Lineares-Modell mit Regularisierung mit Tidymodels!

Hinweise

  • Verzichten Sie auf Vorverarbeitung.
  • Tunen Sie die typischen Parameter.
  • Reichen Sie das Modell ein und berichten Sie Ihren Score.
  • Begrenzen Sie sich auf folgende Prädiktoren.
preds_chosen <- 
  c("id", "budget", "popularity", "runtime")











Lösung

Pakete starten

library(tidyverse)
library(tidymodels)
library(finetune)
library(doParallel)
library(tictoc)

Daten importieren

d_train_path <- "https://raw.githubusercontent.com/sebastiansauer/Lehre/main/data/tmdb-box-office-prediction/train.csv"
d_test_path <- "https://raw.githubusercontent.com/sebastiansauer/Lehre/main/data/tmdb-box-office-prediction/test.csv"

d_train <- read_csv(d_train_path)
d_test <- read_csv(d_test_path)

Werfen wir einen Blick in die Daten:

glimpse(d_train)
Rows: 3,000
Columns: 23
$ id                    <dbl> 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 1…
$ belongs_to_collection <chr> "[{'id': 313576, 'name': 'Hot Tub Time Machine C…
$ budget                <dbl> 1.40e+07, 4.00e+07, 3.30e+06, 1.20e+06, 0.00e+00…
$ genres                <chr> "[{'id': 35, 'name': 'Comedy'}]", "[{'id': 35, '…
$ homepage              <chr> NA, NA, "http://sonyclassics.com/whiplash/", "ht…
$ imdb_id               <chr> "tt2637294", "tt0368933", "tt2582802", "tt182148…
$ original_language     <chr> "en", "en", "en", "hi", "ko", "en", "en", "en", …
$ original_title        <chr> "Hot Tub Time Machine 2", "The Princess Diaries …
$ overview              <chr> "When Lou, who has become the \"father of the In…
$ popularity            <dbl> 6.575393, 8.248895, 64.299990, 3.174936, 1.14807…
$ poster_path           <chr> "/tQtWuwvMf0hCc2QR2tkolwl7c3c.jpg", "/w9Z7A0GHEh…
$ production_companies  <chr> "[{'name': 'Paramount Pictures', 'id': 4}, {'nam…
$ production_countries  <chr> "[{'iso_3166_1': 'US', 'name': 'United States of…
$ release_date          <chr> "2/20/15", "8/6/04", "10/10/14", "3/9/12", "2/5/…
$ runtime               <dbl> 93, 113, 105, 122, 118, 83, 92, 84, 100, 91, 119…
$ spoken_languages      <chr> "[{'iso_639_1': 'en', 'name': 'English'}]", "[{'…
$ status                <chr> "Released", "Released", "Released", "Released", …
$ tagline               <chr> "The Laws of Space and Time are About to be Viol…
$ title                 <chr> "Hot Tub Time Machine 2", "The Princess Diaries …
$ Keywords              <chr> "[{'id': 4379, 'name': 'time travel'}, {'id': 96…
$ cast                  <chr> "[{'cast_id': 4, 'character': 'Lou', 'credit_id'…
$ crew                  <chr> "[{'credit_id': '59ac067c92514107af02c8c8', 'dep…
$ revenue               <dbl> 12314651, 95149435, 13092000, 16000000, 3923970,…
glimpse(d_test)
Rows: 4,398
Columns: 22
$ id                    <dbl> 3001, 3002, 3003, 3004, 3005, 3006, 3007, 3008, …
$ belongs_to_collection <chr> "[{'id': 34055, 'name': 'Pokémon Collection', 'p…
$ budget                <dbl> 0.00e+00, 8.80e+04, 0.00e+00, 6.80e+06, 2.00e+06…
$ genres                <chr> "[{'id': 12, 'name': 'Adventure'}, {'id': 16, 'n…
$ homepage              <chr> "http://www.pokemon.com/us/movies/movie-pokemon-…
$ imdb_id               <chr> "tt1226251", "tt0051380", "tt0118556", "tt125595…
$ original_language     <chr> "ja", "en", "en", "fr", "en", "en", "de", "en", …
$ original_title        <chr> "ディアルガVSパルキアVSダークライ", "Attack of t…
$ overview              <chr> "Ash and friends (this time accompanied by newco…
$ popularity            <dbl> 3.851534, 3.559789, 8.085194, 8.596012, 3.217680…
$ poster_path           <chr> "/tnftmLMemPLduW6MRyZE0ZUD19z.jpg", "/9MgBNBqlH1…
$ production_companies  <chr> NA, "[{'name': 'Woolner Brothers Pictures Inc.',…
$ production_countries  <chr> "[{'iso_3166_1': 'JP', 'name': 'Japan'}, {'iso_3…
$ release_date          <chr> "7/14/07", "5/19/58", "5/23/97", "9/4/10", "2/11…
$ runtime               <dbl> 90, 65, 100, 130, 92, 121, 119, 77, 120, 92, 88,…
$ spoken_languages      <chr> "[{'iso_639_1': 'en', 'name': 'English'}, {'iso_…
$ status                <chr> "Released", "Released", "Released", "Released", …
$ tagline               <chr> "Somewhere Between Time & Space... A Legend Is B…
$ title                 <chr> "Pokémon: The Rise of Darkrai", "Attack of the 5…
$ Keywords              <chr> "[{'id': 11451, 'name': 'pok√©mon'}, {'id': 1155…
$ cast                  <chr> "[{'cast_id': 3, 'character': 'Tonio', 'credit_i…
$ crew                  <chr> "[{'credit_id': '52fe44e7c3a368484e03d683', 'dep…

Resampling / Cross-Validation-Scheme

cv_scheme <- vfold_cv(d_train)

Kleine Werte für \(v\) wie \(v=3\) kann man wählen, um Rechenzeit zu sparen. Das ist gerade fürs Debuggen nützlich. Für die “Wirklichkeit” ist ein höherer Wert besser, z.B. \(v=10\) (der Defaultwert)

Rezept

rec1 <- 
  recipe(revenue ~ budget + popularity + runtime, data = d_train) %>% 
  step_impute_bag(all_predictors()) %>% 
  step_naomit(all_predictors()) 
rec1

Modell

model_lm <- linear_reg(penalty = tune(),
                       engine = "glmnet")

Workflow

wf1 <-
  workflow() %>% 
  add_model(model_lm) %>% 
  add_recipe(rec1)

Modell fitten (und tunen)

Parallele Verarbeitung starten:

cl <- makePSOCKcluster(4)  # Create 4 clusters
registerDoParallel(cl)
tic()
lm_fit1 <-
  wf1 %>% 
  tune_race_anova(resamples = cv_scheme)
toc()
18.691 sec elapsed
lm_fit1 %>% show_best()
Warning: No value of `metric` was given; metric 'rmse' will be used.
# A tibble: 5 × 7
   penalty .metric .estimator      mean     n  std_err .config              
     <dbl> <chr>   <chr>          <dbl> <int>    <dbl> <chr>                
1 1.62e-10 rmse    standard   84540989.    10 5365801. Preprocessor1_Model01
2 3.06e- 9 rmse    standard   84540989.    10 5365801. Preprocessor1_Model02
3 1.51e- 8 rmse    standard   84540989.    10 5365801. Preprocessor1_Model03
4 9.21e- 7 rmse    standard   84540989.    10 5365801. Preprocessor1_Model04
5 2.83e- 6 rmse    standard   84540989.    10 5365801. Preprocessor1_Model05

Finalisieren

wf1_final <-
  wf1 %>% 
  finalize_workflow(select_best(lm_fit1, metric = "rmse"))

Final Fit

fit1_final <-
  wf1_final %>% 
  fit(d_train)

fit1_final
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: linear_reg()

── Preprocessor ────────────────────────────────────────────────────────────────
2 Recipe Steps

• step_impute_bag()
• step_naomit()

── Model ───────────────────────────────────────────────────────────────────────

Call:  glmnet::glmnet(x = maybe_matrix(x), y = y, family = "gaussian") 

   Df  %Dev    Lambda
1   0  0.00 103500000
2   1  9.63  94340000
3   1 17.62  85960000
4   1 24.25  78320000
5   1 29.76  71370000
6   1 34.33  65030000
7   1 38.13  59250000
8   1 41.28  53990000
9   1 43.90  49190000
10  1 46.07  44820000
11  2 48.25  40840000
12  2 50.48  37210000
13  2 52.34  33900000
14  2 53.88  30890000
15  2 55.15  28150000
16  2 56.21  25650000
17  2 57.09  23370000
18  2 57.82  21290000
19  2 58.43  19400000
20  2 58.93  17680000
21  2 59.35  16110000
22  2 59.70  14680000
23  2 59.99  13370000
24  2 60.23  12180000
25  2 60.42  11100000
26  2 60.59  10120000
27  2 60.73   9217000
28  2 60.84   8398000
29  2 60.93   7652000
30  2 61.01   6973000
31  2 61.08   6353000
32  2 61.13   5789000
33  2 61.18   5274000
34  2 61.21   4806000
35  3 61.25   4379000
36  3 61.29   3990000
37  3 61.32   3635000
38  3 61.34   3313000
39  3 61.36   3018000
40  3 61.38   2750000
41  3 61.39   2506000
42  3 61.40   2283000
43  3 61.41   2080000
44  3 61.42   1896000
45  3 61.43   1727000
46  3 61.43   1574000

...
and 12 more lines.
preds <-
  fit1_final %>% 
  predict(d_test)

Submission df

submission_df <-
  d_test %>% 
  select(id) %>% 
  bind_cols(preds) %>% 
  rename(revenue = .pred)

head(submission_df)
# A tibble: 6 × 2
     id   revenue
  <dbl>     <dbl>
1  3001 -3508554.
2  3002 -7712533.
3  3003  8857329.
4  3004 31400199.
5  3005   101521.
6  3006 13470119.

Abspeichern und einreichen:

#write_csv(submission_df, file = "submission.csv")

Kaggle Score

Diese Submission erzielte einen Score von Score: 6.14787 (RMSLE).

sol <- 6.14787

Categories:

  • ds1
  • tidymodels
  • statlearning
  • tmdb
  • num