Llama2 Architecture (P2)

LLAMA2 Architecture

Trong chương này, chúng ta sẽ cùng tìm hiểu về 3 chi tiết còn lại: RMS Norm, Group Multi Query Attention with KV cache, và Feed Forward SwiGLU. Hãy cùng khám phá những khái niệm mới này và tìm hiểu cách chúng hoạt động!

import torch
import torch.nn as nn
from transformers import AutoTokenizer
from datasets import load_dataset
import math
from einops import rearrange # einstein operation
sample = 20

dataset = load_dataset("roneneldan/TinyStories")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-125M")
tokenizer.pad_token = tokenizer.eos_token

subset_dataset = dataset['train'][:sample]['text']
tokenized_dataset = tokenizer(
    subset_dataset,
    return_tensors='pt',
    padding=True,  # Enable padding
    truncation=True  # Enable truncation
)

data = tokenized_dataset['input_ids']
data.shape
Repo card metadata block was not found. Setting CardData to empty.
torch.Size([20, 219])
class ModelArgs:
    def __init__(self, sequence_len, vocab_size):

        self.rotary_dim = 3
        
        self.n_layer = 2
        self.batch_size = 16
        self.n_head = 4
        self.n_embd = 36
        self.sequence_len = sequence_len
        self.vocab_size = vocab_size


sequence_len = data.size(1) - 1
vocab_size = tokenizer.vocab_size

args = ModelArgs(sequence_len, vocab_size)
def get_batch(data, batch_size):
    idx = torch.randint(0, len(data), size=(batch_size,))
    batch = data[idx]

    xb = batch[:, :-1].contiguous()
    yb = batch[:, 1:].contiguous()
    
    return xb, yb

xb, yb = get_batch(data, args.batch_size)
xb.shape, yb.shape
(torch.Size([16, 218]), torch.Size([16, 218]))

Embedding

LLAMA2 Embedding

class Embedding(nn.Module):
    def __init__(self, args:ModelArgs):
        super().__init__()
        self.wte = nn.Embedding(args.vocab_size, args.n_embd)
        
    def forward(self, input_ids):
        input_ids_embd = self.wte(input_ids)
        
        return input_ids_embd
    
embd = Embedding(args)
x_embd = embd(xb)
x_embd.shape
torch.Size([16, 218, 36])

Rotary Embedding

Rotary Position

class RotaryPositionEmbedding(nn.Module):
    def __init__(self, args:ModelArgs, base = 10000):
        super().__init__()
        self.rotary_dim  = args.rotary_dim
        
        inv_freq = 1.0 / (base ** (torch.arange(0, self.rotary_dim, 2) / self.rotary_dim ))
        self.register_buffer("inv_freq", inv_freq)
        
        self.cos_cache = None
        self.sin_cache = None
        
    def forward(self, qkv):
        seqlen = qkv.shape[1]
        
        # Update cos sin cache
        t = torch.arange(seqlen)
        freqs = torch.outer(t, self.inv_freq)
        
        self.cos_cache = torch.cos(freqs)
        self.sin_cache = torch.sin(freqs)
        
        # Apply rotary qkv
        rotary_dim = self.cos_cache.shape[1]
        rotary_dim *= 2
        
        q_rot = qkv[:, :, 0, :, :rotary_dim]
        q_pass = qkv[:, :, 0, :, rotary_dim:]
        
        k_rot = qkv[:, :, 1, :, :rotary_dim]
        k_pass = qkv[:, :, 1, :, rotary_dim:]
        
        # Splits the queries and keys in half
        q1, q2 = q_rot.chunk(2, dim=-1)
        k1, k2 = k_rot.chunk(2, dim=-1)
        c, s = rearrange(self.cos_cache, "t d -> t 1 d"), rearrange(self.sin_cache, "t d -> t 1 d")
        
        # Computes the new keys and queries
        q_rot = torch.cat([q1 * c - q2 * s, q1 * s - q2 * c], dim=-1)
        k_rot = torch.cat([k1 * c - k2 * s, k1 * s - k2 * c], dim = -1)
        
        return torch.cat(
            [
                torch.cat([q_rot, q_pass], dim=-1).unsqueeze(2),
                torch.cat([k_rot, k_pass], dim=-1).unsqueeze(2),
                qkv[:, :, 2:3, :, :]
            ],
            dim=2
        )

RMS Norm

RMS Norm

class RMSNorm(torch.nn.Module):
    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = nn.Parameter(torch.ones(dim))

    def _norm(self, x):
        return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

    def forward(self, x):
        output = self._norm(x.float()).type_as(x)
        return output * self.weight
    
    
attn_norm = RMSNorm(args.n_embd)
x_embd_norm = attn_norm(x_embd)
x_embd_norm.shape
torch.Size([16, 218, 36])

RMS Norm thực chất là một biến thể của Layer Norm. Ý tưởng cơ bản là thay vì sử dụng Layer Norm làm quá trình normalize dữ liệu, họ chuyển sang sử dụng RMS Norm. Cụ thể, trong ảnh trên thay vì áp dụng layer norm cho x_embd để tính toán qkv, họ thay thế nó bằng RMS Norm. Điều này cũng áp dụng tương tự cho việc normalize output. Sự thay đổi này có thể mang lại một số ưu điểm cụ thể trong quá trình xử lý và huấn luyện mô hình.

Self Attention

Self Attention

import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, args:ModelArgs):
        super().__init__()
        
        self.rotary_emb = RotaryPositionEmbedding(args)
        
        self.head_dim = args.n_embd // args.n_head
        opt_size = args.n_head * self.head_dim
        hidden_size = args.n_embd
        
        self.Wqkv = nn.Linear(hidden_size, 3 * opt_size)
        self.out_proj = nn.Linear(opt_size, hidden_size)
        
    def forward(self, input_ids_embd_norm):
        seq_len = input_ids_embd_norm.shape[1]
        
        qkv = self.Wqkv(input_ids_embd_norm)
        qkv = rearrange(qkv, 'b t (three h d) -> b t three h d', three=3, d=self.head_dim)
        
        # Rotary Query & Key
        qkv = self.rotary_emb(qkv)
        
        q, k, v = qkv.unbind(2)
        
        # New code
        # --------------------------------------------------------------------------------
        output = F.scaled_dot_product_attention(q, k, v, is_causal=True)

        # softmax_scale = 1.0 / math.sqrt(q.shape[-1])
        # scores = torch.einsum("bthd, bshd -> bhts", q, k * softmax_scale)
        
        # mask = torch.triu(torch.full((seq_len, seq_len), -10000), 1)
        # scores += mask
        
        # attention_weights = torch.softmax(scores, dim=-1)
        
        # output = torch.einsum("bhts, bshd -> bthd", attention_weights, v)
        # ----------------------------------------------------------------------------------
        
        output = rearrange(output, "... h d -> ... (h d)")

        attn_out = self.out_proj(output)
        
        return attn_out
    
# Normalize
attn_norm = RMSNorm(args.n_embd)
x_embd_norm = attn_norm(x_embd)

attn = Attention(args)
attn_out = attn(x_embd_norm)
# add residual
attn_out += x_embd
attn_out.shape
torch.Size([16, 218, 36])

Trong phương pháp trước đây, việc tính toán lại các scores một cách lặp đi lặp lại dẫn đến sự lãng phí đáng kể về hiệu suất tính toán. Ví dụ, trong câu “Tôi thích chạy bộ” khi chúng ta cố gắng dự đoán từ “thích” dựa trên từ “Tôi”, chúng ta thực hiện tính scores bằng cách nhân toàn bộ query của từ cần dự đoán với toàn bộ key của các từ khác, sau đó phải loại bỏ scores của các từ không cần thiết (trong trường hợp này là “thích chạy bộ”). Tương tự, khi dự đoán từ “chạy”, chúng ta lại thực hiện lại quá trình tính scores bằng cách nhân toàn bộ query với các key, sau đó phải loại bỏ scores của các từ không cần thiết (“chạy bộ”). Điều này dẫn đến một sự lãng phí lớn về hiệu suất tính toán.

Ý tưởng chính của phương pháp “Group Multi-Query Attention with KV cache” là giảm thiểu sự lãng phí này bằng cách tận dụng lại các kết quả đã được tính toán và lưu trữ trước đó thay vì tính toán lại từ đầu. Hàm F.scaled_dot_product_attention thực hiện ý tưởng này và đồng thời giúp mã nguồn trở nên rõ ràng hơn đáng kể. Sử dụng cách tiếp cận này giúp chúng ta tái sử dụng các kết quả trước đó đã được tính toán và lưu trữ, từ đó giảm thiểu việc tính toán lại và giúp mã nguồn trở nên dễ đọc và dễ hiểu hơn rất nhiều.

Feed Forward

Feed Forward

class FeedForward(nn.Module):
    def __init__(self, args:ModelArgs):
        super().__init__()
        multiple_of = 5
        
        hidden_dim = 4 * args.n_embd
        hidden_dim = int(2 * hidden_dim / 3)
        hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
        
        self.w1 = nn.Linear(args.n_embd, hidden_dim, bias=False)
        self.w2 = nn.Linear(hidden_dim, args.n_embd, bias=False)
        
        self.w3 = nn.Linear(args.n_embd, hidden_dim, bias=False)
        
        self.act = nn.SiLU()
        
    def forward(self, attn_out_norm):
        
        hidden_states = self.w1(attn_out_norm) * self.w3(attn_out_norm)
        hidden_states = self.act(hidden_states)
        
        ffwd_out = self.w2(hidden_states)
        
        return ffwd_out
    
# Normalize
ffwd_norm = RMSNorm(args.n_embd)
attn_out_norm = ffwd_norm(attn_out)

ffwd = FeedForward(args)
ffwd_out = ffwd(attn_out_norm)
# add residual
ffwd_out += attn_out
ffwd_out.shape
torch.Size([16, 218, 36])

Giống như trước đó, chúng ta sẽ sử dụng RMS Norm để normalize output attention thay vì Layer Norm.

Feed Forward SwiGLU là một cải tiến của phương pháp Feed Forward thông thường, nhằm tăng cường khả năng học và biểu diễn của mô hình. Bằng cách tăng cường phức tạp hóa cấu trúc của lớp feed forward, SwiGLU có thể học được các mối quan hệ phức tạp và đặc trưng của dữ liệu một cách hiệu quả hơn. Việc tăng cường tính phức tạp của kiến trúc này thường đi kèm với việc sử dụng các phép tính toán và hàm activation phức tạp hơn (SiLU), nhằm tăng tính linh hoạt và mạnh mẽ của mô hình trong việc xử lý dữ liệu phức tạp và đa dạng.

Transformer Block

Transformer Block

class TransfomerBlock(nn.Module):
    def __init__(self, args:ModelArgs):
        super().__init__()
        
        self.attention_norm = RMSNorm(args.n_embd)
        self.ffwd_norm = RMSNorm(args.n_embd)
        
        self.attn = Attention(args)
        self.ffwd = FeedForward(args)
        
    def forward(self, input_ids_embd):
        
        attn_out = input_ids_embd + self.attn(self.attention_norm(input_ids_embd))
        
        ffwd_out = attn_out + self.ffwd(self.ffwd_norm(attn_out))
        
        return ffwd_out
    
t_block = TransfomerBlock(args)
ffwd_out = t_block(x_embd)
ffwd_out.shape
torch.Size([16, 218, 36])

LLAMA2 architecture có những cải tiến hơn so với transformer architecture mà chúng ta đã học trước đó. Các cải tiến này chủ yếu tập trung vào các class Attention và Feed Forward. Do đó, các class còn lại cơ bản chỉ thay đổi việc sử dụng normalize từ Layer Norm sang RMS Norm, các code khác đều sẽ được giữ nguyên.

Transformer

Transformer

class TransformerHead(nn.Module):
    def __init__(self, args:ModelArgs):
        super().__init__()
        
        self.norm = RMSNorm(args.n_embd)
        self.linear = nn.Linear(args.n_embd, args.vocab_size)
        
    def forward(self, ffwd_out):
        ffwd_out_norm = self.norm(ffwd_out)
        logits = self.linear(ffwd_out_norm)
        
        return logits
    
t_head = TransformerHead(args)
logits = t_head(ffwd_out)
logits.shape
torch.Size([16, 218, 50257])
class TransformerSequential(nn.Module):
    def __init__(self, args:ModelArgs):
        super().__init__()
        
        modules = [Embedding(args)]
        modules += [TransfomerBlock(args) for _ in range(args.n_layer)]
        modules.append(TransformerHead(args))
        
        self.layers = nn.Sequential(*modules)
        
    def forward(self, input_ids):
        return self.layers(input_ids)
    
model = TransformerSequential(args)
logits = model(xb)
logits.shape
torch.Size([16, 218, 50257])

Loss

class TransformerLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.loss_fct = nn.CrossEntropyLoss()
        
    def forward(self, logits, labels):
        logits = logits.view(-1, logits.shape[-1])
        labels = labels.view(-1)                    
                             
        loss = self.loss_fct(logits, labels)

        return loss
    
t_loss = TransformerLoss()
loss = t_loss(logits, yb)
loss
tensor(10.8817, grad_fn=<NllLossBackward0>)
data = tokenized_dataset['input_ids']
sequence_len = data.size(1) - 1
vocab_size = tokenizer.vocab_size

args = ModelArgs(sequence_len, vocab_size)
xb, yb = get_batch(data, args.batch_size)

model = TransformerSequential(args)
logits = model(xb)

t_loss = TransformerLoss()
loss = t_loss(logits, yb)
loss
tensor(11.0817, grad_fn=<NllLossBackward0>)

Vậy là chúng ta đã hoàn thành kiến trúc transformer của LLAMA2.Trong chương tiếp theo, chúng ta sẽ đào sâu vào việc khởi tạo trọng số (weight initialization). Có vẻ như phương pháp khởi tạo trọng số mặc định của PyTorch không còn phù hợp nữa, và LLAMA2 đã sử dụng một phương pháp khởi tạo trọng số khác. Hãy cùng tìm hiểu về điều này trong chương tiếp theo.