Browse Source

fixed typo in ShardParam (#294)

pull/394/head
Frank Lee 3 years ago
parent
commit
9afb5c8b2d
  1. 28
      colossalai/zero/shard_param/shard_param.py

28
colossalai/zero/shard_param/shard_param.py

@ -1,24 +1,27 @@
from enum import Enum
from optparse import Option
import torch
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
import torch.distributed as dist
class TensorType(Enum):
GRAD = 1
DATA = 2
class ShardParam(object):
r"""
A wrapper to torch.nn.Parameter. Shard a param
on different processes.
"""
def __init__(self,
def __init__(
self,
param: torch.nn.Parameter,
tensor_type: TensorType = TensorType.DATA,
process_group = None,
process_group=None,
) -> None:
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.world_size = dist.get_world_size(self.process_group)
@ -27,25 +30,25 @@ class ShardParam(object):
self._payload_numel = None
self._origin_shape = param.shape
self._origin_numel = param.numel()
self.is_shared = False
self.is_sharded = False
def payload(self, target_device : torch.device):
def payload(self, target_device: torch.device):
return self._param_payload.to(target_device)
def shard(self):
r"""
Distributed the payload of param to all processes.
"""
if self.is_shared:
if self.is_sharded:
return
self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size)
self.is_shared = True
self.is_sharded = True
def gather(self):
r"""
Collect the payload of param from different processes to process of local rank.
"""
if not self.is_shared:
if not self.is_sharded:
return
buffer_list = []
@ -56,8 +59,9 @@ class ShardParam(object):
else:
buffer_list.append(torch.zeros(payload_numel).cuda())
torch.distributed.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group, async_op=False)
print(buffer_list)
torch.distributed.all_gather(buffer_list,
buffer_list[self.local_rank],
group=self.process_group,
async_op=False)
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
self.is_shared = False
self.is_sharded = False

Loading…
Cancel
Save