[pipeline]support more flexible pipeline (#1138)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [pipeline]support more flexible pipeline
pull/1145/head
YuliangLiu0306 2022-06-21 14:40:50 +08:00 committed by GitHub
parent ccf3c58c89
commit 18091581c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 69 additions and 40 deletions

View File

@ -165,9 +165,9 @@ class PipelineSchedule(BaseSchedule):
if isinstance(model, ShardedModelV2):
self.dtype = torch.half
model = model.module
sig = inspect.signature(model.forward)
for p in sig.parameters.values():
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
# sig = inspect.signature(model.forward)
# for p in sig.parameters.values():
# assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
@staticmethod
def _call_engine(model, data):
@ -180,7 +180,16 @@ class PipelineSchedule(BaseSchedule):
stage_output = None
if 'stage_output' in data:
stage_output = data.pop('stage_output')
if stage_output is None:
return model(**data)
elif isinstance(stage_output, torch.Tensor):
return model(stage_output, **data)
elif isinstance(stage_output, (tuple, list)):
return model(*stage_output, **data)
else:
raise TypeError(
f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}"
)
else:
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")

View File

@ -1,7 +1,7 @@
import torch
import inspect
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, call_module
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoParameter
from .layer_sepc import LayerSpec
@ -213,8 +213,7 @@ class PipelinableModel(torch.nn.Module):
self._front_func_dict = front_func_dict
self._behind_func_dict = behind_func_dict
def forward(self, input_tensor, **kwargs):
def forward(self, *input_tensor, **kwargs):
for module in self._module_list:
if id(module) in self._front_func_dict:
@ -224,36 +223,13 @@ class PipelinableModel(torch.nn.Module):
forward_func = module._forward
else:
forward_func = module.forward
module_kwargs = build_kwargs_for_module(forward_func, input_tensor, kwargs)
if input_tensor is None:
module_kwargs = build_kwargs_for_function(forward_func, kwargs)
input_tensor = call_module(module, kwargs=module_kwargs)
elif isinstance(input_tensor, torch.Tensor):
input_tensor = call_module(module, args=(input_tensor,), kwargs=module_kwargs)
else:
module_kwargs = build_kwargs_for_module(forward_func, kwargs)
if module_kwargs is not None and input_tensor is not None:
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
for v in module_kwargs.values():
convert_kwargs_to_args.append(v)
rst = module(input_tensor, *convert_kwargs_to_args)
else:
rst = module(input_tensor, **module_kwargs)
if isinstance(rst, tuple):
input_tensor = rst[0]
else:
input_tensor = rst
elif module_kwargs is not None and input_tensor is None:
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
for v in module_kwargs.values():
convert_kwargs_to_args.append(v)
rst = module(input_tensor, *convert_kwargs_to_args)
else:
rst = module(**module_kwargs)
if isinstance(rst, tuple):
input_tensor = rst[0]
else:
input_tensor = rst
else:
input_tensor = module(input_tensor)
input_tensor = call_module(module, args=input_tensor, kwargs=module_kwargs)
if id(module) in self._behind_func_dict:
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)

View File

@ -1,9 +1,12 @@
import heapq
import inspect
import torch
from colossalai.logging import get_dist_logger
from colossalai.nn.layer.utils import CheckpointModule
from typing import List
def _binary_partition(weights: List, start: int, end: int):
"""Returns the binary partition position of `weights`, given the start
position `st` and the end position `ed`.
@ -146,16 +149,23 @@ def partition_balanced(weights, pipeline_parallel_size, num_chunks):
return parts
def build_kwargs_for_module(function, kw_dict):
def build_kwargs_for_module(function, input_tensor, kw_dict):
"""
Generally, the first argument of module.forward is an input tensor come from the previous layer.
Therefore, we just filter the kwargs from second element of the dictionary.
"""
sig = inspect.signature(function)
if len(sig.parameters) <= 1:
return None
if input_tensor is None:
kwargs_offset = 0
elif isinstance(input_tensor, torch.Tensor):
kwargs_offset = 1
else:
assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
kwargs_offset = len(input_tensor)
args_name_list = list(sig.parameters.keys())
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]}
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]}
if len(kw_dict) == 0:
return None
return kw_dict
@ -189,6 +199,17 @@ def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
for k in kw_dict.keys():
kwargs[k] = rst
return input_tensor
if isinstance(input_tensor, tuple):
assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.'
sig = inspect.signature(func)
func_args_num = len(sig.parameters)
assert func_args_num <= len(
input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.'
if func_args_num < len(input_tensor):
return func(*input_tensor[:func_args_num])
else:
return func(*input_tensor)
assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.'
return func(input_tensor)
@ -205,3 +226,26 @@ def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
return input_tensor
def call_module(module, args=None, kwargs=None):
if args is None:
args = ()
if kwargs is None:
kwargs = {}
if isinstance(module, CheckpointModule):
forward_func = module._forward
else:
forward_func = module.forward
sig = inspect.signature(forward_func)
param_nums = len(sig.parameters)
feed_nums = len(args) + len(kwargs)
args_needed_nums = param_nums - len(kwargs)
args_needed = args[:args_needed_nums]
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
for v in kwargs.values():
convert_kwargs_to_args.append(v)
return module(*args_needed, *convert_kwargs_to_args)
else:
return module(*args_needed, **kwargs)