[shardformer] integrate with data parallelism (#4103)

pull/4157/head
Frank Lee 2023-06-30 09:58:08 +08:00
parent f3b6aaa6b7
commit 6a88bae4ec
11 changed files with 97 additions and 50 deletions

View File

@ -1,5 +1,8 @@
from dataclasses import dataclass
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.cluster.dist_coordinator import DistCoordinator
__all__ = ['ShardConfig']
@ -11,10 +14,10 @@ class ShardConfig:
The config for sharding the huggingface model
Args:
tensor_parallel_size (int): The size of tensor parallel
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False
"""
tensor_parallel_size: int
tensor_parallel_process_group: int = None
enable_fused_normalization: bool = False
# TODO: add support for tensor parallel
@ -25,10 +28,5 @@ class ShardConfig:
# gather_output: bool = True
def __post_init__(self):
coordinator = DistCoordinator()
# ensure the parallel size can match the world size
world_size = coordinator.world_size
self.data_parallel_size = world_size // self.tensor_parallel_size
assert world_size == self.data_parallel_size * self.tensor_parallel_size, \
f"The world size ({world_size}) should be divisible by the data parallel size {self.data_parallel_size} and tensor parallel size {self.tensor_parallel_size}"
# get the parallel size
self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)

View File

@ -22,16 +22,10 @@ class ModelSharder(object):
shard_config: The setting of distributed model
"""
def __init__(
self,
model: nn.Module,
policy: Policy,
shard_config: ShardConfig = None, # TODO
pg_manager: ProcessGroupManager = None) -> None:
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy
self.shard_config = shard_config
self.pg_manager = pg_manager
def shard(self) -> None:
r"""
@ -198,7 +192,8 @@ class ModelSharder(object):
continue
try:
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
replace_layer = target_module.from_native_module(native_sub_module,
self.shard_config.tensor_parallel_process_group,
**kwargs)
except Exception as e:
raise RuntimeError(

View File

@ -1,7 +1,6 @@
import torch.nn as nn
from torch.utils.data import Dataset
from colossalai.cluster import DistCoordinator, ProcessGroupManager
from colossalai.cluster import DistCoordinator
from ..policies.basepolicy import Policy
from .shard_config import ShardConfig
@ -28,7 +27,6 @@ class ShardFormer:
tensor_parallel_mode='1d',
)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
model = shard_former.shard_model(org_model)
```
"""
@ -41,19 +39,6 @@ class ShardFormer:
"""
self.coordinator = DistCoordinator()
self.shard_config = shard_config
self.pg_manager = None
def init_distributed(self) -> ProcessGroupManager:
"""
Initialize the distributed process group according to the
"""
# create process group manager and 1d process group
# TODO: may need to support other parallel mode when the config has such as field
pg_manager = ProcessGroupManager()
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
self.pg_manager = pg_manager
return pg_manager
def shard_model(self, model: nn.Module, policy: Policy = None):
r"""
@ -64,12 +49,6 @@ class ShardFormer:
shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding
"""
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager)
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
sharder.shard()
return model
def shard_dataset(self, dataset: Dataset):
"""
Shard dataset for DP
"""
pass

View File

@ -3,17 +3,15 @@ import copy
from colossalai.shardformer import ShardConfig, ShardFormer
def build_model(world_size, model_fn):
def build_model(model_fn):
# create new model
org_model = model_fn().cuda()
# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True)
shard_config = ShardConfig(enable_fused_normalization=True)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy).cuda()
return org_model, sharded_model

View File

@ -42,7 +42,7 @@ def check_bert(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -42,7 +42,7 @@ def check_bloom(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -43,7 +43,7 @@ def check_gpt2(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -50,7 +50,7 @@ def check_llama(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -54,7 +54,7 @@ def check_OPTModel(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -42,7 +42,7 @@ def check_t5(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache()

View File

@ -0,0 +1,77 @@
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def check_shardformer_with_ddp(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
# create shardformer
# ranks: [0, 1, 2, 3]
# tp ranks = [0, 1], [2, 3]
# dp ranks = [0, 2], [1, 3]
dp_process_group_1 = dist.new_group([0, 2])
dp_process_group_2 = dist.new_group([1, 3])
tp_process_group_1 = dist.new_group([0, 1])
tp_process_group_2 = dist.new_group([2, 3])
coordinator = DistCoordinator()
if coordinator.rank in [0, 1]:
tp_process_group = tp_process_group_1
else:
tp_process_group = tp_process_group_2
if coordinator.rank in [0, 2]:
dp_process_group = dp_process_group_1
else:
dp_process_group = dp_process_group_2
shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True)
shardformer = ShardFormer(shard_config=shard_config)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create and shard model
model = model_fn().cuda()
sharded_model = shardformer.shard_model(model)
# add ddp
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# switch to train mode
sharded_ddp_model.train()
# run forward
output = sharded_ddp_model(**data)
loss = loss_fn(output)
# backward
loss.backward()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gpt2():
spawn(check_shardformer_with_ddp, 4)
if __name__ == "__main__":
test_gpt2()
test_gpt2()