[pipeline/pytree] add pytree to process args and kwargs | provide `data_process_func` to process args and kwargs after forward (#1642)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera

* [pipeline/chimera] test chimera | fix bug of initializing

* [pipeline/pytree] add pytree to process args and kwargs | provide  to process args and kwargs after forward
pull/1669/head
Kirigaya Kazuto 2022-09-29 10:58:58 +08:00 committed by GitHub
parent c27e701cb2
commit 9708638ded
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 247 additions and 126 deletions

View File

@ -1,3 +1,4 @@
from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
from .utils import pytree_map
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine']
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']

View File

@ -1,9 +1,11 @@
import threading
from enum import Enum
from typing import List, Any, Tuple, Dict, Callable
from functools import partial
from abc import ABC, abstractmethod
import sys
import os
import inspect
import torch
from torch import nn
@ -12,57 +14,10 @@ from torch.futures import Future
from torch._C._distributed_rpc import PyRRef
from torch import autograd
from torch import optim
from tqdm import tqdm
from time import time
from colorama import Back, Style
# config for debug and test
use_color_debug = True
# TODO:
# 1. adjust to args and kwargs (pytree)
def color_debug(text, prefix=' ', color='blue'):
if use_color_debug:
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
def tensor_shape_list(tensors):
if tensors is None:
return None
if isinstance(tensors, (int, float)):
return tensors
if isinstance(tensors, torch.Tensor):
return tensors.shape
shapes = []
for t in tensors:
if hasattr(t, 'shape'):
shapes.append(t.shape)
else:
shapes.append('non tensor')
return shapes
def get_real_args(args):
if isinstance(args, torch.Tensor):
return args
elif isinstance(args, list):
real_args = []
for arg in args:
if isinstance(arg, Future):
value = arg.wait()
else:
value = arg
if isinstance(value, list):
real_args.extend(value)
else:
real_args.append(value)
return real_args
else:
raise TypeError(f"Expect receive tensor or list, but receive {type(args)}")
from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
pytree_map, get_real_args_kwargs, use_color_debug)
class Phase(Enum):
@ -100,9 +55,7 @@ class WorkItem:
kwargs: Dict[str, Any]
output: Future
microbatch_id: int
refcount: int
batch_id: int
num_microbatches: int
forward_only: bool
@ -123,14 +76,16 @@ class WorkItem:
class BackwardCache:
__slots__ = ('checkpoint', 'stage_inputs', 'stage_outputs')
__slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs')
checkpoint: bool
stage_inputs: Tuple[Any]
stage_input_args: Tuple[Any]
stage_input_kwargs: Dict[Any, Any]
stage_outputs: Tuple[Any]
def __init__(self,
stage_inputs: List[torch.Tensor],
stage_outputs: List[torch.Tensor] = None,
stage_input_args: Tuple[Any],
stage_input_kwargs: Dict[Any, Any] = None,
stage_outputs: Tuple[Any] = None,
checkpoint: bool = False) -> None:
for arg_name in self.__slots__:
setattr(self, arg_name, locals()[arg_name])
@ -147,13 +102,18 @@ class WorkerBase(ABC):
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
super().__init__()
self.pp_rank = pp_rank
self.actual_stage_num = actual_stage_num
self.num_microbatches = num_microbatches
self.checkpoint = checkpoint
if data_process_func is not None:
self.data_process_func = partial(data_process_func, pp_rank)
self.device = device
self._initialize_outstanding_range()
@ -260,18 +220,39 @@ class WorkerBase(ABC):
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
return self.module_partition.state_dict()
def _make_args_kwargs(self, microbatch):
if isinstance(microbatch, dict):
return [], microbatch
elif isinstance(microbatch, torch.Tensor):
return [microbatch], {}
elif isinstance(microbatch, (tuple, list)):
args = []
kwargs = {}
for arg in microbatch:
if isinstance(arg, dict):
kwargs.update(arg)
else:
args.append(arg)
return args, kwargs
else:
raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}")
# just for first pp_rank
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
assert self.consumer_stage_ids is not None
key = UniqueKey(microbatch_id, Phase.FORWARD)
output = self._get_future_by_device()
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches,
forward_only)
# make args and kwargs
args, kwargs = self._make_args_kwargs(microbatch)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
self.num_microbatches, forward_only)
with self.work_list_condition_lock:
self.work_list[key] = work_item
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}', 'data dispatch',
'magenta')
if use_color_debug:
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all()
# just for last pp_rank
@ -287,11 +268,12 @@ class WorkerBase(ABC):
key = UniqueKey(microbatch_id, Phase.BACKWARD)
output = self._get_future_by_device()
grad_wrt_loss = torch.tensor(1, device=self.device)
grad_wrt_loss = None
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, False)
if use_color_debug:
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
self.work_list[key] = work_item
@ -315,8 +297,9 @@ class WorkerBase(ABC):
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch',
'magenta')
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer',
'data dispatch', 'magenta')
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only)
@ -327,6 +310,7 @@ class WorkerBase(ABC):
key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_producer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
'data dispatch', 'magenta')
@ -344,6 +328,7 @@ class WorkerBase(ABC):
subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device()
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
'data dispatch', 'magenta')
@ -364,6 +349,7 @@ class WorkerBase(ABC):
key = UniqueKey(microbatch_id, Phase.BACKWARD)
assert key not in self.work_list
self.work_list[key] = work_item_from_consumer
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
'data dispatch', 'magenta')
@ -398,12 +384,23 @@ class WorkerBase(ABC):
def is_last_stage(self):
return self.pp_rank == self.actual_stage_num - 1
def _default_data_process_func(self, args_kwargs):
if self.is_first_stage():
args = args_kwargs[0]
kwargs = args_kwargs[1]
else:
args = args_kwargs
kwargs = {}
return args, kwargs
def _consume_work_item_by_phase(self, work_item: WorkItem):
phase = work_item.phase
args = work_item.args
kwargs = work_item.kwargs
microbatch_id = work_item.microbatch_id
forward_only = work_item.forward_only
data_process_func = getattr(self, 'data_process_func', self._default_data_process_func)
consume_result = None
is_first_stage = self.is_first_stage()
@ -420,18 +417,31 @@ class WorkerBase(ABC):
for stage_id in self.consumer_stage_ids:
consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id]
consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only)
self.forward_times += 1
# sustain pipeline context
self.forward_times += 1
if not forward_only:
self.outstanding += 1
args = get_real_args(args)
# last stage doesn't need to do checkpoint, for it will do backward instantly
# parse and integrate args and kwargs
if is_first_stage:
args = get_real_args_kwargs(args)
kwargs = get_real_args_kwargs(kwargs)
args_kwargs = (args, kwargs)
else:
args_kwargs = get_real_args_kwargs(args)
args, kwargs = data_process_func(args_kwargs)
stage_outputs = None
stage_input_args = args
stage_input_kwargs = kwargs
use_checkpoint = None
if forward_only:
with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs)
# TODO : integrate output list
if is_last_stage and self.criterion:
with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
@ -445,15 +455,18 @@ class WorkerBase(ABC):
metric_result = None
consume_result = [loss.item(), metric_result]
stage_outputs = None
stage_inputs = None
use_checkpoint = None
# last stage doesn't need to do checkpoint, for it will do backward instantly
stage_input_args = None
stage_input_kwargs = None
stage_outputs = consume_result
elif self.checkpoint and not is_last_stage:
with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs)
stage_outputs = None
stage_inputs = args
stage_outputs = consume_result
use_checkpoint = True
else:
consume_result = self.module_partition(*args, **kwargs)
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
@ -475,17 +488,14 @@ class WorkerBase(ABC):
loss = consume_result
stage_outputs = loss
stage_inputs = args
use_checkpoint = False
if not forward_only:
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args,
stage_input_kwargs,
stage_outputs,
checkpoint=use_checkpoint)
consume_result = [consume_result] if isinstance(consume_result,
(torch.Tensor, int, float)) else consume_result
# if not forward_only, do the backward
if not forward_only:
if is_last_stage: # if it is the last stage, trigger backward automatic
@ -504,23 +514,34 @@ class WorkerBase(ABC):
backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id)
stage_outputs = backward_cache.stage_outputs
stage_inputs = backward_cache.stage_inputs
stage_input_args = backward_cache.stage_input_args
stage_input_kwargs = backward_cache.stage_input_kwargs
use_checkpoint = backward_cache.checkpoint
if use_checkpoint:
stage_outputs = [self.module_partition(*stage_inputs)]
stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)]
# take tensor only (for only tensor can do backward)
stage_outputs_tensors = []
pytree_map(stage_outputs, stage_outputs_tensors.append, process_types=torch.Tensor)
# overlap recompute and future.wait
grad_tensors = get_real_args(args)
grad_tensors = get_real_args_kwargs(args)
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
# print('rank', self.pp_rank, tensor_shape_list(stage_outputs_tensors), tensor_shape_list(grad_tensors))
autograd.backward(stage_outputs_tensors, grad_tensors=grad_tensors)
# collect grad of input tensor
# there is a hypothesis that node in kwargs cann't be an non-leaf node in graph
# so we don't need to save the grad of node in kwargs.
consume_result = []
if not is_first_stage:
for input_node in stage_inputs:
if isinstance(input_node, torch.Tensor):
consume_result.append(input_node.grad)
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
# for input_node in stage_input_args:
# if isinstance(input_node, torch.Tensor):
# consume_result.append(input_node.grad)
else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@ -562,6 +583,7 @@ class WorkerBase(ABC):
def _work_loop(self):
# for init
self._get_producer_consumer()
torch.cuda.set_device(ppg.get_local_pp_rank())
# main loop
while True:
@ -571,6 +593,7 @@ class WorkerBase(ABC):
with self.work_list_condition_lock:
work_item = self.work_list.pop(work_item_key)
if use_color_debug:
color_debug(
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
'work loop', 'green')
@ -582,6 +605,7 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item)
if use_color_debug:
color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
'work loop', 'green')
@ -621,7 +645,8 @@ class PipelineEngineBase(ABC, nn.Module):
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
super().__init__()
self.worker_type = worker_type
self.partition_fn: Callable = partition_fn
@ -633,6 +658,7 @@ class PipelineEngineBase(ABC, nn.Module):
self.use_1F1B = use_1F1B
self.stage_num = stage_num
self.checkpoint = checkpoint
self.data_process_func = data_process_func
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
@ -644,9 +670,21 @@ class PipelineEngineBase(ABC, nn.Module):
self._init_worker()
def _check_argument(self) -> None:
# make virtual stage num
self.virtual_stage_num = self.stage_num * self.chunk
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
# check data_process_func
data_process_func = self.data_process_func
if data_process_func is not None:
assert callable(data_process_func), "data_process_func must be a function"
assert '<locals>' not in data_process_func.__repr__(), "data_process_func must be a global function"
assert '<lambda>' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
sig = inspect.signature(data_process_func)
assert len(
sig.parameters
) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
def _get_actual_stage_num(self) -> int:
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
@ -682,6 +720,7 @@ class PipelineEngineBase(ABC, nn.Module):
metric = self.metric
partition_fn = self.partition_fn
chunk = self.chunk
data_process_func = self.data_process_func
for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)):
partition_id = self.pp_rank_to_module_partition_id[pp_rank]
@ -693,7 +732,7 @@ class PipelineEngineBase(ABC, nn.Module):
worker_type,
args=(partition_fn, partition_args, pp_rank,
actual_stage_num, num_microbatches, device,
criterion, metric, checkpoint))
criterion, metric, checkpoint, data_process_func))
# let each worker know global worker rref (include itself)
sync_futs = []
@ -779,20 +818,25 @@ class PipelineEngineBase(ABC, nn.Module):
worker_forward_result = [None] * self.num_microbatches
for microbatch_id in range(self.num_microbatches):
ret = ret_future[pp_rank][microbatch_id].wait()
# TODO : more stable format
ret = [ret] if isinstance(ret, torch.Tensor) else ret
worker_forward_result[microbatch_id] = ret
worker_forward_result = list(zip(*worker_forward_result))
forward_result.extend(worker_forward_result)
return forward_result
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
if labels is not None:
assert len(batch) == len(labels)
if not forward_only:
assert hasattr(self, 'optimizer_class')
batch_lengths = get_batch_lengths(batch)
if labels is not None and not forward_only:
assert hasattr(
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
num_microbatches = self.num_microbatches
microbatch_size = len(batch) // num_microbatches
microbatch_size = batch_lengths[0] // num_microbatches
device = self.device
# If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
input_pp_ranks = self.get_input_pp_ranks()
@ -805,16 +849,17 @@ class PipelineEngineBase(ABC, nn.Module):
# control data input speed
# to prevent exceed of wait limitations
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
batch_start = microbatch_size * microbatch_id
batch_end = batch_start + microbatch_size
# set input
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
microbatch = microbatch.cuda()
microbatch = split_batch(batch, batch_start, batch_end, device)
self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only)
# set labels
if labels is not None:
microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
microlabels = microlabels.cuda()
# microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
microlabels = split_batch(labels, batch_start, batch_end, device)
self._set_labels(output_pp_ranks, microbatch_id, microlabels)
# get data asynchronously

View File

@ -44,7 +44,8 @@ class FillDrainPipelineEngine(PipelineEngineBase):
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
@ -52,7 +53,7 @@ class FillDrainPipelineEngine(PipelineEngineBase):
use_1F1B = False
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint)
metric, checkpoint, data_process_func)
class OneFOneBWorker(WorkerBase):
@ -103,7 +104,8 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
chunk: int = 1,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
if chunk > 1:
assert num_microbatches % stage_num == 0, \
@ -112,7 +114,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
use_1F1B = True
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint)
metric, checkpoint, data_process_func)
class ChimeraWorker(WorkerBase):
@ -227,9 +229,9 @@ class ChimeraWorker(WorkerBase):
if step_index == 1:
ppg.chimera_step_lock.acquire()
print(f'rank_{self.pp_rank} before all reduce')
# print(f'rank_{self.pp_rank} before all reduce')
dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False)
print(f'rank_{self.pp_rank} after all reduce')
# print(f'rank_{self.pp_rank} after all reduce')
if step_index == 0:
ppg.chimera_step_lock.release()
@ -244,7 +246,8 @@ class ChimeraPipelineEngine(PipelineEngineBase):
device: str,
criterion: Callable = None,
metric: Callable = None,
checkpoint: bool = False) -> None:
checkpoint: bool = False,
data_process_func: Callable = None) -> None:
assert num_microbatches % stage_num == 0, \
"In Chimera, num_microbatches must be the multiply of stage_num!"
@ -252,7 +255,7 @@ class ChimeraPipelineEngine(PipelineEngineBase):
chunk = 1
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint)
metric, checkpoint, data_process_func)
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]):
@ -330,6 +333,7 @@ class ChimeraPipelineEngine(PipelineEngineBase):
for microbatch_id in range(self.num_microbatches):
offset = (microbatch_id % 2) * stage_num
ret = ret_future[pp_rank + offset][microbatch_id].wait()
ret = [ret] if isinstance(ret, torch.Tensor) else ret
worker_forward_result[microbatch_id] = ret
worker_forward_result = list(zip(*worker_forward_result))

View File

@ -0,0 +1,74 @@
from typing import List, Any, Tuple, Dict, Callable, Type, Union
import torch
from torch.futures import Future
from colorama import Back, Style
# config for debug and test
use_color_debug = False
def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
"""process object recursively, like pytree
Args:
obj (:class:`Any`): object to process
fn (:class:`Callable`): a function to process subobject in obj
process_types(:class: `type | tuple[type]`): types to determine the type to process
Returns:
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
"""
if isinstance(obj, dict):
return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
elif isinstance(obj, tuple):
return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, list):
return list(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, process_types):
return fn(obj)
else:
return fn(obj) if map_all else obj
def tensor_shape_list(obj):
return pytree_map(obj, fn=lambda x: x.shape, process_types=torch.Tensor)
def get_batch_lengths(batch):
lengths = []
pytree_map(batch, fn=lambda x: lengths.append(len(x)), process_types=torch.Tensor)
return lengths
def split_batch(batch: Any, start, stop, device: str):
if device == 'cuda':
fn = lambda x: x[start:stop].cuda()
else:
fn = lambda x: x[start:stop]
return pytree_map(batch, fn=fn, process_types=torch.Tensor)
def type_detail(obj):
return pytree_map(obj, lambda x: type(x), map_all=True)
def get_real_args_kwargs(args_or_kwargs):
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
# TODO : combine producer and consumer
# by default, merge all args in the output args or kwargs
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
return args_or_kwargs

View File

@ -22,10 +22,9 @@ def run_master(args):
epoch = args.epoch
device = args.device
stage_num = 4
stage_num = args.world_size
chunk = 1
num_microbatches = 4
actual_stage_num = 4
num_microbatches = args.num_microbatches
use_checkpoint = False
sample_num = 1024
@ -78,6 +77,4 @@ def run_master(args):
if __name__ == "__main__":
args = parse_args()
args.world_size = 4
args.num_microbatches = 4
rpc_run(args, run_master)