mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
128 lines
4.6 KiB
128 lines
4.6 KiB
import torch |
|
import pytest |
|
import os |
|
import torch.multiprocessing as mp |
|
import torch.distributed.rpc as rpc |
|
|
|
from torch import nn |
|
from torch._C._distributed_rpc import _is_current_rpc_agent_set |
|
from colossalai import launch |
|
from colossalai.logging import disable_existing_loggers |
|
from colossalai.pipeline.pipeline_process_group import ppg |
|
from colossalai.pipeline.rpc._pipeline_schedule import OneFOneBPipelineEngine |
|
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass |
|
from colossalai.fx import ColoTracer |
|
from colossalai.pipeline.middleware.adaptor import get_fx_topology |
|
from rpc_test_utils import MLP, DAG_MLP |
|
from functools import partial |
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use |
|
|
|
# global variable for model created |
|
batch_size = 16 |
|
dim = 10 |
|
rpc_is_initialized = _is_current_rpc_agent_set |
|
|
|
def create_partition_module(pp_rank: int, stage_num: int, model, data_kwargs): |
|
model.eval() |
|
tracer = ColoTracer() |
|
meta_args = {k: v.to('meta') for k, v in data_kwargs.items()} |
|
graph = tracer.trace(root=model, meta_args=meta_args) |
|
gm = torch.fx.GraphModule(model, graph, model.__class__.__name__) |
|
annotated_model = balanced_split_pass(gm, stage_num) |
|
top_module, split_submodules = split_with_split_nodes_pass(annotated_model, merge_output=True) |
|
topo = get_fx_topology(top_module) |
|
for submodule in split_submodules: |
|
if isinstance(submodule, torch.fx.GraphModule): |
|
setattr(submodule, '_topo', topo) |
|
return split_submodules[pp_rank+1] |
|
|
|
def partition(model, data_kwargs: dict, pp_rank: int, chunk: int, stage_num: int): |
|
torch.manual_seed(1024) |
|
partition = create_partition_module(pp_rank, stage_num, model, data_kwargs) |
|
return partition |
|
|
|
def run_master(model_cls, world_size, forward_only): |
|
torch.manual_seed(100) |
|
|
|
epoch = 3 |
|
device = 'cuda' |
|
stage_num = world_size |
|
chunk = 1 |
|
num_microbatches = 8 |
|
use_checkpoint = 'store_true' |
|
|
|
if model_cls == MLP: |
|
def data_gen(): |
|
x = torch.zeros((batch_size, dim)) |
|
kwargs = dict(x=x) |
|
return kwargs |
|
model = model_cls(dim, stage_num * 3) |
|
if forward_only: |
|
labels = None |
|
else: |
|
labels = 1 |
|
elif model_cls == DAG_MLP: |
|
def data_gen(): |
|
x = torch.zeros((batch_size, dim)) |
|
y = torch.zeros((batch_size, dim)) |
|
kwargs = dict(x=x, y=y) |
|
return kwargs |
|
model = model_cls(dim, stage_num * 3) |
|
if forward_only: |
|
labels = None |
|
else: |
|
labels = 1 |
|
else: |
|
pass |
|
|
|
data_kwargs = data_gen() |
|
|
|
engine = OneFOneBPipelineEngine(partition_fn=partial(partition, model, data_kwargs), |
|
stage_num=stage_num, |
|
num_microbatches=num_microbatches, |
|
device=device, |
|
chunk=chunk, |
|
checkpoint=use_checkpoint,) |
|
if not forward_only: |
|
engine.initialize_optimizer(getattr(torch.optim, 'SGD'), lr=1e-3) |
|
|
|
for _ in range(epoch): |
|
input_x = torch.randn((batch_size, dim), device=device) |
|
input_y = torch.randn((batch_size, dim), device=device) |
|
logits = engine.forward_backward({'x': input_x, 'y': input_y}, labels=labels, forward_only=forward_only) |
|
|
|
def run_worker(rank, model_cls, world_size, forward_only, master_func): |
|
master_addr = 'localhost' |
|
master_port = 29020 |
|
os.environ['MASTER_ADDR'] = master_addr |
|
os.environ['MASTER_PORT'] = str(master_port) |
|
|
|
disable_existing_loggers() |
|
|
|
launch(dict(), rank, world_size, master_addr, master_port, 'nccl', verbose=False) |
|
ppg.set_global_info(rank=rank, |
|
world_size=world_size, |
|
dp_degree=1, |
|
tp_degree=1, |
|
num_worker_threads=128, |
|
device='cuda') |
|
|
|
# in rpc mode, only rank 0 is needed to be coded |
|
if rank == 0: |
|
master_func(model_cls, world_size, forward_only) |
|
# barrier here |
|
if rpc_is_initialized(): |
|
rpc.shutdown() |
|
|
|
@pytest.mark.skip("skip due to CI torch version 1.11") |
|
@parameterize('model_cls', [MLP, DAG_MLP]) |
|
@parameterize('forward_only', [True, False]) |
|
@pytest.mark.dist |
|
@rerun_if_address_is_in_use() |
|
def test_pp_middleware_fwd(model_cls, forward_only): |
|
world_size = 4 |
|
master_func = run_master |
|
mp.spawn(run_worker, args=(model_cls, world_size, forward_only, master_func), nprocs=world_size) |
|
|
|
if __name__ == "__main__": |
|
test_pp_middleware_fwd() |