mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
86 lines
2.5 KiB
86 lines
2.5 KiB
from contextlib import nullcontext |
|
|
|
import pytest |
|
import torch |
|
import torch.distributed as dist |
|
from torch.nn.parallel import DistributedDataParallel as DDP |
|
|
|
import colossalai |
|
from colossalai.cluster import DistCoordinator |
|
from colossalai.lazy import LazyInitContext |
|
from colossalai.logging import disable_existing_loggers |
|
from colossalai.shardformer import ShardConfig, ShardFormer |
|
from colossalai.testing import clear_cache_before_run, parameterize, rerun_if_address_is_in_use, spawn |
|
from tests.kit.model_zoo import model_zoo |
|
|
|
|
|
@parameterize("lazy_init", [True, False]) |
|
def check_shardformer_with_ddp(lazy_init: bool): |
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_gpt") |
|
|
|
# create shardformer |
|
# ranks: [0, 1, 2, 3] |
|
# tp ranks = [0, 1], [2, 3] |
|
# dp ranks = [0, 2], [1, 3] |
|
dp_process_group_1 = dist.new_group([0, 2]) |
|
dp_process_group_2 = dist.new_group([1, 3]) |
|
tp_process_group_1 = dist.new_group([0, 1]) |
|
tp_process_group_2 = dist.new_group([2, 3]) |
|
|
|
coordinator = DistCoordinator() |
|
|
|
if coordinator.rank in [0, 1]: |
|
tp_process_group = tp_process_group_1 |
|
else: |
|
tp_process_group = tp_process_group_2 |
|
|
|
if coordinator.rank in [0, 2]: |
|
dp_process_group = dp_process_group_1 |
|
else: |
|
dp_process_group = dp_process_group_2 |
|
|
|
shard_config = ShardConfig(tensor_parallel_process_group=tp_process_group, enable_fused_normalization=True) |
|
shardformer = ShardFormer(shard_config=shard_config) |
|
|
|
ctx = LazyInitContext() if lazy_init else nullcontext() |
|
|
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): |
|
# create and shard model |
|
with ctx: |
|
model = model_fn().cuda() |
|
sharded_model, _ = shardformer.optimize(model) |
|
|
|
# add ddp |
|
sharded_ddp_model = DDP(sharded_model, process_group=dp_process_group) |
|
|
|
# prepare input |
|
data = data_gen_fn() |
|
data = {k: v.cuda() for k, v in data.items()} |
|
|
|
# switch to train mode |
|
sharded_ddp_model.train() |
|
|
|
# run forward |
|
output = sharded_ddp_model(**data) |
|
loss = loss_fn(output) |
|
|
|
# backward |
|
loss.backward() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def run_dist(rank, world_size, port): |
|
disable_existing_loggers() |
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") |
|
check_shardformer_with_ddp() |
|
|
|
|
|
@pytest.mark.dist |
|
@rerun_if_address_is_in_use() |
|
@clear_cache_before_run() |
|
def test_gpt2(): |
|
spawn(run_dist, 4) |
|
|
|
|
|
if __name__ == "__main__": |
|
test_gpt2()
|
|
|