Browse Source

[moe] fix MoE bugs (#1628)

* remove forced FP32 modules

* correct no_shard-contexts' positions
pull/1630/head
HELSON 2 years ago committed by GitHub
parent
commit
f7f2248771
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      colossalai/nn/layer/moe/experts.py
  2. 31
      colossalai/nn/layer/moe/layers.py
  3. 3
      colossalai/zero/init_ctx/init_context.py
  4. 7
      tests/test_moe/test_kernel.py
  5. 8
      tests/test_moe/test_moe_zero_init.py
  6. 6
      tests/test_moe/test_moe_zero_optim.py
  7. 2
      tests/test_zero/common.py

2
colossalai/nn/layer/moe/experts.py

@ -24,6 +24,7 @@ class MoeExperts(nn.Module):
self.num_local_experts, self.dist_info = MOE_CONTEXT.get_info(num_experts)
@no_shard_zero_decrator(is_replicated=False)
class Experts(MoeExperts):
"""A wrapper class to create experts. It will create E experts across the
moe model parallel group, where E is the number of experts. Every expert
@ -35,7 +36,6 @@ class Experts(MoeExperts):
expert_args: Args used to initialize experts, the args could be found in corresponding expert class
"""
@no_shard_zero_decrator(is_replicated=False)
def __init__(self, expert_cls: Type[nn.Module], num_experts: int, **expert_args):
super().__init__("all_to_all", num_experts)

31
colossalai/nn/layer/moe/layers.py

@ -228,6 +228,7 @@ class FP32LinearGate(nn.Module):
return F.linear(x, self.weight)
@no_shard_zero_decrator(is_replicated=True)
class MoeLayer(nn.Module):
"""A MoE layer, that puts its input tensor to its gate and uses the output logits
to router all tokens, is mainly used to exchange all tokens for every expert across
@ -241,12 +242,11 @@ class MoeLayer(nn.Module):
experts (:class:`torch.nn.Module`): Instance of experts generated by Expert.
"""
@no_shard_zero_decrator(is_replicated=True)
def __init__(self, dim_model: int, num_experts: int, router: nn.Module, experts: MoeExperts):
super().__init__()
self.d_model = dim_model
self.num_experts = num_experts
self.gate = FP32LinearGate(dim_model, num_experts)
self.gate_weight = torch.nn.Parameter(torch.empty(num_experts, dim_model))
self.router = router
self.experts = experts
self.use_kernel = True if COL_MOE_KERNEL_FLAG and MOE_CONTEXT.use_kernel_optim else False
@ -254,16 +254,14 @@ class MoeLayer(nn.Module):
self.ep_size = experts.dist_info.ep_size
self.num_local_experts = experts.num_local_experts
nn.init.trunc_normal_(self.gate_weight, std=math.sqrt(0.1 / dim_model))
def a2a_process(self, dispatch_data: torch.Tensor):
expert_input = AllToAll.apply(dispatch_data, self.ep_group)
input_shape = expert_input.shape
expert_input = expert_input.reshape(self.ep_size, self.num_local_experts, -1, self.d_model)
expert_output = self.experts(expert_input)
expert_output = expert_output.reshape(input_shape)
expert_output = AllToAll.apply(expert_output, self.ep_group)
return expert_output
@ -274,16 +272,22 @@ class MoeLayer(nn.Module):
return expert_out
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
# reshape the input tokens
tokens = inputs.reshape(-1, self.d_model)
fp32_input = tokens.to(torch.float32) if inputs.dtype != torch.float32 else tokens
gate_output = self.gate(fp32_input)
router_res = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
# the data type of the inputs in the gating should be fp32
fp32_input = tokens.to(torch.float)
fp32_weight = self.gate_weight.to(torch.float)
gate_output = F.linear(fp32_input, fp32_weight)
# the result from the router
route_result_list = self.router(inputs=gate_output, use_kernel=self.use_kernel, ep_group=self.ep_group)
if self.use_kernel:
dispatch_data = MoeDispatch.apply(tokens, *router_res[1:])
dispatch_data = MoeDispatch.apply(tokens, *route_result_list[1:])
dispatch_data = dispatch_data.reshape(self.num_experts, -1, self.d_model)
else:
sec_mask_f = router_res[1].type_as(inputs)
sec_mask_f = route_result_list[1].type_as(inputs)
dispatch_data = torch.matmul(sec_mask_f.permute(1, 2, 0), tokens)
# dispatch_data [e, c, h]
@ -295,12 +299,11 @@ class MoeLayer(nn.Module):
raise NotImplementedError("This kind of communication has not been implemented yet.\n Please use Experts "
"build function.")
# expert_output [e, c, h]
if self.use_kernel:
expert_output = expert_output.reshape(-1, self.d_model)
ans = MoeCombine.apply(expert_output, *router_res)
ans = MoeCombine.apply(expert_output, *route_result_list)
else:
combine_weights = router_res[0].type_as(inputs)
combine_weights = route_result_list[0].type_as(inputs)
combine_weights = combine_weights.view(combine_weights.shape[0], -1)
expert_output = expert_output.view(-1, expert_output.shape[-1])
ans = torch.matmul(combine_weights, expert_output)

3
colossalai/zero/init_ctx/init_context.py

@ -258,7 +258,8 @@ def no_shard_zero_decrator(is_replicated: bool = True):
def _no_shard(*args, **kwargs):
with no_shard_zero_context(is_replicated):
init_func(*args, **kwargs)
ret = init_func(*args, **kwargs)
return ret
return _no_shard

7
tests/test_moe/test_kernel.py

@ -38,6 +38,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
expert_factor = dict(in_features=hidden_size, out_features=hidden_size, device=get_current_device())
expert = Experts(expert_module, NUM_EXPERTS, **expert_factor)
layer = MoeLayer(hidden_size, NUM_EXPERTS, router(capacity_factor_train=1.0), expert)
layer = layer.to(get_current_device())
if data_type == torch.float16:
layer = layer.half()
@ -50,11 +51,11 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
# save all results
o_tk_grad = tokens.grad.data.clone()
o_gt_grad = layer.gate.weight.grad.data.clone()
o_gt_grad = layer.gate_weight.grad.data.clone()
# reset all gradients
tokens.grad.zero_()
layer.gate.weight.grad.zero_()
layer.gate_weight.grad.zero_()
layer.use_kernel = True
new_out = layer(tokens) # get ouputs through colossal kernel
@ -67,7 +68,7 @@ def run_routing(rank, world_size, port, rs=2, hidden_size=128, data_type=torch.f
new_out.backward(grad) # get new type gradient
n_tk_grad = tokens.grad.data.clone()
n_gt_grad = layer.gate.weight.grad.data.clone()
n_gt_grad = layer.gate_weight.grad.data.clone()
if data_type == torch.float32:
check_equal(o_tk_grad, n_tk_grad)

8
tests/test_moe/test_moe_zero_init.py

@ -58,15 +58,9 @@ def run_moe_zero_init(init_device_type, shard_strategy_class):
for name, param in model.named_parameters():
assert hasattr(param, 'colo_attr')
# the weights in the gate should be fp32
if 'gate' in name:
assert param.colo_attr.sharded_data_tensor.dtype == torch.float32
else:
assert param.colo_attr.sharded_data_tensor.dtype == torch.half
# the parameters in moe experts and its gate should not be sharded
if ('experts' in name) or ('gate' in name) or ('residual_combine' in name):
assert not param.colo_attr.sharded_data_tensor.is_sharded
assert not param.colo_attr.sharded_data_tensor.is_sharded, "`{}` parameter has problem".format(name)
else:
assert param.colo_attr.sharded_data_tensor.is_sharded

6
tests/test_moe/test_moe_zero_optim.py

@ -94,12 +94,6 @@ def _run_test_sharded_optim_v2(cpu_offload,
apex_model, apex_optimizer = convert_to_apex_amp(model, optim, amp_config)
apex_grad_handler = MoeGradientHandler(model)
# Since MOE is not compatible with apex_amp now, we need to convert gate weight to fp32
for (n, p), zp in zip(apex_model.named_parameters(), zero_model.parameters()):
if 'gate' in n:
p.data = p.float()
p.data.copy_(zp.colo_attr.data_payload)
for i, (data, label) in enumerate(train_dataloader):
if i > 5:
break

2
tests/test_zero/common.py

@ -135,5 +135,5 @@ def check_sharded_model_params(model, zero_model, loose=False, reuse_fp16_shard=
else:
zero_p = zero_p.colo_attr.data_payload.to(p.device)
assert p.dtype == zero_p.dtype
assert p.dtype == zero_p.dtype, "Parameter `{}`:\n{} vs {}".format(name, p.dtype, zero_p.dtype)
assert allclose(p, zero_p, loose=loose), f'{p} vs {zero_p}'

Loading…
Cancel
Save