17. *Lab: LLM Pretraining#
Here we directly leverage the decoder architecture we made from previous sections.
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch.utils.data import DataLoader
from torch import nn
import numpy as np
from omegaconf import OmegaConf
from llm_lab.model.vanilla_decoder import VanillaDecoderModel
from llm_lab.utils.collate_utils import default_data_collator
from llm_lab.utils.common_utils import move_to_device
from transformers import AutoTokenizer
from datasets import load_dataset
from itertools import chain
from functools import partial
%load_ext autoreload
%autoreload 2
17.1. Data#
dataset_name = "wikitext"
data_config = "wikitext-2-raw-v1"
text_column_name = "text"
# model parameters
model_name_or_path="openai-community/gpt2"
raw_datasets = load_dataset(dataset_name, data_config)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_fast=True)
def tokenize(examples):
return tokenizer(examples[text_column_name])
def group_and_chunk(tokenized_examples, chunk_size=1024, chunk_key='input_ids'):
keys = list(tokenized_examples.keys())
# use chain to flatten list
concat_examples = {k: list(chain(*tokenized_examples[k])) for k in keys}
total_length = len(concat_examples[chunk_key])
total_length = (total_length // chunk_size) * chunk_size
result_dict = {
k: [v[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
for k, v in concat_examples.items()
}
return result_dict
tokenized_dataset = raw_datasets.map(
tokenize,
batched=True)
chunk_data = tokenized_dataset.map(
partial(group_and_chunk,
chunk_size=256),
#chunk_size=tokenizer.model_max_length),
batched=True,
remove_columns=['text'])
chunk_data
DatasetDict({
test: Dataset({
features: ['text', 'input_ids', 'attention_mask'],
num_rows: 1104
})
train: Dataset({
features: ['text', 'input_ids', 'attention_mask'],
num_rows: 9327
})
validation: Dataset({
features: ['text', 'input_ids', 'attention_mask'],
num_rows: 964
})
})
17.2. Model#
class DecoderCausalLM(nn.Module):
def __init__(self, config):
super().__init__()
self.decoder = VanillaDecoderModel(config)
self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
def forward(self, batch):
hidden_states = self.decoder(input_ids=batch['input_ids'])
logits = self.lm_head(hidden_states)
return logits
17.3. Training#
def compute_batch_loss(batch, model, device):
assert model.training
move_to_device(batch, device)
model_input = {'input_ids':batch['input_ids'],'attention_mask': batch['attention_mask']}
logits = model(model_input)[:,:-1,:].contiguous()
labels = batch['input_ids'][:,1:].contiguous()
flat_labels = labels.view(-1)
flat_logits = logits.view(-1, logits.shape[-1])
loss = F.cross_entropy(flat_logits, flat_labels)
return loss
def compute_eval_loss(eval_dataloader, model, device):
assert not model.training
all_losses = []
with torch.no_grad():
for batch in eval_dataloader:
move_to_device(batch, device)
model_input = {'input_ids':batch['input_ids'],'attention_mask': batch['attention_mask']}
logits = model(model_input)[:,:-1,:].contiguous()
labels = batch['input_ids'][:,1:].contiguous()
flat_labels = labels.view(-1)
flat_logits = logits.view(-1, logits.shape[-1])
losses = F.cross_entropy(flat_logits, flat_labels, reduction='none').tolist()
all_losses.extend(losses)
mean_loss = np.mean(all_losses)
return mean_loss
def train_model_epoch(model,
train_loader,
val_loader,
optimizer,
device,
train_config):
global_steps = 0
record_list = []
model = model.to(device)
for epoch in range(train_config.num_epochs):
for batch in train_loader:
model.train()
loss = compute_batch_loss(batch, model, device)
optimizer.zero_grad()
loss.backward()
optimizer.step()
global_steps += 1
if global_steps % train_config.log_freq == 0:
model.eval()
val_loss = compute_eval_loss(val_loader, model, device)
record = {"epoch": epoch,
"step": global_steps,
"train_loss": loss.detach().item(),
"val_loss": val_loss}
print(record)
record_list.append(record)
return record_list
def train_main(model_config, train_settings, chunk_data):
torch.manual_seed(train_settings.seed)
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = DecoderCausalLM(config=model_config)
model.to(device)
optimizer = torch.optim.AdamW(model.parameters(),
lr=train_settings.learning_rate,
weight_decay=train_settings.weight_decay)
train_loader = DataLoader(chunk_data['train'],
batch_size=train_settings.batch_size,
shuffle=True,
num_workers=0,
collate_fn=default_data_collator
)
val_loader = DataLoader(chunk_data['validation'],
batch_size=train_settings.batch_size,
shuffle=False,
num_workers=0,
collate_fn=default_data_collator
)
train_model_epoch(model=model,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
train_config=train_settings,
device=device)
17.4. Training Entry#
model_config = {
"vocab_size": 50257, # Vocabulary size
"context_length": 1024,
"d_model": 768, # model dimension
"num_heads": 4, # Number of attention heads
"num_layers": 6, # Number of layers
"dropout": 0.1, # Dropout rate
"qkv_bias": False # Query-key-value bias
}
model_config = OmegaConf.create(model_config)
train_settings = {
"learning_rate": 5e-4,
"num_epochs": 1,
"batch_size": 2,
"weight_decay": 0.1,
"seed": 1,
"log_freq": 50
}
train_settings = OmegaConf.create(train_settings)
# train model
train_main(model_config=model_config, train_settings=train_settings, chunk_data=chunk_data)
# save model
#torch.save(model.state_dict(), "model.pth")
# training process
# {'epoch': 0, 'step': 3850, 'train_loss': 0.6468448042869568, 'val_loss': 0.7575275960154447}
# {'epoch': 0, 'step': 3900, 'train_loss': 0.6099380850791931, 'val_loss': 0.7526979942824544}
# {'epoch': 0, 'step': 3950, 'train_loss': 0.6703057885169983, 'val_loss': 0.7228095072978866}
# {'epoch': 0, 'step': 4000, 'train_loss': 0.6628137826919556, 'val_loss': 0.7049126117942794}
# {'epoch': 0, 'step': 4050, 'train_loss': 0.6893414855003357, 'val_loss': 0.6988437660309196}
# {'epoch': 0, 'step': 4100, 'train_loss': 0.7205791473388672, 'val_loss': 0.6762202845779502}
# {'epoch': 0, 'step': 4150, 'train_loss': 0.8741434812545776, 'val_loss': 0.6782277143252924}