Merge pull request #401 from hpcaitech/develop

pull/402/head
Frank Lee 2022-03-13 11:09:17 +08:00 committed by GitHub
commit fc5101f24c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 6 additions and 5 deletions

View File

@ -11,7 +11,7 @@
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
<a href="https://medium.com/@hpcaitech"> 博客 </a></h3>
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/PR_CI.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/PR_CI.yml)
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
[![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest)
[![codebeat badge](https://codebeat.co/badges/bfe8f98b-5d61-4256-8ad2-ccd34d9cc156)](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main)
[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&amp)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)

View File

@ -11,7 +11,7 @@
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
<a href="https://medium.com/@hpcaitech"> Blog </a></h3>
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/PR_CI.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/PR_CI.yml)
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
[![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest)
[![codebeat badge](https://codebeat.co/badges/bfe8f98b-5d61-4256-8ad2-ccd34d9cc156)](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main)
[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&amp)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)

View File

@ -79,7 +79,7 @@ class PcieProfiler(BaseProfiler):
if self.profiler.enabled:
events = self.profiler.function_events
for event in events:
if event.name == "aten::_to_copy":
if event.name == "aten::copy_":
t_shape = event.input_shapes[0]
if len(t_shape) == 0 or event.cuda_time_total == 0:
continue

View File

@ -5,6 +5,7 @@ import torch.distributed as dist
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.utils import get_current_device
class TensorShardStrategy(BaseShardStrategy):
@ -35,9 +36,9 @@ class TensorShardStrategy(BaseShardStrategy):
payload_numel = t.payload.numel()
for i in range(self.world_size):
if i == self.local_rank:
buffer_list.append(t.payload.cuda())
buffer_list.append(t.payload.cuda(get_current_device()))
else:
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype).cuda())
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
torch.distributed.all_gather(buffer_list,
buffer_list[self.local_rank],