Fixed parameter initialization in FFNExpert (#251)

pull/394/head
HELSON 2022-02-27 14:01:25 +08:00 committed by Frank Lee
parent e13293bb4c
commit 36b8477228
2 changed files with 14 additions and 18 deletions

View File

@ -21,17 +21,15 @@ class AllToAll(torch.autograd.Function):
"""
@staticmethod
def forward(ctx: Any,
inputs: Tensor,
parallel_mode: ParallelMode) -> Tensor:
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
if ctx is not None:
ctx.parallel_mode = parallel_mode
if not inputs.is_contiguous():
inputs = inputs.contiguous()
if gpc.get_world_size(parallel_mode) == 1:
return inputs
output = torch.empty_like(inputs)
dist.all_to_all_single(output, inputs,
group=gpc.get_group(parallel_mode))
dist.all_to_all_single(output, inputs, group=gpc.get_group(parallel_mode))
return output
@staticmethod
@ -58,8 +56,7 @@ class MoeDispatch(torch.autograd.Function):
@staticmethod
def backward(ctx, output_grad):
mask, dest_idx = ctx.saved_tensors
d_tokens = colossal_moe_cuda.dispatch_backward(
ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
d_tokens = colossal_moe_cuda.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
return d_tokens, None, None, None
@ -76,9 +73,7 @@ class MoeCombine(torch.autograd.Function):
fp16_flag = (expert_tokens.dtype == torch.float16)
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h,
cb_input, logits,
mask, dest_idx)
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
output = ctokens.to(torch.float16) if fp16_flag else ctokens
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
@ -97,9 +92,8 @@ class MoeCombine(torch.autograd.Function):
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
else tokens_grad
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
d_expert, d_logits = colossal_moe_cuda.combine_backward(
ctx.s, ctx.e, ctx.c, ctx.h,
cb_grad, cb_input, logits, mask, dest_idx)
d_expert, d_logits = colossal_moe_cuda.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits,
mask, dest_idx)
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
return d_expert, d_logits, None, None, None

View File

@ -62,10 +62,12 @@ class FFNExperts(nn.Module):
s1 = math.sqrt(0.1 / d_model)
s2 = math.sqrt(0.1 / d_ff)
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
with seed(ParallelMode.MOE_MODEL):
nn.init.trunc_normal_(self.w1, std=s1)
nn.init.trunc_normal_(self.b1, std=s1)
nn.init.trunc_normal_(self.w2, std=s2)
nn.init.trunc_normal_(self.b2, std=s2)
self.act = nn.GELU() if activation is None else activation
self.drop = nn.Dropout(p=drop_rate)