Polish sharded parameter (#297)

* init shard param from shape tuple

* add more unitest for shard param

* add more unittests to shareded param
pull/394/head
Jiarui Fang 3 years ago committed by Frank Lee
parent 7aef75ca42
commit e17e92c54d

@ -1,3 +0,0 @@
from .shard_param import ShardParam
__all__ = ['ShardParam']

@ -1,4 +1,3 @@
import functools
from typing import Any, Optional
@ -7,11 +6,10 @@ import torch.distributed as dist
import torch.nn as nn
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
register_ophooks_recursively)
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
from colossalai.engine.paramhooks import BaseParamHookMgr
from colossalai.logging import get_dist_logger
from colossalai.zero.shard_param import ShardParam
from colossalai.zero.sharded_param import ShardedParam
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
from torch.distributed import ProcessGroup
@ -21,17 +19,19 @@ from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
class ShardedModelV2(nn.Module):
def __init__(self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
reshard_after_forward: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
):
def __init__(
self,
module: nn.Module,
process_group: Optional[ProcessGroup] = None,
reduce_scatter_process_group: Optional[ProcessGroup] = None,
reduce_scatter_bucket_size_mb: int = 25,
reshard_after_forward: bool = True,
mixed_precision: bool = False,
fp32_reduce_scatter: bool = False,
offload_config: Optional[dict] = None,
gradient_predivide_factor: Optional[float] = 1.0,
):
r"""
A demo to reconfigure zero1 shared_model.
Currently do not consider the Optimizer States.
@ -49,7 +49,7 @@ class ShardedModelV2(nn.Module):
# Shard the parameters at first
for _, param in self.module.named_parameters():
param.ca_attr = ShardParam(param)
param.ca_attr = ShardedParam(param)
param.ca_attr.shard()
param._sharded_grad = ShardedGradient(param, self, offload_config)
@ -64,8 +64,10 @@ class ShardedModelV2(nn.Module):
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
# So we use 1.0 as the default gradient_predivide_factor
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
# However, if you set gradient_predivide_factor to None,
# we will set gradient_predivide_factor to a value >= 1.0 automatically
self.gradient_predivide_factor: float = \
gradient_predivide_factor if gradient_predivide_factor is not None else \
get_gradient_predivide_factor(self.world_size)
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
@ -107,7 +109,8 @@ class ShardedModelV2(nn.Module):
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
"""
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
full gradient for the local batch. The reduce-scatter op will save
a single shard of the summed gradient across all
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
before reduce_scatter:
@ -139,8 +142,9 @@ class ShardedModelV2(nn.Module):
orig_grad_data = new_grad.data
if self.world_size > 1:
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
self.reducer.reduce_scatter_async(
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
self.reducer.reduce_scatter_async(grad_chunks,
group=self.reduce_scatter_process_group,
callback_fn=functools.partial(self._reduce_scatter_callback, param))
else:
self._reduce_scatter_callback(param, new_grad)
orig_grad_data.record_stream(self.comm_stream)

@ -0,0 +1,3 @@
from .sharded_param import ShardedParam
__all__ = ['ShardedParam']

@ -1,41 +1,59 @@
from enum import Enum
import torch
import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.zero.sharded_model._zero3_utils import get_shard
from typing import Union, Tuple, Optional
import numpy
class TensorType(Enum):
GRAD = 1
DATA = 2
class ShardParam(object):
class ShardedParam(object):
r"""
A wrapper to torch.nn.Parameter. Shard a param
on different processes.
on memory space of different processes.
"""
def __init__(
self,
param: torch.nn.Parameter,
tensor_type: TensorType = TensorType.DATA,
process_group=None,
) -> None:
def __init__(self,
other: Union[torch.nn.Parameter, Tuple[int, ...]],
process_group: Optional[dist.ProcessGroup] = None,
is_sharded: bool = False,
device: Optional[torch.device] = None) -> None:
r"""
other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape.
process_group: the process group storing the shared data.
is_sharded: is shared the param during __init__.
device: the device to place param data payload on
"""
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
self._payload_shape = None
self._payload_numel = None
self._origin_shape = param.shape
self._origin_numel = param.numel()
self._origin_dtype = param.dtype
self.is_sharded = False
# Hijack the data payload of param
if isinstance(other, torch.nn.Parameter):
self._param_payload = other.data.to(device)
self._origin_shape = other.shape
self._origin_numel = other.numel()
if is_sharded:
self.shard()
elif isinstance(other, tuple):
self._origin_shape = other
self._origin_numel = numpy.prod(other)
# TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape.
assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple"
self._param_payload = torch.empty(self._origin_shape, device=device)
if is_sharded:
self.shard()
else:
raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}")
self._payload_numel = None
def payload(self, target_device: torch.device):
r"""
get the payload and move it to target device
"""
return self._param_payload.to(target_device)
def shard(self):
@ -50,6 +68,7 @@ class ShardParam(object):
def gather(self):
r"""
Collect the payload of param from different processes to process of local rank.
The payload has to be moved to cuda memory before communication.
"""
if not self.is_sharded:
return

@ -1,50 +1,72 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from asyncio.log import logger
from functools import partial
import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.zero.shard_param import ShardParam
from colossalai.zero.sharded_param import ShardedParam
from colossalai.utils import free_port
from colossalai.logging import get_dist_logger, disable_existing_loggers
from tests.test_zero_data_parallel.common import Net, CONFIG
def run_init_shard_param(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
param = torch.nn.Parameter(data=torch.rand(2, 3))
sparam = ShardedParam(param, None, True)
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])
del sparam
param_shape = (2, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])
param_shape = (2, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [2, 3])
def run_shard_param_check(rank, world_size, port):
colossalai.launch(config=CONFIG,
rank=rank,
world_size=world_size,
host='localhost',
port=port,
backend='nccl')
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
logger = get_dist_logger()
model = Net()
# add an attribute as ca_attr to hijack the access to param.data
for _, param in model.named_parameters():
numel_ref = (param.numel() + world_size - 1) // world_size
param.ca_attr = ShardParam(param)
param.ca_attr = ShardedParam(param)
param.ca_attr.shard()
param_data = param.ca_attr.payload(torch.device('cpu'))
logger.info(f'shard {param_data.shape} {param_data}', ranks = [1])
assert(numel_ref == param_data.numel())
assert (numel_ref == param_data.numel())
for _, param in model.named_parameters():
param.ca_attr.gather()
param_data = param.ca_attr.payload(torch.device('cpu'))
logger.info(f'gather {param_data.shape} {param_data}', ranks = [1])
disable_existing_loggers([logger])
@pytest.mark.dist
def test_run_shard_shape():
def test_shard_shape():
world_size = 2
run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
@pytest.mark.dist
def test_init_shard_param():
world_size = 2
run_func = partial(run_init_shard_param, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
if __name__ == '__main__':
test_run_shard_shape()
test_shard_shape()
test_init_shard_param()

Loading…
Cancel
Save