pull/6124/head
flybird11111 2024-11-11 18:01:02 +08:00
parent 89a9a600bc
commit 393c31da61
1 changed files with 0 additions and 5 deletions

View File

@ -111,11 +111,6 @@ def search_tp_partition_dim(current_shape: torch.Size, original_shape: torch.Siz
if length > current_shape[dim]: if length > current_shape[dim]:
partition_dim = dim partition_dim = dim
break break
if partition_dim is not None:
assert (
original_shape[partition_dim] == tp_size * current_shape[partition_dim]
), f"The parameter isn't evenly distributed among tensor parallel group: \
shape before sharding {original_shape}, shape after sharding {current_shape}"
return partition_dim return partition_dim