[shardformer] integrate with data parallelism (#4103)

pull/4157/head
Frank Lee 2023-06-30 09:58:08 +08:00
parent f3b6aaa6b7
commit 6a88bae4ec
11 changed files with 97 additions and 50 deletions

View File

@ -1,5 +1,8 @@
from dataclasses import dataclass from dataclasses import dataclass
import torch.distributed as dist
from torch.distributed import ProcessGroup
from colossalai.cluster.dist_coordinator import DistCoordinator from colossalai.cluster.dist_coordinator import DistCoordinator
__all__ = ['ShardConfig'] __all__ = ['ShardConfig']
@ -11,10 +14,10 @@ class ShardConfig:
The config for sharding the huggingface model The config for sharding the huggingface model
Args: Args:
tensor_parallel_size (int): The size of tensor parallel tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
enable_fused_normalization (bool): Whether to use fused layernorm, default is False enable_fused_normalization (bool): Whether to use fused layernorm, default is False
""" """
tensor_parallel_size: int tensor_parallel_process_group: int = None
enable_fused_normalization: bool = False enable_fused_normalization: bool = False
# TODO: add support for tensor parallel # TODO: add support for tensor parallel
@ -25,10 +28,5 @@ class ShardConfig:
# gather_output: bool = True # gather_output: bool = True
def __post_init__(self): def __post_init__(self):
coordinator = DistCoordinator() # get the parallel size
self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
# ensure the parallel size can match the world size
world_size = coordinator.world_size
self.data_parallel_size = world_size // self.tensor_parallel_size
assert world_size == self.data_parallel_size * self.tensor_parallel_size, \
f"The world size ({world_size}) should be divisible by the data parallel size {self.data_parallel_size} and tensor parallel size {self.tensor_parallel_size}"

View File

@ -22,16 +22,10 @@ class ModelSharder(object):
shard_config: The setting of distributed model shard_config: The setting of distributed model
""" """
def __init__( def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
self,
model: nn.Module,
policy: Policy,
shard_config: ShardConfig = None, # TODO
pg_manager: ProcessGroupManager = None) -> None:
self.model = model self.model = model
self.policy = get_autopolicy(self.model) if policy is None else policy self.policy = get_autopolicy(self.model) if policy is None else policy
self.shard_config = shard_config self.shard_config = shard_config
self.pg_manager = pg_manager
def shard(self) -> None: def shard(self) -> None:
r""" r"""
@ -198,7 +192,8 @@ class ModelSharder(object):
continue continue
try: try:
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], replace_layer = target_module.from_native_module(native_sub_module,
self.shard_config.tensor_parallel_process_group,
**kwargs) **kwargs)
except Exception as e: except Exception as e:
raise RuntimeError( raise RuntimeError(

View File

@ -1,7 +1,6 @@
import torch.nn as nn import torch.nn as nn
from torch.utils.data import Dataset
from colossalai.cluster import DistCoordinator, ProcessGroupManager from colossalai.cluster import DistCoordinator
from ..policies.basepolicy import Policy from ..policies.basepolicy import Policy
from .shard_config import ShardConfig from .shard_config import ShardConfig
@ -28,7 +27,6 @@ class ShardFormer:
tensor_parallel_mode='1d', tensor_parallel_mode='1d',
) )
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
model = shard_former.shard_model(org_model) model = shard_former.shard_model(org_model)
``` ```
""" """
@ -41,19 +39,6 @@ class ShardFormer:
""" """
self.coordinator = DistCoordinator() self.coordinator = DistCoordinator()
self.shard_config = shard_config self.shard_config = shard_config
self.pg_manager = None
def init_distributed(self) -> ProcessGroupManager:
"""
Initialize the distributed process group according to the
"""
# create process group manager and 1d process group
# TODO: may need to support other parallel mode when the config has such as field
pg_manager = ProcessGroupManager()
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
self.pg_manager = pg_manager
return pg_manager
def shard_model(self, model: nn.Module, policy: Policy = None): def shard_model(self, model: nn.Module, policy: Policy = None):
r""" r"""
@ -64,12 +49,6 @@ class ShardFormer:
shard_config (`ShardConfig`): the config for distribute information shard_config (`ShardConfig`): the config for distribute information
policy (`Policy`): the custom policy for sharding policy (`Policy`): the custom policy for sharding
""" """
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager) sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
sharder.shard() sharder.shard()
return model return model
def shard_dataset(self, dataset: Dataset):
"""
Shard dataset for DP
"""
pass

View File

@ -3,17 +3,15 @@ import copy
from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.shardformer import ShardConfig, ShardFormer
def build_model(world_size, model_fn): def build_model(model_fn):
# create new model # create new model
org_model = model_fn().cuda() org_model = model_fn().cuda()
# shard model # shard model
shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True) shard_config = ShardConfig(enable_fused_normalization=True)
model_copy = copy.deepcopy(org_model) model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config) shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy).cuda() sharded_model = shard_former.shard_model(model_copy).cuda()
return org_model, sharded_model return org_model, sharded_model

View File

@ -42,7 +42,7 @@ def check_bert(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert') sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn) org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -42,7 +42,7 @@ def check_bloom(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom') sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn) org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -43,7 +43,7 @@ def check_gpt2(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt') sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn) org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -50,7 +50,7 @@ def check_llama(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn) org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -54,7 +54,7 @@ def check_OPTModel(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt') sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn) org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -42,7 +42,7 @@ def check_t5(rank, world_size, port):
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn) org_model, sharded_model = build_model(model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
torch.cuda.empty_cache() torch.cuda.empty_cache()

View File

@ -0,0 +1,77 @@
import pytest
import torch
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
import colossalai
from colossalai.cluster import DistCoordinator
from colossalai.logging import disable_existing_loggers
from colossalai.shardformer import ShardConfig, ShardFormer
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
def check_shardformer_with_ddp(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
# create shardformer
# ranks: [0, 1, 2, 3]
# tp ranks = [0, 1], [2, 3]
# dp ranks = [0, 2], [1, 3]
dp_process_group_1 = dist.new_group([0, 2])
dp_process_group_2 = dist.new_group([1, 3])
tp_process_group_1 = dist.new_group([0, 1])
tp_process_group_2 = dist.new_group([2, 3])
coordinator = DistCoordinator()
if coordinator.rank in [0, 1]:
tp_process_group = tp_process_group_1
else:
tp_process_group = tp_process_group_2
if coordinator.rank in [0, 2]:
dp_process_group = dp_process_group_1
else:
dp_process_group = dp_process_group_2
shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True)
shardformer = ShardFormer(shard_config=shard_config)
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
# create and shard model
model = model_fn().cuda()
sharded_model = shardformer.shard_model(model)
# add ddp
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}
# switch to train mode
sharded_ddp_model.train()
# run forward
output = sharded_ddp_model(**data)
loss = loss_fn(output)
# backward
loss.backward()
torch.cuda.empty_cache()
@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_gpt2():
spawn(check_shardformer_with_ddp, 4)
if __name__ == "__main__":
test_gpt2()
test_gpt2()