• Tutorials >
  • Sequence-to-Sequence Modeling with nn.Transformer and TorchText
Shortcuts

Sequence-to-Sequence Modeling with nn.Transformer and TorchText

This is a tutorial on how to train a sequence-to-sequence model that uses the nn.Transformer module.

PyTorch 1.2 release includes a standard transformer module based on the paper Attention is All You Need. The transformer model has been proved to be superior in quality for many sequence-to-sequence problems while being more parallelizable. The nn.Transformer module relies entirely on an attention mechanism (another module recently implemented as nn.MultiheadAttention) to draw global dependencies between input and output. The nn.Transformer module is now highly modularized such that a single component (like nn.TransformerEncoder in this tutorial) can be easily adapted/composed.

../_images/transformer_architecture.jpg

Define the model

In this tutorial, we train nn.TransformerEncoder model on a language modeling task. The language modeling task is to assign a probability for the likelihood of a given word (or a sequence of words) to follow a sequence of words. A sequence of tokens are passed to the embedding layer first, followed by a positional encoding layer to account for the order of the word (see the next paragraph for more details). The nn.TransformerEncoder consists of multiple layers of nn.TransformerEncoderLayer. Along with the input sequence, a square attention mask is required because the self-attention layers in nn.TransformerEncoder are only allowed to attend the earlier positions in the sequence. For the language modeling task, any tokens on the future positions should be masked. To have the actual words, the output of nn.TransformerEncoder model is sent to the final Linear layer, which is followed by a log-Softmax function.

import math
import torch
import torch.nn as nn
import torch.nn.functional as F

class TransformerModel(nn.Module):

    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):
        super(TransformerModel, self).__init__()
        from torch.nn import TransformerEncoder, TransformerEncoderLayer
        self.model_type = 'Transformer'
        self.pos_encoder = PositionalEncoding(ninp, dropout)
        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)
        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.ninp = ninp
        self.decoder = nn.Linear(ninp, ntoken)

        self.init_weights()

    def generate_square_subsequent_mask(self, sz):
        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
        return mask

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, src, src_mask):
        src = self.encoder(src) * math.sqrt(self.ninp)
        src = self.pos_encoder(src)
        output = self.transformer_encoder(src, src_mask)
        output = self.decoder(output)
        return output

PositionalEncoding module injects some information about the relative or absolute position of the tokens in the sequence. The positional encodings have the same dimension as the embeddings so that the two can be summed. Here, we use sine and cosine functions of different frequencies.

class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)

Load and batch data

This tutorial uses torchtext to generate Wikitext-2 dataset. The vocab object is built based on the train dataset and is used to numericalize tokens into tensors. Starting from sequential data, the batchify() function arranges the dataset into columns, trimming off any tokens remaining after the data has been divided into batches of size batch_size. For instance, with the alphabet as the sequence (total length of 26) and a batch size of 4, we would divide the alphabet into 4 sequences of length 6:

\[\begin{split}\begin{bmatrix} \text{A} & \text{B} & \text{C} & \ldots & \text{X} & \text{Y} & \text{Z} \end{bmatrix} \Rightarrow \begin{bmatrix} \begin{bmatrix}\text{A} \\ \text{B} \\ \text{C} \\ \text{D} \\ \text{E} \\ \text{F}\end{bmatrix} & \begin{bmatrix}\text{G} \\ \text{H} \\ \text{I} \\ \text{J} \\ \text{K} \\ \text{L}\end{bmatrix} & \begin{bmatrix}\text{M} \\ \text{N} \\ \text{O} \\ \text{P} \\ \text{Q} \\ \text{R}\end{bmatrix} & \begin{bmatrix}\text{S} \\ \text{T} \\ \text{U} \\ \text{V} \\ \text{W} \\ \text{X}\end{bmatrix} \end{bmatrix}\end{split}\]

These columns are treated as independent by the model, which means that the dependence of G and F can not be learned, but allows more efficient batch processing.

import io
import torch
from torchtext.datasets import WikiText2
from torchtext.data.utils import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab

train_iter = WikiText2(split='train')
tokenizer = get_tokenizer('basic_english')
counter = Counter()
for line in train_iter:
    counter.update(tokenizer(line))
vocab = Vocab(counter)

def data_process(raw_text_iter):
  data = [torch.tensor([vocab[token] for token in tokenizer(item)],
                       dtype=torch.long) for item in raw_text_iter]
  return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))

train_iter, val_iter, test_iter = WikiText2()
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def batchify(data, bsz):
    # Divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

batch_size = 20
eval_batch_size = 10
train_data = batchify(train_data, batch_size)
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)

Functions to generate input and target sequence

get_batch() function generates the input and target sequence for the transformer model. It subdivides the source data into chunks of length bptt. For the language modeling task, the model needs the following words as Target. For example, with a bptt value of 2, we’d get the following two Variables for i = 0:

../_images/transformer_input_target1.png

It should be noted that the chunks are along dimension 0, consistent with the S dimension in the Transformer model. The batch dimension N is along dimension 1.

bptt = 35
def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].reshape(-1)
    return data, target

Initiate an instance

The model is set up with the hyperparameter below. The vocab size is equal to the length of the vocab object.

ntokens = len(vocab.stoi) # the size of vocabulary
emsize = 200 # embedding dimension
nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder
nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead = 2 # the number of heads in the multiheadattention models
dropout = 0.2 # the dropout value
model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)

Run the model

CrossEntropyLoss is applied to track the loss and SGD implements stochastic gradient descent method as the optimizer. The initial learning rate is set to 5.0. StepLR is applied to adjust the learn rate through epochs. During the training, we use nn.utils.clip_grad_norm_ function to scale all the gradient together to prevent exploding.

criterion = nn.CrossEntropyLoss()
lr = 5.0 # learning rate
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)

import time
def train():
    model.train() # Turn on the train mode
    total_loss = 0.
    start_time = time.time()
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        optimizer.zero_grad()
        if data.size(0) != bptt:
            src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
        output = model(data, src_mask)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()

        total_loss += loss.item()
        log_interval = 200
        if batch % log_interval == 0 and batch > 0:
            cur_loss = total_loss / log_interval
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | '
                  'lr {:02.2f} | ms/batch {:5.2f} | '
                  'loss {:5.2f} | ppl {:8.2f}'.format(
                    epoch, batch, len(train_data) // bptt, scheduler.get_last_lr()[0],
                    elapsed * 1000 / log_interval,
                    cur_loss, math.exp(cur_loss)))
            total_loss = 0
            start_time = time.time()

def evaluate(eval_model, data_source):
    eval_model.eval() # Turn on the evaluation mode
    total_loss = 0.
    src_mask = model.generate_square_subsequent_mask(bptt).to(device)
    with torch.no_grad():
        for i in range(0, data_source.size(0) - 1, bptt):
            data, targets = get_batch(data_source, i)
            if data.size(0) != bptt:
                src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)
            output = eval_model(data, src_mask)
            output_flat = output.view(-1, ntokens)
            total_loss += len(data) * criterion(output_flat, targets).item()
    return total_loss / (len(data_source) - 1)

Loop over epochs. Save the model if the validation loss is the best we’ve seen so far. Adjust the learning rate after each epoch.

best_val_loss = float("inf")
epochs = 3 # The number of epochs
best_model = None

for epoch in range(1, epochs + 1):
    epoch_start_time = time.time()
    train()
    val_loss = evaluate(model, val_data)
    print('-' * 89)
    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '
          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),
                                     val_loss, math.exp(val_loss)))
    print('-' * 89)

    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_model = model

    scheduler.step()

Out:

| epoch   1 |   200/ 2928 batches | lr 5.00 | ms/batch 11.39 | loss  8.19 | ppl  3595.76
| epoch   1 |   400/ 2928 batches | lr 5.00 | ms/batch 11.10 | loss  6.91 | ppl  1004.89
| epoch   1 |   600/ 2928 batches | lr 5.00 | ms/batch 11.18 | loss  6.45 | ppl   633.12
| epoch   1 |   800/ 2928 batches | lr 5.00 | ms/batch 11.19 | loss  6.31 | ppl   548.56
| epoch   1 |  1000/ 2928 batches | lr 5.00 | ms/batch 11.22 | loss  6.20 | ppl   490.80
| epoch   1 |  1200/ 2928 batches | lr 5.00 | ms/batch 11.12 | loss  6.16 | ppl   474.24
| epoch   1 |  1400/ 2928 batches | lr 5.00 | ms/batch 10.92 | loss  6.11 | ppl   452.07
| epoch   1 |  1600/ 2928 batches | lr 5.00 | ms/batch 10.91 | loss  6.11 | ppl   448.91
| epoch   1 |  1800/ 2928 batches | lr 5.00 | ms/batch 10.93 | loss  6.03 | ppl   413.67
| epoch   1 |  2000/ 2928 batches | lr 5.00 | ms/batch 10.92 | loss  6.02 | ppl   413.20
| epoch   1 |  2200/ 2928 batches | lr 5.00 | ms/batch 10.94 | loss  5.89 | ppl   362.73
| epoch   1 |  2400/ 2928 batches | lr 5.00 | ms/batch 10.91 | loss  5.97 | ppl   391.27
| epoch   1 |  2600/ 2928 batches | lr 5.00 | ms/batch 10.91 | loss  5.96 | ppl   386.19
| epoch   1 |  2800/ 2928 batches | lr 5.00 | ms/batch 10.92 | loss  5.88 | ppl   359.08
-----------------------------------------------------------------------------------------
| end of epoch   1 | time: 33.98s | valid loss  5.79 | valid ppl   326.28
-----------------------------------------------------------------------------------------
| epoch   2 |   200/ 2928 batches | lr 4.75 | ms/batch 10.99 | loss  5.87 | ppl   352.85
| epoch   2 |   400/ 2928 batches | lr 4.75 | ms/batch 10.92 | loss  5.86 | ppl   349.93
| epoch   2 |   600/ 2928 batches | lr 4.75 | ms/batch 10.93 | loss  5.68 | ppl   292.76
| epoch   2 |   800/ 2928 batches | lr 4.75 | ms/batch 10.93 | loss  5.70 | ppl   298.14
| epoch   2 |  1000/ 2928 batches | lr 4.75 | ms/batch 10.94 | loss  5.66 | ppl   286.44
| epoch   2 |  1200/ 2928 batches | lr 4.75 | ms/batch 10.95 | loss  5.69 | ppl   296.05
| epoch   2 |  1400/ 2928 batches | lr 4.75 | ms/batch 10.97 | loss  5.69 | ppl   295.17
| epoch   2 |  1600/ 2928 batches | lr 4.75 | ms/batch 11.08 | loss  5.71 | ppl   303.03
| epoch   2 |  1800/ 2928 batches | lr 4.75 | ms/batch 11.17 | loss  5.66 | ppl   286.32
| epoch   2 |  2000/ 2928 batches | lr 4.75 | ms/batch 11.11 | loss  5.68 | ppl   291.77
| epoch   2 |  2200/ 2928 batches | lr 4.75 | ms/batch 10.98 | loss  5.56 | ppl   258.88
| epoch   2 |  2400/ 2928 batches | lr 4.75 | ms/batch 10.98 | loss  5.65 | ppl   284.54
| epoch   2 |  2600/ 2928 batches | lr 4.75 | ms/batch 10.97 | loss  5.65 | ppl   284.97
| epoch   2 |  2800/ 2928 batches | lr 4.75 | ms/batch 10.99 | loss  5.59 | ppl   267.61
-----------------------------------------------------------------------------------------
| end of epoch   2 | time: 33.84s | valid loss  5.64 | valid ppl   280.90
-----------------------------------------------------------------------------------------
| epoch   3 |   200/ 2928 batches | lr 4.51 | ms/batch 11.07 | loss  5.61 | ppl   273.18
| epoch   3 |   400/ 2928 batches | lr 4.51 | ms/batch 11.09 | loss  5.63 | ppl   278.60
| epoch   3 |   600/ 2928 batches | lr 4.51 | ms/batch 11.09 | loss  5.43 | ppl   227.48
| epoch   3 |   800/ 2928 batches | lr 4.51 | ms/batch 10.96 | loss  5.48 | ppl   241.02
| epoch   3 |  1000/ 2928 batches | lr 4.51 | ms/batch 10.96 | loss  5.44 | ppl   230.42
| epoch   3 |  1200/ 2928 batches | lr 4.51 | ms/batch 10.99 | loss  5.48 | ppl   240.83
| epoch   3 |  1400/ 2928 batches | lr 4.51 | ms/batch 11.05 | loss  5.49 | ppl   242.41
| epoch   3 |  1600/ 2928 batches | lr 4.51 | ms/batch 11.07 | loss  5.52 | ppl   248.83
| epoch   3 |  1800/ 2928 batches | lr 4.51 | ms/batch 11.06 | loss  5.47 | ppl   236.81
| epoch   3 |  2000/ 2928 batches | lr 4.51 | ms/batch 11.07 | loss  5.49 | ppl   242.39
| epoch   3 |  2200/ 2928 batches | lr 4.51 | ms/batch 11.09 | loss  5.37 | ppl   214.86
| epoch   3 |  2400/ 2928 batches | lr 4.51 | ms/batch 11.13 | loss  5.48 | ppl   238.73
| epoch   3 |  2600/ 2928 batches | lr 4.51 | ms/batch 11.09 | loss  5.47 | ppl   238.32
| epoch   3 |  2800/ 2928 batches | lr 4.51 | ms/batch 11.11 | loss  5.41 | ppl   223.43
-----------------------------------------------------------------------------------------
| end of epoch   3 | time: 34.03s | valid loss  5.57 | valid ppl   262.42
-----------------------------------------------------------------------------------------

Evaluate the model with the test dataset

Apply the best model to check the result with the test dataset.

test_loss = evaluate(best_model, test_data)
print('=' * 89)
print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(
    test_loss, math.exp(test_loss)))
print('=' * 89)

Out:

=========================================================================================
| End of training | test loss  5.48 | test ppl   238.69
=========================================================================================

Total running time of the script: ( 1 minutes 50.895 seconds)

Gallery generated by Sphinx-Gallery

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources