mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] integrate with data parallelism (#4103)
parent
f3b6aaa6b7
commit
6a88bae4ec
|
@ -1,5 +1,8 @@
|
|||
from dataclasses import dataclass
|
||||
|
||||
import torch.distributed as dist
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from colossalai.cluster.dist_coordinator import DistCoordinator
|
||||
|
||||
__all__ = ['ShardConfig']
|
||||
|
@ -11,10 +14,10 @@ class ShardConfig:
|
|||
The config for sharding the huggingface model
|
||||
|
||||
Args:
|
||||
tensor_parallel_size (int): The size of tensor parallel
|
||||
tensor_parallel_process_group (int): The process group for tensor parallelism, defaults to None, which is the global process group.
|
||||
enable_fused_normalization (bool): Whether to use fused layernorm, default is False
|
||||
"""
|
||||
tensor_parallel_size: int
|
||||
tensor_parallel_process_group: int = None
|
||||
enable_fused_normalization: bool = False
|
||||
|
||||
# TODO: add support for tensor parallel
|
||||
|
@ -25,10 +28,5 @@ class ShardConfig:
|
|||
# gather_output: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
# ensure the parallel size can match the world size
|
||||
world_size = coordinator.world_size
|
||||
self.data_parallel_size = world_size // self.tensor_parallel_size
|
||||
assert world_size == self.data_parallel_size * self.tensor_parallel_size, \
|
||||
f"The world size ({world_size}) should be divisible by the data parallel size {self.data_parallel_size} and tensor parallel size {self.tensor_parallel_size}"
|
||||
# get the parallel size
|
||||
self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
|
||||
|
|
|
@ -22,16 +22,10 @@ class ModelSharder(object):
|
|||
shard_config: The setting of distributed model
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
policy: Policy,
|
||||
shard_config: ShardConfig = None, # TODO
|
||||
pg_manager: ProcessGroupManager = None) -> None:
|
||||
def __init__(self, model: nn.Module, policy: Policy, shard_config: ShardConfig = None) -> None:
|
||||
self.model = model
|
||||
self.policy = get_autopolicy(self.model) if policy is None else policy
|
||||
self.shard_config = shard_config
|
||||
self.pg_manager = pg_manager
|
||||
|
||||
def shard(self) -> None:
|
||||
r"""
|
||||
|
@ -198,7 +192,8 @@ class ModelSharder(object):
|
|||
continue
|
||||
|
||||
try:
|
||||
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
|
||||
replace_layer = target_module.from_native_module(native_sub_module,
|
||||
self.shard_config.tensor_parallel_process_group,
|
||||
**kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
import torch.nn as nn
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
from colossalai.cluster import DistCoordinator, ProcessGroupManager
|
||||
from colossalai.cluster import DistCoordinator
|
||||
|
||||
from ..policies.basepolicy import Policy
|
||||
from .shard_config import ShardConfig
|
||||
|
@ -28,7 +27,6 @@ class ShardFormer:
|
|||
tensor_parallel_mode='1d',
|
||||
)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
model = shard_former.shard_model(org_model)
|
||||
```
|
||||
"""
|
||||
|
@ -41,19 +39,6 @@ class ShardFormer:
|
|||
"""
|
||||
self.coordinator = DistCoordinator()
|
||||
self.shard_config = shard_config
|
||||
self.pg_manager = None
|
||||
|
||||
def init_distributed(self) -> ProcessGroupManager:
|
||||
"""
|
||||
Initialize the distributed process group according to the
|
||||
"""
|
||||
# create process group manager and 1d process group
|
||||
# TODO: may need to support other parallel mode when the config has such as field
|
||||
pg_manager = ProcessGroupManager()
|
||||
pg_manager.create_process_group(name='tp1d', ranks=range(self.coordinator.world_size))
|
||||
self.pg_manager = pg_manager
|
||||
|
||||
return pg_manager
|
||||
|
||||
def shard_model(self, model: nn.Module, policy: Policy = None):
|
||||
r"""
|
||||
|
@ -64,12 +49,6 @@ class ShardFormer:
|
|||
shard_config (`ShardConfig`): the config for distribute information
|
||||
policy (`Policy`): the custom policy for sharding
|
||||
"""
|
||||
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy, pg_manager=self.pg_manager)
|
||||
sharder = ModelSharder(model=model, shard_config=self.shard_config, policy=policy)
|
||||
sharder.shard()
|
||||
return model
|
||||
|
||||
def shard_dataset(self, dataset: Dataset):
|
||||
"""
|
||||
Shard dataset for DP
|
||||
"""
|
||||
pass
|
||||
|
|
|
@ -3,17 +3,15 @@ import copy
|
|||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
||||
|
||||
def build_model(world_size, model_fn):
|
||||
def build_model(model_fn):
|
||||
# create new model
|
||||
org_model = model_fn().cuda()
|
||||
|
||||
# shard model
|
||||
shard_config = ShardConfig(tensor_parallel_size=world_size, enable_fused_normalization=True)
|
||||
shard_config = ShardConfig(enable_fused_normalization=True)
|
||||
model_copy = copy.deepcopy(org_model)
|
||||
shard_former = ShardFormer(shard_config=shard_config)
|
||||
shard_former.init_distributed()
|
||||
sharded_model = shard_former.shard_model(model_copy).cuda()
|
||||
|
||||
return org_model, sharded_model
|
||||
|
||||
|
||||
|
|
|
@ -42,7 +42,7 @@ def check_bert(rank, world_size, port):
|
|||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bert')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -42,7 +42,7 @@ def check_bloom(rank, world_size, port):
|
|||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_bloom')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -43,7 +43,7 @@ def check_gpt2(rank, world_size, port):
|
|||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -50,7 +50,7 @@ def check_llama(rank, world_size, port):
|
|||
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -54,7 +54,7 @@ def check_OPTModel(rank, world_size, port):
|
|||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_opt')
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -42,7 +42,7 @@ def check_t5(rank, world_size, port):
|
|||
sub_model_zoo = model_zoo.get_sub_registry('transformers_t5')
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
org_model, sharded_model = build_model(world_size, model_fn)
|
||||
org_model, sharded_model = build_model(model_fn)
|
||||
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
|
|
@ -0,0 +1,77 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
||||
import colossalai
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.logging import disable_existing_loggers
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
from colossalai.testing import clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
|
||||
|
||||
def check_shardformer_with_ddp(rank, world_size, port):
|
||||
disable_existing_loggers()
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry('transformers_gpt')
|
||||
|
||||
# create shardformer
|
||||
# ranks: [0, 1, 2, 3]
|
||||
# tp ranks = [0, 1], [2, 3]
|
||||
# dp ranks = [0, 2], [1, 3]
|
||||
dp_process_group_1 = dist.new_group([0, 2])
|
||||
dp_process_group_2 = dist.new_group([1, 3])
|
||||
tp_process_group_1 = dist.new_group([0, 1])
|
||||
tp_process_group_2 = dist.new_group([2, 3])
|
||||
|
||||
coordinator = DistCoordinator()
|
||||
|
||||
if coordinator.rank in [0, 1]:
|
||||
tp_process_group = tp_process_group_1
|
||||
else:
|
||||
tp_process_group = tp_process_group_2
|
||||
|
||||
if coordinator.rank in [0, 2]:
|
||||
dp_process_group = dp_process_group_1
|
||||
else:
|
||||
dp_process_group = dp_process_group_2
|
||||
|
||||
shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True)
|
||||
shardformer = ShardFormer(shard_config=shard_config)
|
||||
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
# create and shard model
|
||||
model = model_fn().cuda()
|
||||
sharded_model = shardformer.shard_model(model)
|
||||
|
||||
# add ddp
|
||||
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group)
|
||||
|
||||
# prepare input
|
||||
data = data_gen_fn()
|
||||
data = {k: v.cuda() for k, v in data.items()}
|
||||
|
||||
# switch to train mode
|
||||
sharded_ddp_model.train()
|
||||
|
||||
# run forward
|
||||
output = sharded_ddp_model(**data)
|
||||
loss = loss_fn(output)
|
||||
|
||||
# backward
|
||||
loss.backward()
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
@clear_cache_before_run()
|
||||
def test_gpt2():
|
||||
spawn(check_shardformer_with_ddp, 4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_gpt2()
|
||||
test_gpt2()
|
Loading…
Reference in New Issue