mirror of https://github.com/hpcaitech/ColossalAI
Polish sharded parameter (#297)
* init shard param from shape tuple * add more unitest for shard param * add more unittests to shareded parampull/394/head
parent
7aef75ca42
commit
e17e92c54d
|
@ -1,3 +0,0 @@
|
||||||
from .shard_param import ShardParam
|
|
||||||
|
|
||||||
__all__ = ['ShardParam']
|
|
|
@ -1,4 +1,3 @@
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from typing import Any, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
@ -7,11 +6,10 @@ import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook,
|
from colossalai.engine.ophooks import (ShardGradHook, ShardParamHook, register_ophooks_recursively)
|
||||||
register_ophooks_recursively)
|
|
||||||
from colossalai.engine.paramhooks import BaseParamHookMgr
|
from colossalai.engine.paramhooks import BaseParamHookMgr
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
from colossalai.zero.shard_param import ShardParam
|
from colossalai.zero.sharded_param import ShardedParam
|
||||||
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
from colossalai.zero.sharded_model.reduce_scatter import ReduceScatterBucketer
|
||||||
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
|
from colossalai.zero.sharded_model.sharded_grad import ShardedGradient
|
||||||
from torch.distributed import ProcessGroup
|
from torch.distributed import ProcessGroup
|
||||||
|
@ -21,17 +19,19 @@ from ._zero3_utils import chunk_and_pad, get_gradient_predivide_factor
|
||||||
|
|
||||||
|
|
||||||
class ShardedModelV2(nn.Module):
|
class ShardedModelV2(nn.Module):
|
||||||
def __init__(self,
|
|
||||||
module: nn.Module,
|
def __init__(
|
||||||
process_group: Optional[ProcessGroup] = None,
|
self,
|
||||||
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
module: nn.Module,
|
||||||
reduce_scatter_bucket_size_mb: int = 25,
|
process_group: Optional[ProcessGroup] = None,
|
||||||
reshard_after_forward: bool = True,
|
reduce_scatter_process_group: Optional[ProcessGroup] = None,
|
||||||
mixed_precision: bool = False,
|
reduce_scatter_bucket_size_mb: int = 25,
|
||||||
fp32_reduce_scatter: bool = False,
|
reshard_after_forward: bool = True,
|
||||||
offload_config: Optional[dict] = None,
|
mixed_precision: bool = False,
|
||||||
gradient_predivide_factor: Optional[float] = 1.0,
|
fp32_reduce_scatter: bool = False,
|
||||||
):
|
offload_config: Optional[dict] = None,
|
||||||
|
gradient_predivide_factor: Optional[float] = 1.0,
|
||||||
|
):
|
||||||
r"""
|
r"""
|
||||||
A demo to reconfigure zero1 shared_model.
|
A demo to reconfigure zero1 shared_model.
|
||||||
Currently do not consider the Optimizer States.
|
Currently do not consider the Optimizer States.
|
||||||
|
@ -49,7 +49,7 @@ class ShardedModelV2(nn.Module):
|
||||||
|
|
||||||
# Shard the parameters at first
|
# Shard the parameters at first
|
||||||
for _, param in self.module.named_parameters():
|
for _, param in self.module.named_parameters():
|
||||||
param.ca_attr = ShardParam(param)
|
param.ca_attr = ShardedParam(param)
|
||||||
param.ca_attr.shard()
|
param.ca_attr.shard()
|
||||||
param._sharded_grad = ShardedGradient(param, self, offload_config)
|
param._sharded_grad = ShardedGradient(param, self, offload_config)
|
||||||
|
|
||||||
|
@ -64,8 +64,10 @@ class ShardedModelV2(nn.Module):
|
||||||
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
self._cpu_offload: bool = offload_config.get('device', None) == 'cpu' if offload_config else False
|
||||||
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
# We find if gradient_predivide_factor != 1.0, there may be wrong precision problem
|
||||||
# So we use 1.0 as the default gradient_predivide_factor
|
# So we use 1.0 as the default gradient_predivide_factor
|
||||||
# However, if you set gradient_predivide_factor to None, we will set gradient_predivide_factor to a value >= 1.0 automatically
|
# However, if you set gradient_predivide_factor to None,
|
||||||
self.gradient_predivide_factor: float = gradient_predivide_factor if gradient_predivide_factor is not None else \
|
# we will set gradient_predivide_factor to a value >= 1.0 automatically
|
||||||
|
self.gradient_predivide_factor: float = \
|
||||||
|
gradient_predivide_factor if gradient_predivide_factor is not None else \
|
||||||
get_gradient_predivide_factor(self.world_size)
|
get_gradient_predivide_factor(self.world_size)
|
||||||
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
self.gradient_postdivide_factor: float = self.world_size / self.gradient_predivide_factor
|
||||||
|
|
||||||
|
@ -107,7 +109,8 @@ class ShardedModelV2(nn.Module):
|
||||||
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
def _grad_post_backward_hook(self, param: Parameter, grad: torch.Tensor) -> Optional[torch.Tensor]:
|
||||||
"""
|
"""
|
||||||
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
At the start of :func:`_grad_post_backward_hook`, ``param.grad`` contains the
|
||||||
full gradient for the local batch. The reduce-scatter op will save a single shard of the summed gradient across all
|
full gradient for the local batch. The reduce-scatter op will save
|
||||||
|
a single shard of the summed gradient across all
|
||||||
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
|
GPUs to param._sharded_grad. This shard will align with the current GPU rank. For example::
|
||||||
|
|
||||||
before reduce_scatter:
|
before reduce_scatter:
|
||||||
|
@ -139,8 +142,9 @@ class ShardedModelV2(nn.Module):
|
||||||
orig_grad_data = new_grad.data
|
orig_grad_data = new_grad.data
|
||||||
if self.world_size > 1:
|
if self.world_size > 1:
|
||||||
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
grad_chunks = chunk_and_pad(orig_grad_data, self.reduce_scatter_process_group.size())
|
||||||
self.reducer.reduce_scatter_async(
|
self.reducer.reduce_scatter_async(grad_chunks,
|
||||||
grad_chunks, group=self.reduce_scatter_process_group, callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
group=self.reduce_scatter_process_group,
|
||||||
|
callback_fn=functools.partial(self._reduce_scatter_callback, param))
|
||||||
else:
|
else:
|
||||||
self._reduce_scatter_callback(param, new_grad)
|
self._reduce_scatter_callback(param, new_grad)
|
||||||
orig_grad_data.record_stream(self.comm_stream)
|
orig_grad_data.record_stream(self.comm_stream)
|
||||||
|
|
|
@ -0,0 +1,3 @@
|
||||||
|
from .sharded_param import ShardedParam
|
||||||
|
|
||||||
|
__all__ = ['ShardedParam']
|
|
@ -1,41 +1,59 @@
|
||||||
from enum import Enum
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||||
|
from typing import Union, Tuple, Optional
|
||||||
|
import numpy
|
||||||
|
|
||||||
|
|
||||||
class TensorType(Enum):
|
class ShardedParam(object):
|
||||||
GRAD = 1
|
|
||||||
DATA = 2
|
|
||||||
|
|
||||||
|
|
||||||
class ShardParam(object):
|
|
||||||
r"""
|
r"""
|
||||||
A wrapper to torch.nn.Parameter. Shard a param
|
A wrapper to torch.nn.Parameter. Shard a param
|
||||||
on different processes.
|
on memory space of different processes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
other: Union[torch.nn.Parameter, Tuple[int, ...]],
|
||||||
param: torch.nn.Parameter,
|
process_group: Optional[dist.ProcessGroup] = None,
|
||||||
tensor_type: TensorType = TensorType.DATA,
|
is_sharded: bool = False,
|
||||||
process_group=None,
|
device: Optional[torch.device] = None) -> None:
|
||||||
) -> None:
|
r"""
|
||||||
|
other: either an existing torch parameter or a tuple, indicate allocate a new param with the tuple as shape.
|
||||||
|
process_group: the process group storing the shared data.
|
||||||
|
is_sharded: is shared the param during __init__.
|
||||||
|
device: the device to place param data payload on
|
||||||
|
"""
|
||||||
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
self.process_group = process_group or gpc.get_group(ParallelMode.DATA)
|
||||||
self.world_size = dist.get_world_size(self.process_group)
|
self.world_size = dist.get_world_size(self.process_group)
|
||||||
self.local_rank = dist.get_rank(self.process_group)
|
self.local_rank = dist.get_rank(self.process_group)
|
||||||
self._param_payload = param.data if tensor_type == TensorType.DATA else param.grad
|
|
||||||
self._payload_shape = None
|
|
||||||
self._payload_numel = None
|
|
||||||
self._origin_shape = param.shape
|
|
||||||
self._origin_numel = param.numel()
|
|
||||||
self._origin_dtype = param.dtype
|
|
||||||
self.is_sharded = False
|
self.is_sharded = False
|
||||||
|
|
||||||
|
# Hijack the data payload of param
|
||||||
|
if isinstance(other, torch.nn.Parameter):
|
||||||
|
self._param_payload = other.data.to(device)
|
||||||
|
self._origin_shape = other.shape
|
||||||
|
self._origin_numel = other.numel()
|
||||||
|
if is_sharded:
|
||||||
|
self.shard()
|
||||||
|
elif isinstance(other, tuple):
|
||||||
|
self._origin_shape = other
|
||||||
|
self._origin_numel = numpy.prod(other)
|
||||||
|
|
||||||
|
# TODO(jiaruifang) can be optimized. Directly allocate payload as the sharded shape.
|
||||||
|
assert device is not None, "You have to assign a device to initialize a ShardParam from a shape tuple"
|
||||||
|
self._param_payload = torch.empty(self._origin_shape, device=device)
|
||||||
|
if is_sharded:
|
||||||
|
self.shard()
|
||||||
|
else:
|
||||||
|
raise RuntimeError(f"Initialize ShardParam failed. The 2nd parameter is wrong type {type(other)}")
|
||||||
|
|
||||||
|
self._payload_numel = None
|
||||||
|
|
||||||
def payload(self, target_device: torch.device):
|
def payload(self, target_device: torch.device):
|
||||||
|
r"""
|
||||||
|
get the payload and move it to target device
|
||||||
|
"""
|
||||||
return self._param_payload.to(target_device)
|
return self._param_payload.to(target_device)
|
||||||
|
|
||||||
def shard(self):
|
def shard(self):
|
||||||
|
@ -50,6 +68,7 @@ class ShardParam(object):
|
||||||
def gather(self):
|
def gather(self):
|
||||||
r"""
|
r"""
|
||||||
Collect the payload of param from different processes to process of local rank.
|
Collect the payload of param from different processes to process of local rank.
|
||||||
|
The payload has to be moved to cuda memory before communication.
|
||||||
"""
|
"""
|
||||||
if not self.is_sharded:
|
if not self.is_sharded:
|
||||||
return
|
return
|
|
@ -1,50 +1,72 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
from asyncio.log import logger
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.zero.shard_param import ShardParam
|
from colossalai.zero.sharded_param import ShardedParam
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||||
from tests.test_zero_data_parallel.common import Net, CONFIG
|
from tests.test_zero_data_parallel.common import Net, CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def run_init_shard_param(rank, world_size, port):
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
param = torch.nn.Parameter(data=torch.rand(2, 3))
|
||||||
|
sparam = ShardedParam(param, None, True)
|
||||||
|
payload = sparam.payload(torch.device('cuda'))
|
||||||
|
assert (list(payload.shape) == [3])
|
||||||
|
del sparam
|
||||||
|
|
||||||
|
param_shape = (2, 3)
|
||||||
|
sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
|
||||||
|
payload = sparam.payload(torch.device('cuda'))
|
||||||
|
assert (list(payload.shape) == [3])
|
||||||
|
|
||||||
|
param_shape = (2, 3)
|
||||||
|
sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
|
||||||
|
payload = sparam.payload(torch.device('cuda'))
|
||||||
|
assert (list(payload.shape) == [2, 3])
|
||||||
|
|
||||||
|
|
||||||
def run_shard_param_check(rank, world_size, port):
|
def run_shard_param_check(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG,
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
rank=rank,
|
|
||||||
world_size=world_size,
|
|
||||||
host='localhost',
|
|
||||||
port=port,
|
|
||||||
backend='nccl')
|
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
model = Net()
|
model = Net()
|
||||||
|
|
||||||
# add an attribute as ca_attr to hijack the access to param.data
|
# add an attribute as ca_attr to hijack the access to param.data
|
||||||
for _, param in model.named_parameters():
|
for _, param in model.named_parameters():
|
||||||
numel_ref = (param.numel() + world_size - 1) // world_size
|
numel_ref = (param.numel() + world_size - 1) // world_size
|
||||||
param.ca_attr = ShardParam(param)
|
param.ca_attr = ShardedParam(param)
|
||||||
param.ca_attr.shard()
|
param.ca_attr.shard()
|
||||||
param_data = param.ca_attr.payload(torch.device('cpu'))
|
param_data = param.ca_attr.payload(torch.device('cpu'))
|
||||||
logger.info(f'shard {param_data.shape} {param_data}', ranks = [1])
|
assert (numel_ref == param_data.numel())
|
||||||
assert(numel_ref == param_data.numel())
|
|
||||||
|
|
||||||
for _, param in model.named_parameters():
|
for _, param in model.named_parameters():
|
||||||
param.ca_attr.gather()
|
param.ca_attr.gather()
|
||||||
param_data = param.ca_attr.payload(torch.device('cpu'))
|
param_data = param.ca_attr.payload(torch.device('cpu'))
|
||||||
logger.info(f'gather {param_data.shape} {param_data}', ranks = [1])
|
|
||||||
|
|
||||||
disable_existing_loggers([logger])
|
disable_existing_loggers([logger])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
def test_run_shard_shape():
|
def test_shard_shape():
|
||||||
world_size = 2
|
world_size = 2
|
||||||
run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
|
run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
def test_init_shard_param():
|
||||||
|
world_size = 2
|
||||||
|
run_func = partial(run_init_shard_param, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_run_shard_shape()
|
test_shard_shape()
|
||||||
|
test_init_shard_param()
|
||||||
|
|
Loading…
Reference in New Issue