library(tidyverse)
library(tidymodels)
library(textrecipes)
23 Регрессионные модели с tidymodels
23.1 Регрессионные алгоритмы
В машинном обучении проблемы, связанные с количественным откликом, называют проблемами регрессии, а проблемы, связанные с качественным откликом, проблемами классификации. Однако различие не всегда бывает четким: так, логистическая регрессия применяется для получения качественного бинарного отклика, а некоторые методы, такие как SVM, могут использоваться как для задач классификации, так и для задач регрессии.
В прошлом уроке мы познакомились с простой и множественной регрессией, но регрессионных алгоритмов великое множество. Вот лишь некоторые из них:
полиномиальная регрессия: расширение линейной регрессии, позволяющее учитывать нелинейные зависимости.
логистическая регрессия: используется для прогнозирования категориальных (бинарных) откликов.
регрессия опорных векторов (SVR): ищет гиперплоскость, позволяющую минимизировать ошибку в многомерном пространстве.
деревья регрессии: строят иерархическую древовидную модель, последовательно разбивая данные на подгруппы.
случайный лес: комбинирует предсказания множества деревьев для повышения точности и устойчивости.
Кроме того, существуют методы регуляризации линейных моделей, позволяющие существенно улучшить их качество на данных большой размерности (т.е. с большим количеством предкторов). К таким алгоритмам относятся гребневая регрессия и метод лассо. Первая “штрафует” регрессионные коэффициенты, позволяя тем самым избежать переобучения. Лассо-регрессия выполняет отбор переменных, сводя некоторые коэффициенты до нуля. За оба метода отвечает функция glmnet()
из одноименной библиотеки: при alpha=0
подгоняется гребневая регрессионная модель, а при alpha=1
– лассо-модель.
О математической стороне дела см. Г. Джеймс, Д. Уиттон, Т. Хасти, Р. Тибришани (2017). В этом уроке мы научимся работать с различными регрессионными алгоритмами, используя библиотеку tidymodels
.
23.2 Библиотека tidymodels
Библиотека tidymodels позволяет обучать модели и оценивать их эффективность с использованием принципов опрятных данных. Она представляет собой набор пакетов R, которые разработаны для работы с машинным обучением и являются частью более широкой экосистемы tidyverse
.
Вот некоторые из ключевых пакетов, входящих в состав tidymodels
:
parsnip
- универсальный интерфейс для различных моделей машинного обучения, который упрощает переключение между разными типами моделей;recipes
- фреймворк для создания и управления “рецептами” предварительной обработки данных перед тренировкой модели;rsample
- инструменты для разделения данных на обучающую и тестовую выборки, а также для кросс-валидации;tune
- функции для оптимизации гиперпараметров моделей машинного обучения;yardstick
- инструменты для оценки производительности моделей;workflow
позволяет объединить различные компоненты модели в единый объект: препроцессинг данных, модель машинного обучения, настройку гиперпараметров.
Мы также будем использовать пакет textrecipes
, который представляет собой аналог recipes
для текстовых данных.
23.3 Данные
Датасет для этого урока хранит данные о названиях, рейтингах, жанре, цене и числе отзывов на некоторые книги с Amazon. Мы попробуем построить регресионную модель, которая будет предсказывать цену книги.
<- readxl::read_xlsx("../files/AmazonBooks.xlsx")
books books
Данные не очень опрятны, и прежде всего их надо тайдифицировать.
colnames(books) <- tolower(colnames(books))
<- books |>
books rename(rating = `user rating`)
На графике ниже видно, что сильной корреляции между количественными переменными не прослеживается, так что задача перед нами стоит незаурядная. Посмотрим, что можно сделать в такой ситуации.
|>
books select_if(is.numeric) |>
cor() |>
::corrplot(method = "ellipse") corrplot
Мы видим, что количественные предикторы объясняют лишь ничтожную долю дисперсии (чуть более информативен жанр).
summary(lm(price ~ reviews + year + rating + genre, data = books))
Call:
lm(formula = price ~ reviews + year + rating + genre, data = books)
Residuals:
Min 1Q Median 3Q Max
-16.472 -5.050 -1.841 2.307 89.686
Coefficients:
Estimate Std. Error t value Pr(>|t|)
(Intercept) 8.987e+02 2.734e+02 3.287 0.00107 **
reviews 7.779e-07 3.181e-05 0.024 0.98050
year -4.324e-01 1.370e-01 -3.156 0.00168 **
rating -3.655e+00 1.933e+00 -1.891 0.05909 .
genreNon Fiction 3.920e+00 8.669e-01 4.522 7.41e-06 ***
---
Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1
Residual standard error: 10.16 on 595 degrees of freedom
Multiple R-squared: 0.06903, Adjusted R-squared: 0.06277
F-statistic: 11.03 on 4 and 595 DF, p-value: 1.235e-08
Посмотрим, можно ли как-то улучшить этот результат. Но сначала оценим визуально связь между ценой, с одной стороны, и годом и жанром, с другой.
<- books |>
g1 ggplot(aes(year, price, color = genre, group = genre)) +
geom_jitter(show.legend = FALSE, alpha = 0.7) +
geom_smooth(method = "lm", se = FALSE) +
theme_minimal()
<- books |>
g2 ggplot(aes(genre, price, color = genre)) +
geom_boxplot() +
theme_minimal()
::grid.arrange(g1, g2, nrow = 1) gridExtra
23.4 Обучающая и контрольная выборка
Вы уже знаете, при обучении модели мы стремимся к минимизации среднеквадратичной ошибки (MSE), однако в большинстве случаев нас интересует не то, как метод работает на обучающих данных, а то, как он покажет себя на контрольных данных. Чтобы избежать переобучения, очень важно в самом начале разделить доступные наблюдения на две группы.
<- books |>
books_split initial_split()
<- training(books_split)
books_train <- testing(books_split) books_test
23.5 Определение модели
Определение модели включает следующие шаги:
указывается тип модели на основе ее математической структуры (например, линейная регрессия, случайный лес, KNN и т. д.);
указывается механизм для подгонки модели – чаще всего это программный пакет, который должен быть использован, например
glmnet
. Это самостоятельные модели, иparsnip
обеспечивает согласованные интерфейсы, используя их в качестве движков для моделирования.при необходимости объявляется режим модели. Режим отражает тип прогнозируемого результата. Для числовых результатов режимом является регрессия, для качественных - классификация. Если алгоритм модели может работать только с одним типом результатов прогнозирования, например, линейной регрессией, режим уже задан.
23.6 Регрессия на опорных векторах
Начнем с регрессии на опорных векторах. Функция translate()
позволяет понять, как parsnip
переводит пользовательский код на язык пакета.
<- svm_linear() |>
svm_spec set_engine("LiblineaR") |>
set_mode("regression")
|>
svm_spec translate()
Linear Support Vector Machine Model Specification (regression)
Computational engine: LiblineaR
Model fit template:
LiblineaR::LiblineaR(x = missing_arg(), y = missing_arg(), type = 11,
svr_eps = 0.1)
Пока это просто спецификация модели без данных и без формулы. Добавим ее к воркфлоу.
<- workflow() |>
svm_wflow add_model(svm_spec)
svm_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: None
Model: svm_linear()
── Model ───────────────────────────────────────────────────────────────────────
Linear Support Vector Machine Model Specification (regression)
Computational engine: LiblineaR
23.7 Дизайн переменных
Теперь нам нужен препроцессор. За него отвечает пакет recipes
. Если вы не уверены, какие шаги необходимы на этом этапе, можно заглянуть в шпаргалку. В случае с линейной регрессией это может быть логарифмическая трансформация, нормализация, отсев переменных с нулевой дисперсией (zero variance), добавление (impute) недостающих значений или удаление переменных, которые коррелируют с другими переменными.
Вот так выглядит наш первый рецепт. Обратите внимание, что формула записывается так же, как мы это делали ранее внутри функции lm()
.
<- recipe(price ~ year + genre + name,
books_rec data = books_train) |>
step_dummy(genre) |>
step_normalize(year) |>
step_tokenize(name) |>
step_tokenfilter(name, max_tokens = 1000) |>
step_tfidf(name)
При желании можно посмотреть на результат предобработки.
prep(books_rec, books_train) |>
bake(new_data = NULL)
Добавляем препроцессор в воркфлоу.
<- svm_wflow |>
svm_wflow add_recipe(books_rec)
svm_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_linear()
── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps
• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()
── Model ───────────────────────────────────────────────────────────────────────
Linear Support Vector Machine Model Specification (regression)
Computational engine: LiblineaR
23.8 Подгонка модели
Теперь подгоним модель на обучающих данных.
<- svm_wflow |>
svm_fit fit(data = books_train)
Пакет broom
позволяет тайдифицировать модель. Посмотрим на слова, которые приводят к “удорожанию” книг. Видно, что в начале списка – слова, связанные с научными публикациями, что не лишено смысла.
|>
svm_fit tidy() |>
arrange(-estimate)
Оценим модель на контрольных данных.
<- tibble(truth = books_test$price,
pred_data estimate = predict(svm_fit, books_test)$.pred)
<- metric_set(rmse, rsq, mae)
books_metrics
books_metrics(pred_data, truth = truth, estimate = estimate)
23.9 Повторные выборки
Чтобы не распечатывать каждый раз тестовые данные (в идеале мы их используем один, максимум два раза!), задействуется ряд методов, позволяющих оценить ошибку путем исключения части обучающих наблюдений из процесса подгонки модели и последующего применения этой модели к исключенным наблюдениям.
В пакете rsample
из библиотеки tidymodels
реализованы, среди прочего, следующие методы повторных выборок для оценки производительности моделей машинного обучения:
Метод проверочной выборки – набор наблюдений делится на обучающую и проверочную, или удержанную, выборку (validation set): для этого используется
initial_validation_split()
.K-кратная перекрестная проверка – наблюдения разбиваются на k групп примерно одинакового размера, первый блок служит в качестве проверочной выборки, а модель подгоняется по остальным k-1 блокам; процедура повторяется k раз: функция
vfold_cv()
.Перекрестная проверка Монте-Карло – в отличие от предыдущего метода, создается множество случайных разбиений данных на обучающую и тестовую выборки: функция
mc_cv()
.Бутстреп – отбор наблюдений выполняется с возвращением, т.е. одно и то же наблюдение может встречаться несколько раз: функция
bootstraps()
.Перекрестная проверка по отдельным наблюдениям (leave-one-out сross-validation): одно наблюдение используется в качестве контрольного, а остальные составляют обучающую выборку; модель подгоняется по n-1 наблюдениям, что повторяется n раз: функция
loo_cv()
.
Эти методы повторных выборок позволяют получить надежные оценки производительности моделей машинного обучения, избегая переобучения и обеспечивая репрезентативность тестовых выборок.
set.seed(05102024)
<- vfold_cv(books_train, v = 10)
books_folds
set.seed(05102024)
<- fit_resamples(
svm_rs
svm_wflow,
books_folds,control = control_resamples(save_pred = TRUE)
)
→ A | warning: max_tokens was set to '1000', but only 997 was available and selected.
There were issues with some computations A: x1
→ B | warning: max_tokens was set to '1000', but only 984 was available and selected.
There were issues with some computations A: x1
→ C | warning: max_tokens was set to '1000', but only 998 was available and selected.
There were issues with some computations A: x1
There were issues with some computations A: x1 B: x1 C: x1
Теперь соберем метрики и убедимся, что предыдущая оценка на контрольных данных была слишком оптимистичной. Однако результат не так уж плох: во всяком случае мы смогли добиться заметного улучшения по сравнению с нулевой моделью.
collect_metrics(svm_rs)
|>
svm_rs collect_predictions() |>
ggplot(aes(price, .pred, color = id)) +
geom_jitter(alpha = 0.3) +
geom_abline(lty = 2, color = "grey80") +
theme_minimal() +
coord_cartesian(xlim = c(0,50), ylim = c(0,50))
23.10 Нулевая модель
Кстати, проверим, какой результат даст нулевая модель.
<- null_model() |>
null_reg set_engine("parsnip") |>
set_mode("regression")
<- workflow() |>
null_wflow add_model(null_reg) |>
add_recipe(books_rec)
<- fit_resamples(
null_rs
null_wflow,
books_folds,control = control_resamples(save_pred = TRUE)
)
→ A | warning: A correlation computation is required, but `estimate` is constant and has 0
standard deviation, resulting in a divide by 0 error. `NA` will be returned.
→ B | warning: max_tokens was set to '1000', but only 997 was available and selected.
→ C | warning: max_tokens was set to '1000', but only 984 was available and selected.
→ D | warning: max_tokens was set to '1000', but only 998 was available and selected.
collect_metrics(null_rs)
\(R^2\) в таком случае должен быть NaN.
23.11 Случайный лес
Уточним, какие движки доступны для случайных лесов.
show_engines("rand_forest")
Создадим спецификацию модели. Деревья используются как в задачах классификации, так и в задачах регрессии, поэтому задействуем функцию set_mode()
.
<- rand_forest(trees = 1000) |>
rf_spec set_engine("ranger") |>
set_mode("regression")
<- workflow() |>
rf_wflow add_model(rf_spec) |>
add_recipe(books_rec)
rf_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: rand_forest()
── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps
• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()
── Model ───────────────────────────────────────────────────────────────────────
Random Forest Model Specification (regression)
Main Arguments:
trees = 1000
Computational engine: ranger
Обучение займет чуть больше времени.
<- fit_resamples(
rf_rs
rf_wflow,
books_folds,control = control_resamples(save_pred = TRUE)
)
→ A | warning: max_tokens was set to '1000', but only 997 was available and selected.
There were issues with some computations A: x1
→ B | warning: max_tokens was set to '1000', but only 984 was available and selected.
There were issues with some computations A: x1
There were issues with some computations A: x1 B: x1
→ C | warning: max_tokens was set to '1000', but only 998 was available and selected.
There were issues with some computations A: x1 B: x1
There were issues with some computations A: x1 B: x1 C: x1
There were issues with some computations A: x1 B: x1 C: x1
Мы видим, что среднеквадратическая ошибка уменьшилась, а доля объясненной дисперсии выросла.
collect_metrics(rf_rs)
Тем не менее на графике можно заметить нечто странное: наша модель систематически переоценивает низкие значения и недооценивает высокие. Это связано с тем, что случайные леса не очень подходят для работы с разреженными данными (Hvitfeldt и Silge 2022).
|>
rf_rs collect_predictions() |>
ggplot(aes(price, .pred, color = id)) +
geom_jitter(alpha = 0.3) +
geom_abline(lty = 2, color = "grey80") +
theme_minimal() +
coord_cartesian(xlim = c(0, 50), ylim = c(0, 50))
23.12 Градиентные бустинговые деревья
Также попробуем построить регрессию с использованием градиентных бустинговых деревьев. Это один из алгоритмов ансамблевого машинного обучения, который строит последовательность простых моделей решающих деревьев, каждая из которых работает над ошибками предыдущей. В 2023 г. эта техника показала хорошие результаты в эксперименте по датировке греческих документальных папирусов.
<-
xgb_spec boost_tree(mtry = 50, trees = 1000) |>
set_engine("xgboost") %>%
set_mode("regression")
<- workflow() |>
xgb_wflow add_model(xgb_spec) |>
add_recipe(books_rec)
xgb_wflow
══ Workflow ════════════════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: boost_tree()
── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps
• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()
── Model ───────────────────────────────────────────────────────────────────────
Boosted Tree Model Specification (regression)
Main Arguments:
mtry = 50
trees = 1000
Computational engine: xgboost
Проводим перекрестную проверку.
<- fit_resamples(
xgb_rs
xgb_wflow,
books_folds,control = control_resamples(save_pred = TRUE)
)
→ A | warning: max_tokens was set to '1000', but only 997 was available and selected.
There were issues with some computations A: x1
→ B | warning: max_tokens was set to '1000', but only 984 was available and selected.
There were issues with some computations A: x1
There were issues with some computations A: x1 B: x1
→ C | warning: max_tokens was set to '1000', but only 998 was available and selected.
There were issues with some computations A: x1 B: x1
There were issues with some computations A: x1 B: x1 C: x1
collect_metrics(xgb_rs)
Метрики неплохие! Но если взглянуть на остатки, можно увидеть что-то вроде буквы S.
|>
rf_rs collect_predictions() |>
ggplot(aes(price, .pred, color = id)) +
geom_jitter(alpha = 0.3) +
geom_abline(lty = 2, color = "grey80") +
theme_minimal() +
coord_cartesian(xlim = c(0, 50), ylim = c(0, 50))
23.13 Удаление стопслов
Изменим рецепт приготовления данных.
<- function(stopwords_name) {
stopwords_rec recipe(price ~ year + genre + name, data = books_train) |>
step_dummy(genre) |>
step_normalize(year) |>
step_tokenize(name) |>
step_stopwords(name, stopword_source = stopwords_name) |>
step_tokenfilter(name, max_tokens = 1000) |>
step_tfidf(name)
}
Создадим воркфлоу.
<- workflow() |>
svm_wflow add_model(svm_spec)
И снова проведем перекрестную проверку, на этот раз с разными списками стоп-слов. На этом шаге команда вернет предупреждения о том, что число слов меньше 1000, это нормально, т.к. после удаления стопслов токенов стало меньше.
set.seed(123)
<- fit_resamples(
snowball_rs |> add_recipe(stopwords_rec("snowball")),
svm_wflow
books_folds
)
set.seed(234)
<- fit_resamples(
smart_rs |> add_recipe(stopwords_rec("smart")),
svm_wflow
books_folds
)
set.seed(345)
<- fit_resamples(
stopwords_iso_rs |> add_recipe(stopwords_rec("stopwords-iso")),
svm_wflow
books_folds )
collect_metrics(smart_rs)
collect_metrics(snowball_rs)
collect_metrics((stopwords_iso_rs))
В нашем случае удаление стоп-слов положительного эффекта не имело.
<- tibble(name = c("snowball", "smart", "stopwords-iso")) %>%
word_counts mutate(words = map_int(name, ~length(stopwords::stopwords(source = .))))
list(snowball = snowball_rs,
smart = smart_rs,
`stopwords-iso` = stopwords_iso_rs) |>
map_dfr(show_best, metric = "rmse", .id = "name") |>
left_join(word_counts, by = "name") |>
mutate(name = paste0(name, " (", words, " words)"),
name = fct_reorder(name, words)) |>
ggplot(aes(name, mean, color = name)) +
geom_crossbar(aes(ymin = mean - std_err, ymax = mean + std_err), alpha = 0.6) +
geom_point(size = 3, alpha = 0.8) +
theme(legend.position = "none") +
theme_minimal()
23.14 Настройки числа n-grams
<- function(ngram_options) {
ngram_rec recipe(price ~ year + genre + name, data = books_train) |>
step_dummy(genre) |>
step_normalize(year) |>
step_tokenize(name, token = "ngrams", options = ngram_options) |>
step_tokenfilter(name, max_tokens = 1000) |>
step_tfidf(name)
}
<- function(ngram_options) {
fit_ngram fit_resamples(
%>% add_recipe(ngram_rec(ngram_options)),
svm_wflow
books_folds
) }
set.seed(123)
<- fit_ngram(list(n = 1)) unigram_rs
→ A | warning: max_tokens was set to '1000', but only 997 was available and selected.
→ B | warning: max_tokens was set to '1000', but only 984 was available and selected.
There were issues with some computations A: x1 B: x1
→ C | warning: max_tokens was set to '1000', but only 998 was available and selected.
There were issues with some computations A: x1 B: x1
There were issues with some computations A: x1 B: x1 C: x1
set.seed(234)
<- fit_ngram(list(n = 2, n_min = 1))
bigram_rs
set.seed(345)
<- fit_ngram(list(n = 3, n_min = 1)) trigram_rs
collect_metrics(unigram_rs)
collect_metrics(bigram_rs)
collect_metrics(trigram_rs)
Таким образом, униграмы дают лучший результат:
list(`1` = unigram_rs,
`1 and 2` = bigram_rs,
`1, 2, and 3` = trigram_rs) |>
map_dfr(collect_metrics, .id = "name") |>
filter(.metric == "rmse") |>
ggplot(aes(name, mean, color = name)) +
geom_crossbar(aes(ymin = mean - std_err, ymax = mean + std_err),
alpha = 0.6) +
geom_point(size = 3, alpha = 0.8) +
theme(legend.position = "none") +
labs(
y = "RMSE"
+
) theme_minimal()
23.15 Лучшая модель и оценка
<- svm_wflow |>
svm_fit add_recipe(books_rec) |>
fit(data = books_test)
Warning: max_tokens was set to '1000', but only 612 was available and selected.
svm_fit
══ Workflow [trained] ══════════════════════════════════════════════════════════
Preprocessor: Recipe
Model: svm_linear()
── Preprocessor ────────────────────────────────────────────────────────────────
5 Recipe Steps
• step_dummy()
• step_normalize()
• step_tokenize()
• step_tokenfilter()
• step_tfidf()
── Model ───────────────────────────────────────────────────────────────────────
$TypeDetail
[1] "L2-regularized L2-loss support vector regression primal (L2R_L2LOSS_SVR)"
$Type
[1] 11
$W
year genre_Non.Fiction tfidf_name_1 tfidf_name_10 tfidf_name_100
[1,] -0.9520146 3.485356 1.1305 -1.187969 -1.996428
tfidf_name_150 tfidf_name_16 tfidf_name_17 tfidf_name_1936 tfidf_name_1984
[1,] -0.6271174 0.1926996 1.80019 -0.2370809 -0.6116231
tfidf_name_2 tfidf_name_2.0 tfidf_name_22 tfidf_name_3 tfidf_name_30
[1,] -0.4068736 0.8091512 -0.09894225 -1.04787 0.8119458
tfidf_name_4 tfidf_name_451 tfidf_name_49 tfidf_name_5 tfidf_name_5,000
[1,] 0.3118533 -0.0736 1.196741 0.5180948 0.3178177
tfidf_name_500 tfidf_name_6 tfidf_name_6th tfidf_name_7 tfidf_name_8
[1,] 0.2111397 -3.974732 8.957552 1.933749 -0.09894225
tfidf_name_a tfidf_name_about tfidf_name_absurd tfidf_name_activity
[1,] -1.241315 0.1435666 0.7912706 -1.04787
tfidf_name_adult tfidf_name_adults tfidf_name_advanced
[1,] -1.085287 -0.2930574 0.2111397
tfidf_name_adventures tfidf_name_afterlife tfidf_name_aftermath
[1,] 0.4109577 -0.7802049 0.01016221
tfidf_name_again tfidf_name_ages tfidf_name_agreements tfidf_name_alaska
[1,] 0.8241497 -0.9075971 -1.555872 -0.2463905
tfidf_name_almost tfidf_name_alphabet tfidf_name_am tfidf_name_america
[1,] -0.9554437 -0.9075971 -0.3789486 0.6591157
tfidf_name_american tfidf_name_americans tfidf_name_an tfidf_name_ancient
[1,] 5.041504 0.1500841 -0.5035489 0.7991669
tfidf_name_and tfidf_name_animals tfidf_name_answers tfidf_name_antiracist
[1,] -0.5396753 -1.670664 0.7912706 1.388923
tfidf_name_approach tfidf_name_are tfidf_name_art tfidf_name_as
[1,] 0.6142527 -0.5674495 -0.6325811 -0.4421042
tfidf_name_asians tfidf_name_assault tfidf_name_association
[1,] -0.3779175 -1.910361 8.957552
tfidf_name_astrophysics tfidf_name_at tfidf_name_autobiography
[1,] -1.042553 -0.2370809 -1.295159
tfidf_name_awesome tfidf_name_back tfidf_name_badass tfidf_name_battling
[1,] -0.4442324 -1.753965 -0.8324738 -0.2089105
tfidf_name_be tfidf_name_bear tfidf_name_beasts tfidf_name_beautiful
[1,] 0.6550246 -1.708765 0.8600794 -0.3789486
tfidf_name_become tfidf_name_bed tfidf_name_beginner's
[1,] 0.01735057 -0.9415579 -0.7647593
tfidf_name_beginners tfidf_name_believing tfidf_name_berlin tfidf_name_big
[1,] 0.2111397 0.01735057 -0.2370809 -0.7848711
tfidf_name_bill tfidf_name_blood tfidf_name_boat tfidf_name_body
[1,] -1.910361 -0.8841644 -0.1870963 -1.04787
tfidf_name_book tfidf_name_books tfidf_name_born tfidf_name_bossypants
[1,] -4.197888 -0.1782664 1.572082 -0.3790329
tfidf_name_boxed tfidf_name_boy tfidf_name_boys tfidf_name_brave
...
and 290 more lines.
Взглянем на остатки. Для этого пригодится уже знакомая функция augment()
из пакета broom
.
<- augment(svm_fit, new_data = books_test) |>
svm_res mutate(res = price - .pred) |>
select(price, .pred, res)
svm_res
library(gridExtra)
<- svm_res |>
g1 mutate(res = price - .pred) |>
ggplot(aes(res)) +
geom_histogram(fill = "steelblue", color = "white") +
theme_minimal()
<- svm_res |>
g2 ggplot(aes(price, .pred)) +
geom_jitter(color = "steelblue", alpha = 0.7) +
geom_abline(linetype = 2, color = "grey80", linewidth = 2) +
theme_minimal()
grid.arrange(g1, g2, nrow = 1)
Соберем метрики.
<- metric_set(rmse, rsq, mae)
books_metrics books_metrics(svm_res, truth = price, estimate = .pred)
Также посмотрим, какие слова больше всего связаны с увеличением и с уменьшением цены.
|>
svm_fit tidy() |>
filter(term != "year") |>
filter(!str_detect(term, "genre")) |>
mutate(sign = case_when(estimate > 0 ~ "дороже",
.default = "дешевле"),
estimate = abs(estimate),
term = str_remove_all(term, "tfidf_name_")) |>
group_by(sign) |>
top_n(20, estimate) |>
ungroup() |>
ggplot(aes(x = estimate, y = fct_reorder(term, estimate),
fill = sign)) +
geom_col(show.legend = FALSE) +
scale_x_continuous(expand = c(0,0)) +
facet_wrap(~sign, scales = "free") +
labs(y = NULL,
title = "Связь слов с ценой книг") +
theme_minimal()
Любопытно: судя по нашему датасету, конституция США раздается на Амазоне бесплатно.