class: center, middle, inverse, title-slide # Cross validation
⛑ ### S. Mason Garrison --- layout: true <div class="my-footer"> <span> <a href="https://DataScience4Psych.github.io/DataScience4Psych/" target="_blank">Data Science for Psychologists</a> </span> </div> --- class: middle # Data and exploration --- background-image: url("img/the-office.jpeg") class: middle --- ## Data ```r office_ratings <- read_csv("data/office_ratings.csv") office_ratings ``` ``` ## # A tibble: 188 x 6 ## season episode title imdb_rating total_votes air_date ## <dbl> <dbl> <chr> <dbl> <dbl> <date> ## 1 1 1 Pilot 7.6 3706 2005-03-24 ## 2 1 2 Diversity D~ 8.3 3566 2005-03-29 ## 3 1 3 Health Care 7.9 2983 2005-04-05 ## 4 1 4 The Alliance 8.1 2886 2005-04-12 ## 5 1 5 Basketball 8.4 3179 2005-04-19 ## 6 1 6 Hot Girl 7.8 2852 2005-04-26 ## 7 2 1 The Dundies 8.7 3213 2005-09-20 ## 8 2 2 Sexual Hara~ 8.2 2736 2005-09-27 ## 9 2 3 Office Olym~ 8.4 2742 2005-10-04 ## 10 2 4 The Fire 8.4 2713 2005-10-11 ## # ... with 178 more rows ``` .footnote[ .small[ Source: The data come from [data.world](https://data.world/anujjain7/the-office-imdb-ratings-dataset), by way of [TidyTuesday](https://github.com/rfordatascience/tidytuesday/blob/master/data/2020/2020-03-17/readme.md). ] ] --- ## IMDB ratings .pull-left-wide[ <img src="d25_crossvalidation_files/figure-html/unnamed-chunk-3-1.png" width="100%" style="display: block; margin: auto;" /> ] .small.pull-right-narrow[ ```r ggplot(office_ratings, aes(x = imdb_rating)) + geom_histogram(binwidth = 0.25) + labs( title = "The Office Ratings", x = "IMDB Rating" ) ``` ] --- ## IMDB ratings vs. number of votes .pull-left-wide[ <img src="d25_crossvalidation_files/figure-html/unnamed-chunk-4-1.png" width="100%" style="display: block; margin: auto;" /> ] .small.pull-right-narrow[ ```r ggplot(office_ratings, aes(x = total_votes, y = imdb_rating, color = season)) + geom_jitter(alpha = 0.7) + labs( title = "The Office Ratings", x = "Total Votes", y = "IMDB Rating", color = "Season" ) ``` ] --- ## Outliers .pull-left-wide[ <img src="d25_crossvalidation_files/figure-html/unnamed-chunk-5-1.png" width="100%" style="display: block; margin: auto;" /> ] .small.pull-right-narrow[ ```r ggplot(office_ratings, aes(x = total_votes, y = imdb_rating)) + geom_jitter() + gghighlight(total_votes > 4000, label_key = title) + labs( title = "The Office ratings", x = "Total Votes", y = "IMDB Rating" ) ``` ] .footnote[ .small[ If you like the [Dinner Party](https://www.imdb.com/title/tt1031477/) episode, I highly recommend this ["oral history"](https://www.rollingstone.com/tv/tv-features/that-one-night-the-oral-history-of-the-greatest-office-episode-ever-629472/) of the episode published on Rolling Stone magazine. ] ] --- ## IMDB ratings vs. seasons .pull-left-wide[ <img src="d25_crossvalidation_files/figure-html/unnamed-chunk-6-1.png" width="100%" style="display: block; margin: auto;" /> ] .small.pull-right-narrow[ ```r ggplot(office_ratings, aes(x = factor(season), y = imdb_rating, color = season)) + geom_boxplot() + geom_jitter() + guides(color = FALSE) + labs( title = "The Office Ratings", x = "Season", y = "IMDB Rating" ) ``` ] --- class: middle # Modeling --- ## Train / Test - Create an initial split ```r set.seed(1122) office_split <- initial_split(office_ratings) # prop = 3/4 by default ``` -- .pull-left[ - Save training data .medi[ ```r office_train <- training(office_split) dim(office_train) ``` ``` ## [1] 141 6 ``` ]] -- .pull-right[ - Save testing data .medi[ ```r office_test <- testing(office_split) dim(office_test) ``` ``` ## [1] 47 6 ``` ]] --- ## Specify model ```r office_mod <- linear_reg() %>% set_engine("lm") office_mod ``` ``` ## Linear Regression Model Specification (regression) ## ## Computational engine: lm ``` --- ## Build recipe .panel[.panel-name[Code] ```r office_rec <- recipe(imdb_rating ~ ., data = office_train) %>% # title isn't a predictor, but keep around to ID update_role(title, new_role = "ID") %>% # extract month of air_date step_date(air_date, features = "month") %>% step_rm(air_date) %>% # make dummy variables of month step_dummy(contains("month")) %>% # remove zero variance predictors step_zv(all_predictors()) ``` ] --- ## Build recipe .panel[.panel-name[Output] .small[ ```r office_rec ``` ``` ## Recipe ## ## Inputs: ## ## role #variables ## ID 1 ## outcome 1 ## predictor 4 ## ## Operations: ## ## Date features from air_date ## Variables removed air_date ## Dummy variables from contains("month") ## Zero variance filter on all_predictors() ``` ] ] --- ## Build workflow .panelset[ .panel[.panel-name[Code] ```r office_wflow <- workflow() %>% add_model(office_mod) %>% add_recipe(office_rec) ``` ] .panel[.panel-name[Output] .small[ ```r office_wflow ``` ``` ## == Workflow ===================================================== ## Preprocessor: Recipe ## Model: linear_reg() ## ## -- Preprocessor ------------------------------------------------- ## 4 Recipe Steps ## ## * step_date() ## * step_rm() ## * step_dummy() ## * step_zv() ## ## -- Model -------------------------------------------------------- ## Linear Regression Model Specification (regression) ## ## Computational engine: lm ``` ] ] ] --- ## Fit model .panelset[ .panel[.panel-name[Code] ```r office_fit <- office_wflow %>% fit(data = office_train) ``` ] .panel[.panel-name[Output] .small[ ```r tidy(office_fit) %>% print(n = 12) ``` ``` ## # A tibble: 12 x 5 ## term estimate std.error statistic p.value ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 (Intercept) 7.23 0.205 35.4 3.14e-68 ## 2 season -0.0499 0.0157 -3.18 1.86e- 3 ## 3 episode 0.0353 0.0101 3.50 6.44e- 4 ## 4 total_votes 0.000352 0.0000448 7.85 1.39e-12 ## 5 air_date_month_Feb 0.0242 0.147 0.165 8.69e- 1 ## 6 air_date_month_Mar -0.145 0.144 -1.01 3.16e- 1 ## 7 air_date_month_Apr -0.106 0.140 -0.759 4.49e- 1 ## 8 air_date_month_May 0.0575 0.175 0.329 7.43e- 1 ## 9 air_date_month_Sep 0.440 0.191 2.30 2.30e- 2 ## 10 air_date_month_Oct 0.321 0.150 2.13 3.50e- 2 ## 11 air_date_month_Nov 0.237 0.138 1.72 8.81e- 2 ## 12 air_date_month_Dec 0.443 0.190 2.34 2.09e- 2 ``` ] ] ] --- class: middle # Evaluate model --- ## Make predictions for training data ```r office_train_pred <- predict(office_fit, office_train) %>% bind_cols(office_train %>% select(imdb_rating, title)) office_train_pred ``` ``` ## # A tibble: 141 x 3 ## .pred imdb_rating title ## <dbl> <dbl> <chr> ## 1 7.90 8.1 Garden Party ## 2 8.43 7.9 The Chump ## 3 7.81 7.1 Here Comes Treble ## 4 7.94 6.7 Get the Girl ## 5 7.92 7.9 Tallahassee ## 6 8.29 7.7 The Inner Circle ## 7 7.95 7.8 The Sting ## 8 8.00 7.8 WUPHF.com ## 9 9.56 9.6 Stress Relief ## 10 8.11 8.1 Manager and Salesman ## # ... with 131 more rows ``` --- ## R-squared Percentage of variability in the IMDB ratings explained by the model ```r rsq(office_train_pred, truth = imdb_rating, estimate = .pred) ``` ``` ## # A tibble: 1 x 3 ## .metric .estimator .estimate ## <chr> <chr> <dbl> ## 1 rsq standard 0.500 ``` -- .question[ Are models with high or low `\(R^2\)` more preferable? ] --- ## RMSE An alternative model performance statistic: **root mean square error** $$ RMSE = \sqrt{\frac{\sum_{i = 1}^n (y_i - \hat{y}_i)^2}{n}} $$ -- ```r rmse(office_train_pred, truth = imdb_rating, estimate = .pred) ``` ``` ## # A tibble: 1 x 3 ## .metric .estimator .estimate ## <chr> <chr> <dbl> ## 1 rmse standard 0.373 ``` -- .question[ Are models with high or low RMSE are more preferable? ] --- ## Interpreting RMSE .question[ Is this RMSE considered low or high? ] ```r rmse(office_train_pred, truth = imdb_rating, estimate = .pred) ``` ``` ## # A tibble: 1 x 3 ## .metric .estimator .estimate ## <chr> <chr> <dbl> ## 1 rmse standard 0.373 ``` -- ```r office_train %>% summarize(min = min(imdb_rating), max = max(imdb_rating)) ``` ``` ## # A tibble: 1 x 2 ## min max ## <dbl> <dbl> ## 1 6.7 9.7 ``` --- class: middle .hand[ .light-blue[ but, really, who cares about predictions on .pink[training] data? ] ] --- ## Make predictions for testing data ```r office_test_pred <- predict(office_fit, office_test) %>% bind_cols(office_test %>% select(imdb_rating, title)) office_test_pred ``` ``` ## # A tibble: 47 x 3 ## .pred imdb_rating title ## <dbl> <dbl> <chr> ## 1 8.52 8.4 Office Olympics ## 2 8.54 8.6 The Client ## 3 8.90 8.8 Christmas Party ## 4 8.71 9 The Injury ## 5 8.50 8.2 Boys and Girls ## 6 8.46 8.4 Dwight's Speech ## 7 8.64 8.9 Gay Witch Hunt ## 8 8.35 8 Diwali ## 9 8.77 8.7 A Benihana Christmas ## 10 8.40 8.2 Ben Franklin ## # ... with 37 more rows ``` --- ## Evaluate performance on testing data - RMSE of model fit to testing data .medi[ ```r rmse(office_test_pred, truth = imdb_rating, estimate = .pred) ``` ``` ## # A tibble: 1 x 3 ## .metric .estimator .estimate ## <chr> <chr> <dbl> ## 1 rmse standard 0.386 ``` ] - `\(R^2\)` of model fit to testing data .medi[ ```r rsq(office_test_pred, truth = imdb_rating, estimate = .pred) ``` ``` ## # A tibble: 1 x 3 ## .metric .estimator .estimate ## <chr> <chr> <dbl> ## 1 rsq standard 0.556 ``` ] --- ## Training vs. testing <br> <table> <thead> <tr> <th style="text-align:left;"> metric </th> <th style="text-align:right;"> train </th> <th style="text-align:right;"> test </th> <th style="text-align:left;"> comparison </th> </tr> </thead> <tbody> <tr> <td style="text-align:left;"> RMSE </td> <td style="text-align:right;"> 0.373 </td> <td style="text-align:right;"> 0.386 </td> <td style="text-align:left;"> RMSE lower for training </td> </tr> <tr> <td style="text-align:left;"> R-squared </td> <td style="text-align:right;"> 0.500 </td> <td style="text-align:right;"> 0.556 </td> <td style="text-align:left;"> R-squared higher for training </td> </tr> </tbody> </table> --- ## Evaluating performance on training data - The training set does not have the capacity to be a good arbiter of performance. -- - It is not an independent piece of information; predicting the training set can only reflect what the model already knows. -- - Suppose you give a class a test, then give them the answers, then provide the same test. The student scores on the second test do not accurately reflect what they know about the subject; these scores would probably be higher than their results on the first test. .footnote[ .small[ Source: [tidymodels.org](https://www.tidymodels.org/start/resampling/) ] ] --- class: middle # Wrapping Up... --- class: middle # Cross validation --- ## Cross validation More specifically, **v-fold cross validation**: - Shuffle your data v partitions - Use 1 partition for validation, and the remaining v-1 partitions for training - Repeat v times .footnote[ .small[ You might also heard of this referred to as k-fold cross validation. ] ] --- ## Cross validation <img src="img/cross-validation.png" width="100%" style="display: block; margin: auto;" /> --- ## Split data into folds .pull-left[ ```r set.seed(345) folds <- vfold_cv(office_train, v = 5) folds ``` ``` ## # 5-fold cross-validation ## # A tibble: 5 x 2 ## splits id ## <list> <chr> ## 1 <split [112/29]> Fold1 ## 2 <split [113/28]> Fold2 ## 3 <split [113/28]> Fold3 ## 4 <split [113/28]> Fold4 ## 5 <split [113/28]> Fold5 ``` ] .pull-right[ <img src="img/cross-validation.png" width="100%" style="display: block; margin: auto 0 auto auto;" /> ] --- ## Fit resamples .pull-left[ ```r set.seed(456) office_fit_rs <- office_wflow %>% fit_resamples(folds) #office_fit_rs ``` ] .pull-right[ <img src="img/cross-validation-animated.gif" width="100%" style="display: block; margin: auto 0 auto auto;" /> ] --- ```r office_fit_rs ``` ``` ## # Resampling results ## # 5-fold cross-validation ## # A tibble: 5 x 4 ## splits id .metrics .notes ## <list> <chr> <list> <list> ## 1 <split [112/29]> Fold1 <tibble [2 x 4]> <tibble [0 x 1]> ## 2 <split [113/28]> Fold2 <tibble [2 x 4]> <tibble [0 x 1]> ## 3 <split [113/28]> Fold3 <tibble [2 x 4]> <tibble [0 x 1]> ## 4 <split [113/28]> Fold4 <tibble [2 x 4]> <tibble [0 x 1]> ## 5 <split [113/28]> Fold5 <tibble [2 x 4]> <tibble [0 x 1]> ``` --- ## Collect CV metrics ```r collect_metrics(office_fit_rs) ``` ``` ## # A tibble: 2 x 6 ## .metric .estimator mean n std_err .config ## <chr> <chr> <dbl> <int> <dbl> <chr> ## 1 rmse standard 0.403 5 0.0336 Preprocessor1_Model1 ## 2 rsq standard 0.413 5 0.0727 Preprocessor1_Model1 ``` --- ## Deeper look into CV metrics .panel[.panel-name[Raw] ```r collect_metrics(office_fit_rs, summarize = FALSE) %>% print(n = 10) ``` ``` ## # A tibble: 10 x 5 ## id .metric .estimator .estimate .config ## <chr> <chr> <chr> <dbl> <chr> ## 1 Fold1 rmse standard 0.430 Preprocessor1_Model1 ## 2 Fold1 rsq standard 0.134 Preprocessor1_Model1 ## 3 Fold2 rmse standard 0.368 Preprocessor1_Model1 ## 4 Fold2 rsq standard 0.496 Preprocessor1_Model1 ## 5 Fold3 rmse standard 0.452 Preprocessor1_Model1 ## 6 Fold3 rsq standard 0.501 Preprocessor1_Model1 ## 7 Fold4 rmse standard 0.289 Preprocessor1_Model1 ## 8 Fold4 rsq standard 0.529 Preprocessor1_Model1 ## 9 Fold5 rmse standard 0.475 Preprocessor1_Model1 ## 10 Fold5 rsq standard 0.403 Preprocessor1_Model1 ``` ] --- .panel[.panel-name[Tidy] <table> <thead> <tr> <th style="text-align:left;"> Fold </th> <th style="text-align:right;"> RMSE </th> <th style="text-align:right;"> R-squared </th> </tr> </thead> <tbody> <tr> <td style="text-align:left;"> Fold1 </td> <td style="text-align:right;"> 0.430 </td> <td style="text-align:right;"> 0.134 </td> </tr> <tr> <td style="text-align:left;"> Fold2 </td> <td style="text-align:right;"> 0.368 </td> <td style="text-align:right;"> 0.496 </td> </tr> <tr> <td style="text-align:left;"> Fold3 </td> <td style="text-align:right;"> 0.452 </td> <td style="text-align:right;"> 0.501 </td> </tr> <tr> <td style="text-align:left;"> Fold4 </td> <td style="text-align:right;"> 0.289 </td> <td style="text-align:right;"> 0.529 </td> </tr> <tr> <td style="text-align:left;"> Fold5 </td> <td style="text-align:right;"> 0.475 </td> <td style="text-align:right;"> 0.403 </td> </tr> </tbody> </table> ] --- ## How does RMSE compare to y? - Cross validation RMSE stats ``` ## # A tibble: 1 x 6 ## min max mean med sd IQR ## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 0.289 0.475 0.403 0.430 0.0751 0.0841 ``` - Training data IMDB score stats ``` ## # A tibble: 1 x 6 ## min max mean med sd IQR ## <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 6.7 9.7 8.24 8.2 0.530 0.600 ``` --- ## What's next? <img src="img/post-cv-testing.png" width="90%" style="display: block; margin: auto 0 auto auto;" /> --- class: middle # Wrapping Up... <br> Sources: - Mine Çetinkaya-Rundel's Data Science in a Box ([link](https://datasciencebox.org/)) - Julia Fukuyama's EDA ([link](https://jfukuyama.github.io/))