这是indexloc提供的服务,不要输入任何密码
Skip to content

dpo_loss中提取chosen和reject的logit的疑问 #451

@googlehjx

Description

@googlehjx
def dpo_loss(ref_probs, probs, mask, beta):  

    seq_lengths = mask.sum(dim=1, keepdim=True)  # (batch_size, 1)
    ref_probs = (ref_probs * mask).sum(dim=1) / seq_lengths.squeeze()
    probs = (probs * mask).sum(dim=1) / seq_lengths.squeeze()
    # 将 chosen 和 rejected 数据分开
    batch_size = ref_probs.shape[0]
    chosen_ref_probs = ref_probs[:batch_size // 2]   # ref_probs 和 probs的shape应该是(2*batch_size, seq_len)
    reject_ref_probs = ref_probs[batch_size // 2:]
    chosen_probs = probs[:batch_size // 2]
    reject_probs = probs[batch_size // 2:]

这里分别提取0-bs//2, bs//2-end,但是在输入处理时,把一个batch中的chosen和reject在dim=0维度拼接,模型输入维度的第0为应该是2*batch_size:

def train_epoch(epoch, wandb):

    for step, batch in enumerate(train_loader):
        x_chosen = batch['x_chosen'].to(args.device)
        x_rejected = batch['x_rejected'].to(args.device)
        y_chosen = batch['y_chosen'].to(args.device)
        y_rejected = batch['y_rejected'].to(args.device)
        mask_chosen = batch['mask_chosen'].to(args.device)
        mask_rejected = batch['mask_rejected'].to(args.device)
        # ref_probs 和 probs的shape应该是(2*batch_size, seq_len)因为模型输入x_chosen和x_rejected在0维进行了拼接
        x = torch.cat([x_chosen, x_rejected], dim=0)   
        y = torch.cat([y_chosen, y_rejected], dim=0)
        mask = torch.cat([mask_chosen, mask_rejected], dim=0)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions