You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/pipeline/pipelinable.py

261 lines
11 KiB

import torch
import inspect
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs
from colossalai.nn.layer.utils import CheckpointModule
from colossalai.tensor import ColoParameter
from .layer_sepc import LayerSpec
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
"""
A context manager to split the model into pipeline stages.
"""
def __init__(self, policy: str="balanced"):
super().__init__()
self._layer_spec_dict = {}
self._root_children = None
self._model = None
self._layer_spec_list = []
self._func_dict = {}
self._policy = policy
@property
def policy(self):
return self._policy
@policy.setter
def policy(self, policy: str):
self._policy = policy
@property
def layers_count(self):
return len(self._layer_spec_list)
@property
def funcs_count(self):
return len(self._func_dict)
def _pre_context_exec(self):
"""
The Callback function when entering the context
"""
# reserve rng states
self.cpu_rng_state = torch.get_rng_state()
self.cuda_rng_state = torch.cuda.get_rng_state()
def _post_context_exec(self):
"""
The callback function when exiting context.
"""
# reset rng states
torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state)
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.
"""
# iterate over the positional arguments
# to check if an argument is a torch Module
# if found any torch Module, replace it with its layer spec
# for storage purpose
modified_args = []
for arg in args:
if isinstance(arg, torch.nn.Module):
arg = self._layer_spec_dict[id(arg)]
modified_args.append(arg)
# to the same for the keyword arguments
modified_kwargs = {}
for k, v in kwargs.items():
if isinstance(v, torch.nn.Module):
v = self._layer_spec_dict[id(v)]
# (lyl)TODO: analyse ColoTensor as well
modified_kwargs[k] = v
# keep track of the module children
# as torch.nn.Module.__init__ is called from inner module to outer module,
# the final value of self._model will be the outermost model
# e.g. if the model is torchvision.models.resnet18, then the final value of self._model
# will be the ``ResNet`` object.
self._root_children = list(module.children())
self._model = module
# store the children to keep the module hierarchy
layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs)
layer_spec.set_children(module.children())
# store the layer spec in this context
module_id = id(module)
self._layer_spec_dict[module_id] = layer_spec
# convert all torch.nn.Parameter to colossalai.tensor.ColoParameter
name_list = []
for name, param in module.named_parameters():
if isinstance(param, ColoParameter):
continue
name_list.append((name, param))
for name, param in name_list:
delattr(module, name)
setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad))
def to_layer_list(self, exec_seq=None):
"""
Create a layer spec list and func list with execution sequence given by user.
If exec_seq is None, we will take the module initizing order as execution order.
"""
if exec_seq is None:
# if user do not provide the model executing sequence, we use the initialization order as the executing order.
children_name = []
for child in self._root_children:
layer_spec = self._layer_spec_dict[id(child)]
if layer_spec.typename in (torch.nn.modules.container.ModuleList,
torch.nn.modules.container.Sequential):
for child_in_container in layer_spec.children:
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
for name, module in self._model.named_modules():
if id(module) == id(child_in_container):
children_name.append(name)
break
else:
self._layer_spec_list.append(layer_spec)
for name, module in self._model.named_modules():
if id(module) == id(child):
children_name.append(name)
break
else:
front_funcs_list = []
named_modules = dict(self._model.named_modules())
for index, element in enumerate(exec_seq):
if isinstance(element, str):
assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.'
# get the layer spec based on the module ID
module = named_modules[element]
layer_spec = self._layer_spec_dict[id(module)]
# check whether there are functions which should be executed before this module
if len(front_funcs_list) != 0:
func_key = (layer_spec, "front")
if func_key not in self._func_dict:
self._func_dict[func_key] = []
for f in front_funcs_list:
self._func_dict[func_key].append(f)
front_funcs_list = []
func_key = (layer_spec, "behind")
self._layer_spec_list.append(layer_spec)
elif isinstance(element, tuple) and element[1] == "front":
front_funcs_list.append(element[0])
else:
if func_key not in self._func_dict:
self._func_dict[func_key] = []
if isinstance(element, tuple):
self._func_dict[func_key].append(element[0])
else:
self._func_dict[func_key].append(element)
def partition(self, num_chunks, pipeline_size, rank):
"""
Partitioned model will be built respect to partion policy.
The real module instance will be built in this method.
"""
if isinstance(self._policy, str):
if self._policy == "uniform":
parts = partition_uniform(len(self._layer_spec_list), pipeline_size, num_chunks)[rank]
elif self._policy == "balanced":
param_counts = []
for layer_spec in self._layer_spec_list:
param_counts.append(layer_spec.count_params())
parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
else:
raise ValueError("A string partition policy should be one of ['uniform', 'balanced'].")
elif isinstance(self._policy, dict):
parts = self._policy[rank]
else:
raise ValueError("A partition policy should be either a string or a dictionary.")
layers_to_build = []
for start, end in parts:
layers_to_build += self._layer_spec_list[start:end]
behind_func_dict_in_partition = {}
front_func_dict_in_partition = {}
module_list_in_partition = []
for layer in layers_to_build:
module = layer.build()
module_list_in_partition.append(module)
if (layer, "front") in self._func_dict:
front_func_dict_in_partition[id(module)] = self._func_dict[(layer, "front")]
elif (layer, "behind") in self._func_dict:
behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")]
module_list_in_partition = torch.nn.ModuleList(module_list_in_partition)
pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition,
behind_func_dict_in_partition)
return pipeline_model
class PipelinableModel(torch.nn.Module):
def __init__(self, module_list, front_func_dict, behind_func_dict):
super().__init__()
self._module_list = module_list
self._front_func_dict = front_func_dict
self._behind_func_dict = behind_func_dict
def forward(self, input_tensor, **kwargs):
for module in self._module_list:
if id(module) in self._front_func_dict:
input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
if isinstance(module, CheckpointModule):
forward_func = module._forward
else:
forward_func = module.forward
if input_tensor is None:
module_kwargs = build_kwargs_for_function(forward_func, kwargs)
else:
module_kwargs = build_kwargs_for_module(forward_func, kwargs)
if module_kwargs is not None and input_tensor is not None:
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
for v in module_kwargs.values():
convert_kwargs_to_args.append(v)
rst = module(input_tensor, *convert_kwargs_to_args)
else:
rst = module(input_tensor, **module_kwargs)
if isinstance(rst, tuple):
input_tensor = rst[0]
else:
input_tensor = rst
elif module_kwargs is not None and input_tensor is None:
if isinstance(module, CheckpointModule):
convert_kwargs_to_args = []
for v in module_kwargs.values():
convert_kwargs_to_args.append(v)
rst = module(input_tensor, *convert_kwargs_to_args)
else:
rst = module(**module_kwargs)
if isinstance(rst, tuple):
input_tensor = rst[0]
else:
input_tensor = rst
else:
input_tensor = module(input_tensor)
if id(module) in self._behind_func_dict:
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
return input_tensor