import torch import inspect from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs from colossalai.nn.layer.utils import CheckpointModule from colossalai.tensor import ColoParameter from .layer_sepc import LayerSpec class PipelinableContext(InsertPostInitMethodToModuleSubClasses): """ A context manager to split the model into pipeline stages. """ def __init__(self, policy: str="balanced"): super().__init__() self._layer_spec_dict = {} self._root_children = None self._model = None self._layer_spec_list = [] self._func_dict = {} self._policy = policy @property def policy(self): return self._policy @policy.setter def policy(self, policy: str): self._policy = policy @property def layers_count(self): return len(self._layer_spec_list) @property def funcs_count(self): return len(self._func_dict) def _pre_context_exec(self): """ The Callback function when entering the context """ # reserve rng states self.cpu_rng_state = torch.get_rng_state() self.cuda_rng_state = torch.cuda.get_rng_state() def _post_context_exec(self): """ The callback function when exiting context. """ # reset rng states torch.set_rng_state(self.cpu_rng_state) torch.cuda.set_rng_state(self.cuda_rng_state) def _post_init_method(self, module: torch.nn.Module, *args, **kwargs): """ The function to call at the end of the constructor of each module. NOTE() The module may be passed to this function multiple times. """ # iterate over the positional arguments # to check if an argument is a torch Module # if found any torch Module, replace it with its layer spec # for storage purpose modified_args = [] for arg in args: if isinstance(arg, torch.nn.Module): arg = self._layer_spec_dict[id(arg)] modified_args.append(arg) # to the same for the keyword arguments modified_kwargs = {} for k, v in kwargs.items(): if isinstance(v, torch.nn.Module): v = self._layer_spec_dict[id(v)] # (lyl)TODO: analyse ColoTensor as well modified_kwargs[k] = v # keep track of the module children # as torch.nn.Module.__init__ is called from inner module to outer module, # the final value of self._model will be the outermost model # e.g. if the model is torchvision.models.resnet18, then the final value of self._model # will be the ``ResNet`` object. self._root_children = list(module.children()) self._model = module # store the children to keep the module hierarchy layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs) layer_spec.set_children(module.children()) # store the layer spec in this context module_id = id(module) self._layer_spec_dict[module_id] = layer_spec # convert all torch.nn.Parameter to colossalai.tensor.ColoParameter name_list = [] for name, param in module.named_parameters(): if isinstance(param, ColoParameter): continue name_list.append((name, param)) for name, param in name_list: delattr(module, name) setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad)) def to_layer_list(self, exec_seq=None): """ Create a layer spec list and func list with execution sequence given by user. If exec_seq is None, we will take the module initizing order as execution order. """ if exec_seq is None: # if user do not provide the model executing sequence, we use the initialization order as the executing order. children_name = [] for child in self._root_children: layer_spec = self._layer_spec_dict[id(child)] if layer_spec.typename in (torch.nn.modules.container.ModuleList, torch.nn.modules.container.Sequential): for child_in_container in layer_spec.children: self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)]) for name, module in self._model.named_modules(): if id(module) == id(child_in_container): children_name.append(name) break else: self._layer_spec_list.append(layer_spec) for name, module in self._model.named_modules(): if id(module) == id(child): children_name.append(name) break else: front_funcs_list = [] named_modules = dict(self._model.named_modules()) for index, element in enumerate(exec_seq): if isinstance(element, str): assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.' # get the layer spec based on the module ID module = named_modules[element] layer_spec = self._layer_spec_dict[id(module)] # check whether there are functions which should be executed before this module if len(front_funcs_list) != 0: func_key = (layer_spec, "front") if func_key not in self._func_dict: self._func_dict[func_key] = [] for f in front_funcs_list: self._func_dict[func_key].append(f) front_funcs_list = [] func_key = (layer_spec, "behind") self._layer_spec_list.append(layer_spec) elif isinstance(element, tuple) and element[1] == "front": front_funcs_list.append(element[0]) else: if func_key not in self._func_dict: self._func_dict[func_key] = [] if isinstance(element, tuple): self._func_dict[func_key].append(element[0]) else: self._func_dict[func_key].append(element) def partition(self, num_chunks, pipeline_size, rank): """ Partitioned model will be built respect to partion policy. The real module instance will be built in this method. """ if isinstance(self._policy, str): if self._policy == "uniform": parts = partition_uniform(len(self._layer_spec_list), pipeline_size, num_chunks)[rank] elif self._policy == "balanced": param_counts = [] for layer_spec in self._layer_spec_list: param_counts.append(layer_spec.count_params()) parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank] else: raise ValueError("A string partition policy should be one of ['uniform', 'balanced'].") elif isinstance(self._policy, dict): parts = self._policy[rank] else: raise ValueError("A partition policy should be either a string or a dictionary.") layers_to_build = [] for start, end in parts: layers_to_build += self._layer_spec_list[start:end] behind_func_dict_in_partition = {} front_func_dict_in_partition = {} module_list_in_partition = [] for layer in layers_to_build: module = layer.build() module_list_in_partition.append(module) if (layer, "front") in self._func_dict: front_func_dict_in_partition[id(module)] = self._func_dict[(layer, "front")] elif (layer, "behind") in self._func_dict: behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")] module_list_in_partition = torch.nn.ModuleList(module_list_in_partition) pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition, behind_func_dict_in_partition) return pipeline_model class PipelinableModel(torch.nn.Module): def __init__(self, module_list, front_func_dict, behind_func_dict): super().__init__() self._module_list = module_list self._front_func_dict = front_func_dict self._behind_func_dict = behind_func_dict def forward(self, input_tensor, **kwargs): for module in self._module_list: if id(module) in self._front_func_dict: input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs) if isinstance(module, CheckpointModule): forward_func = module._forward else: forward_func = module.forward if input_tensor is None: module_kwargs = build_kwargs_for_function(forward_func, kwargs) else: module_kwargs = build_kwargs_for_module(forward_func, kwargs) if module_kwargs is not None and input_tensor is not None: if isinstance(module, CheckpointModule): convert_kwargs_to_args = [] for v in module_kwargs.values(): convert_kwargs_to_args.append(v) rst = module(input_tensor, *convert_kwargs_to_args) else: rst = module(input_tensor, **module_kwargs) if isinstance(rst, tuple): input_tensor = rst[0] else: input_tensor = rst elif module_kwargs is not None and input_tensor is None: if isinstance(module, CheckpointModule): convert_kwargs_to_args = [] for v in module_kwargs.values(): convert_kwargs_to_args.append(v) rst = module(input_tensor, *convert_kwargs_to_args) else: rst = module(**module_kwargs) if isinstance(rst, tuple): input_tensor = rst[0] else: input_tensor = rst else: input_tensor = module(input_tensor) if id(module) in self._behind_func_dict: input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs) return input_tensor