mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #401 from hpcaitech/develop
commit
fc5101f24c
|
@ -11,7 +11,7 @@
|
|||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> 论坛 </a> |
|
||||
<a href="https://medium.com/@hpcaitech"> 博客 </a></h3>
|
||||
|
||||
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/PR_CI.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/PR_CI.yml)
|
||||
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
|
||||
[![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest)
|
||||
[![codebeat badge](https://codebeat.co/badges/bfe8f98b-5d61-4256-8ad2-ccd34d9cc156)](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main)
|
||||
[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
<a href="https://github.com/hpcaitech/ColossalAI/discussions"> Forum </a> |
|
||||
<a href="https://medium.com/@hpcaitech"> Blog </a></h3>
|
||||
|
||||
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/PR_CI.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/PR_CI.yml)
|
||||
[![Build](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml/badge.svg)](https://github.com/hpcaitech/ColossalAI/actions/workflows/build.yml)
|
||||
[![Documentation](https://readthedocs.org/projects/colossalai/badge/?version=latest)](https://colossalai.readthedocs.io/en/latest/?badge=latest)
|
||||
[![codebeat badge](https://codebeat.co/badges/bfe8f98b-5d61-4256-8ad2-ccd34d9cc156)](https://codebeat.co/projects/github-com-hpcaitech-colossalai-main)
|
||||
[![slack badge](https://img.shields.io/badge/Slack-join-blueviolet?logo=slack&)](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w)
|
||||
|
|
|
@ -79,7 +79,7 @@ class PcieProfiler(BaseProfiler):
|
|||
if self.profiler.enabled:
|
||||
events = self.profiler.function_events
|
||||
for event in events:
|
||||
if event.name == "aten::_to_copy":
|
||||
if event.name == "aten::copy_":
|
||||
t_shape = event.input_shapes[0]
|
||||
if len(t_shape) == 0 or event.cuda_time_total == 0:
|
||||
continue
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch.distributed as dist
|
|||
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
|
||||
class TensorShardStrategy(BaseShardStrategy):
|
||||
|
@ -35,9 +36,9 @@ class TensorShardStrategy(BaseShardStrategy):
|
|||
payload_numel = t.payload.numel()
|
||||
for i in range(self.world_size):
|
||||
if i == self.local_rank:
|
||||
buffer_list.append(t.payload.cuda())
|
||||
buffer_list.append(t.payload.cuda(get_current_device()))
|
||||
else:
|
||||
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype).cuda())
|
||||
buffer_list.append(torch.zeros(payload_numel, dtype=t.dtype, device=get_current_device()))
|
||||
|
||||
torch.distributed.all_gather(buffer_list,
|
||||
buffer_list[self.local_rank],
|
||||
|
|
Loading…
Reference in New Issue