| Function | Works |
|---|---|
tidypredict_fit(), tidypredict_sql(),
parse_model()
|
✔ |
tidypredict_to_column() |
✗ |
tidypredict_test() |
✗ |
tidypredict_interval(),
tidypredict_sql_interval()
|
✗ |
parsnip |
✔ |
How it works
Here is a simple randomForest() model using the
mtcars dataset:
library(dplyr)
library(tidypredict)
library(randomForest)
model <- randomForest(mpg ~ ., data = mtcars, ntree = 5, proximity = TRUE)Under the hood
The parser is based on the output from the
randomForest::getTree() function. It will return as many
decision paths as there are non-NA rows in the prediction
field.
getTree(model, labelVar = TRUE) %>%
head()
#> left daughter right daughter split var split point status prediction
#> 1 2 3 carb 1.500 -3 20.12813
#> 2 4 5 gear 3.500 -3 28.82222
#> 3 6 7 disp 221.700 -3 16.72609
#> 4 0 0 <NA> 0.000 -1 18.10000
#> 5 8 9 hp 65.500 -3 31.88571
#> 6 10 11 wt 3.295 -3 21.38333The output from parse_model() is transformed into a
dplyr, a.k.a Tidy Eval, formula. Each decision tree becomes
one dplyr::case_when() statement, which are then
combined.
tidypredict_fit(model)
#> (case_when(carb <= 1.5 ~ case_when(gear <= 3.5 ~ 18.1, .default = case_when(hp <=
#> 65.5 ~ 33.9, .default = case_when(qsec <= 19.185 ~ 27.3,
#> .default = 32.4))), .default = case_when(disp <= 221.7 ~
#> case_when(wt <= 3.295 ~ 21.82, .default = 19.2), .default = case_when(wt <=
#> 4.66 ~ case_when(qsec <= 16.96 ~ 14.94, .default = case_when(hp <=
#> 162.5 ~ 15.2, .default = case_when(carb <= 2.5 ~ 19.2, .default = 16.34))),
#> .default = 10.4))) + case_when(hp <= 95 ~ case_when(hp <=
#> 65.5 ~ 31.5666666666667, .default = case_when(wt <= 2.23 ~
#> 26.325, .default = 22.8)), .default = case_when(hp <= 190 ~
#> case_when(drat <= 3.385 ~ case_when(hp <= 130 ~ 19.2, .default = 17.025),
#> .default = case_when(gear <= 4.5 ~ case_when(vs <= 0.5 ~
#> 21, .default = 21.5), .default = 19.7)), .default = case_when(qsec <=
#> 17.62 ~ 14.325, .default = 10.4))) + case_when(disp <= 142.9 ~
#> case_when(qsec <= 18.755 ~ 22.8, .default = 28.74), .default = case_when(wt <=
#> 3.49 ~ case_when(disp <= 163.8 ~ 20.1333333333333, .default = case_when(gear <=
#> 3.5 ~ 19.2, .default = 17.4)), .default = case_when(qsec <=
#> 17.71 ~ case_when(hp <= 212.5 ~ case_when(drat <= 2.915 ~
#> 15.5, .default = 18.25), .default = case_when(carb <= 6 ~
#> 13.7, .default = 15)), .default = 10.4))) + case_when(am <=
#> 0.5 ~ case_when(cyl <= 5 ~ 22.55, .default = case_when(hp <=
#> 192.5 ~ case_when(hp <= 136.5 ~ 18.5, .default = case_when(drat <=
#> 3.11 ~ 17.075, .default = 17.3)), .default = 13.84)), .default = case_when(wt <=
#> 2.2775 ~ 29.78, .default = 18.42)) + case_when(hp <= 116.5 ~
#> case_when(cyl <= 5 ~ case_when(drat <= 4.165 ~ case_when(hp <=
#> 77.5 ~ 24.4, .default = case_when(disp <= 114.05 ~ 22.8,
#> .default = 21.7)), .default = 32.15), .default = 21.1),
#> .default = case_when(qsec <= 18.14 ~ case_when(hp <= 192.5 ~
#> case_when(disp <= 337.9 ~ 16.76, .default = 19.2), .default = case_when(drat <=
#> 3.105 ~ 10.4, .default = 13.975)), .default = 18.2666666666667)))/5From there, the Tidy Eval formula can be used anywhere where it can
be operated. tidypredict provides three paths:
- Use directly inside
dplyr,mutate(iris, !! tidypredict_fit(model)) - Use
tidypredict_to_column(model)to a piped command set - Use
tidypredict_to_sql(model)to retrieve the SQL statement
parsnip
tidypredict also supports randomForest
model objects fitted via the parsnip package.
library(parsnip)
parsnip_model <- rand_forest(mode = "regression", trees = 5) %>%
set_engine("randomForest") %>%
fit(mpg ~ ., data = mtcars)
tidypredict_fit(parsnip_model)
#> (case_when(disp <= 107.6 ~ case_when(wt <= 2.0675 ~ 30.5, .default = 32.4),
#> .default = case_when(wt <= 3.16 ~ case_when(drat <= 4.175 ~
#> case_when(hp <= 96 ~ 22.8, .default = 20.925), .default = 26),
#> .default = case_when(cyl <= 7 ~ case_when(qsec <= 18.6 ~
#> 19.2, .default = 17.9), .default = case_when(hp <=
#> 177.5 ~ case_when(wt <= 3.4375 ~ 15.2, .default = case_when(drat <=
#> 2.92 ~ 15.5, .default = 18.8)), .default = case_when(drat <=
#> 3.14 ~ 16.55, .default = 15.05))))) + case_when(hp <=
#> 79.5 ~ 29.8, .default = case_when(cyl <= 7 ~ case_when(disp <=
#> 142.9 ~ 22.125, .default = case_when(hp <= 116.5 ~ 21.1333333333333,
#> .default = case_when(qsec <= 18.6 ~ 19.5, .default = 17.8))),
#> .default = case_when(disp <= 456 ~ case_when(drat <= 3.115 ~
#> case_when(drat <= 3.075 ~ 15.725, .default = 19.2), .default = 15.34),
#> .default = 10.4))) + case_when(drat <= 3.75 ~ case_when(wt <=
#> 3.3125 ~ 22.05, .default = case_when(drat <= 3.035 ~ 10.4,
#> .default = case_when(carb <= 2.5 ~ 16.3666666666667, .default = case_when(hp <=
#> 290 ~ case_when(disp <= 395 ~ 13.68, .default = 14.7),
#> .default = 15)))), .default = case_when(wt <= 2.23 ~
#> 29.6, .default = case_when(gear <= 4.5 ~ case_when(vs <=
#> 0.5 ~ 21, .default = case_when(qsec <= 0 ~ 24.4, .default = 22.8)),
#> .default = 15.8))) + case_when(cyl <= 5 ~ case_when(wt <=
#> 2.26 ~ case_when(qsec <= 16.8 ~ 26, .default = 31.5), .default = 22.56),
#> .default = case_when(disp <= 266.9 ~ case_when(qsec <= 19.83 ~
#> case_when(hp <= 116.5 ~ 21.08, .default = 19.2), .default = 18.1),
#> .default = case_when(qsec <= 17.71 ~ case_when(drat <=
#> 3.19 ~ case_when(drat <= 3.075 ~ 16.4, .default = 18.95),
#> .default = 15.25), .default = 12.8))) + case_when(hp <=
#> 80.5 ~ 29.9, .default = case_when(carb <= 2.5 ~ case_when(gear <=
#> 4.5 ~ case_when(disp <= 182.9 ~ 21.7, .default = case_when(wt <=
#> 3.3275 ~ 21.4, .default = case_when(cyl <= 7 ~ 18.1, .default = case_when(disp <=
#> 380 ~ 18.7, .default = 19.2)))), .default = 30.4), .default = case_when(hp <=
#> 192.5 ~ 15.725, .default = case_when(hp <= 222.5 ~ 10.4,
#> .default = case_when(drat <= 3.88 ~ 14.925, .default = 15.8))))))/5