[pipeline/pytree] add pytree to process args and kwargs | provide `data_process_func` to process args and kwargs after forward (#1642)

* [pipeline/tuning] improve dispatch performance both time and space cost

* [pipeline/converge] add interface for testing convergence

* [NFC] polish colossalai/utils/multi_tensor_apply/multi_tensor_apply.py code style

* Update PipelineBase.py

* [pipeline/chimera] reconstruct PipelineBase and Worker to support more feasible custom schedule | finish Chimera

* [pipeline/chimera] test chimera | fix bug of initializing

* [pipeline/pytree] add pytree to process args and kwargs | provide  to process args and kwargs after forward
pull/1669/head
Kirigaya Kazuto 2022-09-29 10:58:58 +08:00 committed by GitHub
parent c27e701cb2
commit 9708638ded
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 247 additions and 126 deletions

View File

@ -1,3 +1,4 @@
from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine from ._pipeline_schedule import FillDrainPipelineEngine, OneFOneBPipelineEngine, ChimeraPipelineEngine
from .utils import pytree_map
__all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine'] __all__ = ['FillDrainPipelineEngine', 'OneFOneBPipelineEngine', 'ChimeraPipelineEngine', 'pytree_map']

View File

@ -1,9 +1,11 @@
import threading import threading
from enum import Enum from enum import Enum
from typing import List, Any, Tuple, Dict, Callable from typing import List, Any, Tuple, Dict, Callable
from functools import partial
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
import sys import sys
import os import os
import inspect
import torch import torch
from torch import nn from torch import nn
@ -12,57 +14,10 @@ from torch.futures import Future
from torch._C._distributed_rpc import PyRRef from torch._C._distributed_rpc import PyRRef
from torch import autograd from torch import autograd
from torch import optim from torch import optim
from tqdm import tqdm
from time import time
from colorama import Back, Style from colossalai.pipeline.pipeline_process_group import ppg
from colossalai.pipeline.rpc.utils import (color_debug, tensor_shape_list, get_batch_lengths, split_batch, type_detail,
# config for debug and test pytree_map, get_real_args_kwargs, use_color_debug)
use_color_debug = True
# TODO:
# 1. adjust to args and kwargs (pytree)
def color_debug(text, prefix=' ', color='blue'):
if use_color_debug:
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
def tensor_shape_list(tensors):
if tensors is None:
return None
if isinstance(tensors, (int, float)):
return tensors
if isinstance(tensors, torch.Tensor):
return tensors.shape
shapes = []
for t in tensors:
if hasattr(t, 'shape'):
shapes.append(t.shape)
else:
shapes.append('non tensor')
return shapes
def get_real_args(args):
if isinstance(args, torch.Tensor):
return args
elif isinstance(args, list):
real_args = []
for arg in args:
if isinstance(arg, Future):
value = arg.wait()
else:
value = arg
if isinstance(value, list):
real_args.extend(value)
else:
real_args.append(value)
return real_args
else:
raise TypeError(f"Expect receive tensor or list, but receive {type(args)}")
class Phase(Enum): class Phase(Enum):
@ -100,9 +55,7 @@ class WorkItem:
kwargs: Dict[str, Any] kwargs: Dict[str, Any]
output: Future output: Future
microbatch_id: int microbatch_id: int
refcount: int refcount: int
batch_id: int batch_id: int
num_microbatches: int num_microbatches: int
forward_only: bool forward_only: bool
@ -123,14 +76,16 @@ class WorkItem:
class BackwardCache: class BackwardCache:
__slots__ = ('checkpoint', 'stage_inputs', 'stage_outputs') __slots__ = ('checkpoint', 'stage_input_args', 'stage_input_kwargs', 'stage_outputs')
checkpoint: bool checkpoint: bool
stage_inputs: Tuple[Any] stage_input_args: Tuple[Any]
stage_input_kwargs: Dict[Any, Any]
stage_outputs: Tuple[Any] stage_outputs: Tuple[Any]
def __init__(self, def __init__(self,
stage_inputs: List[torch.Tensor], stage_input_args: Tuple[Any],
stage_outputs: List[torch.Tensor] = None, stage_input_kwargs: Dict[Any, Any] = None,
stage_outputs: Tuple[Any] = None,
checkpoint: bool = False) -> None: checkpoint: bool = False) -> None:
for arg_name in self.__slots__: for arg_name in self.__slots__:
setattr(self, arg_name, locals()[arg_name]) setattr(self, arg_name, locals()[arg_name])
@ -147,13 +102,18 @@ class WorkerBase(ABC):
device: str, device: str,
criterion: Callable = None, criterion: Callable = None,
metric: Callable = None, metric: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False,
data_process_func: Callable = None) -> None:
super().__init__() super().__init__()
self.pp_rank = pp_rank self.pp_rank = pp_rank
self.actual_stage_num = actual_stage_num self.actual_stage_num = actual_stage_num
self.num_microbatches = num_microbatches self.num_microbatches = num_microbatches
self.checkpoint = checkpoint self.checkpoint = checkpoint
if data_process_func is not None:
self.data_process_func = partial(data_process_func, pp_rank)
self.device = device self.device = device
self._initialize_outstanding_range() self._initialize_outstanding_range()
@ -260,18 +220,39 @@ class WorkerBase(ABC):
self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition')) self.partition_condition_lock.wait_for(lambda: hasattr(self, 'module_partition'))
return self.module_partition.state_dict() return self.module_partition.state_dict()
def _make_args_kwargs(self, microbatch):
if isinstance(microbatch, dict):
return [], microbatch
elif isinstance(microbatch, torch.Tensor):
return [microbatch], {}
elif isinstance(microbatch, (tuple, list)):
args = []
kwargs = {}
for arg in microbatch:
if isinstance(arg, dict):
kwargs.update(arg)
else:
args.append(arg)
return args, kwargs
else:
raise TypeError(f"Input batch can be only dict, list, tuple or tensor, but receive {type(microbatch)}")
# just for first pp_rank # just for first pp_rank
def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool): def set_input(self, microbatch_id: int, microbatch: Tuple[Any], forward_only: bool):
assert self.consumer_stage_ids is not None assert self.consumer_stage_ids is not None
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
output = self._get_future_by_device() output = self._get_future_by_device()
args = [microbatch] if isinstance(microbatch, torch.Tensor) else microbatch
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, {}, output, microbatch_id, None, self.num_microbatches, # make args and kwargs
forward_only) args, kwargs = self._make_args_kwargs(microbatch)
work_item = WorkItem(self.pp_rank, Phase.FORWARD, args, kwargs, output, microbatch_id, None,
self.num_microbatches, forward_only)
with self.work_list_condition_lock: with self.work_list_condition_lock:
self.work_list[key] = work_item self.work_list[key] = work_item
color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}', 'data dispatch', if use_color_debug:
'magenta') color_debug(f'rank {self.pp_rank} receive data from dataloader {self._get_store_len()}',
'data dispatch', 'magenta')
self.work_list_condition_lock.notify_all() self.work_list_condition_lock.notify_all()
# just for last pp_rank # just for last pp_rank
@ -287,11 +268,12 @@ class WorkerBase(ABC):
key = UniqueKey(microbatch_id, Phase.BACKWARD) key = UniqueKey(microbatch_id, Phase.BACKWARD)
output = self._get_future_by_device() output = self._get_future_by_device()
grad_wrt_loss = torch.tensor(1, device=self.device) grad_wrt_loss = None
work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None, work_item = WorkItem(self.pp_rank, Phase.BACKWARD, grad_wrt_loss, {}, output, microbatch_id, None,
self.num_microbatches, False) self.num_microbatches, False)
if use_color_debug:
color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta') color_debug(f'rank {self.pp_rank} propose backward', 'data dispatch', 'magenta')
self.work_list[key] = work_item self.work_list[key] = work_item
@ -315,8 +297,9 @@ class WorkerBase(ABC):
producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id] producer_worker_rref = self.pp_rank_to_worker_rref[producer_stage_id]
subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key) subscribe_forward_futures[i] = producer_worker_rref.rpc_async().get_output_by_key(producer_output_key)
color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer', 'data dispatch', if use_color_debug:
'magenta') color_debug(f'rank {self.pp_rank} get {len(subscribe_forward_futures)} futs from its producer',
'data dispatch', 'magenta')
work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output, work_item_from_producer = WorkItem(stage_id, Phase.FORWARD, subscribe_forward_futures, {}, output,
microbatch_id, None, self.num_microbatches, forward_only) microbatch_id, None, self.num_microbatches, forward_only)
@ -327,6 +310,7 @@ class WorkerBase(ABC):
key = UniqueKey(microbatch_id, Phase.FORWARD) key = UniqueKey(microbatch_id, Phase.FORWARD)
assert key not in self.work_list assert key not in self.work_list
self.work_list[key] = work_item_from_producer self.work_list[key] = work_item_from_producer
if use_color_debug:
color_debug( color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}', f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_producer.phase} data: {tensor_shape_list(work_item_from_producer.args)}',
'data dispatch', 'magenta') 'data dispatch', 'magenta')
@ -344,6 +328,7 @@ class WorkerBase(ABC):
subscribe_backward_futures: List[Future] = [None] * consumer_num subscribe_backward_futures: List[Future] = [None] * consumer_num
output = self._get_future_by_device() output = self._get_future_by_device()
if use_color_debug:
color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer', color_debug(f'rank {self.pp_rank} get {len(subscribe_backward_futures)} futs from its consumer',
'data dispatch', 'magenta') 'data dispatch', 'magenta')
@ -364,6 +349,7 @@ class WorkerBase(ABC):
key = UniqueKey(microbatch_id, Phase.BACKWARD) key = UniqueKey(microbatch_id, Phase.BACKWARD)
assert key not in self.work_list assert key not in self.work_list
self.work_list[key] = work_item_from_consumer self.work_list[key] = work_item_from_consumer
if use_color_debug:
color_debug( color_debug(
f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}', f'rank_{self.pp_rank} load a new task to its work_list {key} {work_item_from_consumer.phase} data: {tensor_shape_list(work_item_from_consumer.args)}',
'data dispatch', 'magenta') 'data dispatch', 'magenta')
@ -398,12 +384,23 @@ class WorkerBase(ABC):
def is_last_stage(self): def is_last_stage(self):
return self.pp_rank == self.actual_stage_num - 1 return self.pp_rank == self.actual_stage_num - 1
def _default_data_process_func(self, args_kwargs):
if self.is_first_stage():
args = args_kwargs[0]
kwargs = args_kwargs[1]
else:
args = args_kwargs
kwargs = {}
return args, kwargs
def _consume_work_item_by_phase(self, work_item: WorkItem): def _consume_work_item_by_phase(self, work_item: WorkItem):
phase = work_item.phase phase = work_item.phase
args = work_item.args args = work_item.args
kwargs = work_item.kwargs kwargs = work_item.kwargs
microbatch_id = work_item.microbatch_id microbatch_id = work_item.microbatch_id
forward_only = work_item.forward_only forward_only = work_item.forward_only
data_process_func = getattr(self, 'data_process_func', self._default_data_process_func)
consume_result = None consume_result = None
is_first_stage = self.is_first_stage() is_first_stage = self.is_first_stage()
@ -420,18 +417,31 @@ class WorkerBase(ABC):
for stage_id in self.consumer_stage_ids: for stage_id in self.consumer_stage_ids:
consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id] consumer_worker_rref = self.pp_rank_to_worker_rref[stage_id]
consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only) consumer_worker_rref.remote().subscribe_producer(microbatch_id, forward_only)
self.forward_times += 1
# sustain pipeline context
self.forward_times += 1
if not forward_only: if not forward_only:
self.outstanding += 1 self.outstanding += 1
args = get_real_args(args)
# last stage doesn't need to do checkpoint, for it will do backward instantly # parse and integrate args and kwargs
if is_first_stage:
args = get_real_args_kwargs(args)
kwargs = get_real_args_kwargs(kwargs)
args_kwargs = (args, kwargs)
else:
args_kwargs = get_real_args_kwargs(args)
args, kwargs = data_process_func(args_kwargs)
stage_outputs = None
stage_input_args = args
stage_input_kwargs = kwargs
use_checkpoint = None
if forward_only: if forward_only:
with torch.no_grad(): with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs) consume_result = self.module_partition(*args, **kwargs)
# TODO : integrate output list
if is_last_stage and self.criterion: if is_last_stage and self.criterion:
with self.label_lock: with self.label_lock:
self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels) self.label_lock.wait_for(lambda: microbatch_id in self.microbatch_id_to_labels)
@ -445,15 +455,18 @@ class WorkerBase(ABC):
metric_result = None metric_result = None
consume_result = [loss.item(), metric_result] consume_result = [loss.item(), metric_result]
stage_outputs = None # last stage doesn't need to do checkpoint, for it will do backward instantly
stage_inputs = None stage_input_args = None
use_checkpoint = None stage_input_kwargs = None
stage_outputs = consume_result
elif self.checkpoint and not is_last_stage: elif self.checkpoint and not is_last_stage:
with torch.no_grad(): with torch.no_grad():
consume_result = self.module_partition(*args, **kwargs) consume_result = self.module_partition(*args, **kwargs)
stage_outputs = None
stage_inputs = args stage_outputs = consume_result
use_checkpoint = True use_checkpoint = True
else: else:
consume_result = self.module_partition(*args, **kwargs) consume_result = self.module_partition(*args, **kwargs)
# print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', ) # print(f'model{self.pp_rank + 1}(param_sum: {sum([p.sum().item() for p in self.module_partition.parameters()])}) input sum: {args[0].sum().item()} forward output sum: {consume_result.sum().item()}', )
@ -475,17 +488,14 @@ class WorkerBase(ABC):
loss = consume_result loss = consume_result
stage_outputs = loss stage_outputs = loss
stage_inputs = args
use_checkpoint = False use_checkpoint = False
if not forward_only: if not forward_only:
self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_inputs, self.microbatch_id_to_backward_cache[microbatch_id] = BackwardCache(stage_input_args,
stage_input_kwargs,
stage_outputs, stage_outputs,
checkpoint=use_checkpoint) checkpoint=use_checkpoint)
consume_result = [consume_result] if isinstance(consume_result,
(torch.Tensor, int, float)) else consume_result
# if not forward_only, do the backward # if not forward_only, do the backward
if not forward_only: if not forward_only:
if is_last_stage: # if it is the last stage, trigger backward automatic if is_last_stage: # if it is the last stage, trigger backward automatic
@ -504,23 +514,34 @@ class WorkerBase(ABC):
backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id) backward_cache = self.microbatch_id_to_backward_cache.pop(microbatch_id)
stage_outputs = backward_cache.stage_outputs stage_outputs = backward_cache.stage_outputs
stage_inputs = backward_cache.stage_inputs stage_input_args = backward_cache.stage_input_args
stage_input_kwargs = backward_cache.stage_input_kwargs
use_checkpoint = backward_cache.checkpoint use_checkpoint = backward_cache.checkpoint
if use_checkpoint: if use_checkpoint:
stage_outputs = [self.module_partition(*stage_inputs)] stage_outputs = [self.module_partition(*stage_input_args, **stage_input_kwargs)]
# take tensor only (for only tensor can do backward)
stage_outputs_tensors = []
pytree_map(stage_outputs, stage_outputs_tensors.append, process_types=torch.Tensor)
# overlap recompute and future.wait # overlap recompute and future.wait
grad_tensors = get_real_args(args) grad_tensors = get_real_args_kwargs(args)
autograd.backward(stage_outputs, grad_tensors=grad_tensors) # print('rank', self.pp_rank, tensor_shape_list(stage_outputs_tensors), tensor_shape_list(grad_tensors))
autograd.backward(stage_outputs_tensors, grad_tensors=grad_tensors)
# collect grad of input tensor # collect grad of input tensor
# there is a hypothesis that node in kwargs cann't be an non-leaf node in graph
# so we don't need to save the grad of node in kwargs.
consume_result = [] consume_result = []
if not is_first_stage: if not is_first_stage:
for input_node in stage_inputs: pytree_map(stage_input_args, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
if isinstance(input_node, torch.Tensor): pytree_map(stage_input_kwargs, lambda x: consume_result.append(x.grad), process_types=torch.Tensor)
consume_result.append(input_node.grad)
# for input_node in stage_input_args:
# if isinstance(input_node, torch.Tensor):
# consume_result.append(input_node.grad)
else: else:
raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}") raise TypeError(f"Unknown phase appears in _consume_work_item_by_phase {phase}")
@ -562,6 +583,7 @@ class WorkerBase(ABC):
def _work_loop(self): def _work_loop(self):
# for init # for init
self._get_producer_consumer() self._get_producer_consumer()
torch.cuda.set_device(ppg.get_local_pp_rank())
# main loop # main loop
while True: while True:
@ -571,6 +593,7 @@ class WorkerBase(ABC):
with self.work_list_condition_lock: with self.work_list_condition_lock:
work_item = self.work_list.pop(work_item_key) work_item = self.work_list.pop(work_item_key)
if use_color_debug:
color_debug( color_debug(
f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}', f'rank {self.pp_rank} get a key : {work_item_key} work_item args: {tensor_shape_list(work_item.args)} {self._get_store_len()}',
'work loop', 'green') 'work loop', 'green')
@ -582,6 +605,7 @@ class WorkerBase(ABC):
consume_result = self._consume_work_item_by_phase(work_item) consume_result = self._consume_work_item_by_phase(work_item)
if use_color_debug:
color_debug( color_debug(
f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}', f'rank_{self.pp_rank} [{work_item.phase}] finish consuming, result is {tensor_shape_list(consume_result)} {self._get_store_len()} | {self.work_list.keys()} | {self.output_list.keys()}',
'work loop', 'green') 'work loop', 'green')
@ -621,7 +645,8 @@ class PipelineEngineBase(ABC, nn.Module):
chunk: int = 1, chunk: int = 1,
criterion: Callable = None, criterion: Callable = None,
metric: Callable = None, metric: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False,
data_process_func: Callable = None) -> None:
super().__init__() super().__init__()
self.worker_type = worker_type self.worker_type = worker_type
self.partition_fn: Callable = partition_fn self.partition_fn: Callable = partition_fn
@ -633,6 +658,7 @@ class PipelineEngineBase(ABC, nn.Module):
self.use_1F1B = use_1F1B self.use_1F1B = use_1F1B
self.stage_num = stage_num self.stage_num = stage_num
self.checkpoint = checkpoint self.checkpoint = checkpoint
self.data_process_func = data_process_func
self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict() self.pp_rank_to_worker_rref: Dict[int, PyRRef] = dict()
@ -644,9 +670,21 @@ class PipelineEngineBase(ABC, nn.Module):
self._init_worker() self._init_worker()
def _check_argument(self) -> None: def _check_argument(self) -> None:
# make virtual stage num
self.virtual_stage_num = self.stage_num * self.chunk self.virtual_stage_num = self.stage_num * self.chunk
assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!" assert self.stage_num <= torch.cuda.device_count(), "stage_num must be smaller than device count!"
# check data_process_func
data_process_func = self.data_process_func
if data_process_func is not None:
assert callable(data_process_func), "data_process_func must be a function"
assert '<locals>' not in data_process_func.__repr__(), "data_process_func must be a global function"
assert '<lambda>' not in data_process_func.__repr__(), "data_process_func cannot be a lambda expression"
sig = inspect.signature(data_process_func)
assert len(
sig.parameters
) == 2, f"length of data_process_func' arguments must be 2, receive {len(sig.parameters)} arguments instead"
def _get_actual_stage_num(self) -> int: def _get_actual_stage_num(self) -> int:
return self.stage_num if self.chunk == 1 else self.virtual_stage_num return self.stage_num if self.chunk == 1 else self.virtual_stage_num
@ -682,6 +720,7 @@ class PipelineEngineBase(ABC, nn.Module):
metric = self.metric metric = self.metric
partition_fn = self.partition_fn partition_fn = self.partition_fn
chunk = self.chunk chunk = self.chunk
data_process_func = self.data_process_func
for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)): for pp_rank in range(len(self.pp_rank_to_rpc_worker_id)):
partition_id = self.pp_rank_to_module_partition_id[pp_rank] partition_id = self.pp_rank_to_module_partition_id[pp_rank]
@ -693,7 +732,7 @@ class PipelineEngineBase(ABC, nn.Module):
worker_type, worker_type,
args=(partition_fn, partition_args, pp_rank, args=(partition_fn, partition_args, pp_rank,
actual_stage_num, num_microbatches, device, actual_stage_num, num_microbatches, device,
criterion, metric, checkpoint)) criterion, metric, checkpoint, data_process_func))
# let each worker know global worker rref (include itself) # let each worker know global worker rref (include itself)
sync_futs = [] sync_futs = []
@ -779,20 +818,25 @@ class PipelineEngineBase(ABC, nn.Module):
worker_forward_result = [None] * self.num_microbatches worker_forward_result = [None] * self.num_microbatches
for microbatch_id in range(self.num_microbatches): for microbatch_id in range(self.num_microbatches):
ret = ret_future[pp_rank][microbatch_id].wait() ret = ret_future[pp_rank][microbatch_id].wait()
# TODO : more stable format
ret = [ret] if isinstance(ret, torch.Tensor) else ret
worker_forward_result[microbatch_id] = ret worker_forward_result[microbatch_id] = ret
worker_forward_result = list(zip(*worker_forward_result)) worker_forward_result = list(zip(*worker_forward_result))
forward_result.extend(worker_forward_result) forward_result.extend(worker_forward_result)
return forward_result return forward_result
def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False): def forward_backward(self, batch: torch.Tensor, labels: torch.Tensor = None, forward_only: bool = False):
if labels is not None: batch_lengths = get_batch_lengths(batch)
assert len(batch) == len(labels)
if not forward_only: if labels is not None and not forward_only:
assert hasattr(self, 'optimizer_class') assert hasattr(
self, 'optimizer_class'), "call `initialize_optimizer` to initialize optimizer before forward_backward"
num_microbatches = self.num_microbatches num_microbatches = self.num_microbatches
microbatch_size = len(batch) // num_microbatches microbatch_size = batch_lengths[0] // num_microbatches
device = self.device
# If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks' # If Chimera mode is used, then rank of down pipeline is excluded from 'input_pp_ranks' or 'output_pp_ranks'
input_pp_ranks = self.get_input_pp_ranks() input_pp_ranks = self.get_input_pp_ranks()
@ -805,16 +849,17 @@ class PipelineEngineBase(ABC, nn.Module):
# control data input speed # control data input speed
# to prevent exceed of wait limitations # to prevent exceed of wait limitations
self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future) self._consume_constraint(microbatch_id, forward_only, input_pp_ranks, output_pp_ranks, ret_future)
batch_start = microbatch_size * microbatch_id
batch_end = batch_start + microbatch_size
# set input # set input
microbatch = batch[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] microbatch = split_batch(batch, batch_start, batch_end, device)
microbatch = microbatch.cuda()
self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only) self._set_input(input_pp_ranks, microbatch_id, microbatch, forward_only)
# set labels # set labels
if labels is not None: if labels is not None:
microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)] # microlabels = labels[microbatch_size * microbatch_id:microbatch_size * (microbatch_id + 1)]
microlabels = microlabels.cuda() microlabels = split_batch(labels, batch_start, batch_end, device)
self._set_labels(output_pp_ranks, microbatch_id, microlabels) self._set_labels(output_pp_ranks, microbatch_id, microlabels)
# get data asynchronously # get data asynchronously

View File

@ -44,7 +44,8 @@ class FillDrainPipelineEngine(PipelineEngineBase):
chunk: int = 1, chunk: int = 1,
criterion: Callable = None, criterion: Callable = None,
metric: Callable = None, metric: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False,
data_process_func: Callable = None) -> None:
if chunk > 1: if chunk > 1:
assert num_microbatches % stage_num == 0, \ assert num_microbatches % stage_num == 0, \
@ -52,7 +53,7 @@ class FillDrainPipelineEngine(PipelineEngineBase):
use_1F1B = False use_1F1B = False
super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, super().__init__(FillDrainWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint) metric, checkpoint, data_process_func)
class OneFOneBWorker(WorkerBase): class OneFOneBWorker(WorkerBase):
@ -103,7 +104,8 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
chunk: int = 1, chunk: int = 1,
criterion: Callable = None, criterion: Callable = None,
metric: Callable = None, metric: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False,
data_process_func: Callable = None) -> None:
if chunk > 1: if chunk > 1:
assert num_microbatches % stage_num == 0, \ assert num_microbatches % stage_num == 0, \
@ -112,7 +114,7 @@ class OneFOneBPipelineEngine(PipelineEngineBase):
use_1F1B = True use_1F1B = True
super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, super().__init__(OneFOneBWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint) metric, checkpoint, data_process_func)
class ChimeraWorker(WorkerBase): class ChimeraWorker(WorkerBase):
@ -227,9 +229,9 @@ class ChimeraWorker(WorkerBase):
if step_index == 1: if step_index == 1:
ppg.chimera_step_lock.acquire() ppg.chimera_step_lock.acquire()
print(f'rank_{self.pp_rank} before all reduce') # print(f'rank_{self.pp_rank} before all reduce')
dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False) dist.all_reduce_coalesced(grads, group=all_reduce_group, async_op=False)
print(f'rank_{self.pp_rank} after all reduce') # print(f'rank_{self.pp_rank} after all reduce')
if step_index == 0: if step_index == 0:
ppg.chimera_step_lock.release() ppg.chimera_step_lock.release()
@ -244,7 +246,8 @@ class ChimeraPipelineEngine(PipelineEngineBase):
device: str, device: str,
criterion: Callable = None, criterion: Callable = None,
metric: Callable = None, metric: Callable = None,
checkpoint: bool = False) -> None: checkpoint: bool = False,
data_process_func: Callable = None) -> None:
assert num_microbatches % stage_num == 0, \ assert num_microbatches % stage_num == 0, \
"In Chimera, num_microbatches must be the multiply of stage_num!" "In Chimera, num_microbatches must be the multiply of stage_num!"
@ -252,7 +255,7 @@ class ChimeraPipelineEngine(PipelineEngineBase):
chunk = 1 chunk = 1
super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion, super().__init__(ChimeraWorker, partition_fn, stage_num, num_microbatches, device, use_1F1B, chunk, criterion,
metric, checkpoint) metric, checkpoint, data_process_func)
def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]], def _consume_constraint(self, microbatch_id: int, forward_only: bool, ret_future: Dict[PyRRef, List[Future]],
input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]): input_pp_ranks: List[PyRRef], output_pp_ranks: List[PyRRef]):
@ -330,6 +333,7 @@ class ChimeraPipelineEngine(PipelineEngineBase):
for microbatch_id in range(self.num_microbatches): for microbatch_id in range(self.num_microbatches):
offset = (microbatch_id % 2) * stage_num offset = (microbatch_id % 2) * stage_num
ret = ret_future[pp_rank + offset][microbatch_id].wait() ret = ret_future[pp_rank + offset][microbatch_id].wait()
ret = [ret] if isinstance(ret, torch.Tensor) else ret
worker_forward_result[microbatch_id] = ret worker_forward_result[microbatch_id] = ret
worker_forward_result = list(zip(*worker_forward_result)) worker_forward_result = list(zip(*worker_forward_result))

View File

@ -0,0 +1,74 @@
from typing import List, Any, Tuple, Dict, Callable, Type, Union
import torch
from torch.futures import Future
from colorama import Back, Style
# config for debug and test
use_color_debug = False
def color_debug(text, prefix=' ', color='blue'):
color = color.upper()
print(getattr(Back, color), prefix, Style.RESET_ALL, text)
def pytree_map(obj: Any, fn: Callable, process_types: Union[Type, Tuple[Type]] = (), map_all: bool = False) -> Any:
"""process object recursively, like pytree
Args:
obj (:class:`Any`): object to process
fn (:class:`Callable`): a function to process subobject in obj
process_types(:class: `type | tuple[type]`): types to determine the type to process
Returns:
:class:`Any`: returns have the same structure of `obj` and type in process_types after map of `fn`
"""
if isinstance(obj, dict):
return {k: pytree_map(obj[k], fn, process_types, map_all) for k in obj}
elif isinstance(obj, tuple):
return tuple(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, list):
return list(pytree_map(o, fn, process_types, map_all) for o in obj)
elif isinstance(obj, process_types):
return fn(obj)
else:
return fn(obj) if map_all else obj
def tensor_shape_list(obj):
return pytree_map(obj, fn=lambda x: x.shape, process_types=torch.Tensor)
def get_batch_lengths(batch):
lengths = []
pytree_map(batch, fn=lambda x: lengths.append(len(x)), process_types=torch.Tensor)
return lengths
def split_batch(batch: Any, start, stop, device: str):
if device == 'cuda':
fn = lambda x: x[start:stop].cuda()
else:
fn = lambda x: x[start:stop]
return pytree_map(batch, fn=fn, process_types=torch.Tensor)
def type_detail(obj):
return pytree_map(obj, lambda x: type(x), map_all=True)
def get_real_args_kwargs(args_or_kwargs):
args_or_kwargs = pytree_map(args_or_kwargs, fn=lambda x: x.wait(), process_types=Future)
# TODO : combine producer and consumer
# by default, merge all args in the output args or kwargs
if args_or_kwargs is not None:
if isinstance(args_or_kwargs, dict):
pass
else:
flatten_args = []
pytree_map(args_or_kwargs, fn=lambda x: flatten_args.append(x), map_all=True)
args_or_kwargs = flatten_args
return args_or_kwargs

View File

@ -22,10 +22,9 @@ def run_master(args):
epoch = args.epoch epoch = args.epoch
device = args.device device = args.device
stage_num = 4 stage_num = args.world_size
chunk = 1 chunk = 1
num_microbatches = 4 num_microbatches = args.num_microbatches
actual_stage_num = 4
use_checkpoint = False use_checkpoint = False
sample_num = 1024 sample_num = 1024
@ -78,6 +77,4 @@ def run_master(args):
if __name__ == "__main__": if __name__ == "__main__":
args = parse_args() args = parse_args()
args.world_size = 4
args.num_microbatches = 4
rpc_run(args, run_master) rpc_run(args, run_master)