mirror of https://github.com/hpcaitech/ColossalAI
[pipeline] supported more flexible dataflow control for pipeline parallel training (#1108)
* [pipeline] supported more flexible dataflow control for pipeline parallel training * polish code * polish code * polish codepull/1117/head
parent
53297330c0
commit
6f82ac9bcb
|
@ -18,13 +18,12 @@ class BaseSchedule(ABC):
|
|||
control of FP16 in class schedule.
|
||||
|
||||
Args:
|
||||
batch_data_process_func (Callable, optional): The preprocessing function which receives a batch of data,
|
||||
and it will be executed in load_batch.
|
||||
data_process_func (Callable, optional): The preprocessing function which receives a batch of data and arranges them into data and label.
|
||||
"""
|
||||
|
||||
def __init__(self, batch_data_process_func: Callable = None):
|
||||
def __init__(self, data_process_func: Callable = None):
|
||||
self.logger = get_dist_logger()
|
||||
self.batch_data_process_func = batch_data_process_func
|
||||
self.data_process_func = data_process_func
|
||||
|
||||
@staticmethod
|
||||
def _move_tensor(element):
|
||||
|
@ -34,16 +33,24 @@ class BaseSchedule(ABC):
|
|||
return element
|
||||
|
||||
def _move_to_device(self, data):
|
||||
if isinstance(data, dict):
|
||||
if isinstance(data, torch.Tensor):
|
||||
data = data.to(get_current_device())
|
||||
elif isinstance(data, (list, tuple)):
|
||||
data = [self._move_tensor(v) for v in data]
|
||||
elif isinstance(data, dict):
|
||||
data = {k: self._move_tensor(v) for k, v in data.items()}
|
||||
else:
|
||||
data = self._move_tensor(data)
|
||||
raise TypeError(
|
||||
f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _check_sanity(data, tag: str):
|
||||
assert isinstance(data, (torch.Tensor, dict)), \
|
||||
f'{tag} must be torch.Tensor or dict'
|
||||
def _get_batch_size(self, data):
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data.size(0)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
return data[0].size(0)
|
||||
elif isinstance(data, dict):
|
||||
return data[next(data.keys())].size(0)
|
||||
|
||||
def load_batch(self, data_iter, to_gpu=True):
|
||||
"""Loads a batch from data iterator. It returns the data and labels which are
|
||||
|
@ -60,19 +67,10 @@ class BaseSchedule(ABC):
|
|||
raise RuntimeError('Dataloader is not defined.')
|
||||
batch_data = next(data_iter)
|
||||
|
||||
if self.batch_data_process_func:
|
||||
data, label = self.batch_data_process_func(batch_data)
|
||||
else:
|
||||
data, label = batch_data
|
||||
self._check_sanity(data, 'data')
|
||||
self._check_sanity(label, 'label')
|
||||
if isinstance(data, torch.Tensor):
|
||||
self.batch_size = data.size(0)
|
||||
else:
|
||||
self.batch_size = next(iter(data.values())).size(0)
|
||||
if to_gpu:
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
return data, label
|
||||
batch_data = self._move_to_device(batch_data)
|
||||
self.batch_size = self._get_batch_size(batch_data)
|
||||
return batch_data
|
||||
|
||||
def pre_processing(self, engine):
|
||||
"""To perform actions before running the schedule.
|
||||
|
@ -101,8 +99,13 @@ class BaseSchedule(ABC):
|
|||
def _call_engine(engine, inputs):
|
||||
if isinstance(inputs, torch.Tensor):
|
||||
return engine(inputs)
|
||||
else:
|
||||
elif isinstance(inputs, (list, tuple)):
|
||||
return engine(*inputs)
|
||||
elif isinstance(inputs, dict):
|
||||
return engine(**inputs)
|
||||
else:
|
||||
TypeError(
|
||||
f"Expected engine inputs to be of type torch.Tensor, list, tuple, or dict, but got {type(inputs)}")
|
||||
|
||||
@staticmethod
|
||||
def _call_engine_criterion(engine, outputs, labels):
|
||||
|
@ -112,6 +115,17 @@ class BaseSchedule(ABC):
|
|||
if isinstance(outputs, torch.Tensor):
|
||||
outputs = (outputs,)
|
||||
if isinstance(labels, torch.Tensor):
|
||||
return engine.criterion(*outputs, labels)
|
||||
else:
|
||||
labels = (labels,)
|
||||
|
||||
if isinstance(outputs, (tuple, list)) and isinstance(labels, (tuple, list)):
|
||||
return engine.criterion(*outputs, *labels)
|
||||
elif isinstance(outputs, (tuple, list)) and isinstance(labels, dict):
|
||||
return engine.criterion(*outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, dict):
|
||||
return engine.criterion(**outputs, **labels)
|
||||
elif isinstance(outputs, dict) and isinstance(labels, (list, tuple)):
|
||||
raise ValueError(f"Expected labels to be a dict when the model outputs are dict, but got {type(labels)}")
|
||||
else:
|
||||
raise TypeError(f"Expected model outputs and labels to be of type torch.Tensor ' \
|
||||
'(which is auto-converted to tuple), list, tuple, or dict, ' \
|
||||
'but got {type(outputs)} (model outputs) and {type(labels)} (labels)")
|
||||
|
|
|
@ -4,9 +4,10 @@
|
|||
from typing import Iterable
|
||||
|
||||
import torch
|
||||
|
||||
import inspect
|
||||
from ._base_schedule import BaseSchedule
|
||||
from colossalai.utils import conditional_context
|
||||
from typing import Callable
|
||||
|
||||
|
||||
class NonPipelineSchedule(BaseSchedule):
|
||||
|
@ -16,10 +17,32 @@ class NonPipelineSchedule(BaseSchedule):
|
|||
to update the parameters if it is in training mode.
|
||||
|
||||
Args:
|
||||
batch_data_process_func (Callable, optional): The preprocessing function which receives a batch of data,
|
||||
data_process_func (Callable, optional): The preprocessing function which receives a batch of data
|
||||
and returns a tuple in the form of (data, label).
|
||||
and it will be executed in load_batch.
|
||||
|
||||
Example:
|
||||
# this shows an example of customized data_process_func
|
||||
def data_process_func(dataloader_output):
|
||||
item1, item2, item3 = dataloader_output
|
||||
data = (item1, item2)
|
||||
label = item3
|
||||
return data, label
|
||||
"""
|
||||
|
||||
def __init__(self, data_process_func: Callable = None):
|
||||
# check that non-pipeline schedule data process func only takes in one parameter
|
||||
# which is the batch data
|
||||
|
||||
if data_process_func:
|
||||
sig = inspect.signature(data_process_func)
|
||||
assert len(sig.parameters) == 1, \
|
||||
'The data_process_func only takes in one parameter for NonPipelineSchedule, ' \
|
||||
'which is a tuple of tensors for the current batch, ' \
|
||||
'i.e. data_process_func(dataloader_output).'
|
||||
|
||||
super().__init__(data_process_func)
|
||||
|
||||
def forward_backward_step(self,
|
||||
engine,
|
||||
data_iter: Iterable,
|
||||
|
@ -42,7 +65,14 @@ class NonPipelineSchedule(BaseSchedule):
|
|||
"""
|
||||
assert forward_only or return_loss, \
|
||||
"The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
data, label = self.load_batch(data_iter)
|
||||
batch_data = self.load_batch(data_iter)
|
||||
|
||||
if self.batch_data_process_func:
|
||||
data, label = self.batch_data_process_func(batch_data)
|
||||
else:
|
||||
# if not batch data process func is given,
|
||||
# then we regard the batch data as a simple tuple of (data, label)
|
||||
data, label = batch_data
|
||||
|
||||
# forward
|
||||
with conditional_context(torch.no_grad(), enable=forward_only):
|
||||
|
|
|
@ -68,19 +68,41 @@ class PipelineSchedule(BaseSchedule):
|
|||
|
||||
Args:
|
||||
num_microbatches (int): The number of microbatches.
|
||||
batch_data_process_func (Callable, optional):
|
||||
data_process_func (Callable, optional):
|
||||
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
|
||||
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
||||
scatter_gather_tensors (bool, optional):
|
||||
If set to `True`, communication will be reduced over pipeline when using 1D tensor parallelization.
|
||||
|
||||
Example:
|
||||
|
||||
# this shows an example of customized data_process_func
|
||||
def data_process_func(stage_output, dataloader_output):
|
||||
output1, output2 = stage_output
|
||||
item1, item2, item3 = dataloader_output
|
||||
|
||||
# assume item2 is not needed
|
||||
data = (output1, output2, item1)
|
||||
label = item3
|
||||
return data, label
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
num_microbatches,
|
||||
batch_data_process_func: Callable = None,
|
||||
data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False):
|
||||
super().__init__(batch_data_process_func=batch_data_process_func)
|
||||
|
||||
# we need to make sure that the signature of the data_process_func is valid
|
||||
if data_process_func:
|
||||
sig = inspect.signature(data_process_func)
|
||||
assert len(sig.parameters) == 2, \
|
||||
'The data_process_func only takes in two parameters for NonPipelineSchedule, ' \
|
||||
'which is the tensors passed by the previous pipeline stage and the dataloader output from this stage, ' \
|
||||
'i.e. data_process_func(stage_output, dataloader_output).'
|
||||
|
||||
super().__init__(data_process_func=data_process_func)
|
||||
|
||||
assert num_microbatches > 0, f'expected num_microbatches to be larger then 1, but got {num_microbatches}'
|
||||
|
||||
|
@ -99,29 +121,32 @@ class PipelineSchedule(BaseSchedule):
|
|||
self.scatter_gather_tensors = scatter_gather_tensors
|
||||
self._logger = get_dist_logger()
|
||||
|
||||
# cache for the batch data
|
||||
self.batch_data = None
|
||||
|
||||
def load_batch(self, data_iter):
|
||||
# Pipeline schedule just puts data in memory
|
||||
self.batch_data, self.batch_label = super().load_batch(data_iter, to_gpu=False)
|
||||
batch_data = super().load_batch(data_iter, to_gpu=False)
|
||||
self.microbatch_offset = 0
|
||||
if isinstance(self.batch_data, torch.Tensor):
|
||||
batch_size = self.batch_data.size(0)
|
||||
else:
|
||||
batch_size = next(iter(self.batch_data.values())).size(0)
|
||||
assert batch_size % self.num_microbatches == 0, \
|
||||
assert self.batch_size % self.num_microbatches == 0, \
|
||||
"Batch size should divided by the number of microbatches"
|
||||
self.microbatch_size = batch_size // self.num_microbatches
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
self.batch_data = batch_data
|
||||
|
||||
def _get_data_slice(self, data, offset):
|
||||
if isinstance(data, torch.Tensor):
|
||||
return data[offset:offset + self.microbatch_size]
|
||||
elif isinstance(data, (list, tuple)):
|
||||
return [val[offset:offset + self.microbatch_size] for val in data]
|
||||
elif isinstance(data, dict):
|
||||
return {k: v[offset:offset + self.microbatch_size] for k, v in data.items()}
|
||||
else:
|
||||
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
|
||||
def load_micro_batch(self):
|
||||
data = self._get_data_slice(self.batch_data, self.microbatch_offset)
|
||||
label = self._get_data_slice(self.batch_label, self.microbatch_offset)
|
||||
mciro_batch_data = self._get_data_slice(self.batch_data, self.microbatch_offset)
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
return self._move_to_device(mciro_batch_data)
|
||||
|
||||
def pre_processing(self, engine):
|
||||
# TODO: remove this after testing new zero with pipeline parallelism
|
||||
|
@ -137,45 +162,78 @@ class PipelineSchedule(BaseSchedule):
|
|||
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(model, input_obj, batch_data):
|
||||
if isinstance(model, NaiveAMPModel):
|
||||
sig = inspect.signature(model.model.forward)
|
||||
elif hasattr(model, 'colo_attr'):
|
||||
sig = inspect.signature(model.module.forward)
|
||||
def _call_engine(model, data):
|
||||
if data is not None:
|
||||
if isinstance(data, torch.Tensor):
|
||||
return model(data)
|
||||
elif isinstance(data, (list, tuple)):
|
||||
return model(*data)
|
||||
elif isinstance(data, dict):
|
||||
return model(**data)
|
||||
else:
|
||||
sig = inspect.signature(model.forward)
|
||||
if isinstance(batch_data, torch.Tensor):
|
||||
for p in sig.parameters.values():
|
||||
if p.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
if input_obj is None:
|
||||
return model(batch_data)
|
||||
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
|
||||
def _get_actual_forward_func(self, module):
|
||||
if isinstance(module, NaiveAMPModel):
|
||||
sig = inspect.signature(module.model.forward)
|
||||
elif hasattr(module, 'colo_attr'):
|
||||
sig = inspect.signature(module.module.forward)
|
||||
else:
|
||||
return model(input_obj)
|
||||
if input_obj is None:
|
||||
return model(batch_data)
|
||||
elif isinstance(input_obj, torch.Tensor):
|
||||
if len(sig.parameters) > 1:
|
||||
return model(input_obj, batch_data)
|
||||
sig = inspect.signature(module.forward)
|
||||
return sig
|
||||
|
||||
def _get_data_label_for_current_step(self, stage_output, micro_batch_data, criterion, model):
|
||||
if self.data_process_func:
|
||||
# use customized function to get data and label
|
||||
data, label = self.data_process_func(stage_output, micro_batch_data)
|
||||
else:
|
||||
return model(input_obj)
|
||||
if isinstance(micro_batch_data, (tuple, list)):
|
||||
if gpc.is_first_rank(ParallelMode.PIPELINE):
|
||||
# for the first stage, we use the data from the
|
||||
# dataloader output by default
|
||||
data, label = micro_batch_data
|
||||
else:
|
||||
if len(sig.parameters) > len(input_obj):
|
||||
return model(*input_obj, batch_data)
|
||||
# for non-first stage, we use the output passed
|
||||
# by the previous as the model input
|
||||
data = stage_output
|
||||
_, label = micro_batch_data
|
||||
elif isinstance(micro_batch_data, dict):
|
||||
args = []
|
||||
data = {}
|
||||
label = {}
|
||||
|
||||
# we feed the stage output to args first
|
||||
# then map each arg in args to its param name
|
||||
if stage_output is not None:
|
||||
if isinstance(stage_output, torch.Tensor):
|
||||
args.append(stage_output)
|
||||
elif isinstance(stage_output, (list, tuple)):
|
||||
args.extend(stage_output)
|
||||
else:
|
||||
return model(*input_obj)
|
||||
raise TypeError(
|
||||
f"Expected the values passed from previous pipeline stage to be torch.Tensor, list or tuple, but got {type(input_obj)}"
|
||||
)
|
||||
|
||||
# get all parameter names for the forward function of the model
|
||||
fwd_sig = self._get_actual_forward_func(model)
|
||||
fwd_sig_param_name = [p.name for p in fwd_sig.values()]
|
||||
|
||||
# build the kwargs for the forward function
|
||||
for idx, param_name in enumerate(fwd_sig_param_name):
|
||||
if idx < len(args):
|
||||
data[param_name] = args[idx]
|
||||
else:
|
||||
filter_batch = True
|
||||
for p in sig.parameters.values():
|
||||
if p.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
filter_batch = False
|
||||
if filter_batch:
|
||||
batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters}
|
||||
if input_obj is None and filter_batch:
|
||||
return model(**batch_data)
|
||||
elif isinstance(input_obj, torch.Tensor) or input_obj is None:
|
||||
return model(input_obj, **batch_data)
|
||||
else:
|
||||
return model(*input_obj, **batch_data)
|
||||
if param_name in micro_batch_data:
|
||||
data[param_name] = micro_batch_data[param_name]
|
||||
|
||||
# get the tensors for loss
|
||||
loss_sig = inspect.signature(criterion)
|
||||
loss_sig_param_name = [p.name for p in loss_sig.values()]
|
||||
|
||||
for param_name in loss_sig_param_name:
|
||||
if param_name in micro_batch_data:
|
||||
label[param_name] = micro_batch_data[param_name]
|
||||
return data, label
|
||||
|
||||
def _forward_step(self, engine, input_obj, return_tensors, return_output_label=True, accum_loss=None):
|
||||
"""Forward step for passed-in model. If it is the first stage, the input tensor
|
||||
|
@ -191,8 +249,11 @@ class PipelineSchedule(BaseSchedule):
|
|||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
|
||||
"""
|
||||
data, label = self.load_micro_batch()
|
||||
output_obj = self._call_engine(engine.model, input_obj, data)
|
||||
micro_batch_data = self.load_micro_batch()
|
||||
|
||||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion, engine.model)
|
||||
|
||||
output_obj = self._call_engine(engine.model, data)
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
if return_output_label:
|
||||
|
@ -399,7 +460,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
def __init__(self,
|
||||
num_microbatches: int,
|
||||
num_model_chunks: int,
|
||||
batch_data_process_func: Callable = None,
|
||||
data_process_func: Callable = None,
|
||||
tensor_shape: Union[torch.Size, List[int], Tuple[int]] = None,
|
||||
scatter_gather_tensors: bool = False):
|
||||
"""A helper schedule class for pipeline parallelism running environment.
|
||||
|
@ -409,7 +470,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
Args:
|
||||
num_microbatches (int): The number of microbatches.
|
||||
num_model_chunks (int): The number of model chunks.
|
||||
batch_data_process_func (Callable, optional):
|
||||
data_process_func (Callable, optional):
|
||||
The preprocessing function which receives a batch of data, and it will be executed in `load_batch`.
|
||||
tensor_shape (torch.Size, optional): Specified shape in pipeline communication.
|
||||
scatter_gather_tensors (bool, optional):
|
||||
|
@ -420,7 +481,7 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
assert isinstance(num_model_chunks, int) and num_model_chunks > 0, \
|
||||
f'expected num_model_chunks to be an integer and larger than 0, but got {num_model_chunks}'
|
||||
super().__init__(num_microbatches,
|
||||
batch_data_process_func=batch_data_process_func,
|
||||
data_process_func=data_process_func,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather_tensors)
|
||||
gpc.set_virtual_pipeline_parallel_size(num_model_chunks)
|
||||
|
@ -446,9 +507,8 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
|
||||
def load_micro_batch(self, model_chunk_id):
|
||||
data = self._get_data_slice(self.batch_data, self.microbatch_offset[model_chunk_id])
|
||||
label = self._get_data_slice(self.batch_label, self.microbatch_offset[model_chunk_id])
|
||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
return self._move_to_device(data), self._move_to_device(label)
|
||||
return self._move_to_device(data)
|
||||
|
||||
def _forward_step(self,
|
||||
engine,
|
||||
|
@ -471,8 +531,11 @@ class InterleavedPipelineSchedule(PipelineSchedule):
|
|||
Returns:
|
||||
Union[:class:`torch.Tensor`, List[:class:`torch.Tensor`]]: output or the loss value of the current pipeline stage.
|
||||
"""
|
||||
data, label = self.load_micro_batch(model_chunk_id)
|
||||
output_obj = self._call_engine(engine.model[model_chunk_id], input_obj, data)
|
||||
micro_batch_data = self.load_micro_batch(model_chunk_id)
|
||||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, engine.criterion,
|
||||
engine.model[model_chunk_id])
|
||||
|
||||
output_obj = self._call_engine(engine.model[model_chunk_id], data)
|
||||
|
||||
if gpc.is_pipeline_last_stage():
|
||||
if return_output_label:
|
||||
|
|
Loading…
Reference in New Issue