From bb5f5289e0e4e316eec19fe0d86f73a065d1e37a Mon Sep 17 00:00:00 2001 From: Kirigaya Kazuto <59416203+LSTM-Kirigaya@users.noreply.github.com> Date: Mon, 22 Aug 2022 10:50:51 +0800 Subject: [PATCH] [pipeline/rpc] implement a demo for PP with cuda rpc framework (#1470) * support p2p communication with any type of object | pass test * reconstruct pipeline schedule with p2p_v2.py(support communication with List[Any]) | pass test * [engin/schedule] use p2p_v2 to recontruct pipeline_schedule * [pipeline/rpc] implement a demo for PP with cuda rpc framework * Delete p2p_v2.py * Delete _pipeline_schedule_v2.py * Delete test_object_list_p2p_v2.py * Delete test_boardcast_send_recv_v2.py * Delete test_cifar_with_data_pipeline_tensor_v2.py --- colossalai/pipeline/rpc/PipelineBase.py | 580 ++++++++++++++++++ tests/test_pipeline/test_cuda_rpc_pipeline.py | 102 +++ 2 files changed, 682 insertions(+) create mode 100644 colossalai/pipeline/rpc/PipelineBase.py create mode 100644 tests/test_pipeline/test_cuda_rpc_pipeline.py diff --git a/colossalai/pipeline/rpc/PipelineBase.py b/colossalai/pipeline/rpc/PipelineBase.py new file mode 100644 index 000000000..36a69499e --- /dev/null +++ b/colossalai/pipeline/rpc/PipelineBase.py @@ -0,0 +1,580 @@ +import threading +from enum import Enum +from typing import List, Any, Tuple, Dict +from abc import ABC, abstractmethod + +import torch +from torch import nn +import torch.distributed.rpc as rpc +from torch.futures import Future +from torch._C._distributed_rpc import PyRRef +from torch import autograd +from tqdm import tqdm + +from colorama import Back, Style + +# config for debug and test +use_color_debug = False +use_progress = False + +# TODO: +# 1. design a unique_key without node.name (Maybe I can use combination of microbatch_id and stage_id) +# 2. use waiting list to contain the uncomplete WorkItem +# 3. think about the representation of the order of args and kwargs + + +def color_debug(text, prefix=' ', color='blue'): + if use_color_debug: + color = color.upper() + print(getattr(Back, color), prefix, Style.RESET_ALL, text) + + +def tensor_shape_list(tensors): + if isinstance(tensors, torch.Tensor): + return tensors.shape + shapes = [] + for t in tensors: + if hasattr(t, 'shape'): + shapes.append(t.shape) + else: + shapes.append('non tensor') + return shapes + + +class Phase(Enum): + FORWARD = 0 + BACKWARD = 1 + ACCUM_GRAD = 2 + SYNC = 3 + + +class UniqueKey: + __slots__ = ('microbatch_id', 'phase') + microbatch_id: int + phase: Phase + + def __init__(self, microbatch_id, phase) -> None: + self.microbatch_id = microbatch_id + self.phase = phase + + def __eq__(self, __o: object) -> bool: + return (self.microbatch_id == __o.microbatch_id) and (self.phase == __o.phase) + + def __hash__(self) -> int: + return tuple.__hash__((self.microbatch_id, self.phase)) + + def __repr__(self) -> str: + return f'Key(microbatch_id={self.microbatch_id}, phase={self.phase})' + + +class WorkItem: + __slots__ = ('stage_id', 'phase', 'args', 'kwargs', 'output', 'refcount', 'microbatch_id', 'batch_id', + 'num_microbatches') + + stage_id: int + phase: Phase + args: Tuple[Any] + kwargs: Dict[str, Any] + output: Future + microbatch_id: int + + refcount: int + + batch_id: int + num_microbatches: int + + def __init__(self, + stage_id, + phase, + args, + kwargs, + output, + microbatch_id, + batch_id, + num_microbatches, + refcount=0) -> None: + for attr_name in self.__slots__: + setattr(self, attr_name, locals()[attr_name]) + + +class BackwardCache: + __slots__ = ('checkpoint', 'stage_inputs', 'stage_outputs') + checkpoint: bool + stage_inputs: Tuple[Any] + stage_outputs: Tuple[Any] + + def __init__(self, + stage_inputs: List[torch.Tensor], + stage_outputs: List[torch.Tensor] = None, + checkpoint: bool = False) -> None: + for arg_name in self.__slots__: + setattr(self, arg_name, locals()[arg_name]) + + +class RemoteExecutor: + + def __init__(self) -> None: + pass + + +class RemoteOptimizer: + + def __init__(self) -> None: + pass + + +class Worker: + + def __init__(self, + cur_rank_module: nn.Module, + rank: int, + world_size: int, + num_microbatches: int, + max_outstanding: int, + device: str, + checkpoint: bool = False) -> None: + super().__init__() + self.rank = rank + self.world_size = world_size + self.num_microbatches = num_microbatches + self.max_outstanding = max_outstanding + self.outstanding = 0 + self.checkpoint = checkpoint + + if device == 'cuda': + device = f'cuda:{rank}' + self.device = device + + self.future_devices = None if device is None or device == 'cpu' else [device] + + self.stage_to_worker_rref: Dict[int, PyRRef] = None + self.producer_stage_ids: List[int] = None + self.consumer_stage_ids: List[int] = None + + # module + self.cur_rank_module = cur_rank_module.to(device) + + self.microbatch_id_to_backward_cache: Dict[int, BackwardCache] = dict() + + self.work_list: Dict[UniqueKey, WorkItem] = dict() + self.output_list: Dict[UniqueKey, WorkItem] = dict() + + # Why must a Lock instead of RLock ? + # Because RLock cannot be pickled + self.work_list_condition_lock = threading.Condition(threading.Lock()) + self.output_list_condition_lock = threading.Condition(threading.Lock()) + + self.main_loop_thread = threading.Thread(target=self._work_loop, name=f'rank_{rank}', daemon=True) + self.main_loop_thread.start() + + def _get_future_by_device(self): + return torch.futures.Future(devices=None if self.device in (None, 'cpu') else [self.device]) + + def sync_global_worker_rrefs(self, stage_to_worker_rref: Dict[int, PyRRef]) -> None: + assert self.stage_to_worker_rref is None, f"in rank {self.rank}, worker has sync global workers rrefs" + assert stage_to_worker_rref is not None, "stage_to_workers must be a dict instead of None" + self.stage_to_worker_rref = stage_to_worker_rref + + def get_output_by_key(self, key: UniqueKey) -> Any: + with self.output_list_condition_lock: + while key not in self.output_list: + self.output_list_condition_lock.wait() + + output_work_item = self.output_list[key] + + output = output_work_item.output.wait() + # color_debug(f'rank {self.rank}, output {type(output)}', 'get output', 'red') + output_work_item.refcount += 1 + + # all consumers have been satisfied, the work_item can be released + with self.output_list_condition_lock: + if output_work_item.refcount == len(self.consumer_stage_ids): + self.output_list.pop(key) + + return output + + # just for first rank + # TODO : input is args kwargs + def set_input(self, microbatch_id: int, microbatch: Tuple[Any]): + with self.work_list_condition_lock: + assert self.consumer_stage_ids is not None + consumer_num = len(self.consumer_stage_ids) + key = UniqueKey(microbatch_id, Phase.FORWARD) + output = self._get_future_by_device() + args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch + + work_item = WorkItem(self.rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches, + consumer_num) + self.work_list[key] = work_item + + color_debug(f'rank {self.rank} receive data from dataloader', 'data dispatch', 'magenta') + + self.work_list_condition_lock.notify_all() + + # just for last rank + # TODO : write a function to add gradient to work_list and see if there is contradictory + def _begin_backward(self, microbatch_id: int): + with self.work_list_condition_lock: + assert self.producer_stage_ids is not None + producer_num = len(self.producer_stage_ids) + key = UniqueKey(microbatch_id, Phase.BACKWARD) + output = self._get_future_by_device() + grad_wrt_loss = torch.tensor(1, device=self.device) + + work_item = WorkItem(self.rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, + self.num_microbatches, producer_num) + + color_debug(f'rank {self.rank} propose backward', 'data dispatch', 'magenta') + + self.work_list[key] = work_item + self.work_list_condition_lock.notify_all() + + def subscribe_producer(self, microbatch_id: int): + """ + You should call this function asynchronously + """ + assert self.producer_stage_ids is not None + producer_num = len(self.producer_stage_ids) + consumer_num = len(self.consumer_stage_ids) + assert producer_num > 0, "only stage that has producers can subscribe producers" + + stage_id = self.rank + + subscribe_forward_futures: List[Future] = [None] * producer_num + output = self._get_future_by_device() + + for i in range(producer_num): + producer_stage_id = self.producer_stage_ids[i] + producer_output_key = UniqueKey(microbatch_id, Phase.FORWARD) + producer_worker_rref = self.stage_to_worker_rref[producer_stage_id] + subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) + + color_debug(f'rank {self.rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch', + 'magenta') + + args = [] + for i in range(producer_num): + producer_args = subscribe_forward_futures[i].wait() + args.extend(producer_args) + + # TODO : not only args + work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, args, {}, output, microbatch_id, None, + self.num_microbatches, consumer_num) + + color_debug(f'rank {self.rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') + # add work_item to work_list + with self.work_list_condition_lock: + key = UniqueKey(microbatch_id, Phase.FORWARD) + assert key not in self.work_list + self.work_list[key] = work_item_from_producer + color_debug( + f'rank_{self.rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}', + 'data dispatch', 'magenta') + self.work_list_condition_lock.notify_all() + + def subscribe_consumer(self, microbatch_id: int): + """ + You should call this function asynchronously + """ + assert self.producer_stage_ids is not None + producer_num = len(self.producer_stage_ids) + consumer_num = len(self.consumer_stage_ids) + assert consumer_num > 0, "only stage that has consumers can subscribe comsumers" + + # TODO : is this right? + stage_id = self.rank + + subscribe_backward_futures: List[Future] = [None] * consumer_num + output = self._get_future_by_device() + + color_debug(f'rank {self.rank} get {len(subscribe_backward_futures)} futs from its consumer', 'data dispatch', + 'magenta') + + for i in range(consumer_num): + consumer_stage_id = self.consumer_stage_ids[i] + consumer_output_key = UniqueKey(microbatch_id, Phase.BACKWARD) + consumer_worker_rref = self.stage_to_worker_rref[consumer_stage_id] + subscribe_backward_futures[i] = consumer_worker_rref.rpc_async().get_output_by_key(consumer_output_key) + + args = [] + for i in range(consumer_num): + consumer_args = subscribe_backward_futures[i].wait() + args.extend(consumer_args) + + # flatten args + work_item_from_consumer = WorkItem(stage_id, Phase.BACKWARD, args, {}, output, microbatch_id, None, + self.num_microbatches, producer_num) + + color_debug(f'rank {self.rank} get value {tensor_shape_list(args)} from fut', 'data dispatch', 'magenta') + + # add work_item to work_list + with self.work_list_condition_lock: + key = UniqueKey(microbatch_id, Phase.BACKWARD) + assert key not in self.work_list + self.work_list[key] = work_item_from_consumer + color_debug( + f'rank_{self.rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}', + 'data dispatch', 'magenta') + self.work_list_condition_lock.notify_all() + + # TODO : fit in any type of partition of network + def _get_producer_consumer(self) -> None: + rank = self.rank + assert self.producer_stage_ids is None, f"all the producers of rank {rank} has been subscribed" + assert self.consumer_stage_ids is None, f"all the consumers of rank {rank} has been subscribed" + + # should be aranged in order, the order of the input of current forward + self.producer_stage_ids = [] + self.consumer_stage_ids = [] + + # Just for demo + prev_rank = rank - 1 + next_rank = rank + 1 + if prev_rank >= 0: + self.producer_stage_ids.append(prev_rank) + if next_rank <= self.world_size - 1: + self.consumer_stage_ids.append(next_rank) + + def _skip_forward(self, work_item_phase: Phase) -> bool: + if work_item_phase == Phase.FORWARD and \ + self.max_outstanding is not None and \ + self.outstanding >= self.max_outstanding: + return True + return False + + def _get_work_item_key(self) -> UniqueKey: + with self.work_list_condition_lock: + while len(self.work_list) == 0: + self.work_list_condition_lock.wait() + + # execute backward first (if backward phase in work_list) + + select_work_list_key = None + for key in self.work_list: + work_item = self.work_list[key] + + if work_item.phase == Phase.BACKWARD: + return key + + if self._skip_forward(work_item.phase): + continue + else: + select_work_list_key = key + return select_work_list_key + + def _consume_work_item_by_phase(self, work_item: WorkItem): + phase = work_item.phase + args = work_item.args + kwargs = work_item.kwargs + microbatch_id = work_item.microbatch_id + consume_result = None + + # color_debug(f'rank_{self.rank} enter consume', 'consume', 'blue') + + if phase == Phase.FORWARD: + self.outstanding += 1 + + # TODO : more elegant ? + for i in range(len(args)): + arg_obj = args[i] + if isinstance(arg_obj, torch.Tensor) and not arg_obj.requires_grad: + args[i] = arg_obj.requires_grad_() + + # TODO : use process manager to acquire rank info later + is_last_stage = len(self.consumer_stage_ids) == 0 + + if self.checkpoint and not is_last_stage: + with torch.no_grad(): + consume_result = self.cur_rank_module(*args, **kwargs) + stage_outputs = None + stage_inputs = args + self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs, + stage_outputs, + checkpoint=True) + else: + # TODO : replace with *args, **kwargs and ensure the consume_result is a tuple + consume_result = self.cur_rank_module(*args, **kwargs) + stage_outputs = consume_result + stage_inputs = args + self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs, + stage_outputs, + checkpoint=False) + + consume_result = [consume_result] if isinstance(consume_result, torch.Tensor) else consume_result + + # if it is the last stage, trigger backward automatic + if is_last_stage: + self._begin_backward(microbatch_id) + + elif phase == Phase.BACKWARD: + self.outstanding -= 1 + assert microbatch_id in self.microbatch_id_to_backward_cache, f"microbatch_id {microbatch_id} not in backward cache" + backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id) + + stage_outputs = backward_cache.stage_outputs + stage_inputs = backward_cache.stage_inputs + grad_tensors = args + + # color_debug(f'rank_{self.rank} before backward', 'consume', 'yellow') + + if self.checkpoint: + stage_outputs = [self.cur_rank_module(*stage_inputs)] + + autograd.backward(stage_outputs, grad_tensors=grad_tensors) + + # color_debug(f'rank_{self.rank} after backward', 'consume', 'yellow') + + # collect grad of input tensor + consume_result = [] + for input_node in stage_inputs: + if isinstance(input_node, torch.Tensor): + consume_result.append(input_node.grad) + + elif phase == Phase.SYNC: + pass + else: + raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") + + return consume_result + + # do the main loop to consume ready_list + def _work_loop(self): + # for init + self._get_producer_consumer() + + # main loop + while True: + work_item_key = self._get_work_item_key() + if work_item_key is None: + continue + + # move current work item to output_list to activate subscribe in advance + with self.work_list_condition_lock: + work_item = self.work_list.pop(work_item_key) + + color_debug( + f'rank {self.rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)}', + 'work loop', 'green') + + with self.output_list_condition_lock: + # assert work_item_key not in self.output_list + self.output_list[work_item_key] = work_item + self.output_list_condition_lock.notify_all() + + consume_result = self._consume_work_item_by_phase(work_item) + + color_debug( + f'rank_{self.rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)}', + 'work loop', 'green') + # if work_item.stage_id == 1 and work_item.phase == Phase.BACKWARD: + # from time import sleep + # sleep(5) + work_item.output.set_result(consume_result) + + +# TODO +# 1. chunk +# 2. checkpoint +class PipelineEngineBase(ABC, nn.Module): + + def __init__(self, + module_partitions, + chunk, + world_size, + num_microbatches, + device: str, + max_outstanding=None, + use_interleave: bool = False, + checkpoint: bool = False) -> None: + super().__init__() + self.module_partitions: List[nn.Module] = module_partitions + self.chunk = chunk + self.num_microbatches = num_microbatches + self.device = device + self.max_outstanding = max_outstanding + self.world_size = world_size + self.checkpoint = checkpoint + self.use_interleave = use_interleave + + self.stage_to_worker_rref: Dict[int, PyRRef] = dict() + self._init_worker() + + def _init_worker(self): + world_size = self.world_size + max_outstanding = self.max_outstanding + checkpoint = self.checkpoint + num_microbatches = self.num_microbatches + device = self.device + + # TODO : world size is correct ? + for rank in range(world_size): + cur_rank_module = self.module_partitions[rank] + self.stage_to_worker_rref[rank] = rpc.remote(rank, + Worker, + args=(cur_rank_module, rank, world_size, num_microbatches, + max_outstanding, device, checkpoint)) + + # let each worker know global worker rref (include itself) + for rank in range(world_size): + self.stage_to_worker_rref[rank].rpc_sync().sync_global_worker_rrefs(self.stage_to_worker_rref) + + @abstractmethod + def forward_backward(self): + pass + + +class FillDrainPipelineEngine(PipelineEngineBase): + + def __init__(self, + module_partitions, + chunk, + world_size, + num_microbatches, + device: str, + max_outstanding=None, + use_interleave: bool = False, + checkpoint: bool = False) -> None: + super().__init__(module_partitions, chunk, world_size, num_microbatches, device, max_outstanding, + use_interleave, checkpoint) + + # TODO : adjust to args and kwargs + def forward_backward(self, batch: torch.Tensor): + first_stage_worker = self.stage_to_worker_rref[0] + microbatch_size = len(batch) // self.num_microbatches + + microbatch_iter = range(self.num_microbatches) + if use_progress: + microbatch_iter = tqdm(microbatch_iter) + + for microbatch_id in microbatch_iter: + microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] + + # forward subscribe asynchronously + for rank in range(1, self.world_size, 1): + worker_rref = self.stage_to_worker_rref[rank] + worker_rref.rpc_async().subscribe_producer(microbatch_id) + + # backward subscribe asynchronously + for rank in range(self.world_size - 2, -1, -1): + worker_rref = self.stage_to_worker_rref[rank] + worker_rref.rpc_async().subscribe_consumer(microbatch_id) + + # run one microbatch + first_stage_worker.rpc_sync().set_input(microbatch_id, microbatch) + + +class OneFOneBPipelineEngine(FillDrainPipelineEngine): + + def __init__(self, + module_partitions, + chunk, + world_size, + num_microbatches, + device: str, + max_outstanding=None, + use_interleave: bool = False, + checkpoint: bool = False) -> None: + if max_outstanding is None: + max_outstanding = world_size + super().__init__(module_partitions, chunk, world_size, num_microbatches, device, max_outstanding, + use_interleave, checkpoint) diff --git a/tests/test_pipeline/test_cuda_rpc_pipeline.py b/tests/test_pipeline/test_cuda_rpc_pipeline.py new file mode 100644 index 000000000..66b76d136 --- /dev/null +++ b/tests/test_pipeline/test_cuda_rpc_pipeline.py @@ -0,0 +1,102 @@ +import os +import argparse + +import torch +from torch import nn +import torch.multiprocessing as mp +import torch.distributed.rpc as rpc + +from colossalai.pipeline.rpc.PipelineBase import FillDrainPipelineEngine, OneFOneBPipelineEngine + + +class TestModel(nn.Module): + + def __init__(self, rank, world_size, feat_num, h) -> None: + super().__init__() + self.rank = rank + self.is_last_rank = rank == world_size - 1 + self.linear_name = f'linear_{rank}' + if rank == 0: + setattr(self, self.linear_name, nn.Linear(feat_num, h)) + elif rank == world_size - 1: + setattr(self, self.linear_name, nn.Linear(h, 1)) + else: + setattr(self, self.linear_name, nn.Linear(h, h)) + + def forward(self, x) -> torch.Tensor: + linear: nn.Module = getattr(self, self.linear_name) + out: torch.Tensor = linear(x) + + if self.is_last_rank: + out = out.sum() + return out + + +def run_main(args): + torch.manual_seed(100) + + sample_num = 128 + feat_num = 10000 + h = 10000 + device = args.device + world_size = args.world_size + batch_size = 128 + assert sample_num % batch_size == 0 + batch_num = sample_num // batch_size + num_microbatches = world_size + + input_sample = torch.randn((sample_num, feat_num), device=device) + + module_partitions = [TestModel(rank, world_size, feat_num, h) for rank in range(world_size)] + + engine = OneFOneBPipelineEngine(module_partitions=module_partitions, + chunk=1, + world_size=world_size, + num_microbatches=num_microbatches, + device=args.device, + max_outstanding=world_size, + use_interleave=False, + checkpoint=False) + + for i in range(batch_num): + batch = input_sample[i * batch_size:(i + 1) * batch_size] + engine.forward_backward(batch) + + +def run_worker(rank, args): + os.environ['MASTER_ADDR'] = args.master_addr + os.environ['MASTER_PORT'] = args.master_port + + # config rpc + # if cuda is used, set_device_map is a must is configured + # for cuda is not supported in torch rpc by default + options = rpc.TensorPipeRpcBackendOptions(num_worker_threads=args.num_worker_threads) + + world_size = args.world_size + for rank_idx in range(world_size): + options.set_device_map(f'work{rank_idx}', {rank: rank_idx}) + + rpc.init_rpc(name=f'work{rank}', rank=rank, world_size=world_size, rpc_backend_options=options) + + # in rpc mode, only rank 0 is needed to be coded + if rank == 0: + run_main(args) + # barrier here + rpc.shutdown() + + +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument('--world_size', type=int, default=2) + parser.add_argument('--device', type=str, default='cuda') + parser.add_argument('--master_addr', type=str, default='localhost') + parser.add_argument('--master_port', type=str, default='29020') + parser.add_argument('--num_worker_threads', type=str, default=128) + return parser.parse_args() + + +if __name__ == "__main__": + args = parse_args() + world_size = args.world_size + assert args.device in ['cpu', 'cuda'], "device must be cpu or cuda!" + mp.spawn(run_worker, args=(args,), nprocs=world_size)