diff --git a/colossalai/utils/model/colo_init_context.py b/colossalai/utils/model/colo_init_context.py index d32cc58d4..429bf2175 100644 --- a/colossalai/utils/model/colo_init_context.py +++ b/colossalai/utils/model/colo_init_context.py @@ -12,17 +12,6 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses): super().__init__() self._lazy_memory_allocate = lazy_memory_allocate - def _pre_context_exec(self): - """ - The Callback function when entering the context - """ - pass - - def _post_context_exec(self): - """The callback function when exiting context. - """ - pass - def _post_init_method(self, module: torch.nn.Module): """ The function to call at the end of the constructor of each module. diff --git a/colossalai/utils/model/pipelinable.py b/colossalai/utils/model/pipelinable.py new file mode 100644 index 000000000..ba5bbddb3 --- /dev/null +++ b/colossalai/utils/model/pipelinable.py @@ -0,0 +1,211 @@ +import torch +import functools +from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str +from colossalai.builder.pipeline import partition_uniform, partition_balanced +from colossalai.core import global_context as gpc + + +class PipelinableContext(InsertPostInitMethodToModuleSubClasses): + + def __init__(self): + super().__init__() + self._layer_spec_dict = {} + self._root_children = None + self._model = None + self._layer_spec_list = [] + self._func_dict = {} + self._policy = "balanced" + + @property + def policy(self): + return self._policy + + @property + def layers_count(self): + return len(self._layer_spec_list) + + @property + def funcs_count(self): + return len(self._func_dict) + + def _pre_context_exec(self): + """ + The Callback function when entering the context + """ + + # reserve rng states + self.cpu_rng_state = torch.get_rng_state() + self.cuda_rng_state = torch.cuda.get_rng_state() + + def _post_context_exec(self): + """ + The callback function when exiting context. + """ + + # reset rng states + torch.set_rng_state(self.cpu_rng_state) + torch.cuda.set_rng_state(self.cuda_rng_state) + + def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): + """ + The function to call at the end of the constructor of each module. + NOTE() The module may be passed to this function multiple times. + """ + module_id = id(module) + modified_args = [] + for obj in args: + if issubclass(obj.__class__, torch.nn.modules.module.Module): + obj = self._layer_spec_dict[id(obj)] + modified_args.append(obj) + # (lyl)TODO: analyse kwargs as well + modified_args = tuple(modified_args) + self._root_children = list(module.children()) + self._model = module + layer_spec = LayerSpec(module.__class__, *modified_args, **kwargs) + layer_spec.set_children(module.children()) + self._layer_spec_dict[module_id] = layer_spec + for param in module.parameters(recurse=False): + param.data = torch.rand(1, 1) + + def to_layer_list(self, exec_seq=None): + """ + Create a layer spec list and func list with execution sequence given by user. + If exec_seq is None, we will take the module initizing order as execution order. + """ + if exec_seq is None: + #if user do not provide the model executing sequence, we use the initialization order as the executing order. + for child in self._root_children: + layer_spec = self._layer_spec_dict[id(child)] + if layer_spec.typename in (torch.nn.modules.container.ModuleList, + torch.nn.modules.container.Sequential): + for child_in_container in layer_spec.children: + self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)]) + + else: + self._layer_spec_list.append(layer_spec) + + else: + func_key = "first" + for index, element in enumerate(exec_seq): + if isinstance(element, str): + module = dict(self._model.named_modules())[element] + layer_spec = self._layer_spec_dict[id(module)] + func_key = layer_spec + self._layer_spec_list.append(layer_spec) + else: + if func_key not in self._func_dict: + self._func_dict[func_key] = [] + self._func_dict[func_key].append(element) + + def partition(self, num_chunks, pipeline_size, rank): + """ + Partitioned model will be built respect to partion policy. + The real module instance will be built in this method. + """ + if isinstance(self._policy, str): + if self._policy == "uniform": + parts = partition_uniform(len(self._layer_spec_list), pipeline_size, num_chunks)[rank] + elif self._policy == "balanced": + param_counts = [] + for layer_spec in self._layer_spec_list: + param_counts.append(layer_spec.count_params()) + parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank] + else: + raise ValueError("A string partition policy should be one of ['uniform', 'balanced'].") + elif isinstance(self._policy, dict): + parts = self._policy[rank] + else: + raise ValueError("A partition policy should be either a string or a dictionary.") + + layers_to_build = [] + for start, end in parts: + layers_to_build += self._layer_spec_list[start:end] + func_dict_in_partition = {} + module_list_in_partition = [] + if rank == 0 and "first" in self._func_dict: + func_dict_in_partition["first"] = self._func_dict["first"] + for layer in layers_to_build: + module = layer.build() + module_list_in_partition.append(module) + if layer in self._func_dict: + func_dict_in_partition[id(module)] = self._func_dict[layer] + module_list_in_partition = torch.nn.ModuleList(module_list_in_partition) + pipeline_model = PipelinableModel(module_list_in_partition, func_dict_in_partition) + + return pipeline_model + + def load_policy(self, policy): + self._policy = policy + + +class PipelinableModel(torch.nn.Module): + + def __init__(self, module_list, func_dict): + super().__init__() + self._module_list = module_list + self._func_dict = func_dict + + def forward(self, input_tensor): + if "first" in self._func_dict: + funcs = self._func_dict["first"] + if isinstance(funcs, list): + for f in funcs: + input_tensor = f(input_tensor) + else: + input_tensor = funcs(input_tensor) + + for module in self._module_list: + input_tensor = module(input_tensor) + if id(module) in self._func_dict: + funcs = self._func_dict[id(module)] + if isinstance(funcs, list): + for f in funcs: + input_tensor = f(input_tensor) + else: + input_tensor = funcs(input_tensor) + + return input_tensor + + +class LayerSpec: + + def __init__(self, typename, *module_args, **module_kwargs): + self.typename = typename + self.module_args = module_args + self.module_kwargs = module_kwargs + self.children = None + self._param_count = 0 + + if not issubclass(typename, torch.nn.Module): + raise RuntimeError('LayerSpec only supports torch.nn.Module types.') + + def __repr__(self): + return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs) + + @property + def param_count(self): + return self._param_count + + def build(self): + """Build the stored specification.""" + + recovered_args = [] + for obj in self.module_args: + if isinstance(obj, LayerSpec): + obj = obj.build() + recovered_args.append(obj) + recovered_args = tuple(recovered_args) + return self.typename(*recovered_args, **self.module_kwargs) + + def set_children(self, children): + self.children = children + + def count_params(self): + self._param_count = 0 + layer = self.build() + for param in layer.parameters(): + self._param_count += param.numel() + return self._param_count + + def reset_param_count(self): + self._param_count = 0 diff --git a/colossalai/utils/model/utils.py b/colossalai/utils/model/utils.py index ced1365ec..50a75a110 100644 --- a/colossalai/utils/model/utils.py +++ b/colossalai/utils/model/utils.py @@ -9,6 +9,28 @@ def _substitute_init_recursively(cls, func): func(subcls) +def call_to_str(base, *args, **kwargs): + """Construct a string representation of a call. + + Args: + base (str): name of the call + args (tuple, optional): args to ``base`` + kwargs (dict, optional): kwargs supplied to ``base`` + + Returns: + str: A string representation of base(*args, **kwargs) + """ + name = f'{base}(' + if args: + name += ', '.join(repr(arg) for arg in args) + if kwargs: + name += ', ' + if kwargs: + name += ', '.join(f'{key}={repr(arg)}' for key, arg in kwargs.items()) + name += ')' + return name + + class InsertPostInitMethodToModuleSubClasses(object): def __init__(self, default_dtype: Optional[torch.dtype] = None): @@ -28,7 +50,7 @@ class InsertPostInitMethodToModuleSubClasses(object): @functools.wraps(f) def wrapper(module: torch.nn.Module, *args, **kwargs): f(module, *args, **kwargs) - self._post_init_method(module) + self._post_init_method(module, *args, **kwargs) return wrapper @@ -71,7 +93,7 @@ class InsertPostInitMethodToModuleSubClasses(object): return False # To be implemented by inheriting classes - def _post_init_method(self, module): + def _post_init_method(self, module, *args, **kwargs): pass def _pre_context_exec(self): diff --git a/tests/test_utils/test_pipelinable.py b/tests/test_utils/test_pipelinable.py new file mode 100644 index 000000000..2be3b264c --- /dev/null +++ b/tests/test_utils/test_pipelinable.py @@ -0,0 +1,64 @@ +import os.path as osp + +import pytest +import torch +import torch.multiprocessing as mp + +from colossalai.utils.model.pipelinable import PipelinableContext + +from functools import partial +from colossalai.utils import free_port +from colossalai.testing import rerun_on_exception + +NUM_CHUNKS = 1 +PIPELINE_SIZE = 2 + + +class MLP(torch.nn.Module): + + def __init__(self, dim: int = 256): + super().__init__() + intermediate_dim = dim * 4 + self.dense_1 = torch.nn.Linear(dim, intermediate_dim) + self.activation = torch.nn.GELU() + self.dense_2 = torch.nn.Linear(intermediate_dim, dim) + self.dropout = torch.nn.Dropout(0.1) + + def forward(self, x): + x = self.dense_1(x) + x = self.activation(x) + x = self.dense_2(x) + x = self.dropout(x) + return x + + +def run_pipelinable(rank): + pipelinable = PipelinableContext() + with pipelinable: + model = MLP() + + assert pipelinable.policy == "balanced" + pipelinable.load_policy("uniform") + assert pipelinable.policy == "uniform" + pipelinable.to_layer_list() + + assert pipelinable.layers_count == len(list(model.children())) + + pipeline_model_part_0 = pipelinable.partition(NUM_CHUNKS, PIPELINE_SIZE, 0) + assert isinstance(pipeline_model_part_0, torch.nn.Module) + pipeline_model_part_1 = pipelinable.partition(NUM_CHUNKS, PIPELINE_SIZE, 1) + assert isinstance(pipeline_model_part_1, torch.nn.Module) + + layers_count_in_part_0 = len(list(pipeline_model_part_0._module_list)) + layers_count_in_part_1 = len(list(pipeline_model_part_1._module_list)) + + assert layers_count_in_part_0 + layers_count_in_part_1 == pipelinable.layers_count + + +@rerun_on_exception(exception_type=mp.ProcessRaisedException, pattern=".*Address already in use.*") +def test_pipelinable(): + mp.spawn(run_pipelinable, nprocs=1) + + +if __name__ == '__main__': + test_pipelinable()