diff --git a/colossalai/pipeline/pipelinable.py b/colossalai/pipeline/pipelinable.py index 9f3c7cf13..d55da0fce 100644 --- a/colossalai/pipeline/pipelinable.py +++ b/colossalai/pipeline/pipelinable.py @@ -1,9 +1,14 @@ import torch import inspect from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses -from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, call_module + +from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, \ + build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, \ + call_module, customized_partition from colossalai.nn.layer.utils import CheckpointModule from colossalai.tensor import ColoParameter +from colossalai.core import global_context as gpc +from colossalai.context import ParallelMode from .layer_sepc import LayerSpec @@ -113,6 +118,8 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): Create a layer spec list and func list with execution sequence given by user. If exec_seq is None, we will take the module initizing order as execution order. """ + + self._exec_seq = exec_seq if exec_seq is None: # if user do not provide the model executing sequence, we use the initialization order as the executing order. children_name = [] @@ -138,6 +145,8 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): named_modules = dict(self._model.named_modules()) for index, element in enumerate(exec_seq): if isinstance(element, str): + if element == 'SPLIT_NODE': + continue assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.' # get the layer spec based on the module ID @@ -178,8 +187,15 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses): for layer_spec in self._layer_spec_list: param_counts.append(layer_spec.count_params()) parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank] + elif self._policy == "customized": + assert self._exec_seq is not None, f'An explicit exec_seq must be defined by user in customized policy mode.' + self.customized_parts = customized_partition(self._exec_seq) + assert len(self.customized_parts) == gpc.get_world_size( + ParallelMode.PIPELINE + ), f'World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partions is {len(self.customized_parts)}' + parts = self.customized_parts[rank] else: - raise ValueError("A string partition policy should be one of ['uniform', 'balanced'].") + raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].") elif isinstance(self._policy, dict): parts = self._policy[rank] else: diff --git a/colossalai/pipeline/utils.py b/colossalai/pipeline/utils.py index 7029ab215..d4e759a63 100644 --- a/colossalai/pipeline/utils.py +++ b/colossalai/pipeline/utils.py @@ -249,3 +249,24 @@ def call_module(module, args=None, kwargs=None): return module(*args_needed, *convert_kwargs_to_args) else: return module(*args_needed, **kwargs) + +def customized_partition(exec_seq): + ''' + This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an + annotation to note the partition point. + ''' + customized_parts = {} + start = 0 + stop = 0 + rank = 0 + for element in exec_seq: + if isinstance(element, str): + if element == 'SPLIT_NODE': + customized_parts[rank] = [(start, stop)] + start = stop + rank += 1 + else: + stop += 1 + customized_parts[rank] = [(start, stop)] + return customized_parts +