[pipeline/pipleline_process_group] finish PipelineProcessGroup to manage local abd global rank in TP,DP and PP (#1508)

* support p2p communication with any type of object | pass test

* reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test

* [engin/schedule] use p2p_v2 to recontruct pipeline_schedule

* [pipeline/rpc] implement a demo for PP with cuda rpc framework

* [pipeline/rpc] support interleaving | fix checkpoint bug | change logic when dispatch data in work_list to ensure steady 1F1B

* [pipeline/rpc] implement distributed optimizer | test with assert_close

* [pipeline/rpc] implement distributed optimizer | test with assert_close

* [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy

* [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy

* [pipeline/rpc] update outstanding mechanism | optimize dispatching strategy

* [pipeline/pipleline_process_group] finish PipelineProcessGroup to manage local abd global rank in TP,DP and PP

* [pipeline/pipleline_process_group] remove comment

* [pipeline/pipleline_process_group] remove comment

* [pipeline/pipleline_process_group] skip process group test

* [pipeline/pipleline_process_group] remove test named function
pull/1533/head
Kirigaya Kazuto 2022-09-01 17:45:47 +08:00 committed by GitHub
parent 8a29ce5443
commit f1e1836218
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 199 additions and 1 deletions

View File

@ -0,0 +1,135 @@
from typing import List, Dict, Tuple
import os
from torch.distributed import rpc
import torch.distributed as dist
from colossalai.tensor import ProcessGroup
class PipelineProcessGroup:
# TODO : flexible API for DP size and TP size
# In the future design mode, dp_degree and tp_degree should be removed
def __init__(self,
rank: int,
world_size: int,
dp_degree: int = 1,
tp_degree: int = 1,
num_worker_threads: int = 1,
device: str = "cuda") -> None:
device_mesh_size = dp_degree * tp_degree
assert world_size % device_mesh_size == 0, "world_size must be the multiple of dp_degree * tp_degree !!!"
self._num_worker_threads = num_worker_threads
self._device_mesh_size = device_mesh_size
self._rank = rank
self._world_size = world_size
self._dp_degree = dp_degree
self._tp_degree = tp_degree
self.device = device
self._stage_num = world_size // device_mesh_size
self._pp_rank = rank // device_mesh_size
self._pp_ranks = [(rank % device_mesh_size) + i * device_mesh_size for i in range(self._stage_num)]
self._local_stage_ranks = [(rank // device_mesh_size * device_mesh_size) + i for i in range(device_mesh_size)]
# pp_ranks
self._initialize_pp_process_group()
# initialise tp dp process groups
self._initialize_tp_dp_process_group()
# status
self._is_first_pp_rank = self._pp_rank == 0
self._is_last_pp_rank = self._pp_rank == self._stage_num - 1
def _initialize_process_group(self):
stage_num = self.get_stage_num()
if stage_num == 1:
return
device = self.device
world_size = self.get_world_size()
rank = self.get_global_rank()
backend = 'nccl' if device == 'cuda' else 'gloo'
dist.init_process_group(backend, world_size=world_size, rank=rank, group_name='main_group')
def _initialize_pp_process_group(self) -> None:
rank = self.get_global_rank()
world_size = self.get_world_size()
# build rpc connection
options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=self._num_worker_threads)
for pp_rank in self._pp_ranks:
options.set_device_map(f'work{pp_rank}', {rank: pp_rank})
rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options)
def _initialize_tp_dp_process_group(self) -> None:
rank = self.get_global_rank()
local_stage_ranks = self.get_local_stage_global_ranks()
dp_degree = self.get_dp_degree()
tp_degree = self.get_tp_degree()
self._tp_dp_process_group = ProcessGroup(rank, local_stage_ranks, tp_degree, dp_degree)
def get_global_rank(self):
return self._rank
def get_world_size(self):
return self._world_size
def get_dp_degree(self) -> int:
return self._dp_degree
def get_tp_degree(self) -> int:
return self._tp_degree
def get_local_device_mesh_size(self) -> int:
return self._device_mesh_size
def get_device_mesh_num(self) -> int:
pass
def get_stage_num(self) -> int:
return self._stage_num
def is_first_stage(self) -> bool:
return self._is_first_pp_rank
def is_last_stage(self) -> bool:
return self._is_last_pp_rank
def check_pp_rank_valid(self, pp_rank: int) -> bool:
return -1 < pp_rank < self._stage_num
def get_local_pp_rank(self) -> int:
return self._pp_rank
def get_prev_pp_rank(self) -> int:
prev_pp_rank = self._pp_rank - 1
if not self.check_pp_rank_valid(prev_pp_rank):
assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a previous stage!")
return prev_pp_rank
def get_next_pp_rank(self) -> int:
next_pp_rank = self._pp_rank + 1
if not self.check_pp_rank_valid(next_pp_rank):
assert ValueError(f"current rank's pp_rank: {self._pp_rank} doesn't have a next stage!")
return next_pp_rank
def get_local_stage_global_ranks(self) -> List[int]:
return self._local_stage_ranks
def local_dp_rank(self) -> int:
return self._tp_dp_process_group.dp_local_rank()
def local_tp_rank(self) -> int:
return self._tp_dp_process_group.tp_local_rank()
def get_pp_global_ranks(self) -> int:
return self._pp_ranks
def get_dp_global_ranks(self):
pass
def get_tp_global_ranks(self):
pass

View File

@ -1,13 +1,17 @@
import os import os
import argparse import argparse
import warnings
import torch import torch
from torch import nn from torch import nn
import torch.multiprocessing as mp import torch.multiprocessing as mp
import torch.distributed.rpc as rpc import torch.distributed.rpc as rpc
from torch.optim import SGD, Adam, RMSprop, Optimizer from torch.optim import SGD, Adam, RMSprop, Optimizer
from torch._C._distributed_rpc import _is_current_rpc_agent_set
from colorama import Back, Style from colorama import Back, Style
rpc_is_initialized = _is_current_rpc_agent_set
def color_debug(text, prefix=' ', color='blue'): def color_debug(text, prefix=' ', color='blue'):
color = color.upper() color = color.upper()
@ -52,6 +56,19 @@ def parse_args():
return parser.parse_args() return parser.parse_args()
def pg_parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('--world_size', type=int, default=4)
parser.add_argument('--dp_degree', type=int, default=2)
parser.add_argument('--tp_degree', type=int, default=1)
parser.add_argument('--chunk', type=int, default=1)
parser.add_argument('--num_worker_threads', type=str, default=128)
parser.add_argument('--device', type=str, choices=['cpu', 'cuda'], default='cuda')
parser.add_argument('--master_addr', type=str, default='localhost')
parser.add_argument('--master_port', type=str, default='29020')
return parser.parse_args()
def run_worker(rank, args, master_func): def run_worker(rank, args, master_func):
os.environ['MASTER_ADDR'] = args.master_addr os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port os.environ['MASTER_PORT'] = args.master_port
@ -71,7 +88,10 @@ def run_worker(rank, args, master_func):
if rank == 0: if rank == 0:
master_func(args) master_func(args)
# barrier here # barrier here
if rpc_is_initialized():
rpc.shutdown() rpc.shutdown()
else:
warnings.warn("RPC has not been initialized")
def rpc_run(args, master_func): def rpc_run(args, master_func):

View File

@ -0,0 +1,43 @@
import os
import torch.distributed.rpc as rpc
import torch.multiprocessing as mp
import pytest
from colossalai.pipeline.pipeline_process_group import PipelineProcessGroup
from colossalai.initialize import launch
from colossalai.logging import disable_existing_loggers
from rpc_test_utils import pg_parse_args, rpc_is_initialized
def run_worker(rank, args):
os.environ['MASTER_ADDR'] = args.master_addr
os.environ['MASTER_PORT'] = args.master_port
device = args.device
world_size = args.world_size
dp_degree = args.dp_degree
tp_degree = args.tp_degree
num_worker_threads = args.num_worker_threads
host = args.master_addr
port = args.master_port
backend = 'nccl' if device == 'cuda' else 'gloo'
disable_existing_loggers()
launch(dict(), rank, world_size, host, int(port), backend, verbose=False)
pg = PipelineProcessGroup(rank=rank,
world_size=world_size,
dp_degree=dp_degree,
tp_degree=tp_degree,
num_worker_threads=num_worker_threads,
device=device)
if rpc_is_initialized():
rpc.shutdown()
if __name__ == "__main__":
args = pg_parse_args()
world_size = args.world_size
mp.spawn(run_worker, args=(args,), nprocs=world_size)