[zero] a shard strategy in granularity of tensor (#307)

pull/394/head
Jiarui Fang 2022-03-04 11:59:35 +08:00 committed by Frank Lee
parent 80364c7686
commit 74f77e314b
4 changed files with 56 additions and 10 deletions

View File

@ -0,0 +1,4 @@
from colossalai.zero.shard_utils.base_shard_strategy import BaseShardStrategy
from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy
__all__ = ['BaseShardStrategy', 'TensorShardStrategy']

View File

@ -0,0 +1,27 @@
from abc import ABC, abstractmethod
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
import torch.distributed as dist
from typing import List, Optional
class BaseShardStrategy(ABC):
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
self.process_group = process_group
self.world_size = dist.get_world_size(self.process_group)
self.local_rank = dist.get_rank(self.process_group)
super().__init__()
@abstractmethod
def shard(self, tensor_list: List[ShardedTensor]):
r"""
sharded the memory of tensor on multiple processes.
"""
pass
@abstractmethod
def gather(self, tensor_list: List[ShardedTensor]):
r"""
duplicate tensor payload on each processes.
"""
pass

View File

@ -0,0 +1,18 @@
from colossalai.zero.shard_utils import BaseShardStrategy
import torch.distributed as dist
from typing import List, Optional
from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor
class TensorShardStrategy(BaseShardStrategy):
def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None:
super().__init__(process_group)
def shard(self, tensor_list: List[ShardedTensor]):
for t in tensor_list:
t.shard()
def gather(self, tensor_list: List[ShardedTensor]):
for t in tensor_list:
t.gather()

View File

@ -7,6 +7,8 @@ import colossalai
import pytest
import torch
import torch.multiprocessing as mp
from colossalai.zero.shard_utils import TensorShardStrategy
from colossalai.zero.sharded_param import ShardedTensor, ShardedParam
from colossalai.utils import free_port
from colossalai.logging import get_dist_logger, disable_existing_loggers
@ -18,19 +20,14 @@ def run_shard_tensor(rank, world_size, port):
t = ShardedTensor(tensor=torch.randn(world_size * 2, 3))
assert list(t.shape) == [world_size * 2, 3]
t.shard()
# The shape is flattened
assert list(t.shape) == [6]
# Do nothing
t.shard()
assert list(t.shape) == [6]
shard_strategy = TensorShardStrategy(process_group=None)
t.gather()
# test shard strategy
shard_strategy.shard([t])
assert list(t.shape) == [6]
shard_strategy.gather([t])
assert list(t.shape) == [world_size * 2, 3]
t.payload = torch.zeros(world_size * 2, 3)
assert torch.sum(t.payload).cpu() == 0
@pytest.mark.dist
def test_shard_tensor():