Building a Simple Transformer using PyTorch [Code Included]
A code-walkthrough on how to code a transformer from scratch
Introduction
As we draw the curtain on our journey through the transformative world of Transformer models, we aim to demystify the complexities of these architectures by bringing theory into practice. This week, we take a hands-on approach to consolidate our understanding, presenting a simplified demonstration of building a Transformer model using PyTorch. This post will not only solidify the concepts we've explored but also provide a tangible grasp of how these models are constructed and operated in real-world applications.
Our exploration over the past weeks has taken us from the foundational principles of self-attention and positional encodings to deep dives into landmark models like BERT, the GPT series, and T5. Each step has revealed the intricate mechanics and broad applications of Transformers, showcasing their unparalleled impact on the field of natural language processing and beyond.
In this demonstration, we specifically focus on the encoding aspect of the Transformer architecture, streamlining our model to highlight key functionalities without incorporating the decoder component. This approach allows us to concentrate on the transformative capabilities of the model while maintaining simplicity for educational purposes. The task, though straightforward, embodies the pivotal elements that underscore the Transformer's prowess:
A Single-Head Self-Attention Mechanism: We distill the concept of multi-head self-attention to its core, demonstrating the fundamental operation of self-attention in processing sequences.
A Simple Position-Wise Feed-Forward Network: Through a minimalist version of the feed-forward network, we illustrate how Transformers apply transformations to the data at each position independently, enhancing the model's ability to capture relationships within the data.
Skip Connections and Layer Normalization: These essential components are integrated to ensure the model's training stability and efficiency, showcasing their role in facilitating effective learning in deep architectures.
Simple Positional Encoding: We underscore the critical role of positional information in Transformer models by incorporating a straightforward method of positional encoding, ensuring our model acknowledges the order of elements in a sequence.
By focusing solely on the encoding side and omitting the decoder, our aim is to show the foundational aspects of Transformer models in a more digestible format. This simplified Transformer model provides a clear window into the inner workings of more complex architectures, laying the groundwork for understanding and innovation.
Self-Attention
This section explains a PyTorch implementation of a simplified self-attention mechanism in the SimpleSelfAttention module, crucial for understanding data sequences in a demo Transformer model. The module, focusing on a single attention head for simplicity, includes:
Initialization Parameters:
embed_size
sets the input vector size and dimensions for query, key, and value vectors.heads
parameter, although fixed to one here, allows for potential scalability.Linear Transformations: It transforms inputs into value, key, and query matrices via
self.values
,self.keys
, andself.queries
layers, and processes the attention output throughself.fc_out
.
Forward Pass Steps:
Generating Matrices: Transforms input into query, key, and value matrices.
Calculating Attention Scores: Computes scores by matrix multiplication of query and key, normalizes them with softmax scaled by the square root of
embed_size
.Applying Attention: Aggregates information using the weighted sum of value vectors and attention scores, outputting through
self.fc_out
.
Self-attention's dynamic weighting captures various dependencies within sequences, essential for tasks like natural language processing and sequence prediction, showcasing its versatility and effectiveness.
class SimpleSelfAttention(nn.Module):
def __init__(self, embed_size, heads=1):
super(SimpleSelfAttention, self).__init__()
self.embed_size = embed_size
self.heads = heads
self.values = nn.Linear(embed_size, embed_size, bias=False)
self.keys = nn.Linear(embed_size, embed_size, bias=False)
self.queries = nn.Linear(embed_size, embed_size, bias=False)
self.fc_out = nn.Linear(embed_size, embed_size)
def forward(self, value, key, query):
# Get Q, K, V matrices
queries = self.queries(query)
keys = self.keys(key)
values = self.values(value)
# Calculate the attention scores
energy = torch.bmm(queries, keys.transpose(1, 2))
attention = torch.softmax(
energy / (self.embed_size ** (1 / 2)), dim=-1
)
# Get the weighted value vectors
out = torch.bmm(attention, values)
out = self.fc_out(out)
return out
SimpleTransformerBlock: The Core of Our Transformer Model
The SimpleTransformerBlock
class encapsulates the essence of a Transformer block, streamlined for our demonstration purposes. It integrates self-attention with basic Transformer architecture components, including normalization layers and a simple feed-forward network, to illustrate the model's core functionality.
Components of SimpleTransformerBlock:
Self-Attention Mechanism: At the heart of the block is the
SimpleSelfAttention
module, which we previously defined. This module allows the Transformer to weigh the importance of different parts of the input sequence relative to each other, a fundamental aspect of the Transformer's ability to understand context.Normalization Layers: Following the self-attention mechanism, the
SimpleTransformerBlock
employs two layer normalization (LayerNorm
) steps. The first normalization layer (norm1
) is applied directly after the self-attention output, and the second (norm2
) follows the feed-forward network. These normalization layers help stabilize the learning process, ensuring that the model's activations remain consistent across different inputs and training iterations.Feed-Forward Network: A simple position-wise feed-forward network follows the self-attention mechanism. This network consists of two linear transformations with a ReLU activation in between. It is designed to process each position of the input sequence independently, allowing the model to refine the representations further after considering the attention-based context. The network expands the dimensionality of the input from
embed_size
toembed_size * 4
and then compresses it back toembed_size
, offering the model an additional layer of abstraction and complexity.
Forward Pass:
The forward pass of the SimpleTransformerBlock
takes three inputs: value
, key
, and query
, which are processed by the self-attention mechanism. The output of the attention mechanism is then normalized and passed through the feed-forward network. The final output is a transformed representation of the input sequence, enriched by both the self-attention mechanism's contextual insights and the feed-forward network's processing.
class SimpleTransformerBlock(nn.Module):
def __init__(self, embed_size):
super(SimpleTransformerBlock, self).__init__()
self.attention = SimpleSelfAttention(embed_size)
self.norm1 = nn.LayerNorm(embed_size)
self.norm2 = nn.LayerNorm(embed_size)
self.feed_forward = nn.Sequential(
nn.Linear(embed_size, embed_size * 4),
nn.ReLU(),
nn.Linear(embed_size * 4, embed_size),
)
def forward(self, value, key, query):
attention = self.attention(value, key, query)
# Add skip connection, followed by LayerNorm
x = self.norm1(attention + query)
forward = self.feed_forward(x)
# Add skip connection, followed by LayerNorm
out = self.norm2(forward + x)
return out
Positional Encoding in Transformer Models
Positional Encoding is a fundamental component in Transformer models, adding crucial information about the order of sequences to the model's input. This mechanism allows Transformers, which are inherently order-agnostic due to their parallel processing nature, to account for the sequence of input data, such as words in a sentence or time series data.
Implementation Overview
The PositionalEncoding
module generates a unique encoding for each position in a sequence up to a maximum length (max_len
). It relies on a mix of sine and cosine functions with different frequencies, where each position gets a distinct vector based on its order in the sequence. This encoding is then added to the input embeddings, ensuring the model can distinguish between sequences based on their order, a critical feature for tasks involving syntax or temporal patterns.
Key Components
Encoding Matrix: Pre-computed at initialization, this matrix holds the positional encodings for all positions up to
max_len
. It uses a pattern of sine and cosine functions, where the wavelength varies across the dimensions of the embedding space, allowing each position to have a unique encoding.Forward Method: The forward pass simply adds the positional encoding to the input embeddings. The encoding is sliced to match the input sequence length, ensuring compatibility regardless of the actual sequence size. This addition is performed element-wise, blending the original embeddings with positional information.
The inclusion of positional encodings is what enables Transformer models to perform sequence-dependent tasks effectively, despite their parallel processing capabilities. Through this elegant solution, Transformers maintain their efficiency while gaining the ability to recognize and utilize the order of input sequences, a cornerstone for their success in various applications.
class PositionalEncoding(nn.Module):
def __init__(self, embed_size, max_len=100):
super(PositionalEncoding, self).__init__()
self.encoding = torch.zeros(max_len, embed_size)
for pos in range(max_len):
for i in range(0, embed_size, 2):
position = torch.tensor([[pos]], dtype=torch.float32)
div_term = torch.pow(
10000, (
2 * (i // 2)) / torch.tensor(embed_size
).float()
)
self.encoding[pos, i] = torch.sin(
position / div_term
)
self.encoding[pos, i + 1] = torch.cos(
position / div_term
)
self.encoding = self.encoding.unsqueeze(0)
def forward(self, x):
return x + self.encoding[:, : x.size(1), :].to(x.device)
SimpleTransformer Class: Building a Transformer from Scratch
The SimpleTransformer
class encapsulates the essence of a Transformer model in a compact, understandable format. This class is designed to demonstrate the fundamental operations within a Transformer, including embedding inputs, adding positional encodings, processing through a Transformer block, and producing output predictions.
Core Components of SimpleTransformer
Embedding Layer: The
nn.Embedding
layer maps each input token to a high-dimensional vector. This embedding process is crucial for representing discrete tokens (like integers in our simplified example) in a form that the model can process.Positional Encoding: The
PositionalEncoding
module adds sequence order information to the embeddings. Since the Transformer architecture doesn't inherently process tokens in order, positional encodings ensure that the model can recognize the position of each token in the sequence.Transformer Block: The
SimpleTransformerBlock
is the heart of the model, where the self-attention mechanism and a feed-forward network enable the model to understand and transform the input data based on both content and context.Output Layer: Finally, a linear layer (
nn.Linear
) projects the Transformer block's output to the desired output size. This output could represent, for instance, the next number in a sequence in our simple demonstration task.
Forward Pass Explained
The forward pass of the SimpleTransformer
model follows these steps:
Embedding: Input tokens are passed through the embedding layer, transforming them into dense vectors that the model can process.
Positional Encoding Addition: The positional encodings are added to the embeddings, infusing the model with the ability to recognize the order of tokens within the sequence.
Processing Through Transformer Block: The embedded and positionally encoded inputs are then fed into the Transformer block. Here, through self-attention and subsequent processing, the model dynamically adjusts its focus on different parts of the input to generate a contextually enriched representation.
Generating Output: The Transformer block's output is passed through a final linear layer, which shapes the output to match the desired prediction task. In our case, this step aims to predict the next number in a sequence based on the context provided by the input sequence.
This concise implementation of a Transformer model in PyTorch illustrates the core principles behind more complex architectures like BERT and GPT. By understanding and experimenting with this simplified model, learners and practitioners can gain insights into the workings of Transformers and their applications in various tasks within the realm of artificial intelligence.
class SimpleTransformer(nn.Module):
def __init__(self, embed_size, max_len, output_size):
super(SimpleTransformer, self).__init__()
self.embed = nn.Embedding(output_size, embed_size)
self.pos_encoder = PositionalEncoding(embed_size, max_len)
self.transformer_block = SimpleTransformerBlock(embed_size)
self.fc_out = nn.Linear(embed_size, output_size)
def forward(self, x):
embedding = self.embed(x)
# Add positional encoding
embedding += self.pos_encoder(embedding)
transformer_out = self.transformer_block(
embedding, embedding, embedding
)
out = self.fc_out(transformer_out)
return out
Running the SimpleTransformer Model on Sample Input Data
In this section, we demonstrate how to run our SimpleTransformer
model on a simple input sequence. Our task is framed around predicting the next digit in a sequence, utilizing a vocabulary size corresponding to 10 digits (0-9) and a sequence length of 3 for demonstration purposes.
Setting Up the Model
First, we initialize the SimpleTransformer
model with an embedding size suitable for our task, a maximum sequence length of 3, and an output size equal to our vocabulary size of 10 digits.
model = SimpleTransformer(embed_size, sequence_length, vocab_size)
Preparing the Input
Our sample input is a sequence of 3 digits [1, 2, 3]
, and we aim for the model to predict the next digit 4
. To feed this sequence into our model, we convert it into a tensor and ensure it's in the correct shape for our model's input.
sample_sequence = [1, 2, 3]
sample_tensor = torch.tensor(sample_sequence, dtype=torch.long).unsqueeze(0)
Running the Model
With the model and input prepared, we run the forward pass to generate a prediction. The output of the model will be a set of logits corresponding to the likelihood of each digit being the next in the sequence.
model.eval() # Set the model to evaluation mode
with torch.no_grad(): # Disable gradient computation for inference
predictions = model(sample_tensor)
predicted_index = predictions.argmax(-1) # Get the index of the max log-probability for the last position
# Assuming we're predicting the next number in the sequence
predicted_number = predicted_index[0, -1].item() # Convert to Python number
print(f"Input Sequence: {sample_sequence}")
print(f"Predicted Next Number: {predicted_number}")
Sample Output:
Input Sequence: [1, 2, 3]
Predicted Next Number: 4
Interpretation
The model's prediction, represented by the index of the highest logit, indicates the digit it predicts to follow the input sequence [1, 2, 3]
. For this simplified example, we aim for it to predict 4
, demonstrating the model's ability to learn sequential patterns.
Note that this example is over-simplified and needs hyperparameter tuning and other adjustments to make it more robust.
For those interested in diving deeper into the code and trying out modifications or more complex examples, the complete code for this demonstration, including the SimpleTransformer
model and sample input data handling, is available via the provided GitHub link below.
Conclusion
As we conclude our series on Transformer models, we've journeyed from the theoretical foundations to the cutting-edge applications that have reshaped the landscape of artificial intelligence. This final post, providing a hands-on demonstration of building a simple Transformer model in PyTorch, aims to solidify your understanding of these powerful architectures.
The simplified model we've discussed encapsulates the core components of Transformers, offering a practical insight into their operation. While this model is a basic representation, it serves as a stepping stone for diving deeper into the complexities and capabilities of more advanced Transformer architectures.
For those eager to explore the code and experiment with the model, we invite you to access the full implementation via the GitHub link provided. This resource is not just code; it's an invitation to tinker, adapt, and innovate on the foundations of Transformer models.
Thank you for joining us on this enlightening journey through the world of Transformers. We look forward to seeing how you'll leverage these insights and tools to push the boundaries of what's possible in AI.
Thank you for reading!