diff --git a/colossalai/engine/schedule/_pipeline_schedule.py b/colossalai/engine/schedule/_pipeline_schedule.py index 43cf61a65..8c408b436 100644 --- a/colossalai/engine/schedule/_pipeline_schedule.py +++ b/colossalai/engine/schedule/_pipeline_schedule.py @@ -119,9 +119,12 @@ class PipelineSchedule(BaseSchedule): def pre_processing(self, engine): # TODO: remove this after testing new zero with pipeline parallelism model = engine.model - if isinstance(model, (NaiveAMPModel, ShardedModelV2)): + if isinstance(model, NaiveAMPModel): self.dtype = torch.half model = model.model + if isinstance(model, ShardedModelV2): + self.dtype = torch.half + model = model.module sig = inspect.signature(model.forward) for p in sig.parameters.values(): assert p.kind != inspect.Parameter.VAR_POSITIONAL, '*args is not supported' @@ -135,6 +138,12 @@ class PipelineSchedule(BaseSchedule): else: sig = inspect.signature(model.forward) if isinstance(batch_data, torch.Tensor): + for p in sig.parameters.values(): + if p.kind == inspect.Parameter.VAR_KEYWORD: + if input_tensor is None: + return model(batch_data) + else: + return model(input_tensor) if input_tensor is None: return model(batch_data) elif len(sig.parameters) > 1: @@ -148,7 +157,7 @@ class PipelineSchedule(BaseSchedule): filter_batch = False if filter_batch: batch_data = {k: v for k, v in batch_data.items() if k in sig.parameters} - if input_tensor is None: + if input_tensor is None and filter_batch: return model(**batch_data) else: return model(input_tensor, **batch_data) diff --git a/colossalai/utils/model/pipelinable.py b/colossalai/utils/model/pipelinable.py index 379430366..82323ef18 100644 --- a/colossalai/utils/model/pipelinable.py +++ b/colossalai/utils/model/pipelinable.py @@ -1,8 +1,11 @@ import torch import functools +import inspect +from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str from colossalai.builder.pipeline import partition_uniform, partition_balanced from colossalai.core import global_context as gpc +from colossalai.nn.layer.utils import CheckpointModule from colossalai.tensor import ColoTensor @@ -58,11 +61,18 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): if issubclass(obj.__class__, torch.nn.modules.module.Module): obj = self._layer_spec_dict[id(obj)] modified_args.append(obj) - # (lyl)TODO: analyse kwargs as well + + modified_kwargs = {} + for k, v in kwargs.items(): + if issubclass(v.__class__, torch.nn.modules.module.Module): + v = self._layer_spec_dict[id(v)] + # (lyl)TODO: analyse ColoTensor as well + modified_kwargs[k] = v + modified_args = tuple(modified_args) self._root_children = list(module.children()) self._model = module - layer_spec = LayerSpec(module.__class__, *modified_args, **kwargs) + layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs) layer_spec.set_children(module.children()) self._layer_spec_dict[module_id] = layer_spec name_list = [] @@ -82,28 +92,49 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): """ if exec_seq is None: #if user do not provide the model executing sequence, we use the initialization order as the executing order. + children_name = [] for child in self._root_children: layer_spec = self._layer_spec_dict[id(child)] if layer_spec.typename in (torch.nn.modules.container.ModuleList, torch.nn.modules.container.Sequential): for child_in_container in layer_spec.children: self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)]) + for name, module in self._model.named_modules(): + if id(module) == id(child_in_container): + children_name.append(name) + break else: self._layer_spec_list.append(layer_spec) + for name, module in self._model.named_modules(): + if id(module) == id(child): + children_name.append(name) + break else: - func_key = "first" + front_funcs_list = [] for index, element in enumerate(exec_seq): if isinstance(element, str): module = dict(self._model.named_modules())[element] layer_spec = self._layer_spec_dict[id(module)] - func_key = layer_spec + if len(front_funcs_list) != 0: + func_key = (layer_spec, "front") + if func_key not in self._func_dict: + self._func_dict[func_key] = [] + for f in front_funcs_list: + self._func_dict[func_key].append(f) + front_funcs_list = [] + func_key = (layer_spec, "behind") self._layer_spec_list.append(layer_spec) + elif isinstance(element, tuple) and element[1] == "front": + front_funcs_list.append(element[0]) else: if func_key not in self._func_dict: self._func_dict[func_key] = [] - self._func_dict[func_key].append(element) + if isinstance(element, tuple): + self._func_dict[func_key].append(element[0]) + else: + self._func_dict[func_key].append(element) def partition(self, num_chunks, pipeline_size, rank): """ @@ -128,17 +159,19 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): layers_to_build = [] for start, end in parts: layers_to_build += self._layer_spec_list[start:end] - func_dict_in_partition = {} + behind_func_dict_in_partition = {} + front_func_dict_in_partition = {} module_list_in_partition = [] - if rank == 0 and "first" in self._func_dict: - func_dict_in_partition["first"] = self._func_dict["first"] for layer in layers_to_build: module = layer.build() module_list_in_partition.append(module) - if layer in self._func_dict: - func_dict_in_partition[id(module)] = self._func_dict[layer] + if (layer, "front") in self._func_dict: + front_func_dict_in_partition[id(module)] = self._func_dict[(layer, "front")] + elif (layer, "behind") in self._func_dict: + behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")] module_list_in_partition = torch.nn.ModuleList(module_list_in_partition) - pipeline_model = PipelinableModel(module_list_in_partition, func_dict_in_partition) + pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition, + behind_func_dict_in_partition) return pipeline_model @@ -146,31 +179,119 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): self._policy = policy +def _build_kwargs_for_module(function, kw_dict): + """ + Generally, the first argument of module.forward is an input tensor come from the previous layer. + Therefore, we just filter the kwargs from second element of the dictionary. + """ + sig = inspect.signature(function) + if len(sig.parameters) <= 1: + return None + args_name_list = list(sig.parameters.keys()) + kw_dict = {k: v for k, v in kw_dict.items() if k in args_name_list[1:]} + return kw_dict + + +def _build_kwargs_for_function(function, kw_dict): + sig = inspect.signature(function) + kw_dict = {k: v for k, v in kw_dict.items() if k in sig.parameters} + if len(kw_dict) == 0: + return None + return kw_dict + + +def _exec_func_with_kwargs(func, kw_dict, input_tensor, kwargs): + """ + We suppose the callable object passed to to_layer_list method in two purpose: + a. use the callable object to modify input tensor, such as \ + lambda x: torch.flatten(x, 1) + b. use the callable object to modify kwargs value, such as \ + def foo(attention_mask=None): + if attention_mask is not None: + batch_size = input_ids.shape[0] + attention_mask = attention_mask.view(batch_size, -1) + return attention_mask + """ + + if kw_dict is not None: + rst = func(**kw_dict) + if isinstance(rst, tuple): + for i, k in enumerate(kw_dict.keys()): + kwargs[k] = rst[i] + else: + for k in kw_dict.keys(): + kwargs[k] = rst + return input_tensor + return func(input_tensor) + + +def _exec_funcs_with_kwargs(func_dict, func_key, input_tensor, kwargs): + + assert func_key in func_dict, f"{func_key} is not in the function_dict." + funcs_to_exec = func_dict[func_key] + if isinstance(funcs_to_exec, list): + for f in funcs_to_exec: + f_kwargs = _build_kwargs_for_function(f, kwargs) + input_tensor = _exec_func_with_kwargs(f, f_kwargs, input_tensor, kwargs) + else: + f_kwargs = _build_kwargs_for_function(funcs_to_exec, kwargs) + input_tensor = _exec_func_with_kwargs(funcs_to_exec, f_kwargs, input_tensor, kwargs) + + return input_tensor + + class PipelinableModel(torch.nn.Module): - def __init__(self, module_list, func_dict): + def __init__(self, module_list, front_func_dict, behind_func_dict): super().__init__() self._module_list = module_list - self._func_dict = func_dict - - def forward(self, input_tensor): - if "first" in self._func_dict: - funcs = self._func_dict["first"] - if isinstance(funcs, list): - for f in funcs: - input_tensor = f(input_tensor) - else: - input_tensor = funcs(input_tensor) + self._front_func_dict = front_func_dict + self._behind_func_dict = behind_func_dict + + def forward(self, input_tensor, **kwargs): for module in self._module_list: - input_tensor = module(input_tensor) - if id(module) in self._func_dict: - funcs = self._func_dict[id(module)] - if isinstance(funcs, list): - for f in funcs: - input_tensor = f(input_tensor) + + if id(module) in self._front_func_dict: + input_tensor = _exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs) + + if isinstance(module, CheckpointModule): + forward_func = module._forward + else: + forward_func = module.forward + if input_tensor is None: + module_kwargs = _build_kwargs_for_function(forward_func, kwargs) + else: + module_kwargs = _build_kwargs_for_module(forward_func, kwargs) + if module_kwargs is not None and input_tensor is not None: + if isinstance(module, CheckpointModule): + convert_kwargs_to_args = [] + for v in module_kwargs.values(): + convert_kwargs_to_args.append(v) + rst = module(input_tensor, *convert_kwargs_to_args) + else: + rst = module(input_tensor, **module_kwargs) + if isinstance(rst, tuple): + input_tensor = rst[0] + else: + input_tensor = rst + elif module_kwargs is not None and input_tensor is None: + if isinstance(module, CheckpointModule): + convert_kwargs_to_args = [] + for v in module_kwargs.values(): + convert_kwargs_to_args.append(v) + rst = module(input_tensor, *convert_kwargs_to_args) else: - input_tensor = funcs(input_tensor) + rst = module(**module_kwargs) + if isinstance(rst, tuple): + input_tensor = rst[0] + else: + input_tensor = rst + else: + input_tensor = module(input_tensor) + + if id(module) in self._behind_func_dict: + input_tensor = _exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs) return input_tensor @@ -203,7 +324,14 @@ class LayerSpec: obj = obj.build() recovered_args.append(obj) recovered_args = tuple(recovered_args) - return self.typename(*recovered_args, **self.module_kwargs) + + recovered_kwargs = {} + for k, v in self.module_kwargs.items(): + if isinstance(v, LayerSpec): + v = v.build() + recovered_kwargs[k] = v + + return self.typename(*recovered_args, **recovered_kwargs) def set_children(self, children): self.children = children