mirror of https://github.com/hpcaitech/ColossalAI
[pipeline]support more flexible pipeline (#1138)
* [CLI] add CLI launcher
* Revert "[CLI] add CLI launcher"
This reverts commit df7e6506d4
.
* [pipeline]support more flexible pipeline
pull/1145/head
parent
ccf3c58c89
commit
18091581c0
|
@ -165,9 +165,9 @@ class PipelineSchedule(BaseSchedule):
|
|||
if isinstance(model, ShardedModelV2):
|
||||
self.dtype = torch.half
|
||||
model = model.module
|
||||
sig = inspect.signature(model.forward)
|
||||
for p in sig.parameters.values():
|
||||
assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||
# sig = inspect.signature(model.forward)
|
||||
# for p in sig.parameters.values():
|
||||
# assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported'
|
||||
|
||||
@staticmethod
|
||||
def _call_engine(model, data):
|
||||
|
@ -180,7 +180,16 @@ class PipelineSchedule(BaseSchedule):
|
|||
stage_output = None
|
||||
if 'stage_output' in data:
|
||||
stage_output = data.pop('stage_output')
|
||||
if stage_output is None:
|
||||
return model(**data)
|
||||
elif isinstance(stage_output, torch.Tensor):
|
||||
return model(stage_output, **data)
|
||||
elif isinstance(stage_output, (tuple, list)):
|
||||
return model(*stage_output, **data)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}"
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch
|
||||
import inspect
|
||||
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||||
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs
|
||||
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, call_module
|
||||
from colossalai.nn.layer.utils import CheckpointModule
|
||||
from colossalai.tensor import ColoParameter
|
||||
from .layer_sepc import LayerSpec
|
||||
|
@ -213,8 +213,7 @@ class PipelinableModel(torch.nn.Module):
|
|||
self._front_func_dict = front_func_dict
|
||||
self._behind_func_dict = behind_func_dict
|
||||
|
||||
def forward(self, input_tensor, **kwargs):
|
||||
|
||||
def forward(self, *input_tensor, **kwargs):
|
||||
for module in self._module_list:
|
||||
|
||||
if id(module) in self._front_func_dict:
|
||||
|
@ -224,36 +223,13 @@ class PipelinableModel(torch.nn.Module):
|
|||
forward_func = module._forward
|
||||
else:
|
||||
forward_func = module.forward
|
||||
module_kwargs = build_kwargs_for_module(forward_func, input_tensor, kwargs)
|
||||
if input_tensor is None:
|
||||
module_kwargs = build_kwargs_for_function(forward_func, kwargs)
|
||||
input_tensor = call_module(module, kwargs=module_kwargs)
|
||||
elif isinstance(input_tensor, torch.Tensor):
|
||||
input_tensor = call_module(module, args=(input_tensor,), kwargs=module_kwargs)
|
||||
else:
|
||||
module_kwargs = build_kwargs_for_module(forward_func, kwargs)
|
||||
if module_kwargs is not None and input_tensor is not None:
|
||||
if isinstance(module, CheckpointModule):
|
||||
convert_kwargs_to_args = []
|
||||
for v in module_kwargs.values():
|
||||
convert_kwargs_to_args.append(v)
|
||||
rst = module(input_tensor, *convert_kwargs_to_args)
|
||||
else:
|
||||
rst = module(input_tensor, **module_kwargs)
|
||||
if isinstance(rst, tuple):
|
||||
input_tensor = rst[0]
|
||||
else:
|
||||
input_tensor = rst
|
||||
elif module_kwargs is not None and input_tensor is None:
|
||||
if isinstance(module, CheckpointModule):
|
||||
convert_kwargs_to_args = []
|
||||
for v in module_kwargs.values():
|
||||
convert_kwargs_to_args.append(v)
|
||||
rst = module(input_tensor, *convert_kwargs_to_args)
|
||||
else:
|
||||
rst = module(**module_kwargs)
|
||||
if isinstance(rst, tuple):
|
||||
input_tensor = rst[0]
|
||||
else:
|
||||
input_tensor = rst
|
||||
else:
|
||||
input_tensor = module(input_tensor)
|
||||
input_tensor = call_module(module, args=input_tensor, kwargs=module_kwargs)
|
||||
|
||||
if id(module) in self._behind_func_dict:
|
||||
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
||||
|
|
|
@ -1,9 +1,12 @@
|
|||
import heapq
|
||||
import inspect
|
||||
import torch
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.layer.utils import CheckpointModule
|
||||
from typing import List
|
||||
|
||||
|
||||
def _binary_partition(weights: List, start: int, end: int):
|
||||
"""Returns the binary partition position of `weights`, given the start
|
||||
position `st` and the end position `ed`.
|
||||
|
@ -146,16 +149,23 @@ def partition_balanced(weights, pipeline_parallel_size, num_chunks):
|
|||
return parts
|
||||
|
||||
|
||||
def build_kwargs_for_module(function, kw_dict):
|
||||
def build_kwargs_for_module(function, input_tensor, kw_dict):
|
||||
"""
|
||||
Generally, the first argument of module.forward is an input tensor come from the previous layer.
|
||||
Therefore, we just filter the kwargs from second element of the dictionary.
|
||||
"""
|
||||
sig = inspect.signature(function)
|
||||
if len(sig.parameters) <= 1:
|
||||
return None
|
||||
if input_tensor is None:
|
||||
kwargs_offset = 0
|
||||
elif isinstance(input_tensor, torch.Tensor):
|
||||
kwargs_offset = 1
|
||||
else:
|
||||
assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.'
|
||||
kwargs_offset = len(input_tensor)
|
||||
args_name_list = list(sig.parameters.keys())
|
||||
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]}
|
||||
kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]}
|
||||
if len(kw_dict) == 0:
|
||||
return None
|
||||
return kw_dict
|
||||
|
||||
|
||||
|
@ -189,6 +199,17 @@ def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs):
|
|||
for k in kw_dict.keys():
|
||||
kwargs[k] = rst
|
||||
return input_tensor
|
||||
if isinstance(input_tensor, tuple):
|
||||
assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.'
|
||||
sig = inspect.signature(func)
|
||||
func_args_num = len(sig.parameters)
|
||||
assert func_args_num <= len(
|
||||
input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.'
|
||||
if func_args_num < len(input_tensor):
|
||||
return func(*input_tensor[:func_args_num])
|
||||
else:
|
||||
return func(*input_tensor)
|
||||
assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.'
|
||||
return func(input_tensor)
|
||||
|
||||
|
||||
|
@ -205,3 +226,26 @@ def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs):
|
|||
input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs)
|
||||
|
||||
return input_tensor
|
||||
|
||||
|
||||
def call_module(module, args=None, kwargs=None):
|
||||
if args is None:
|
||||
args = ()
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
if isinstance(module, CheckpointModule):
|
||||
forward_func = module._forward
|
||||
else:
|
||||
forward_func = module.forward
|
||||
sig = inspect.signature(forward_func)
|
||||
param_nums = len(sig.parameters)
|
||||
feed_nums = len(args) + len(kwargs)
|
||||
args_needed_nums = param_nums - len(kwargs)
|
||||
args_needed = args[:args_needed_nums]
|
||||
if isinstance(module, CheckpointModule):
|
||||
convert_kwargs_to_args = []
|
||||
for v in kwargs.values():
|
||||
convert_kwargs_to_args.append(v)
|
||||
return module(*args_needed, *convert_kwargs_to_args)
|
||||
else:
|
||||
return module(*args_needed, **kwargs)
|
||||
|
|
Loading…
Reference in New Issue