import torch
import inspect
from colossalai . utils . model . utils import InsertPostInitMethodToModuleSubClasses
from . utils import partition_uniform , partition_balanced , build_kwargs_for_function , \
build_kwargs_for_module , exec_func_with_kwargs , exec_funcs_with_kwargs , \
call_module , customized_partition
from colossalai . nn . layer . utils import CheckpointModule
from colossalai . tensor import ColoParameter
from colossalai . core import global_context as gpc
from colossalai . context import ParallelMode
from . layer_sepc import LayerSpec
class PipelinableContext ( InsertPostInitMethodToModuleSubClasses ) :
"""
A context manager to split the model into pipeline stages .
"""
def __init__ ( self , policy : str = " balanced " ) :
super ( ) . __init__ ( )
self . _layer_spec_dict = { }
self . _root_children = None
self . _model = None
self . _layer_spec_list = [ ]
self . _func_dict = { }
self . _policy = policy
@property
def policy ( self ) :
return self . _policy
@policy.setter
def policy ( self , policy : str ) :
self . _policy = policy
@property
def layers_count ( self ) :
return len ( self . _layer_spec_list )
@property
def funcs_count ( self ) :
return len ( self . _func_dict )
def _pre_context_exec ( self ) :
"""
The Callback function when entering the context
"""
# reserve rng states
self . cpu_rng_state = torch . get_rng_state ( )
self . cuda_rng_state = torch . cuda . get_rng_state ( )
def _post_context_exec ( self ) :
"""
The callback function when exiting context .
"""
# reset rng states
torch . set_rng_state ( self . cpu_rng_state )
torch . cuda . set_rng_state ( self . cuda_rng_state )
def _post_init_method ( self , module : torch . nn . Module , * args , * * kwargs ) :
"""
The function to call at the end of the constructor of each module .
NOTE ( ) The module may be passed to this function multiple times .
"""
# iterate over the positional arguments
# to check if an argument is a torch Module
# if found any torch Module, replace it with its layer spec
# for storage purpose
modified_args = [ ]
for arg in args :
if isinstance ( arg , torch . nn . Module ) :
# if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build.
# if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context.
if id ( arg ) in self . _layer_spec_dict :
arg = self . _layer_spec_dict [ id ( arg ) ]
modified_args . append ( arg )
# to the same for the keyword arguments
modified_kwargs = { }
for k , v in kwargs . items ( ) :
if isinstance ( v , torch . nn . Module ) :
v = self . _layer_spec_dict [ id ( v ) ]
# (lyl)TODO: analyse ColoTensor as well
modified_kwargs [ k ] = v
# keep track of the module children
# as torch.nn.Module.__init__ is called from inner module to outer module,
# the final value of self._model will be the outermost model
# e.g. if the model is torchvision.models.resnet18, then the final value of self._model
# will be the ``ResNet`` object.
self . _root_children = list ( module . children ( ) )
self . _model = module
# store the children to keep the module hierarchy
layer_spec = LayerSpec ( module . __class__ , * modified_args , * * modified_kwargs )
layer_spec . set_children ( module . children ( ) )
# store the layer spec in this context
module_id = id ( module )
self . _layer_spec_dict [ module_id ] = layer_spec
# convert all torch.nn.Parameter to colossalai.tensor.ColoParameter
name_list = [ ]
for name , param in module . named_parameters ( ) :
if isinstance ( param , ColoParameter ) :
continue
name_list . append ( ( name , param ) )
for name , param in name_list :
if hasattr ( module , name ) :
delattr ( module , name )
setattr ( module , name , ColoParameter . from_torch_tensor ( tensor = param . data , requires_grad = param . requires_grad ) )
def to_layer_list ( self , exec_seq = None ) :
"""
Create a layer spec list and func list with execution sequence given by user .
If exec_seq is None , we will take the module initizing order as execution order .
"""
self . _exec_seq = exec_seq
if exec_seq is None :
# if user do not provide the model executing sequence, we use the initialization order as the executing order.
children_name = [ ]
for child in self . _root_children :
layer_spec = self . _layer_spec_dict [ id ( child ) ]
if layer_spec . typename in ( torch . nn . modules . container . ModuleList ,
torch . nn . modules . container . Sequential ) :
for child_in_container in layer_spec . children :
self . _layer_spec_list . append ( self . _layer_spec_dict [ id ( child_in_container ) ] )
for name , module in self . _model . named_modules ( ) :
if id ( module ) == id ( child_in_container ) :
children_name . append ( name )
break
else :
self . _layer_spec_list . append ( layer_spec )
for name , module in self . _model . named_modules ( ) :
if id ( module ) == id ( child ) :
children_name . append ( name )
break
else :
front_funcs_list = [ ]
named_modules = dict ( self . _model . named_modules ( ) )
for index , element in enumerate ( exec_seq ) :
if isinstance ( element , str ) :
if element == ' SPLIT_NODE ' :
continue
assert element in named_modules , f ' Found invalid module name { element } , please check if you spell the module name correctly. '
# get the layer spec based on the module ID
module = named_modules [ element ]
layer_spec = self . _layer_spec_dict [ id ( module ) ]
# check whether there are functions which should be executed before this module
if len ( front_funcs_list ) != 0 :
func_key = ( layer_spec , " front " )
if func_key not in self . _func_dict :
self . _func_dict [ func_key ] = [ ]
for f in front_funcs_list :
self . _func_dict [ func_key ] . append ( f )
front_funcs_list = [ ]
func_key = ( layer_spec , " behind " )
self . _layer_spec_list . append ( layer_spec )
elif isinstance ( element , tuple ) and element [ 1 ] == " front " :
front_funcs_list . append ( element [ 0 ] )
else :
if func_key not in self . _func_dict :
self . _func_dict [ func_key ] = [ ]
if isinstance ( element , tuple ) :
self . _func_dict [ func_key ] . append ( element [ 0 ] )
else :
self . _func_dict [ func_key ] . append ( element )
def partition ( self , num_chunks , pipeline_size , rank ) :
"""
Partitioned model will be built respect to partion policy .
The real module instance will be built in this method .
"""
if isinstance ( self . _policy , str ) :
if self . _policy == " uniform " :
parts = partition_uniform ( len ( self . _layer_spec_list ) , pipeline_size , num_chunks ) [ rank ]
elif self . _policy == " balanced " :
param_counts = [ ]
for layer_spec in self . _layer_spec_list :
param_counts . append ( layer_spec . count_params ( ) )
parts = partition_balanced ( param_counts , pipeline_size , num_chunks ) [ rank ]
elif self . _policy == " customized " :
assert self . _exec_seq is not None , f ' An explicit exec_seq must be defined by user in customized policy mode. '
self . customized_parts = customized_partition ( self . _exec_seq )
assert len ( self . customized_parts ) == gpc . get_world_size (
ParallelMode . PIPELINE
) , f ' World size is { gpc . get_world_size ( ParallelMode . PIPELINE ) } , but the number of partions is { len ( self . customized_parts ) } '
parts = self . customized_parts [ rank ]
else :
raise ValueError ( " A string partition policy should be one of [ ' uniform ' , ' balanced ' , ' customized ' ]. " )
elif isinstance ( self . _policy , dict ) :
parts = self . _policy [ rank ]
else :
raise ValueError ( " A partition policy should be either a string or a dictionary. " )
layers_to_build = [ ]
for start , end in parts :
layers_to_build + = self . _layer_spec_list [ start : end ]
behind_func_dict_in_partition = { }
front_func_dict_in_partition = { }
module_list_in_partition = [ ]
for layer in layers_to_build :
module = layer . build ( )
module_list_in_partition . append ( module )
if ( layer , " front " ) in self . _func_dict :
front_func_dict_in_partition [ id ( module ) ] = self . _func_dict [ ( layer , " front " ) ]
elif ( layer , " behind " ) in self . _func_dict :
behind_func_dict_in_partition [ id ( module ) ] = self . _func_dict [ ( layer , " behind " ) ]
module_list_in_partition = torch . nn . ModuleList ( module_list_in_partition )
pipeline_model = PipelinableModel ( module_list_in_partition , front_func_dict_in_partition ,
behind_func_dict_in_partition )
return pipeline_model
class PipelinableModel ( torch . nn . Module ) :
def __init__ ( self , module_list , front_func_dict , behind_func_dict ) :
super ( ) . __init__ ( )
self . _module_list = module_list
self . _front_func_dict = front_func_dict
self . _behind_func_dict = behind_func_dict
def forward ( self , * input_tensor , * * kwargs ) :
for module in self . _module_list :
if id ( module ) in self . _front_func_dict :
input_tensor = exec_funcs_with_kwargs ( self . _front_func_dict , id ( module ) , input_tensor , kwargs )
if isinstance ( module , CheckpointModule ) :
forward_func = module . _forward
else :
forward_func = module . forward
module_kwargs = build_kwargs_for_module ( forward_func , input_tensor , kwargs )
if input_tensor is None :
input_tensor = call_module ( module , kwargs = module_kwargs )
elif isinstance ( input_tensor , torch . Tensor ) :
input_tensor = call_module ( module , args = ( input_tensor , ) , kwargs = module_kwargs )
else :
input_tensor = call_module ( module , args = input_tensor , kwargs = module_kwargs )
if id ( module ) in self . _behind_func_dict :
input_tensor = exec_funcs_with_kwargs ( self . _behind_func_dict , id ( module ) , input_tensor , kwargs )
return input_tensor