Polish sharded parameter (#297)

* init shard param from shape tuple

* add more unitest for shard param

* add more unittests to shareded param
pull/394/head
Jiarui Fang 2022-03-03 12:42:57 +08:00 committed by Frank Lee
parent 7aef75ca42
commit e17e92c54d
5 changed files with 106 additions and 61 deletions

View File

@ -1,3 +0,0 @@
from .shard_param import ShardParam
__all__ = ['ShardParam']

View File

@ -1,4 +1,3 @@
import functools import functools
from typing import Any, Optional from typing import Any, Optional
@ -7,11 +6,10 @@ import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
register_ophooks_recursively)
from colossalai.engine.paramhooks import BaseParamHookMgr from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.zero.shard_param import ShardParam from colossalai.zero.sharded_param import ShardedParam
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
@ -21,17 +19,19 @@ from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
class ShardedModelV2(nn.Module): class ShardedModelV2(nn.Module):
def __init__(self,
module: nn.Module, def __init__(
process_group: Optional[ProcessGroup] = None, self,
reduce_scatter_process_group: Optional[ProcessGroup] = None, module: nn.Module,
reduce_scatter_bucket_size_mb: int = 25, process_group: Optional[ProcessGroup] = None,
reshard_after_forward: bool = True, reduce_scatter_process_group: Optional[ProcessGroup] = None,
mixed_precision: bool = False, reduce_scatter_bucket_size_mb: int = 25,
fp32_reduce_scatter: bool = False, reshard_after_forward: bool = True,
offload_config: Optional[dict] = None, mixed_precision: bool = False,
gradient_predivide_factor: Optional[float] = 1.0, fp32_reduce_scatter: bool = False,
): offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
):
r""" r"""
A demo to reconfigure zero1 shared_model. A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States. Currently do not consider the Optimizer States.
@ -49,7 +49,7 @@ class ShardedModelV2(nn.Module):
# Shard the parameters at first # Shard the parameters at first
for _, param in self.module.named_parameters(): for _, param in self.module.named_parameters():
param.ca_attr = ShardParam(param) param.ca_attr = ShardedParam(param)
param.ca_attr.shard() param.ca_attr.shard()
param._sharded_grad = ShardedGradient(param, self, offload_config) param._sharded_grad = ShardedGradient(param, self, offload_config)
@ -64,8 +64,10 @@ class ShardedModelV2(nn.Module):
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem # We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor # So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically # However, if you set gradient_predivide_factor to None,
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \ # we will set gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = \
gradient_predivide_factor if gradient_predivide_factor is not None else \
get_gradient_predivide_factor(self.world_size) get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
@ -107,7 +109,8 @@ class ShardedModelV2(nn.Module):
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]: def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
""" """
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example:: GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
before reduce_scatter: before reduce_scatter:
@ -139,8 +142,9 @@ class ShardedModelV2(nn.Module):
orig_grad_data = new_grad.data orig_grad_data = new_grad.data
if self.world_size > 1: if self.world_size > 1:
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size()) grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
self.reducer.reduce_scatter_async( self.reducer.reduce_scatter_async(grad_chunks,
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param)) group=self.reduce_scatter_process_group,
callback_fn=functools.partial(self._reduce_scatter_callback, param))
else: else:
self._reduce_scatter_callback(param, new_grad) self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream) orig_grad_data.record_stream(self.comm_stream)

View File

@ -0,0 +1,3 @@
from .sharded_param import ShardedParam
__all__ = ['ShardedParam']

View File

@ -1,41 +1,59 @@
from enum import Enum
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.zero.sharded_model._zero3_utils import get_shard from colossalai.zero.sharded_model._zero3_utils import get_shard
from typing import Union, Tuple, Optional
import numpy
class TensorType(Enum): class ShardedParam(object):
GRAD = 1
DATA = 2
class ShardParam(object):
r""" r"""
A wrapper to torch.nn.Parameter. Shard a param A wrapper to torch.nn.Parameter. Shard a param
on different processes. on memory space of different processes.
""" """
def __init__( def __init__(self,
self, other: Union[torch.nn.Parameter, Tuple[int, ...]],
param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None,
tensor_type: TensorType = TensorType.DATA, is_sharded: bool = False,
process_group=None, device: Optional[torch.device] = None) -> None:
) -> None: r"""
other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape.
process_group: the process group storing the shared data.
is_sharded: is shared the param during __init__.
device: the device to place param data payload on
"""
self.process_group = process_group or gpc.get_group(ParallelMode.DATA) self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.world_size = dist.get_world_size(self.process_group) self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group) self.local_rank = dist.get_rank(self.process_group)
self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
self._payload_shape = None
self._payload_numel = None
self._origin_shape = param.shape
self._origin_numel = param.numel()
self._origin_dtype = param.dtype
self.is_sharded = False self.is_sharded = False
# Hijack the data payload of param
if isinstance(other, torch.nn.Parameter):
self._param_payload = other.data.to(device)
self._origin_shape = other.shape
self._origin_numel = other.numel()
if is_sharded:
self.shard()
elif isinstance(other, tuple):
self._origin_shape = other
self._origin_numel = numpy.prod(other)
# TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape.
assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple"
self._param_payload = torch.empty(self._origin_shape, device=device)
if is_sharded:
self.shard()
else:
raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}")
self._payload_numel = None
def payload(self, target_device: torch.device): def payload(self, target_device: torch.device):
r"""
get the payload and move it to target device
"""
return self._param_payload.to(target_device) return self._param_payload.to(target_device)
def shard(self): def shard(self):
@ -50,6 +68,7 @@ class ShardParam(object):
def gather(self): def gather(self):
r""" r"""
Collect the payload of param from different processes to process of local rank. Collect the payload of param from different processes to process of local rank.
The payload has to be moved to cuda memory before communication.
""" """
if not self.is_sharded: if not self.is_sharded:
return return

View File

@ -1,50 +1,72 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from asyncio.log import logger
from functools import partial from functools import partial
import colossalai import colossalai
import pytest import pytest
import torch import torch
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.zero.shard_param import ShardParam from colossalai.zero.sharded_param import ShardedParam
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.logging import get_dist_logger, disable_existing_loggers from colossalai.logging import get_dist_logger, disable_existing_loggers
from tests.test_zero_data_parallel.common import Net, CONFIG from tests.test_zero_data_parallel.common import Net, CONFIG
def run_init_shard_param(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
param = torch.nn.Parameter(data=torch.rand(2, 3))
sparam = ShardedParam(param, None, True)
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])
del sparam
param_shape = (2, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])
param_shape = (2, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [2, 3])
def run_shard_param_check(rank, world_size, port): def run_shard_param_check(rank, world_size, port):
colossalai.launch(config=CONFIG, colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
logger = get_dist_logger() logger = get_dist_logger()
model = Net() model = Net()
# add an attribute as ca_attr to hijack the access to param.data # add an attribute as ca_attr to hijack the access to param.data
for _, param in model.named_parameters(): for _, param in model.named_parameters():
numel_ref = (param.numel() + world_size - 1) // world_size numel_ref = (param.numel() + world_size - 1) // world_size
param.ca_attr = ShardParam(param) param.ca_attr = ShardedParam(param)
param.ca_attr.shard() param.ca_attr.shard()
param_data = param.ca_attr.payload(torch.device('cpu')) param_data = param.ca_attr.payload(torch.device('cpu'))
logger.info(f'shard {param_data.shape} {param_data}', ranks = [1]) assert (numel_ref == param_data.numel())
assert(numel_ref == param_data.numel())
for _, param in model.named_parameters(): for _, param in model.named_parameters():
param.ca_attr.gather() param.ca_attr.gather()
param_data = param.ca_attr.payload(torch.device('cpu')) param_data = param.ca_attr.payload(torch.device('cpu'))
logger.info(f'gather {param_data.shape} {param_data}', ranks = [1])
disable_existing_loggers([logger]) disable_existing_loggers([logger])
@pytest.mark.dist @pytest.mark.dist
def test_run_shard_shape(): def test_shard_shape():
world_size = 2 world_size = 2
run_func = partial(run_shard_param_check, world_size=world_size, port=free_port()) run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)
@pytest.mark.dist
def test_init_shard_param():
world_size = 2
run_func = partial(run_init_shard_param, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__': if __name__ == '__main__':
test_run_shard_shape() test_shard_shape()
test_init_shard_param()