[pipeline] supported more flexible dataflow control for pipeline parallel training (#1108)

* [pipeline] supported more flexible dataflow control for pipeline parallel training

* polish code

* polish code

* polish code
pull/1117/head
Frank Lee 2022-06-15 10:41:28 +08:00 committed by GitHub
parent 53297330c0
commit 6f82ac9bcb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 194 additions and 87 deletions

View File

@ -18,13 +18,12 @@ class BaseSchedule(ABC):
control of FP16 in class schedule.
Args:
batch_data_process_func (Callable, optional): The preprocessing function which receives a batch of data,
and it will be executed in load_batch.
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges them into data and label.
"""
def __init__(self, batch_data_process_func: Callable = None):
def __init__(self, data_process_func: Callable = None):
self.logger = get_dist_logger()
self.batch_data_process_func = batch_data_process_func
self.data_process_func = data_process_func
@staticmethod
def _move_tensor(element):
@ -34,16 +33,24 @@ class BaseSchedule(ABC):
return element
def _move_to_device(self, data):
if isinstance(data, dict):
if isinstance(data, torch.Tensor):
data = data.to(get_current_device())
elif isinstance(data, (list, tuple)):
data = [self._move_tensor(v) for v in data]
elif isinstance(data, dict):
data = {k: self._move_tensor(v) for k, v in data.items()}
else:
data = self._move_tensor(data)
raise TypeError(
f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
return data
@staticmethod
def _check_sanity(data, tag: str):
assert isinstance(data, (torch.Tensor, dict)), \
f'{tag} must be torch.Tensor or dict'
def _get_batch_size(self, data):
if isinstance(data, torch.Tensor):
return data.size(0)
elif isinstance(data, (list, tuple)):
return data[0].size(0)
elif isinstance(data, dict):
return data[next(data.keys())].size(0)
def load_batch(self, data_iter, to_gpu=True):
"""Loads a batch from data iterator. It returns the data and labels which are
@ -60,19 +67,10 @@ class BaseSchedule(ABC):
raise RuntimeError('Dataloader is not defined.')
batch_data = next(data_iter)
if self.batch_data_process_func:
data, label = self.batch_data_process_func(batch_data)
else:
data, label = batch_data
self._check_sanity(data, 'data')
self._check_sanity(label, 'label')
if isinstance(data, torch.Tensor):
self.batch_size = data.size(0)
else:
self.batch_size = next(iter(data.values())).size(0)
if to_gpu:
return self._move_to_device(data), self._move_to_device(label)
return data, label
batch_data = self._move_to_device(batch_data)
self.batch_size = self._get_batch_size(batch_data)
return batch_data
def pre_processing(self, engine):
"""To perform actions before running the schedule.
@ -101,8 +99,13 @@ class BaseSchedule(ABC):
def _call_engine(engine, inputs):
if isinstance(inputs, torch.Tensor):
return engine(inputs)
else:
elif isinstance(inputs, (list, tuple)):
return engine(*inputs)
elif isinstance(inputs, dict):
return engine(**inputs)
else:
TypeError(
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}")
@staticmethod
def _call_engine_criterion(engine, outputs, labels):
@ -112,6 +115,17 @@ class BaseSchedule(ABC):
if isinstance(outputs, torch.Tensor):
outputs = (outputs,)
if isinstance(labels, torch.Tensor):
return engine.criterion(*outputs, labels)
else:
labels = (labels,)
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
return engine.criterion(*outputs, *labels)
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
return engine.criterion(*outputs, **labels)
elif isinstance(outputs, dict) and isinstance(labels, dict):
return engine.criterion(**outputs, **labels)
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
else:
raise TypeError(f"Expected model outputs and labels to be of type torch.Tensor ' \
'(which is auto-converted to tuple), list, tuple, or dict, ' \
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)")

View File

@ -4,9 +4,10 @@
from typing import Iterable
import torch
import inspect
from ._base_schedule import BaseSchedule
from colossalai.utils import conditional_context
from typing import Callable
class NonPipelineSchedule(BaseSchedule):
@ -16,10 +17,32 @@ class NonPipelineSchedule(BaseSchedule):
to update the parameters if it is in training mode.
Args:
batch_data_process_func (Callable, optional): The preprocessing function which receives a batch of data,
data_process_func (Callable, optional): The preprocessing function which receives a batch of data
and returns a tuple in the form of (data, label).
and it will be executed in load_batch.
Example:
# this shows an example of customized data_process_func
def data_process_func(dataloader_output):
item1, item2, item3 = dataloader_output
data = (item1, item2)
label = item3
return data, label
"""
def __init__(self, data_process_func: Callable = None):
# check that non-pipeline schedule data process func only takes in one parameter
# which is the batch data
if data_process_func:
sig = inspect.signature(data_process_func)
assert len(sig.parameters) == 1, \
'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \
'which is a tuple of tensors for the current batch, ' \
'i.e. data_process_func(dataloader_output).'
super().__init__(data_process_func)
def forward_backward_step(self,
engine,
data_iter: Iterable,
@ -42,7 +65,14 @@ class NonPipelineSchedule(BaseSchedule):
"""
assert forward_only or return_loss, \
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
data, label = self.load_batch(data_iter)
batch_data = self.load_batch(data_iter)
if self.batch_data_process_func:
data, label = self.batch_data_process_func(batch_data)
else:
# if not batch data process func is given,
# then we regard the batch data as a simple tuple of (data, label)
data, label = batch_data
# forward
with conditional_context(torch.no_grad(), enable=forward_only):

View File

@ -68,19 +68,41 @@ class PipelineSchedule(BaseSchedule):
Args:
num_microbatches (int): The number of microbatches.
batch_data_process_func (Callable, optional):
data_process_func (Callable, optional):
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
scatter_gather_tensors (bool, optional):
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
Example:
# this shows an example of customized data_process_func
def data_process_func(stage_output, dataloader_output):
output1, output2 = stage_output
item1, item2, item3 = dataloader_output
# assume item2 is not needed
data = (output1, output2, item1)
label = item3
return data, label
"""
def __init__(self,
num_microbatches,
batch_data_process_func: Callable = None,
data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
super().__init__(batch_data_process_func=batch_data_process_func)
# we need to make sure that the signature of the data_process_func is valid
if data_process_func:
sig = inspect.signature(data_process_func)
assert len(sig.parameters) == 2, \
'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \
'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \
'i.e. data_process_func(stage_output, dataloader_output).'
super().__init__(data_process_func=data_process_func)
assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}'
@ -99,29 +121,32 @@ class PipelineSchedule(BaseSchedule):
self.scatter_gather_tensors = scatter_gather_tensors
self._logger = get_dist_logger()
# cache for the batch data
self.batch_data = None
def load_batch(self, data_iter):
# Pipeline schedule just puts data in memory
self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False)
batch_data = super().load_batch(data_iter, to_gpu=False)
self.microbatch_offset = 0
if isinstance(self.batch_data, torch.Tensor):
batch_size = self.batch_data.size(0)
else:
batch_size = next(iter(self.batch_data.values())).size(0)
assert batch_size % self.num_microbatches == 0, \
assert self.batch_size % self.num_microbatches == 0, \
"Batch size should divided by the number of microbatches"
self.microbatch_size = batch_size // self.num_microbatches
self.microbatch_size = self.batch_size // self.num_microbatches
self.batch_data = batch_data
def _get_data_slice(self, data, offset):
if isinstance(data, torch.Tensor):
return data[offset:offset + self.microbatch_size]
elif isinstance(data, (list, tuple)):
return [val[offset:offset + self.microbatch_size] for val in data]
elif isinstance(data, dict):
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
else:
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
def load_micro_batch(self):
data = self._get_data_slice(self.batch_data, self.microbatch_offset)
label = self._get_data_slice(self.batch_label, self.microbatch_offset)
mciro_batch_data = self._get_data_slice(self.batch_data, self.microbatch_offset)
self.microbatch_offset += self.microbatch_size
return self._move_to_device(data), self._move_to_device(label)
return self._move_to_device(mciro_batch_data)
def pre_processing(self, engine):
# TODO: remove this after testing new zero with pipeline parallelism
@ -137,45 +162,78 @@ class PipelineSchedule(BaseSchedule):
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
@staticmethod
def _call_engine(model, input_obj, batch_data):
if isinstance(model, NaiveAMPModel):
sig = inspect.signature(model.model.forward)
elif hasattr(model, 'colo_attr'):
sig = inspect.signature(model.module.forward)
def _call_engine(model, data):
if data is not None:
if isinstance(data, torch.Tensor):
return model(data)
elif isinstance(data, (list, tuple)):
return model(*data)
elif isinstance(data, dict):
return model(**data)
else:
sig = inspect.signature(model.forward)
if isinstance(batch_data, torch.Tensor):
for p in sig.parameters.values():
if p.kind == inspect.Parameter.VAR_KEYWORD:
if input_obj is None:
return model(batch_data)
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
def _get_actual_forward_func(self, module):
if isinstance(module, NaiveAMPModel):
sig = inspect.signature(module.model.forward)
elif hasattr(module, 'colo_attr'):
sig = inspect.signature(module.module.forward)
else:
return model(input_obj)
if input_obj is None:
return model(batch_data)
elif isinstance(input_obj, torch.Tensor):
if len(sig.parameters) > 1:
return model(input_obj, batch_data)
sig = inspect.signature(module.forward)
return sig
def _get_data_label_for_current_step(self, stage_output, micro_batch_data, criterion, model):
if self.data_process_func:
# use customized function to get data and label
data, label = self.data_process_func(stage_output, micro_batch_data)
else:
return model(input_obj)
if isinstance(micro_batch_data, (tuple, list)):
if gpc.is_first_rank(ParallelMode.PIPELINE):
# for the first stage, we use the data from the
# dataloader output by default
data, label = micro_batch_data
else:
if len(sig.parameters) > len(input_obj):
return model(*input_obj, batch_data)
# for non-first stage, we use the output passed
# by the previous as the model input
data = stage_output
_, label = micro_batch_data
elif isinstance(micro_batch_data, dict):
args = []
data = {}
label = {}
# we feed the stage output to args first
# then map each arg in args to its param name
if stage_output is not None:
if isinstance(stage_output, torch.Tensor):
args.append(stage_output)
elif isinstance(stage_output, (list, tuple)):
args.extend(stage_output)
else:
return model(*input_obj)
raise TypeError(
f"Expected the values passed from previous pipeline stage to be torch.Tensor, list or tuple, but got {type(input_obj)}"
)
# get all parameter names for the forward function of the model
fwd_sig = self._get_actual_forward_func(model)
fwd_sig_param_name = [p.name for p in fwd_sig.values()]
# build the kwargs for the forward function
for idx, param_name in enumerate(fwd_sig_param_name):
if idx < len(args):
data[param_name] = args[idx]
else:
filter_batch = True
for p in sig.parameters.values():
if p.kind == inspect.Parameter.VAR_KEYWORD:
filter_batch = False
if filter_batch:
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
if input_obj is None and filter_batch:
return model(**batch_data)
elif isinstance(input_obj, torch.Tensor) or input_obj is None:
return model(input_obj, **batch_data)
else:
return model(*input_obj, **batch_data)
if param_name in micro_batch_data:
data[param_name] = micro_batch_data[param_name]
# get the tensors for loss
loss_sig = inspect.signature(criterion)
loss_sig_param_name = [p.name for p in loss_sig.values()]
for param_name in loss_sig_param_name:
if param_name in micro_batch_data:
label[param_name] = micro_batch_data[param_name]
return data, label
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
"""Forward step for passed-in model. If it is the first stage, the input tensor
@ -191,8 +249,11 @@ class PipelineSchedule(BaseSchedule):
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
"""
data, label = self.load_micro_batch()
output_obj = self._call_engine(engine.model, input_obj, data)
micro_batch_data = self.load_micro_batch()
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, engine.model)
output_obj = self._call_engine(engine.model, data)
if gpc.is_last_rank(ParallelMode.PIPELINE):
if return_output_label:
@ -399,7 +460,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
def __init__(self,
num_microbatches: int,
num_model_chunks: int,
batch_data_process_func: Callable = None,
data_process_func: Callable = None,
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
scatter_gather_tensors: bool = False):
"""A helper schedule class for pipeline parallelism running environment.
@ -409,7 +470,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
Args:
num_microbatches (int): The number of microbatches.
num_model_chunks (int): The number of model chunks.
batch_data_process_func (Callable, optional):
data_process_func (Callable, optional):
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
scatter_gather_tensors (bool, optional):
@ -420,7 +481,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \
f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}'
super().__init__(num_microbatches,
batch_data_process_func=batch_data_process_func,
data_process_func=data_process_func,
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather_tensors)
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
@ -446,9 +507,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
def load_micro_batch(self, model_chunk_id):
data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id])
label = self._get_data_slice(self.batch_label, self.microbatch_offset[model_chunk_id])
self.microbatch_offset[model_chunk_id] += self.microbatch_size
return self._move_to_device(data), self._move_to_device(label)
return self._move_to_device(data)
def _forward_step(self,
engine,
@ -471,8 +531,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
Returns:
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
"""
data, label = self.load_micro_batch(model_chunk_id)
output_obj = self._call_engine(engine.model[model_chunk_id], input_obj, data)
micro_batch_data = self.load_micro_batch(model_chunk_id)
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion,
engine.model[model_chunk_id])
output_obj = self._call_engine(engine.model[model_chunk_id], data)
if gpc.is_pipeline_last_stage():
if return_output_label: