the scope of this blog post is to show how to do binary text classification using standard tools such as tidytext
and caret
packages. One of if not the most common binary text classification task is the spam detection (spam vs non-spam) that happens in most email services but has many other application such as language identification (English vs non-English).
In this post I’ll showcase 5 different classification methods to see how they compare with this data. The methods all land on the less complex side of the spectrum and thus does not include creating complex deep neural networks.
An expansion of this subject is multiclass text classification which I might write about in the future.
Packages
We load the packages we need for this project. tidyverse
for general data science work, tidytext
for text manipulation and caret
for modeling.
library(tidyverse)
library(tidytext)
library(caret)
Data
The data we will be using for this demonstration will be some English1 social media disaster tweets discussed in this article. It consist of a number of tweets regarding accidents mixed in with a selection control tweets (not about accidents). We start by loading in the data.
<- read_csv("https://raw.githubusercontent.com/EmilHvitfeldt/blog/750dc28aa8d514e2c0b8b418ade584df8f4a8c92/data/socialmedia-disaster-tweets-DFE.csv") data
And for this exercise we will only look at the body of the text. Furthermore a handful of the tweets weren’t classified, marked "Can't Decide"
so we are removing those as well. Since we are working with tweet data we have the constraint that most of tweets don’t actually have that much information in them as they are limited in characters and some only contain a couple of words.
We will at this stage remove what appears to be urls using some regex and str_replace_all
, and we will select the columns id
, disaster
and text
.
<- data %>%
data_clean filter(choose_one != "Can't Decide") %>%
mutate(id = `_unit_id`,
disaster = choose_one == "Relevant",
text = str_replace_all(text, " ?(f|ht)tp(s?)://(.*)[.][a-z]+", "")) %>%
select(id, disaster, text)
First we take a quick look at the distribution of classes and we see if the classes are balanced
%>%
data_clean ggplot(aes(disaster)) +
geom_bar()
And we see that is fairly balanced so we don’t have to worry about sampling this time.
The representation we will be using in this post will be the bag-of-words representation in which we just count how many times each word appears in each tweet disregarding grammar and even word order (mostly).
We will construct a tf-idf vector model in which each unique word is represented as a column and each document (tweet in our case) is a row of the tf-idf values. This will create a very large matrix/data.frame (a column of each unique word in the total data set) which will overload a lot of the different models we can implement, furthermore will a lot of the words (or features in ML slang) not add considerably information. We have a trade off between information and computational speed.
First we will remove all the stop words, this will insure that common words that usually don’t carry meaning doesn’t take up space (and time) in our model. Next will we only look at words that appear in 10 different tweets. Lastly we will be looking at both unigrams and bigrams to hopefully get a better information extraction.
<- map_df(1:2,
data_counts ~ unnest_tokens(data_clean, word, text,
token = "ngrams", n = .x)) %>%
anti_join(stop_words, by = "word") %>%
count(id, word, sort = TRUE)
We will only look at words at appear in at least 10 different tweets.
<- data_counts %>%
words_10 group_by(word) %>%
summarise(n = n()) %>%
filter(n >= 10) %>%
select(word)
we will right-join this to our data.frame before we will calculate the tf_idf and cast it to a document term matrix.
<- data_counts %>%
data_dtm right_join(words_10, by = "word") %>%
bind_tf_idf(word, id, n) %>%
cast_dtm(id, word, tf_idf)
This leaves us with 2993 features. We create this meta data.frame which acts as a intermediate from our first data set since some tweets might have disappeared completely after the reduction.
<- tibble(id = as.numeric(dimnames(data_dtm)[[1]])) %>%
meta left_join(data_clean[!duplicated(data_clean$id), ], by = "id")
We also create the index (based on the meta
data.frame) to separate the data into a training and test set.
set.seed(1234)
<- createDataPartition(meta$disaster, p = 0.8, list = FALSE, times = 1) trainIndex
since a lot of the methods take data.frames as inputs we will take the time and create these here:
<- data_dtm[trainIndex, ] %>% as.matrix() %>% as.data.frame()
data_df_train <- data_dtm[-trainIndex, ] %>% as.matrix() %>% as.data.frame()
data_df_test
<- meta$disaster[trainIndex] response_train
Now each row in the data.frame is a document/tweet (yay tidy principles!!).
Missing tweets
In the feature selection earlier we decided to turn our focus towards certain words and word-pairs, with that we also turned our focus AWAY from certain words. Since the tweets are fairly short in length it wouldn’t be surprising if a handful of the tweets completely skipped out focus as we noted earlier. Lets take a look at those tweets here.
%>%
data_clean anti_join(meta, by = "id") %>%
head(25) %>%
pull(text)
We see that a lot of them appears to be part of urls that our regex didn’t detect, furthermore it appears that in those tweet the sole text was the url which wouldn’t have helped us in this case anyways.
Modeling
Now that we have the data all clean and tidy we will turn our heads towards modeling. We will be using the wonderful caret
package which we will use to employ the following models
These where chosen because of their frequent use ( why SVM are good at text classification ) or because they are common in the classification field. They were also chosen because they where able to work with data with this number of variables in a reasonable time.
First time around we will not use a resampling method.
<- trainControl(method = "none") trctrl
SVM
The first model will be the svmLinearWeights2
model from the LiblineaR package. Where we specify default parameters.
<- train(x = data_df_train,
svm_mod y = as.factor(response_train),
method = "svmLinearWeights2",
trControl = trctrl,
tuneGrid = data.frame(cost = 1,
Loss = 0,
weight = 1))
We predict on the test data set based on the fitted model.
<- predict(svm_mod,
svm_pred newdata = data_df_test)
lastly we calculate the confusion matrix using the confusionMatrix
function in the caret
package.
<- confusionMatrix(svm_pred, meta[-trainIndex, ]$disaster)
svm_cm svm_cm
and we get an accuracy of 0.7461646.
Naive-Bayes
The second model will be the naive_bayes
model from the naivebayes package. Where we specify default parameters.
<- train(x = data_df_train,
nb_mod y = as.factor(response_train),
method = "naive_bayes",
trControl = trctrl,
tuneGrid = data.frame(laplace = 0,
usekernel = FALSE,
adjust = FALSE))
We predict on the test data set based on the fitted model.
<- predict(nb_mod,
nb_pred newdata = data_df_test)
calculate the confusion matrix
<- confusionMatrix(nb_pred, meta[-trainIndex, ]$disaster)
nb_cm nb_cm
and we get an accuracy of 0.5564854.
LogitBoost
The third model will be the LogitBoost
model from the caTools package. We don’t have to specify any parameters.
<- train(x = data_df_train,
logitboost_mod y = as.factor(response_train),
method = "LogitBoost",
trControl = trctrl)
We predict on the test data set based on the fitted model.
<- predict(logitboost_mod,
logitboost_pred newdata = data_df_test)
calculate the confusion matrix
<- confusionMatrix(logitboost_pred, meta[-trainIndex, ]$disaster)
logitboost_cm logitboost_cm
and we get an accuracy of 0.632729.
Random forest
The fourth model will be the ranger
model from the caTools package. Where we specify default parameters.
<- train(x = data_df_train,
rf_mod y = as.factor(response_train),
method = "ranger",
trControl = trctrl,
tuneGrid = data.frame(mtry = floor(sqrt(dim(data_df_train)[2])),
splitrule = "gini",
min.node.size = 1))
We predict on the test data set based on the fitted model.
<- predict(rf_mod,
rf_pred newdata = data_df_test)
calculate the confusion matrix
<- confusionMatrix(rf_pred, meta[-trainIndex, ]$disaster)
rf_cm rf_cm
and we get an accuracy of 0.7777778.
nnet
The fifth and final model will be the nnet
model from the caTools package. Where we specify default parameters. We will also specify MaxNWts = 5000
such that it will work. It will need to be more then the number of columns multiplied the size.
<- train(x = data_df_train,
nnet_mod y = as.factor(response_train),
method = "nnet",
trControl = trctrl,
tuneGrid = data.frame(size = 1,
decay = 5e-4),
MaxNWts = 5000)
We predict on the test data set based on the fitted model.
<- predict(nnet_mod,
nnet_pred newdata = data_df_test)
calculate the confusion matrix
<- confusionMatrix(nnet_pred, meta[-trainIndex, ]$disaster)
nnet_cm nnet_cm
and we get an accuracy of 0.7173408.
Comparing models
To see how the different models stack out we combine the metrics together in a data.frame
.
<- rbind(
mod_results $overall,
svm_cm$overall,
nb_cm$overall,
logitboost_cm$overall,
rf_cm$overall
nnet_cm%>%
) as.data.frame() %>%
mutate(model = c("SVM", "Naive-Bayes", "LogitBoost", "Random forest", "Neural network"))
visualizing the accuracy for the different models with the red line being the “No Information Rate” that is, having a model that just picks the model common class.
%>%
mod_results ggplot(aes(model, Accuracy)) +
geom_point() +
ylim(0, 1) +
geom_hline(yintercept = mod_results$AccuracyNull[1],
color = "red")
As you can see all but one approach does better then the “No Information Rate” on its first try before tuning the hyperparameters.
Tuning hyperparameters
After trying out the different models we saw quite a spread in performance. But it important to remember that the results might be because of good/bad default hyperparameters. There are a few different ways to handle this problem. I’ll show on of them here, grid search, on the SVM model so you get the idea.
We will be using 10-fold cross-validation and 3 repeats, which will slow down the procedure, but will try to limit and reduce overfitting. We will be using grid search approach to find optimal hyperparameters. For the sake of time have to fixed 2 of the hyperparameters and only let one vary. Remember that the time it takes to search though all combinations take a long time when then number of hyperparameters increase.
<- trainControl(method = "repeatedcv",
fitControl number = 3,
repeats = 3,
search = "grid")
We have decided to limit the search around the weight
parameter’s default value 1.
<- train(x = data_df_train,
svm_mod y = as.factor(response_train),
method = "svmLinearWeights2",
trControl = fitControl,
tuneGrid = data.frame(cost = 0.01,
Loss = 0,
weight = seq(0.5, 1.5, 0.1)))
and once it have finished running we can plot the train object to see which value is highest.
plot(svm_mod)
And we see that it appear to be just around 1. It is important to search multiple parameters at the SAME TIME as it can not be assumed that the parameters are independent of each others. Only reason I didn’t do that here was to same the time.
I will leave to you the reader to find out which of the models have the highest accuracy after doing parameter tuning.
I hope you have enjoyed this overview of binary text classification.
session information
─ Session info ───────────────────────────────────────────────────────────────
setting value 4.1.0 (2021-05-18)
version R version 10.16
os macOS Big Sur .0
system x86_64, darwin17
ui X11 language (EN)
-8
collate en_US.UTF-8
ctype en_US.UTF/Los_Angeles
tz America2021-07-13
date
─ Packages ───────────────────────────────────────────────────────────────────* version date lib source
package 1.3.2 2021-06-09 [1] Github (rstudio/blogdown@00a2090)
blogdown 0.22 2021-04-22 [1] CRAN (R 4.1.0)
bookdown 0.2.5.1 2021-05-18 [1] CRAN (R 4.1.0)
bslib 3.0.0 2021-06-30 [1] CRAN (R 4.1.0)
cli 0.7.1 2020-10-08 [1] CRAN (R 4.1.0)
clipr 1.4.1 2021-02-08 [1] CRAN (R 4.1.0)
crayon 1.3.0 2021-03-05 [1] CRAN (R 4.1.0)
desc * 0.2.1 2020-01-12 [1] CRAN (R 4.1.0)
details 0.6.27 2020-10-24 [1] CRAN (R 4.1.0)
digest 0.14 2019-05-28 [1] CRAN (R 4.1.0)
evaluate 0.5.1.1 2021-01-22 [1] CRAN (R 4.1.0)
htmltools 1.4.2 2020-07-20 [1] CRAN (R 4.1.0)
httr 0.1.4 2021-04-26 [1] CRAN (R 4.1.0)
jquerylib 1.7.2 2020-12-09 [1] CRAN (R 4.1.0)
jsonlite * 1.33 2021-04-24 [1] CRAN (R 4.1.0)
knitr 2.0.1 2020-11-17 [1] CRAN (R 4.1.0)
magrittr 0.1-7 2013-12-03 [1] CRAN (R 4.1.0)
png 2.5.0 2020-10-28 [1] CRAN (R 4.1.0)
R6 0.4.11 2021-04-30 [1] CRAN (R 4.1.0)
rlang 2.9 2021-06-15 [1] CRAN (R 4.1.0)
rmarkdown 2.0.2 2020-11-15 [1] CRAN (R 4.1.0)
rprojroot 0.4.0 2021-05-12 [1] CRAN (R 4.1.0)
sass 1.1.1 2018-11-05 [1] CRAN (R 4.1.0)
sessioninfo 1.6.2 2021-05-17 [1] CRAN (R 4.1.0)
stringi 1.4.0 2019-02-10 [1] CRAN (R 4.1.0)
stringr 2.4.2 2021-04-18 [1] CRAN (R 4.1.0)
withr 0.24 2021-06-15 [1] CRAN (R 4.1.0)
xfun 1.3.2 2020-04-23 [1] CRAN (R 4.1.0)
xml2 2.2.1 2020-02-01 [1] CRAN (R 4.1.0)
yaml
1] /Library/Frameworks/R.framework/Versions/4.1/Resources/library [