Decision trees are a classic “nonparametric” technique for making decisions. Let’s build one quickly for the relationship between pubs and waiting times.
library(tidyverse)
library(knitr)
library(caret)
times <- read_csv('../data/waiting_times.csv')
times %>% head(5) %>% kable()
waiting | pub | pub_number | long_wait | drinktype |
---|---|---|---|---|
5.896570 | The Berkeley | 1 | TRUE | Beer |
8.822136 | The Berkeley | 1 | TRUE | Beer |
4.521891 | The Berkeley | 1 | TRUE | Beer |
6.656845 | The Berkeley | 1 | TRUE | Cocktail |
1.005883 | The Berkeley | 1 | TRUE | Cocktail |
Before, we fit the very simple regression model:
wait_model <- lm(waiting ~ 0 + pub, data=times)
Which I wrote to you as: \[y_i = \begin{cases} \text{if at berkeley:} & \alpha_b \\ \text{if at gallimaufry:} & \alpha_g \\ \text{if at milk thistle:} & \alpha_{m} \end{cases} + e_i\] And which you saw gives you the same predictions as:
times %>% group_by(pub) %>% summarize(mean(waiting))
## # A tibble: 3 x 2
## pub `mean(waiting)`
## <chr> <dbl>
## 1 The Berkeley 4.70
## 2 The Gallimaufry 1.40
## 3 The Milk Thistle 6.27
And corresponds to predicting the means of the following sets of observations:
ggplot(times, aes(x=pub, y=waiting, color=pub)) +
geom_point()
Another way we can express this is as a tree:
library(rpart)
library(rpart.plot)
dtree <- rpart(waiting ~ pub, data=times,
control=list(minsplit=1, # override the default minimum split size
minbucket=1 # override the default minimum leaf size
)
)
prp(dtree)
You read decision trees top to bottom. The top part, called the root, indicates the “first split” that is most useful in describing the data. Then, reading downwards, each “split” represents one point at which the data is divided in order to make a prediction. You can see that if the Pub is the Gallimaufrey, we immediately split, and predict a value of 1.4 for the waiting time. This is because the Galli is quite a different distribution than the other two pubs; it’s variance is way lower, and its waiting time is also quite low. Then, the tree splits on whether the pub is the Berkeley, predicting 4.7 if so and 6.3 if not. These, again, are the means of those groups of observations.
So, decision trees are a way to “slice and dice” the data in order to make predictions about progressively smaller subsets of the data. They are intrinsically nonlinear, in that they generally don’t know anything about a functional form for the outcome variable, and they are usually very nonlinear in their predictions, since they can change their predictions rapidly based on very simple rules.
Another more complicated example: the house prices over time:
library(sf)
## Linking to GEOS 3.9.1, GDAL 3.2.1, PROJ 7.2.1
weca = read_sf("../data/weca.gpkg") %>%
st_drop_geometry() %>%
pivot_longer(price_dec_1995:price_dec_2018,
names_to=c(NA, 'quarter', 'year'),
values_to='price', names_sep='_')
dtree_weca <- rpart(price ~ year + quarter + la_name,
data=weca %>% mutate(price = price/1000) # make the plot easier to read
)
prp(dtree_weca, faclen=20)
In this tree, the model branches immediately on whether or not we’re predicting after 2003. If we’re predicting before 2003, then we then branch on whether we’re before 2001. If so, we predict a median house price of £71,000. If not, then the median house is £126,000. Then, we’ll split on whether the house was sold between 2004 and 2014. If yes, then we’ll enter the middle branch, and if not (it’s sold after 2014), we’ll enter the right branch. Both of those then split on whether or not the house is in Bath or in either of Bristol or South Glos. If it’s in Bristol or South Glos and sold from 2004 to 2014, the prediction is £185k. But, if it’s in Bath during that time, the tree predicts 241. Alternatively, if it’s after 2014 in Bath, we predict £329,000. This strategy works for both regression problems (where we predict the mean of a given group within the split) or classification problems, where we predict a single class for the whole split.
The actual algorithsm to design decision trees vary from implementation to implementation. Often, they’re looking for “binary recursive splits”. That is, they’re looking for a single sp[lit for a feature upon which to divide the dataset in two.1 As an aside, the information about the sales quarter, here, does not significantly help the model because quarter is a “seasonal” variate. It helps you make predictions around the year mean, but it’s not gonna be useful by itself to predict houses; all “March” houses aren’t going to have similar prices absolutely, they’ll only have similar prices relative to their year mean. Decision trees generally don’t like features like this. A split is “good” if the variance of the two halves is substantially smaller than the variance of the data overall. Once a split is made, then another split is identified for each of the two halves of the data, ad nauseum.
As you can tell, this procedure may make this model a bit too sensitive to the input data. The precision with which data can be sliced and diced may make for very predictions with very low prediction bias, sure.2 Remember: we mean bias in the sense that the predictions are fairly close to reality, not that the predictions are systematically over (or under) the correct value. But, the prediction variance can be immense. For example, in the weca
tree above, just changing from 2000 to 2001 will nearly double the predicted house price!
So, random forests provide a way to collect together decision trees to reduce this prediction variance, while keeping the bias acceptably low. Basically, a random forest works by fitting many trees3 Hence “forest,” a good name for a collection of trees 😏 , and giving each tree a slightly different view of the data. They do this by randomizing the features that the decision trees “see.” For example, let’s think about a random forest trained to predict song genre using danceability, tempo, acousticness, instrumentalness, energy, and valence.
songs = read_csv("../data/midterm-songs.csv")
## Rows: 32833 Columns: 23
## -- Column specification --------------------------------------------------------
## Delimiter: ","
## chr (10): track_id, track_name, track_artist, track_album_id, track_album_na...
## dbl (13): track_popularity, danceability, energy, key, loudness, mode, speec...
##
## i Use `spec()` to retrieve the full column specification for this data.
## i Specify the column types or set `show_col_types = FALSE` to quiet this message.
dance_rf = train(playlist_genre ~ danceability + tempo + acousticness + instrumentalness + energy + valence,
data=songs,
method='ranger', # a faster random forest than method='rf'
importance='impurity', # save measures of how important each feature is
trControl= trainControl(method='repeatedcv', repeats=1, number=5))
## Growing trees.. Progress: 87%. Estimated remaining time: 4 seconds.
## Growing trees.. Progress: 60%. Estimated remaining time: 20 seconds.
## Growing trees.. Progress: 97%. Estimated remaining time: 1 seconds.
## Growing trees.. Progress: 66%. Estimated remaining time: 15 seconds.
## Growing trees.. Progress: 89%. Estimated remaining time: 3 seconds.
## Growing trees.. Progress: 64%. Estimated remaining time: 17 seconds.
## Growing trees.. Progress: 91%. Estimated remaining time: 3 seconds.
## Growing trees.. Progress: 53%. Estimated remaining time: 27 seconds.
## Growing trees.. Progress: 93%. Estimated remaining time: 4 seconds.
In this forest, each decision tree “sees” a different set of variables. Some trees may look at danceability, tempo, and energy (like we’ve seen before), but others may look at valence, instrumentalness, and acousticness. By ensuring that the trees see many different perspectives of the data, we can ensure that they are relatively independent. That is, they will not all use the same splits, nor even see the same basic picture of the dataset.4 li id="fn4"> This idea, that a collection of independent but “weak” prediction algorithms (like Decision Trees) can perform learning tasks well, is a central idea of ensemble learning, of which Random Forests are a central component.↩︎ In addition, another strategy that is used by default in Random Forest-trained models is called “Bootstrap Aggregation,” or “bagging,” for short. This is discussed in ISL 8.2.1, and (at its core) does a similar thing to the rows of the training data that the random forest does to the columns of a dataframe.↩︎
One very nice trait of random forests is that they allow you to get a measurement of the importance of a variable for the prediction. This is kind of like an effect size (from a regression framework). Usually, the “best” feature is rated 100. For any model fit in the caret
framework, you can get feature importance using the varImp()
function (short for var
iable Imp
ortance):
varImp(dance_rf)
## ranger variable importance
##
## Overall
## tempo 100.00
## danceability 75.82
## energy 51.12
## valence 47.44
## acousticness 43.33
## instrumentalness 0.00
thus, you can see that the tempo
variable is the most useful in predicting genre. The next most important feature is danceability
. Energy, valence, and acousticness all hang around the same value. And, finally, we see that instrumentalness
is not very useful at all in predicting the genre of a song’s playlist. A classic visualization of this data uses a lollipop plot:
varImp(dance_rf) %>% plot()
It’s also useful to note that, unlike KNN methods, decision trees are generally not sensitive to re-scaling of the data. This is because the trees find cut-points that reduce the variability of the data. Whether this cut point is shifted left or right (as is done in centering) or stretched relative to zero (as is done in re-scaling) will not matter for finding the cut point.6 li id="fn6"> show this for yourself using the preprocess=
argument for train
.↩︎
earnings.csv
data (which describes yearly earnings in thousands of dollars as a function of demographic variables), fit a single decision tree (using the rpart
) to predict earnings using all of the variables in the data except age_band
.7 li id="fn7">Note, you may need to drop NA values…↩︎
Remember: you can use predict(model, new_data)
to make predictions about new data. And, you can set up a range of values from a
to b
, where each value is separated by step
using seq(a,b,step)
↩︎
Think hard about what the MSE measures in part 1, versus what the MSE shown by caret
measures.↩︎