19. *Lab: DPO Training#
import math
from typing import List, Optional, Tuple, Union
import os
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
import json
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
from functools import partial
from datasets import load_dataset
import collections
from torch.utils.data import DataLoader
from llm_lab.utils.common_utils import move_to_device
19.1. Data#
raw_dataset = load_dataset("HumanLLMs/Human-Like-DPO-Dataset")
raw_dataset['train'][652]
{'prompt': 'Oh, you like [insert interest]? Me too! What do you love about it?',
'chosen': "Yeah! I'm super passionate about music! 🎵 There's just something about how a good song can evoke emotions and transport you to a different time and place, you know? 🕰️ I love how it can bring people together, too. What about you? What kind of music are you into? 🎶 Do you have a favorite artist or genre? 🤔",
'rejected': "Good day. As a digital entity, I don't have a physical presence or a circadian rhythm, so I neither wake up early nor stay up late. I am designed to operate 24/7, providing assistance and responding to inquiries at any time. My purpose is to provide accurate and helpful information, and I do not have personal preferences or experiences."}
model_name = "Qwen/Qwen2.5-0.5B"
model_name = "MiniLLM/MiniLLM-gpt2-120M" # a tiny model for fast debug
tokenizer = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token_id = tokenizer.eos_token_id
def tokenize(example):
result_dict = {}
prompt_tokenized = tokenizer(example['prompt'])
chosen_encoded = tokenizer(example['chosen'])
rejected_encoded = tokenizer(example['rejected'])
result_dict['chosen_encoded'] = {'input_ids': prompt_tokenized['input_ids'] + chosen_encoded['input_ids'],
'attention_mask': prompt_tokenized['attention_mask'] + chosen_encoded['attention_mask'],
'loss_mask': [0] * len(prompt_tokenized['input_ids']) + [1] * len(chosen_encoded['input_ids']),
}
result_dict['rejected_encoded'] = {'input_ids': prompt_tokenized['input_ids'] + rejected_encoded['input_ids'],
'attention_mask': prompt_tokenized['attention_mask'] + rejected_encoded['attention_mask'],
'loss_mask': [0] * len(prompt_tokenized['input_ids']) + [1] * len(rejected_encoded['input_ids']),
}
example.update(result_dict)
return example
tokenized_dataset = raw_dataset.map(tokenize, remove_columns=['prompt','chosen','rejected'])
def custom_collate_fn(batch, tokenizer, ignore_idx=-100):
max_len = max([len(e[type]['input_ids']) for type in ['chosen_encoded','rejected_encoded'] for e in batch])
result_dict = {}
for type in ['chosen_encoded','rejected_encoded']:
if type not in result_dict:
result_dict[type] = collections.defaultdict(list)
for e in batch:
needed = max_len - len(e[type]['input_ids'])
e[type]['input_ids'] += [tokenizer.pad_token_id] * needed
e[type]['attention_mask'] += [0] * needed
e[type]['loss_mask'] += [0] * needed
result_dict[type]['input_ids'].append(e[type]['input_ids'])
result_dict[type]['attention_mask'].append(e[type]['attention_mask'])
result_dict[type]['loss_mask'].append(e[type]['loss_mask'])
for type in ['chosen_encoded','rejected_encoded']:
for key in result_dict[type]:
result_dict[type][key] = torch.LongTensor(result_dict[type][key])
return result_dict
dataset = tokenized_dataset['train'].train_test_split(test_size=0.05)
# for batch in train_dataloader:
# #print(batch)
# chosen_batch = move_to_device(batch['chosen_encoded'], device)
# chosen_logits = model(input_ids=chosen_batch['input_ids'], attention_mask=chosen_batch['attention_mask']).logits
# print(chosen_logits)
# break
19.1.1. Preference Learning utility functions#
def preference_loss(
chosen_log_probs: torch.FloatTensor,
rejected_log_probs: torch.FloatTensor,
reference_chosen_log_probs: torch.FloatTensor,
reference_rejected_log_probs: torch.FloatTensor,
beta: float = 0.1, # suggested value in the DPO paper
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
"""
Args:
chosen_log_probs: log probabilities of the policy model for the chosen responses, shape: (batch_size,)
"""
pi_logratios = chosen_log_probs - rejected_log_probs
ref_logratios = reference_chosen_log_probs - reference_rejected_log_probs
logratios_difference = pi_logratios - ref_logratios
losses = - F.logsigmoid(beta * logratios_difference)
chosen_rewards = beta * (chosen_log_probs - reference_chosen_log_probs).detach()
rejected_rewards = beta * (rejected_log_probs - reference_rejected_log_probs).detach()
return losses, chosen_rewards, rejected_rewards
def _get_squence_log_probs(
logits: torch.FloatTensor,
labels: torch.LongTensor,
loss_mask: torch.LongTensor,
average_log_prob: bool = False,
) -> torch.FloatTensor:
"""
Args:
logits: logits of the model output. Shape: (batch_size, seq_length, vocab_size)
labels: labels for which token's log probability; label = -100 indicates ignore. Shape (batch_size, seq_length)
"""
assert logits.shape[:-1] == labels.shape
#assert labels.shape == loss_mask.shape
# let the sequence be A, B, C, D
# labels[:,1:] are B, C, D
# logits corresponds to B, C, D, X
# logits[:,:-1,:] corresponds to B, C, D
labels = labels[:,1:].clone() # labels
logits = logits[:,:-1,:]
loss_mask = loss_mask[:, 1:]
# log_probs shape (batch_size, seq_len - 1, vocab_size)
# label shape before unsqueeze - (batch_size, seq_len - 1), after - (batch_size, seq_len - 1, vocab_size)
log_probs = logits.log_softmax(-1)
# per_token_logps shape (batch_size, seq_len - 1)
per_token_logps = torch.gather(log_probs, dim=2, index=labels.unsqueeze(2)).squeeze(2) # squeeze on the last dim
if average_log_prob:
return (per_token_logps * loss_mask).sum(-1) / loss_mask.sum(-1)
else:
return (per_token_logps * loss_mask).sum(-1)
19.2. Training#
def compute_batch_loss(batch, model, ref_model, device):
assert model.training == True
assert ref_model.training == False
chosen_batch = batch['chosen_encoded']
rejected_batch = batch['rejected_encoded']
chosen_batch = move_to_device(chosen_batch, device)
rejected_batch = move_to_device(rejected_batch, device)
chosen_logits = model(input_ids=chosen_batch['input_ids'], attention_mask=chosen_batch['attention_mask']).logits
rejected_logits = model(input_ids=rejected_batch['input_ids'], attention_mask=rejected_batch['attention_mask']).logits
chosen_sequence_log_probs = _get_squence_log_probs(chosen_logits, labels = chosen_batch['input_ids'], loss_mask = chosen_batch['loss_mask'])
rejected_sequence_log_probs = _get_squence_log_probs(rejected_logits, labels = rejected_batch['input_ids'], loss_mask = rejected_batch['loss_mask'])
with torch.no_grad():
chosen_logits_ref = ref_model(input_ids=chosen_batch['input_ids'], attention_mask=chosen_batch['attention_mask']).logits
rejected_logits_ref = ref_model(input_ids=rejected_batch['input_ids'], attention_mask=rejected_batch['attention_mask']).logits
chosen_sequence_log_probs_ref = _get_squence_log_probs(chosen_logits_ref, labels = chosen_batch['input_ids'], loss_mask = chosen_batch['loss_mask'])
rejected_sequence_log_probs_ref = _get_squence_log_probs(rejected_logits_ref, labels = rejected_batch['input_ids'], loss_mask = rejected_batch['loss_mask'])
losses, chosen_rewards, rejected_rewards = preference_loss(chosen_sequence_log_probs, rejected_sequence_log_probs, chosen_sequence_log_probs_ref, rejected_sequence_log_probs_ref)
return losses, chosen_rewards, rejected_rewards
def train_dpo(model, ref_model, optimizer, train_loader, train_settings, device):
global_steps = 0
record_list = []
model = model.to(device)
ref_model = ref_model.to(device)
for epoch in range(train_settings.num_epochs):
for batch in train_loader:
#print(global_steps)
model.train()
optimizer.zero_grad()
losses, chosen_rewards, rejected_rewards = compute_batch_loss(batch, model, ref_model, device)
loss = losses.mean()
chosen_reward = chosen_rewards.mean()
rejected_reward = rejected_rewards.mean()
loss.backward()
optimizer.step()
global_steps += 1
if global_steps % train_settings.log_freq == 0:
#model.eval()
record = {"epoch": epoch,
"step": global_steps,
"train_loss": loss.detach().item(),
"chosen_reward": chosen_reward.item(),
"rejected_reward": rejected_reward.item()
}
print(record)
record_list.append(record)
return record_list
from omegaconf import OmegaConf
train_settings = {
"pretrained_model_name": "Qwen/Qwen2.5-0.5B",
"learning_rate": 5e-6,
"num_epochs": 10,
"batch_size": 4,
"weight_decay": 0.1,
"seed": 1,
"log_freq": 50
}
batch_size = 16
train_dataloader = DataLoader(dataset['train'],
batch_size= batch_size,
#num_workers=num_workers,
shuffle=True,
collate_fn=partial(custom_collate_fn, tokenizer=tokenizer))
device = 'cuda'
model = AutoModelForCausalLM.from_pretrained(model_name)
ref_model = AutoModelForCausalLM.from_pretrained(model_name)
# ref_model.load_state_dict(model.state_dict())
for param in ref_model.parameters():
param.require_grad = False
optimizer = torch.optim.Adam(model.parameters(), lr=5e-6)
# train model
train_dpo(model, ref_model, optimizer, train_dataloader, OmegaConf.create(train_settings), device)
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
Cell In [14], line 29
27 optimizer = torch.optim.Adam(model.parameters(), lr=5e-6)
28 # train model
---> 29 train_dpo(model, ref_model, optimizer, train_dataloader, OmegaConf.create(train_settings), device)
Cell In [13], line 14, in train_dpo(model, ref_model, optimizer, train_loader, train_settings, device)
11 model.train()
12 optimizer.zero_grad()
---> 14 losses, chosen_rewards, rejected_rewards = compute_batch_loss(batch, model, ref_model, device)
16 loss = losses.mean()
17 chosen_reward = chosen_rewards.mean()
Cell In [12], line 22, in compute_batch_loss(batch, model, ref_model, device)
19 chosen_logits_ref = ref_model(input_ids=chosen_batch['input_ids'], attention_mask=chosen_batch['attention_mask']).logits
20 rejected_logits_ref = ref_model(input_ids=rejected_batch['input_ids'], attention_mask=rejected_batch['attention_mask']).logits
---> 22 chosen_sequence_log_probs_ref = _get_squence_log_probs(chosen_logits_ref, labels = chosen_batch['input_ids'], loss_mask = chosen_batch['loss_mask'])
23 rejected_sequence_log_probs_ref = _get_squence_log_probs(rejected_logits_ref, labels = rejected_batch['input_ids'], loss_mask = rejected_batch['loss_mask'])
25 losses, chosen_rewards, rejected_rewards = preference_loss(chosen_sequence_log_probs, rejected_sequence_log_probs, chosen_sequence_log_probs_ref, rejected_sequence_log_probs_ref)
Cell In [11], line 46, in _get_squence_log_probs(logits, labels, loss_mask, average_log_prob)
40 assert logits.shape[:-1] == labels.shape
41 #assert labels.shape == loss_mask.shape
42 # let the sequence be A, B, C, D
43 # labels[:,1:] are B, C, D
44 # logits corresponds to B, C, D, X
45 # logits[:,:-1,:] corresponds to B, C, D
---> 46 labels = labels[:,1:].clone() # labels
47 logits = logits[:,:-1,:]
48 loss_mask = loss_mask[:, 1:]
KeyboardInterrupt: