From 9708638ded08b0afad9291641cb2869ab7b2fe15 Mon Sep 17 00:00:00 2001 From: Kirigaya Kazuto <59416203+LSTM-Kirigaya@users.noreply.github.com> Date: Thu, 29 Sep 2022 10:58:58 +0800 Subject: [PATCH] [pipeline/pytree] add pytree to process args and kwargs | provide `data_process_func` to process args and kwargs after forward (#1642) * [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera * [pipeline/chimera] test chimera | fix bug of initializing * [pipeline/pytree] add pytree to process args and kwargs | provide to process args and kwargs after forward --- colossalai/pipeline/rpc/__init__.py | 3 +- colossalai/pipeline/rpc/_pipeline_base.py | 269 ++++++++++-------- colossalai/pipeline/rpc/_pipeline_schedule.py | 20 +- colossalai/pipeline/rpc/utils.py | 74 +++++ tests/test_pipeline/test_cuda_rpc_chimera.py | 7 +- 5 files changed, 247 insertions(+), 126 deletions(-) create mode 100644 colossalai/pipeline/rpc/utils.py diff --git a/colossalai/pipeline/rpc/__init__.py b/colossalai/pipeline/rpc/__init__.py index 5e0726456..9d9e9d44f 100644 --- a/colossalai/pipeline/rpc/__init__.py +++ b/colossalai/pipeline/rpc/__init__.py @@ -1,3 +1,4 @@ from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine +from .utils import pytree_map -__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine'] \ No newline at end of file +__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map'] \ No newline at end of file diff --git a/colossalai/pipeline/rpc/_pipeline_base.py b/colossalai/pipeline/rpc/_pipeline_base.py index 58071dc26..16c2c95dc 100644 --- a/colossalai/pipeline/rpc/_pipeline_base.py +++ b/colossalai/pipeline/rpc/_pipeline_base.py @@ -1,9 +1,11 @@ import threading from enum import Enum from typing import List, Any, Tuple, Dict, Callable +from functools import partial from abc import ABC, abstractmethod import sys import os +import inspect import torch from torch import nn @@ -12,57 +14,10 @@ from torch.futures import Future from torch._C._distributed_rpc import PyRRef from torch import autograd from torch import optim -from tqdm import tqdm -from time import time -from colorama import Back, Style - -# config for debug and test -use_color_debug = True - -# TODO: -# 1. adjust to args and kwargs (pytree) - - -def color_debug(text, prefix=' ', color='blue'): - if use_color_debug: - color = color.upper() - print(getattr(Back, color), prefix, Style.RESET_ALL, text) - - -def tensor_shape_list(tensors): - if tensors is None: - return None - if isinstance(tensors, (int, float)): - return tensors - if isinstance(tensors, torch.Tensor): - return tensors.shape - shapes = [] - for t in tensors: - if hasattr(t, 'shape'): - shapes.append(t.shape) - else: - shapes.append('non tensor') - return shapes - - -def get_real_args(args): - if isinstance(args, torch.Tensor): - return args - elif isinstance(args, list): - real_args = [] - for arg in args: - if isinstance(arg, Future): - value = arg.wait() - else: - value = arg - if isinstance(value, list): - real_args.extend(value) - else: - real_args.append(value) - return real_args - else: - raise TypeError(f"Expect receive tensor or list, but receive {type(args)}") +from colossalai.pipeline.pipeline_process_group import ppg +from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail, + pytree_map, get_real_args_kwargs, use_color_debug) class Phase(Enum): @@ -100,9 +55,7 @@ class WorkItem: kwargs: Dict[str, Any] output: Future microbatch_id: int - refcount: int - batch_id: int num_microbatches: int forward_only: bool @@ -123,14 +76,16 @@ class WorkItem: class BackwardCache: - __slots__ = ('checkpoint', 'stage_inputs', 'stage_outputs') + __slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs') checkpoint: bool - stage_inputs: Tuple[Any] + stage_input_args: Tuple[Any] + stage_input_kwargs: Dict[Any, Any] stage_outputs: Tuple[Any] def __init__(self, - stage_inputs: List[torch.Tensor], - stage_outputs: List[torch.Tensor] = None, + stage_input_args: Tuple[Any], + stage_input_kwargs: Dict[Any, Any] = None, + stage_outputs: Tuple[Any] = None, checkpoint: bool = False) -> None: for arg_name in self.__slots__: setattr(self, arg_name, locals()[arg_name]) @@ -147,13 +102,18 @@ class WorkerBase(ABC): device: str, criterion: Callable = None, metric: Callable = None, - checkpoint: bool = False) -> None: + checkpoint: bool = False, + data_process_func: Callable = None) -> None: super().__init__() self.pp_rank = pp_rank self.actual_stage_num = actual_stage_num self.num_microbatches = num_microbatches self.checkpoint = checkpoint + + if data_process_func is not None: + self.data_process_func = partial(data_process_func, pp_rank) + self.device = device self._initialize_outstanding_range() @@ -260,18 +220,39 @@ class WorkerBase(ABC): self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) return self.module_partition.state_dict() + def _make_args_kwargs(self, microbatch): + if isinstance(microbatch, dict): + return [], microbatch + elif isinstance(microbatch, torch.Tensor): + return [microbatch], {} + elif isinstance(microbatch, (tuple, list)): + args = [] + kwargs = {} + for arg in microbatch: + if isinstance(arg, dict): + kwargs.update(arg) + else: + args.append(arg) + return args, kwargs + else: + raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}") + # just for first pp_rank def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): assert self.consumer_stage_ids is not None key = UniqueKey(microbatch_id, Phase.FORWARD) output = self._get_future_by_device() - args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch - work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches, - forward_only) + + # make args and kwargs + args, kwargs = self._make_args_kwargs(microbatch) + + work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None, + self.num_microbatches, forward_only) with self.work_list_condition_lock: self.work_list[key] = work_item - color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}', 'data dispatch', - 'magenta') + if use_color_debug: + color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}', + 'data dispatch', 'magenta') self.work_list_condition_lock.notify_all() # just for last pp_rank @@ -287,12 +268,13 @@ class WorkerBase(ABC): key = UniqueKey(microbatch_id, Phase.BACKWARD) output = self._get_future_by_device() - grad_wrt_loss = torch.tensor(1, device=self.device) + grad_wrt_loss = None work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, self.num_microbatches, False) - color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta') + if use_color_debug: + color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta') self.work_list[key] = work_item self.work_list_condition_lock.notify_all() @@ -315,8 +297,9 @@ class WorkerBase(ABC): producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) - color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch', - 'magenta') + if use_color_debug: + color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', + 'data dispatch', 'magenta') work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, microbatch_id, None, self.num_microbatches, forward_only) @@ -327,9 +310,10 @@ class WorkerBase(ABC): key = UniqueKey(microbatch_id, Phase.FORWARD) assert key not in self.work_list self.work_list[key] = work_item_from_producer - color_debug( - f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}', - 'data dispatch', 'magenta') + if use_color_debug: + color_debug( + f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}', + 'data dispatch', 'magenta') self.work_list_condition_lock.notify_all() def subscribe_consumer(self, microbatch_id: int): @@ -344,8 +328,9 @@ class WorkerBase(ABC): subscribe_backward_futures: List[Future] = [None] * consumer_num output = self._get_future_by_device() - color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer', - 'data dispatch', 'magenta') + if use_color_debug: + color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer', + 'data dispatch', 'magenta') for i in range(consumer_num): consumer_stage_id = self.consumer_stage_ids[i] @@ -364,9 +349,10 @@ class WorkerBase(ABC): key = UniqueKey(microbatch_id, Phase.BACKWARD) assert key not in self.work_list self.work_list[key] = work_item_from_consumer - color_debug( - f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}', - 'data dispatch', 'magenta') + if use_color_debug: + color_debug( + f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}', + 'data dispatch', 'magenta') self.work_list_condition_lock.notify_all() def _get_producer_consumer(self) -> None: @@ -398,12 +384,23 @@ class WorkerBase(ABC): def is_last_stage(self): return self.pp_rank == self.actual_stage_num - 1 + def _default_data_process_func(self, args_kwargs): + if self.is_first_stage(): + args = args_kwargs[0] + kwargs = args_kwargs[1] + else: + args = args_kwargs + kwargs = {} + + return args, kwargs + def _consume_work_item_by_phase(self, work_item: WorkItem): phase = work_item.phase args = work_item.args kwargs = work_item.kwargs microbatch_id = work_item.microbatch_id forward_only = work_item.forward_only + data_process_func = getattr(self, 'data_process_func', self._default_data_process_func) consume_result = None is_first_stage = self.is_first_stage() @@ -420,18 +417,31 @@ class WorkerBase(ABC): for stage_id in self.consumer_stage_ids: consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id] consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only) - self.forward_times += 1 + # sustain pipeline context + self.forward_times += 1 if not forward_only: self.outstanding += 1 - args = get_real_args(args) - # last stage doesn't need to do checkpoint, for it will do backward instantly + # parse and integrate args and kwargs + if is_first_stage: + args = get_real_args_kwargs(args) + kwargs = get_real_args_kwargs(kwargs) + args_kwargs = (args, kwargs) + else: + args_kwargs = get_real_args_kwargs(args) + + args, kwargs = data_process_func(args_kwargs) + + stage_outputs = None + stage_input_args = args + stage_input_kwargs = kwargs + use_checkpoint = None + if forward_only: with torch.no_grad(): consume_result = self.module_partition(*args, **kwargs) - # TODO : integrate output list if is_last_stage and self.criterion: with self.label_lock: self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels) @@ -445,15 +455,18 @@ class WorkerBase(ABC): metric_result = None consume_result = [loss.item(), metric_result] - stage_outputs = None - stage_inputs = None - use_checkpoint = None + # last stage doesn't need to do checkpoint, for it will do backward instantly + stage_input_args = None + stage_input_kwargs = None + stage_outputs = consume_result + elif self.checkpoint and not is_last_stage: with torch.no_grad(): consume_result = self.module_partition(*args, **kwargs) - stage_outputs = None - stage_inputs = args + + stage_outputs = consume_result use_checkpoint = True + else: consume_result = self.module_partition(*args, **kwargs) # print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', ) @@ -475,17 +488,14 @@ class WorkerBase(ABC): loss = consume_result stage_outputs = loss - stage_inputs = args use_checkpoint = False if not forward_only: - self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs, + self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args, + stage_input_kwargs, stage_outputs, checkpoint=use_checkpoint) - consume_result = [consume_result] if isinstance(consume_result, - (torch.Tensor, int, float)) else consume_result - # if not forward_only, do the backward if not forward_only: if is_last_stage: # if it is the last stage, trigger backward automatic @@ -504,23 +514,34 @@ class WorkerBase(ABC): backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id) stage_outputs = backward_cache.stage_outputs - stage_inputs = backward_cache.stage_inputs + stage_input_args = backward_cache.stage_input_args + stage_input_kwargs = backward_cache.stage_input_kwargs use_checkpoint = backward_cache.checkpoint if use_checkpoint: - stage_outputs = [self.module_partition(*stage_inputs)] + stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)] + + # take tensor only (for only tensor can do backward) + stage_outputs_tensors = [] + pytree_map(stage_outputs, stage_outputs_tensors.append, process_types=torch.Tensor) # overlap recompute and future.wait - grad_tensors = get_real_args(args) + grad_tensors = get_real_args_kwargs(args) - autograd.backward(stage_outputs, grad_tensors=grad_tensors) + # print('rank', self.pp_rank, tensor_shape_list(stage_outputs_tensors), tensor_shape_list(grad_tensors)) + autograd.backward(stage_outputs_tensors, grad_tensors=grad_tensors) # collect grad of input tensor + # there is a hypothesis that node in kwargs cann't be an non-leaf node in graph + # so we don't need to save the grad of node in kwargs. consume_result = [] if not is_first_stage: - for input_node in stage_inputs: - if isinstance(input_node, torch.Tensor): - consume_result.append(input_node.grad) + pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) + pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor) + + # for input_node in stage_input_args: + # if isinstance(input_node, torch.Tensor): + # consume_result.append(input_node.grad) else: raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") @@ -562,6 +583,7 @@ class WorkerBase(ABC): def _work_loop(self): # for init self._get_producer_consumer() + torch.cuda.set_device(ppg.get_local_pp_rank()) # main loop while True: @@ -571,9 +593,10 @@ class WorkerBase(ABC): with self.work_list_condition_lock: work_item = self.work_list.pop(work_item_key) - color_debug( - f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}', - 'work loop', 'green') + if use_color_debug: + color_debug( + f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}', + 'work loop', 'green') with self.output_list_condition_lock: # assert work_item_key not in self.output_list @@ -582,9 +605,10 @@ class WorkerBase(ABC): consume_result = self._consume_work_item_by_phase(work_item) - color_debug( - f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}', - 'work loop', 'green') + if use_color_debug: + color_debug( + f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}', + 'work loop', 'green') work_item.output.set_result(consume_result) @@ -621,7 +645,8 @@ class PipelineEngineBase(ABC, nn.Module): chunk: int = 1, criterion: Callable = None, metric: Callable = None, - checkpoint: bool = False) -> None: + checkpoint: bool = False, + data_process_func: Callable = None) -> None: super().__init__() self.worker_type = worker_type self.partition_fn: Callable = partition_fn @@ -633,6 +658,7 @@ class PipelineEngineBase(ABC, nn.Module): self.use_1F1B = use_1F1B self.stage_num = stage_num self.checkpoint = checkpoint + self.data_process_func = data_process_func self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() @@ -644,9 +670,21 @@ class PipelineEngineBase(ABC, nn.Module): self._init_worker() def _check_argument(self) -> None: + # make virtual stage num self.virtual_stage_num = self.stage_num * self.chunk assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!" + # check data_process_func + data_process_func = self.data_process_func + if data_process_func is not None: + assert callable(data_process_func), "data_process_func must be a function" + assert '' not in data_process_func.__repr__(), "data_process_func must be a global function" + assert '' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression" + sig = inspect.signature(data_process_func) + assert len( + sig.parameters + ) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead" + def _get_actual_stage_num(self) -> int: return self.stage_num if self.chunk == 1 else self.virtual_stage_num @@ -682,6 +720,7 @@ class PipelineEngineBase(ABC, nn.Module): metric = self.metric partition_fn = self.partition_fn chunk = self.chunk + data_process_func = self.data_process_func for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)): partition_id = self.pp_rank_to_module_partition_id[pp_rank] @@ -693,7 +732,7 @@ class PipelineEngineBase(ABC, nn.Module): worker_type, args=(partition_fn, partition_args, pp_rank, actual_stage_num, num_microbatches, device, - criterion, metric, checkpoint)) + criterion, metric, checkpoint, data_process_func)) # let each worker know global worker rref (include itself) sync_futs = [] @@ -779,20 +818,25 @@ class PipelineEngineBase(ABC, nn.Module): worker_forward_result = [None] * self.num_microbatches for microbatch_id in range(self.num_microbatches): ret = ret_future[pp_rank][microbatch_id].wait() + # TODO : more stable format + ret = [ret] if isinstance(ret, torch.Tensor) else ret worker_forward_result[microbatch_id] = ret + worker_forward_result = list(zip(*worker_forward_result)) forward_result.extend(worker_forward_result) return forward_result def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): - if labels is not None: - assert len(batch) == len(labels) - if not forward_only: - assert hasattr(self, 'optimizer_class') + batch_lengths = get_batch_lengths(batch) + + if labels is not None and not forward_only: + assert hasattr( + self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward" num_microbatches = self.num_microbatches - microbatch_size = len(batch) // num_microbatches + microbatch_size = batch_lengths[0] // num_microbatches + device = self.device # If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks' input_pp_ranks = self.get_input_pp_ranks() @@ -805,16 +849,17 @@ class PipelineEngineBase(ABC, nn.Module): # control data input speed # to prevent exceed of wait limitations self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future) + batch_start = microbatch_size * microbatch_id + batch_end = batch_start + microbatch_size # set input - microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] - microbatch = microbatch.cuda() + microbatch = split_batch(batch, batch_start, batch_end, device) self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only) # set labels if labels is not None: - microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] - microlabels = microlabels.cuda() + # microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] + microlabels = split_batch(labels, batch_start, batch_end, device) self._set_labels(output_pp_ranks, microbatch_id, microlabels) # get data asynchronously diff --git a/colossalai/pipeline/rpc/_pipeline_schedule.py b/colossalai/pipeline/rpc/_pipeline_schedule.py index 523d2d807..6c4c39a73 100644 --- a/colossalai/pipeline/rpc/_pipeline_schedule.py +++ b/colossalai/pipeline/rpc/_pipeline_schedule.py @@ -44,7 +44,8 @@ class FillDrainPipelineEngine(PipelineEngineBase): chunk: int = 1, criterion: Callable = None, metric: Callable = None, - checkpoint: bool = False) -> None: + checkpoint: bool = False, + data_process_func: Callable = None) -> None: if chunk > 1: assert num_microbatches % stage_num == 0, \ @@ -52,7 +53,7 @@ class FillDrainPipelineEngine(PipelineEngineBase): use_1F1B = False super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint) + metric, checkpoint, data_process_func) class OneFOneBWorker(WorkerBase): @@ -103,7 +104,8 @@ class OneFOneBPipelineEngine(PipelineEngineBase): chunk: int = 1, criterion: Callable = None, metric: Callable = None, - checkpoint: bool = False) -> None: + checkpoint: bool = False, + data_process_func: Callable = None) -> None: if chunk > 1: assert num_microbatches % stage_num == 0, \ @@ -112,7 +114,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase): use_1F1B = True super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint) + metric, checkpoint, data_process_func) class ChimeraWorker(WorkerBase): @@ -227,9 +229,9 @@ class ChimeraWorker(WorkerBase): if step_index == 1: ppg.chimera_step_lock.acquire() - print(f'rank_{self.pp_rank} before all reduce') + # print(f'rank_{self.pp_rank} before all reduce') dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False) - print(f'rank_{self.pp_rank} after all reduce') + # print(f'rank_{self.pp_rank} after all reduce') if step_index == 0: ppg.chimera_step_lock.release() @@ -244,7 +246,8 @@ class ChimeraPipelineEngine(PipelineEngineBase): device: str, criterion: Callable = None, metric: Callable = None, - checkpoint: bool = False) -> None: + checkpoint: bool = False, + data_process_func: Callable = None) -> None: assert num_microbatches % stage_num == 0, \ "In Chimera, num_microbatches must be the multiply of stage_num!" @@ -252,7 +255,7 @@ class ChimeraPipelineEngine(PipelineEngineBase): chunk = 1 super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, - metric, checkpoint) + metric, checkpoint, data_process_func) def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]], input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]): @@ -330,6 +333,7 @@ class ChimeraPipelineEngine(PipelineEngineBase): for microbatch_id in range(self.num_microbatches): offset = (microbatch_id % 2) * stage_num ret = ret_future[pp_rank + offset][microbatch_id].wait() + ret = [ret] if isinstance(ret, torch.Tensor) else ret worker_forward_result[microbatch_id] = ret worker_forward_result = list(zip(*worker_forward_result)) diff --git a/colossalai/pipeline/rpc/utils.py b/colossalai/pipeline/rpc/utils.py new file mode 100644 index 000000000..5badecedb --- /dev/null +++ b/colossalai/pipeline/rpc/utils.py @@ -0,0 +1,74 @@ +from typing import List, Any, Tuple, Dict, Callable, Type, Union + +import torch +from torch.futures import Future + +from colorama import Back, Style + +# config for debug and test +use_color_debug = False + + +def color_debug(text, prefix=' ', color='blue'): + color = color.upper() + print(getattr(Back, color), prefix, Style.RESET_ALL, text) + + +def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any: + """process object recursively, like pytree + + Args: + obj (:class:`Any`): object to process + fn (:class:`Callable`): a function to process subobject in obj + process_types(:class: `type | tuple[type]`): types to determine the type to process + + Returns: + :class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn` + """ + if isinstance(obj, dict): + return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj} + elif isinstance(obj, tuple): + return tuple(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, list): + return list(pytree_map(o, fn, process_types, map_all) for o in obj) + elif isinstance(obj, process_types): + return fn(obj) + else: + return fn(obj) if map_all else obj + + +def tensor_shape_list(obj): + return pytree_map(obj, fn=lambda x: x.shape, process_types=torch.Tensor) + + +def get_batch_lengths(batch): + lengths = [] + pytree_map(batch, fn=lambda x: lengths.append(len(x)), process_types=torch.Tensor) + return lengths + + +def split_batch(batch: Any, start, stop, device: str): + if device == 'cuda': + fn = lambda x: x[start:stop].cuda() + else: + fn = lambda x: x[start:stop] + return pytree_map(batch, fn=fn, process_types=torch.Tensor) + + +def type_detail(obj): + return pytree_map(obj, lambda x: type(x), map_all=True) + + +def get_real_args_kwargs(args_or_kwargs): + args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future) + # TODO : combine producer and consumer + # by default, merge all args in the output args or kwargs + if args_or_kwargs is not None: + if isinstance(args_or_kwargs, dict): + pass + else: + flatten_args = [] + pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True) + args_or_kwargs = flatten_args + + return args_or_kwargs diff --git a/tests/test_pipeline/test_cuda_rpc_chimera.py b/tests/test_pipeline/test_cuda_rpc_chimera.py index cf9e4114f..45ad8f828 100644 --- a/tests/test_pipeline/test_cuda_rpc_chimera.py +++ b/tests/test_pipeline/test_cuda_rpc_chimera.py @@ -22,10 +22,9 @@ def run_master(args): epoch = args.epoch device = args.device - stage_num = 4 + stage_num = args.world_size chunk = 1 - num_microbatches = 4 - actual_stage_num = 4 + num_microbatches = args.num_microbatches use_checkpoint = False sample_num = 1024 @@ -78,6 +77,4 @@ def run_master(args): if __name__ == "__main__": args = parse_args() - args.world_size = 4 - args.num_microbatches = 4 rpc_run(args, run_master)