mirror of https://github.com/hpcaitech/ColossalAI
[zero] sharded tensor (#305)
* init shard param from shape tuple * add more unitest for shard param * add set_payload method for ShardedParam * [zero] add shareded tensor class * polish codepull/394/head
parent
d344689274
commit
80364c7686
|
@ -1,3 +1,4 @@
|
||||||
from .sharded_param import ShardedParam
|
from colossalai.zero.sharded_param.sharded_param import ShardedParam
|
||||||
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
|
|
||||||
__all__ = ['ShardedParam']
|
__all__ = ['ShardedParam', 'ShardedTensor']
|
||||||
|
|
|
@ -56,6 +56,13 @@ class ShardedParam(object):
|
||||||
"""
|
"""
|
||||||
return self._param_payload.to(target_device)
|
return self._param_payload.to(target_device)
|
||||||
|
|
||||||
|
def set_payload(self, data: torch.Tensor):
|
||||||
|
r"""
|
||||||
|
set payload as data
|
||||||
|
"""
|
||||||
|
assert self._param_payload.numel() == data.numel()
|
||||||
|
self._param_payload.copy_(data)
|
||||||
|
|
||||||
def shard(self):
|
def shard(self):
|
||||||
r"""
|
r"""
|
||||||
Distributed the payload of param to all processes.
|
Distributed the payload of param to all processes.
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
|
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
class ShardedTensor(object):
|
||||||
|
|
||||||
|
def __init__(self, tensor: torch.Tensor, process_group: Optional[dist.ProcessGroup] = None) -> None:
|
||||||
|
r"""
|
||||||
|
A tensor sharded in multiple processes.
|
||||||
|
"""
|
||||||
|
self._payload = tensor
|
||||||
|
self.process_group = process_group
|
||||||
|
self.world_size = dist.get_world_size(self.process_group)
|
||||||
|
self.local_rank = dist.get_rank(self.process_group)
|
||||||
|
self._is_sharded = False
|
||||||
|
self._payload = tensor
|
||||||
|
|
||||||
|
self._origin_shape = tensor.shape
|
||||||
|
self._origin_numel = tensor.numel()
|
||||||
|
self._origin_dtype = tensor.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_sharded(self):
|
||||||
|
return self._is_sharded
|
||||||
|
|
||||||
|
@property
|
||||||
|
def payload(self):
|
||||||
|
return self._payload
|
||||||
|
|
||||||
|
@payload.setter
|
||||||
|
def payload(self, tensor):
|
||||||
|
self._payload.copy_(tensor)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self._origin_dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def shape(self):
|
||||||
|
return self._payload.shape
|
||||||
|
|
||||||
|
def shard(self):
|
||||||
|
if self._is_sharded:
|
||||||
|
return
|
||||||
|
self._payload, _ = get_shard(self._payload, self.local_rank, self.world_size)
|
||||||
|
self._is_sharded = True
|
||||||
|
|
||||||
|
def gather(self):
|
||||||
|
if not self._is_sharded:
|
||||||
|
return
|
||||||
|
|
||||||
|
buffer_list = []
|
||||||
|
payload_numel = self._payload.numel()
|
||||||
|
for i in range(self.world_size):
|
||||||
|
if i == self.local_rank:
|
||||||
|
buffer_list.append(self._payload.cuda())
|
||||||
|
else:
|
||||||
|
buffer_list.append(torch.zeros(payload_numel).cuda())
|
||||||
|
|
||||||
|
torch.distributed.all_gather(buffer_list,
|
||||||
|
buffer_list[self.local_rank],
|
||||||
|
group=self.process_group,
|
||||||
|
async_op=False)
|
||||||
|
self._payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
|
||||||
|
self._is_sharded = False
|
|
@ -7,12 +7,38 @@ import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
from colossalai.zero.sharded_param import ShardedParam
|
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
from colossalai.logging import get_dist_logger, disable_existing_loggers
|
||||||
from tests.test_zero_data_parallel.common import Net, CONFIG
|
from tests.test_zero_data_parallel.common import Net, CONFIG
|
||||||
|
|
||||||
|
|
||||||
|
def run_shard_tensor(rank, world_size, port):
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
||||||
|
|
||||||
|
assert list(t.shape) == [world_size * 2, 3]
|
||||||
|
t.shard()
|
||||||
|
# The shape is flattened
|
||||||
|
assert list(t.shape) == [6]
|
||||||
|
# Do nothing
|
||||||
|
t.shard()
|
||||||
|
assert list(t.shape) == [6]
|
||||||
|
|
||||||
|
t.gather()
|
||||||
|
assert list(t.shape) == [world_size * 2, 3]
|
||||||
|
|
||||||
|
t.payload = torch.zeros(world_size * 2, 3)
|
||||||
|
assert torch.sum(t.payload).cpu() == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
def test_shard_tensor():
|
||||||
|
world_size = 2
|
||||||
|
run_func = partial(run_shard_tensor, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
def run_init_shard_param(rank, world_size, port):
|
def run_init_shard_param(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
param = torch.nn.Parameter(data=torch.rand(2, 3))
|
param = torch.nn.Parameter(data=torch.rand(2, 3))
|
||||||
|
@ -68,5 +94,6 @@ def test_init_shard_param():
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
|
test_shard_tensor()
|
||||||
test_shard_shape()
|
test_shard_shape()
|
||||||
test_init_shard_param()
|
test_init_shard_param()
|
||||||
|
|
Loading…
Reference in New Issue