diff --git a/colossalai/shardformer/policies/mixtral.py b/colossalai/shardformer/policies/mixtral.py index 10f54e1a4..8fed5ee5c 100644 --- a/colossalai/shardformer/policies/mixtral.py +++ b/colossalai/shardformer/policies/mixtral.py @@ -26,18 +26,8 @@ class MixtralPolicy(Policy): pass def preprocess(self): + self.tie_weight = self.tie_weight_check() self.origin_attn_implement = self.model.config._attn_implementation - # if self.shard_config.enable_tensor_parallelism: - # # non-moe params tensor parallelism - - # # Resize embedding - # vocab_size = self.model.config.vocab_size - # world_size = self.shard_config.tensor_parallel_size - - # if vocab_size % world_size != 0: - # new_vocab_size = vocab_size + world_size - vocab_size % world_size - # self.model.resize_token_embeddings(new_vocab_size) - return self.model def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: diff --git a/tests/test_shardformer/test_model/test_shard_mixtral.py b/tests/test_shardformer/test_model/test_shard_mixtral.py index 2e2b675a4..e873f46f7 100644 --- a/tests/test_shardformer/test_model/test_shard_mixtral.py +++ b/tests/test_shardformer/test_model/test_shard_mixtral.py @@ -67,12 +67,7 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, rank = dist.get_rank() # for p1, p2 in zip(mixtral_model.parameters(), sharded_optimizer._master_param_groups_of_current_rank[0]): for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): - try: - assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False) - print(f"{rank=},passed grad: {n1}, {n2}") - except Exception as e: - print(f"{rank=},failed grad: {n1} {p1.grad[:2,:2]}, {n2} {p2.grad[:2, :2]}") - raise e + assert_close(p1.grad, p2.grad, atol=5e-3, rtol=5e-3, check_dtype=False) # Save gradient tensors for comparison between the original model and the sharded model before optimizer step. grads_to_check = {} @@ -108,25 +103,14 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, # print(grads_to_check) check_all_grad_tensors(grads_to_check) for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): - try: - assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) - print(f"{rank=},passed param before step: {n1}, {n2}") - except Exception: - print( - f"{rank=},failed param before step: {n1} {p1[:2,:2] if p1 else None}, {n2} {p2[:2, :2] if p2 else None}" - ) + assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) + # optimizer executes step org_optimizer.step() sharded_optimizer.step() for (n1, p1), (n2, p2) in zip(mixtral_model.named_parameters(), shard_mixtral_model.named_parameters()): - try: - assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) - print(f"{rank=},passed param after step: {n1}, {n2}") - except Exception as e: - print( - f"{rank=},failed param after step: {n1} {p1 if p1 is not None else None}, {n2} {p2 if p2 is not None else None}" - ) - raise e + assert_close(p1, p2, atol=5e-3, rtol=5e-3, check_dtype=False) + # check weights if stage_manager is None or stage_manager.is_first_stage(): if test_config["precision"] == "fp32":