mirror of https://github.com/hpcaitech/ColossalAI
[zero] polish shard strategy (#310)
* init shard param from shape tuple * add more unitest for shard param * add set_payload method for ShardedParam * [zero] add shareded tensor class * polish code * add shard stratgy * move shard and gather logic to shard strategy from shard tensor. * polish codepull/394/head
parent
3092317b80
commit
c9e7d9582d
|
@ -1,7 +1,11 @@
|
||||||
from colossalai.zero.shard_utils import BaseShardStrategy
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
||||||
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
|
||||||
|
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
||||||
|
|
||||||
|
|
||||||
class TensorShardStrategy(BaseShardStrategy):
|
class TensorShardStrategy(BaseShardStrategy):
|
||||||
|
@ -11,8 +15,35 @@ class TensorShardStrategy(BaseShardStrategy):
|
||||||
|
|
||||||
def shard(self, tensor_list: List[ShardedTensor]):
|
def shard(self, tensor_list: List[ShardedTensor]):
|
||||||
for t in tensor_list:
|
for t in tensor_list:
|
||||||
t.shard()
|
self._shard_tensor(t)
|
||||||
|
|
||||||
def gather(self, tensor_list: List[ShardedTensor]):
|
def gather(self, tensor_list: List[ShardedTensor]):
|
||||||
for t in tensor_list:
|
for t in tensor_list:
|
||||||
t.gather()
|
self._gather_tensor(t)
|
||||||
|
|
||||||
|
def _shard_tensor(self, t: ShardedTensor):
|
||||||
|
if t.is_sharded:
|
||||||
|
return
|
||||||
|
sharded_payload, _ = get_shard(t.payload, self.local_rank, self.world_size)
|
||||||
|
t.reset_payload(sharded_payload)
|
||||||
|
t.is_sharded = True
|
||||||
|
|
||||||
|
def _gather_tensor(self, t: ShardedTensor):
|
||||||
|
if not t.is_sharded:
|
||||||
|
return
|
||||||
|
|
||||||
|
buffer_list = []
|
||||||
|
payload_numel = t.payload.numel()
|
||||||
|
for i in range(self.world_size):
|
||||||
|
if i == self.local_rank:
|
||||||
|
buffer_list.append(t.payload.cuda())
|
||||||
|
else:
|
||||||
|
buffer_list.append(torch.zeros(payload_numel).cuda())
|
||||||
|
|
||||||
|
torch.distributed.all_gather(buffer_list,
|
||||||
|
buffer_list[self.local_rank],
|
||||||
|
group=self.process_group,
|
||||||
|
async_op=False)
|
||||||
|
gathered_payload = torch.narrow(torch.cat(buffer_list), 0, 0, t.origin_numel).reshape(t.origin_shape)
|
||||||
|
t.reset_payload(gathered_payload)
|
||||||
|
t.is_sharded = False
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
from colossalai.zero.sharded_model._zero3_utils import get_shard
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
@ -21,47 +20,38 @@ class ShardedTensor(object):
|
||||||
self._origin_numel = tensor.numel()
|
self._origin_numel = tensor.numel()
|
||||||
self._origin_dtype = tensor.dtype
|
self._origin_dtype = tensor.dtype
|
||||||
|
|
||||||
|
@property
|
||||||
|
def origin_numel(self):
|
||||||
|
return self._origin_numel
|
||||||
|
|
||||||
|
@property
|
||||||
|
def origin_shape(self):
|
||||||
|
return self._origin_shape
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def is_sharded(self):
|
def is_sharded(self):
|
||||||
return self._is_sharded
|
return self._is_sharded
|
||||||
|
|
||||||
|
@is_sharded.setter
|
||||||
|
def is_sharded(self, flag: bool):
|
||||||
|
self._is_sharded = flag
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def payload(self):
|
def payload(self):
|
||||||
return self._payload
|
return self._payload
|
||||||
|
|
||||||
@payload.setter
|
def copy_payload(self, tensor):
|
||||||
def payload(self, tensor):
|
|
||||||
self._payload.copy_(tensor)
|
self._payload.copy_(tensor)
|
||||||
|
|
||||||
|
def reset_payload(self, tensor):
|
||||||
|
del self._payload
|
||||||
|
self._payload = tensor
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self):
|
def dtype(self):
|
||||||
|
assert self._payload.dtype == self._origin_dtype
|
||||||
return self._origin_dtype
|
return self._origin_dtype
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def shape(self):
|
def shape(self):
|
||||||
return self._payload.shape
|
return self._payload.shape
|
||||||
|
|
||||||
def shard(self):
|
|
||||||
if self._is_sharded:
|
|
||||||
return
|
|
||||||
self._payload, _ = get_shard(self._payload, self.local_rank, self.world_size)
|
|
||||||
self._is_sharded = True
|
|
||||||
|
|
||||||
def gather(self):
|
|
||||||
if not self._is_sharded:
|
|
||||||
return
|
|
||||||
|
|
||||||
buffer_list = []
|
|
||||||
payload_numel = self._payload.numel()
|
|
||||||
for i in range(self.world_size):
|
|
||||||
if i == self.local_rank:
|
|
||||||
buffer_list.append(self._payload.cuda())
|
|
||||||
else:
|
|
||||||
buffer_list.append(torch.zeros(payload_numel).cuda())
|
|
||||||
|
|
||||||
torch.distributed.all_gather(buffer_list,
|
|
||||||
buffer_list[self.local_rank],
|
|
||||||
group=self.process_group,
|
|
||||||
async_op=False)
|
|
||||||
self._payload = torch.narrow(torch.cat(buffer_list), 0, 0, self._origin_numel).view(self._origin_shape)
|
|
||||||
self._is_sharded = False
|
|
||||||
|
|
|
@ -7,7 +7,6 @@ import colossalai
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
from colossalai.zero.shard_utils import TensorShardStrategy
|
from colossalai.zero.shard_utils import TensorShardStrategy
|
||||||
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
|
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
|
@ -18,15 +17,16 @@ from tests.test_zero_data_parallel.common import Net, CONFIG
|
||||||
def run_shard_tensor(rank, world_size, port):
|
def run_shard_tensor(rank, world_size, port):
|
||||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
|
||||||
|
assert list(t.origin_shape) == [world_size * 2, 3]
|
||||||
assert list(t.shape) == [world_size * 2, 3]
|
assert list(t.shape) == [world_size * 2, 3]
|
||||||
|
|
||||||
shard_strategy = TensorShardStrategy(process_group=None)
|
shard_strategy = TensorShardStrategy(process_group=None)
|
||||||
|
|
||||||
# test shard strategy
|
# test shard strategy
|
||||||
shard_strategy.shard([t])
|
shard_strategy.shard([t])
|
||||||
assert list(t.shape) == [6]
|
assert list(t.shape) == [6], f"{list(t.shape)} vs 6"
|
||||||
shard_strategy.gather([t])
|
shard_strategy.gather([t])
|
||||||
assert list(t.shape) == [world_size * 2, 3]
|
assert list(t.shape) == [world_size * 2, 3], f"{list(t.shape)} vs {[world_size * 2, 3]}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
|
|
Loading…
Reference in New Issue