Implementing Transformer Models in PyTorch: A Guided Walkthrough

June 5, 2024

In recent years, transformer models have revolutionized the field of natural language processing (NLP) and have found applications in various other domains such as computer vision and time series forecasting. Their ability to handle long-range dependencies and parallelize training has made them the go-to architecture for many state-of-the-art models.

Nowadays, transformers and their variants are everywhere. Let's deep dive into it and understand its code from scratch. In this article, we will explore the implementation of transformer models in PyTorch, leveraging the excellent tutorial and GitHub repository by Umar Jamil.

We will follow along with Umar Jamil's comprehensive YouTube tutorial and reference his GitHub repository to understand the intricate details of transformer models. This article is designed for those who already have a solid foundation in machine learning and PyTorch and are looking to expand their knowledge by delving into advanced models.

While utilizing these resources, I am also creating my own repository on GitHub to update it to the latest version and incorporate improvements, especially taking advantage of my RTX 4070 Ti GPU for efficient training and experimentation.

Prerequisite Knowledge

  1. Python
  2. PyTorch
  3. Fundamental of artificial neural networks

In case you are unfamiliar with artificial neural networks, please feel free to explore my blog by clicking on the provided link. Artificial Neural Network You can access my repository by using the provided link: transformer

How to run this?

pip install -r requiremnets.txt

and run train_test.py

Transformers Architecture

Don't worry about the diagram, I'll break down everything step by step and explain each component and block in this transformer model architecture.

Understanding Transformer Models in Simple Terms

  1. Attention: Imagine you're reading a story. Some words are more important for understanding than others. Attention in a transformer is like giving more focus to the important words while reading.

  2. Encoder: Think of this as the part of the model that understands the story you're reading. It figures out the meaning of each word and how they relate to each other.

  3. Decoder: This part of the model uses what the encoder understood to create a new story, maybe in a different language. It's like translating or summarizing.

  4. Positional Encoding: Since the model doesn't naturally know the order of words, positional encoding is like giving each word a number to show its place in the story.

  5. Word Embedding: When the model reads words, it doesn't see them like we do. Instead, it turns each word into a special code called a vector. These vectors help the model understand the meaning of words based on how they're used in the story.

  6. Feed-Forward Neural Networks (FFNN): These are like super smart calculators. They help the model understand complex patterns in the story.

  7. Residual Connections: Imagine if you're building a tall tower with blocks. Sometimes, to make sure the tower doesn't fall, you add extra support. Residual connections are like that extra support for the model, helping it learn better.

  8. Layer Normalization: This is like making sure each part of the model is using the same scale or rules to understand the story. It keeps everything fair and balanced.

  9. Multi-Head Attention: Think of this as having multiple pairs of eyes reading the story at the same time, each looking for different important parts.

  10. Masking: Just like you wouldn't peek ahead in a book to spoil the ending, masking ensures the model focuses only on the parts of the story it has already "read." It's like covering up the pages ahead, so the model can't cheat by looking into the future while learning.

Deep Dive into Transformers with code

In this deep dive, we will explore the Transformer model, focusing on a practical use case: translating text from English to Italian.

In this guide, we'll break down the process into several stages, starting with data preprocessing.

Data Preprocessing

Tokenization: Is important for transformer models in natural language processing because it breaks down text into smaller parts, called tokens, that the model can understand. This helps the model learn patterns in the data and handle different input lengths and unknown words, improving its ability to understand and generate human language.

To put it simply, tokenization assigns each word a specific index, like a unique key in a database. These indexed words are then passed into the transformer model, which uses them to perform complex mathematical calculations.

train_test.py

def get_or_build_tokenizer(config, ds, lang):
    # Construct the file path for the tokenizer using the language-specific format
    tokenizer_path = Path(config['tokenizer_file'].format(lang))

    # Check if the tokenizer file exists at the specified path
    if not Path.exists(tokenizer_path):
        # If the tokenizer file does not exist, create a new WordLevel tokenizer
        tokenizer = Tokenizer(WordLevel(unk_token="[UNK]"))

        # Set the pre-tokenizer to split text by whitespace
        tokenizer.pre_tokenizer = Whitespace()

        # Define a trainer for the tokenizer with special tokens and a minimum frequency threshold
        trainer = WordLevelTrainer(special_tokens=["[UNK]", "[PAD]", "[SOS]", "[EOS]"], min_frequency=2)

        # Train the tokenizer on sentences from the dataset for the specified language
        # 'get_all_sentences(ds, lang)' is return an iterator over all sentences
        tokenizer.train_from_iterator(get_all_sentences(ds, lang), trainer=trainer)

        # Save the trained tokenizer to the specified file path
        tokenizer.save(str(tokenizer_path))
    else:
        # If the tokenizer file exists, load the tokenizer from the file
        tokenizer = Tokenizer.from_file(str(tokenizer_path))

    # Return the tokenizer, either newly created or loaded from the file
    return tokenizer

torch.utils.data.Dataset is a powerful tool in PyTorch for creating custom datasets. By defining how to access and retrieve data, you can handle complex data pipelines and integrate seamlessly with PyTorch's data loading utilities. "Dataset" is an abstract class, meaning you don't use it directly. Instead, you create a subclass and implement specific methods to define how data should be accessed and manipulated. In the code provided below, we have implemented our own custom logic to include an additional token in our sentences.

Let's skip the causal_mask function for now. I'll cover what masking is in more detail later on in our walk through.

dataset.py

import torch
from torch.utils.data import Dataset

class BilingualDataset(Dataset):

    def __init__(self, ds, tokenizer_src, tokenizer_tgt, src_lang, tgt_lang, seq_len):
        super().__init__()
        self.seq_len = seq_len

        self.ds = ds
        self.tokenizer_src = tokenizer_src
        self.tokenizer_tgt = tokenizer_tgt
        self.src_lang = src_lang
        self.tgt_lang = tgt_lang

        self.sos_token = torch.tensor([tokenizer_tgt.token_to_id("[SOS]")], dtype=torch.int64)
        self.eos_token = torch.tensor([tokenizer_tgt.token_to_id("[EOS]")], dtype=torch.int64)
        self.pad_token = torch.tensor([tokenizer_tgt.token_to_id("[PAD]")], dtype=torch.int64)

    def __len__(self):
        return len(self.ds)

    def __getitem__(self, idx):
        src_target_pair = self.ds[idx]
        src_text = src_target_pair['translation'][self.src_lang]
        tgt_text = src_target_pair['translation'][self.tgt_lang]

        # Transform the text into tokens
        enc_input_tokens = self.tokenizer_src.encode(src_text).ids
        dec_input_tokens = self.tokenizer_tgt.encode(tgt_text).ids

        # Add sos, eos and padding to each sentence
        enc_num_padding_tokens = self.seq_len - len(enc_input_tokens) - 2  # We will add <s> and </s>
        # We will only add <s>, and </s> only on the label
        dec_num_padding_tokens = self.seq_len - len(dec_input_tokens) - 1

        # Make sure the number of padding tokens is not negative. If it is, the sentence is too long
        if enc_num_padding_tokens < 0 or dec_num_padding_tokens < 0:
            raise ValueError("Sentence is too long")

        # Add <s> and </s> token
        encoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(enc_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * enc_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only <s> token
        decoder_input = torch.cat(
            [
                self.sos_token,
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Add only </s> token
        label = torch.cat(
            [
                torch.tensor(dec_input_tokens, dtype=torch.int64),
                self.eos_token,
                torch.tensor([self.pad_token] * dec_num_padding_tokens, dtype=torch.int64),
            ],
            dim=0,
        )

        # Double check the size of the tensors to make sure they are all seq_len long
        assert encoder_input.size(0) == self.seq_len
        assert decoder_input.size(0) == self.seq_len
        assert label.size(0) == self.seq_len

        return {
            "encoder_input": encoder_input,  # (seq_len)
            "decoder_input": decoder_input,  # (seq_len)
            "encoder_mask": (encoder_input != self.pad_token).unsqueeze(0).unsqueeze(0).int(), # (1, 1, seq_len)
            "decoder_mask": (decoder_input != self.pad_token).unsqueeze(0).int() & causal_mask(decoder_input.size(0)), # (1, seq_len) & (1, seq_len, seq_len),
            "label": label,  # (seq_len)
            "src_text": src_text,
            "tgt_text": tgt_text,
        }
    
def causal_mask(size):
    mask = torch.triu(torch.ones((1, size, size)), diagonal=1).type(torch.int)
    return mask == 0


Train test split

train_test.py

def get_ds(config):
    # Load the raw dataset based on the datasource and language pair from the configuration
    # The dataset only has the train split, so we divide it ourselves into training and validation sets
    ds_raw = load_dataset(f"{config['datasource']}", f"{config['lang_src']}-{config['lang_tgt']}", split='train')

    # Build or retrieve tokenizers for both source and target languages using the raw dataset
    tokenizer_src = get_or_build_tokenizer(config, ds_raw, config['lang_src'])
    tokenizer_tgt = get_or_build_tokenizer(config, ds_raw, config['lang_tgt'])

    # Calculate the sizes for training and validation datasets (90% for training, 10% for validation)
    train_ds_size = int(0.9 * len(ds_raw))
    val_ds_size = len(ds_raw) - train_ds_size
    
    # Randomly split the raw dataset into training and validation sets based on the calculated sizes
    train_ds_raw, val_ds_raw = random_split(ds_raw, [train_ds_size, val_ds_size])

    train_ds = BilingualDataset(train_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'],
                                config['seq_len'])
    val_ds = BilingualDataset(val_ds_raw, tokenizer_src, tokenizer_tgt, config['lang_src'], config['lang_tgt'],
                              config['seq_len'])

    # Find the maximum length of each sentence in the source and target sentence
    #region to chacke max len. other then that we do not have any use of this block
    max_len_src = 0
    max_len_tgt = 0

    for item in ds_raw:
        src_ids = tokenizer_src.encode(item['translation'][config['lang_src']]).ids
        tgt_ids = tokenizer_tgt.encode(item['translation'][config['lang_tgt']]).ids
        max_len_src = max(max_len_src, len(src_ids))
        max_len_tgt = max(max_len_tgt, len(tgt_ids))

    print(f'Max length of source sentence: {max_len_src}')
    print(f'Max length of target sentence: {max_len_tgt}')
    #endregion

    train_dataloader = DataLoader(train_ds, batch_size=config['batch_size'], shuffle=True)
    val_dataloader = DataLoader(val_ds, batch_size=1, shuffle=True)

    return train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt

Word Embedding

Word embeddings encode semantic information about words. Words with similar meanings are represented by vectors that are closer together in the embedding space. This allows the model to understand the relationships between words and capture semantic similarities.

Word embeddings transform words or tokens into dense numerical vectors in a continuous vector space. Each word in a vocabulary is represented by a unique vector, typically of fixed length.

Here we are using a dimension of 512.

model.py

class InputEmbeddings(nn.Module):

    def __init__(self, d_model: int, vocab_size: int) -> None:
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(vocab_size, d_model)

    def forward(self, x):
        # (batch, seq_len) --> (batch, seq_len, d_model)
        # Multiply by sqrt(d_model) to scale the embeddings according to the paper
        return self.embedding(x) * math.sqrt(self.d_model)

Over time, the performance of finding semantic similarity of words will improve as the transformer model is trained.

Positional Encoding

Transformers process input tokens in parallel rather than sequentially (unlike RNNs), they need a way to encode the position of each token in the sequence. Positional Encoding provides a way to inject some information about the position of each token within the sequence into the model. This helps the transformer to differentiate between different positions and understand the order of the tokens.

Positional Encodings are added to the input embeddings at the bottom of the encoder and decoder stacks. There are several ways to implement positional encoding, but the most common method used in the original transformer paper "Attention is All You Need" employs sinusoidal functions of different frequencies.

Sine Component

PE(pos,2i)=sin(pos100002idmodel)\text{PE}_{(pos, 2i)} = \sin \left( \frac{\text{pos}}{10000^{\frac{2i}{d_{model}}}} \right)

Cosine Component

PE(pos,2i+1)=cos(pos100002idmodel)\text{PE}_{(pos, 2i+1)} = \cos \left( \frac{\text{pos}}{10000^{\frac{2i}{d_{model}}}} \right)

Key Reasons for Using Multiple Frequencies

  1. Encoding Different Scales of Relationships:

    • Different frequency components enable the model to capture both short-term and long-term dependencies.
    • Higher frequencies correspond to fine-grained (local) positional differences, whereas lower frequencies capture broader (global) positional patterns.
  2. Uniqueness of Position Representation:

    • Using a combination of sine and cosine functions at different frequencies ensures that each position has a unique encoding. This uniqueness helps the model distinguish between different positions effectively.
    • If we used a single frequency, the positional encodings might not be distinct enough to represent different positions clearly.
  3. Periodicity and Continuity:

    • The sine and cosine functions are periodic, which means they repeat their values in regular intervals. This periodicity can help the model understand cyclical patterns in the data.
    • The continuous nature of these functions allows for smooth interpolation between positions, which can be beneficial when dealing with sequences of varying lengths.

model.py

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, seq_len: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model
        self.seq_len = seq_len
        self.dropout = nn.Dropout(dropout)
        # Create a matrix of shape (seq_len, d_model)
        pe = torch.zeros(seq_len, d_model)
        # Create a vector of shape (seq_len)
        position = torch.arange(0, seq_len, dtype=torch.float).unsqueeze(1) # (seq_len, 1)
        # Create a vector of shape (d_model)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) # (d_model / 2)
        # Apply sine to even indices
        pe[:, 0::2] = torch.sin(position * div_term) # sin(position * (10000 ** (2i / d_model))
        # Apply cosine to odd indices
        pe[:, 1::2] = torch.cos(position * div_term) # cos(position * (10000 ** (2i / d_model))
        # Add a batch dimension to the positional encoding
        pe = pe.unsqueeze(0) # (1, seq_len, d_model)
        # Register the positional encoding as a buffer
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + (self.pe[:, :x.shape[1], :]).requires_grad_(False) # (batch, seq_len, d_model)
        return self.dropout(x)

Multi-Head Attention

"Attention is All You Need," revolutionized the field of natural language processing (NLP) by leveraging attention mechanisms to process sequential data. Central to the Transformer’s success is the multi-head attention mechanism, which allows the model to attend to different parts of a sequence simultaneously, capturing complex relationships and dependencies more effectively than previous architectures like recurrent neural networks (RNNs) or convolutional neural networks (CNNs). In this section, we will delve into the intricacies of the multi-head attention mechanism, explaining its components and functionality step by step.

Linear Projections

Queries (Q) : Queries are representations of the input sequence that the model uses to ask questions about the importance of different tokens. For each token in the sequence, a query vector is generated to interact with the keys of all tokens. The query essentially determines "which parts of the input sequence should be attended to" when processing a particular token.

Keys (K) : Keys are another set of vectors derived from the input sequence, representing the tokens in a way that makes them comparable with the queries. Each key vector can be thought of as an answer to a potential query. When a query from one token interacts with the keys of all tokens, it determines how much focus or attention each token should receive relative to the query token.

Values (V) : Values are the representations of the input sequence that are used to produce the final output of the attention mechanism. While queries and keys determine the attention weights, values are the actual content that gets combined to form the output. The values are weighted by the attention scores to reflect the importance of each token as determined by the interaction of queries and keys.

For self-attention, the same input embeddings are used to generate queries (Q), keys (K), and values (V) through learned linear transformations. Each token’s embedding is projected into these three distinct spaces:

Q=XWQQ=XW_Q
K=XWKK=XW_K
V=XWVV=XW_V

Where WQW_Q, WKW_K ,and WVW_V are learned weight matrices of dimensions dd x dkd_k. Consequently, Q, K, and V are matrices of size nn x dkd_k.

Scaled Dot-Product Attention

Each attention head computes the attention scores using the scaled dot-product attention mechanism. This involves three main steps:

Compute Dot Products : The dot products between the query vectors and key vectors are computed to get raw attention scores:

scores = QKTQK^T

Scale the Scores : The raw attention scores are scaled by the square root of the dimension of the key vectors dkd_k. This helps in stabilizing the gradients during training:

scaled_scores= scoresdk\frac{scores}{\sqrt{d_k}}

Apply Softmax : The scaled scores are passed through a softmax function to obtain the attention weights. This normalizes the scores to a probability distribution:

attention_weights = softmax(scaled_scores)

Weighted Sum of Values : The attention weights are used to compute a weighted sum of the value vectors, producing the output for each attention head:

head_output = attention_weights ⋅ V

Multi-Head Attention

While a single attention mechanism is powerful, the Transformer employs multiple attention heads to capture various aspects of the relationships in the data. Each head operates independently, and their outputs are concatenated:

concat_output=Concat(head1​,head2​,...,headh​)

Final Linear Projection

The concatenated output from all the heads is then linearly projected back to the original embedding dimension d using another learned weight matrix WOW_O :

multi_head_output=concat_outputWOW_O

model.py

class MultiHeadAttentionBlock(nn.Module):

    def __init__(self, d_model: int, h: int, dropout: float) -> None:
        super().__init__()
        self.d_model = d_model # Embedding vector size
        self.h = h # Number of heads
        # Make sure d_model is divisible by h
        assert d_model % h == 0, "d_model is not divisible by h"

        self.d_k = d_model // h # Dimension of vector seen by each head
        self.w_q = nn.Linear(d_model, d_model, bias=False) # Wq
        self.w_k = nn.Linear(d_model, d_model, bias=False) # Wk
        self.w_v = nn.Linear(d_model, d_model, bias=False) # Wv
        self.w_o = nn.Linear(d_model, d_model, bias=False) # Wo
        self.dropout = nn.Dropout(dropout)

    @staticmethod
    def attention(query, key, value, mask, dropout: nn.Dropout):
        d_k = query.shape[-1]
        # Just apply the formula from the paper
        # (batch, h, seq_len, d_k) --> (batch, h, seq_len, seq_len)
        attention_scores = (query @ key.transpose(-2, -1)) / math.sqrt(d_k)
        if mask is not None:
            # Write a very low value (indicating -inf) to the positions where mask == 0
            attention_scores.masked_fill_(mask == 0, -1e9)
        attention_scores = attention_scores.softmax(dim=-1) # (batch, h, seq_len, seq_len) # Apply softmax
        if dropout is not None:
            attention_scores = dropout(attention_scores)
        # (batch, h, seq_len, seq_len) --> (batch, h, seq_len, d_k)
        # return attention scores which can be used for visualization
        return (attention_scores @ value), attention_scores

    def forward(self, q, k, v, mask):
        query = self.w_q(q) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        key = self.w_k(k) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)
        value = self.w_v(v) # (batch, seq_len, d_model) --> (batch, seq_len, d_model)

        # (batch, seq_len, d_model) --> (batch, seq_len, h, d_k) --> (batch, h, seq_len, d_k)
        query = query.view(query.shape[0], query.shape[1], self.h, self.d_k).transpose(1, 2)
        key = key.view(key.shape[0], key.shape[1], self.h, self.d_k).transpose(1, 2)
        value = value.view(value.shape[0], value.shape[1], self.h, self.d_k).transpose(1, 2)

        # Calculate attention
        x, self.attention_scores = MultiHeadAttentionBlock.attention(query, key, value, mask, self.dropout)
        
        # Combine all the heads together
        # (batch, h, seq_len, d_k) --> (batch, seq_len, h, d_k) --> (batch, seq_len, d_model)
        x = x.transpose(1, 2).contiguous().view(x.shape[0], -1, self.h * self.d_k)

        # Multiply by Wo
        # (batch, seq_len, d_model) --> (batch, seq_len, d_model)  
        return self.w_o(x)

Masked Multi-Head Attention

While standard MHA is powerful, it lacks the ability to preserve the autoregressive property, essential for tasks where the model should not have access to future information during inference. This is where Masked Multi-Head Attention comes into play. By applying a mask to the attention scores, Masked MHA ensures that each token in the sequence attends only to preceding tokens, preventing it from peeking into the future.

The masking mechanism in Masked MHA involves modifying the attention scores to enforce the autoregressive constraint. Specifically, a mask matrix is applied to the attention scores, setting future positions to very large negative values (e.g.,−∞). This effectively blocks attention to future tokens, ensuring that the model's predictions depend only on the tokens that precede them in the sequence.

Demonstrating sample code to illustrate the functionality of a mask.


import torch

# Define the mask tensor
mask = torch.tensor([
    [1, 0, 0],
    [1, 1, 0],
    [1, 1, 1]
])

# Attention scores (example)
scores = torch.tensor([
    [0.8, 0.1, 0.2],
    [0.3, 0.9, 0.5],
    [0.2, 0.5, 0.7]
])

# Apply mask to attention scores
masked_scores = torch.where(mask == 1, scores, float('-inf'))

# Softmax function
softmax = torch.nn.Softmax(dim=-1)

# Apply softmax to the masked scores
attention_weights = softmax(masked_scores)

print("Mask Tensor (mask):")
print(mask)

print("\nAttention Scores:")
print(scores)

print("\nMasked Attention Scores:")
print(masked_scores)

print("\nAttention Weights (after softmax):")
print(attention_weights)

output

Mask Tensor (mask):
tensor([[1, 0, 0],
        [1, 1, 0],
        [1, 1, 1]])

Attention Scores:
tensor([[0.8000, 0.1000, 0.2000],
        [0.3000, 0.9000, 0.5000],
        [0.2000, 0.5000, 0.7000]])

Masked Attention Scores:
tensor([[0.8000,   -inf,   -inf],
        [0.3000, 0.9000,   -inf],
        [0.2000, 0.5000, 0.7000]])

Attention Weights (after softmax):
tensor([[1.0000, 0.0000, 0.0000],
        [0.3543, 0.6457, 0.0000],
        [0.2501, 0.3376, 0.4123]])

Residual Connections

Residual connections are designed to help mitigate the vanishing gradient problem, making it easier to train deep neural networks. In a transformer, residual connections add the input of a layer to its output before passing it to the next layer. Mathematically, for an input xx and a layer operation F(x)F(x), the output with a residual connection is given by:

Output=F(x)+xF(x)+x

This simple addition ensures that the gradient can flow directly through the network, preventing it from becoming too small during backpropagation. This is crucial for deep models like transformers, which consist of many layers.

Layer Normalization in Residual Connections

Layer normalization is a technique used to stabilize and accelerate the training of deep neural networks by normalizing the inputs to each layer. Unlike batch normalization, which normalizes across the batch dimension, layer normalization normalizes across the features for each training example.

Mathematically, for an input xx with mean μ\mu and variance σ2\sigma^2 , the layer normalization is defined as:

LayerNorm(x)=γ(xμσ2+ϵ)+βLayerNorm(x) = \gamma(\frac{x-\mu}{\sqrt{ \sigma^2 + \epsilon}}) + \beta

Here, γ\gamma and β\beta are learnable parameters that scale and shift the normalized output, and ϵ\epsilon is a small constant to prevent division by zero.

model.py


class LayerNormalization(nn.Module):

    def __init__(self, features: int, eps:float=10**-6) -> None:
        super().__init__()
        self.eps = eps
        self.alpha = nn.Parameter(torch.ones(features)) # alpha is a learnable parameter
        self.bias = nn.Parameter(torch.zeros(features)) # bias is a learnable parameter

    def forward(self, x):
        # x: (batch, seq_len, hidden_size)
         # Keep the dimension for broadcasting
        mean = x.mean(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # Keep the dimension for broadcasting
        std = x.std(dim = -1, keepdim = True) # (batch, seq_len, 1)
        # eps is to prevent dividing by zero or when std is very small
        return self.alpha * (x - mean) / (std + self.eps) + self.bias


class ResidualConnection(nn.Module):
    
        def __init__(self, features: int, dropout: float) -> None:
            super().__init__()
            self.dropout = nn.Dropout(dropout)
            self.norm = LayerNormalization(features)
    
        def forward(self, x, sublayer):
            return x + self.dropout(sublayer(self.norm(x)))

Feed-Forward Networks?

Feed-forward networks introduce non-linearities into the model. This is crucial because the self-attention mechanism alone is a linear operation, and without non-linearity, the model would not be able to capture complex patterns and interactions within the data. The non-linear activation functions in FFNs allow the model to learn more intricate representations of the input data.

FFN(x)=relu(xW1+b1)W2+b2FFN(x)=relu(xW_1+b_1)W_2+b_2

model.py

class FeedForwardBlock(nn.Module):

    def __init__(self, d_model: int, d_ff: int, dropout: float) -> None:
        super().__init__()
        self.linear_1 = nn.Linear(d_model, d_ff) # w1 and b1
        self.dropout = nn.Dropout(dropout)
        self.linear_2 = nn.Linear(d_ff, d_model) # w2 and b2

    def forward(self, x):
        # (batch, seq_len, d_model) --> (batch, seq_len, d_ff) --> (batch, seq_len, d_model)
        return self.linear_2(self.dropout(torch.relu(self.linear_1(x))))

Encoder, Decoder and Transformers

Now, let us gather all the layers and construct the encode, decoder, and transformer.

model.py

class EncoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(2)])

    def forward(self, x, src_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, src_mask))
        x = self.residual_connections[1](x, self.feed_forward_block)
        return x
    
class Encoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class DecoderBlock(nn.Module):

    def __init__(self, features: int, self_attention_block: MultiHeadAttentionBlock, cross_attention_block: MultiHeadAttentionBlock, feed_forward_block: FeedForwardBlock, dropout: float) -> None:
        super().__init__()
        self.self_attention_block = self_attention_block
        self.cross_attention_block = cross_attention_block
        self.feed_forward_block = feed_forward_block
        self.residual_connections = nn.ModuleList([ResidualConnection(features, dropout) for _ in range(3)])

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        x = self.residual_connections[0](x, lambda x: self.self_attention_block(x, x, x, tgt_mask))
        x = self.residual_connections[1](x, lambda x: self.cross_attention_block(x, encoder_output, encoder_output, src_mask))
        x = self.residual_connections[2](x, self.feed_forward_block)
        return x
    
class Decoder(nn.Module):

    def __init__(self, features: int, layers: nn.ModuleList) -> None:
        super().__init__()
        self.layers = layers
        self.norm = LayerNormalization(features)

    def forward(self, x, encoder_output, src_mask, tgt_mask):
        for layer in self.layers:
            x = layer(x, encoder_output, src_mask, tgt_mask)
        return self.norm(x)

class ProjectionLayer(nn.Module):

    def __init__(self, d_model, vocab_size) -> None:
        super().__init__()
        self.proj = nn.Linear(d_model, vocab_size)

    def forward(self, x) -> None:
        # (batch, seq_len, d_model) --> (batch, seq_len, vocab_size)
        return self.proj(x)
    
class Transformer(nn.Module):

    def __init__(self, encoder: Encoder, decoder: Decoder, src_embed: InputEmbeddings, tgt_embed: InputEmbeddings, src_pos: PositionalEncoding, tgt_pos: PositionalEncoding, projection_layer: ProjectionLayer) -> None:
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.src_pos = src_pos
        self.tgt_pos = tgt_pos
        self.projection_layer = projection_layer

    def encode(self, src, src_mask):
        # (batch, seq_len, d_model)
        src = self.src_embed(src)
        src = self.src_pos(src)
        return self.encoder(src, src_mask)
    
    def decode(self, encoder_output: torch.Tensor, src_mask: torch.Tensor, tgt: torch.Tensor, tgt_mask: torch.Tensor):
        # (batch, seq_len, d_model)
        tgt = self.tgt_embed(tgt)
        tgt = self.tgt_pos(tgt)
        return self.decoder(tgt, encoder_output, src_mask, tgt_mask)
    
    def project(self, x):
        # (batch, seq_len, vocab_size)
        return self.projection_layer(x)

Creating instances of every class

model.py

def build_transformer(src_vocab_size: int, tgt_vocab_size: int, src_seq_len: int, tgt_seq_len: int, d_model: int=512, N: int=6, h: int=8, dropout: float=0.1, d_ff: int=2048) -> Transformer:
    # Create the embedding layers
    src_embed = InputEmbeddings(d_model, src_vocab_size)
    tgt_embed = InputEmbeddings(d_model, tgt_vocab_size)

    # Create the positional encoding layers
    src_pos = PositionalEncoding(d_model, src_seq_len, dropout)
    tgt_pos = PositionalEncoding(d_model, tgt_seq_len, dropout)
    
    # Create the encoder blocks
    encoder_blocks = []
    for _ in range(N):
        encoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        encoder_block = EncoderBlock(d_model, encoder_self_attention_block, feed_forward_block, dropout)
        encoder_blocks.append(encoder_block)

    # Create the decoder blocks
    decoder_blocks = []
    for _ in range(N):
        decoder_self_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        decoder_cross_attention_block = MultiHeadAttentionBlock(d_model, h, dropout)
        feed_forward_block = FeedForwardBlock(d_model, d_ff, dropout)
        decoder_block = DecoderBlock(d_model, decoder_self_attention_block, decoder_cross_attention_block, feed_forward_block, dropout)
        decoder_blocks.append(decoder_block)
    
    # Create the encoder and decoder
    encoder = Encoder(d_model, nn.ModuleList(encoder_blocks))
    decoder = Decoder(d_model, nn.ModuleList(decoder_blocks))
    
    # Create the projection layer
    projection_layer = ProjectionLayer(d_model, tgt_vocab_size)
    
    # Create the transformer
    transformer = Transformer(encoder, decoder, src_embed, tgt_embed, src_pos, tgt_pos, projection_layer)
    
    # Initialize the parameters
    for p in transformer.parameters():
        if p.dim() > 1:
            nn.init.xavier_uniform_(p)
    
    return transformer

Training

train_test.py


def get_model(config, vocab_src_len, vocab_tgt_len):
    model = build_transformer(vocab_src_len, vocab_tgt_len, config["seq_len"], config['seq_len'], d_model=config['d_model'])
    return model

def train_model(config):
    # Define the device
    device = "cuda" if torch.cuda.is_available() else "mps" if torch.has_mps or torch.backends.mps.is_available() else "cpu"
    print("Using device:", device)
    if (device == 'cuda'):
        print(f"Device name: {torch.cuda.get_device_name(device.index)}")
        print(f"Device memory: {torch.cuda.get_device_properties(device.index).total_memory / 1024 ** 3} GB")
    elif (device == 'mps'):
        print(f"Device name: <mps>")
    else:
        print("NOTE: If you have a GPU, consider using it for training.")
        print("      On a Windows machine with NVidia GPU, check this video: https://www.youtube.com/watch?v=GMSjDTU8Zlc")
        print("      On a Mac machine, run: pip3 install --pre torch torchvision torchaudio torchtext --index-url https://download.pytorch.org/whl/nightly/cpu")
    device = torch.device(device)

    # Make sure the weights folder exists
    Path(f"{config['datasource']}_{config['model_folder']}").mkdir(parents=True, exist_ok=True)

    train_dataloader, val_dataloader, tokenizer_src, tokenizer_tgt = get_ds(config)
    model = get_model(config, tokenizer_src.get_vocab_size(), tokenizer_tgt.get_vocab_size()).to(device)
    # Tensorboard
    writer = SummaryWriter(config['experiment_name'])

    optimizer = torch.optim.Adam(model.parameters(), lr=config['lr'], eps=1e-9)

    # If the user specified a model to preload before training, load it
    initial_epoch = 0
    global_step = 0
    preload = config['preload']
    model_filename = latest_weights_file_path(config) if preload == 'latest' else get_weights_file_path(config, preload) if preload else None
    if model_filename:
        print(f'Preloading model {model_filename}')
        state = torch.load(model_filename)
        model.load_state_dict(state['model_state_dict'])
        initial_epoch = state['epoch'] + 1
        optimizer.load_state_dict(state['optimizer_state_dict'])
        global_step = state['global_step']
    else:
        print('No model to preload, starting from scratch')

    loss_fn = nn.CrossEntropyLoss(ignore_index=tokenizer_src.token_to_id('[PAD]'), label_smoothing=0.1).to(device)

    for epoch in range(initial_epoch, config['num_epochs']):
        torch.cuda.empty_cache()
        model.train()
        batch_iterator = tqdm(train_dataloader, desc=f"Processing Epoch {epoch:02d}")
        for batch in batch_iterator:

            encoder_input = batch['encoder_input'].to(device) # (b, seq_len)
            decoder_input = batch['decoder_input'].to(device) # (B, seq_len)
            encoder_mask = batch['encoder_mask'].to(device) # (B, 1, 1, seq_len)
            decoder_mask = batch['decoder_mask'].to(device) # (B, 1, seq_len, seq_len)

            # Run the tensors through the encoder, decoder and the projection layer
            encoder_output = model.encode(encoder_input, encoder_mask) # (B, seq_len, d_model)
            decoder_output = model.decode(encoder_output, encoder_mask, decoder_input, decoder_mask) # (B, seq_len, d_model)
            proj_output = model.project(decoder_output) # (B, seq_len, vocab_size)

            # Compare the output with the label
            label = batch['label'].to(device) # (B, seq_len)

            # Reshape the model output and labels to the required shapes for loss computation
            reshaped_output = proj_output.view(-1, tokenizer_tgt.get_vocab_size())
            reshaped_label = label.view(-1)

            # Compute the loss using a simple cross-entropy
            loss = loss_fn(reshaped_output, reshaped_label)

            batch_iterator.set_postfix({"loss": f"{loss.item():6.3f}"})

            # Log the loss
            writer.add_scalar('train loss', loss.item(), global_step)
            writer.flush()

            # Backpropagate the loss
            loss.backward()

            # Update the weights
            optimizer.step()
            optimizer.zero_grad(set_to_none=True)

            global_step += 1

        # Run validation at the end of every epoch
        run_validation(model, val_dataloader, tokenizer_src, tokenizer_tgt, config['seq_len'], device, lambda msg: batch_iterator.write(msg), global_step, writer)

        # Save the model at the end of every epoch
        model_filename = get_weights_file_path(config, f"{epoch:02d}")
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'global_step': global_step
        }, model_filename)

Testing

The function works token by token, building up the decoded sequence incrementally. Each iteration of the loop generates the next token in the sequence until the end-of-sequence token is reached or the maximum length constraint is met.

train_test.py

def greedy_decode(model, source, source_mask, tokenizer_src, tokenizer_tgt, max_len, device):
    sos_idx = tokenizer_tgt.token_to_id('[SOS]')
    eos_idx = tokenizer_tgt.token_to_id('[EOS]')

    # Precompute the encoder output and reuse it for every step
    encoder_output = model.encode(source, source_mask)
    # Initialize the decoder input with the sos token
    decoder_input = torch.empty(1, 1).fill_(sos_idx).type_as(source).to(device)
    while True:
        if decoder_input.size(1) == max_len:
            break

        # build mask for target
        decoder_mask = causal_mask(decoder_input.size(1)).type_as(source_mask).to(device)

        # calculate output
        out = model.decode(encoder_output, source_mask, decoder_input, decoder_mask)

        # get next token
        prob = model.project(out[:, -1])
        _, next_word = torch.max(prob, dim=1)
        decoder_input = torch.cat(
            [decoder_input, torch.empty(1, 1).type_as(source).fill_(next_word.item()).to(device)], dim=1
        )

        if next_word == eos_idx:
            break

    return decoder_input.squeeze(0)

train_test.py

def run_validation(model, validation_ds, tokenizer_src, tokenizer_tgt, max_len, device, print_msg, global_step, writer,
                   num_examples=2):
    model.eval()
    count = 0

    source_texts = []
    expected = []
    predicted = []

    try:
        # get the console window width
        with os.popen('stty size', 'r') as console:
            _, console_width = console.read().split()
            console_width = int(console_width)
    except:
        # If we can't get the console width, use 80 as default
        console_width = 80

    with torch.no_grad():
        for batch in validation_ds:
            count += 1
            encoder_input = batch["encoder_input"].to(device)  # (b, seq_len)
            encoder_mask = batch["encoder_mask"].to(device)  # (b, 1, 1, seq_len)

            # check that the batch size is 1
            assert encoder_input.size(
                0) == 1, "Batch size must be 1 for validation"

            model_out = greedy_decode(model, encoder_input, encoder_mask, tokenizer_src, tokenizer_tgt, max_len, device)

            source_text = batch["src_text"][0]
            target_text = batch["tgt_text"][0]
            model_out_text = tokenizer_tgt.decode(model_out.detach().cpu().numpy())

            source_texts.append(source_text)
            expected.append(target_text)
            predicted.append(model_out_text)

            # Print the source, target and model output
            print_msg('-' * console_width)
            print_msg(f"{f'SOURCE: ':>12}{source_text}")
            print_msg(f"{f'TARGET: ':>12}{target_text}")
            print_msg(f"{f'PREDICTED: ':>12}{model_out_text}")

            if count == num_examples:
                print_msg('-' * console_width)
                break

if __name__ == '__main__':
    #warnings.filterwarnings("ignore")
    config = get_config()
    train_model(config)

note: I'm not familiar with Italian, but here's the result from our fourth epoch.

Using device: cuda  
Device name: NVIDIA GeForce RTX 4070 Ti  
Device memory: 11.99365234375 GB  
Max length of source sentence: 309  
Max length of target sentence: 274  
Preloading model opus_books_weights\tmodel_03.pt  
Processing Epoch 04: 100%|██████████| 3638/3638 [07:10<00:00,  8.45it/s, loss=5.541]  
'stty' is not recognized as an internal or external command,  
operable program or batch file.  
--------------------------------------------------------------------------------  
    SOURCE: And this I must observe, with grief, too, that the discomposure of my mind had great impression also upon the religious part of my thoughts; for the dread and terror of falling into the hands of savages and cannibals lay so upon my spirits, that I seldom found myself in a due temper for application to my Maker; at least, not with the sedate calmness and resignation of soul which I was wont to do: I rather prayed to God as under great affliction and pressure of mind, surrounded with danger, and in expectation every night of being murdered and devoured before morning; and I must testify, from my experience, that a temper of peace, thankfulness, love, and affection, is much the more proper frame for prayer than that of terror and discomposure: and that under the dread of mischief impending, a man is no more fit for a comforting performance of the duty of praying to God than he is for a repentance on a sick-bed; for these discomposures affect the mind, as the others do the body; and the discomposure of the mind must necessarily be as great a disability as that of the body, and much greater; praying to God being properly an act of the mind, not of the body.  
    TARGET: Io pregava Dio com’uomo oppresso dal peso di una grande afflizione e costernazione, com’uomo cinto di pericoli d’ogni intorno e che si aspettava ogni notte di essere ucciso, ogni mattina di essere divorato. Posso dire dietro l’esperimento fattone su me stesso, che una disposizione pacifica, grata, lieta, affettuosa è molto più propria alla preghiera che quella d’un animo scompigliato ed atterrito.  
 PREDICTED: E questo io , e la sua natura , che mi la mia vita di , e la mia vita era stata in mente il mio corpo , e la mia vita , e la mia vita , e la mia vita mi , e in quel momento in quel momento , come io non mi più di più , e mi in un ’ altra parte di , e mi in un altro luogo , che mi in un altro luogo , e , e in un ’ altra parte di , e , come mi , come un ’ altra parte di , e , e , come mi , e , e , e , come mi , e , e , come io non mi , e , come mi , come mi , come mi , come mi , come mi , e , come mi , come un altro , come mi , e , e , come mi , come mi , come io non mi , come mi , come mi , come mi , come mi , come mi , come mi , come mi , come mi , come io , come io , come mi , come io non mi , come io non mi , come mi , come mi , come io non mi , come io non mi , come mi , come mi , come io non mi , come io non mi , come mi , come io non mi , come io non mi , come io non mi , come io non mi , come mi , come mi , come  
--------------------------------------------------------------------------------  
    SOURCE: These two circumstances were as follows. From the fact that when he had met Karenin in the street the previous day the latter had treated him with cold stiffness, and had not called or even informed them of his arrival – from this, added to the rumour about Anna and Vronsky that had reached him, Oblonsky concluded that all was not as it should be between the husband and wife.  
    TARGET: Esse erano: la prima, che il giorno avanti, incontrato per via Aleksej Aleksandrovic, aveva notato ch’egli era stato asciutto e brusco con lui e, associando questa espressione del viso di Aleksej Aleksandrovic e il fatto che non era venuto da loro e non aveva fatto sapere nulla di sé con le voci che circolavano sul conto di Anna e Vronskij, Stepan Arkad’ic indovinò che c’era qualcosa che non andava tra marito e moglie.  
 PREDICTED: Questi due anni , come erano queste , quando , quando , quando Aleksej Aleksandrovic aveva visto , aveva visto il giorno prima , aveva sentito il giorno e Aleksej Aleksandrovic , non aveva visto nulla di più , e che Anna aveva fatto che Anna aveva fatto e che non aveva visto e che Anna aveva visto e che aveva fatto e che Anna non aveva visto e che Anna non aveva fatto e che Anna aveva fatto e che la principessa non aveva visto e che la principessa e che Anna non aveva fatto e che Anna .  
--------------------------------------------------------------------------------  

Related:

← Back