• Tutorials >
  • Text classification with the torchtext library
Shortcuts

Text classification with the torchtext library

In this tutorial, we will show how to use the torchtext library to build the dataset for the text classification analysis. Users will have the flexibility to

  • Access to the raw data as an iterator
  • Build data processing pipeline to convert the raw text strings into torch.Tensor that can be used to train the model
  • Shuffle and iterate the data with torch.utils.data.DataLoader

Access to the raw dataset iterators

The torchtext library provides a few raw dataset iterators, which yield the raw text strings. For example, the AG_NEWS dataset iterators yield the raw data as a tuple of label and text.

import torch
from torchtext.datasets import AG_NEWS
train_iter = AG_NEWS(split='train')
next(train_iter)
>>> (3, "Wall St. Bears Claw Back Into the Black (Reuters) Reuters -
Short-sellers, Wall Street's dwindling\\band of ultra-cynics, are seeing green
again.")

next(train_iter)
>>> (3, 'Carlyle Looks Toward Commercial Aerospace (Reuters) Reuters - Private
investment firm Carlyle Group,\\which has a reputation for making well-timed
and occasionally\\controversial plays in the defense industry, has quietly
placed\\its bets on another part of the market.')

next(train_iter)
>>> (3, "Oil and Economy Cloud Stocks' Outlook (Reuters) Reuters - Soaring
crude prices plus worries\\about the economy and the outlook for earnings are
expected to\\hang over the stock market next week during the depth of
the\\summer doldrums.")

Prepare data processing pipelines

We have revisited the very basic components of the torchtext library, including vocab, word vectors, tokenizer. Those are the basic data processing building blocks for raw text string.

Here is an example for typical NLP data processing with tokenizer and vocabulary. The first step is to build a vocabulary with the raw training dataset. Users can have a customized vocab by setting up arguments in the constructor of the Vocab class. For example, the minimum frequency min_freq for the tokens to be included.

from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab

tokenizer = get_tokenizer('basic_english')
train_iter = AG_NEWS(split='train')
counter = Counter()
for (label, line) in train_iter:
    counter.update(tokenizer(line))
vocab = Vocab(counter, min_freq=1)

The vocabulary block converts a list of tokens into integers.

[vocab[token] for token in ['here', 'is', 'an', 'example']]
>>> [476, 22, 31, 5298]

Prepare the text processing pipeline with the tokenizer and vocabulary. The text and label pipelines will be used to process the raw data strings from the dataset iterators.

text_pipeline = lambda x: [vocab[token] for token in tokenizer(x)]
label_pipeline = lambda x: int(x) - 1

The text pipeline converts a text string into a list of integers based on the lookup table defined in the vocabulary. The label pipeline converts the label into integers. For example,

text_pipeline('here is the an example')
>>> [475, 21, 2, 30, 5286]
label_pipeline('10')
>>> 9

Generate data batch and iterator

torch.utils.data.DataLoader is recommended for PyTorch users (a tutorial is here). It works with a map-style dataset that implements the getitem() and len() protocols, and represents a map from indices/keys to data samples. It also works with an iterable datasets with the shuffle argumnent of False.

Before sending to the model, collate_fn function works on a batch of samples generated from DataLoader. The input to collate_fn is a batch of data with the batch size in DataLoader, and collate_fn processes them according to the data processing pipelines declared previouly. Pay attention here and make sure that collate_fn is declared as a top level def. This ensures that the function is available in each worker.

In this example, the text entries in the original data batch input are packed into a list and concatenated as a single tensor for the input of nn.EmbeddingBag. The offset is a tensor of delimiters to represent the beginning index of the individual sequence in the text tensor. Label is a tensor saving the labels of indidividual text entries.

from torch.utils.data import DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def collate_batch(batch):
    label_list, text_list, offsets = [], [], [0]
    for (_label, _text) in batch:
         label_list.append(label_pipeline(_label))
         processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)
         text_list.append(processed_text)
         offsets.append(processed_text.size(0))
    label_list = torch.tensor(label_list, dtype=torch.int64)
    offsets = torch.tensor(offsets[:-1]).cumsum(dim=0)
    text_list = torch.cat(text_list)
    return label_list.to(device), text_list.to(device), offsets.to(device)

train_iter = AG_NEWS(split='train')
dataloader = DataLoader(train_iter, batch_size=8, shuffle=False, collate_fn=collate_batch)

Define the model

The model is composed of the nn.EmbeddingBag layer plus a linear layer for the classification purpose. nn.EmbeddingBag with the default mode of “mean” computes the mean value of a “bag” of embeddings. Although the text entries here have different lengths, nn.EmbeddingBag module requires no padding here since the text lengths are saved in offsets.

Additionally, since nn.EmbeddingBag accumulates the average across the embeddings on the fly, nn.EmbeddingBag can enhance the performance and memory efficiency to process a sequence of tensors.

../_images/text_sentiment_ngrams_model.png
from torch import nn

class TextClassificationModel(nn.Module):

    def __init__(self, vocab_size, embed_dim, num_class):
        super(TextClassificationModel, self).__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        initrange = 0.5
        self.embedding.weight.data.uniform_(-initrange, initrange)
        self.fc.weight.data.uniform_(-initrange, initrange)
        self.fc.bias.data.zero_()

    def forward(self, text, offsets):
        embedded = self.embedding(text, offsets)
        return self.fc(embedded)

Initiate an instance

The AG_NEWS dataset has four labels and therefore the number of classes is four.

1 : World
2 : Sports
3 : Business
4 : Sci/Tec

We build a model with the embedding dimension of 64. The vocab size is equal to the length of the vocabulary instance. The number of classes is equal to the number of labels,

train_iter = AG_NEWS(split='train')
num_class = len(set([label for (label, text) in train_iter]))
vocab_size = len(vocab)
emsize = 64
model = TextClassificationModel(vocab_size, emsize, num_class).to(device)

Define functions to train the model and evaluate results.

import time

def train(dataloader):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 500
    start_time = time.time()

    for idx, (label, text, offsets) in enumerate(dataloader):
        optimizer.zero_grad()
        predited_label = model(text, offsets)
        loss = criterion(predited_label, label)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optimizer.step()
        total_acc += (predited_label.argmax(1) == label).sum().item()
        total_count += label.size(0)
        if idx % log_interval == 0 and idx > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches '
                  '| accuracy {:8.3f}'.format(epoch, idx, len(dataloader),
                                              total_acc/total_count))
            total_acc, total_count = 0, 0
            start_time = time.time()

def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (label, text, offsets) in enumerate(dataloader):
            predited_label = model(text, offsets)
            loss = criterion(predited_label, label)
            total_acc += (predited_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc/total_count

Split the dataset and run the model

Since the original AG_NEWS has no valid dataset, we split the training dataset into train/valid sets with a split ratio of 0.95 (train) and 0.05 (valid). Here we use torch.utils.data.dataset.random_split function in PyTorch core library.

CrossEntropyLoss criterion combines nn.LogSoftmax() and nn.NLLLoss() in a single class. It is useful when training a classification problem with C classes. SGD implements stochastic gradient descent method as the optimizer. The initial learning rate is set to 5.0. StepLR is used here to adjust the learning rate through epochs.

from torch.utils.data.dataset import random_split
# Hyperparameters
EPOCHS = 10 # epoch
LR = 5  # learning rate
BATCH_SIZE = 64 # batch size for training

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None
train_iter, test_iter = AG_NEWS()
train_dataset = list(train_iter)
test_dataset = list(test_iter)
num_train = int(len(train_dataset) * 0.95)
split_train_, split_valid_ = \
    random_split(train_dataset, [num_train, len(train_dataset) - num_train])

train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,
                              shuffle=True, collate_fn=collate_batch)
test_dataloader = DataLoader(test_dataset, batch_size=BATCH_SIZE,
                             shuffle=True, collate_fn=collate_batch)

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_dataloader)
    accu_val = evaluate(valid_dataloader)
    if total_accu is not None and total_accu > accu_val:
      scheduler.step()
    else:
       total_accu = accu_val
    print('-' * 59)
    print('| end of epoch {:3d} | time: {:5.2f}s | '
          'valid accuracy {:8.3f} '.format(epoch,
                                           time.time() - epoch_start_time,
                                           accu_val))
    print('-' * 59)

Out:

| epoch   1 |   500/ 1782 batches | accuracy    0.688
| epoch   1 |  1000/ 1782 batches | accuracy    0.855
| epoch   1 |  1500/ 1782 batches | accuracy    0.876
-----------------------------------------------------------
| end of epoch   1 | time:  9.68s | valid accuracy    0.888
-----------------------------------------------------------
| epoch   2 |   500/ 1782 batches | accuracy    0.899
| epoch   2 |  1000/ 1782 batches | accuracy    0.898
| epoch   2 |  1500/ 1782 batches | accuracy    0.905
-----------------------------------------------------------
| end of epoch   2 | time:  9.39s | valid accuracy    0.898
-----------------------------------------------------------
| epoch   3 |   500/ 1782 batches | accuracy    0.915
| epoch   3 |  1000/ 1782 batches | accuracy    0.914
| epoch   3 |  1500/ 1782 batches | accuracy    0.914
-----------------------------------------------------------
| end of epoch   3 | time:  9.82s | valid accuracy    0.907
-----------------------------------------------------------
| epoch   4 |   500/ 1782 batches | accuracy    0.922
| epoch   4 |  1000/ 1782 batches | accuracy    0.926
| epoch   4 |  1500/ 1782 batches | accuracy    0.923
-----------------------------------------------------------
| end of epoch   4 | time:  9.35s | valid accuracy    0.909
-----------------------------------------------------------
| epoch   5 |   500/ 1782 batches | accuracy    0.932
| epoch   5 |  1000/ 1782 batches | accuracy    0.928
| epoch   5 |  1500/ 1782 batches | accuracy    0.927
-----------------------------------------------------------
| end of epoch   5 | time:  9.42s | valid accuracy    0.904
-----------------------------------------------------------
| epoch   6 |   500/ 1782 batches | accuracy    0.941
| epoch   6 |  1000/ 1782 batches | accuracy    0.944
| epoch   6 |  1500/ 1782 batches | accuracy    0.942
-----------------------------------------------------------
| end of epoch   6 | time:  9.70s | valid accuracy    0.913
-----------------------------------------------------------
| epoch   7 |   500/ 1782 batches | accuracy    0.945
| epoch   7 |  1000/ 1782 batches | accuracy    0.944
| epoch   7 |  1500/ 1782 batches | accuracy    0.942
-----------------------------------------------------------
| end of epoch   7 | time:  9.71s | valid accuracy    0.914
-----------------------------------------------------------
| epoch   8 |   500/ 1782 batches | accuracy    0.945
| epoch   8 |  1000/ 1782 batches | accuracy    0.944
| epoch   8 |  1500/ 1782 batches | accuracy    0.944
-----------------------------------------------------------
| end of epoch   8 | time:  9.68s | valid accuracy    0.912
-----------------------------------------------------------
| epoch   9 |   500/ 1782 batches | accuracy    0.947
| epoch   9 |  1000/ 1782 batches | accuracy    0.946
| epoch   9 |  1500/ 1782 batches | accuracy    0.945
-----------------------------------------------------------
| end of epoch   9 | time:  9.94s | valid accuracy    0.914
-----------------------------------------------------------
| epoch  10 |   500/ 1782 batches | accuracy    0.946
| epoch  10 |  1000/ 1782 batches | accuracy    0.946
| epoch  10 |  1500/ 1782 batches | accuracy    0.949
-----------------------------------------------------------
| end of epoch  10 | time:  9.84s | valid accuracy    0.914
-----------------------------------------------------------

Running the model on GPU with the following printout:

| epoch   1 |   500/ 1782 batches | accuracy    0.684
| epoch   1 |  1000/ 1782 batches | accuracy    0.852
| epoch   1 |  1500/ 1782 batches | accuracy    0.877
-----------------------------------------------------------
| end of epoch   1 | time:  8.33s | valid accuracy    0.867
-----------------------------------------------------------
| epoch   2 |   500/ 1782 batches | accuracy    0.895
| epoch   2 |  1000/ 1782 batches | accuracy    0.900
| epoch   2 |  1500/ 1782 batches | accuracy    0.903
-----------------------------------------------------------
| end of epoch   2 | time:  8.18s | valid accuracy    0.890
-----------------------------------------------------------
| epoch   3 |   500/ 1782 batches | accuracy    0.914
| epoch   3 |  1000/ 1782 batches | accuracy    0.914
| epoch   3 |  1500/ 1782 batches | accuracy    0.916
-----------------------------------------------------------
| end of epoch   3 | time:  8.20s | valid accuracy    0.897
-----------------------------------------------------------
| epoch   4 |   500/ 1782 batches | accuracy    0.926
| epoch   4 |  1000/ 1782 batches | accuracy    0.924
| epoch   4 |  1500/ 1782 batches | accuracy    0.921
-----------------------------------------------------------
| end of epoch   4 | time:  8.18s | valid accuracy    0.895
-----------------------------------------------------------
| epoch   5 |   500/ 1782 batches | accuracy    0.938
| epoch   5 |  1000/ 1782 batches | accuracy    0.935
| epoch   5 |  1500/ 1782 batches | accuracy    0.937
-----------------------------------------------------------
| end of epoch   5 | time:  8.16s | valid accuracy    0.902
-----------------------------------------------------------
| epoch   6 |   500/ 1782 batches | accuracy    0.939
| epoch   6 |  1000/ 1782 batches | accuracy    0.939
| epoch   6 |  1500/ 1782 batches | accuracy    0.938
-----------------------------------------------------------
| end of epoch   6 | time:  8.16s | valid accuracy    0.906
-----------------------------------------------------------
| epoch   7 |   500/ 1782 batches | accuracy    0.941
| epoch   7 |  1000/ 1782 batches | accuracy    0.939
| epoch   7 |  1500/ 1782 batches | accuracy    0.939
-----------------------------------------------------------
| end of epoch   7 | time:  8.19s | valid accuracy    0.903
-----------------------------------------------------------
| epoch   8 |   500/ 1782 batches | accuracy    0.942
| epoch   8 |  1000/ 1782 batches | accuracy    0.941
| epoch   8 |  1500/ 1782 batches | accuracy    0.942
-----------------------------------------------------------
| end of epoch   8 | time:  8.16s | valid accuracy    0.904
-----------------------------------------------------------
| epoch   9 |   500/ 1782 batches | accuracy    0.942
| epoch   9 |  1000/ 1782 batches | accuracy    0.941
| epoch   9 |  1500/ 1782 batches | accuracy    0.942
-----------------------------------------------------------
  end of epoch   9 | time:  8.16s | valid accuracy    0.904
-----------------------------------------------------------
| epoch  10 |   500/ 1782 batches | accuracy    0.940
| epoch  10 |  1000/ 1782 batches | accuracy    0.942
| epoch  10 |  1500/ 1782 batches | accuracy    0.942
-----------------------------------------------------------
| end of epoch  10 | time:  8.15s | valid accuracy    0.904
-----------------------------------------------------------

Evaluate the model with test dataset

Checking the results of the test dataset…

print('Checking the results of test dataset.')
accu_test = evaluate(test_dataloader)
print('test accuracy {:8.3f}'.format(accu_test))

Out:

Checking the results of test dataset.
test accuracy    0.904
test accuracy    0.906

Test on a random news

Use the best model so far and test a golf news.

ag_news_label = {1: "World",
                 2: "Sports",
                 3: "Business",
                 4: "Sci/Tec"}

def predict(text, text_pipeline):
    with torch.no_grad():
        text = torch.tensor(text_pipeline(text))
        output = model(text, torch.tensor([0]))
        return output.argmax(1).item() + 1

ex_text_str = "MEMPHIS, Tenn. – Four days ago, Jon Rahm was \
    enduring the season’s worst weather conditions on Sunday at The \
    Open on his way to a closing 75 at Royal Portrush, which \
    considering the wind and the rain was a respectable showing. \
    Thursday’s first round at the WGC-FedEx St. Jude Invitational \
    was another story. With temperatures in the mid-80s and hardly any \
    wind, the Spaniard was 13 strokes better in a flawless round. \
    Thanks to his best putting performance on the PGA Tour, Rahm \
    finished with an 8-under 62 for a three-stroke lead, which \
    was even more impressive considering he’d never played the \
    front nine at TPC Southwind."

model = model.to("cpu")

print("This is a %s news" %ag_news_label[predict(ex_text_str, text_pipeline)])

Out:

This is a Sports news
This is a Sports news

Total running time of the script: ( 1 minutes 47.377 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources