from typing import Optional, Tuple, Union import numpy import torch import torch.distributed as dist from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_param import ShardedTensor class ShardedParamV2(object): def __init__(self, param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None, rm_torch_payload=False) -> None: self._data_sharded_tensor: ShardedTensor = ShardedTensor(param.data, process_group) self.fp16_grad: Optional[torch.Tensor] = None self.fp32_grad: Optional[torch.Tensor] = None # make sure the shared param is the only owner of payload # The param.data maybe used to init the other part of the model. # For example: File "resnet.py", line 190, in __init__ # nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') # So we can not empty the .data at this time self.param = param if rm_torch_payload: self.remove_torch_payload() # Backward count for handle local grad accumulation # This value will increment by 1 in every pre-bwd hook # And will be reset to 0 in every final-bwd hook self.bwd_count = 0 def remove_torch_payload(self): self.param.data = torch.empty([], dtype=self.param.dtype, device=self.param.device) @property def data(self): return self._data_sharded_tensor @property def param_is_sharded(self): return self._data_sharded_tensor.is_sharded class ShardedParam(object): r""" A wrapper to torch.nn.Parameter. Shard a param on memory space of different processes. """ def __init__(self, other: Union[torch.nn.Parameter, Tuple[int, ...]], process_group: Optional[dist.ProcessGroup] = None, is_sharded: bool = False, device: Optional[torch.device] = None) -> None: r""" other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape. process_group: the process group storing the shared data. is_sharded: is shared the param during __init__. device: the device to place param data payload on """ self.process_group = process_group or gpc.get_group(ParallelMode.DATA) self.world_size = dist.get_world_size(self.process_group) self.local_rank = dist.get_rank(self.process_group) self.is_sharded = False self.device = device # Hijack the data payload of param if isinstance(other, torch.nn.Parameter): self._param_payload = other.data.to(device) self._origin_shape = other.shape self._origin_numel = other.numel() if is_sharded: self.shard() elif isinstance(other, tuple): self._origin_shape = other self._origin_numel = numpy.prod(other) # TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape. assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple" self._param_payload = torch.empty(self._origin_shape, device=device) if is_sharded: self.shard() else: raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}") self._payload_numel = None def payload(self, target_device: Optional[torch.device] = None): r""" get the payload and move it to target device """ if target_device is not None: return self._param_payload.to(target_device) return self._param_payload def set_payload(self, data: torch.Tensor): r""" set payload as data """ assert self._param_payload.shape == data.shape self._param_payload.copy_(data) def shard(self): r""" Distributed the payload of param to all processes. """ if self.is_sharded: return self._param_payload, _ = get_shard(self._param_payload, self.local_rank, self.world_size) self.is_sharded = True def gather(self): r""" Collect the payload of param from different processes to process of local rank. The payload has to be moved to cuda memory before communication. """ if not self.is_sharded: return buffer_list = [] payload_numel = self._param_payload.numel() for i in range(self.world_size): if i == self.local_rank: buffer_list.append(self._param_payload.cuda()) else: buffer_list.append(torch.zeros(payload_numel).cuda()) torch.distributed.all_gather(buffer_list, buffer_list[self.local_rank], group=self.process_group, async_op=False) self._param_payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape) self.is_sharded = False @property def origin_dtype(self): return self._origin_dtype