[pipeline]support more flexible pipeline (#1138)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [pipeline]support more flexible pipeline
pull/1145/head
YuliangLiu0306 2022-06-21 14:40:50 +08:00 committed by GitHub
parent ccf3c58c89
commit 18091581c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 40 deletions

View File

@ -165,9 +165,9 @@ class PipelineSchedule(BaseSchedule):
if isinstance(model, ShardedModelV2): if isinstance(model, ShardedModelV2):
self.dtype = torch.half self.dtype = torch.half
model = model.module model = model.module
sig = inspect.signature(model.forward) # sig = inspect.signature(model.forward)
for p in sig.parameters.values(): # for p in sig.parameters.values():
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' # assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
@staticmethod @staticmethod
def _call_engine(model, data): def _call_engine(model, data):
@ -180,7 +180,16 @@ class PipelineSchedule(BaseSchedule):
stage_output = None stage_output = None
if 'stage_output' in data: if 'stage_output' in data:
stage_output = data.pop('stage_output') stage_output = data.pop('stage_output')
if stage_output is None:
return model(**data)
elif isinstance(stage_output, torch.Tensor):
return model(stage_output, **data) return model(stage_output, **data)
elif isinstance(stage_output, (tuple, list)):
return model(*stage_output, **data)
else:
raise TypeError(
f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}"
)
else: else:
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")

View File

@ -1,7 +1,7 @@
import torch import torch
import inspect import inspect
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, call_module
from colossalai.nn.layer.utils import CheckpointModule from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoParameter from colossalai.tensor import ColoParameter
from .layer_sepc import LayerSpec from .layer_sepc import LayerSpec
@ -213,8 +213,7 @@ class PipelinableModel(torch.nn.Module):
self._front_func_dict = front_func_dict self._front_func_dict = front_func_dict
self._behind_func_dict = behind_func_dict self._behind_func_dict = behind_func_dict
def forward(self, input_tensor, **kwargs): def forward(self, *input_tensor, **kwargs):
for module in self._module_list: for module in self._module_list:
if id(module) in self._front_func_dict: if id(module) in self._front_func_dict:
@ -224,36 +223,13 @@ class PipelinableModel(torch.nn.Module):
forward_func = module._forward forward_func = module._forward
else: else:
forward_func = module.forward forward_func = module.forward
module_kwargs = build_kwargs_for_module(forward_func, input_tensor, kwargs)
if input_tensor is None: if input_tensor is None:
module_kwargs = build_kwargs_for_function(forward_func, kwargs) input_tensor = call_module(module, kwargs=module_kwargs)
elif isinstance(input_tensor, torch.Tensor):
input_tensor = call_module(module, args=(input_tensor,), kwargs=module_kwargs)
else: else:
module_kwargs = build_kwargs_for_module(forward_func, kwargs) input_tensor = call_module(module, args=input_tensor, kwargs=module_kwargs)
if module_kwargs is not None and input_tensor is not None:
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
for v in module_kwargs.values():
convert_kwargs_to_args.append(v)
rst = module(input_tensor, *convert_kwargs_to_args)
else:
rst = module(input_tensor, **module_kwargs)
if isinstance(rst, tuple):
input_tensor = rst[0]
else:
input_tensor = rst
elif module_kwargs is not None and input_tensor is None:
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
for v in module_kwargs.values():
convert_kwargs_to_args.append(v)
rst = module(input_tensor, *convert_kwargs_to_args)
else:
rst = module(**module_kwargs)
if isinstance(rst, tuple):
input_tensor = rst[0]
else:
input_tensor = rst
else:
input_tensor = module(input_tensor)
if id(module) in self._behind_func_dict: if id(module) in self._behind_func_dict:
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs) input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)

View File

@ -1,9 +1,12 @@
import heapq import heapq
import inspect import inspect
import torch
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger
from colossalai.nn.layer.utils import CheckpointModule
from typing import List from typing import List
def _binary_partition(weights: List, start: int, end: int): def _binary_partition(weights: List, start: int, end: int):
"""Returns the binary partition position of `weights`, given the start """Returns the binary partition position of `weights`, given the start
position `st` and the end position `ed`. position `st` and the end position `ed`.
@ -146,16 +149,23 @@ def partition_balanced(weights, pipeline_parallel_size, num_chunks):
return parts return parts
def build_kwargs_for_module(function, kw_dict): def build_kwargs_for_module(function, input_tensor, kw_dict):
""" """
Generally, the first argument of module.forward is an input tensor come from the previous layer. Generally, the first argument of module.forward is an input tensor come from the previous layer.
Therefore, we just filter the kwargs from second element of the dictionary. Therefore, we just filter the kwargs from second element of the dictionary.
""" """
sig = inspect.signature(function) sig = inspect.signature(function)
if len(sig.parameters) <= 1: if input_tensor is None:
return None kwargs_offset = 0
elif isinstance(input_tensor, torch.Tensor):
kwargs_offset = 1
else:
assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
kwargs_offset = len(input_tensor)
args_name_list = list(sig.parameters.keys()) args_name_list = list(sig.parameters.keys())
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]} kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]}
if len(kw_dict) == 0:
return None
return kw_dict return kw_dict
@ -189,6 +199,17 @@ def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
for k in kw_dict.keys(): for k in kw_dict.keys():
kwargs[k] = rst kwargs[k] = rst
return input_tensor return input_tensor
if isinstance(input_tensor, tuple):
assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.'
sig = inspect.signature(func)
func_args_num = len(sig.parameters)
assert func_args_num <= len(
input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.'
if func_args_num < len(input_tensor):
return func(*input_tensor[:func_args_num])
else:
return func(*input_tensor)
assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.'
return func(input_tensor) return func(input_tensor)
@ -205,3 +226,26 @@ def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs) input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
return input_tensor return input_tensor
def call_module(module, args=None, kwargs=None):
if args is None:
args = ()
if kwargs is None:
kwargs = {}
if isinstance(module, CheckpointModule):
forward_func = module._forward
else:
forward_func = module.forward
sig = inspect.signature(forward_func)
param_nums = len(sig.parameters)
feed_nums = len(args) + len(kwargs)
args_needed_nums = param_nums - len(kwargs)
args_needed = args[:args_needed_nums]
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
for v in kwargs.values():
convert_kwargs_to_args.append(v)
return module(*args_needed, *convert_kwargs_to_args)
else:
return module(*args_needed, **kwargs)