fixed typo in ShardParam (#294)

pull/394/head
Frank Lee 2022-03-02 17:26:23 +08:00
parent 27155b8513
commit 9afb5c8b2d
1 changed files with 23 additions and 19 deletions

View File

@ -1,25 +1,28 @@
from enum import Enum
from optparse import Option
import torch
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
import torch.distributed as dist
class TensorType(Enum):
GRAD = 1
DATA = 2
class ShardParam(object):
r"""
A wrapper to torch.nn.Parameter. Shard a param
on different processes.
"""
def __init__(self,
param: torch.nn.Parameter,
tensor_type: TensorType = TensorType.DATA,
process_group = None,
) -> None:
def __init__(
self,
param: torch.nn.Parameter,
tensor_type: TensorType = TensorType.DATA,
process_group=None,
) -> None:
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
@ -27,27 +30,27 @@ class ShardParam(object):
self._payload_numel = None
self._origin_shape = param.shape
self._origin_numel = param.numel()
self.is_shared = False
def payload(self, target_device : torch.device):
self.is_sharded = False
def payload(self, target_device: torch.device):
return self._param_payload.to(target_device)
def shard(self):
r"""
Distributed the payload of param to all processes.
"""
if self.is_shared:
if self.is_sharded:
return
self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size)
self.is_shared = True
self.is_sharded = True
def gather(self):
r"""
Collect the payload of param from different processes to process of local rank.
"""
if not self.is_shared:
if not self.is_sharded:
return
buffer_list = []
payload_numel = self._param_payload.numel()
for i in range(self.world_size):
@ -55,9 +58,10 @@ class ShardParam(object):
buffer_list.append(self._param_payload.cuda())
else:
buffer_list.append(torch.zeros(payload_numel).cuda())
torch.distributed.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group, async_op=False)
print(buffer_list)
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
self.is_shared = False
torch.distributed.all_gather(buffer_list,
buffer_list[self.local_rank],
group=self.process_group,
async_op=False)
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
self.is_sharded = False