-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Open
Description
def dpo_loss(ref_probs, probs, mask, beta):
seq_lengths = mask.sum(dim=1, keepdim=True) # (batch_size, 1)
ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()
# 将 chosen 和 rejected 数据分开
batch_size = ref_probs.shape[0]
chosen_ref_probs = ref_probs[:batch_size // 2] # ref_probs 和 probs的shape应该是(2*batch_size, seq_len)
reject_ref_probs = ref_probs[batch_size // 2:]
chosen_probs = probs[:batch_size // 2]
reject_probs = probs[batch_size // 2:]
这里分别提取0-bs//2, bs//2-end,但是在输入处理时,把一个batch中的chosen和reject在dim=0维度拼接,模型输入维度的第0为应该是2*batch_size:
def train_epoch(epoch, wandb):
for step, batch in enumerate(train_loader):
x_chosen = batch['x_chosen'].to(args.device)
x_rejected = batch['x_rejected'].to(args.device)
y_chosen = batch['y_chosen'].to(args.device)
y_rejected = batch['y_rejected'].to(args.device)
mask_chosen = batch['mask_chosen'].to(args.device)
mask_rejected = batch['mask_rejected'].to(args.device)
# ref_probs 和 probs的shape应该是(2*batch_size, seq_len)因为模型输入x_chosen和x_rejected在0维进行了拼接
x = torch.cat([x_chosen, x_rejected], dim=0)
y = torch.cat([y_chosen, y_rejected], dim=0)
mask = torch.cat([mask_chosen, mask_rejected], dim=0)
Metadata
Metadata
Assignees
Labels
No labels