diff --git a/colossalai/checkpoint_io/utils.py b/colossalai/checkpoint_io/utils.py index b3917bd9d..786d2df86 100644 --- a/colossalai/checkpoint_io/utils.py +++ b/colossalai/checkpoint_io/utils.py @@ -111,11 +111,6 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz if length > current_shape[dim]: partition_dim = dim break - if partition_dim is not None: - assert ( - original_shape[partition_dim] == tp_size * current_shape[partition_dim] - ), f"The parameter isn't evenly distributed among tensor parallel group: \ - shape before sharding {original_shape}, shape after sharding {current_shape}" return partition_dim