Ошибка Tidymodels / XGBoost в last_fit со значением rsplit

Я пытаюсь следовать этому руководству здесь — https://juliasilge.com/blog/xgboost-tune-volleyball/

Я использую его в самом последнем наборе данных Tidy Tuesday о рыбалке в великих озерах — пытаюсь предсказать влияние на основе многих других значений.

ВСЕ приведенный ниже код работает, кроме последней строки, где я получаю следующую ошибку:

> final_res <- last_fit(final_xgb, stock_folds)
Error: Each element of `splits` must be an `rsplit` object.

Я поискал эту ошибку и попал на эту страницу — https://github.com/tidymodels/rsample/issues/175 На этом сайте это называется ошибкой и, кажется, исправлено, но я использую именно initial_time_split, а не initial_split. Я бы предпочел не менять его, потому что тогда мне пришлось бы повторно запускать xgboost, что заняло 9 часов. Что здесь пошло не так?

# Setup ----
library(tidyverse)
library(tidymodels)

stocked <- readr::read_csv('https://raw.githubusercontent.com/rfordatascience/tidytuesday/master/data/2021/2021-06-08/stocked.csv')

stocked_modeling <- stocked %>% 
  mutate(AGENCY = case_when(
    AGENCY != "OMNR" ~ "other",
    TRUE ~ AGENCY
  )) %>% 
  select(-SID, -MONTH, -DAY, -LATITUDE, -LONGITUDE, -GRID, -STRAIN, -AGEMONTH,
         -MARK_EFF, -TAG_NO, -TAG_RET, -LENGTH, -WEIGHT, - CONDITION, -LOT_CODE,
         -NOTES, - VALIDATION, -LS_MGMT, -STAT_DIST, -ST_SITE, -YEAR_CLASS, -STOCK_METH) %>% 
  mutate_if(is.character, factor) %>% 
  drop_na()

# Start making model ----
set.seed(123)
stock_split <- initial_split(stocked_modeling, strata = AGENCY)
stock_train <- training(stock_split)
stock_test <- testing(stock_split)

xgb_spec <- boost_tree(
  trees = 1000,
  tree_depth = tune(), min_n = tune(), loss_reduction = tune(),
  sample_size = tune(), mtry = tune(),
  learn_rate = tune()
) %>% 
  set_engine("xgboost") %>% 
  set_mode("classification")

xgb_grid <- grid_latin_hypercube(
  tree_depth(),
  min_n(),
  loss_reduction(),
  sample_size = sample_prop(),
  finalize(mtry(), stock_train),
  learn_rate(),
  size = 20
)

xgb_workflow <- workflow() %>% 
  add_formula(AGENCY ~ .) %>% 
  add_model(xgb_spec)

set.seed(123)
stock_folds <- vfold_cv(stock_train, strata = AGENCY)

doParallel::registerDoParallel()


# BEWARE, THIS CODE BELOW TOOK 9 HOURS TO RUN
set.seed(234)
xgb_res <- tune_grid(
  xgb_workflow,
  resamples = stock_folds,
  grid = xgb_grid,
  control = control_grid(save_pred = TRUE)
)

# Explore results
best_auc <- select_best(xgb_res, "roc_auc")

final_xgb <- finalize_workflow(
  xgb_workflow, 
  best_auc)

final_res <- last_fit(final_xgb, stock_folds)

См. также:  Как проверить числовые даты на неправильные значения?
Понравилась статья? Поделиться с друзьями:
IT Шеф
Комментарии: 1
  1. Indescribled

    Если мы посмотрим на документацию last_fit (), мы увидим, что split должен быть

    Объект rsplit, созданный из `rsample :: initial_split ().

    Вы случайно передали объект сворачивания перекрестной проверки stock_folds в split, но вместо этого вы должны были передать rsplit объект stock_split

    final_res <- last_fit(final_xgb, stock_split)
    
Добавить комментарий

;-) :| :x :twisted: :smile: :shock: :sad: :roll: :razz: :oops: :o :mrgreen: :lol: :idea: :grin: :evil: :cry: :cool: :arrow: :???: :?: :!: