From f1e1836218a4ecc6e8acf85c7031492cc4e6b590 Mon Sep 17 00:00:00 2001 From: Kirigaya Kazuto <59416203+LSTM-Kirigaya@users.noreply.github.com> Date: Thu, 1 Sep 2022 17:45:47 +0800 Subject: [PATCH] [pipeline/pipleline_process_group] finish PipelineProcessGroup to manage local abd global rank in TP,DP and PP (#1508) * support p2p communication with any type of object | pass test * reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [pipeline/rpc] implement a demo for PP with cuda rpc framework * [pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B * [pipeline/rpc] implement distributed optimizer | test with assert_close * [pipeline/rpc] implement distributed optimizer | test with assert_close * [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy * [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy * [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy * [pipeline/pipleline_process_group] finish PipelineProcessGroup to manage local abd global rank in TP,DP and PP * [pipeline/pipleline_process_group] remove comment * [pipeline/pipleline_process_group] remove comment * [pipeline/pipleline_process_group] skip process group test * [pipeline/pipleline_process_group] remove test named function --- colossalai/pipeline/pipeline_process_group.py | 135 ++++++++++++++++++ tests/test_pipeline/rpc_test_utils.py | 22 ++- .../test_pipeline_process_group.py | 43 ++++++ 3 files changed, 199 insertions(+), 1 deletion(-) create mode 100644 colossalai/pipeline/pipeline_process_group.py create mode 100644 tests/test_pipeline/test_pipeline_process_group.py diff --git a/colossalai/pipeline/pipeline_process_group.py b/colossalai/pipeline/pipeline_process_group.py new file mode 100644 index 000000000..d6fe47bc4 --- /dev/null +++ b/colossalai/pipeline/pipeline_process_group.py @@ -0,0 +1,135 @@ +from typing import List, Dict, Tuple +import os + +from torch.distributed import rpc +import torch.distributed as dist + +from colossalai.tensor import ProcessGroup + + +class PipelineProcessGroup: + # TODO : flexible API for DP size and TP size + # In the future design mode, dp_degree and tp_degree should be removed + def __init__(self, + rank: int, + world_size: int, + dp_degree: int = 1, + tp_degree: int = 1, + num_worker_threads: int = 1, + device: str = "cuda") -> None: + device_mesh_size = dp_degree * tp_degree + assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!" + self._num_worker_threads = num_worker_threads + + self._device_mesh_size = device_mesh_size + self._rank = rank + self._world_size = world_size + self._dp_degree = dp_degree + self._tp_degree = tp_degree + self.device = device + self._stage_num = world_size // device_mesh_size + self._pp_rank = rank // device_mesh_size + self._pp_ranks = [(rank % device_mesh_size) + i * device_mesh_size for i in range(self._stage_num)] + self._local_stage_ranks = [(rank // device_mesh_size * device_mesh_size) + i for i in range(device_mesh_size)] + + # pp_ranks + self._initialize_pp_process_group() + + # initialise tp dp process groups + self._initialize_tp_dp_process_group() + + # status + self._is_first_pp_rank = self._pp_rank == 0 + self._is_last_pp_rank = self._pp_rank == self._stage_num - 1 + + def _initialize_process_group(self): + stage_num = self.get_stage_num() + if stage_num == 1: + return + device = self.device + world_size = self.get_world_size() + rank = self.get_global_rank() + backend = 'nccl' if device == 'cuda' else 'gloo' + dist.init_process_group(backend, world_size=world_size, rank=rank, group_name='main_group') + + def _initialize_pp_process_group(self) -> None: + rank = self.get_global_rank() + world_size = self.get_world_size() + + # build rpc connection + options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=self._num_worker_threads) + + for pp_rank in self._pp_ranks: + options.set_device_map(f'work{pp_rank}', {rank: pp_rank}) + + rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) + + def _initialize_tp_dp_process_group(self) -> None: + rank = self.get_global_rank() + local_stage_ranks = self.get_local_stage_global_ranks() + dp_degree = self.get_dp_degree() + tp_degree = self.get_tp_degree() + self._tp_dp_process_group = ProcessGroup(rank, local_stage_ranks, tp_degree, dp_degree) + + def get_global_rank(self): + return self._rank + + def get_world_size(self): + return self._world_size + + def get_dp_degree(self) -> int: + return self._dp_degree + + def get_tp_degree(self) -> int: + return self._tp_degree + + def get_local_device_mesh_size(self) -> int: + return self._device_mesh_size + + def get_device_mesh_num(self) -> int: + pass + + def get_stage_num(self) -> int: + return self._stage_num + + def is_first_stage(self) -> bool: + return self._is_first_pp_rank + + def is_last_stage(self) -> bool: + return self._is_last_pp_rank + + def check_pp_rank_valid(self, pp_rank: int) -> bool: + return -1 < pp_rank < self._stage_num + + def get_local_pp_rank(self) -> int: + return self._pp_rank + + def get_prev_pp_rank(self) -> int: + prev_pp_rank = self._pp_rank - 1 + if not self.check_pp_rank_valid(prev_pp_rank): + assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a previous stage!") + return prev_pp_rank + + def get_next_pp_rank(self) -> int: + next_pp_rank = self._pp_rank + 1 + if not self.check_pp_rank_valid(next_pp_rank): + assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a next stage!") + return next_pp_rank + + def get_local_stage_global_ranks(self) -> List[int]: + return self._local_stage_ranks + + def local_dp_rank(self) -> int: + return self._tp_dp_process_group.dp_local_rank() + + def local_tp_rank(self) -> int: + return self._tp_dp_process_group.tp_local_rank() + + def get_pp_global_ranks(self) -> int: + return self._pp_ranks + + def get_dp_global_ranks(self): + pass + + def get_tp_global_ranks(self): + pass diff --git a/tests/test_pipeline/rpc_test_utils.py b/tests/test_pipeline/rpc_test_utils.py index bcfeb1760..cb65f3eec 100644 --- a/tests/test_pipeline/rpc_test_utils.py +++ b/tests/test_pipeline/rpc_test_utils.py @@ -1,13 +1,17 @@ import os import argparse +import warnings import torch from torch import nn import torch.multiprocessing as mp import torch.distributed.rpc as rpc from torch.optim import SGD, Adam, RMSprop, Optimizer +from torch._C._distributed_rpc import _is_current_rpc_agent_set from colorama import Back, Style +rpc_is_initialized = _is_current_rpc_agent_set + def color_debug(text, prefix=' ', color='blue'): color = color.upper() @@ -52,6 +56,19 @@ def parse_args(): return parser.parse_args() +def pg_parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=4) + parser.add_argument('--dp_degree', type=int, default=2) + parser.add_argument('--tp_degree', type=int, default=1) + parser.add_argument('--chunk', type=int, default=1) + parser.add_argument('--num_worker_threads', type=str, default=128) + parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29020') + return parser.parse_args() + + def run_worker(rank, args, master_func): os.environ['MASTER_ADDR'] = args.master_addr os.environ['MASTER_PORT'] = args.master_port @@ -71,7 +88,10 @@ def run_worker(rank, args, master_func): if rank == 0: master_func(args) # barrier here - rpc.shutdown() + if rpc_is_initialized(): + rpc.shutdown() + else: + warnings.warn("RPC has not been initialized") def rpc_run(args, master_func): diff --git a/tests/test_pipeline/test_pipeline_process_group.py b/tests/test_pipeline/test_pipeline_process_group.py new file mode 100644 index 000000000..c0aff8c10 --- /dev/null +++ b/tests/test_pipeline/test_pipeline_process_group.py @@ -0,0 +1,43 @@ +import os + +import torch.distributed.rpc as rpc +import torch.multiprocessing as mp +import pytest + +from colossalai.pipeline.pipeline_process_group import PipelineProcessGroup +from colossalai.initialize import launch +from colossalai.logging import disable_existing_loggers +from rpc_test_utils import pg_parse_args, rpc_is_initialized + + +def run_worker(rank, args): + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + + device = args.device + world_size = args.world_size + dp_degree = args.dp_degree + tp_degree = args.tp_degree + num_worker_threads = args.num_worker_threads + host = args.master_addr + port = args.master_port + backend = 'nccl' if device == 'cuda' else 'gloo' + + disable_existing_loggers() + launch(dict(), rank, world_size, host, int(port), backend, verbose=False) + + pg = PipelineProcessGroup(rank=rank, + world_size=world_size, + dp_degree=dp_degree, + tp_degree=tp_degree, + num_worker_threads=num_worker_threads, + device=device) + + if rpc_is_initialized(): + rpc.shutdown() + + +if __name__ == "__main__": + args = pg_parse_args() + world_size = args.world_size + mp.spawn(run_worker, args=(args,), nprocs=world_size) \ No newline at end of file