Function | Works |
---|---|
tidypredict_fit() , tidypredict_sql() ,
parse_model()
|
✔ |
tidypredict_to_column() |
✗ |
tidypredict_test() |
✗ |
tidypredict_interval() ,
tidypredict_sql_interval()
|
✗ |
parsnip |
✔ |
How it works
Here is a simple ranger()
model using the
iris
dataset:
Under the hood
The parser is based on the output from the
ranger::treeInfo()
function. It will return as many
decision paths as there are non-NA rows in the prediction
field.
treeInfo(model) %>%
head()
#> nodeID leftChild rightChild splitvarID splitvarName splitval terminal
#> 1 0 1 2 3 Petal.Width 1.75 FALSE
#> 2 1 3 4 2 Petal.Length 2.45 FALSE
#> 3 2 5 6 2 Petal.Length 4.85 FALSE
#> 4 3 NA NA NA <NA> NA TRUE
#> 5 4 7 8 2 Petal.Length 5.40 FALSE
#> 6 5 NA NA NA <NA> NA TRUE
#> prediction
#> 1 <NA>
#> 2 <NA>
#> 3 <NA>
#> 4 setosa
#> 5 <NA>
#> 6 virginica
The output from parse_model()
is transformed into a
dplyr
, a.k.a Tidy Eval, formula. The entire decision tree
becomes one dplyr::case_when()
statement
tidypredict_fit(model)[1]
#> [[1]]
#> case_when(Petal.Length < 2.45 & Petal.Width < 1.75 ~ "setosa",
#> Petal.Length < 4.85 & Petal.Width >= 1.75 ~ "virginica",
#> Petal.Length >= 4.85 & Petal.Width >= 1.75 ~ "virginica",
#> Petal.Length < 5.4 & Petal.Length >= 2.45 & Petal.Width <
#> 1.75 ~ "versicolor", Petal.Length >= 5.4 & Petal.Length >=
#> 2.45 & Petal.Width < 1.75 ~ "virginica")
From there, the Tidy Eval formula can be used anywhere where it can
be operated. tidypredict
provides three paths:
- Use directly inside
dplyr
,mutate(iris, !! tidypredict_fit(model))
- Use
tidypredict_to_column(model)
to a piped command set - Use
tidypredict_to_sql(model)
to retrieve the SQL statement
parsnip
tidypredict
also supports ranger
model
objects fitted via the parsnip
package.
library(parsnip)
parsnip_model <- rand_forest(mode = "classification") %>%
set_engine("ranger") %>%
fit(Species ~ ., data = iris)
tidypredict_fit(parsnip_model)[[1]]
#> case_when(Petal.Width < 0.8 ~ "setosa", Petal.Length >= 5.05 &
#> Petal.Width >= 0.8 ~ "virginica", Petal.Width < 1.65 & Petal.Length <
#> 4.75 & Petal.Length < 5.05 & Petal.Width >= 0.8 ~ "versicolor",
#> Petal.Width >= 1.65 & Petal.Length < 4.75 & Petal.Length <
#> 5.05 & Petal.Width >= 0.8 ~ "virginica", Sepal.Length >=
#> 6.5 & Petal.Length >= 4.75 & Petal.Length < 5.05 & Petal.Width >=
#> 0.8 ~ "versicolor", Sepal.Width < 3.1 & Sepal.Length <
#> 6.5 & Petal.Length >= 4.75 & Petal.Length < 5.05 & Petal.Width >=
#> 0.8 ~ "virginica", Sepal.Width >= 3.1 & Sepal.Length <
#> 6.5 & Petal.Length >= 4.75 & Petal.Length < 5.05 & Petal.Width >=
#> 0.8 ~ "versicolor")