25  Многоклассовая классификация

cowsay::say("Это заготовка для главы. Загляните позже.", "egret")

 ------------------------------------------- 
<Это заготовка для главы. Загляните позже.>
 ------------------------------------------- 
        \
         \

           _,
      -==<' `
          ) /
         / (_.
        |  ,-,`\
         \\   \ \
          `\,  \ \
           ||\  \`|,
 jgs      _|| `=`-'
         ~~`~`

Многоклассовая классификация может использоваться для определения автора, жанра, тематики или эмоциональной тональности текста.

library(tidyverse)
library(textrecipes)
library(tidymodels)

25.1 Данные

load("../data/greek_corpus.Rdata")
greek_corpus |> 
  count(corpus) |> 
  ggplot(aes(reorder(corpus, n), n, fill = corpus)) +
  geom_bar(stat = "identity", show.legend = FALSE) +
  coord_flip() +
  xlab(NULL) +
  ylab(NULL) +
  scale_fill_viridis_d() + 
  theme_light() +
  geom_text(aes(label = n), nudge_y = 4, color = "grey40")

set.seed(06022025)
data_split <- initial_split(greek_corpus, strata = corpus)
data_train <- training(data_split)[-2]
data_test <- testing(data_split)[-2]

data_train
# folds
set.seed(06022025)
folds <- vfold_cv(data_train, strata = corpus, v = 5)
folds

25.2 Recipe

library(themis)

base_rec <- recipe(corpus ~ ., data = data_train) |> 
  step_downsample(corpus)
base_rec
── Recipe ──────────────────────────────────────────────────────────────────────
── Inputs 
Number of variables by role
outcome:      1
predictor: 1001
── Operations 
• Down-sampling based on: corpus
norm_rec <- recipe(corpus ~ ., data = data_train)  |> 
  step_normalize(all_predictors()) |> 
  step_downsample(corpus)
norm_rec
── Recipe ──────────────────────────────────────────────────────────────────────
── Inputs 
Number of variables by role
outcome:      1
predictor: 1001
── Operations 
• Centering and scaling for: all_predictors()
• Down-sampling based on: corpus

25.3 KNN

knn_spec <- nearest_neighbor(neighbors = tune()) |> 
  set_mode("classification") |> 
  set_engine("kknn")

knn_spec
K-Nearest Neighbor Model Specification (classification)

Main Arguments:
  neighbors = tune()

Computational engine: kknn 
knn_grid <- tibble(neighbors = c(1,3,5))
knn_grid

25.4 SVM

svm_lin_spec <- svm_linear(cost = tune()) |> 
  set_mode("classification") |> 
  set_engine("kernlab")
svm_pol_spec <- svm_poly(degree = tune()) |> 
  set_mode("classification") |> 
  set_engine("kernlab")
svm_lin_param <- extract_parameter_set_dials(svm_lin_spec)

svm_lin_grid <- svm_lin_param |> 
  grid_regular(levels = 3)

svm_lin_grid
svm_pol_grid <- tibble(degree = c(1.5, 2))

svm_pol_grid

26 Workflow_set

wflow_set <- workflow() |> 
  workflow_set(  
  preproc = list(base = base_rec,  
                 norm = norm_rec),  
  models = list(knn = knn_spec, 
                svml = svm_lin_spec,
                svmp = svm_pol_spec),  
  cross = TRUE
)


wflow_set
wflow_set_final <- wflow_set |> 
  option_add(grid = knn_grid, id = "base_knn") |> 
  option_add(grid = knn_grid, id = "norm_knn") |> 
  option_add(grid = svm_lin_grid, id = "base_svml") |>
  option_add(grid = svm_lin_grid, id = "norm_svml") |> 
  option_add(grid = svm_pol_grid, id = "base_svmp") |>
  option_add(grid = svm_pol_grid, id = "norm_svmp") 

Это займет время!!!

library(doMC)
library(tictoc)
registerDoMC(cores = 5)

set.seed(07022025)
tic()
wflow_set_fit <- 
  workflow_map(
    wflow_set_final, 
    verbose = TRUE, 
    metrics = metric_set(accuracy, roc_auc),
    resamples = folds,
    control = control_resamples(save_pred = TRUE),
    fn = "tune_grid"
  )
toc()
registerDoSEQ()
# i 1 of 6 tuning:     base_knn
# ✔ 1 of 6 tuning:     base_knn (4.3s)
# i 2 of 6 tuning:     base_svml
# ✔ 2 of 6 tuning:     base_svml (1m 8.8s)
# i 3 of 6 tuning:     base_svmp
# ✔ 3 of 6 tuning:     base_svmp (34.9s)
# i 4 of 6 tuning:     norm_knn
# ✔ 4 of 6 tuning:     norm_knn (5.1s)
# i 5 of 6 tuning:     norm_svml
# ✔ 5 of 6 tuning:     norm_svml (1m 9.7s)
# i 6 of 6 tuning:     norm_svmp
# ✔ 6 of 6 tuning:     norm_svmp (30.6s)
# 218.219 sec elapsed
autoplot(wflow_set_fit, metric = "accuracy") + 
  theme_light() +
  theme(legend.position = "none",
        ) +
  geom_text(aes(y = (mean - 2 * std_err), label = wflow_id), angle = 90, hjust = 1) +
  lims(y = c(0, 1)) 
Warning: Removed 1 row containing missing values or values outside the scale range
(`geom_text()`).

26.1 Boost_tree

bt_spec <- boost_tree(trees = tune()) |> 
  set_mode("classification") |> 
  set_engine("xgboost")
bt_grid <- tibble(trees = c(100, 500, 1000))
bt_wflow <- workflow() |> 
  add_model(bt_spec) |> 
  add_recipe(base_rec)

bt_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: boost_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
1 Recipe Step

• step_downsample()

── Model ───────────────────────────────────────────────────────────────────────
Boosted Tree Model Specification (classification)

Main Arguments:
  trees = tune()

Computational engine: xgboost 
set.seed(08022025)
bt_tune <- tune_grid(
  bt_wflow,
  grid = bt_grid,
  metrics = metric_set(accuracy, roc_auc),
  folds,
  control = control_resamples(save_pred = TRUE, save_workflow = TRUE)
)
# → A | warning: ✖ No observations were detected in `truth` for levels: Antiphon,
#                  Callimachus, Philostratus the Athenian, and Xenophon.
#                ℹ Computation will proceed by ignoring those levels.
# → B | warning: ✖ No observations were detected in `truth` for levels: Antiphon, Arrian,
#                  Hyperides, and Philostratus the Athenian.
#                ℹ Computation will proceed by ignoring those levels.
# → C | warning: ✖ No observations were detected in `truth` for levels: Aeschylus, Antiphon,
#                  Callimachus, and Hyperides.
#                ℹ Computation will proceed by ignoring those levels.
# → D | warning: ✖ No observations were detected in `truth` for levels: Aristophanes,
#                  Hyperides, Isaeus, and Sophocles.
#                ℹ Computation will proceed by ignoring those levels.
# → E | warning: ✖ No observations were detected in `truth` for levels: Isaeus and
#                  Sophocles.
#                ℹ Computation will proceed by ignoring those levels.
autoplot(bt_tune)

26.2 Добавить в workflow_set

wflow_set_final <- wflow_set_fit |> 
  bind_rows(as_workflow_set(bt_base = bt_tune))
autoplot(wflow_set_final, metric = "accuracy") + 
  theme_light() +
  theme(legend.position = "none") +
  geom_text(aes(y = (mean - 2*std_err), label = wflow_id), angle = 90, hjust = 1) +
  lims(y = c(-0.1, 1))

autoplot(wflow_set_final, metric = "roc_auc") + 
  theme_light() +
  theme(legend.position = "none") +
  geom_text(aes(y = (mean - 2*std_err), label = wflow_id), angle = 90, hjust = 1) +
  lims(y = c(0.4, 1))
Warning: Removed 1 row containing missing values or values outside the scale range
(`geom_text()`).

26.3 Выбор модели и окончательная настройка

rank_results(wflow_set_final, rank_metric = "accuracy")
autoplot(wflow_set_fit, id = "norm_svml") +
  theme_light()

best_results <- 
   wflow_set_final |> 
   extract_workflow_set_result("norm_svml") |> 
   select_best(metric = "roc_auc")
best_results
svml_test_results <- 
   wflow_set_final |> 
   extract_workflow("norm_svml") |> 
   finalize_workflow(best_results) |> 
   last_fit(split = data_split)
→ A | warning: Variable(s) `' constant. Cannot scale data.
There were issues with some computations   A: x1
→ B | warning: ✖ No observations were detected in `truth` for levels: Arrian and Philostratus
                 the Athenian.
               ℹ Computation will proceed by ignoring those levels.
There were issues with some computations   A: x1
There were issues with some computations   A: x1   B: x1

26.4 Оценка

collect_metrics(svml_test_results)
svml_test_results |> 
  collect_predictions() |>
  conf_mat(truth = corpus, estimate = .pred_class) |> 
  autoplot(type = "heatmap") +
  scale_fill_gradient(low = "white", high = "#233857") +
  theme(panel.grid.major = element_line(colour = "#233857"),
        axis.text = element_text(color = "#233857"),
        axis.title = element_text(color = "#233857"),
        plot.title = element_text(color = "#233857"),
        axis.text.x = element_text(angle = 90))

Многовато классов, ничего не видно, позже поправлю.