17. *Lab: LLM Pretraining#

Here we directly leverage the decoder architecture we made from previous sections.

import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import DataLoader
from torch import nn
import numpy as np
from omegaconf import OmegaConf
from llm_lab.model.vanilla_decoder import VanillaDecoderModel
from llm_lab.utils.collate_utils import default_data_collator
from llm_lab.utils.common_utils import move_to_device
from transformers import AutoTokenizer
from datasets import load_dataset
from itertools import chain
from functools import partial

%load_ext autoreload
%autoreload 2 

17.1. Data#

dataset_name =  "wikitext"
data_config = "wikitext-2-raw-v1"
text_column_name = "text"

# model parameters
model_name_or_path="openai-community/gpt2"
raw_datasets = load_dataset(dataset_name, data_config)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
def tokenize(examples):
    return tokenizer(examples[text_column_name])
def group_and_chunk(tokenized_examples, chunk_size=1024, chunk_key='input_ids'):
    keys = list(tokenized_examples.keys())
    # use chain to flatten list
    concat_examples = {k: list(chain(*tokenized_examples[k])) for k in keys}
    total_length = len(concat_examples[chunk_key])
    total_length = (total_length // chunk_size) * chunk_size
    
    result_dict = {
        k: [v[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, v in concat_examples.items()
    }

    return result_dict
tokenized_dataset = raw_datasets.map(
                    tokenize, 
                    batched=True)
chunk_data = tokenized_dataset.map(
                                    partial(group_and_chunk, 
                                            chunk_size=256),
                                        #chunk_size=tokenizer.model_max_length),
                                    batched=True,
                                    remove_columns=['text'])
chunk_data
DatasetDict({
    test: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 1104
    })
    train: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 9327
    })
    validation: Dataset({
        features: ['text', 'input_ids', 'attention_mask'],
        num_rows: 964
    })
})

17.2. Model#

class DecoderCausalLM(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.decoder = VanillaDecoderModel(config)
        self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
        
    def forward(self, batch):
        hidden_states = self.decoder(input_ids=batch['input_ids'])
        logits = self.lm_head(hidden_states)
        return logits

17.3. Training#

def compute_batch_loss(batch, model, device):
    assert model.training
    move_to_device(batch, device)
    model_input = {'input_ids':batch['input_ids'],'attention_mask': batch['attention_mask']}
    logits = model(model_input)[:,:-1,:].contiguous()
    labels = batch['input_ids'][:,1:].contiguous()
    flat_labels = labels.view(-1)
    flat_logits = logits.view(-1, logits.shape[-1])
    loss = F.cross_entropy(flat_logits, flat_labels)
    return loss

def compute_eval_loss(eval_dataloader, model, device):
    assert not model.training
    all_losses = []
    with torch.no_grad():
        for batch in eval_dataloader:
            move_to_device(batch, device)
            model_input = {'input_ids':batch['input_ids'],'attention_mask': batch['attention_mask']}
            logits = model(model_input)[:,:-1,:].contiguous()
            labels = batch['input_ids'][:,1:].contiguous()
            flat_labels = labels.view(-1)
            flat_logits = logits.view(-1, logits.shape[-1])
            losses = F.cross_entropy(flat_logits, flat_labels, reduction='none').tolist()
            all_losses.extend(losses)
    
    mean_loss = np.mean(all_losses)
    return mean_loss

def train_model_epoch(model, 
                train_loader, 
                val_loader, 
                optimizer,
                device,
                train_config):
    
    global_steps = 0
    record_list = []
    model = model.to(device)
    for epoch in range(train_config.num_epochs):
        
        
        for batch in train_loader:
            model.train()
            loss = compute_batch_loss(batch, model, device)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            global_steps += 1
            if global_steps % train_config.log_freq == 0:
                model.eval()
                val_loss = compute_eval_loss(val_loader, model, device)
                record = {"epoch": epoch,
                          "step": global_steps,
                          "train_loss": loss.detach().item(),
                          "val_loss": val_loss}
                print(record)
                record_list.append(record)
        
    return record_list
def train_main(model_config, train_settings, chunk_data):
    
    torch.manual_seed(train_settings.seed)
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
            
    model = DecoderCausalLM(config=model_config)
    model.to(device)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=train_settings.learning_rate,
                                  weight_decay=train_settings.weight_decay)
    
    
    train_loader = DataLoader(chunk_data['train'],
                                      batch_size=train_settings.batch_size,
                                      shuffle=True,
                                        num_workers=0,
                                        collate_fn=default_data_collator
    )
    
    val_loader =  DataLoader(chunk_data['validation'],
                                      batch_size=train_settings.batch_size,
                                      shuffle=False,
                                        num_workers=0,
                                        collate_fn=default_data_collator
    )
    
    train_model_epoch(model=model,
                train_loader=train_loader,
                val_loader=val_loader,
                optimizer=optimizer,
                train_config=train_settings,
                device=device)
    
    
    

17.4. Training Entry#

model_config = {
    "vocab_size": 50257,    # Vocabulary size
    "context_length": 1024,
    "d_model": 768,         # model dimension
    "num_heads": 4,          # Number of attention heads
    "num_layers": 6,         # Number of layers
    "dropout": 0.1,       # Dropout rate
    "qkv_bias": False       # Query-key-value bias
}

model_config = OmegaConf.create(model_config)
train_settings = {
    "learning_rate": 5e-4,
    "num_epochs": 1,
    "batch_size": 2,
    "weight_decay": 0.1,
    "seed": 1,
    "log_freq": 50
}

train_settings = OmegaConf.create(train_settings)

# train model
train_main(model_config=model_config, train_settings=train_settings, chunk_data=chunk_data)
    

# save model
#torch.save(model.state_dict(), "model.pth")

# training process
# {'epoch': 0, 'step': 3850, 'train_loss': 0.6468448042869568, 'val_loss': 0.7575275960154447}
# {'epoch': 0, 'step': 3900, 'train_loss': 0.6099380850791931, 'val_loss': 0.7526979942824544}
# {'epoch': 0, 'step': 3950, 'train_loss': 0.6703057885169983, 'val_loss': 0.7228095072978866}
# {'epoch': 0, 'step': 4000, 'train_loss': 0.6628137826919556, 'val_loss': 0.7049126117942794}
# {'epoch': 0, 'step': 4050, 'train_loss': 0.6893414855003357, 'val_loss': 0.6988437660309196}
# {'epoch': 0, 'step': 4100, 'train_loss': 0.7205791473388672, 'val_loss': 0.6762202845779502}
# {'epoch': 0, 'step': 4150, 'train_loss': 0.8741434812545776, 'val_loss': 0.6782277143252924}