Predicting Book from Avatar: The Last Airbender scripts
This week’s data can be found here. It is scripts texts for Avatar: The Last Airbender (the TV show). However, since this data comes straight out of the appa package, I may as well use that package directly!
# devtools::install_github("averyrobbins1/appa")
avatar <- appa::appa
colnames(avatar)
[1] "id" "book" "book_num"
[4] "chapter" "chapter_num" "character"
[7] "full_text" "character_words" "scene_description"
[10] "writer" "director" "imdb_rating"
For this blog post, I will use supervised machine learning for text analysis to attempt to predict, for each text snippet, which ‘Book’ it comes from. The textrecipes package is designed to fit neatly into a tidymodels workflow.
Avatar is divided into Books (Series), depending on where the main characters travel. The unique values are:
unique(avatar$book)
[1] Water Earth Fire
Levels: Water Earth Fire
Since there are 3 Books, we will need to use multinomial classification (rather than binary classification, which is most common but only works if there are only 2 classes).
To classify, we will use the full_text, character, and imdb_rating columns.
The text is incredibly clean: no spelling mistakes or odd punctuation. However, some character-spoken text comes with scene descriptions in square brackets. Since those descriptions are still part of the character’s contribution at that scene, I’ll just remove the brackets but leave the text.
avatar_all <- dplyr::transmute(avatar, book, full_text,
character = as.factor(character),
imdb_rating) %>%
tidyr::drop_na() %>%
dplyr::mutate(full_text = gsub('\\[|\\]', '', full_text))
Next, we split the now-preprepared data into a training dataset, and a testing dataset. The training dataset we will split up into k-fold cross-validation groups to allow for out-of-bag error estimates before going to the testing set.
avatar_split <- rsample::initial_split(avatar_all, strata = book)
avatar_train <- rsample::training(avatar_split)
avatar_test <- rsample::testing(avatar_split)
avatar_folds <- rsample::vfold_cv(avatar_train)
Now we need to set up a preprocessing ‘recipe’, which we will apply to the training dataset and we can subsequently apply to the testing data (to make sure we preprocess on the basis of training data only).
avatar_rec <-
recipes::recipe(book ~ ., data = avatar_train) %>%
recipes::step_dummy(character) %>%
recipes::step_normalize(imdb_rating) %>%
textrecipes::step_tokenize(full_text) %>%
textrecipes::step_stopwords(full_text) %>%
textrecipes::step_tokenfilter(full_text, max_tokens = 500) %>%
textrecipes::step_tfidf(full_text) %>%
recipes::step_zv(recipes::all_predictors())
rec_prep <- recipes::prep(avatar_rec)
For the actual machine learning model, we will use a Support Vector Machine algorithm. This group of algorithms tends to be particularly strong at high-dimensional problems (like text).
The liquidSVM Support Vector Machine implemented in tidymodels automatically detects if it’s a binary, or a multiclass classification problem; hence there is no need to specify this up-front.
# Support vector machine
avatar_svm <- parsnip::svm_rbf() %>%
parsnip::set_mode("classification") %>%
parsnip::set_engine("liquidSVM")
Tidymodels works on the basis of workflows: you put together the recipe (which specifies data pre-processing) and the model (which specifies the machine learning algorithm).
I also define a metric set with multiple metrics I’d like to calculate after fitting.
svm_wf <- workflows::workflow() %>%
workflows::add_recipe(avatar_rec) %>%
workflows::add_model(avatar_svm)
multi_met <- yardstick::metric_set(yardstick::accuracy,
yardstick::precision,
yardstick::recall,
yardstick::spec)
With all the preparatory steps done, it’s time to fit the actual model! We fit it to each of the folds separately.
svm_rs <- tune::fit_resamples(
svm_wf,
avatar_folds,
metrics = yardstick::metric_set(yardstick::accuracy),
control = tune::control_resamples(save_pred = TRUE)
)
Because we fit this model using k-fold cross-validation, we have to use the collect_predictions function to get predictions back.
Then, we apply the metric set object we created before to get a neat tibble of model metrics.
svm_rs_predictions <- tune::collect_predictions(svm_rs)
metrics <- svm_rs_predictions %>%
multi_met(truth = book, estimate = .pred_class)
metrics
# A tibble: 4 x 3
.metric .estimator .estimate
<chr> <chr> <dbl>
1 accuracy multiclass 0.650
2 precision macro 0.649
3 recall macro 0.649
4 spec macro 0.825
The metrics are… not great. But what types of errors are we making?
svm_rs_predictions %>%
dplyr::group_by(.pred_class, book) %>%
dplyr::summarise(count = n()) %>%
ggplot2::ggplot(aes(y = .pred_class, x = book, fill = count), colour = 'black')+
ggplot2::geom_tile(width = 0.95, height = 0.95)+
ggplot2::geom_text(aes(label = count), colour = 'white')+
ggplot2::scale_fill_gradient(low = 'firebrick', high = 'chartreuse4')+
ggplot2::theme_minimal()+
ggplot2::theme(legend.position = 'none')+
ggplot2::labs(title = 'Avatar: The Last Airbender Book supervised ML for text',
subtitle = 'Confusion matrix of true versus predicted Book (ie. season); cross-fold validation',
x = 'True Book',
y = 'Predicted Book')
![]()
From the confusion matrix, the errors seem fairly evenly distributed. However, it appears it’s easier to distinguish Fire from the other two books (especially Water), than it is to distinguish Water and Earth. This makes sense from a story perspective.
Just to be sure, do these results hold up if we don’t use cross-validation within the training data, but instead fit the workflow to the whole training data and subsequently test on the held-out testing data?
First, we re-fit the algorithm on the whole training data.
svm_fit <- parsnip::fit(svm_wf, data = avatar_train)
Then, we gather predictions for the testing data, and create the same confusion matrix as before. Thankfully, the results look similar: not too shabby!
avatar_test[,5] <- predict(svm_fit, new_data = avatar_test)
avatar_test %>%
dplyr::group_by(.pred_class, book) %>%
dplyr::summarise(count = n()) %>%
ggplot2::ggplot(aes(y = .pred_class, x = book, fill = count), colour = 'black')+
ggplot2::geom_tile(width = 0.95, height = 0.95)+
ggplot2::geom_text(aes(label = count), colour = 'white')+
ggplot2::scale_fill_gradient(low = 'firebrick', high = 'chartreuse4')+
ggplot2::theme_minimal()+
ggplot2::theme(legend.position = 'none')+
ggplot2::labs(title = 'Avatar: The Last Airbender Book supervised ML for text',
subtitle = 'Confusion matrix of true versus predicted Book (ie. season); testing data',
x = 'True Book',
y = 'Predicted Book')
![]()