class: center, middle, inverse, title-slide # Prediction and overfitting
🏃 ### S. Mason Garrison --- layout: true <div class="my-footer"> <span> <a href="https://DataScience4Psych.github.io/DataScience4Psych/" target="_blank">Data Science for Psychologists</a> </span> </div> --- class: middle # Prediction --- ## Goal: Building a spam filter - Data: Set of emails and we know if each email is spam/not and other features - Use logistic regression to predict the probability that an incoming email is spam - Use model selection to pick the model with the best predictive performance -- - Building a model to predict the probability that an email is spam is only half of the battle! We also need a decision rule about which emails get flagged as spam (e.g. what probability should we use as out cutoff?) -- - A simple approach: choose a single threshold probability and any email that exceeds that probability is flagged as spam --- ## A multiple regression approach .pull-left[ .panel[.panel-name[Output] .small[ ``` ## # A tibble: 22 x 5 ## term estimate std.error statistic p.value ## <chr> <dbl> <dbl> <dbl> <dbl> ## 1 (Intercept) -9.09e+1 9.80e+3 -0.00928 9.93e- 1 ## 2 to_multiple1 -2.68e+0 3.27e-1 -8.21 2.25e-16 ## 3 from1 -2.19e+1 9.80e+3 -0.00224 9.98e- 1 ## 4 cc 1.88e-2 2.20e-2 0.855 3.93e- 1 ## 5 sent_email1 -2.07e+1 3.87e+2 -0.0536 9.57e- 1 ## 6 time 8.48e-8 2.85e-8 2.98 2.92e- 3 ## 7 image -1.78e+0 5.95e-1 -3.00 2.73e- 3 ## 8 attach 7.35e-1 1.44e-1 5.09 3.61e- 7 ## 9 dollar -6.85e-2 2.64e-2 -2.59 9.64e- 3 ## 10 winneryes 2.07e+0 3.65e-1 5.67 1.41e- 8 ## 11 inherit 3.15e-1 1.56e-1 2.02 4.32e- 2 ## 12 viagra 2.84e+0 2.22e+3 0.00128 9.99e- 1 ## 13 password -8.54e-1 2.97e-1 -2.88 4.03e- 3 ## 14 num_char 5.06e-2 2.38e-2 2.13 3.35e- 2 ## 15 line_breaks -5.49e-3 1.35e-3 -4.06 4.91e- 5 ## 16 format1 -6.14e-1 1.49e-1 -4.14 3.53e- 5 ## 17 re_subj1 -1.64e+0 3.86e-1 -4.25 2.16e- 5 ## 18 exclaim_subj 1.42e-1 2.43e-1 0.585 5.58e- 1 ## 19 urgent_subj1 3.88e+0 1.32e+0 2.95 3.18e- 3 ## 20 exclaim_mess 1.08e-2 1.81e-3 5.98 2.23e- 9 ## 21 numbersmall -1.19e+0 1.54e-1 -7.74 9.62e-15 ## 22 numberbig -2.95e-1 2.20e-1 -1.34 1.79e- 1 ``` ``` ## ## Call: ## glm(formula = spam ~ re_subj, family = "binomial", data = email) ## ## Deviance Residuals: ## Min 1Q Median 3Q Max ## -0.5145 -0.5145 -0.5145 -0.1252 3.1155 ## ## Coefficients: ## Estimate Std. Error z value Pr(>|z|) ## (Intercept) -1.95542 0.05639 -34.677 < 2e-16 *** ## re_subj1 -2.88976 0.35937 -8.041 8.9e-16 *** ## --- ## Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1 ## ## (Dispersion parameter for binomial family taken to be 1) ## ## Null deviance: 2437.2 on 3920 degrees of freedom ## Residual deviance: 2264.1 on 3919 degrees of freedom ## AIC: 2268.1 ## ## Number of Fisher Scoring iterations: 7 ``` ] ] ] .pull-right[.panel[.panel-name[Code] ```r logistic_reg() %>% set_engine("glm") %>% fit(spam ~ ., data = email, family = "binomial") %>% tidy() %>% print(n = 22) ``` ``` ## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred ``` ```r summary(glm(spam~re_subj, data=email, family = "binomial")) ``` ] ] --- ## Prediction - The mechanics of prediction is **easy**: - Plug in values of predictors to the model equation - Calculate the predicted value of the response variable, `\(\hat{y}\)` -- - Getting it right is **hard**! - No guarantee that you have the correct the model estimates - Or that your model will perform as well with new data as it did with your sample data --- ## Underfitting and overfitting <img src="d24_overfitting_files/figure-html/unnamed-chunk-3-1.png" width="70%" style="display: block; margin: auto;" /> --- ## Spending our data - Several steps to create a useful model: - parameter estimation, - model selection, - performance assessment, etc. -- - Doing all of this on the entire data we have available can lead to **overfitting** -- - Allocate specific subsets of data for different tasks, - as opposed to allocating the largest possible amount to the model parameter estimation only - (which is what we've done so far) --- class: middle # Splitting data --- ## Splitting data - **Training set:** - Sandbox for model building - Spend most of your time using the training set to develop the model - Majority of the data (usually 80%) - **Testing set:** - Held in reserve to determine efficacy of one or two chosen models - Critical to look at it once, otherwise it becomes part of the modeling process - Remainder of the data (usually 20%) --- ## Performing the split ```r # Fix random numbers by setting the seed # Enables analysis to be reproducible when random numbers are used set.seed(1066) # Put 80% of the data into the training set email_split <- initial_split(email, prop = 0.80) # Create data frames for the two sets: train_data <- training(email_split) test_data <- testing(email_split) ``` --- ## Peek at the split .small[ .pull-left[ ```r glimpse(train_data) ``` ``` ## Rows: 3,136 ## Columns: 21 ## $ spam <fct> 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, ~ ## $ to_multiple <fct> 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 1, 0, 0, ~ ## $ from <fct> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ~ ## $ cc <int> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ sent_email <fct> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ time <dttm> 2012-02-10 17:21:24, 2012-02-09 04:32:51,~ ## $ image <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ attach <dbl> 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, ~ ## $ dollar <dbl> 2, 0, 0, 2, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, ~ ## $ winner <fct> no, no, no, no, no, no, no, no, no, no, no~ ## $ inherit <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ viagra <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ password <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 0, ~ ## $ num_char <dbl> 18.649, 0.696, 2.792, 17.049, 8.592, 1.360~ ## $ line_breaks <int> 388, 20, 39, 386, 146, 28, 740, 42, 152, 7~ ## $ format <fct> 1, 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 0, 1, ~ ## $ re_subj <fct> 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ exclaim_subj <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ urgent_subj <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ exclaim_mess <dbl> 20, 1, 1, 0, 1, 0, 2, 0, 1, 0, 0, 0, 0, 1,~ ## $ number <fct> big, small, small, small, small, small, sm~ ``` ] .pull-right[ ```r glimpse(test_data) ``` ``` ## Rows: 785 ## Columns: 21 ## $ spam <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ to_multiple <fct> 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, ~ ## $ from <fct> 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, ~ ## $ cc <int> 0, 0, 0, 0, 0, 2, 1, 2, 0, 0, 0, 0, 0, 0, ~ ## $ sent_email <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, ~ ## $ time <dttm> 2012-01-01 04:09:49, 2012-01-01 22:00:18,~ ## $ image <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ attach <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ dollar <dbl> 0, 0, 5, 0, 21, 0, 0, 0, 11, 18, 16, 0, 2,~ ## $ winner <fct> no, no, no, no, no, no, no, no, no, no, no~ ## $ inherit <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ viagra <dbl> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ password <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ num_char <dbl> 13.256, 10.459, 18.037, 4.560, 16.647, 7.8~ ## $ line_breaks <int> 255, 191, 345, 64, 560, 97, 881, 185, 350,~ ## $ format <fct> 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, ~ ## $ re_subj <fct> 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 0, 0, ~ ## $ exclaim_subj <dbl> 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, ~ ## $ urgent_subj <fct> 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, ~ ## $ exclaim_mess <dbl> 48, 0, 20, 0, 3, 1, 5, 3, 2, 15, 16, 4, 6,~ ## $ number <fct> small, big, small, none, small, small, big~ ``` ] ] --- class: middle # Modeling workflow --- ## Fit a model to the training dataset ```r # AHHHHHHHHHHHHHHHHHHHH set iter doesn't stick!!!! email_fit <- logistic_reg() %>% set_engine("glm") %>% fit(spam ~ ., data = train_data, family = "binomial") ``` ``` ## Warning: glm.fit: fitted probabilities numerically 0 or 1 occurred ``` ```r email_fit <- logistic_reg() %>% set_engine("glm") %>% fit(spam ~ ., data = train_data, family = "binomial") %>% suppressWarnings(.) ``` --- ## Categorical predictors <img src="d24_overfitting_files/figure-html/unnamed-chunk-9-1.png" width="80%" style="display: block; margin: auto;" /> --- ## `from` and `sent_email` .pull-left[ - `from`: Whether the message was listed as from anyone (this is usually set by default for regular outgoing email) ```r train_data %>% count(spam, from) ``` ``` ## # A tibble: 3 x 3 ## spam from n ## <fct> <fct> <int> ## 1 0 1 2847 ## 2 1 0 3 ## 3 1 1 286 ``` ] .pull-right[ - `sent_email`: Indicator for whether the sender had been sent an email in the last 30 days ```r train_data %>% count(spam, sent_email) ``` ``` ## # A tibble: 3 x 3 ## spam sent_email n ## <fct> <fct> <int> ## 1 0 0 1981 ## 2 0 1 866 ## 3 1 0 289 ``` ] --- ## Numerical predictors ``` ## ## -- Variable type: numeric -------------------------------------------------------------------------- ## # A tibble: 22 x 11 ## skim_variable spam n_missing complete_rate mean sd p0 p25 p50 p75 p100 ## * <chr> <fct> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 cc 0 0 1 0.416 2.73 0 0 0 0 68 ## 2 cc 1 0 1 0.439 3.37 0 0 0 0 50 ## 3 image 0 0 1 0.0534 0.502 0 0 0 0 20 ## 4 image 1 0 1 0.00346 0.0588 0 0 0 0 1 ## 5 attach 0 0 1 0.128 0.764 0 0 0 0 21 ## 6 attach 1 0 1 0.235 0.629 0 0 0 0 2 ## 7 dollar 0 0 1 1.61 5.43 0 0 0 0 64 ## 8 dollar 1 0 1 0.599 1.95 0 0 0 0 26 *## 9 inherit 0 0 1 0.0365 0.242 0 0 0 0 6 *## 10 inherit 1 0 1 0.0727 0.564 0 0 0 0 9 ## 11 viagra 0 0 1 0 0 0 0 0 0 0 ## 12 viagra 1 0 1 0.0277 0.471 0 0 0 0 8 ## 13 password 0 0 1 0.124 1.08 0 0 0 0 28 ## 14 password 1 0 1 0.0208 0.185 0 0 0 0 2 ## 15 num_char 0 0 1 11.3 14.3 0.003 1.95 6.91 15.8 161. ## 16 num_char 1 0 1 4.90 12.1 0.001 0.499 1.09 3.37 128. ## 17 line_breaks 0 0 1 245. 313. 2 42 138 319 3522 ## 18 line_breaks 1 0 1 94.2 239. 1 14 23 66 2908 ## 19 exclaim_subj 0 0 1 0.0804 0.272 0 0 0 0 1 ## 20 exclaim_subj 1 0 1 0.0865 0.282 0 0 0 0 1 ## 21 exclaim_mess 0 0 1 6.10 42.2 0 0 1 5 1203 ## 22 exclaim_mess 1 0 1 4.70 55.4 0 0 0 1 939 ``` --- ## Numerical predictors .tiny[ ``` ## ## -- Variable type: numeric -------------------------------------------------------------------------- ## # A tibble: 22 x 11 ## skim_variable spam n_missing complete_rate mean sd p0 p25 p50 p75 p100 ## * <chr> <fct> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 cc 0 0 1 0.416 2.73 0 0 0 0 68 ## 2 cc 1 0 1 0.439 3.37 0 0 0 0 50 ## 3 image 0 0 1 0.0534 0.502 0 0 0 0 20 ## 4 image 1 0 1 0.00346 0.0588 0 0 0 0 1 ## 5 attach 0 0 1 0.128 0.764 0 0 0 0 21 ## 6 attach 1 0 1 0.235 0.629 0 0 0 0 2 ## 7 dollar 0 0 1 1.61 5.43 0 0 0 0 64 ## 8 dollar 1 0 1 0.599 1.95 0 0 0 0 26 *## 9 inherit 0 0 1 0.0365 0.242 0 0 0 0 6 *## 10 inherit 1 0 1 0.0727 0.564 0 0 0 0 9 ## 11 viagra 0 0 1 0 0 0 0 0 0 0 ## 12 viagra 1 0 1 0.0277 0.471 0 0 0 0 8 ## 13 password 0 0 1 0.124 1.08 0 0 0 0 28 ## 14 password 1 0 1 0.0208 0.185 0 0 0 0 2 ## 15 num_char 0 0 1 11.3 14.3 0.003 1.95 6.91 15.8 161. ## 16 num_char 1 0 1 4.90 12.1 0.001 0.499 1.09 3.37 128. ## 17 line_breaks 0 0 1 245. 313. 2 42 138 319 3522 ## 18 line_breaks 1 0 1 94.2 239. 1 14 23 66 2908 ## 19 exclaim_subj 0 0 1 0.0804 0.272 0 0 0 0 1 ## 20 exclaim_subj 1 0 1 0.0865 0.282 0 0 0 0 1 ## 21 exclaim_mess 0 0 1 6.10 42.2 0 0 1 5 1203 ## 22 exclaim_mess 1 0 1 4.70 55.4 0 0 0 1 939 ``` ] --- ## Numerical predictors .small[ ``` ## ## -- Variable type: numeric -------------------------------------------------------------------------- ## # A tibble: 22 x 11 ## skim_variable spam n_missing complete_rate mean sd p0 p25 p50 p75 p100 ## * <chr> <fct> <int> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl> ## 1 cc 0 0 1 0.416 2.73 0 0 0 0 68 ## 2 cc 1 0 1 0.439 3.37 0 0 0 0 50 ## 3 image 0 0 1 0.0534 0.502 0 0 0 0 20 ## 4 image 1 0 1 0.00346 0.0588 0 0 0 0 1 ## 5 attach 0 0 1 0.128 0.764 0 0 0 0 21 ## 6 attach 1 0 1 0.235 0.629 0 0 0 0 2 ## 7 dollar 0 0 1 1.61 5.43 0 0 0 0 64 ## 8 dollar 1 0 1 0.599 1.95 0 0 0 0 26 *## 9 inherit 0 0 1 0.0365 0.242 0 0 0 0 6 *## 10 inherit 1 0 1 0.0727 0.564 0 0 0 0 9 ## 11 viagra 0 0 1 0 0 0 0 0 0 0 ## 12 viagra 1 0 1 0.0277 0.471 0 0 0 0 8 ## 13 password 0 0 1 0.124 1.08 0 0 0 0 28 ## 14 password 1 0 1 0.0208 0.185 0 0 0 0 2 ## 15 num_char 0 0 1 11.3 14.3 0.003 1.95 6.91 15.8 161. ## 16 num_char 1 0 1 4.90 12.1 0.001 0.499 1.09 3.37 128. ## 17 line_breaks 0 0 1 245. 313. 2 42 138 319 3522 ## 18 line_breaks 1 0 1 94.2 239. 1 14 23 66 2908 ## 19 exclaim_subj 0 0 1 0.0804 0.272 0 0 0 0 1 ## 20 exclaim_subj 1 0 1 0.0865 0.282 0 0 0 0 1 ## 21 exclaim_mess 0 0 1 6.10 42.2 0 0 1 5 1203 ## 22 exclaim_mess 1 0 1 4.70 55.4 0 0 0 1 939 ``` ] --- ## Fit a model to the training dataset ```r email_fit <- logistic_reg() %>% set_engine("glm") %>% * fit(spam ~ . - from - sent_email - viagra, * data = train_data, family = "binomial")%>% suppressWarnings(.) ``` .small[ ```r email_fit ``` ``` ## parsnip model object ## ## Fit time: 21ms ## ## Call: stats::glm(formula = spam ~ . - from - sent_email - viagra, family = stats::binomial, ## data = data) ## ## Coefficients: ## (Intercept) to_multiple1 cc time image attach dollar ## -1.112e+02 -2.646e+00 2.594e-02 8.323e-08 -3.611e+00 5.879e-01 -8.206e-02 ## winneryes inherit password num_char line_breaks format1 re_subj1 ## 2.241e+00 3.761e-01 -7.070e-01 5.872e-02 -5.843e-03 -7.830e-01 -2.841e+00 ## exclaim_subj urgent_subj1 exclaim_mess numbersmall numberbig ## 1.867e-01 2.391e+00 1.067e-02 -6.830e-01 2.070e-01 ## ## Degrees of Freedom: 3135 Total (i.e. Null); 3117 Residual ## Null Deviance: 1929 ## Residual Deviance: 1425 AIC: 1463 ``` ] --- ## Predict outcome on the testing dataset ```r predict(email_fit, test_data) ``` ``` ## # A tibble: 785 x 1 ## .pred_class ## <fct> ## 1 0 ## 2 0 ## 3 0 ## 4 0 ## 5 0 ## 6 0 ## 7 0 ## 8 0 ## 9 0 ## 10 0 ## # ... with 775 more rows ``` --- ## Predict probabilities on the testing dataset ```r email_pred <- predict(email_fit, test_data, type = "prob") %>% bind_cols(test_data %>% select(spam, time)) email_pred ``` ``` ## # A tibble: 785 x 4 ## .pred_0 .pred_1 spam time ## <dbl> <dbl> <fct> <dttm> ## 1 0.927 0.0734 0 2012-01-01 04:09:49 ## 2 0.875 0.125 0 2012-01-01 22:00:18 ## 3 0.982 0.0179 0 2012-01-02 00:42:16 ## 4 0.852 0.148 0 2012-01-01 20:58:14 ## 5 0.998 0.00181 0 2012-01-02 02:07:22 ## 6 0.999 0.000845 0 2012-01-02 13:09:45 ## 7 0.999 0.00129 0 2012-01-02 10:12:51 ## 8 0.999 0.000545 0 2012-01-02 16:24:21 ## 9 0.984 0.0155 0 2012-01-03 04:34:50 ## 10 0.994 0.00594 0 2012-01-03 08:33:28 ## # ... with 775 more rows ``` --- ### A closer look at predictions .pull-left-wide[ ```r email_pred %>% arrange(desc(.pred_1)) %>% print(n = 10) ``` ``` ## # A tibble: 785 x 4 ## .pred_0 .pred_1 spam time ## <dbl> <dbl> <fct> <dttm> ## 1 0.187 0.813 0 2012-01-08 09:28:42 ## 2 0.225 0.775 1 2012-01-19 19:43:14 *## 3 0.235 0.765 1 2012-03-17 12:20:57 ## 4 0.246 0.754 1 2012-03-31 05:20:24 ## 5 0.254 0.746 1 2012-03-17 06:13:27 ## 6 0.274 0.726 1 2012-02-13 07:15:00 *## 7 0.302 0.698 0 2012-03-17 10:41:28 ## 8 0.323 0.677 1 2012-02-28 05:19:20 ## 9 0.372 0.628 0 2012-01-28 14:41:43 ## 10 0.387 0.613 0 2012-02-04 05:39:37 ## # ... with 775 more rows ``` ] --- ## Evaluate the performance **Receiver operating characteristic (ROC) curve**<sup>+</sup> which plot true positive rate vs. false positive rate (1 - specificity) .medi.pull-left-narrow[ ```r email_pred %>% roc_curve( truth = spam, .pred_1, event_level = "second" ) %>% autoplot() ``` ] .pull-right-wide[ <img src="d24_overfitting_files/figure-html/unnamed-chunk-20-1.png" width="100%" style="display: block; margin: auto;" /> ] .footnote[ .small[ <sup>+</sup>Originally developed for operators of military radar receivers, hence the name. ] ] --- ## Evaluate the performance Find the area under the curve: .medi.pull-left-narrow[ ```r email_pred %>% roc_auc( truth = spam, .pred_1, event_level = "second" ) ``` ``` ## # A tibble: 1 x 3 ## .metric .estimator .estimate ## <chr> <chr> <dbl> ## 1 roc_auc binary 0.862 ``` ] .pull-right-wide[ <img src="d24_overfitting_files/figure-html/unnamed-chunk-22-1.png" width="100%" style="display: block; margin: auto;" /> ] --- class: middle # Wrapping Up... <br> Sources: - Mine Çetinkaya-Rundel's Data Science in a Box ([link](https://datasciencebox.org/)) - Julia Fukuyama's EDA ([link](https://jfukuyama.github.io/))