mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
261 lines
11 KiB
261 lines
11 KiB
3 years ago
|
import torch
|
||
3 years ago
|
import inspect
|
||
3 years ago
|
from colossalai.utils.model.utils import InsertPostInitMethodToModuleSubClasses
|
||
|
from .utils import partition_uniform, partition_balanced, build_kwargs_for_function, build_kwargs_for_module, exec_func_with_kwargs, exec_funcs_with_kwargs
|
||
3 years ago
|
from colossalai.nn.layer.utils import CheckpointModule
|
||
3 years ago
|
from colossalai.tensor import ColoParameter
|
||
|
from .layer_sepc import LayerSpec
|
||
3 years ago
|
|
||
|
|
||
|
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
||
3 years ago
|
"""
|
||
|
A context manager to split the model into pipeline stages.
|
||
|
"""
|
||
3 years ago
|
|
||
3 years ago
|
def __init__(self, policy: str="balanced"):
|
||
3 years ago
|
super().__init__()
|
||
|
self._layer_spec_dict = {}
|
||
|
self._root_children = None
|
||
|
self._model = None
|
||
|
self._layer_spec_list = []
|
||
|
self._func_dict = {}
|
||
3 years ago
|
self._policy = policy
|
||
3 years ago
|
|
||
|
@property
|
||
|
def policy(self):
|
||
|
return self._policy
|
||
|
|
||
3 years ago
|
@policy.setter
|
||
|
def policy(self, policy: str):
|
||
|
self._policy = policy
|
||
|
|
||
3 years ago
|
@property
|
||
|
def layers_count(self):
|
||
|
return len(self._layer_spec_list)
|
||
|
|
||
|
@property
|
||
|
def funcs_count(self):
|
||
|
return len(self._func_dict)
|
||
|
|
||
|
def _pre_context_exec(self):
|
||
3 years ago
|
"""
|
||
3 years ago
|
The Callback function when entering the context
|
||
|
"""
|
||
|
# reserve rng states
|
||
|
self.cpu_rng_state = torch.get_rng_state()
|
||
|
self.cuda_rng_state = torch.cuda.get_rng_state()
|
||
|
|
||
|
def _post_context_exec(self):
|
||
|
"""
|
||
|
The callback function when exiting context.
|
||
|
"""
|
||
|
|
||
|
# reset rng states
|
||
|
torch.set_rng_state(self.cpu_rng_state)
|
||
|
torch.cuda.set_rng_state(self.cuda_rng_state)
|
||
|
|
||
|
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||
|
"""
|
||
|
The function to call at the end of the constructor of each module.
|
||
|
NOTE() The module may be passed to this function multiple times.
|
||
|
"""
|
||
3 years ago
|
# iterate over the positional arguments
|
||
|
# to check if an argument is a torch Module
|
||
|
# if found any torch Module, replace it with its layer spec
|
||
|
# for storage purpose
|
||
3 years ago
|
modified_args = []
|
||
3 years ago
|
for arg in args:
|
||
|
if isinstance(arg, torch.nn.Module):
|
||
|
arg = self._layer_spec_dict[id(arg)]
|
||
|
modified_args.append(arg)
|
||
3 years ago
|
|
||
3 years ago
|
# to the same for the keyword arguments
|
||
3 years ago
|
modified_kwargs = {}
|
||
|
for k, v in kwargs.items():
|
||
3 years ago
|
if isinstance(v, torch.nn.Module):
|
||
3 years ago
|
v = self._layer_spec_dict[id(v)]
|
||
|
# (lyl)TODO: analyse ColoTensor as well
|
||
|
modified_kwargs[k] = v
|
||
|
|
||
3 years ago
|
# keep track of the module children
|
||
|
# as torch.nn.Module.__init__ is called from inner module to outer module,
|
||
|
# the final value of self._model will be the outermost model
|
||
|
# e.g. if the model is torchvision.models.resnet18, then the final value of self._model
|
||
|
# will be the ``ResNet`` object.
|
||
3 years ago
|
self._root_children = list(module.children())
|
||
|
self._model = module
|
||
3 years ago
|
|
||
|
# store the children to keep the module hierarchy
|
||
3 years ago
|
layer_spec = LayerSpec(module.__class__, *modified_args, **modified_kwargs)
|
||
3 years ago
|
layer_spec.set_children(module.children())
|
||
3 years ago
|
|
||
|
# store the layer spec in this context
|
||
|
module_id = id(module)
|
||
3 years ago
|
self._layer_spec_dict[module_id] = layer_spec
|
||
3 years ago
|
|
||
|
# convert all torch.nn.Parameter to colossalai.tensor.ColoParameter
|
||
3 years ago
|
name_list = []
|
||
|
for name, param in module.named_parameters():
|
||
3 years ago
|
if isinstance(param, ColoParameter):
|
||
3 years ago
|
continue
|
||
|
name_list.append((name, param))
|
||
|
|
||
|
for name, param in name_list:
|
||
|
delattr(module, name)
|
||
3 years ago
|
setattr(module, name, ColoParameter.from_torch_tensor(tensor=param.data, requires_grad=param.requires_grad))
|
||
3 years ago
|
|
||
|
def to_layer_list(self, exec_seq=None):
|
||
|
"""
|
||
|
Create a layer spec list and func list with execution sequence given by user.
|
||
|
If exec_seq is None, we will take the module initizing order as execution order.
|
||
|
"""
|
||
|
if exec_seq is None:
|
||
3 years ago
|
# if user do not provide the model executing sequence, we use the initialization order as the executing order.
|
||
3 years ago
|
children_name = []
|
||
3 years ago
|
for child in self._root_children:
|
||
|
layer_spec = self._layer_spec_dict[id(child)]
|
||
|
if layer_spec.typename in (torch.nn.modules.container.ModuleList,
|
||
|
torch.nn.modules.container.Sequential):
|
||
|
for child_in_container in layer_spec.children:
|
||
|
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
|
||
3 years ago
|
for name, module in self._model.named_modules():
|
||
|
if id(module) == id(child_in_container):
|
||
|
children_name.append(name)
|
||
|
break
|
||
3 years ago
|
else:
|
||
|
self._layer_spec_list.append(layer_spec)
|
||
3 years ago
|
for name, module in self._model.named_modules():
|
||
|
if id(module) == id(child):
|
||
|
children_name.append(name)
|
||
|
break
|
||
3 years ago
|
|
||
|
else:
|
||
3 years ago
|
front_funcs_list = []
|
||
3 years ago
|
named_modules = dict(self._model.named_modules())
|
||
3 years ago
|
for index, element in enumerate(exec_seq):
|
||
|
if isinstance(element, str):
|
||
3 years ago
|
assert element in named_modules, f'Found invalid module name {element}, please check if you spell the module name correctly.'
|
||
|
|
||
|
# get the layer spec based on the module ID
|
||
|
module = named_modules[element]
|
||
3 years ago
|
layer_spec = self._layer_spec_dict[id(module)]
|
||
3 years ago
|
|
||
|
# check whether there are functions which should be executed before this module
|
||
3 years ago
|
if len(front_funcs_list) != 0:
|
||
|
func_key = (layer_spec, "front")
|
||
|
if func_key not in self._func_dict:
|
||
|
self._func_dict[func_key] = []
|
||
|
for f in front_funcs_list:
|
||
|
self._func_dict[func_key].append(f)
|
||
|
front_funcs_list = []
|
||
3 years ago
|
|
||
3 years ago
|
func_key = (layer_spec, "behind")
|
||
3 years ago
|
self._layer_spec_list.append(layer_spec)
|
||
3 years ago
|
elif isinstance(element, tuple) and element[1] == "front":
|
||
|
front_funcs_list.append(element[0])
|
||
3 years ago
|
else:
|
||
|
if func_key not in self._func_dict:
|
||
|
self._func_dict[func_key] = []
|
||
3 years ago
|
if isinstance(element, tuple):
|
||
|
self._func_dict[func_key].append(element[0])
|
||
|
else:
|
||
|
self._func_dict[func_key].append(element)
|
||
3 years ago
|
|
||
|
def partition(self, num_chunks, pipeline_size, rank):
|
||
|
"""
|
||
|
Partitioned model will be built respect to partion policy.
|
||
|
The real module instance will be built in this method.
|
||
|
"""
|
||
|
if isinstance(self._policy, str):
|
||
|
if self._policy == "uniform":
|
||
|
parts = partition_uniform(len(self._layer_spec_list), pipeline_size, num_chunks)[rank]
|
||
|
elif self._policy == "balanced":
|
||
|
param_counts = []
|
||
|
for layer_spec in self._layer_spec_list:
|
||
|
param_counts.append(layer_spec.count_params())
|
||
|
parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
|
||
|
else:
|
||
|
raise ValueError("A string partition policy should be one of ['uniform', 'balanced'].")
|
||
|
elif isinstance(self._policy, dict):
|
||
|
parts = self._policy[rank]
|
||
|
else:
|
||
|
raise ValueError("A partition policy should be either a string or a dictionary.")
|
||
|
|
||
|
layers_to_build = []
|
||
|
for start, end in parts:
|
||
|
layers_to_build += self._layer_spec_list[start:end]
|
||
3 years ago
|
behind_func_dict_in_partition = {}
|
||
|
front_func_dict_in_partition = {}
|
||
3 years ago
|
module_list_in_partition = []
|
||
|
for layer in layers_to_build:
|
||
|
module = layer.build()
|
||
|
module_list_in_partition.append(module)
|
||
3 years ago
|
if (layer, "front") in self._func_dict:
|
||
|
front_func_dict_in_partition[id(module)] = self._func_dict[(layer, "front")]
|
||
|
elif (layer, "behind") in self._func_dict:
|
||
|
behind_func_dict_in_partition[id(module)] = self._func_dict[(layer, "behind")]
|
||
3 years ago
|
module_list_in_partition = torch.nn.ModuleList(module_list_in_partition)
|
||
3 years ago
|
pipeline_model = PipelinableModel(module_list_in_partition, front_func_dict_in_partition,
|
||
|
behind_func_dict_in_partition)
|
||
3 years ago
|
|
||
|
return pipeline_model
|
||
|
|
||
3 years ago
|
|
||
3 years ago
|
class PipelinableModel(torch.nn.Module):
|
||
|
|
||
3 years ago
|
def __init__(self, module_list, front_func_dict, behind_func_dict):
|
||
3 years ago
|
super().__init__()
|
||
|
self._module_list = module_list
|
||
3 years ago
|
self._front_func_dict = front_func_dict
|
||
|
self._behind_func_dict = behind_func_dict
|
||
|
|
||
|
def forward(self, input_tensor, **kwargs):
|
||
3 years ago
|
|
||
|
for module in self._module_list:
|
||
3 years ago
|
|
||
|
if id(module) in self._front_func_dict:
|
||
3 years ago
|
input_tensor = exec_funcs_with_kwargs(self._front_func_dict, id(module), input_tensor, kwargs)
|
||
3 years ago
|
|
||
|
if isinstance(module, CheckpointModule):
|
||
|
forward_func = module._forward
|
||
|
else:
|
||
|
forward_func = module.forward
|
||
|
if input_tensor is None:
|
||
3 years ago
|
module_kwargs = build_kwargs_for_function(forward_func, kwargs)
|
||
3 years ago
|
else:
|
||
3 years ago
|
module_kwargs = build_kwargs_for_module(forward_func, kwargs)
|
||
3 years ago
|
if module_kwargs is not None and input_tensor is not None:
|
||
|
if isinstance(module, CheckpointModule):
|
||
|
convert_kwargs_to_args = []
|
||
|
for v in module_kwargs.values():
|
||
|
convert_kwargs_to_args.append(v)
|
||
|
rst = module(input_tensor, *convert_kwargs_to_args)
|
||
|
else:
|
||
|
rst = module(input_tensor, **module_kwargs)
|
||
|
if isinstance(rst, tuple):
|
||
|
input_tensor = rst[0]
|
||
|
else:
|
||
|
input_tensor = rst
|
||
|
elif module_kwargs is not None and input_tensor is None:
|
||
|
if isinstance(module, CheckpointModule):
|
||
|
convert_kwargs_to_args = []
|
||
|
for v in module_kwargs.values():
|
||
|
convert_kwargs_to_args.append(v)
|
||
|
rst = module(input_tensor, *convert_kwargs_to_args)
|
||
3 years ago
|
else:
|
||
3 years ago
|
rst = module(**module_kwargs)
|
||
|
if isinstance(rst, tuple):
|
||
|
input_tensor = rst[0]
|
||
|
else:
|
||
|
input_tensor = rst
|
||
|
else:
|
||
|
input_tensor = module(input_tensor)
|
||
|
|
||
|
if id(module) in self._behind_func_dict:
|
||
3 years ago
|
input_tensor = exec_funcs_with_kwargs(self._behind_func_dict, id(module), input_tensor, kwargs)
|
||
3 years ago
|
|
||
|
return input_tensor
|
||
|
|
||
|
|
||
|
|