[hotfix] fixed bugs in ShardStrategy and PcieProfiler (#394)

pull/246/head
HELSON 2022-03-11 18:12:46 +08:00 committed by GitHub
parent 1e4bf85cdb
commit 7c079d9c33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 4 additions and 3 deletions

View File

@ -79,7 +79,7 @@ class PcieProfiler(BaseProfiler):
if self.profiler.enabled: if self.profiler.enabled:
events = self.profiler.function_events events = self.profiler.function_events
for event in events: for event in events:
if event.name == "aten::_to_copy": if event.name == "aten::copy_":
t_shape = event.input_shapes[0] t_shape = event.input_shapes[0]
if len(t_shape) == 0 or event.cuda_time_total == 0: if len(t_shape) == 0 or event.cuda_time_total == 0:
continue continue

View File

@ -5,6 +5,7 @@ import torch.distributed as dist
from colossalai.zero.shard_utils import BaseShardStrategy from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
class TensorShardStrategy(BaseShardStrategy): class TensorShardStrategy(BaseShardStrategy):
@ -35,9 +36,9 @@ class TensorShardStrategy(BaseShardStrategy):
payload_numel = t.payload.numel() payload_numel = t.payload.numel()
for i in range(self.world_size): for i in range(self.world_size):
if i == self.local_rank: if i == self.local_rank:
buffer_list.append(t.payload.cuda()) buffer_list.append(t.payload.cuda(get_current_device()))
else: else:
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype).cuda()) buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
torch.distributed.all_gather(buffer_list, torch.distributed.all_gather(buffer_list,
buffer_list[self.local_rank], buffer_list[self.local_rank],