moe sp + ep bug fix

colossalchat
haze188 2024-07-18 10:08:06 +00:00 committed by Hongxin Liu
parent 877d94bb8c
commit 2cddeac717
2 changed files with 16 additions and 12 deletions

View File

@ -221,7 +221,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
)
self.logger.info(
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size}\n"
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size=}\n"
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
ranks=[0],
)

View File

@ -37,7 +37,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
)
print(org_output.last_hidden_state.shape, sharded_output.last_hidden_state.shape)
stage_manager = booster.plugin.stage_manager
tp_group = booster.plugin.tp_group
@ -65,9 +64,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
):
rank = dist.get_rank()
# for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
name_to_p = {n: p for n, p in mixtral_model.named_parameters()}
for n, p in shard_mixtral_model.named_parameters():
zero_grad = sharded_optimizer.get_param_grad(p)
if name_to_p[n].grad is None:
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
continue
assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False)
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
grads_to_check = {}
@ -100,16 +103,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
grads_to_check.update(row_layer_grads)
# check grads
# print(grads_to_check)
check_all_grad_tensors(grads_to_check)
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
for n, p in shard_mixtral_model.named_parameters():
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
# optimizer executes step
org_optimizer.step()
sharded_optimizer.step()
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
for n, p in shard_mixtral_model.named_parameters():
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
# check weights
if stage_manager is None or stage_manager.is_first_stage():
@ -170,10 +174,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
"tp_size": 1,
"pp_size": 1,
"sp_size": 2,
"ep_size": 1,
"ep_size": 2,
"enable_sequence_parallelism": True,
"sequence_parallelism_mode": "all_to_all",
"zero_stage": 0,
"zero_stage": 1,
"overlap_communication": False,
"precision": "fp16",
"initial_scale": 1,