[zero] yet an improved sharded param (#311)

pull/394/head
Jiarui Fang 2022-03-04 15:49:23 +08:00 committed by Frank Lee
parent c9e7d9582d
commit 90d3aef62c
3 changed files with 82 additions and 21 deletions

View File

@ -1,4 +1,4 @@
from colossalai.zero.sharded_param.sharded_param import ShardedParam
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
from colossalai.zero.sharded_param.sharded_param import ShardedParam, ShardedParamV2
__all__ = ['ShardedParam', 'ShardedTensor']
__all__ = ['ShardedParam', 'ShardedTensor', 'ShardedParamV2']

View File

@ -6,6 +6,40 @@ import torch.distributed as dist
from colossalai.context.parallel_mode import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.zero.sharded_model._zero3_utils import get_shard
from colossalai.zero.sharded_param import ShardedTensor
from typing import Union, Tuple, Optional
import numpy
class ShardedParamV2(object):
def __init__(self, param: torch.nn.Parameter, process_group: Optional[dist.ProcessGroup] = None) -> None:
self._data_sharded_tensor = ShardedTensor(param.data, process_group)
if param.requires_grad and param.grad is not None:
self._grad_sharded_tensor = ShardedTensor(param.grad, process_group)
param.grad = None
else:
self._grad_sharded_tensor = None
# make sure the shared param is the only owner of payload
param.data = torch.empty([], dtype=param.dtype, device=param.device)
@property
def data(self):
return self._data_sharded_tensor.payload
@data.setter
def data(self, t: torch.Tensor):
self._data_sharded_tensor.payload = t
@property
def grad(self):
return self._grad_sharded_tensor.payload
@grad.setter
def grad(self, t: torch.Tensor):
self._grad_sharded_tensor.payload = t
class ShardedParam(object):

View File

@ -1,9 +1,11 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from copy import deepcopy
from functools import partial
import colossalai
from colossalai.zero.sharded_param.sharded_param import ShardedParamV2
import pytest
import torch
import torch.multiprocessing as mp
@ -11,7 +13,7 @@ from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
from colossalai.utils import free_port
from colossalai.logging import get_dist_logger, disable_existing_loggers
from tests.test_zero_data_parallel.common import Net, CONFIG
from tests.test_zero_data_parallel.common import Net, CONFIG, allclose
def run_shard_tensor(rank, world_size, port):
@ -36,28 +38,33 @@ def test_shard_tensor():
mp.spawn(run_func, nprocs=world_size)
def run_init_shard_param(rank, world_size, port):
def _run_shard_param_v2(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
param = torch.nn.Parameter(data=torch.rand(2, 3))
sparam = ShardedParam(param, None, True)
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])
del sparam
param_shape = (2, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])
param = torch.nn.Parameter(torch.randn(2, 3))
param_ref = deepcopy(param)
sparam = ShardedParamV2(param=param, process_group=None)
param_shape = (2, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [2, 3])
allclose(sparam.data, param_ref.data)
assert (param.data.numel() == 1)
def run_shard_param_check(rank, world_size, port):
@pytest.mark.dist
def test_shard_param_v2():
world_size = 2
run_func = partial(_run_shard_param_v2, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
def _run_test_shard_param(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
param = torch.nn.Parameter(torch.randn(2, 3))
param_ref = deepcopy(param)
sparam = ShardedParamV2(param=param, process_group=None)
print(sparam.data)
print(param_ref.data)
logger = get_dist_logger()
model = Net()
@ -77,12 +84,31 @@ def run_shard_param_check(rank, world_size, port):
@pytest.mark.dist
def test_shard_shape():
def test_shard_param():
world_size = 2
run_func = partial(run_shard_param_check, world_size=world_size, port=free_port())
run_func = partial(_run_test_shard_param, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size)
def run_init_shard_param(rank, world_size, port):
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
param = torch.nn.Parameter(data=torch.rand(2, 3))
sparam = ShardedParam(param, None, True)
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])
del sparam
param_shape = (2, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=True, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [3])
param_shape = (2, 3)
sparam = ShardedParam(param_shape, process_group=None, is_sharded=False, device=torch.device('cpu'))
payload = sparam.payload(torch.device('cuda'))
assert (list(payload.shape) == [2, 3])
@pytest.mark.dist
def test_init_shard_param():
world_size = 2
@ -92,5 +118,6 @@ def test_init_shard_param():
if __name__ == '__main__':
test_shard_tensor()
test_shard_shape()
test_shard_param()
test_shard_param_v2()
test_init_shard_param()