diff --git a/colossalai/zero/shard_utils/commons.py b/colossalai/zero/shard_utils/commons.py index f24559644..71cef44c1 100644 --- a/colossalai/zero/shard_utils/commons.py +++ b/colossalai/zero/shard_utils/commons.py @@ -14,7 +14,9 @@ def get_shard(tensor: torch.Tensor, rank: int, world_size: int) -> Tuple[torch.T num_to_pad = chunks[0].numel() - chunks[rank].numel() assert num_to_pad >= 0, num_to_pad - shard = chunks[rank].clone() - if num_to_pad > 0: - shard = F.pad(shard, [0, num_to_pad]) + shard = torch.zeros_like(chunks[0]) + length = chunks[rank].size(0) + shard_temp = shard[:length] + shard_temp.copy_(chunks[rank]) + return shard, num_to_pad diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index 8fba5f73a..8857d7ae4 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -43,18 +43,16 @@ class TensorShardStrategy(BaseShardStrategy): if not t.is_sharded: return target_device = t.device - buffer_list = [] payload_numel = t.payload.numel() world_size = dist.get_world_size(process_group) rank = dist.get_rank(process_group) - for i in range(world_size): - if i == rank: - buffer_list.append(t.payload.cuda(get_current_device())) - else: - buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device())) + + buffer = torch.empty(payload_numel * world_size, dtype=t.payload.dtype, device=get_current_device()) + buffer_list = list(torch.chunk(buffer, chunks=world_size, dim=0)) + buffer_list[rank].copy_(t.payload) dist.all_gather(buffer_list, buffer_list[rank], group=process_group, async_op=False) - gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape) + gathered_payload = torch.narrow(buffer, 0, 0, t.origin_numel).reshape(t.origin_shape) t.reset_payload(gathered_payload) colo_model_data_tensor_move_inline(t, target_device) t.is_sharded = False