23  Регрессионные модели с tidymodels

23.1 Регрессионные алгоритмы

В машинном обучении проблемы, связанные с количественным откликом, называют проблемами регрессии, а проблемы, связанные с качественным откликом, проблемами классификации. Однако различие не всегда бывает четким: так, логистическая регрессия применяется для получения качественного бинарного отклика, а некоторые методы, такие как SVM, могут использоваться как для задач классификации, так и для задач регрессии.

В прошлом уроке мы познакомились с простой и множественной регрессией, но регрессионных алгоритмов великое множество. Вот лишь некоторые из них:

  1. полиномиальная регрессия: расширение линейной регрессии, позволяющее учитывать нелинейные зависимости.

  2. логистическая регрессия: используется для прогнозирования категориальных (бинарных) откликов.

  3. регрессия опорных векторов (SVR): ищет гиперплоскость, позволяющую минимизировать ошибку в многомерном пространстве.

  4. деревья регрессии: строят иерархическую древовидную модель, последовательно разбивая данные на подгруппы.

  5. случайный лес: комбинирует предсказания множества деревьев для повышения точности и устойчивости.

Кроме того, существуют методы регуляризации линейных моделей, позволяющие существенно улучшить их качество на данных большой размерности (т.е. с большим количеством предкторов). К таким алгоритмам относятся гребневая регрессия и метод лассо. Первая “штрафует” регрессионные коэффициенты, позволяя тем самым избежать переобучения. Лассо-регрессия выполняет отбор переменных, сводя некоторые коэффициенты до нуля. За оба метода отвечает функция glmnet() из одноименной библиотеки: при alpha=0 подгоняется гребневая регрессионная модель, а при alpha=1 – лассо-модель.

О математической стороне дела см. Г. Джеймс, Д. Уиттон, Т. Хасти, Р. Тибришани (2017). В этом уроке мы научимся работать с различными регрессионными алгоритмами, используя библиотеку tidymodels.

23.2 Библиотека tidymodels

Библиотека tidymodels позволяет обучать модели и оценивать их эффективность с использованием принципов опрятных данных. Она представляет собой набор пакетов R, которые разработаны для работы с машинным обучением и являются частью более широкой экосистемы tidyverse.

Вот некоторые из ключевых пакетов, входящих в состав tidymodels:

  1. parsnip - универсальный интерфейс для различных моделей машинного обучения, который упрощает переключение между разными типами моделей;

  2. recipes - фреймворк для создания и управления “рецептами” предварительной обработки данных перед тренировкой модели;

  3. rsample - инструменты для разделения данных на обучающую и тестовую выборки, а также для кросс-валидации;

  4. tune - функции для оптимизации гиперпараметров моделей машинного обучения;

  5. yardstick - инструменты для оценки производительности моделей;

  6. workflow позволяет объединить различные компоненты модели в единый объект: препроцессинг данных, модель машинного обучения, настройку гиперпараметров.

Мы также будем использовать пакет textrecipes, который представляет собой аналог recipes для текстовых данных.

library(tidyverse)
library(tidymodels)
library(textrecipes)

23.3 Данные

Датасет для этого урока хранит данные о названиях, рейтингах, жанре, цене и числе отзывов на некоторые книги с Amazon. Мы попробуем построить регресионную модель, которая будет предсказывать цену книги.

books  <- readxl::read_xlsx("../files/AmazonBooks.xlsx")
books

Данные не очень опрятны, и прежде всего их надо тайдифицировать.

colnames(books) <- tolower(colnames(books))
books <- books |> 
  rename(rating = `user rating`)

На графике ниже видно, что сильной корреляции между количественными переменными не прослеживается, так что задача перед нами стоит незаурядная. Посмотрим, что можно сделать в такой ситуации.

books |> 
  select_if(is.numeric) |> 
  cor() |> 
  corrplot::corrplot(method = "ellipse")

Мы видим, что количественные предикторы объясняют лишь ничтожную долю дисперсии (чуть более информативен жанр).

summary(lm(price ~ reviews + year + rating + genre, data  = books))

Call:
lm(formula = price ~ reviews + year + rating + genre, data = books)

Residuals:
    Min      1Q  Median      3Q     Max 
-16.472  -5.050  -1.841   2.307  89.686 

Coefficients:
                   Estimate Std. Error t value Pr(>|t|)    
(Intercept)       8.987e+02  2.734e+02   3.287  0.00107 ** 
reviews           7.779e-07  3.181e-05   0.024  0.98050    
year             -4.324e-01  1.370e-01  -3.156  0.00168 ** 
rating           -3.655e+00  1.933e+00  -1.891  0.05909 .  
genreNon Fiction  3.920e+00  8.669e-01   4.522 7.41e-06 ***
---
Signif. codes:  0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1

Residual standard error: 10.16 on 595 degrees of freedom
Multiple R-squared:  0.06903,   Adjusted R-squared:  0.06277 
F-statistic: 11.03 on 4 and 595 DF,  p-value: 1.235e-08

Посмотрим, можно ли как-то улучшить этот результат. Но сначала оценим визуально связь между ценой, с одной стороны, и годом и жанром, с другой.

g1 <- books |> 
  ggplot(aes(year, price, color = genre, group = genre)) + 
  geom_jitter(show.legend = FALSE, alpha = 0.7) + 
  geom_smooth(method = "lm", se = FALSE) +
  theme_minimal()

g2 <- books |> 
  ggplot(aes(genre, price, color = genre)) + 
  geom_boxplot() + 
  theme_minimal()

gridExtra::grid.arrange(g1, g2, nrow = 1)

23.4 Обучающая и контрольная выборка

Вы уже знаете, при обучении модели мы стремимся к минимизации среднеквадратичной ошибки (MSE), однако в большинстве случаев нас интересует не то, как метод работает на обучающих данных, а то, как он покажет себя на контрольных данных. Чтобы избежать переобучения, очень важно в самом начале разделить доступные наблюдения на две группы.

books_split <- books |> 
  initial_split()

books_train <- training(books_split)
books_test <- testing(books_split)

23.5 Определение модели

Определение модели включает следующие шаги:

  • указывается тип модели на основе ее математической структуры (например, линейная регрессия, случайный лес, KNN и т. д.);

  • указывается механизм для подгонки модели – чаще всего это программный пакет, который должен быть использован, например glmnet. Это самостоятельные модели, и parsnip обеспечивает согласованные интерфейсы, используя их в качестве движков для моделирования.

  • при необходимости объявляется режим модели. Режим отражает тип прогнозируемого результата. Для числовых результатов режимом является регрессия, для качественных - классификация. Если алгоритм модели может работать только с одним типом результатов прогнозирования, например, линейной регрессией, режим уже задан.

23.6 Регрессия на опорных векторах

Начнем с регрессии на опорных векторах. Функция translate() позволяет понять, как parsnip переводит пользовательский код на язык пакета.

svm_spec <- svm_linear() |>
  set_engine("LiblineaR") |> 
  set_mode("regression")

svm_spec |> 
  translate()
Linear Support Vector Machine Model Specification (regression)

Computational engine: LiblineaR 

Model fit template:
LiblineaR::LiblineaR(x = missing_arg(), y = missing_arg(), type = 11, 
    svr_eps = 0.1)

Пока это просто спецификация модели без данных и без формулы. Добавим ее к воркфлоу.

svm_wflow <- workflow() |> 
  add_model(svm_spec)

svm_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: None
Model: svm_linear()

── Model ───────────────────────────────────────────────────────────────────────
Linear Support Vector Machine Model Specification (regression)

Computational engine: LiblineaR 

23.7 Дизайн переменных

Теперь нам нужен препроцессор. За него отвечает пакет recipes. Если вы не уверены, какие шаги необходимы на этом этапе, можно заглянуть в шпаргалку. В случае с линейной регрессией это может быть логарифмическая трансформация, нормализация, отсев переменных с нулевой дисперсией (zero variance), добавление (impute) недостающих значений или удаление переменных, которые коррелируют с другими переменными.

Вот так выглядит наш первый рецепт. Обратите внимание, что формула записывается так же, как мы это делали ранее внутри функции lm().

books_rec <- recipe(price ~ year + genre + name, 
                    data = books_train) |> 
  step_dummy(genre)  |> 
  step_normalize(year) |> 
  step_tokenize(name)  |> 
  step_tokenfilter(name, max_tokens = 1000)  |> 
  step_tfidf(name) 

При желании можно посмотреть на результат предобработки.

prep(books_rec, books_train) |> 
  bake(new_data = NULL)

Добавляем препроцессор в воркфлоу.

svm_wflow <- svm_wflow |> 
  add_recipe(books_rec)

svm_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_linear()

── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps

• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()

── Model ───────────────────────────────────────────────────────────────────────
Linear Support Vector Machine Model Specification (regression)

Computational engine: LiblineaR 

23.8 Подгонка модели

Теперь подгоним модель на обучающих данных.

svm_fit <- svm_wflow |>
  fit(data = books_train)

Пакет broom позволяет тайдифицировать модель. Посмотрим на слова, которые приводят к “удорожанию” книг. Видно, что в начале списка – слова, связанные с научными публикациями, что не лишено смысла.

svm_fit |> 
  tidy() |> 
  arrange(-estimate)

Оценим модель на контрольных данных.

pred_data <- tibble(truth = books_test$price,
                    estimate = predict(svm_fit, books_test)$.pred)
books_metrics <- metric_set(rmse, rsq, mae)

books_metrics(pred_data, truth = truth,  estimate = estimate)

23.9 Повторные выборки

Чтобы не распечатывать каждый раз тестовые данные (в идеале мы их используем один, максимум два раза!), задействуется ряд методов, позволяющих оценить ошибку путем исключения части обучающих наблюдений из процесса подгонки модели и последующего применения этой модели к исключенным наблюдениям.

В пакете rsample из библиотеки tidymodels реализованы, среди прочего, следующие методы повторных выборок для оценки производительности моделей машинного обучения:

  1. Метод проверочной выборки – набор наблюдений делится на обучающую и проверочную, или удержанную, выборку (validation set): для этого используется initial_validation_split().

  2. K-кратная перекрестная проверка – наблюдения разбиваются на k групп примерно одинакового размера, первый блок служит в качестве проверочной выборки, а модель подгоняется по остальным k-1 блокам; процедура повторяется k раз: функция vfold_cv().

  3. Перекрестная проверка Монте-Карло – в отличие от предыдущего метода, создается множество случайных разбиений данных на обучающую и тестовую выборки: функция mc_cv().

  4. Бутстреп – отбор наблюдений выполняется с возвращением, т.е. одно и то же наблюдение может встречаться несколько раз: функция bootstraps().

  5. Перекрестная проверка по отдельным наблюдениям (leave-one-out сross-validation): одно наблюдение используется в качестве контрольного, а остальные составляют обучающую выборку; модель подгоняется по n-1 наблюдениям, что повторяется n раз: функция loo_cv().

Эти методы повторных выборок позволяют получить надежные оценки производительности моделей машинного обучения, избегая переобучения и обеспечивая репрезентативность тестовых выборок.

set.seed(05102024)
books_folds <- vfold_cv(books_train, v = 10) 

set.seed(05102024)
svm_rs <- fit_resamples(
  svm_wflow,
  books_folds,
  control = control_resamples(save_pred = TRUE)
)
→ A | warning: max_tokens was set to '1000', but only 997 was available and selected.
There were issues with some computations   A: x1
→ B | warning: max_tokens was set to '1000', but only 984 was available and selected.
There were issues with some computations   A: x1
→ C | warning: max_tokens was set to '1000', but only 998 was available and selected.
There were issues with some computations   A: x1
There were issues with some computations   A: x1   B: x1   C: x1

Теперь соберем метрики и убедимся, что предыдущая оценка на контрольных данных была слишком оптимистичной. Однако результат не так уж плох: во всяком случае мы смогли добиться заметного улучшения по сравнению с нулевой моделью.

collect_metrics(svm_rs)
svm_rs |> 
  collect_predictions() |> 
  ggplot(aes(price, .pred, color = id)) +
  geom_jitter(alpha = 0.3) +
  geom_abline(lty = 2, color = "grey80") + 
  theme_minimal() +
  coord_cartesian(xlim = c(0,50), ylim = c(0,50))

23.10 Нулевая модель

Кстати, проверим, какой результат даст нулевая модель.

null_reg <- null_model() |> 
  set_engine("parsnip") |> 
  set_mode("regression")

null_wflow <- workflow() |> 
    add_model(null_reg) |> 
    add_recipe(books_rec)

null_rs <- fit_resamples(
  null_wflow,
  books_folds,
  control = control_resamples(save_pred = TRUE)
  )
→ A | warning: A correlation computation is required, but `estimate` is constant and has 0
               standard deviation, resulting in a divide by 0 error. `NA` will be returned.
→ B | warning: max_tokens was set to '1000', but only 997 was available and selected.
→ C | warning: max_tokens was set to '1000', but only 984 was available and selected.
→ D | warning: max_tokens was set to '1000', but only 998 was available and selected.
collect_metrics(null_rs)

\(R^2\) в таком случае должен быть NaN.

23.11 Случайный лес

Уточним, какие движки доступны для случайных лесов.

show_engines("rand_forest")

Создадим спецификацию модели. Деревья используются как в задачах классификации, так и в задачах регрессии, поэтому задействуем функцию set_mode().

rf_spec <- rand_forest(trees = 1000) |> 
  set_engine("ranger") |> 
  set_mode("regression")
rf_wflow <- workflow() |> 
  add_model(rf_spec) |> 
  add_recipe(books_rec)

rf_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()

── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps

• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()

── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (regression)

Main Arguments:
  trees = 1000

Computational engine: ranger 

Обучение займет чуть больше времени.

rf_rs <- fit_resamples(
  rf_wflow,
  books_folds,
  control = control_resamples(save_pred = TRUE)
)
→ A | warning: max_tokens was set to '1000', but only 997 was available and selected.
There were issues with some computations   A: x1
→ B | warning: max_tokens was set to '1000', but only 984 was available and selected.
There were issues with some computations   A: x1
There were issues with some computations   A: x1   B: x1
→ C | warning: max_tokens was set to '1000', but only 998 was available and selected.
There were issues with some computations   A: x1   B: x1
There were issues with some computations   A: x1   B: x1   C: x1
There were issues with some computations   A: x1   B: x1   C: x1

Мы видим, что среднеквадратическая ошибка уменьшилась, а доля объясненной дисперсии выросла.

collect_metrics(rf_rs)

Тем не менее на графике можно заметить нечто странное: наша модель систематически переоценивает низкие значения и недооценивает высокие. Это связано с тем, что случайные леса не очень подходят для работы с разреженными данными (Hvitfeldt и Silge 2022).

rf_rs |> 
  collect_predictions() |> 
  ggplot(aes(price, .pred, color = id)) +
  geom_jitter(alpha = 0.3) +
  geom_abline(lty = 2, color = "grey80") +
  theme_minimal() +
  coord_cartesian(xlim = c(0, 50), ylim = c(0, 50))

23.12 Градиентные бустинговые деревья

Также попробуем построить регрессию с использованием градиентных бустинговых деревьев. Это один из алгоритмов ансамблевого машинного обучения, который строит последовательность простых моделей решающих деревьев, каждая из которых работает над ошибками предыдущей. В 2023 г. эта техника показала хорошие результаты в эксперименте по датировке греческих документальных папирусов.

xgb_spec <- 
  boost_tree(mtry = 50, trees = 1000)  |> 
  set_engine("xgboost") %>%
  set_mode("regression")
xgb_wflow <- workflow() |> 
  add_model(xgb_spec) |> 
  add_recipe(books_rec)

xgb_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: boost_tree()

── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps

• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()

── Model ───────────────────────────────────────────────────────────────────────
Boosted Tree Model Specification (regression)

Main Arguments:
  mtry = 50
  trees = 1000

Computational engine: xgboost 

Проводим перекрестную проверку.

xgb_rs <- fit_resamples(
  xgb_wflow,
  books_folds,
  control = control_resamples(save_pred = TRUE)
)
→ A | warning: max_tokens was set to '1000', but only 997 was available and selected.
There were issues with some computations   A: x1
→ B | warning: max_tokens was set to '1000', but only 984 was available and selected.
There were issues with some computations   A: x1
There were issues with some computations   A: x1   B: x1
→ C | warning: max_tokens was set to '1000', but only 998 was available and selected.
There were issues with some computations   A: x1   B: x1
There were issues with some computations   A: x1   B: x1   C: x1
collect_metrics(xgb_rs)

Метрики неплохие! Но если взглянуть на остатки, можно увидеть что-то вроде буквы S.

rf_rs |> 
  collect_predictions() |> 
  ggplot(aes(price, .pred, color = id)) +
  geom_jitter(alpha = 0.3) +
  geom_abline(lty = 2, color = "grey80") +
  theme_minimal() +
  coord_cartesian(xlim = c(0, 50), ylim = c(0, 50))

23.13 Удаление стопслов

Изменим рецепт приготовления данных.

stopwords_rec <- function(stopwords_name) {
  recipe(price ~ year + genre + name, data = books_train) |> 
  step_dummy(genre)  |> 
  step_normalize(year) |> 
  step_tokenize(name)  |> 
  step_stopwords(name, stopword_source = stopwords_name) |> 
  step_tokenfilter(name, max_tokens = 1000)  |> 
  step_tfidf(name) 
}

Создадим воркфлоу.

svm_wflow <- workflow() |> 
  add_model(svm_spec)

И снова проведем перекрестную проверку, на этот раз с разными списками стоп-слов. На этом шаге команда вернет предупреждения о том, что число слов меньше 1000, это нормально, т.к. после удаления стопслов токенов стало меньше.

set.seed(123)
snowball_rs <- fit_resamples(
  svm_wflow |>  add_recipe(stopwords_rec("snowball")),
  books_folds
)

set.seed(234)
smart_rs <- fit_resamples(
  svm_wflow |> add_recipe(stopwords_rec("smart")),
  books_folds
)

set.seed(345)
stopwords_iso_rs <- fit_resamples(
  svm_wflow |> add_recipe(stopwords_rec("stopwords-iso")),
  books_folds
)
collect_metrics(smart_rs)
collect_metrics(snowball_rs)
collect_metrics((stopwords_iso_rs))

В нашем случае удаление стоп-слов положительного эффекта не имело.

word_counts <- tibble(name = c("snowball", "smart", "stopwords-iso")) %>%
  mutate(words = map_int(name, ~length(stopwords::stopwords(source = .))))

list(snowball = snowball_rs,
     smart = smart_rs,
     `stopwords-iso` = stopwords_iso_rs)  |> 
  map_dfr(show_best, metric = "rmse", .id = "name")  |> 
  left_join(word_counts, by = "name")  |> 
  mutate(name = paste0(name, " (", words, " words)"),
         name = fct_reorder(name, words))  |> 
  ggplot(aes(name, mean, color = name)) +
  geom_crossbar(aes(ymin = mean - std_err, ymax = mean + std_err), alpha = 0.6) +
  geom_point(size = 3, alpha = 0.8) +
  theme(legend.position = "none") + 
  theme_minimal()

23.14 Настройки числа n-grams

ngram_rec <- function(ngram_options) {
  recipe(price ~ year + genre + name, data = books_train) |> 
  step_dummy(genre)  |> 
  step_normalize(year) |> 
  step_tokenize(name, token = "ngrams", options = ngram_options)  |> 
  step_tokenfilter(name, max_tokens = 1000)  |> 
  step_tfidf(name) 
}
fit_ngram <- function(ngram_options) {
  fit_resamples(
    svm_wflow %>% add_recipe(ngram_rec(ngram_options)),
    books_folds
  )
}
set.seed(123)
unigram_rs <- fit_ngram(list(n = 1))
→ A | warning: max_tokens was set to '1000', but only 997 was available and selected.
→ B | warning: max_tokens was set to '1000', but only 984 was available and selected.
There were issues with some computations   A: x1   B: x1
→ C | warning: max_tokens was set to '1000', but only 998 was available and selected.
There were issues with some computations   A: x1   B: x1
There were issues with some computations   A: x1   B: x1   C: x1
set.seed(234)
bigram_rs <- fit_ngram(list(n = 2, n_min = 1))

set.seed(345)
trigram_rs <- fit_ngram(list(n = 3, n_min = 1))
collect_metrics(unigram_rs)
collect_metrics(bigram_rs)
collect_metrics(trigram_rs)

Таким образом, униграмы дают лучший результат:

list(`1` = unigram_rs,
     `1 and 2` = bigram_rs,
     `1, 2, and 3` = trigram_rs) |> 
  map_dfr(collect_metrics, .id = "name")  |> 
  filter(.metric == "rmse")  |> 
  ggplot(aes(name, mean, color = name)) +
  geom_crossbar(aes(ymin = mean - std_err, ymax = mean + std_err), 
                alpha = 0.6) +
  geom_point(size = 3, alpha = 0.8) +
  theme(legend.position = "none") +
  labs(
    y = "RMSE"
  ) + 
  theme_minimal()

23.15 Лучшая модель и оценка

svm_fit <- svm_wflow |>
  add_recipe(books_rec) |> 
  fit(data = books_test)
Warning: max_tokens was set to '1000', but only 612 was available and selected.
svm_fit
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_linear()

── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps

• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()

── Model ───────────────────────────────────────────────────────────────────────
$TypeDetail
[1] "L2-regularized L2-loss support vector regression primal (L2R_L2LOSS_SVR)"

$Type
[1] 11

$W
           year genre_Non.Fiction tfidf_name_1 tfidf_name_10 tfidf_name_100
[1,] -0.9520146          3.485356       1.1305     -1.187969      -1.996428
     tfidf_name_150 tfidf_name_16 tfidf_name_17 tfidf_name_1936 tfidf_name_1984
[1,]     -0.6271174     0.1926996       1.80019      -0.2370809      -0.6116231
     tfidf_name_2 tfidf_name_2.0 tfidf_name_22 tfidf_name_3 tfidf_name_30
[1,]   -0.4068736      0.8091512   -0.09894225     -1.04787     0.8119458
     tfidf_name_4 tfidf_name_451 tfidf_name_49 tfidf_name_5 tfidf_name_5,000
[1,]    0.3118533        -0.0736      1.196741    0.5180948        0.3178177
     tfidf_name_500 tfidf_name_6 tfidf_name_6th tfidf_name_7 tfidf_name_8
[1,]      0.2111397    -3.974732       8.957552     1.933749  -0.09894225
     tfidf_name_a tfidf_name_about tfidf_name_absurd tfidf_name_activity
[1,]    -1.241315        0.1435666         0.7912706            -1.04787
     tfidf_name_adult tfidf_name_adults tfidf_name_advanced
[1,]        -1.085287        -0.2930574           0.2111397
     tfidf_name_adventures tfidf_name_afterlife tfidf_name_aftermath
[1,]             0.4109577           -0.7802049           0.01016221
     tfidf_name_again tfidf_name_ages tfidf_name_agreements tfidf_name_alaska
[1,]        0.8241497      -0.9075971             -1.555872        -0.2463905
     tfidf_name_almost tfidf_name_alphabet tfidf_name_am tfidf_name_america
[1,]        -0.9554437          -0.9075971    -0.3789486          0.6591157
     tfidf_name_american tfidf_name_americans tfidf_name_an tfidf_name_ancient
[1,]            5.041504            0.1500841    -0.5035489          0.7991669
     tfidf_name_and tfidf_name_animals tfidf_name_answers tfidf_name_antiracist
[1,]     -0.5396753          -1.670664          0.7912706              1.388923
     tfidf_name_approach tfidf_name_are tfidf_name_art tfidf_name_as
[1,]           0.6142527     -0.5674495     -0.6325811    -0.4421042
     tfidf_name_asians tfidf_name_assault tfidf_name_association
[1,]        -0.3779175          -1.910361               8.957552
     tfidf_name_astrophysics tfidf_name_at tfidf_name_autobiography
[1,]               -1.042553    -0.2370809                -1.295159
     tfidf_name_awesome tfidf_name_back tfidf_name_badass tfidf_name_battling
[1,]         -0.4442324       -1.753965        -0.8324738          -0.2089105
     tfidf_name_be tfidf_name_bear tfidf_name_beasts tfidf_name_beautiful
[1,]     0.6550246       -1.708765         0.8600794           -0.3789486
     tfidf_name_become tfidf_name_bed tfidf_name_beginner's
[1,]        0.01735057     -0.9415579            -0.7647593
     tfidf_name_beginners tfidf_name_believing tfidf_name_berlin tfidf_name_big
[1,]            0.2111397           0.01735057        -0.2370809     -0.7848711
     tfidf_name_bill tfidf_name_blood tfidf_name_boat tfidf_name_body
[1,]       -1.910361       -0.8841644      -0.1870963        -1.04787
     tfidf_name_book tfidf_name_books tfidf_name_born tfidf_name_bossypants
[1,]       -4.197888       -0.1782664        1.572082            -0.3790329
     tfidf_name_boxed tfidf_name_boy tfidf_name_boys tfidf_name_brave

...
and 290 more lines.

Взглянем на остатки. Для этого пригодится уже знакомая функция augment() из пакета broom.

svm_res <- augment(svm_fit, new_data = books_test) |> 
  mutate(res = price - .pred) |> 
  select(price, .pred, res)

svm_res
library(gridExtra)

g1 <- svm_res |> 
  mutate(res = price - .pred) |> 
  ggplot(aes(res)) +
  geom_histogram(fill = "steelblue", color  = "white") +
  theme_minimal()

g2 <- svm_res |> 
  ggplot(aes(price, .pred)) +
  geom_jitter(color = "steelblue", alpha = 0.7) +
  geom_abline(linetype = 2, color = "grey80", linewidth = 2) +
  theme_minimal()

grid.arrange(g1, g2, nrow = 1)

Соберем метрики.

books_metrics <- metric_set(rmse, rsq, mae)
books_metrics(svm_res, truth = price,  estimate = .pred)

Также посмотрим, какие слова больше всего связаны с увеличением и с уменьшением цены.

svm_fit |> 
  tidy() |> 
  filter(term != "year") |> 
  filter(!str_detect(term, "genre")) |> 
  mutate(sign = case_when(estimate > 0 ~ "дороже",
                          .default = "дешевле"),
         estimate = abs(estimate), 
         term = str_remove_all(term, "tfidf_name_")) |> 
  group_by(sign) |> 
  top_n(20, estimate) |> 
  ungroup() |> 
  ggplot(aes(x = estimate, y = fct_reorder(term, estimate),
             fill = sign)) +
  geom_col(show.legend = FALSE) +
  scale_x_continuous(expand = c(0,0)) +
  facet_wrap(~sign, scales = "free") +
  labs(y = NULL, 
       title = "Связь слов с ценой книг") +
  theme_minimal()

Любопытно: судя по нашему датасету, конституция США раздается на Амазоне бесплатно.