mirror of https://github.com/hpcaitech/ColossalAI
moe sp + ep bug fix
parent
877d94bb8c
commit
2cddeac717
|
@ -221,7 +221,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin):
|
|||
)
|
||||
|
||||
self.logger.info(
|
||||
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size}\n"
|
||||
f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size=}\n"
|
||||
f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}",
|
||||
ranks=[0],
|
||||
)
|
||||
|
|
|
@ -37,7 +37,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin(
|
||||
org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster
|
||||
)
|
||||
print(org_output.last_hidden_state.shape, sharded_output.last_hidden_state.shape)
|
||||
stage_manager = booster.plugin.stage_manager
|
||||
tp_group = booster.plugin.tp_group
|
||||
|
||||
|
@ -65,9 +64,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all"
|
||||
):
|
||||
rank = dist.get_rank()
|
||||
# for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]):
|
||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||
assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
name_to_p = {n: p for n, p in mixtral_model.named_parameters()}
|
||||
for n, p in shard_mixtral_model.named_parameters():
|
||||
zero_grad = sharded_optimizer.get_param_grad(p)
|
||||
if name_to_p[n].grad is None:
|
||||
name_to_p[n].grad = torch.zeros_like(name_to_p[n].data)
|
||||
continue
|
||||
assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
|
||||
# Save gradient tensors for comparison between the original model and the sharded model before optimizer step.
|
||||
grads_to_check = {}
|
||||
|
@ -100,16 +103,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
grads_to_check.update(row_layer_grads)
|
||||
|
||||
# check grads
|
||||
# print(grads_to_check)
|
||||
check_all_grad_tensors(grads_to_check)
|
||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
|
||||
for n, p in shard_mixtral_model.named_parameters():
|
||||
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
|
||||
# optimizer executes step
|
||||
org_optimizer.step()
|
||||
sharded_optimizer.step()
|
||||
for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()):
|
||||
assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
|
||||
for n, p in shard_mixtral_model.named_parameters():
|
||||
assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False)
|
||||
|
||||
# check weights
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
|
@ -170,10 +174,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn,
|
|||
"tp_size": 1,
|
||||
"pp_size": 1,
|
||||
"sp_size": 2,
|
||||
"ep_size": 1,
|
||||
"ep_size": 2,
|
||||
"enable_sequence_parallelism": True,
|
||||
"sequence_parallelism_mode": "all_to_all",
|
||||
"zero_stage": 0,
|
||||
"zero_stage": 1,
|
||||
"overlap_communication": False,
|
||||
"precision": "fp16",
|
||||
"initial_scale": 1,
|
||||
|
|
Loading…
Reference in New Issue