mirror of https://github.com/hpcaitech/ColossalAI
[pipeline/pytree] add pytree to process args and kwargs | provide `data_process_func` to process args and kwargs after forward (#1642)
* [pipeline/tuning] improve dispatch performance both time and space cost * [pipeline/converge] add interface for testing convergence * [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style * Update PipelineBase.py * [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera * [pipeline/chimera] test chimera | fix bug of initializing * [pipeline/pytree] add pytree to process args and kwargs | provide to process args and kwargs after forwardpull/1669/head
parent
c27e701cb2
commit
9708638ded
|
@ -1,3 +1,4 @@
|
|||
from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
|
||||
from .utils import pytree_map
|
||||
|
||||
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine']
|
||||
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']
|
|
@ -1,9 +1,11 @@
|
|||
import threading
|
||||
from enum import Enum
|
||||
from typing import List, Any, Tuple, Dict, Callable
|
||||
from functools import partial
|
||||
from abc import ABC, abstractmethod
|
||||
import sys
|
||||
import os
|
||||
import inspect
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -12,57 +14,10 @@ from torch.futures import Future
|
|||
from torch._C._distributed_rpc import PyRRef
|
||||
from torch import autograd
|
||||
from torch import optim
|
||||
from tqdm import tqdm
|
||||
from time import time
|
||||
|
||||
from colorama import Back, Style
|
||||
|
||||
# config for debug and test
|
||||
use_color_debug = True
|
||||
|
||||
# TODO:
|
||||
# 1. adjust to args and kwargs (pytree)
|
||||
|
||||
|
||||
def color_debug(text, prefix=' ', color='blue'):
|
||||
if use_color_debug:
|
||||
color = color.upper()
|
||||
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
||||
|
||||
|
||||
def tensor_shape_list(tensors):
|
||||
if tensors is None:
|
||||
return None
|
||||
if isinstance(tensors, (int, float)):
|
||||
return tensors
|
||||
if isinstance(tensors, torch.Tensor):
|
||||
return tensors.shape
|
||||
shapes = []
|
||||
for t in tensors:
|
||||
if hasattr(t, 'shape'):
|
||||
shapes.append(t.shape)
|
||||
else:
|
||||
shapes.append('non tensor')
|
||||
return shapes
|
||||
|
||||
|
||||
def get_real_args(args):
|
||||
if isinstance(args, torch.Tensor):
|
||||
return args
|
||||
elif isinstance(args, list):
|
||||
real_args = []
|
||||
for arg in args:
|
||||
if isinstance(arg, Future):
|
||||
value = arg.wait()
|
||||
else:
|
||||
value = arg
|
||||
if isinstance(value, list):
|
||||
real_args.extend(value)
|
||||
else:
|
||||
real_args.append(value)
|
||||
return real_args
|
||||
else:
|
||||
raise TypeError(f"Expect receive tensor or list, but receive {type(args)}")
|
||||
from colossalai.pipeline.pipeline_process_group import ppg
|
||||
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
|
||||
pytree_map, get_real_args_kwargs, use_color_debug)
|
||||
|
||||
|
||||
class Phase(Enum):
|
||||
|
@ -100,9 +55,7 @@ class WorkItem:
|
|||
kwargs: Dict[str, Any]
|
||||
output: Future
|
||||
microbatch_id: int
|
||||
|
||||
refcount: int
|
||||
|
||||
batch_id: int
|
||||
num_microbatches: int
|
||||
forward_only: bool
|
||||
|
@ -123,14 +76,16 @@ class WorkItem:
|
|||
|
||||
|
||||
class BackwardCache:
|
||||
__slots__ = ('checkpoint', 'stage_inputs', 'stage_outputs')
|
||||
__slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs')
|
||||
checkpoint: bool
|
||||
stage_inputs: Tuple[Any]
|
||||
stage_input_args: Tuple[Any]
|
||||
stage_input_kwargs: Dict[Any, Any]
|
||||
stage_outputs: Tuple[Any]
|
||||
|
||||
def __init__(self,
|
||||
stage_inputs: List[torch.Tensor],
|
||||
stage_outputs: List[torch.Tensor] = None,
|
||||
stage_input_args: Tuple[Any],
|
||||
stage_input_kwargs: Dict[Any, Any] = None,
|
||||
stage_outputs: Tuple[Any] = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
for arg_name in self.__slots__:
|
||||
setattr(self, arg_name, locals()[arg_name])
|
||||
|
@ -147,13 +102,18 @@ class WorkerBase(ABC):
|
|||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.pp_rank = pp_rank
|
||||
self.actual_stage_num = actual_stage_num
|
||||
self.num_microbatches = num_microbatches
|
||||
self.checkpoint = checkpoint
|
||||
|
||||
if data_process_func is not None:
|
||||
self.data_process_func = partial(data_process_func, pp_rank)
|
||||
|
||||
self.device = device
|
||||
self._initialize_outstanding_range()
|
||||
|
||||
|
@ -260,18 +220,39 @@ class WorkerBase(ABC):
|
|||
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
|
||||
return self.module_partition.state_dict()
|
||||
|
||||
def _make_args_kwargs(self, microbatch):
|
||||
if isinstance(microbatch, dict):
|
||||
return [], microbatch
|
||||
elif isinstance(microbatch, torch.Tensor):
|
||||
return [microbatch], {}
|
||||
elif isinstance(microbatch, (tuple, list)):
|
||||
args = []
|
||||
kwargs = {}
|
||||
for arg in microbatch:
|
||||
if isinstance(arg, dict):
|
||||
kwargs.update(arg)
|
||||
else:
|
||||
args.append(arg)
|
||||
return args, kwargs
|
||||
else:
|
||||
raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}")
|
||||
|
||||
# just for first pp_rank
|
||||
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
|
||||
assert self.consumer_stage_ids is not None
|
||||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
output = self._get_future_by_device()
|
||||
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches,
|
||||
forward_only)
|
||||
|
||||
# make args and kwargs
|
||||
args, kwargs = self._make_args_kwargs(microbatch)
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
|
||||
self.num_microbatches, forward_only)
|
||||
with self.work_list_condition_lock:
|
||||
self.work_list[key] = work_item
|
||||
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}', 'data dispatch',
|
||||
'magenta')
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}',
|
||||
'data dispatch', 'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
# just for last pp_rank
|
||||
|
@ -287,12 +268,13 @@ class WorkerBase(ABC):
|
|||
|
||||
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||
output = self._get_future_by_device()
|
||||
grad_wrt_loss = torch.tensor(1, device=self.device)
|
||||
grad_wrt_loss = None
|
||||
|
||||
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
|
||||
self.num_microbatches, False)
|
||||
|
||||
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
|
||||
|
||||
self.work_list[key] = work_item
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
@ -315,8 +297,9 @@ class WorkerBase(ABC):
|
|||
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
|
||||
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
|
||||
|
||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch',
|
||||
'magenta')
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer',
|
||||
'data dispatch', 'magenta')
|
||||
|
||||
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
|
||||
microbatch_id, None, self.num_microbatches, forward_only)
|
||||
|
@ -327,9 +310,10 @@ class WorkerBase(ABC):
|
|||
key = UniqueKey(microbatch_id, Phase.FORWARD)
|
||||
assert key not in self.work_list
|
||||
self.work_list[key] = work_item_from_producer
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
|
||||
'data dispatch', 'magenta')
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
|
||||
'data dispatch', 'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def subscribe_consumer(self, microbatch_id: int):
|
||||
|
@ -344,8 +328,9 @@ class WorkerBase(ABC):
|
|||
subscribe_backward_futures: List[Future] = [None] * consumer_num
|
||||
output = self._get_future_by_device()
|
||||
|
||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
|
||||
'data dispatch', 'magenta')
|
||||
if use_color_debug:
|
||||
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
|
||||
'data dispatch', 'magenta')
|
||||
|
||||
for i in range(consumer_num):
|
||||
consumer_stage_id = self.consumer_stage_ids[i]
|
||||
|
@ -364,9 +349,10 @@ class WorkerBase(ABC):
|
|||
key = UniqueKey(microbatch_id, Phase.BACKWARD)
|
||||
assert key not in self.work_list
|
||||
self.work_list[key] = work_item_from_consumer
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
|
||||
'data dispatch', 'magenta')
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
|
||||
'data dispatch', 'magenta')
|
||||
self.work_list_condition_lock.notify_all()
|
||||
|
||||
def _get_producer_consumer(self) -> None:
|
||||
|
@ -398,12 +384,23 @@ class WorkerBase(ABC):
|
|||
def is_last_stage(self):
|
||||
return self.pp_rank == self.actual_stage_num - 1
|
||||
|
||||
def _default_data_process_func(self, args_kwargs):
|
||||
if self.is_first_stage():
|
||||
args = args_kwargs[0]
|
||||
kwargs = args_kwargs[1]
|
||||
else:
|
||||
args = args_kwargs
|
||||
kwargs = {}
|
||||
|
||||
return args, kwargs
|
||||
|
||||
def _consume_work_item_by_phase(self, work_item: WorkItem):
|
||||
phase = work_item.phase
|
||||
args = work_item.args
|
||||
kwargs = work_item.kwargs
|
||||
microbatch_id = work_item.microbatch_id
|
||||
forward_only = work_item.forward_only
|
||||
data_process_func = getattr(self, 'data_process_func', self._default_data_process_func)
|
||||
consume_result = None
|
||||
|
||||
is_first_stage = self.is_first_stage()
|
||||
|
@ -420,18 +417,31 @@ class WorkerBase(ABC):
|
|||
for stage_id in self.consumer_stage_ids:
|
||||
consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id]
|
||||
consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only)
|
||||
self.forward_times += 1
|
||||
|
||||
# sustain pipeline context
|
||||
self.forward_times += 1
|
||||
if not forward_only:
|
||||
self.outstanding += 1
|
||||
args = get_real_args(args)
|
||||
|
||||
# last stage doesn't need to do checkpoint, for it will do backward instantly
|
||||
# parse and integrate args and kwargs
|
||||
if is_first_stage:
|
||||
args = get_real_args_kwargs(args)
|
||||
kwargs = get_real_args_kwargs(kwargs)
|
||||
args_kwargs = (args, kwargs)
|
||||
else:
|
||||
args_kwargs = get_real_args_kwargs(args)
|
||||
|
||||
args, kwargs = data_process_func(args_kwargs)
|
||||
|
||||
stage_outputs = None
|
||||
stage_input_args = args
|
||||
stage_input_kwargs = kwargs
|
||||
use_checkpoint = None
|
||||
|
||||
if forward_only:
|
||||
with torch.no_grad():
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
|
||||
# TODO : integrate output list
|
||||
if is_last_stage and self.criterion:
|
||||
with self.label_lock:
|
||||
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
|
||||
|
@ -445,15 +455,18 @@ class WorkerBase(ABC):
|
|||
metric_result = None
|
||||
consume_result = [loss.item(), metric_result]
|
||||
|
||||
stage_outputs = None
|
||||
stage_inputs = None
|
||||
use_checkpoint = None
|
||||
# last stage doesn't need to do checkpoint, for it will do backward instantly
|
||||
stage_input_args = None
|
||||
stage_input_kwargs = None
|
||||
stage_outputs = consume_result
|
||||
|
||||
elif self.checkpoint and not is_last_stage:
|
||||
with torch.no_grad():
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
stage_outputs = None
|
||||
stage_inputs = args
|
||||
|
||||
stage_outputs = consume_result
|
||||
use_checkpoint = True
|
||||
|
||||
else:
|
||||
consume_result = self.module_partition(*args, **kwargs)
|
||||
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
|
||||
|
@ -475,17 +488,14 @@ class WorkerBase(ABC):
|
|||
loss = consume_result
|
||||
|
||||
stage_outputs = loss
|
||||
stage_inputs = args
|
||||
use_checkpoint = False
|
||||
|
||||
if not forward_only:
|
||||
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs,
|
||||
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args,
|
||||
stage_input_kwargs,
|
||||
stage_outputs,
|
||||
checkpoint=use_checkpoint)
|
||||
|
||||
consume_result = [consume_result] if isinstance(consume_result,
|
||||
(torch.Tensor, int, float)) else consume_result
|
||||
|
||||
# if not forward_only, do the backward
|
||||
if not forward_only:
|
||||
if is_last_stage: # if it is the last stage, trigger backward automatic
|
||||
|
@ -504,23 +514,34 @@ class WorkerBase(ABC):
|
|||
backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id)
|
||||
|
||||
stage_outputs = backward_cache.stage_outputs
|
||||
stage_inputs = backward_cache.stage_inputs
|
||||
stage_input_args = backward_cache.stage_input_args
|
||||
stage_input_kwargs = backward_cache.stage_input_kwargs
|
||||
use_checkpoint = backward_cache.checkpoint
|
||||
|
||||
if use_checkpoint:
|
||||
stage_outputs = [self.module_partition(*stage_inputs)]
|
||||
stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)]
|
||||
|
||||
# take tensor only (for only tensor can do backward)
|
||||
stage_outputs_tensors = []
|
||||
pytree_map(stage_outputs, stage_outputs_tensors.append, process_types=torch.Tensor)
|
||||
|
||||
# overlap recompute and future.wait
|
||||
grad_tensors = get_real_args(args)
|
||||
grad_tensors = get_real_args_kwargs(args)
|
||||
|
||||
autograd.backward(stage_outputs, grad_tensors=grad_tensors)
|
||||
# print('rank', self.pp_rank, tensor_shape_list(stage_outputs_tensors), tensor_shape_list(grad_tensors))
|
||||
autograd.backward(stage_outputs_tensors, grad_tensors=grad_tensors)
|
||||
|
||||
# collect grad of input tensor
|
||||
# there is a hypothesis that node in kwargs cann't be an non-leaf node in graph
|
||||
# so we don't need to save the grad of node in kwargs.
|
||||
consume_result = []
|
||||
if not is_first_stage:
|
||||
for input_node in stage_inputs:
|
||||
if isinstance(input_node, torch.Tensor):
|
||||
consume_result.append(input_node.grad)
|
||||
pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||
pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
|
||||
|
||||
# for input_node in stage_input_args:
|
||||
# if isinstance(input_node, torch.Tensor):
|
||||
# consume_result.append(input_node.grad)
|
||||
|
||||
else:
|
||||
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
|
||||
|
@ -562,6 +583,7 @@ class WorkerBase(ABC):
|
|||
def _work_loop(self):
|
||||
# for init
|
||||
self._get_producer_consumer()
|
||||
torch.cuda.set_device(ppg.get_local_pp_rank())
|
||||
|
||||
# main loop
|
||||
while True:
|
||||
|
@ -571,9 +593,10 @@ class WorkerBase(ABC):
|
|||
with self.work_list_condition_lock:
|
||||
work_item = self.work_list.pop(work_item_key)
|
||||
|
||||
color_debug(
|
||||
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
|
||||
'work loop', 'green')
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
|
||||
'work loop', 'green')
|
||||
|
||||
with self.output_list_condition_lock:
|
||||
# assert work_item_key not in self.output_list
|
||||
|
@ -582,9 +605,10 @@ class WorkerBase(ABC):
|
|||
|
||||
consume_result = self._consume_work_item_by_phase(work_item)
|
||||
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
|
||||
'work loop', 'green')
|
||||
if use_color_debug:
|
||||
color_debug(
|
||||
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
|
||||
'work loop', 'green')
|
||||
|
||||
work_item.output.set_result(consume_result)
|
||||
|
||||
|
@ -621,7 +645,8 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
super().__init__()
|
||||
self.worker_type = worker_type
|
||||
self.partition_fn: Callable = partition_fn
|
||||
|
@ -633,6 +658,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
self.use_1F1B = use_1F1B
|
||||
self.stage_num = stage_num
|
||||
self.checkpoint = checkpoint
|
||||
self.data_process_func = data_process_func
|
||||
|
||||
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
|
||||
|
||||
|
@ -644,9 +670,21 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
self._init_worker()
|
||||
|
||||
def _check_argument(self) -> None:
|
||||
# make virtual stage num
|
||||
self.virtual_stage_num = self.stage_num * self.chunk
|
||||
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
|
||||
|
||||
# check data_process_func
|
||||
data_process_func = self.data_process_func
|
||||
if data_process_func is not None:
|
||||
assert callable(data_process_func), "data_process_func must be a function"
|
||||
assert '<locals>' not in data_process_func.__repr__(), "data_process_func must be a global function"
|
||||
assert '<lambda>' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
|
||||
sig = inspect.signature(data_process_func)
|
||||
assert len(
|
||||
sig.parameters
|
||||
) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
|
||||
|
||||
def _get_actual_stage_num(self) -> int:
|
||||
return self.stage_num if self.chunk == 1 else self.virtual_stage_num
|
||||
|
||||
|
@ -682,6 +720,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
metric = self.metric
|
||||
partition_fn = self.partition_fn
|
||||
chunk = self.chunk
|
||||
data_process_func = self.data_process_func
|
||||
|
||||
for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)):
|
||||
partition_id = self.pp_rank_to_module_partition_id[pp_rank]
|
||||
|
@ -693,7 +732,7 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
worker_type,
|
||||
args=(partition_fn, partition_args, pp_rank,
|
||||
actual_stage_num, num_microbatches, device,
|
||||
criterion, metric, checkpoint))
|
||||
criterion, metric, checkpoint, data_process_func))
|
||||
|
||||
# let each worker know global worker rref (include itself)
|
||||
sync_futs = []
|
||||
|
@ -779,20 +818,25 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
worker_forward_result = [None] * self.num_microbatches
|
||||
for microbatch_id in range(self.num_microbatches):
|
||||
ret = ret_future[pp_rank][microbatch_id].wait()
|
||||
# TODO : more stable format
|
||||
ret = [ret] if isinstance(ret, torch.Tensor) else ret
|
||||
worker_forward_result[microbatch_id] = ret
|
||||
|
||||
worker_forward_result = list(zip(*worker_forward_result))
|
||||
forward_result.extend(worker_forward_result)
|
||||
|
||||
return forward_result
|
||||
|
||||
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
|
||||
if labels is not None:
|
||||
assert len(batch) == len(labels)
|
||||
if not forward_only:
|
||||
assert hasattr(self, 'optimizer_class')
|
||||
batch_lengths = get_batch_lengths(batch)
|
||||
|
||||
if labels is not None and not forward_only:
|
||||
assert hasattr(
|
||||
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
|
||||
|
||||
num_microbatches = self.num_microbatches
|
||||
microbatch_size = len(batch) // num_microbatches
|
||||
microbatch_size = batch_lengths[0] // num_microbatches
|
||||
device = self.device
|
||||
|
||||
# If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
|
||||
input_pp_ranks = self.get_input_pp_ranks()
|
||||
|
@ -805,16 +849,17 @@ class PipelineEngineBase(ABC, nn.Module):
|
|||
# control data input speed
|
||||
# to prevent exceed of wait limitations
|
||||
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
|
||||
batch_start = microbatch_size * microbatch_id
|
||||
batch_end = batch_start + microbatch_size
|
||||
|
||||
# set input
|
||||
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
|
||||
microbatch = microbatch.cuda()
|
||||
microbatch = split_batch(batch, batch_start, batch_end, device)
|
||||
self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only)
|
||||
|
||||
# set labels
|
||||
if labels is not None:
|
||||
microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
|
||||
microlabels = microlabels.cuda()
|
||||
# microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
|
||||
microlabels = split_batch(labels, batch_start, batch_end, device)
|
||||
self._set_labels(output_pp_ranks, microbatch_id, microlabels)
|
||||
|
||||
# get data asynchronously
|
||||
|
|
|
@ -44,7 +44,8 @@ class FillDrainPipelineEngine(PipelineEngineBase):
|
|||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
|
@ -52,7 +53,7 @@ class FillDrainPipelineEngine(PipelineEngineBase):
|
|||
use_1F1B = False
|
||||
|
||||
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint)
|
||||
metric, checkpoint, data_process_func)
|
||||
|
||||
|
||||
class OneFOneBWorker(WorkerBase):
|
||||
|
@ -103,7 +104,8 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
|
|||
chunk: int = 1,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
if chunk > 1:
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
|
@ -112,7 +114,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
|
|||
use_1F1B = True
|
||||
|
||||
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint)
|
||||
metric, checkpoint, data_process_func)
|
||||
|
||||
|
||||
class ChimeraWorker(WorkerBase):
|
||||
|
@ -227,9 +229,9 @@ class ChimeraWorker(WorkerBase):
|
|||
if step_index == 1:
|
||||
ppg.chimera_step_lock.acquire()
|
||||
|
||||
print(f'rank_{self.pp_rank} before all reduce')
|
||||
# print(f'rank_{self.pp_rank} before all reduce')
|
||||
dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False)
|
||||
print(f'rank_{self.pp_rank} after all reduce')
|
||||
# print(f'rank_{self.pp_rank} after all reduce')
|
||||
|
||||
if step_index == 0:
|
||||
ppg.chimera_step_lock.release()
|
||||
|
@ -244,7 +246,8 @@ class ChimeraPipelineEngine(PipelineEngineBase):
|
|||
device: str,
|
||||
criterion: Callable = None,
|
||||
metric: Callable = None,
|
||||
checkpoint: bool = False) -> None:
|
||||
checkpoint: bool = False,
|
||||
data_process_func: Callable = None) -> None:
|
||||
|
||||
assert num_microbatches % stage_num == 0, \
|
||||
"In Chimera, num_microbatches must be the multiply of stage_num!"
|
||||
|
@ -252,7 +255,7 @@ class ChimeraPipelineEngine(PipelineEngineBase):
|
|||
chunk = 1
|
||||
|
||||
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
|
||||
metric, checkpoint)
|
||||
metric, checkpoint, data_process_func)
|
||||
|
||||
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
|
||||
input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]):
|
||||
|
@ -330,6 +333,7 @@ class ChimeraPipelineEngine(PipelineEngineBase):
|
|||
for microbatch_id in range(self.num_microbatches):
|
||||
offset = (microbatch_id % 2) * stage_num
|
||||
ret = ret_future[pp_rank + offset][microbatch_id].wait()
|
||||
ret = [ret] if isinstance(ret, torch.Tensor) else ret
|
||||
worker_forward_result[microbatch_id] = ret
|
||||
|
||||
worker_forward_result = list(zip(*worker_forward_result))
|
||||
|
|
|
@ -0,0 +1,74 @@
|
|||
from typing import List, Any, Tuple, Dict, Callable, Type, Union
|
||||
|
||||
import torch
|
||||
from torch.futures import Future
|
||||
|
||||
from colorama import Back, Style
|
||||
|
||||
# config for debug and test
|
||||
use_color_debug = False
|
||||
|
||||
|
||||
def color_debug(text, prefix=' ', color='blue'):
|
||||
color = color.upper()
|
||||
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
|
||||
|
||||
|
||||
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
|
||||
"""process object recursively, like pytree
|
||||
|
||||
Args:
|
||||
obj (:class:`Any`): object to process
|
||||
fn (:class:`Callable`): a function to process subobject in obj
|
||||
process_types(:class: `type | tuple[type]`): types to determine the type to process
|
||||
|
||||
Returns:
|
||||
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
|
||||
elif isinstance(obj, tuple):
|
||||
return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
|
||||
elif isinstance(obj, list):
|
||||
return list(pytree_map(o, fn, process_types, map_all) for o in obj)
|
||||
elif isinstance(obj, process_types):
|
||||
return fn(obj)
|
||||
else:
|
||||
return fn(obj) if map_all else obj
|
||||
|
||||
|
||||
def tensor_shape_list(obj):
|
||||
return pytree_map(obj, fn=lambda x: x.shape, process_types=torch.Tensor)
|
||||
|
||||
|
||||
def get_batch_lengths(batch):
|
||||
lengths = []
|
||||
pytree_map(batch, fn=lambda x: lengths.append(len(x)), process_types=torch.Tensor)
|
||||
return lengths
|
||||
|
||||
|
||||
def split_batch(batch: Any, start, stop, device: str):
|
||||
if device == 'cuda':
|
||||
fn = lambda x: x[start:stop].cuda()
|
||||
else:
|
||||
fn = lambda x: x[start:stop]
|
||||
return pytree_map(batch, fn=fn, process_types=torch.Tensor)
|
||||
|
||||
|
||||
def type_detail(obj):
|
||||
return pytree_map(obj, lambda x: type(x), map_all=True)
|
||||
|
||||
|
||||
def get_real_args_kwargs(args_or_kwargs):
|
||||
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
|
||||
# TODO : combine producer and consumer
|
||||
# by default, merge all args in the output args or kwargs
|
||||
if args_or_kwargs is not None:
|
||||
if isinstance(args_or_kwargs, dict):
|
||||
pass
|
||||
else:
|
||||
flatten_args = []
|
||||
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
|
||||
args_or_kwargs = flatten_args
|
||||
|
||||
return args_or_kwargs
|
|
@ -22,10 +22,9 @@ def run_master(args):
|
|||
|
||||
epoch = args.epoch
|
||||
device = args.device
|
||||
stage_num = 4
|
||||
stage_num = args.world_size
|
||||
chunk = 1
|
||||
num_microbatches = 4
|
||||
actual_stage_num = 4
|
||||
num_microbatches = args.num_microbatches
|
||||
use_checkpoint = False
|
||||
|
||||
sample_num = 1024
|
||||
|
@ -78,6 +77,4 @@ def run_master(args):
|
|||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
args.world_size = 4
|
||||
args.num_microbatches = 4
|
||||
rpc_run(args, run_master)
|
||||
|
|
Loading…
Reference in New Issue