[pipeline]add customized policy (#1139)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [pipeline]add customized policy
pull/1148/head
YuliangLiu0306 2022-06-21 15:23:41 +08:00 committed by GitHub
parent d1918304bb
commit 70dd88e2ee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 39 additions and 2 deletions

View File

@ -1,9 +1,14 @@
import torch
import inspect
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, call_module
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, \
build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs, \
call_module, customized_partition
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoParameter
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from .layer_sepc import LayerSpec
@ -113,6 +118,8 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
Create a layer spec list and func list with execution sequence given by user.
If exec_seq is None, we will take the module initizing order as execution order.
"""
self._exec_seq = exec_seq
if exec_seq is None:
# if user do not provide the model executing sequence, we use the initialization order as the executing order.
children_name = []
@ -138,6 +145,8 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
named_modules = dict(self._model.named_modules())
for index, element in enumerate(exec_seq):
if isinstance(element, str):
if element == 'SPLIT_NODE':
continue
assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.'
# get the layer spec based on the module ID
@ -178,8 +187,15 @@ class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
for layer_spec in self._layer_spec_list:
param_counts.append(layer_spec.count_params())
parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
elif self._policy == "customized":
assert self._exec_seq is not None, f'An explicit exec_seq must be defined by user in customized policy mode.'
self.customized_parts = customized_partition(self._exec_seq)
assert len(self.customized_parts) == gpc.get_world_size(
ParallelMode.PIPELINE
), f'World size is {gpc.get_world_size(ParallelMode.PIPELINE)}, but the number of partions is {len(self.customized_parts)}'
parts = self.customized_parts[rank]
else:
raise ValueError("A string partition policy should be one of ['uniform', 'balanced'].")
raise ValueError("A string partition policy should be one of ['uniform', 'balanced', 'customized'].")
elif isinstance(self._policy, dict):
parts = self._policy[rank]
else:

View File

@ -249,3 +249,24 @@ def call_module(module, args=None, kwargs=None):
return module(*args_needed, *convert_kwargs_to_args)
else:
return module(*args_needed, **kwargs)
def customized_partition(exec_seq):
'''
This function will analyze the exec_seq. In the exec_seq, users will use 'SPLIT_NODE' as an
annotation to note the partition point.
'''
customized_parts = {}
start = 0
stop = 0
rank = 0
for element in exec_seq:
if isinstance(element, str):
if element == 'SPLIT_NODE':
customized_parts[rank] = [(start, stop)]
start = stop
rank += 1
else:
stop += 1
customized_parts[rank] = [(start, stop)]
return customized_parts