Building Attention Mechanisms from Scratch

Building Attention Mechanisms from Scratch: A Deep Dive

Introduction: Why Attention Changed Everything

Before 2014, sequence-to-sequence models compressed entire input sequences into a single fixed-size vector—a bottleneck that lost critical information. Imagine summarizing a entire book into one sentence, then trying to translate it. The attention mechanism solved this by allowing models to dynamically focus on relevant parts of the input, revolutionizing NLP and beyond.

Why and How Self-Attention Changed Everything

The Fundamental Breakthrough: "Every Token Can Directly Talk to Every Other Token"

Before Self-Attention (The Sequential Prison)

Token1 → Token2 → Token3 → Token4 → Token5
         ↑ Must pass through intermediate states
  • Information bottleneck: Later tokens only know about earlier ones through a compressed hidden state
  • Gradient degradation: Learning long-range dependencies requires gradients to flow through many steps
  • Sequential computation: Can't process token 5 until you've processed tokens 1-4

After Self-Attention (The Democratic Network)

Token1 ←→ Token2 ←→ Token3 ←→ Token4 ←→ Token5
  ↑         ↑         ↑         ↑         ↑
  └─────────┴─────────┴─────────┴─────────┘
     Every token directly connects to every other
  • Direct information highway: Any token can directly access any other token's information
  • Gradient superhighway: Gradients flow directly between any pair of tokens
  • Parallel computation: All connections computed simultaneously

The Three Revolutionary Properties

1. "Constant Path Length" - O(1) operations between any two positions

# RNN/LSTM: O(n) operations to connect distant tokens
hidden_state = token1 → process → token2 → process → ... → token_n

# Self-Attention: O(1) operations - direct connection
attention_score = query_token_n @ key_token_1  # Direct computation!

2. "Content-Based Routing" - The model learns WHAT to look at, not WHERE

# Traditional: Fixed information flow
next_hidden = f(current_hidden, current_input)  # Always uses previous state

# Self-Attention: Dynamic information flow
relevance = softmax(query @ all_keys)  # Model decides what's relevant
output = relevance @ all_values         # Retrieves relevant information

3. "Parallelizable by Design" - No sequential dependencies

# RNN: Must compute sequentially (can't parallelize)
for t in range(seq_len):
    h[t] = f(h[t-1], x[t])  # Depends on previous step

# Self-Attention: Fully parallel (matrix multiplication)
ALL_OUTPUTS = softmax(Q @ K.T / √d) @ V  # One parallel operation!

The Killer Equation That Changed NLP

Output = Softmax(QK^T/√d)V

Why this is revolutionary:

  • Q (Query): "What am I looking for?"
  • K (Key): "What information do I contain?"
  • V (Value): "Here's my actual information"
  • Result: Every position dynamically retrieves information from ALL other positions based on learned relevance

The Cascading Effects

Immediate Impact (2017-2018)

  • BERT: Bidirectional understanding became possible (no sequential constraint)
  • GPT: Scaling became feasible (parallel training on massive data)

Secondary Revolution (2019-2021)

  • Vision Transformers: "Wait, images are just sequences of patches!"
  • Protein Folding: "Amino acids are just sequences!"
  • Music Generation: "Notes are just sequences!"

The Current Era (2022-2024)

  • Multimodal Models: Different modalities unified through attention
  • Scaling Laws: Self-attention scales predictably with compute/data
  • Emergent Abilities: Large-scale attention networks show unexpected capabilities

The Core Insight in Three Phrases

1. "From Sequential Steps to Parallel Glimpses"

  • Computation changed from step-by-step to all-at-once

2. "From Fixed Patterns to Learned Relevance"

  • Information flow changed from hardcoded to dynamic

3. "From Local Context to Global Understanding"

  • Models changed from seeing windows to seeing everything

The Simplest Explanation

Traditional Models: Like reading a book where you can only remember the last few pages
Self-Attention: Like having the entire book spread on a table, able to connect any two ideas instantly

Why It Scales So Well

Computational Properties:
├── Matrix Multiplication → GPUs love this
├── No Recurrence → Perfect parallelization  
├── Fixed Computation Graph → Easy to optimize
└── Differentiable Everywhere → Gradient-friendly

The Philosophical Shift

From "Process in Order" to "Relate Everything to Everything"

This mirrors how humans actually think:

  • We don't process information strictly sequentially
  • We constantly reference and cross-reference
  • We dynamically focus on what's relevant

Self-attention mechanized the human cognitive process of selective attention - and that changed everything.


"Self-attention replaced sequential processing with parallel relevance computation, enabling models to learn what to look at rather than being told how to look."

Belows are some of components that go to make self-attention possible in details.

Part 1: The Mathematical Foundation

Understanding Attention as Weighted Averaging

At its core, attention is surprisingly simple: it's a learned, dynamic weighted average. Given a query q and a set of key-value pairs (KV), attention computes:

Attention(q, K, V) = Σᵢ α(q, káµ¢) · váµ¢

Where α(q, káµ¢) represents the attention weight—how much we should "pay attention" to value váµ¢ when processing query q.

The Critical Question: How Do We Compute α?

The genius lies in making α learnable and context-dependent. The standard approach:

  1. Score: Compute similarity between query and each key
  2. Normalize: Apply softmax to get probability distribution
  3. Aggregate: Weighted sum of values
def attention(query, keys, values):
    # Step 1: Compute scores (using dot product)
    scores = torch.matmul(query, keys.transpose(-1, -2))
    
    # Step 2: Normalize with softmax
    weights = F.softmax(scores, dim=-1)
    
    # Step 3: Weighted aggregation
    output = torch.matmul(weights, values)
    return output, weights

Key Insight: The dot product measures similarity in high-dimensional space. Higher dot product = more similar = pay more attention.

Part 2: Scaled Dot-Product Attention

The Scaling Problem

Why do we divide by √d_k in the Transformer? Let's analyze:

For random vectors with unit variance, the dot product of d-dimensional vectors has variance d. As d grows:

  • Dot products grow to large magnitudes
  • Softmax becomes peaky (approaching one-hot)
  • Gradients vanish
def scaled_dot_product_attention(Q, K, V):
    d_k = Q.size(-1)
    
    # Without scaling: variance of scores ≈ d_k
    # With scaling: variance of scores ≈ 1
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(d_k)
    
    # Now softmax operates in reasonable range
    attention_weights = F.softmax(scores, dim=-1)
    
    return torch.matmul(attention_weights, V)

Mathematical Proof:

  • Var(q·k) = Σᵢ Var(qáµ¢·káµ¢) = d · Var(q₁·k₁) = d
  • After scaling: Var(q·k/√d) = 1

Part 3: Multi-Head Attention - The Power of Diverse Perspectives

Why Multiple Heads?

Single attention learns one representation pattern. Multi-head attention learns multiple, parallel representation subspaces:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_model=512, n_heads=8):
        super().__init__()
        assert d_model % n_heads == 0
        
        self.d_model = d_model
        self.n_heads = n_heads
        self.d_k = d_model // n_heads  # 64 in this case
        
        # Separate projections for each component
        self.W_q = nn.Linear(d_model, d_model)
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model)
        
    def forward(self, query, key, value, mask=None):
        batch_size = query.size(0)
        
        # 1. Linear projections in batch from d_model => h x d_k
        Q = self.W_q(query).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        K = self.W_k(key).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        V = self.W_v(value).view(batch_size, -1, self.n_heads, self.d_k).transpose(1, 2)
        
        # 2. Apply attention on all the projected vectors in batch
        attention_output, attention_weights = scaled_dot_product_attention(Q, K, V, mask)
        
        # 3. Concatenate heads and put through final linear layer
        attention_output = attention_output.transpose(1, 2).contiguous().view(
            batch_size, -1, self.d_model
        )
        
        return self.W_o(attention_output)

Critical Analysis: Each head can learn different types of relationships:

  • Head 1 might learn syntactic dependencies
  • Head 2 might learn semantic similarity
  • Head 3 might learn positional patterns

Part 4: The Mask - Controlling Information Flow

Causal Masking for Autoregressive Models

In decoder self-attention, we prevent positions from attending to future positions:

def create_causal_mask(size):
    # Create upper triangular matrix
    mask = torch.triu(torch.ones(size, size), diagonal=1).bool()
    return mask

def apply_mask(scores, mask):
    # Set masked positions to -inf before softmax
    # After softmax, these become 0
    return scores.masked_fill(mask, float('-inf'))

Why -inf? After softmax: e^(-inf) / Σe^x = 0

Padding Mask for Variable Length Sequences

Real-world sequences have different lengths. We use padding masks to ignore padded positions:

def create_padding_mask(lengths, max_len):
    batch_size = len(lengths)
    mask = torch.arange(max_len).expand(batch_size, max_len) >= lengths.unsqueeze(1)
    return mask

Part 5: Position Information - The Missing Piece

Why Positional Encoding?

Attention has no inherent notion of position—it's permutation invariant. The Transformer's sinusoidal encoding:

def positional_encoding(max_len, d_model):
    pe = torch.zeros(max_len, d_model)
    position = torch.arange(0, max_len).unsqueeze(1).float()
    
    div_term = torch.exp(torch.arange(0, d_model, 2).float() * 
                         -(math.log(10000.0) / d_model))
    
    pe[:, 0::2] = torch.sin(position * div_term)  # Even dimensions
    pe[:, 1::2] = torch.cos(position * div_term)  # Odd dimensions
    
    return pe

Mathematical Properties:

  • Unique encoding for each position
  • Allows model to learn relative positions: PE(pos + k) can be represented as linear function of PE(pos)
  • Bounded values ([-1, 1])

Part 6: Computational Complexity Analysis

Time and Space Complexity

For sequence length n and dimension d:

OperationTime ComplexitySpace Complexity
Self-AttentionO(n²·d)O(n²)
Feed-ForwardO(n·d²)O(n·d)

When does attention become the bottleneck?

  • When n² > d (typically n > 512 for d = 512)
  • Hence recent work on linear attention mechanisms

Memory-Efficient Implementation

def memory_efficient_attention(Q, K, V, chunk_size=32):
    """
    Compute attention in chunks to reduce memory usage.
    Useful for very long sequences.
    """
    seq_len = Q.size(-2)
    output = torch.zeros_like(Q)
    
    for i in range(0, seq_len, chunk_size):
        end_i = min(i + chunk_size, seq_len)
        q_chunk = Q[..., i:end_i, :]
        
        # Compute attention for this chunk
        scores = torch.matmul(q_chunk, K.transpose(-2, -1)) / math.sqrt(Q.size(-1))
        weights = F.softmax(scores, dim=-1)
        output[..., i:end_i, :] = torch.matmul(weights, V)
    
    return output

Part 7: Common Pitfalls and Debugging

1. Gradient Vanishing/Exploding

Symptom: Attention weights become uniform or one-hot
Solution: Check scaling, use gradient clipping, layer normalization

# Monitor attention entropy
entropy = -(weights * weights.log()).sum(dim=-1).mean()
# Low entropy = focused attention, High entropy = uniform attention

2. Attention Collapse

Symptom: All queries attend to same key positions
Solution: Regularization, dropout on attention weights

def attention_with_dropout(Q, K, V, dropout_p=0.1):
    scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(Q.size(-1))
    weights = F.softmax(scores, dim=-1)
    weights = F.dropout(weights, p=dropout_p, training=self.training)
    return torch.matmul(weights, V)

Conclusion: The Elegance of Attention

The attention mechanism's beauty lies in its simplicity: it's fundamentally about learning what to look at. From this simple concept emerges the ability to:

  • Model long-range dependencies
  • Process sequences in parallel
  • Create interpretable models (attention visualization)
  • Build state-of-the-art models across domains

The key insights we've covered:

  1. Attention as learnable routing - dynamically choosing what information flows where
  2. Scaling for stability - mathematical necessity, not arbitrary choice
  3. Multi-head for diversity - multiple representation subspaces
  4. Masking for control - directing the flow of information
  5. Position encoding - adding sequential awareness

Comments