From 74f77e314baaab202f155263dbd13353bc89a04f Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Fri, 4 Mar 2022 11:59:35 +0800 Subject: [PATCH] [zero] a shard strategy in granularity of tensor (#307) --- colossalai/zero/shard_utils/__init__.py | 4 +++ .../zero/shard_utils/base_shard_strategy.py | 27 +++++++++++++++++++ .../zero/shard_utils/tensor_shard_strategy.py | 18 +++++++++++++ .../test_shard_param.py | 17 +++++------- 4 files changed, 56 insertions(+), 10 deletions(-) create mode 100644 colossalai/zero/shard_utils/__init__.py create mode 100644 colossalai/zero/shard_utils/base_shard_strategy.py create mode 100644 colossalai/zero/shard_utils/tensor_shard_strategy.py diff --git a/colossalai/zero/shard_utils/__init__.py b/colossalai/zero/shard_utils/__init__.py new file mode 100644 index 000000000..417e201e8 --- /dev/null +++ b/colossalai/zero/shard_utils/__init__.py @@ -0,0 +1,4 @@ +from colossalai.zero.shard_utils.base_shard_strategy import BaseShardStrategy +from colossalai.zero.shard_utils.tensor_shard_strategy import TensorShardStrategy + +__all__ = ['BaseShardStrategy', 'TensorShardStrategy'] diff --git a/colossalai/zero/shard_utils/base_shard_strategy.py b/colossalai/zero/shard_utils/base_shard_strategy.py new file mode 100644 index 000000000..e3f57eca4 --- /dev/null +++ b/colossalai/zero/shard_utils/base_shard_strategy.py @@ -0,0 +1,27 @@ +from abc import ABC, abstractmethod +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor +import torch.distributed as dist +from typing import List, Optional + + +class BaseShardStrategy(ABC): + + def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None: + self.process_group = process_group + self.world_size = dist.get_world_size(self.process_group) + self.local_rank = dist.get_rank(self.process_group) + super().__init__() + + @abstractmethod + def shard(self, tensor_list: List[ShardedTensor]): + r""" + sharded the memory of tensor on multiple processes. + """ + pass + + @abstractmethod + def gather(self, tensor_list: List[ShardedTensor]): + r""" + duplicate tensor payload on each processes. + """ + pass diff --git a/colossalai/zero/shard_utils/tensor_shard_strategy.py b/colossalai/zero/shard_utils/tensor_shard_strategy.py new file mode 100644 index 000000000..2c8f3c904 --- /dev/null +++ b/colossalai/zero/shard_utils/tensor_shard_strategy.py @@ -0,0 +1,18 @@ +from colossalai.zero.shard_utils import BaseShardStrategy +import torch.distributed as dist +from typing import List, Optional +from colossalai.zero.sharded_param.sharded_tensor import ShardedTensor + + +class TensorShardStrategy(BaseShardStrategy): + + def __init__(self, process_group: Optional[dist.ProcessGroup] = None) -> None: + super().__init__(process_group) + + def shard(self, tensor_list: List[ShardedTensor]): + for t in tensor_list: + t.shard() + + def gather(self, tensor_list: List[ShardedTensor]): + for t in tensor_list: + t.gather() diff --git a/tests/test_zero_data_parallel/test_shard_param.py b/tests/test_zero_data_parallel/test_shard_param.py index 4f6eb52b2..4341cf5ff 100644 --- a/tests/test_zero_data_parallel/test_shard_param.py +++ b/tests/test_zero_data_parallel/test_shard_param.py @@ -7,6 +7,8 @@ import colossalai import pytest import torch import torch.multiprocessing as mp + +from colossalai.zero.shard_utils import TensorShardStrategy from colossalai.zero.sharded_param import ShardedTensor, ShardedParam from colossalai.utils import free_port from colossalai.logging import get_dist_logger, disable_existing_loggers @@ -18,19 +20,14 @@ def run_shard_tensor(rank, world_size, port): t = ShardedTensor(tensor=torch.randn(world_size * 2, 3)) assert list(t.shape) == [world_size * 2, 3] - t.shard() - # The shape is flattened - assert list(t.shape) == [6] - # Do nothing - t.shard() - assert list(t.shape) == [6] + shard_strategy = TensorShardStrategy(process_group=None) - t.gather() + # test shard strategy + shard_strategy.shard([t]) + assert list(t.shape) == [6] + shard_strategy.gather([t]) assert list(t.shape) == [world_size * 2, 3] - t.payload = torch.zeros(world_size * 2, 3) - assert torch.sum(t.payload).cpu() == 0 - @pytest.mark.dist def test_shard_tensor():