From 6a88bae4ec26b11261047e5462a5c2ed6bfe41f4 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Fri, 30 Jun 2023 09:58:08 +0800 Subject: [PATCH] [shardformer] integrate with data parallelism (#4103) --- colossalai/shardformer/shard/shard_config.py | 16 ++-- colossalai/shardformer/shard/sharder.py | 11 +-- colossalai/shardformer/shard/shardformer.py | 25 +----- tests/test_shardformer/test_model/_utils.py | 6 +- .../test_model/test_shard_bert.py | 2 +- .../test_model/test_shard_bloom.py | 2 +- .../test_model/test_shard_gpt2.py | 2 +- .../test_model/test_shard_llama.py | 2 +- .../test_model/test_shard_opt.py | 2 +- .../test_model/test_shard_t5.py | 2 +- tests/test_shardformer/test_with_torch_ddp.py | 77 +++++++++++++++++++ 11 files changed, 97 insertions(+), 50 deletions(-) create mode 100644 tests/test_shardformer/test_with_torch_ddp.py diff --git a/colossalai/shardformer/shard/shard_config.py b/colossalai/shardformer/shard/shard_config.py index 428ebc978..e83191210 100644 --- a/colossalai/shardformer/shard/shard_config.py +++ b/colossalai/shardformer/shard/shard_config.py @@ -1,5 +1,8 @@ from dataclasses import dataclass +import torch.distributed as dist +from torch.distributed import ProcessGroup + from colossalai.cluster.dist_coordinator import DistCoordinator __all__ = ['ShardConfig'] @@ -11,10 +14,10 @@ class ShardConfig: The config for sharding the huggingface model Args: - tensor_parallel_size (int): The size of tensor parallel + tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group. enable_fused_normalization (bool): Whether to use fused layernorm, default is False """ - tensor_parallel_size: int + tensor_parallel_process_group: int = None enable_fused_normalization: bool = False # TODO: add support for tensor parallel @@ -25,10 +28,5 @@ class ShardConfig: # gather_output: bool = True def __post_init__(self): - coordinator = DistCoordinator() - - # ensure the parallel size can match the world size - world_size = coordinator.world_size - self.data_parallel_size = world_size // self.tensor_parallel_size - assert world_size == self.data_parallel_size * self.tensor_parallel_size, \ - f"The world size ({world_size}) should be divisible by the data parallel size {self.data_parallel_size} and tensor parallel size {self.tensor_parallel_size}" + # get the parallel size + self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group) diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index c2444e1f7..e9b27ea45 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -22,16 +22,10 @@ class ModelSharder(object): shard_config: The setting of distributed model """ - def __init__( - self, - model: nn.Module, - policy: Policy, - shard_config: ShardConfig = None, # TODO - pg_manager: ProcessGroupManager = None) -> None: + def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None: self.model = model self.policy = get_autopolicy(self.model) if policy is None else policy self.shard_config = shard_config - self.pg_manager = pg_manager def shard(self) -> None: r""" @@ -198,7 +192,8 @@ class ModelSharder(object): continue try: - replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], + replace_layer = target_module.from_native_module(native_sub_module, + self.shard_config.tensor_parallel_process_group, **kwargs) except Exception as e: raise RuntimeError( diff --git a/colossalai/shardformer/shard/shardformer.py b/colossalai/shardformer/shard/shardformer.py index 1208a9d09..7c4220c3a 100644 --- a/colossalai/shardformer/shard/shardformer.py +++ b/colossalai/shardformer/shard/shardformer.py @@ -1,7 +1,6 @@ import torch.nn as nn -from torch.utils.data import Dataset -from colossalai.cluster import DistCoordinator, ProcessGroupManager +from colossalai.cluster import DistCoordinator from ..policies.basepolicy import Policy from .shard_config import ShardConfig @@ -28,7 +27,6 @@ class ShardFormer: tensor_parallel_mode='1d', ) shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() model = shard_former.shard_model(org_model) ``` """ @@ -41,19 +39,6 @@ class ShardFormer: """ self.coordinator = DistCoordinator() self.shard_config = shard_config - self.pg_manager = None - - def init_distributed(self) -> ProcessGroupManager: - """ - Initialize the distributed process group according to the - """ - # create process group manager and 1d process group - # TODO: may need to support other parallel mode when the config has such as field - pg_manager = ProcessGroupManager() - pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size)) - self.pg_manager = pg_manager - - return pg_manager def shard_model(self, model: nn.Module, policy: Policy = None): r""" @@ -64,12 +49,6 @@ class ShardFormer: shard_config (`ShardConfig`): the config for distribute information policy (`Policy`): the custom policy for sharding """ - sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager) + sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy) sharder.shard() return model - - def shard_dataset(self, dataset: Dataset): - """ - Shard dataset for DP - """ - pass diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py index e49b0246c..a6355bf1c 100644 --- a/tests/test_shardformer/test_model/_utils.py +++ b/tests/test_shardformer/test_model/_utils.py @@ -3,17 +3,15 @@ import copy from colossalai.shardformer import ShardConfig, ShardFormer -def build_model(world_size, model_fn): +def build_model(model_fn): # create new model org_model = model_fn().cuda() # shard model - shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True) + shard_config = ShardConfig(enable_fused_normalization=True) model_copy = copy.deepcopy(org_model) shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() sharded_model = shard_former.shard_model(model_copy).cuda() - return org_model, sharded_model diff --git a/tests/test_shardformer/test_model/test_shard_bert.py b/tests/test_shardformer/test_model/test_shard_bert.py index ad98e3d07..a089a1ab3 100644 --- a/tests/test_shardformer/test_model/test_shard_bert.py +++ b/tests/test_shardformer/test_model/test_shard_bert.py @@ -42,7 +42,7 @@ def check_bert(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_bloom.py b/tests/test_shardformer/test_model/test_shard_bloom.py index 7e2e3dfa8..2e7ae7067 100644 --- a/tests/test_shardformer/test_model/test_shard_bloom.py +++ b/tests/test_shardformer/test_model/test_shard_bloom.py @@ -42,7 +42,7 @@ def check_bloom(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_gpt2.py b/tests/test_shardformer/test_model/test_shard_gpt2.py index 676267c2c..4d4dc3c1e 100644 --- a/tests/test_shardformer/test_model/test_shard_gpt2.py +++ b/tests/test_shardformer/test_model/test_shard_gpt2.py @@ -43,7 +43,7 @@ def check_gpt2(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index 8b672af50..763fb2a6b 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -50,7 +50,7 @@ def check_llama(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_opt.py b/tests/test_shardformer/test_model/test_shard_opt.py index 4d4c55770..d70b5d8e5 100644 --- a/tests/test_shardformer/test_model/test_shard_opt.py +++ b/tests/test_shardformer/test_model/test_shard_opt.py @@ -54,7 +54,7 @@ def check_OPTModel(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 6074a902e..6f558e237 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -42,7 +42,7 @@ def check_t5(rank, world_size, port): sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - org_model, sharded_model = build_model(world_size, model_fn) + org_model, sharded_model = build_model(model_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_with_torch_ddp.py b/tests/test_shardformer/test_with_torch_ddp.py new file mode 100644 index 000000000..61b672650 --- /dev/null +++ b/tests/test_shardformer/test_with_torch_ddp.py @@ -0,0 +1,77 @@ +import pytest +import torch +import torch.distributed as dist +from torch.nn.parallel import DistributedDataParallel as DDP + +import colossalai +from colossalai.cluster import DistCoordinator +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer import ShardConfig, ShardFormer +from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo + + +def check_shardformer_with_ddp(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') + + # create shardformer + # ranks: [0, 1, 2, 3] + # tp ranks = [0, 1], [2, 3] + # dp ranks = [0, 2], [1, 3] + dp_process_group_1 = dist.new_group([0, 2]) + dp_process_group_2 = dist.new_group([1, 3]) + tp_process_group_1 = dist.new_group([0, 1]) + tp_process_group_2 = dist.new_group([2, 3]) + + coordinator = DistCoordinator() + + if coordinator.rank in [0, 1]: + tp_process_group = tp_process_group_1 + else: + tp_process_group = tp_process_group_2 + + if coordinator.rank in [0, 2]: + dp_process_group = dp_process_group_1 + else: + dp_process_group = dp_process_group_2 + + shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True) + shardformer = ShardFormer(shard_config=shard_config) + + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): + # create and shard model + model = model_fn().cuda() + sharded_model = shardformer.shard_model(model) + + # add ddp + sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group) + + # prepare input + data = data_gen_fn() + data = {k: v.cuda() for k, v in data.items()} + + # switch to train mode + sharded_ddp_model.train() + + # run forward + output = sharded_ddp_model(**data) + loss = loss_fn(output) + + # backward + loss.backward() + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +@clear_cache_before_run() +def test_gpt2(): + spawn(check_shardformer_with_ddp, 4) + + +if __name__ == "__main__": + test_gpt2() + test_gpt2()