mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
220 lines
8.0 KiB
220 lines
8.0 KiB
import torch
|
|
import functools
|
|
from colossalai.utils.model.utils import _substitute_init_recursively, InsertPostInitMethodToModuleSubClasses, call_to_str
|
|
from colossalai.builder.pipeline import partition_uniform, partition_balanced
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.tensor import ColoTensor
|
|
|
|
|
|
class PipelinableContext(InsertPostInitMethodToModuleSubClasses):
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._layer_spec_dict = {}
|
|
self._root_children = None
|
|
self._model = None
|
|
self._layer_spec_list = []
|
|
self._func_dict = {}
|
|
self._policy = "balanced"
|
|
|
|
@property
|
|
def policy(self):
|
|
return self._policy
|
|
|
|
@property
|
|
def layers_count(self):
|
|
return len(self._layer_spec_list)
|
|
|
|
@property
|
|
def funcs_count(self):
|
|
return len(self._func_dict)
|
|
|
|
def _pre_context_exec(self):
|
|
"""
|
|
The Callback function when entering the context
|
|
"""
|
|
|
|
# reserve rng states
|
|
self.cpu_rng_state = torch.get_rng_state()
|
|
self.cuda_rng_state = torch.cuda.get_rng_state()
|
|
|
|
def _post_context_exec(self):
|
|
"""
|
|
The callback function when exiting context.
|
|
"""
|
|
|
|
# reset rng states
|
|
torch.set_rng_state(self.cpu_rng_state)
|
|
torch.cuda.set_rng_state(self.cuda_rng_state)
|
|
|
|
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
|
"""
|
|
The function to call at the end of the constructor of each module.
|
|
NOTE() The module may be passed to this function multiple times.
|
|
"""
|
|
module_id = id(module)
|
|
modified_args = []
|
|
for obj in args:
|
|
if issubclass(obj.__class__, torch.nn.modules.module.Module):
|
|
obj = self._layer_spec_dict[id(obj)]
|
|
modified_args.append(obj)
|
|
# (lyl)TODO: analyse kwargs as well
|
|
modified_args = tuple(modified_args)
|
|
self._root_children = list(module.children())
|
|
self._model = module
|
|
layer_spec = LayerSpec(module.__class__, *modified_args, **kwargs)
|
|
layer_spec.set_children(module.children())
|
|
self._layer_spec_dict[module_id] = layer_spec
|
|
name_list = []
|
|
for name, param in module.named_parameters():
|
|
if isinstance(param, ColoTensor):
|
|
continue
|
|
name_list.append((name, param))
|
|
|
|
for name, param in name_list:
|
|
delattr(module, name)
|
|
setattr(module, name, ColoTensor.init_from_torch_tensor(tensor=param, save_payload=False))
|
|
|
|
def to_layer_list(self, exec_seq=None):
|
|
"""
|
|
Create a layer spec list and func list with execution sequence given by user.
|
|
If exec_seq is None, we will take the module initizing order as execution order.
|
|
"""
|
|
if exec_seq is None:
|
|
#if user do not provide the model executing sequence, we use the initialization order as the executing order.
|
|
for child in self._root_children:
|
|
layer_spec = self._layer_spec_dict[id(child)]
|
|
if layer_spec.typename in (torch.nn.modules.container.ModuleList,
|
|
torch.nn.modules.container.Sequential):
|
|
for child_in_container in layer_spec.children:
|
|
self._layer_spec_list.append(self._layer_spec_dict[id(child_in_container)])
|
|
|
|
else:
|
|
self._layer_spec_list.append(layer_spec)
|
|
|
|
else:
|
|
func_key = "first"
|
|
for index, element in enumerate(exec_seq):
|
|
if isinstance(element, str):
|
|
module = dict(self._model.named_modules())[element]
|
|
layer_spec = self._layer_spec_dict[id(module)]
|
|
func_key = layer_spec
|
|
self._layer_spec_list.append(layer_spec)
|
|
else:
|
|
if func_key not in self._func_dict:
|
|
self._func_dict[func_key] = []
|
|
self._func_dict[func_key].append(element)
|
|
|
|
def partition(self, num_chunks, pipeline_size, rank):
|
|
"""
|
|
Partitioned model will be built respect to partion policy.
|
|
The real module instance will be built in this method.
|
|
"""
|
|
if isinstance(self._policy, str):
|
|
if self._policy == "uniform":
|
|
parts = partition_uniform(len(self._layer_spec_list), pipeline_size, num_chunks)[rank]
|
|
elif self._policy == "balanced":
|
|
param_counts = []
|
|
for layer_spec in self._layer_spec_list:
|
|
param_counts.append(layer_spec.count_params())
|
|
parts = partition_balanced(param_counts, pipeline_size, num_chunks)[rank]
|
|
else:
|
|
raise ValueError("A string partition policy should be one of ['uniform', 'balanced'].")
|
|
elif isinstance(self._policy, dict):
|
|
parts = self._policy[rank]
|
|
else:
|
|
raise ValueError("A partition policy should be either a string or a dictionary.")
|
|
|
|
layers_to_build = []
|
|
for start, end in parts:
|
|
layers_to_build += self._layer_spec_list[start:end]
|
|
func_dict_in_partition = {}
|
|
module_list_in_partition = []
|
|
if rank == 0 and "first" in self._func_dict:
|
|
func_dict_in_partition["first"] = self._func_dict["first"]
|
|
for layer in layers_to_build:
|
|
module = layer.build()
|
|
module_list_in_partition.append(module)
|
|
if layer in self._func_dict:
|
|
func_dict_in_partition[id(module)] = self._func_dict[layer]
|
|
module_list_in_partition = torch.nn.ModuleList(module_list_in_partition)
|
|
pipeline_model = PipelinableModel(module_list_in_partition, func_dict_in_partition)
|
|
|
|
return pipeline_model
|
|
|
|
def load_policy(self, policy):
|
|
self._policy = policy
|
|
|
|
|
|
class PipelinableModel(torch.nn.Module):
|
|
|
|
def __init__(self, module_list, func_dict):
|
|
super().__init__()
|
|
self._module_list = module_list
|
|
self._func_dict = func_dict
|
|
|
|
def forward(self, input_tensor):
|
|
if "first" in self._func_dict:
|
|
funcs = self._func_dict["first"]
|
|
if isinstance(funcs, list):
|
|
for f in funcs:
|
|
input_tensor = f(input_tensor)
|
|
else:
|
|
input_tensor = funcs(input_tensor)
|
|
|
|
for module in self._module_list:
|
|
input_tensor = module(input_tensor)
|
|
if id(module) in self._func_dict:
|
|
funcs = self._func_dict[id(module)]
|
|
if isinstance(funcs, list):
|
|
for f in funcs:
|
|
input_tensor = f(input_tensor)
|
|
else:
|
|
input_tensor = funcs(input_tensor)
|
|
|
|
return input_tensor
|
|
|
|
|
|
class LayerSpec:
|
|
|
|
def __init__(self, typename, *module_args, **module_kwargs):
|
|
self.typename = typename
|
|
self.module_args = module_args
|
|
self.module_kwargs = module_kwargs
|
|
self.children = None
|
|
self._param_count = 0
|
|
|
|
if not issubclass(typename, torch.nn.Module):
|
|
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
|
|
|
|
def __repr__(self):
|
|
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
|
|
|
|
@property
|
|
def param_count(self):
|
|
return self._param_count
|
|
|
|
def build(self):
|
|
"""Build the stored specification."""
|
|
|
|
recovered_args = []
|
|
for obj in self.module_args:
|
|
if isinstance(obj, LayerSpec):
|
|
obj = obj.build()
|
|
recovered_args.append(obj)
|
|
recovered_args = tuple(recovered_args)
|
|
return self.typename(*recovered_args, **self.module_kwargs)
|
|
|
|
def set_children(self, children):
|
|
self.children = children
|
|
|
|
def count_params(self):
|
|
self._param_count = 0
|
|
layer = self.build()
|
|
for param in layer.parameters():
|
|
self._param_count += param.numel()
|
|
return self._param_count
|
|
|
|
def reset_param_count(self):
|
|
self._param_count = 0
|