fixed typo in ShardParam (#294)

pull/394/head
Frank Lee 3 years ago
parent 27155b8513
commit 9afb5c8b2d

@ -1,25 +1,28 @@
from enum import Enum from enum import Enum
from optparse import Option
import torch import torch
from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
import torch.distributed as dist import torch.distributed as dist
class TensorType(Enum): class TensorType(Enum):
GRAD = 1 GRAD = 1
DATA = 2 DATA = 2
class ShardParam(object): class ShardParam(object):
r""" r"""
A wrapper to torch.nn.Parameter. Shard a param A wrapper to torch.nn.Parameter. Shard a param
on different processes. on different processes.
""" """
def __init__(self,
param: torch.nn.Parameter, def __init__(
tensor_type: TensorType = TensorType.DATA, self,
process_group = None, param: torch.nn.Parameter,
) -> None: tensor_type: TensorType = TensorType.DATA,
process_group=None,
) -> None:
self.process_group = process_group or gpc.get_group(ParallelMode.DATA) self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.world_size = dist.get_world_size(self.process_group) self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group) self.local_rank = dist.get_rank(self.process_group)
@ -27,25 +30,25 @@ class ShardParam(object):
self._payload_numel = None self._payload_numel = None
self._origin_shape = param.shape self._origin_shape = param.shape
self._origin_numel = param.numel() self._origin_numel = param.numel()
self.is_shared = False self.is_sharded = False
def payload(self, target_device : torch.device): def payload(self, target_device: torch.device):
return self._param_payload.to(target_device) return self._param_payload.to(target_device)
def shard(self): def shard(self):
r""" r"""
Distributed the payload of param to all processes. Distributed the payload of param to all processes.
""" """
if self.is_shared: if self.is_sharded:
return return
self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size) self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size)
self.is_shared = True self.is_sharded = True
def gather(self): def gather(self):
r""" r"""
Collect the payload of param from different processes to process of local rank. Collect the payload of param from different processes to process of local rank.
""" """
if not self.is_shared: if not self.is_sharded:
return return
buffer_list = [] buffer_list = []
@ -56,8 +59,9 @@ class ShardParam(object):
else: else:
buffer_list.append(torch.zeros(payload_numel).cuda()) buffer_list.append(torch.zeros(payload_numel).cuda())
torch.distributed.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group, async_op=False) torch.distributed.all_gather(buffer_list,
print(buffer_list) buffer_list[self.local_rank],
group=self.process_group,
async_op=False)
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape) self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
self.is_shared = False self.is_sharded = False

Loading…
Cancel
Save