|
|
@ -1,25 +1,28 @@
|
|
|
|
from enum import Enum
|
|
|
|
from enum import Enum
|
|
|
|
from optparse import Option
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch
|
|
|
|
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
|
|
|
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
from colossalai.core import global_context as gpc
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.distributed as dist
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TensorType(Enum):
|
|
|
|
class TensorType(Enum):
|
|
|
|
GRAD = 1
|
|
|
|
GRAD = 1
|
|
|
|
DATA = 2
|
|
|
|
DATA = 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ShardParam(object):
|
|
|
|
class ShardParam(object):
|
|
|
|
r"""
|
|
|
|
r"""
|
|
|
|
A wrapper to torch.nn.Parameter. Shard a param
|
|
|
|
A wrapper to torch.nn.Parameter. Shard a param
|
|
|
|
on different processes.
|
|
|
|
on different processes.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
def __init__(self,
|
|
|
|
|
|
|
|
param: torch.nn.Parameter,
|
|
|
|
def __init__(
|
|
|
|
tensor_type: TensorType = TensorType.DATA,
|
|
|
|
self,
|
|
|
|
process_group = None,
|
|
|
|
param: torch.nn.Parameter,
|
|
|
|
) -> None:
|
|
|
|
tensor_type: TensorType = TensorType.DATA,
|
|
|
|
|
|
|
|
process_group=None,
|
|
|
|
|
|
|
|
) -> None:
|
|
|
|
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
|
|
|
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
|
|
|
self.world_size = dist.get_world_size(self.process_group)
|
|
|
|
self.world_size = dist.get_world_size(self.process_group)
|
|
|
|
self.local_rank = dist.get_rank(self.process_group)
|
|
|
|
self.local_rank = dist.get_rank(self.process_group)
|
|
|
@ -27,25 +30,25 @@ class ShardParam(object):
|
|
|
|
self._payload_numel = None
|
|
|
|
self._payload_numel = None
|
|
|
|
self._origin_shape = param.shape
|
|
|
|
self._origin_shape = param.shape
|
|
|
|
self._origin_numel = param.numel()
|
|
|
|
self._origin_numel = param.numel()
|
|
|
|
self.is_shared = False
|
|
|
|
self.is_sharded = False
|
|
|
|
|
|
|
|
|
|
|
|
def payload(self, target_device : torch.device):
|
|
|
|
def payload(self, target_device: torch.device):
|
|
|
|
return self._param_payload.to(target_device)
|
|
|
|
return self._param_payload.to(target_device)
|
|
|
|
|
|
|
|
|
|
|
|
def shard(self):
|
|
|
|
def shard(self):
|
|
|
|
r"""
|
|
|
|
r"""
|
|
|
|
Distributed the payload of param to all processes.
|
|
|
|
Distributed the payload of param to all processes.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if self.is_shared:
|
|
|
|
if self.is_sharded:
|
|
|
|
return
|
|
|
|
return
|
|
|
|
self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size)
|
|
|
|
self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size)
|
|
|
|
self.is_shared = True
|
|
|
|
self.is_sharded = True
|
|
|
|
|
|
|
|
|
|
|
|
def gather(self):
|
|
|
|
def gather(self):
|
|
|
|
r"""
|
|
|
|
r"""
|
|
|
|
Collect the payload of param from different processes to process of local rank.
|
|
|
|
Collect the payload of param from different processes to process of local rank.
|
|
|
|
"""
|
|
|
|
"""
|
|
|
|
if not self.is_shared:
|
|
|
|
if not self.is_sharded:
|
|
|
|
return
|
|
|
|
return
|
|
|
|
|
|
|
|
|
|
|
|
buffer_list = []
|
|
|
|
buffer_list = []
|
|
|
@ -56,8 +59,9 @@ class ShardParam(object):
|
|
|
|
else:
|
|
|
|
else:
|
|
|
|
buffer_list.append(torch.zeros(payload_numel).cuda())
|
|
|
|
buffer_list.append(torch.zeros(payload_numel).cuda())
|
|
|
|
|
|
|
|
|
|
|
|
torch.distributed.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group, async_op=False)
|
|
|
|
torch.distributed.all_gather(buffer_list,
|
|
|
|
print(buffer_list)
|
|
|
|
buffer_list[self.local_rank],
|
|
|
|
|
|
|
|
group=self.process_group,
|
|
|
|
|
|
|
|
async_op=False)
|
|
|
|
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
|
|
|
|
self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
|
|
|
|
self.is_shared = False
|
|
|
|
self.is_sharded = False
|
|
|
|
|
|
|
|
|
|
|
|