19. *Lab: DPO Training#

import math
from typing import List, Optional, Tuple, Union
import os
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import json
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from functools import partial
from datasets import load_dataset
import collections
from torch.utils.data import DataLoader
from llm_lab.utils.common_utils import move_to_device

19.1. Data#

raw_dataset = load_dataset("HumanLLMs/Human-Like-DPO-Dataset")
raw_dataset['train'][652]
{'prompt': 'Oh, you like [insert interest]? Me too! What do you love about it?',
 'chosen': "Yeah! I'm super passionate about music! 🎵 There's just something about how a good song can evoke emotions and transport you to a different time and place, you know? 🕰️ I love how it can bring people together, too. What about you? What kind of music are you into? 🎶 Do you have a favorite artist or genre? 🤔",
 'rejected': "Good day. As a digital entity, I don't have a physical presence or a circadian rhythm, so I neither wake up early nor stay up late. I am designed to operate 24/7, providing assistance and responding to inquiries at any time. My purpose is to provide accurate and helpful information, and I do not have personal preferences or experiences."}
model_name = "Qwen/Qwen2.5-0.5B"
model_name = "MiniLLM/MiniLLM-gpt2-120M" # a tiny model for fast debug
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
def tokenize(example):
    result_dict = {}
    prompt_tokenized = tokenizer(example['prompt'])
    chosen_encoded = tokenizer(example['chosen'])
    rejected_encoded = tokenizer(example['rejected'])
    
    result_dict['chosen_encoded'] = {'input_ids': prompt_tokenized['input_ids'] + chosen_encoded['input_ids'],
                                   'attention_mask': prompt_tokenized['attention_mask'] + chosen_encoded['attention_mask'],
                                   'loss_mask': [0] * len(prompt_tokenized['input_ids']) + [1] * len(chosen_encoded['input_ids']),
                                   }
    result_dict['rejected_encoded'] = {'input_ids': prompt_tokenized['input_ids'] + rejected_encoded['input_ids'],
                                   'attention_mask': prompt_tokenized['attention_mask'] + rejected_encoded['attention_mask'],
                                   'loss_mask': [0] * len(prompt_tokenized['input_ids']) + [1] * len(rejected_encoded['input_ids']),
                                   }
    example.update(result_dict)
    return example
    
    
tokenized_dataset = raw_dataset.map(tokenize, remove_columns=['prompt','chosen','rejected'])
def custom_collate_fn(batch, tokenizer, ignore_idx=-100):
    
    max_len = max([len(e[type]['input_ids']) for type in ['chosen_encoded','rejected_encoded'] for e in batch])
    result_dict = {}
    
    for type in ['chosen_encoded','rejected_encoded']:
        if type not in result_dict:
            result_dict[type] = collections.defaultdict(list)
        for e in batch:
            needed = max_len - len(e[type]['input_ids'])
            e[type]['input_ids'] += [tokenizer.pad_token_id] * needed
            e[type]['attention_mask'] += [0] * needed
            e[type]['loss_mask'] += [0] * needed
            result_dict[type]['input_ids'].append(e[type]['input_ids'])
            result_dict[type]['attention_mask'].append(e[type]['attention_mask'])
            result_dict[type]['loss_mask'].append(e[type]['loss_mask'])
    
    for type in ['chosen_encoded','rejected_encoded']:
        for key in result_dict[type]:
            result_dict[type][key] = torch.LongTensor(result_dict[type][key])
            
    return result_dict
    
dataset = tokenized_dataset['train'].train_test_split(test_size=0.05)
# for batch in train_dataloader:
#     #print(batch)
#     chosen_batch = move_to_device(batch['chosen_encoded'], device)
    
#     chosen_logits = model(input_ids=chosen_batch['input_ids'], attention_mask=chosen_batch['attention_mask']).logits    
#     print(chosen_logits)
#     break

19.1.1. Preference Learning utility functions#

def preference_loss(
    chosen_log_probs: torch.FloatTensor, 
    rejected_log_probs: torch.FloatTensor,
    reference_chosen_log_probs: torch.FloatTensor,
    reference_rejected_log_probs: torch.FloatTensor,
    beta: float = 0.1, # suggested value in the DPO paper
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
    """
    Args:
        chosen_log_probs: log probabilities of the policy model for the chosen responses, shape: (batch_size,)
        
    """
    
    pi_logratios = chosen_log_probs - rejected_log_probs
    ref_logratios = reference_chosen_log_probs - reference_rejected_log_probs
    
    logratios_difference = pi_logratios - ref_logratios
    
    losses = - F.logsigmoid(beta * logratios_difference)
        
    chosen_rewards = beta * (chosen_log_probs - reference_chosen_log_probs).detach()
    rejected_rewards = beta * (rejected_log_probs - reference_rejected_log_probs).detach()
    
    return losses, chosen_rewards, rejected_rewards

def _get_squence_log_probs(
    logits: torch.FloatTensor,
    labels: torch.LongTensor,
    loss_mask: torch.LongTensor,
    average_log_prob: bool = False,
    ) -> torch.FloatTensor:
    
    """
    Args:
        logits: logits of the model output. Shape: (batch_size, seq_length, vocab_size)
        labels: labels for which token's log probability; label = -100 indicates ignore. Shape (batch_size, seq_length)

    """

    assert logits.shape[:-1] == labels.shape
    #assert labels.shape == loss_mask.shape
    # let the sequence be A, B, C, D
    # labels[:,1:] are B, C, D
    # logits corresponds to B, C, D, X
    # logits[:,:-1,:] corresponds to B, C, D
    labels = labels[:,1:].clone() # labels 
    logits = logits[:,:-1,:]
    loss_mask = loss_mask[:, 1:]
    
    
    # log_probs shape (batch_size, seq_len - 1, vocab_size)
    # label shape before unsqueeze - (batch_size, seq_len - 1), after - (batch_size, seq_len - 1, vocab_size)
    log_probs = logits.log_softmax(-1)
    # per_token_logps shape (batch_size, seq_len - 1)
    per_token_logps = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(2) # squeeze on the last dim
    
    if average_log_prob:
        return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
    else:
        return (per_token_logps * loss_mask).sum(-1)
        

19.2. Training#

def compute_batch_loss(batch, model, ref_model, device):
    
    assert model.training == True
    assert ref_model.training == False
    
    chosen_batch = batch['chosen_encoded']
    rejected_batch = batch['rejected_encoded']
    
    chosen_batch = move_to_device(chosen_batch, device)
    rejected_batch = move_to_device(rejected_batch, device)
    
    chosen_logits = model(input_ids=chosen_batch['input_ids'], attention_mask=chosen_batch['attention_mask']).logits
    rejected_logits = model(input_ids=rejected_batch['input_ids'], attention_mask=rejected_batch['attention_mask']).logits
    
    chosen_sequence_log_probs = _get_squence_log_probs(chosen_logits, labels = chosen_batch['input_ids'], loss_mask = chosen_batch['loss_mask'])
    rejected_sequence_log_probs = _get_squence_log_probs(rejected_logits, labels = rejected_batch['input_ids'], loss_mask = rejected_batch['loss_mask'])
    
    with torch.no_grad():
        chosen_logits_ref = ref_model(input_ids=chosen_batch['input_ids'], attention_mask=chosen_batch['attention_mask']).logits
        rejected_logits_ref = ref_model(input_ids=rejected_batch['input_ids'], attention_mask=rejected_batch['attention_mask']).logits
    
        chosen_sequence_log_probs_ref = _get_squence_log_probs(chosen_logits_ref, labels = chosen_batch['input_ids'], loss_mask = chosen_batch['loss_mask'])
        rejected_sequence_log_probs_ref = _get_squence_log_probs(rejected_logits_ref, labels = rejected_batch['input_ids'], loss_mask = rejected_batch['loss_mask'])
    
    losses, chosen_rewards, rejected_rewards = preference_loss(chosen_sequence_log_probs, rejected_sequence_log_probs, chosen_sequence_log_probs_ref, rejected_sequence_log_probs_ref)
    
    return losses, chosen_rewards, rejected_rewards
def train_dpo(model, ref_model, optimizer, train_loader, train_settings, device):
    
    global_steps = 0
    record_list = []
    model = model.to(device)
    ref_model = ref_model.to(device)
    for epoch in range(train_settings.num_epochs):
        
        for batch in train_loader:
            #print(global_steps)
            model.train()
            optimizer.zero_grad()
            
            losses, chosen_rewards, rejected_rewards = compute_batch_loss(batch, model, ref_model, device)
    
            loss = losses.mean()
            chosen_reward = chosen_rewards.mean()
            rejected_reward = rejected_rewards.mean()
            
            loss.backward()
            
            optimizer.step()
            
            global_steps += 1
            if global_steps % train_settings.log_freq == 0:
                #model.eval()
                record = {"epoch": epoch,
                          "step": global_steps,
                          "train_loss": loss.detach().item(),
                          "chosen_reward": chosen_reward.item(),
                          "rejected_reward": rejected_reward.item()
                          }
                print(record)
                record_list.append(record)
                
    return record_list
    
from omegaconf import OmegaConf

train_settings = {
    "pretrained_model_name": "Qwen/Qwen2.5-0.5B",
    "learning_rate": 5e-6,
    "num_epochs": 10,
    "batch_size": 4,
    "weight_decay": 0.1,
    "seed": 1,
    "log_freq": 50
}
batch_size = 16
train_dataloader = DataLoader(dataset['train'], 
                              batch_size= batch_size, 
                              #num_workers=num_workers, 
                              shuffle=True, 
                              collate_fn=partial(custom_collate_fn, tokenizer=tokenizer))

device = 'cuda'
model = AutoModelForCausalLM.from_pretrained(model_name)
ref_model = AutoModelForCausalLM.from_pretrained(model_name)
# ref_model.load_state_dict(model.state_dict())
for param in ref_model.parameters():
    param.require_grad = False


optimizer = torch.optim.Adam(model.parameters(), lr=5e-6)
# train model
train_dpo(model, ref_model, optimizer, train_dataloader, OmegaConf.create(train_settings), device)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In [14], line 29
     27 optimizer = torch.optim.Adam(model.parameters(), lr=5e-6)
     28 # train model
---> 29 train_dpo(model, ref_model, optimizer, train_dataloader, OmegaConf.create(train_settings), device)

Cell In [13], line 14, in train_dpo(model, ref_model, optimizer, train_loader, train_settings, device)
     11 model.train()
     12 optimizer.zero_grad()
---> 14 losses, chosen_rewards, rejected_rewards = compute_batch_loss(batch, model, ref_model, device)
     16 loss = losses.mean()
     17 chosen_reward = chosen_rewards.mean()

Cell In [12], line 22, in compute_batch_loss(batch, model, ref_model, device)
     19     chosen_logits_ref = ref_model(input_ids=chosen_batch['input_ids'], attention_mask=chosen_batch['attention_mask']).logits
     20     rejected_logits_ref = ref_model(input_ids=rejected_batch['input_ids'], attention_mask=rejected_batch['attention_mask']).logits
---> 22     chosen_sequence_log_probs_ref = _get_squence_log_probs(chosen_logits_ref, labels = chosen_batch['input_ids'], loss_mask = chosen_batch['loss_mask'])
     23     rejected_sequence_log_probs_ref = _get_squence_log_probs(rejected_logits_ref, labels = rejected_batch['input_ids'], loss_mask = rejected_batch['loss_mask'])
     25 losses, chosen_rewards, rejected_rewards = preference_loss(chosen_sequence_log_probs, rejected_sequence_log_probs, chosen_sequence_log_probs_ref, rejected_sequence_log_probs_ref)

Cell In [11], line 46, in _get_squence_log_probs(logits, labels, loss_mask, average_log_prob)
     40 assert logits.shape[:-1] == labels.shape
     41 #assert labels.shape == loss_mask.shape
     42 # let the sequence be A, B, C, D
     43 # labels[:,1:] are B, C, D
     44 # logits corresponds to B, C, D, X
     45 # logits[:,:-1,:] corresponds to B, C, D
---> 46 labels = labels[:,1:].clone() # labels 
     47 logits = logits[:,:-1,:]
     48 loss_mask = loss_mask[:, 1:]

KeyboardInterrupt: