From 7c079d9c334c73b8d3d0cc44beca0dee3f203e77 Mon Sep 17 00:00:00 2001
From: HELSON <72907851+1SAA@users.noreply.github.com>
Date: Fri, 11 Mar 2022 18:12:46 +0800
Subject: [PATCH] [hotfix] fixed bugs in ShardStrategy and PcieProfiler (#394)

---
 colossalai/utils/profiler/pcie_profiler.py           | 2 +-
 colossalai/zero/shard_utils/tensor_shard_strategy.py | 5 +++--
 2 files changed, 4 insertions(+), 3 deletions(-)

diff --git a/colossalai/utils/profiler/pcie_profiler.py b/colossalai/utils/profiler/pcie_profiler.py
index 2d325da69..a01a37489 100644
--- a/colossalai/utils/profiler/pcie_profiler.py
+++ b/colossalai/utils/profiler/pcie_profiler.py
@@ -79,7 +79,7 @@ class PcieProfiler(BaseProfiler):
         if self.profiler.enabled:
             events = self.profiler.function_events
             for event in events:
-                if event.name == "aten::_to_copy":
+                if event.name == "aten::copy_":
                     t_shape = event.input_shapes[0]
                     if len(t_shape) == 0 or event.cuda_time_total == 0:
                         continue
diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py
index 94d40e9fb..08ac39e7d 100644
--- a/colossalai/zero/shard_utils/tensor_shard_strategy.py
+++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py
@@ -5,6 +5,7 @@ import torch.distributed as dist
 from colossalai.zero.shard_utils import BaseShardStrategy
 from colossalai.zero.sharded_model._zero3_utils import get_shard
 from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
+from colossalai.utils import get_current_device
 
 
 class TensorShardStrategy(BaseShardStrategy):
@@ -35,9 +36,9 @@ class TensorShardStrategy(BaseShardStrategy):
         payload_numel = t.payload.numel()
         for i in range(self.world_size):
             if i == self.local_rank:
-                buffer_list.append(t.payload.cuda())
+                buffer_list.append(t.payload.cuda(get_current_device()))
             else:
-                buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype).cuda())
+                buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
 
         torch.distributed.all_gather(buffer_list,
                                      buffer_list[self.local_rank],