In the last lecture, we focused on data wrangling (import, tidy, transform, visualize). Now we progress to the modeling part. In this lecture, we focus on predictive modeling using machine learning methods (logistic, random forest, neural network, …). In the next lecture, we focus on policy evaluation using double machine learning.
Supervised vs unsuperversed learning, logistic regression, ROC curve, overfitting, regularization, L1 and L2 penalty, elastic net, cross-validation, hyperparameter tuning, model evaluation, and interpretation.
3 Machine learning overview
3.1 Supervised vs unsupervised learning
Supervised learning: input(s) -> output.
Prediction (or regression): the output is continuous (income, weight, bmi, …).
Classification: the output is categorical (disease or not, pattern recognition, …).
Unsupervised learning: no output. We learn relationships and structure in the data.
Clustering.
Dimension reduction.
Embedding.
In modern applications, the line between supervised and unsupervised learning is blurred.
Matrix completion: Netflix problem. Both supervise and unsupervised techniques are used.
Large language model (LLM) combines supervised learning and reinforcement learning.
4 Logistic regression
We load the Food Security Supplement household data we curated earlier. Our goal is to predict food insecurity status using household’s socio-economical status.
# A tibble: 30,162 × 13
HRFS12M1 HRNUMHOU HRHTYPE GEREG PRCHLD HRPOOR PRTAGE PEEDUCA PEMLR PTDTRACE
<fct> <dbl> <fct> <fct> <fct> <fct> <dbl> <fct> <fct> <fct>
1 Food Secu… 1 Indivi… West NoChi… NotPo… 35 HighSc… Empl… AIAN
2 Food Secu… 1 Indivi… South NoChi… NotPo… 36 Colleg… Empl… White
3 Food Secu… 3 Marrie… South NoChi… NotPo… 55 HighSc… Empl… White
4 Food Secu… 2 Marrie… South NoChi… NotPo… 85 Colleg… NotI… White
5 Food Secu… 2 Marrie… West NoChi… Poor 69 HighSc… NotI… AIAN
6 Food Secu… 1 Indivi… Nort… NoChi… NotPo… 51 HighSc… Empl… White
7 Low Food … 2 Unmarr… Midw… Child… Poor 54 HighSc… NotI… White
8 Food Secu… 3 Marrie… South Child… NotPo… 46 Colleg… Empl… White
9 Food Secu… 2 Marrie… Midw… NoChi… NotPo… 69 HighSc… NotI… White
10 Food Secu… 2 Marrie… South NoChi… NotPo… 75 HighSc… NotI… Black
# ℹ 30,152 more rows
# ℹ 3 more variables: HHSUPWGT <dbl>, PEHSPNON <fct>, HRFS12M1_binary <fct>
4.1 Why not linear regression?
We are interested in predicting whether a household will be food insecure, on the basis of household size, marital status, have children or not, geographical region, education level, employment status, and race.
The response HRFS12M1_binary falls into one of two categories, Food Insecure (1) or Food Secure (0). Rather than modeling this response \(Y\) directly, logistic regression models the probability that \(Y\) belongs to a particular category.
The parameter \(p_i = \mathbb{E}(Y_i)\) will be related to the predictors \(\mathbf{x}_i\) via
\[
p_i = \frac{e^{\eta_i}}{1 + e^{\eta_i}},
\] where \(\eta_i\) is the linear predictor (or systematic component)
\[
\eta_i = \mathbf{x}_i^T \boldsymbol{\beta} = \beta_0 + \beta_1 x_{i1} + \beta_2 x_{i2} + \dots + \beta_q x_{iq}.
\] In other words, logistic regression models the log-odds of the probability of success as a linear function of the predictors
Therefore \(\beta_1\) can be interpreted as a unit increase in \(x_1\) with other predictors held fixed increases the log-odds of success by \(\beta_1\), or increase the odds of success by a factor of \(e^{\beta_1}\).
4.2 Logistic regression on food security data
To further investigate the factors that are associated with food insecurity, we can use logistic regression to model the probability of a household being in low food security.
# Fit logistic regressionlogit_model <-glm(# All predictors except HRFS12M1 and HHSUPWGT HRFS12M1_binary ~ . - HRFS12M1 - HHSUPWGT,data = data_clean, family ="binomial" )logit_model
If we want to use this model for prediction, we need to evaluate its performance. There are many metrics related to classification models, such as accuracy, precision, recall, sensitivity, specificity, …
If we set the threshold to 0.5, the accuracy is 0.90, which is quite higher than setting threshold to 0.1. However, the sensitivity is only 0.02, which means the model can only capture 2% of the households in low food security. In contrast, if we set the threshold to 0.1, the sensitivity is 0.71, which means the model can capture 71% of the households in low food security. However, the specificity decreases to 0.74. This trade-off is common in classification models.
Therefore, we need a metric that can evaluate the model’s performance under different thresholds. The ROC curve is a good choice. The ROC curve is a popular graphic for simultaneously displaying the two types of errors for all possible thresholds. The name “ROC” is historic, and comes from communications theory. It is an acronym for receiver operating characteristics.
The overall performance of a classifier, summarized over all possible thresholds, is given by the area under the (ROC) curve (AUC). An ideal ROC curve will hug the top left corner, so the larger area under the AUC the better the classifier. We expect a classifier that performs no better than chance to have an AUC of 0.5.
There is another similar plot called the precision-recall curve, which sets the x-axis as recall and the y-axis as precision. The classifier that has a higher AUC on the ROC curve will always have a higher AUC on the PR curve as well.
data_clean <- data_clean |>mutate(prob =predict(logit_model, type ="response"))roc_data <-roc(data_clean$HRFS12M1_binary, data_clean$prob)ggroc(roc_data, legacy.axes =TRUE) +labs(title ="ROC Curve for Logistic Regression Model",x ="1 - Specificity",y ="Sensitivity") +theme_minimal() +annotate("text", x =0.5, y =0.5,label =paste("AUC =", round(auc(roc_data), 3)))
The logistic regression model has an AUC of 0.794, which indicates that the model has a reasonably good discrimination ability. However, if we want to evaluate the model’s predictive performance, simply fitting models and calculating AUC is not enough.
5 Assessing model accuracy
5.1 Measuring the quality of fit
In order to evaluate the performance of a statistical learning method on a given data set, we need some way to measure how well its predictions actually match the observed data. That is, we need to quantify the extent to which the predicted response value for a given observation is close to the true response value for that observation. In the regression setting, the most commonly-used measure is the mean squared error (MSE), given by
\[
\text{MSE} = \frac{1}{n} \sum_{i=1}^{n} (y_i - \hat{f}(x_i))^2.
\] The MSE will be small if the predicted responses are very close to the true responses, and will be large if for some of the observations, the predicted and true responses differ substantially.
The MSE is computed using the training data that was used to fit the model. But in general, we do not really care how well the method works training on the training data. Rather, we are interested in the accuracy of the predictions that we obtain when we apply our method to previously unseen test data. We’d like to select the model for which the average of the test MSE is as small as possible.
As model flexibility increases, training MSE will decrease, but the test MSE may not. When a given method yields a small training MSE but a large test MSE, we are said to be overfitting the data.
Tip
Does it mean simpler models are always better?
No Free Lunch Theorem [David Wolpert, William Macready]: Any two optimization algorithms are equivalent when their performance is averaged across all possible problems.
The black points represent the training data. There are two models A and B, in which model B is more flexible than model A.
The white points represent the test data. In the left panel, model A has a smaller test MSE than model B. In the right panel, model B has a smaller test MSE than model A. Therefore, we cannot say that simpler models are always better.
In practice, one can usually compute the training MSE with relative ease, but estimating test MSE is considerably more difficult because usually no test data are available. The flexibility level corresponding to the model with the minimal test MSE can vary considerably among data sets. One important method is cross-validation, which is a cross method for estimating test MSE using the training data.
5.2 Cross-validation
K-fold cross-validation randomly divides the set of observations into k groups, or folds, of approximately equal size. The first fold is treated as a validation set, and the method is fit on the remaining k − 1 folds. The mean squared error, \(\text{MSE}_1\), is then computed on the observations in the held-out fold. This procedure is repeated k times; each time, a different group of observations is treated as a validation set. This process results in k estimates of the test error, \(\text{MSE}_1\), \(\text{MSE}_2\),…, \(\text{MSE}_k\). The k-fold CV estimate is computed by averaging these values,
Feature engineering: coding qualitative predictors, transformation of predictors (e.g., log), extracting key features from raw variables (e.g., getting the day of the week out of a date variable), interaction terms, … (recipes package);
7 Elastic-net (enet) regularization and shrinkage methods
The subset selection methods such as best subset selection, forward stepwise selection, and backward stepwise selection have some limitations. They are computationally expensive and can lead to overfitting. Shrinkage methods are an alternative approach to subset selection. we fit a model containing all p predictors using a technique that constrains or regularizes the coefficient estimates, or equivalently, that shrinks the coefficient estimates towards zero. It may not be immediately obvious why such a constraint should improve the fit, but it turns out that shrinking the coefficient estimates can significantly reduce their variance. The two best-known techniques for shrinking the regression coefficients towards zero are ridge regression and the lasso.
In logistic regression, for ridge regression (\(L_2\) penalty), we need to optimize the following objective function: \[
\ell^*(\boldsymbol\beta) = \ell(\boldsymbol\beta) - \lambda \sum_{j=1}^{p} \beta_j^2,
\] where the penalty term is \(\lambda \sum_{j=1}^{p} \beta_j^2\).
For the lasso (\(L_1\) penalty), we need to optimize the following objective function: \[
\ell^*(\boldsymbol\beta) = \ell(\boldsymbol\beta) - \lambda \sum_{j=1}^{p} |\beta_j|,
\] where the penalty term is \(\lambda \sum_{j=1}^{p} |\beta_j|\).
The elastic net combines the ridge and lasso penalties, and the penalty term is \(\lambda \left( \alpha \sum_{j=1}^{p} |\beta_j| + \frac{1 - \alpha}{2} \sum_{j=1}^{p} \beta_j^2 \right)\), where \(\alpha\) controls the relative weight of the two penalties.
Implementing ridge regression and the lasso requires a method for selecting a value for the tuning parameter \(\lambda\) (and \(\alpha\) if we use elastic net). Cross-validation provides a simple way to tackle this problem. We choose a grid of tuning parameters values, and compute the cross-validation error for each value of tuning parameters. We then select the tuning parameters value for which the cross-validation error is smallest. Finally, the model is re-fit using all of the available observations and the selected value of the tuning parameter.
We randomly split the data into 25% test data and 75% non-test data. Stratify on food security status.
# For reproducibilityset.seed(2024)data_split <- data_clean |>initial_split(# stratify by HRFS12M1_binarystrata ="HRFS12M1_binary", prop =0.75 )data_split
<Training/Testing/Total>
<22621/7541/30162>
data_other <-training(data_split)dim(data_other)
[1] 22621 14
data_test <-testing(data_split)dim(data_test)
[1] 7541 14
8.2 Recipe
Recipe for preprocessing the data:
recipe <-recipe( HRFS12M1_binary ~ .,data = data_other ) |># remove the weights and original HRFS12M1step_rm(HHSUPWGT, HRFS12M1) |># create dummy variables for categorical predictorsstep_dummy(all_nominal_predictors()) |># zero-variance filterstep_zv(all_numeric_predictors()) |># center and scale numeric datastep_normalize(all_numeric_predictors()) |># estimate the means and standard deviationsprint()