| Function | Works |
|---|---|
tidypredict_fit(), tidypredict_sql(),
parse_model()
|
✔ |
tidypredict_to_column() |
✔ |
tidypredict_test() |
✔ |
tidypredict_interval(),
tidypredict_sql_interval()
|
✗ |
parsnip |
✔ |
tidypredict_ functions
library(lightgbm)
# Prepare data
X <- data.matrix(mtcars[, c("mpg", "cyl", "disp")])
y <- mtcars$hp
dtrain <- lgb.Dataset(X, label = y, colnames = c("mpg", "cyl", "disp"))
model <- lgb.train(
params = list(
num_leaves = 4L,
learning_rate = 0.5,
objective = "regression",
min_data_in_leaf = 1L
),
data = dtrain,
nrounds = 10L,
verbose = -1L
)-
Create the R formula
tidypredict_fit(model) #> case_when((cyl <= 7 | is.na(cyl)) ~ 122.371527777778, cyl > 7 & #> (mpg <= 15.1 | is.na(mpg)) & (disp <= 334 | is.na(disp)) ~ #> 240.84375, cyl > 7 & (mpg <= 15.1 | is.na(mpg)) & disp > #> 334 ~ 187.34375, cyl > 7 & mpg > 15.1 ~ 164.21875) + case_when((mpg <= #> 20.35 | is.na(mpg)) & (mpg <= 15.1 | is.na(mpg)) ~ 24.7864583333333, #> (mpg <= 20.35 | is.na(mpg)) & mpg > 15.1 ~ 7.36516196529071, #> mpg > 20.35 & (mpg <= 23.6 | is.na(mpg)) ~ -9.47147832598005, #> mpg > 20.35 & mpg > 23.6 ~ -24.4000499589103) + case_when((mpg <= #> 21.45 | is.na(mpg)) & (mpg <= 17.55 | is.na(mpg)) & (mpg <= #> 15.65 | is.na(mpg)) ~ 6.33150075541602, (mpg <= 21.45 | is.na(mpg)) & #> (mpg <= 17.55 | is.na(mpg)) & mpg > 15.65 ~ 18.2080434163411, #> (mpg <= 21.45 | is.na(mpg)) & mpg > 17.55 ~ 0.0642609119415283, #> mpg > 21.45 ~ -11.2250245332718) + case_when((disp <= 334 | #> is.na(disp)) & (mpg <= 15.1 | is.na(mpg)) ~ 31.5191459655761, #> (disp <= 334 | is.na(disp)) & mpg > 15.1 ~ -4.16611832639445, #> disp > 334 & (disp <= 380 | is.na(disp)) ~ 16.3295566439629, #> disp > 334 & disp > 380 ~ -0.254162490367889) + case_when((disp <= #> 78.85 | is.na(disp)) ~ -10.7901678085327, disp > 78.85 & #> (mpg <= 28.85 | is.na(mpg)) & (disp <= 334 | is.na(disp)) ~ #> -0.749505842104554, disp > 78.85 & (mpg <= 28.85 | is.na(mpg)) & #> disp > 334 ~ 4.01884852349758, disp > 78.85 & mpg > 28.85 ~ #> 15.2098321914673) + case_when((disp <= 78.85 | is.na(disp)) ~ #> -5.39508358637492, disp > 78.85 & (disp <= 466 | is.na(disp)) & #> (mpg <= 15.1 | is.na(mpg)) ~ 4.51956310272216, disp > 78.85 & #> (disp <= 466 | is.na(disp)) & mpg > 15.1 ~ 0.0956797075012456, #> disp > 78.85 & disp > 466 ~ -8.61319732666016) + case_when((mpg <= #> 21.45 | is.na(mpg)) & (disp <= 153.35 | is.na(disp)) & (cyl <= #> 5 | is.na(cyl)) ~ 0.427817046642302, (mpg <= 21.45 | is.na(mpg)) & #> (disp <= 153.35 | is.na(disp)) & cyl > 5 ~ 25.0094966888427, #> (mpg <= 21.45 | is.na(mpg)) & disp > 153.35 ~ -0.436469084024429, #> mpg > 21.45 ~ -1.67079323530197) + case_when((disp <= 334 | #> is.na(disp)) & (mpg <= 15.1 | is.na(mpg)) ~ 14.0927782058715, #> (disp <= 334 | is.na(disp)) & mpg > 15.1 & (disp <= 288.4 | #> is.na(disp)) ~ -0.208523882286889, (disp <= 334 | is.na(disp)) & #> mpg > 15.1 & disp > 288.4 ~ -11.3294992446899, disp > #> 334 ~ 1.61815292341635) + case_when((mpg <= 13.8 | is.na(mpg)) ~ #> -3.70564748346806, mpg > 13.8 & (mpg <= 17.55 | is.na(mpg)) & #> (disp <= 334 | is.na(disp)) ~ -0.805894414583842, mpg > 13.8 & #> (mpg <= 17.55 | is.na(mpg)) & disp > 334 ~ 9.19054534534613, #> mpg > 13.8 & mpg > 17.55 ~ -0.580966327060014) + case_when((disp <= #> 380 | is.na(disp)) & (mpg <= 17.55 | is.na(mpg)) ~ 1.89159697956509, #> (disp <= 380 | is.na(disp)) & mpg > 17.55 & (mpg <= 18.95 | #> is.na(mpg)) ~ -6.20051162441571, (disp <= 380 | is.na(disp)) & #> mpg > 17.55 & mpg > 18.95 ~ 0.834156103432178, disp > #> 380 ~ -2.94233404099941) -
Add the prediction to the original table
library(dplyr) mtcars %>% tidypredict_to_column(model) %>% glimpse() #> Rows: 32 #> Columns: 12 #> $ mpg <dbl> 21.0, 21.0, 22.8, 21.4, 18.7, 18.1, 14.3, 24.4, 22.8, 19… #> $ cyl <dbl> 6, 6, 4, 6, 8, 6, 8, 4, 4, 6, 6, 8, 8, 8, 8, 8, 8, 4, 4,… #> $ disp <dbl> 160.0, 160.0, 108.0, 258.0, 360.0, 225.0, 360.0, 146.7, … #> $ hp <dbl> 110, 110, 93, 110, 175, 105, 245, 62, 95, 123, 123, 180,… #> $ drat <dbl> 3.90, 3.90, 3.85, 3.08, 3.15, 2.76, 3.21, 3.69, 3.92, 3.… #> $ wt <dbl> 2.620, 2.875, 2.320, 3.215, 3.440, 3.460, 3.570, 3.190, … #> $ qsec <dbl> 16.46, 17.02, 18.61, 19.44, 17.02, 20.22, 15.84, 20.00, … #> $ vs <dbl> 0, 0, 1, 1, 0, 1, 0, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 1, 1,… #> $ am <dbl> 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1,… #> $ gear <dbl> 4, 4, 4, 3, 3, 3, 3, 4, 4, 4, 4, 3, 3, 3, 3, 3, 3, 4, 4,… #> $ carb <dbl> 4, 4, 1, 1, 2, 1, 4, 2, 2, 4, 4, 3, 3, 3, 4, 4, 4, 1, 2,… #> $ fit <dbl> 107.75256, 107.75256, 95.22895, 107.75256, 186.49246, 11… -
Confirm that
tidypredictresults match to the model’spredict()results. Thexg_dfargument expects the matrix data set.tidypredict_test(model, xg_df = X) #> tidypredict test results #> Difference threshold: 1e-12 #> #> All results are within the difference threshold
Supported objectives
LightGBM supports many objective functions. The following objectives
are supported by tidypredict:
Binary classification example
X_bin <- data.matrix(mtcars[, c("mpg", "cyl", "disp")])
y_bin <- mtcars$am
dtrain_bin <- lgb.Dataset(X_bin, label = y_bin, colnames = c("mpg", "cyl", "disp"))
model_bin <- lgb.train(
params = list(
num_leaves = 4L,
learning_rate = 0.5,
objective = "binary",
min_data_in_leaf = 1L
),
data = dtrain_bin,
nrounds = 10L,
verbose = -1L
)
tidypredict_test(model_bin, xg_df = X_bin)
#> tidypredict test results
#> Difference threshold: 1e-12
#>
#> All results are within the difference thresholdMulticlass classification
For multiclass models, tidypredict_fit() returns a named
list of formulas, one for each class:
X_iris <- data.matrix(iris[, 1:4])
colnames(X_iris) <- c("Sepal.Length", "Sepal.Width", "Petal.Length", "Petal.Width")
y_iris <- as.integer(iris$Species) - 1L
dtrain_iris <- lgb.Dataset(X_iris, label = y_iris, colnames = colnames(X_iris))
model_multi <- lgb.train(
params = list(
num_leaves = 4L,
learning_rate = 0.5,
objective = "multiclass",
num_class = 3L,
min_data_in_leaf = 1L
),
data = dtrain_iris,
nrounds = 5L,
verbose = -1L
)
fit_formulas <- tidypredict_fit(model_multi)
names(fit_formulas)
#> [1] "class_0" "class_1" "class_2"Each formula produces the predicted probability for that class:
iris %>%
mutate(
prob_setosa = !!fit_formulas$class_0,
prob_versicolor = !!fit_formulas$class_1,
prob_virginica = !!fit_formulas$class_2
) %>%
select(Species, starts_with("prob_")) %>%
head()
#> Species prob_setosa prob_versicolor prob_virginica
#> 1 setosa 0.9786973 0.01046491 0.0108378
#> 2 setosa 0.9786973 0.01046491 0.0108378
#> 3 setosa 0.9786973 0.01046491 0.0108378
#> 4 setosa 0.9786973 0.01046491 0.0108378
#> 5 setosa 0.9786973 0.01046491 0.0108378
#> 6 setosa 0.9786973 0.01046491 0.0108378Note: tidypredict_test() does not support multiclass
models. Use tidypredict_fit() directly.
Categorical features
LightGBM supports native categorical features. When a feature is
marked as categorical, tidypredict generates appropriate
%in% conditions:
set.seed(123)
n <- 200
cat_data <- data.frame(
cat_feat = sample(0:3, n, replace = TRUE),
y = NA
)
cat_data$y <- ifelse(cat_data$cat_feat %in% c(0, 1), 10, -10) + rnorm(n, sd = 2)
X_cat <- matrix(cat_data$cat_feat, ncol = 1)
colnames(X_cat) <- "cat_feat"
dtrain_cat <- lgb.Dataset(
X_cat,
label = cat_data$y,
categorical_feature = "cat_feat"
)
model_cat <- lgb.train(
params = list(
num_leaves = 4L,
learning_rate = 1.0,
objective = "regression",
min_data_in_leaf = 1L
),
data = dtrain_cat,
nrounds = 2L,
verbose = -1L
)
tidypredict_fit(model_cat)
#> case_when(cat_feat %in% 0:1 ~ 9.22111156962135, (!(cat_feat %in%
#> 0:1) | is.na(cat_feat)) ~ -9.19527530561794) + case_when(cat_feat %in%
#> 0:1 ~ 0.837108638881579, (!(cat_feat %in% 0:1) | is.na(cat_feat)) ~
#> -0.837108347632668)parsnip
parsnip fitted models (via the bonsai
package) are also supported by tidypredict:
library(parsnip)
library(bonsai)
p_model <- boost_tree(
trees = 10,
tree_depth = 3,
min_n = 1
) %>%
set_engine("lightgbm") %>%
set_mode("regression") %>%
fit(hp ~ mpg + cyl + disp, data = mtcars)
# Extract the underlying lgb.Booster
lgb_model <- p_model$fit
tidypredict_test(lgb_model, xg_df = X)
#> tidypredict test results
#> Difference threshold: 1e-12
#>
#> All results are within the difference thresholdParse model spec
Here is an example of the model spec:
pm <- parse_model(model)
str(pm, 2)
#> List of 2
#> $ general:List of 9
#> ..$ model : chr "lgb.Booster"
#> ..$ type : chr "lgb"
#> ..$ version : num 1
#> ..$ params :List of 8
#> ..$ feature_names : chr [1:3] "mpg" "cyl" "disp"
#> ..$ nfeatures : int 3
#> ..$ num_class : int 1
#> ..$ num_tree_per_iteration: int 1
#> ..$ niter : int 10
#> $ trees :List of 10
#> ..$ 0:List of 4
#> ..$ 1:List of 4
#> ..$ 2:List of 4
#> ..$ 3:List of 4
#> ..$ 4:List of 4
#> ..$ 5:List of 4
#> ..$ 6:List of 4
#> ..$ 7:List of 4
#> ..$ 8:List of 4
#> ..$ 9:List of 4
#> - attr(*, "class")= chr [1:3] "parsed_model" "pm_lgb" "list"
str(pm$trees[1])
#> List of 1
#> $ 0:List of 4
#> ..$ :List of 2
#> .. ..$ prediction: num 122
#> .. ..$ path :List of 1
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "cyl"
#> .. .. .. ..$ val : num 7
#> .. .. .. ..$ op : chr "less-equal"
#> .. .. .. ..$ missing: logi TRUE
#> ..$ :List of 2
#> .. ..$ prediction: num 241
#> .. ..$ path :List of 3
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "cyl"
#> .. .. .. ..$ val : num 7
#> .. .. .. ..$ op : chr "more"
#> .. .. .. ..$ missing: logi FALSE
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "mpg"
#> .. .. .. ..$ val : num 15.1
#> .. .. .. ..$ op : chr "less-equal"
#> .. .. .. ..$ missing: logi TRUE
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "disp"
#> .. .. .. ..$ val : num 334
#> .. .. .. ..$ op : chr "less-equal"
#> .. .. .. ..$ missing: logi TRUE
#> ..$ :List of 2
#> .. ..$ prediction: num 187
#> .. ..$ path :List of 3
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "cyl"
#> .. .. .. ..$ val : num 7
#> .. .. .. ..$ op : chr "more"
#> .. .. .. ..$ missing: logi FALSE
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "mpg"
#> .. .. .. ..$ val : num 15.1
#> .. .. .. ..$ op : chr "less-equal"
#> .. .. .. ..$ missing: logi TRUE
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "disp"
#> .. .. .. ..$ val : num 334
#> .. .. .. ..$ op : chr "more"
#> .. .. .. ..$ missing: logi FALSE
#> ..$ :List of 2
#> .. ..$ prediction: num 164
#> .. ..$ path :List of 2
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "cyl"
#> .. .. .. ..$ val : num 7
#> .. .. .. ..$ op : chr "more"
#> .. .. .. ..$ missing: logi FALSE
#> .. .. ..$ :List of 5
#> .. .. .. ..$ type : chr "conditional"
#> .. .. .. ..$ col : chr "mpg"
#> .. .. .. ..$ val : num 15.1
#> .. .. .. ..$ op : chr "more"
#> .. .. .. ..$ missing: logi FALSELimitations
- Ranking objectives (
lambdarank,rank_xendcg) are not supported - Prediction intervals are not supported
-
tidypredict_test()does not support multiclass models
