mirror of https://github.com/hpcaitech/ColossalAI
Fixed parameter initialization in FFNExpert (#251)
parent
e13293bb4c
commit
36b8477228
|
@ -21,17 +21,15 @@ class AllToAll(torch.autograd.Function):
|
|||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx: Any,
|
||||
inputs: Tensor,
|
||||
parallel_mode: ParallelMode) -> Tensor:
|
||||
def forward(ctx: Any, inputs: Tensor, parallel_mode: ParallelMode) -> Tensor:
|
||||
if ctx is not None:
|
||||
ctx.parallel_mode = parallel_mode
|
||||
if not inputs.is_contiguous():
|
||||
inputs = inputs.contiguous()
|
||||
|
||||
if gpc.get_world_size(parallel_mode) == 1:
|
||||
return inputs
|
||||
output = torch.empty_like(inputs)
|
||||
dist.all_to_all_single(output, inputs,
|
||||
group=gpc.get_group(parallel_mode))
|
||||
dist.all_to_all_single(output, inputs, group=gpc.get_group(parallel_mode))
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
|
@ -58,8 +56,7 @@ class MoeDispatch(torch.autograd.Function):
|
|||
@staticmethod
|
||||
def backward(ctx, output_grad):
|
||||
mask, dest_idx = ctx.saved_tensors
|
||||
d_tokens = colossal_moe_cuda.dispatch_backward(
|
||||
ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||
d_tokens = colossal_moe_cuda.dispatch_backward(ctx.s, ctx.ec, ctx.h, output_grad, mask, dest_idx)
|
||||
return d_tokens, None, None, None
|
||||
|
||||
|
||||
|
@ -76,9 +73,7 @@ class MoeCombine(torch.autograd.Function):
|
|||
|
||||
fp16_flag = (expert_tokens.dtype == torch.float16)
|
||||
cb_input = expert_tokens.to(torch.float32) if fp16_flag else expert_tokens
|
||||
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h,
|
||||
cb_input, logits,
|
||||
mask, dest_idx)
|
||||
ctokens = colossal_moe_cuda.combine_forward(s, e, c, h, cb_input, logits, mask, dest_idx)
|
||||
output = ctokens.to(torch.float16) if fp16_flag else ctokens
|
||||
|
||||
ctx.save_for_backward(expert_tokens, logits, mask, dest_idx)
|
||||
|
@ -97,9 +92,8 @@ class MoeCombine(torch.autograd.Function):
|
|||
cb_grad = tokens_grad.to(torch.float32) if tokens_grad.dtype is torch.float16 \
|
||||
else tokens_grad
|
||||
cb_input = expert_tokens.to(torch.float32) if ctx.fp16_flag else expert_tokens
|
||||
d_expert, d_logits = colossal_moe_cuda.combine_backward(
|
||||
ctx.s, ctx.e, ctx.c, ctx.h,
|
||||
cb_grad, cb_input, logits, mask, dest_idx)
|
||||
d_expert, d_logits = colossal_moe_cuda.combine_backward(ctx.s, ctx.e, ctx.c, ctx.h, cb_grad, cb_input, logits,
|
||||
mask, dest_idx)
|
||||
d_expert = d_expert.to(torch.float16) if ctx.fp16_flag else d_expert
|
||||
|
||||
return d_expert, d_logits, None, None, None
|
||||
|
|
|
@ -62,10 +62,12 @@ class FFNExperts(nn.Module):
|
|||
|
||||
s1 = math.sqrt(0.1 / d_model)
|
||||
s2 = math.sqrt(0.1 / d_ff)
|
||||
nn.init.trunc_normal_(self.w1, std=s1)
|
||||
nn.init.trunc_normal_(self.b1, std=s1)
|
||||
nn.init.trunc_normal_(self.w2, std=s2)
|
||||
nn.init.trunc_normal_(self.b2, std=s2)
|
||||
|
||||
with seed(ParallelMode.MOE_MODEL):
|
||||
nn.init.trunc_normal_(self.w1, std=s1)
|
||||
nn.init.trunc_normal_(self.b1, std=s1)
|
||||
nn.init.trunc_normal_(self.w2, std=s2)
|
||||
nn.init.trunc_normal_(self.b2, std=s2)
|
||||
|
||||
self.act = nn.GELU() if activation is None else activation
|
||||
self.drop = nn.Dropout(p=drop_rate)
|
||||
|
|
Loading…
Reference in New Issue