-
Notifications
You must be signed in to change notification settings - Fork 2.8k
Open
Description
class MoEFeedForward:
....
if self.training:
x = x.repeat_interleave(self.config.num_experts_per_tok, dim=0)
y = torch.empty_like(x, dtype=torch.float16)
for i, expert in enumerate(self.experts):
y[flat_topk_idx == i] = expert(x[flat_topk_idx == i]).to(y.dtype)
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1) #这里是dim=1 还是dim=2
y = y.view(*orig_shape)
thanks
Metadata
Metadata
Assignees
Labels
No labels