mirror of https://github.com/hpcaitech/ColossalAI
Polish sharded parameter (#297)
* init shard param from shape tuple * add more unitest for shard param * add more unittests to shareded parampull/394/head
parent
7aef75ca42
commit
e17e92c54d
|
@ -1,3 +0,0 @@
|
|||
from .shard_param import ShardParam
|
||||
|
||||
__all__ = ['ShardParam']
|
|
@ -1,4 +1,3 @@
|
|||
|
||||
import functools
|
||||
from typing import Any, Optional
|
||||
|
||||
|
@ -7,11 +6,10 @@ import torch.distributed as dist
|
|||
import torch.nn as nn
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
|
||||
register_ophooks_recursively)
|
||||
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
|
||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.zero.shard_param import ShardParam
|
||||
from colossalai.zero.sharded_param import ShardedParam
|
||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
|
||||
from torch.distributed import ProcessGroup
|
||||
|
@ -21,7 +19,9 @@ from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
|
|||
|
||||
|
||||
class ShardedModelV2(nn.Module):
|
||||
def __init__(self,
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
module: nn.Module,
|
||||
process_group: Optional[ProcessGroup] = None,
|
||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||
|
@ -49,7 +49,7 @@ class ShardedModelV2(nn.Module):
|
|||
|
||||
# Shard the parameters at first
|
||||
for _, param in self.module.named_parameters():
|
||||
param.ca_attr = ShardParam(param)
|
||||
param.ca_attr = ShardedParam(param)
|
||||
param.ca_attr.shard()
|
||||
param._sharded_grad = ShardedGradient(param, self, offload_config)
|
||||
|
||||
|
@ -64,8 +64,10 @@ class ShardedModelV2(nn.Module):
|
|||
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||
# So we use 1.0 as the default gradient_predivide_factor
|
||||
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
|
||||
# However, if you set gradient_predivide_factor to None,
|
||||
# we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||
self.gradient_predivide_factor: float = \
|
||||
gradient_predivide_factor if gradient_predivide_factor is not None else \
|
||||
get_gradient_predivide_factor(self.world_size)
|
||||
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
||||
|
||||
|
@ -107,7 +109,8 @@ class ShardedModelV2(nn.Module):
|
|||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||
"""
|
||||
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
||||
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
|
||||
full gradient for the local batch. The reduce-scatter op will save
|
||||
a single shard of the summed gradient across all
|
||||
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
|
||||
|
||||
before reduce_scatter:
|
||||
|
@ -139,8 +142,9 @@ class ShardedModelV2(nn.Module):
|
|||
orig_grad_data = new_grad.data
|
||||
if self.world_size > 1:
|
||||
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
||||
self.reducer.reduce_scatter_async(
|
||||
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||
self.reducer.reduce_scatter_async(grad_chunks,
|
||||
group=self.reduce_scatter_process_group,
|
||||
callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||
else:
|
||||
self._reduce_scatter_callback(param, new_grad)
|
||||
orig_grad_data.record_stream(self.comm_stream)
|
||||
|
|
|
@ -0,0 +1,3 @@
|
|||
from .sharded_param import ShardedParam
|
||||
|
||||
__all__ = ['ShardedParam']
|
|
@ -1,41 +1,59 @@
|
|||
from enum import Enum
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||
from typing import Union, Tuple, Optional
|
||||
import numpy
|
||||
|
||||
|
||||
class TensorType(Enum):
|
||||
GRAD = 1
|
||||
DATA = 2
|
||||
|
||||
|
||||
class ShardParam(object):
|
||||
class ShardedParam(object):
|
||||
r"""
|
||||
A wrapper to torch.nn.Parameter. Shard a param
|
||||
on different processes.
|
||||
on memory space of different processes.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
param: torch.nn.Parameter,
|
||||
tensor_type: TensorType = TensorType.DATA,
|
||||
process_group=None,
|
||||
) -> None:
|
||||
def __init__(self,
|
||||
other: Union[torch.nn.Parameter, Tuple[int, ...]],
|
||||
process_group: Optional[dist.ProcessGroup] = None,
|
||||
is_sharded: bool = False,
|
||||
device: Optional[torch.device] = None) -> None:
|
||||
r"""
|
||||
other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape.
|
||||
process_group: the process group storing the shared data.
|
||||
is_sharded: is shared the param during __init__.
|
||||
device: the device to place param data payload on
|
||||
"""
|
||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||
self.world_size = dist.get_world_size(self.process_group)
|
||||
self.local_rank = dist.get_rank(self.process_group)
|
||||
self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
|
||||
self._payload_shape = None
|
||||
self._payload_numel = None
|
||||
self._origin_shape = param.shape
|
||||
self._origin_numel = param.numel()
|
||||
self._origin_dtype = param.dtype
|
||||
self.is_sharded = False
|
||||
|
||||
# Hijack the data payload of param
|
||||
if isinstance(other, torch.nn.Parameter):
|
||||
self._param_payload = other.data.to(device)
|
||||
self._origin_shape = other.shape
|
||||
self._origin_numel = other.numel()
|
||||
if is_sharded:
|
||||
self.shard()
|
||||
elif isinstance(other, tuple):
|
||||
self._origin_shape = other
|
||||
self._origin_numel = numpy.prod(other)
|
||||
|
||||
# TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape.
|
||||
assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple"
|
||||
self._param_payload = torch.empty(self._origin_shape, device=device)
|
||||
if is_sharded:
|
||||
self.shard()
|
||||
else:
|
||||
raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}")
|
||||
|
||||
self._payload_numel = None
|
||||
|
||||
def payload(self, target_device: torch.device):
|
||||
r"""
|
||||
get the payload and move it to target device
|
||||
"""
|
||||
return self._param_payload.to(target_device)
|
||||
|
||||
def shard(self):
|
||||
|
@ -50,6 +68,7 @@ class ShardParam(object):
|
|||
def gather(self):
|
||||
r"""
|
||||
Collect the payload of param from different processes to process of local rank.
|
||||
The payload has to be moved to cuda memory before communication.
|
||||
"""
|
||||
if not self.is_sharded:
|
||||
return
|
|
@ -1,25 +1,39 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from asyncio.log import logger
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.zero.shard_param import ShardParam
|
||||
from colossalai.zero.sharded_param import ShardedParam
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||
from tests.test_zero_data_parallel.common import Net, CONFIG
|
||||
|
||||
|
||||
def run_init_shard_param(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
param = torch.nn.Parameter(data=torch.rand(2, 3))
|
||||
sparam = ShardedParam(param, None, True)
|
||||
payload = sparam.payload(torch.device('cuda'))
|
||||
assert (list(payload.shape) == [3])
|
||||
del sparam
|
||||
|
||||
param_shape = (2, 3)
|
||||
sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
|
||||
payload = sparam.payload(torch.device('cuda'))
|
||||
assert (list(payload.shape) == [3])
|
||||
|
||||
param_shape = (2, 3)
|
||||
sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
|
||||
payload = sparam.payload(torch.device('cuda'))
|
||||
assert (list(payload.shape) == [2, 3])
|
||||
|
||||
|
||||
def run_shard_param_check(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl')
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
logger = get_dist_logger()
|
||||
model = Net()
|
||||
|
@ -27,24 +41,32 @@ def run_shard_param_check(rank, world_size, port):
|
|||
# add an attribute as ca_attr to hijack the access to param.data
|
||||
for _, param in model.named_parameters():
|
||||
numel_ref = (param.numel() + world_size - 1) // world_size
|
||||
param.ca_attr = ShardParam(param)
|
||||
param.ca_attr = ShardedParam(param)
|
||||
param.ca_attr.shard()
|
||||
param_data = param.ca_attr.payload(torch.device('cpu'))
|
||||
logger.info(f'shard {param_data.shape} {param_data}', ranks = [1])
|
||||
assert (numel_ref == param_data.numel())
|
||||
|
||||
for _, param in model.named_parameters():
|
||||
param.ca_attr.gather()
|
||||
param_data = param.ca_attr.payload(torch.device('cpu'))
|
||||
logger.info(f'gather {param_data.shape} {param_data}', ranks = [1])
|
||||
|
||||
disable_existing_loggers([logger])
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_run_shard_shape():
|
||||
def test_shard_shape():
|
||||
world_size = 2
|
||||
run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_init_shard_param():
|
||||
world_size = 2
|
||||
run_func = partial(run_init_shard_param, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_run_shard_shape()
|
||||
test_shard_shape()
|
||||
test_init_shard_param()
|
||||
|
|
Loading…
Reference in New Issue