mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
43 lines
1.3 KiB
43 lines
1.3 KiB
2 years ago
|
import os
|
||
|
|
||
|
import torch.distributed.rpc as rpc
|
||
|
import torch.multiprocessing as mp
|
||
|
import pytest
|
||
|
|
||
|
from colossalai.pipeline.pipeline_process_group import PipelineProcessGroup
|
||
|
from colossalai.initialize import launch
|
||
|
from colossalai.logging import disable_existing_loggers
|
||
|
from rpc_test_utils import pg_parse_args, rpc_is_initialized
|
||
|
|
||
|
|
||
|
def run_worker(rank, args):
|
||
|
os.environ['MASTER_ADDR'] = args.master_addr
|
||
|
os.environ['MASTER_PORT'] = args.master_port
|
||
|
|
||
|
device = args.device
|
||
|
world_size = args.world_size
|
||
|
dp_degree = args.dp_degree
|
||
|
tp_degree = args.tp_degree
|
||
|
num_worker_threads = args.num_worker_threads
|
||
|
host = args.master_addr
|
||
|
port = args.master_port
|
||
|
backend = 'nccl' if device == 'cuda' else 'gloo'
|
||
|
|
||
|
disable_existing_loggers()
|
||
|
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
|
||
|
|
||
|
pg = PipelineProcessGroup(rank=rank,
|
||
|
world_size=world_size,
|
||
|
dp_degree=dp_degree,
|
||
|
tp_degree=tp_degree,
|
||
|
num_worker_threads=num_worker_threads,
|
||
|
device=device)
|
||
|
|
||
|
if rpc_is_initialized():
|
||
|
rpc.shutdown()
|
||
|
|
||
|
|
||
|
if __name__ == "__main__":
|
||
|
args = pg_parse_args()
|
||
|
world_size = args.world_size
|
||
|
mp.spawn(run_worker, args=(args,), nprocs=world_size)
|