diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index becd08501..ba30549bc 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -15,7 +15,7 @@ from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.nn.optimizer import ColoOptimizer from functools import partial -from _utils import set_seed +from _utils import tensor_equal, tensor_shard_equal, set_seed def init_1d_row_linear(weight): @@ -144,20 +144,8 @@ def run_1d_hybrid_tp(model_name): with torch.no_grad(): # check param - for p1, p2 in zip(model.parameters(), model_torch.parameters()): - if p1.size() == p2.size(): - assert torch.allclose(p1, p2) - else: - # TODO(jzy) Only check 1D spec. Need to be replaced by new DistSpec. - if p1.size(-1) < p2.size(-1): # col - world_size = p2.size(-1) // p1.size(-1) - split_p2 = torch.chunk(p2, world_size, dim=-1)[0] - - elif p1.size(0) < p2.size(0): # row - world_size = p2.size(0) // p1.size(0) - split_p2 = torch.chunk(p2, world_size, dim=0)[0] - - assert torch.allclose(p1, split_p2) + for p, torch_p in zip(model.parameters(), model_torch.parameters()): + assert tensor_shard_equal(torch_p, p) if i > 5: break