diff --git a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py index fc3340981..d4226b108 100644 --- a/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/moe_hybrid_parallel_plugin.py @@ -221,7 +221,7 @@ class MoeHybridParallelPlugin(HybridParallelPlugin): ) self.logger.info( - f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size}\n" + f"{type(self).__name__}: {self.ep_size=} {self.moe_dp_size=} {self.moe_tp_size=} {self.sp_size=}\n" f"rank {dist.get_rank()} moe_dp_group {dist.get_process_group_ranks(self.moe_dp_group)} ep_group {dist.get_process_group_ranks(self.ep_group)} moe_tp_group {dist.get_process_group_ranks(self.moe_tp_group)} sp_group {dist.get_process_group_ranks(self.sp_group)}", ranks=[0], ) diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 232e16f3b..4e9c594d2 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -37,7 +37,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, org_loss, org_output, sharded_loss, sharded_output = run_forward_backward_with_hybrid_plugin( org_model, sharded_model, sharded_optimizer, data_gen_fn, output_transform_fn, criterion, booster ) - print(org_output.last_hidden_state.shape, sharded_output.last_hidden_state.shape) stage_manager = booster.plugin.stage_manager tp_group = booster.plugin.tp_group @@ -65,9 +64,13 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, and booster.plugin.shard_config.sequence_parallelism_mode == "all_to_all" ): rank = dist.get_rank() - # for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): - for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): - assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False) + name_to_p = {n: p for n, p in mixtral_model.named_parameters()} + for n, p in shard_mixtral_model.named_parameters(): + zero_grad = sharded_optimizer.get_param_grad(p) + if name_to_p[n].grad is None: + name_to_p[n].grad = torch.zeros_like(name_to_p[n].data) + continue + assert_close(name_to_p[n].grad, zero_grad, atol=5e-3, rtol=5e-3, check_dtype=False) # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -100,16 +103,17 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, grads_to_check.update(row_layer_grads) # check grads - # print(grads_to_check) check_all_grad_tensors(grads_to_check) - for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): - assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) + + for n, p in shard_mixtral_model.named_parameters(): + assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False) # optimizer executes step org_optimizer.step() sharded_optimizer.step() - for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): - assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) + + for n, p in shard_mixtral_model.named_parameters(): + assert_close(name_to_p[n], p, atol=5e-3, rtol=5e-3, check_dtype=False) # check weights if stage_manager is None or stage_manager.is_first_stage(): @@ -170,10 +174,10 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "tp_size": 1, "pp_size": 1, "sp_size": 2, - "ep_size": 1, + "ep_size": 2, "enable_sequence_parallelism": True, "sequence_parallelism_mode": "all_to_all", - "zero_stage": 0, + "zero_stage": 1, "overlap_communication": False, "precision": "fp16", "initial_scale": 1,