From e17e92c54d106f7cfbc7d480200b44e404049b99 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 3 Mar 2022 12:42:57 +0800 Subject: [PATCH] Polish sharded parameter (#297) * init shard param from shape tuple * add more unitest for shard param * add more unittests to shareded param --- colossalai/zero/shard_param/__init__.py | 3 - .../zero/sharded_model/sharded_model_v2.py | 46 +++++++------- colossalai/zero/sharded_param/__init__.py | 3 + .../sharded_param.py} | 61 ++++++++++++------- .../test_shard_param.py | 54 +++++++++++----- 5 files changed, 106 insertions(+), 61 deletions(-) delete mode 100644 colossalai/zero/shard_param/__init__.py create mode 100644 colossalai/zero/sharded_param/__init__.py rename colossalai/zero/{shard_param/shard_param.py => sharded_param/sharded_param.py} (51%) diff --git a/colossalai/zero/shard_param/__init__.py b/colossalai/zero/shard_param/__init__.py deleted file mode 100644 index bd7f5e46b..000000000 --- a/colossalai/zero/shard_param/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .shard_param import ShardParam - -__all__ = ['ShardParam'] \ No newline at end of file diff --git a/colossalai/zero/sharded_model/sharded_model_v2.py b/colossalai/zero/sharded_model/sharded_model_v2.py index a32afdff2..36e3e4b30 100644 --- a/colossalai/zero/sharded_model/sharded_model_v2.py +++ b/colossalai/zero/sharded_model/sharded_model_v2.py @@ -1,4 +1,3 @@ - import functools from typing import Any, Optional @@ -7,11 +6,10 @@ import torch.distributed as dist import torch.nn as nn from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc -from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, - register_ophooks_recursively) +from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively) from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.logging import get_dist_logger -from colossalai.zero.shard_param import ShardParam +from colossalai.zero.sharded_param import ShardedParam from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_model.sharded_grad import ShardedGradient from torch.distributed import ProcessGroup @@ -21,17 +19,19 @@ from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor class ShardedModelV2(nn.Module): - def __init__(self, - module: nn.Module, - process_group: Optional[ProcessGroup] = None, - reduce_scatter_process_group: Optional[ProcessGroup] = None, - reduce_scatter_bucket_size_mb: int = 25, - reshard_after_forward: bool = True, - mixed_precision: bool = False, - fp32_reduce_scatter: bool = False, - offload_config: Optional[dict] = None, - gradient_predivide_factor: Optional[float] = 1.0, - ): + + def __init__( + self, + module: nn.Module, + process_group: Optional[ProcessGroup] = None, + reduce_scatter_process_group: Optional[ProcessGroup] = None, + reduce_scatter_bucket_size_mb: int = 25, + reshard_after_forward: bool = True, + mixed_precision: bool = False, + fp32_reduce_scatter: bool = False, + offload_config: Optional[dict] = None, + gradient_predivide_factor: Optional[float] = 1.0, + ): r""" A demo to reconfigure zero1 shared_model. Currently do not consider the Optimizer States. @@ -49,7 +49,7 @@ class ShardedModelV2(nn.Module): # Shard the parameters at first for _, param in self.module.named_parameters(): - param.ca_attr = ShardParam(param) + param.ca_attr = ShardedParam(param) param.ca_attr.shard() param._sharded_grad = ShardedGradient(param, self, offload_config) @@ -64,8 +64,10 @@ class ShardedModelV2(nn.Module): self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem # So we use 1.0 as the default gradient_predivide_factor - # However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically - self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \ + # However, if you set gradient_predivide_factor to None, + # we will set gradient_predivide_factor to a value >= 1.0 automatically + self.gradient_predivide_factor: float = \ + gradient_predivide_factor if gradient_predivide_factor is not None else \ get_gradient_predivide_factor(self.world_size) self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor @@ -107,7 +109,8 @@ class ShardedModelV2(nn.Module): def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: """ At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the - full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all + full gradient for the local batch. The reduce-scatter op will save + a single shard of the summed gradient across all GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example:: before reduce_scatter: @@ -139,8 +142,9 @@ class ShardedModelV2(nn.Module): orig_grad_data = new_grad.data if self.world_size > 1: grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size()) - self.reducer.reduce_scatter_async( - grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param)) + self.reducer.reduce_scatter_async(grad_chunks, + group=self.reduce_scatter_process_group, + callback_fn=functools.partial(self._reduce_scatter_callback, param)) else: self._reduce_scatter_callback(param, new_grad) orig_grad_data.record_stream(self.comm_stream) diff --git a/colossalai/zero/sharded_param/__init__.py b/colossalai/zero/sharded_param/__init__.py new file mode 100644 index 000000000..527cf11d6 --- /dev/null +++ b/colossalai/zero/sharded_param/__init__.py @@ -0,0 +1,3 @@ +from .sharded_param import ShardedParam + +__all__ = ['ShardedParam'] diff --git a/colossalai/zero/shard_param/shard_param.py b/colossalai/zero/sharded_param/sharded_param.py similarity index 51% rename from colossalai/zero/shard_param/shard_param.py rename to colossalai/zero/sharded_param/sharded_param.py index 7bc36470f..f7363d0a5 100644 --- a/colossalai/zero/shard_param/shard_param.py +++ b/colossalai/zero/sharded_param/sharded_param.py @@ -1,41 +1,59 @@ -from enum import Enum - import torch import torch.distributed as dist from colossalai.context.parallel_mode import ParallelMode from colossalai.core import global_context as gpc from colossalai.zero.sharded_model._zero3_utils import get_shard +from typing import Union, Tuple, Optional +import numpy -class TensorType(Enum): - GRAD = 1 - DATA = 2 - - -class ShardParam(object): +class ShardedParam(object): r""" A wrapper to torch.nn.Parameter. Shard a param - on different processes. + on memory space of different processes. """ - def __init__( - self, - param: torch.nn.Parameter, - tensor_type: TensorType = TensorType.DATA, - process_group=None, - ) -> None: + def __init__(self, + other: Union[torch.nn.Parameter, Tuple[int, ...]], + process_group: Optional[dist.ProcessGroup] = None, + is_sharded: bool = False, + device: Optional[torch.device] = None) -> None: + r""" + other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape. + process_group: the process group storing the shared data. + is_sharded: is shared the param during __init__. + device: the device to place param data payload on + """ self.process_group = process_group or gpc.get_group(ParallelMode.DATA) self.world_size = dist.get_world_size(self.process_group) self.local_rank = dist.get_rank(self.process_group) - self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad - self._payload_shape = None - self._payload_numel = None - self._origin_shape = param.shape - self._origin_numel = param.numel() - self._origin_dtype = param.dtype self.is_sharded = False + # Hijack the data payload of param + if isinstance(other, torch.nn.Parameter): + self._param_payload = other.data.to(device) + self._origin_shape = other.shape + self._origin_numel = other.numel() + if is_sharded: + self.shard() + elif isinstance(other, tuple): + self._origin_shape = other + self._origin_numel = numpy.prod(other) + + # TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape. + assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple" + self._param_payload = torch.empty(self._origin_shape, device=device) + if is_sharded: + self.shard() + else: + raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}") + + self._payload_numel = None + def payload(self, target_device: torch.device): + r""" + get the payload and move it to target device + """ return self._param_payload.to(target_device) def shard(self): @@ -50,6 +68,7 @@ class ShardParam(object): def gather(self): r""" Collect the payload of param from different processes to process of local rank. + The payload has to be moved to cuda memory before communication. """ if not self.is_sharded: return diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index 9973ee524..642cd7f2b 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -1,50 +1,72 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from asyncio.log import logger from functools import partial import colossalai import pytest import torch import torch.multiprocessing as mp -from colossalai.zero.shard_param import ShardParam +from colossalai.zero.sharded_param import ShardedParam from colossalai.utils import free_port from colossalai.logging import get_dist_logger, disable_existing_loggers from tests.test_zero_data_parallel.common import Net, CONFIG + +def run_init_shard_param(rank, world_size, port): + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + param = torch.nn.Parameter(data=torch.rand(2, 3)) + sparam = ShardedParam(param, None, True) + payload = sparam.payload(torch.device('cuda')) + assert (list(payload.shape) == [3]) + del sparam + + param_shape = (2, 3) + sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu')) + payload = sparam.payload(torch.device('cuda')) + assert (list(payload.shape) == [3]) + + param_shape = (2, 3) + sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu')) + payload = sparam.payload(torch.device('cuda')) + assert (list(payload.shape) == [2, 3]) + + def run_shard_param_check(rank, world_size, port): - colossalai.launch(config=CONFIG, - rank=rank, - world_size=world_size, - host='localhost', - port=port, - backend='nccl') - + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + logger = get_dist_logger() model = Net() # add an attribute as ca_attr to hijack the access to param.data for _, param in model.named_parameters(): numel_ref = (param.numel() + world_size - 1) // world_size - param.ca_attr = ShardParam(param) + param.ca_attr = ShardedParam(param) param.ca_attr.shard() param_data = param.ca_attr.payload(torch.device('cpu')) - logger.info(f'shard {param_data.shape} {param_data}', ranks = [1]) - assert(numel_ref == param_data.numel()) + assert (numel_ref == param_data.numel()) for _, param in model.named_parameters(): param.ca_attr.gather() param_data = param.ca_attr.payload(torch.device('cpu')) - logger.info(f'gather {param_data.shape} {param_data}', ranks = [1]) - + disable_existing_loggers([logger]) + @pytest.mark.dist -def test_run_shard_shape(): +def test_shard_shape(): world_size = 2 run_func = partial(run_shard_param_check, world_size=world_size, port=free_port()) mp.spawn(run_func, nprocs=world_size) + +@pytest.mark.dist +def test_init_shard_param(): + world_size = 2 + run_func = partial(run_init_shard_param, world_size=world_size, port=free_port()) + mp.spawn(run_func, nprocs=world_size) + + if __name__ == '__main__': - test_run_shard_shape() \ No newline at end of file + test_shard_shape() + test_init_shard_param()