From 7c079d9c334c73b8d3d0cc44beca0dee3f203e77 Mon Sep 17 00:00:00 2001 From: HELSON <72907851+1SAA@users.noreply.github.com> Date: Fri, 11 Mar 2022 18:12:46 +0800 Subject: [PATCH] [hotfix] fixed bugs in ShardStrategy and PcieProfiler (#394) --- colossalai/utils/profiler/pcie_profiler.py | 2 +- colossalai/zero/shard_utils/tensor_shard_strategy.py | 5 +++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/colossalai/utils/profiler/pcie_profiler.py b/colossalai/utils/profiler/pcie_profiler.py index 2d325da69..a01a37489 100644 --- a/colossalai/utils/profiler/pcie_profiler.py +++ b/colossalai/utils/profiler/pcie_profiler.py @@ -79,7 +79,7 @@ class PcieProfiler(BaseProfiler): if self.profiler.enabled: events = self.profiler.function_events for event in events: - if event.name == "aten::_to_copy": + if event.name == "aten::copy_": t_shape = event.input_shapes[0] if len(t_shape) == 0 or event.cuda_time_total == 0: continue diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py index 94d40e9fb..08ac39e7d 100644 --- a/colossalai/zero/shard_utils/tensor_shard_strategy.py +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -5,6 +5,7 @@ import torch.distributed as dist from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +from colossalai.utils import get_current_device class TensorShardStrategy(BaseShardStrategy): @@ -35,9 +36,9 @@ class TensorShardStrategy(BaseShardStrategy): payload_numel = t.payload.numel() for i in range(self.world_size): if i == self.local_rank: - buffer_list.append(t.payload.cuda()) + buffer_list.append(t.payload.cuda(get_current_device())) else: - buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype).cuda()) + buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device())) torch.distributed.all_gather(buffer_list, buffer_list[self.local_rank],