From 18091581c0b5680afa8220f5242f19c67c7ec469 Mon Sep 17 00:00:00 2001 From: YuliangLiu0306 <72588413+YuliangLiu0306@users.noreply.github.com> Date: Tue, 21 Jun 2022 14:40:50 +0800 Subject: [PATCH] [pipeline]support more flexible pipeline (#1138) * [CLI] add CLI launcher * Revert "[CLI] add CLI launcher" This reverts commit df7e6506d4500af6a9220ef7fe4d3c7b1daebd4c. * [pipeline]support more flexible pipeline --- .../engine/schedule/_pipeline_schedule.py | 17 ++++-- colossalai/pipeline/pipelinable.py | 38 +++---------- colossalai/pipeline/utils.py | 54 +++++++++++++++++-- 3 files changed, 69 insertions(+), 40 deletions(-) diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 6f1d755b8..6e865ae8f 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -165,9 +165,9 @@ class PipelineSchedule(BaseSchedule): if isinstance(model, ShardedModelV2): self.dtype = torch.half model = model.module - sig = inspect.signature(model.forward) - for p in sig.parameters.values(): - assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' + # sig = inspect.signature(model.forward) + # for p in sig.parameters.values(): + # assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' @staticmethod def _call_engine(model, data): @@ -180,7 +180,16 @@ class PipelineSchedule(BaseSchedule): stage_output = None if 'stage_output' in data: stage_output = data.pop('stage_output') - return model(stage_output, **data) + if stage_output is None: + return model(**data) + elif isinstance(stage_output, torch.Tensor): + return model(stage_output, **data) + elif isinstance(stage_output, (tuple, list)): + return model(*stage_output, **data) + else: + raise TypeError( + f"Expected stage_output to be of type torch.Tensor, list, or tuple, but got {type(stage_output)}" + ) else: raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/pipeline/pipelinable.py index d7db77c9d..9f3c7cf13 100644 --- a/colossalai/pipeline/pipelinable.py +++ b/colossalai/pipeline/pipelinable.py @@ -1,7 +1,7 @@ import torch import inspect from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses -from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs +from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, call_module from colossalai.nn.layer.utils import CheckpointModule from colossalai.tensor import ColoParameter from .layer_sepc import LayerSpec @@ -213,8 +213,7 @@ class PipelinableModel(torch.nn.Module): self._front_func_dict = front_func_dict self._behind_func_dict = behind_func_dict - def forward(self, input_tensor, **kwargs): - + def forward(self, *input_tensor, **kwargs): for module in self._module_list: if id(module) in self._front_func_dict: @@ -224,36 +223,13 @@ class PipelinableModel(torch.nn.Module): forward_func = module._forward else: forward_func = module.forward + module_kwargs = build_kwargs_for_module(forward_func, input_tensor, kwargs) if input_tensor is None: - module_kwargs = build_kwargs_for_function(forward_func, kwargs) + input_tensor = call_module(module, kwargs=module_kwargs) + elif isinstance(input_tensor, torch.Tensor): + input_tensor = call_module(module, args=(input_tensor,), kwargs=module_kwargs) else: - module_kwargs = build_kwargs_for_module(forward_func, kwargs) - if module_kwargs is not None and input_tensor is not None: - if isinstance(module, CheckpointModule): - convert_kwargs_to_args = [] - for v in module_kwargs.values(): - convert_kwargs_to_args.append(v) - rst = module(input_tensor, *convert_kwargs_to_args) - else: - rst = module(input_tensor, **module_kwargs) - if isinstance(rst, tuple): - input_tensor = rst[0] - else: - input_tensor = rst - elif module_kwargs is not None and input_tensor is None: - if isinstance(module, CheckpointModule): - convert_kwargs_to_args = [] - for v in module_kwargs.values(): - convert_kwargs_to_args.append(v) - rst = module(input_tensor, *convert_kwargs_to_args) - else: - rst = module(**module_kwargs) - if isinstance(rst, tuple): - input_tensor = rst[0] - else: - input_tensor = rst - else: - input_tensor = module(input_tensor) + input_tensor = call_module(module, args=input_tensor, kwargs=module_kwargs) if id(module) in self._behind_func_dict: input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs) diff --git a/colossalai/pipeline/utils.py b/colossalai/pipeline/utils.py index 6d1ea73d5..7029ab215 100644 --- a/colossalai/pipeline/utils.py +++ b/colossalai/pipeline/utils.py @@ -1,9 +1,12 @@ import heapq import inspect +import torch from colossalai.logging import get_dist_logger +from colossalai.nn.layer.utils import CheckpointModule from typing import List + def _binary_partition(weights: List, start: int, end: int): """Returns the binary partition position of `weights`, given the start position `st` and the end position `ed`. @@ -146,16 +149,23 @@ def partition_balanced(weights, pipeline_parallel_size, num_chunks): return parts -def build_kwargs_for_module(function, kw_dict): +def build_kwargs_for_module(function, input_tensor, kw_dict): """ Generally, the first argument of module.forward is an input tensor come from the previous layer. Therefore, we just filter the kwargs from second element of the dictionary. """ sig = inspect.signature(function) - if len(sig.parameters) <= 1: - return None + if input_tensor is None: + kwargs_offset = 0 + elif isinstance(input_tensor, torch.Tensor): + kwargs_offset = 1 + else: + assert isinstance(input_tensor, tuple), f'input_tensor should be a torch.Tensor or a tuple object.' + kwargs_offset = len(input_tensor) args_name_list = list(sig.parameters.keys()) - kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]} + kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[kwargs_offset:]} + if len(kw_dict) == 0: + return None return kw_dict @@ -189,6 +199,17 @@ def exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs): for k in kw_dict.keys(): kwargs[k] = rst return input_tensor + if isinstance(input_tensor, tuple): + assert len(input_tensor) > 0, f'input_tensor should not be empty, when kw_dict is None.' + sig = inspect.signature(func) + func_args_num = len(sig.parameters) + assert func_args_num <= len( + input_tensor), f'func requires {func_args_num} arguments, but input_tensors only have {len(input_tensor)}.' + if func_args_num < len(input_tensor): + return func(*input_tensor[:func_args_num]) + else: + return func(*input_tensor) + assert isinstance(input_tensor, torch.Tensor), 'input_tensor should be a type of torch.Tensor or tuple.' return func(input_tensor) @@ -204,4 +225,27 @@ def exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs): f_kwargs = build_kwargs_for_function(funcs_to_exec, kwargs) input_tensor = exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs) - return input_tensor \ No newline at end of file + return input_tensor + + +def call_module(module, args=None, kwargs=None): + if args is None: + args = () + if kwargs is None: + kwargs = {} + if isinstance(module, CheckpointModule): + forward_func = module._forward + else: + forward_func = module.forward + sig = inspect.signature(forward_func) + param_nums = len(sig.parameters) + feed_nums = len(args) + len(kwargs) + args_needed_nums = param_nums - len(kwargs) + args_needed = args[:args_needed_nums] + if isinstance(module, CheckpointModule): + convert_kwargs_to_args = [] + for v in kwargs.values(): + convert_kwargs_to_args.append(v) + return module(*args_needed, *convert_kwargs_to_args) + else: + return module(*args_needed, **kwargs)