2023-09-11 08:24:28 +00:00
import torch
2023-09-18 08:31:06 +00:00
from colossalai . legacy . context import ParallelMode
from colossalai . legacy . core import global_context as gpc
2023-09-11 08:24:28 +00:00
from colossalai . legacy . nn . layer . utils import CheckpointModule
from colossalai . tensor import ColoParameter
from colossalai . utils . model . utils import InsertPostInitMethodToModuleSubClasses
2022-10-25 08:48:48 +00:00
from . layer_spec import LayerSpec
2023-09-11 08:24:28 +00:00
from . utils import (
build_kwargs_for_module ,
call_module ,
customized_partition ,
exec_funcs_with_kwargs ,
partition_balanced ,
partition_uniform ,
)
2022-04-24 05:03:12 +00:00
class PipelinableContext ( InsertPostInitMethodToModuleSubClasses ) :
2022-06-10 03:27:38 +00:00
"""
A context manager to split the model into pipeline stages .
"""
2022-04-24 05:03:12 +00:00
2022-06-16 03:19:48 +00:00
def __init__ ( self , policy : str = " balanced " ) :
2022-04-24 05:03:12 +00:00
super ( ) . __init__ ( )
self . _layer_spec_dict = { }
self . _root_children = None
self . _model = None
self . _layer_spec_list = [ ]
self . _func_dict = { }
2022-06-10 03:27:38 +00:00
self . _policy = policy
2022-04-24 05:03:12 +00:00
@property
def policy ( self ) :
return self . _policy
2022-06-10 03:27:38 +00:00
@policy.setter
def policy ( self , policy : str ) :
self . _policy = policy
2022-04-24 05:03:12 +00:00
@property
def layers_count ( self ) :
return len ( self . _layer_spec_list )
@property
def funcs_count ( self ) :
return len ( self . _func_dict )
def _pre_context_exec ( self ) :
2022-06-10 03:27:38 +00:00
"""
2022-04-24 05:03:12 +00:00
The Callback function when entering the context
"""
# reserve rng states
self . cpu_rng_state = torch . get_rng_state ( )
self . cuda_rng_state = torch . cuda . get_rng_state ( )
def _post_context_exec ( self ) :
"""
The callback function when exiting context .
"""
# reset rng states
torch . set_rng_state ( self . cpu_rng_state )
torch . cuda . set_rng_state ( self . cuda_rng_state )
def _post_init_method ( self , module : torch . nn . Module , * args , * * kwargs ) :
"""
The function to call at the end of the constructor of each module .
NOTE ( ) The module may be passed to this function multiple times .
"""
2022-06-10 03:27:38 +00:00
# iterate over the positional arguments
# to check if an argument is a torch Module
2022-06-16 03:19:48 +00:00
# if found any torch Module, replace it with its layer spec
2022-06-10 03:27:38 +00:00
# for storage purpose
2022-04-24 05:03:12 +00:00
modified_args = [ ]
2022-06-10 03:27:38 +00:00
for arg in args :
if isinstance ( arg , torch . nn . Module ) :
2022-06-17 09:54:15 +00:00
# if nn.Module is an argument of a non-root module, then we should convert it to layer spec, which make sure the correct init method used in the real build.
# if nn.Module is an argument of the root module, then we should just record the module instance itself, because those instance has been built outside of the context.
if id ( arg ) in self . _layer_spec_dict :
arg = self . _layer_spec_dict [ id ( arg ) ]
2022-06-10 03:27:38 +00:00
modified_args . append ( arg )
2022-05-11 01:23:58 +00:00
2022-06-10 03:27:38 +00:00
# to the same for the keyword arguments
2022-05-11 01:23:58 +00:00
modified_kwargs = { }
for k , v in kwargs . items ( ) :
2022-06-10 03:27:38 +00:00
if isinstance ( v , torch . nn . Module ) :
2022-05-11 01:23:58 +00:00
v = self . _layer_spec_dict [ id ( v ) ]
2023-06-06 06:07:36 +00:00
# (lyl)TODO: analyze ColoTensor as well
2022-05-11 01:23:58 +00:00
modified_kwargs [ k ] = v
2022-06-10 03:27:38 +00:00
# keep track of the module children
# as torch.nn.Module.__init__ is called from inner module to outer module,
# the final value of self._model will be the outermost model
# e.g. if the model is torchvision.models.resnet18, then the final value of self._model
# will be the ``ResNet`` object.
2022-04-24 05:03:12 +00:00
self . _root_children = list ( module . children ( ) )
self . _model = module
2022-06-10 03:27:38 +00:00
# store the children to keep the module hierarchy
2022-05-11 01:23:58 +00:00
layer_spec = LayerSpec ( module . __class__ , * modified_args , * * modified_kwargs )
2022-04-24 05:03:12 +00:00
layer_spec . set_children ( module . children ( ) )
2022-06-10 03:27:38 +00:00
# store the layer spec in this context
module_id = id ( module )
2022-04-24 05:03:12 +00:00
self . _layer_spec_dict [ module_id ] = layer_spec
2022-06-10 03:27:38 +00:00
# convert all torch.nn.Parameter to colossalai.tensor.ColoParameter
2022-04-24 10:31:22 +00:00
name_list = [ ]
for name , param in module . named_parameters ( ) :
2022-06-10 03:27:38 +00:00
if isinstance ( param , ColoParameter ) :
2022-04-24 10:31:22 +00:00
continue
name_list . append ( ( name , param ) )
for name , param in name_list :
2022-09-20 10:00:39 +00:00
if hasattr ( module , name ) :
delattr ( module , name )
2022-06-10 03:27:38 +00:00
setattr ( module , name , ColoParameter . from_torch_tensor ( tensor = param . data , requires_grad = param . requires_grad ) )
2022-04-24 05:03:12 +00:00
def to_layer_list ( self , exec_seq = None ) :
"""
Create a layer spec list and func list with execution sequence given by user .
2023-06-06 06:07:36 +00:00
If exec_seq is None , we will take the module initializing order as execution order .
2022-04-24 05:03:12 +00:00
"""
2022-06-21 07:23:41 +00:00
self . _exec_seq = exec_seq
2022-04-24 05:03:12 +00:00
if exec_seq is None :
2022-05-19 04:44:59 +00:00
# if user do not provide the model executing sequence, we use the initialization order as the executing order.
2022-05-11 01:23:58 +00:00
children_name = [ ]
2022-04-24 05:03:12 +00:00
for child in self . _root_children :
layer_spec = self . _layer_spec_dict [ id ( child ) ]
2023-09-18 08:31:06 +00:00
if layer_spec . typename in (
2023-09-19 06:20:26 +00:00
torch . nn . modules . container . ModuleList ,
torch . nn . modules . container . Sequential ,
2023-09-18 08:31:06 +00:00
) :
2022-04-24 05:03:12 +00:00
for child_in_container in layer_spec . children :
self . _layer_spec_list . append ( self . _layer_spec_dict [ id ( child_in_container ) ] )
2022-05-11 01:23:58 +00:00
for name , module in self . _model . named_modules ( ) :
if id ( module ) == id ( child_in_container ) :
children_name . append ( name )
break
2022-04-24 05:03:12 +00:00
else :
self . _layer_spec_list . append ( layer_spec )
2022-05-11 01:23:58 +00:00
for name , module in self . _model . named_modules ( ) :
if id ( module ) == id ( child ) :
children_name . append ( name )
break
2022-04-24 05:03:12 +00:00
else :
2022-05-11 01:23:58 +00:00
front_funcs_list = [ ]
2022-06-10 03:27:38 +00:00
named_modules = dict ( self . _model . named_modules ( ) )
2022-04-24 05:03:12 +00:00
for index , element in enumerate ( exec_seq ) :
if isinstance ( element , str ) :
2023-09-18 08:31:06 +00:00
if element == " SPLIT_NODE " :
2022-06-21 07:23:41 +00:00
continue
2023-09-18 08:31:06 +00:00
assert (
element in named_modules
) , f " Found invalid module name { element } , please check if you spell the module name correctly. "
2022-06-10 03:27:38 +00:00
# get the layer spec based on the module ID
module = named_modules [ element ]
2022-04-24 05:03:12 +00:00
layer_spec = self . _layer_spec_dict [ id ( module ) ]
2022-06-10 03:27:38 +00:00
# check whether there are functions which should be executed before this module
2022-05-11 01:23:58 +00:00
if len ( front_funcs_list ) != 0 :
func_key = ( layer_spec , " front " )
if func_key not in self . _func_dict :
self . _func_dict [ func_key ] = [ ]
for f in front_funcs_list :
self . _func_dict [ func_key ] . append ( f )
front_funcs_list = [ ]
2022-06-10 03:27:38 +00:00
2022-05-11 01:23:58 +00:00
func_key = ( layer_spec , " behind " )
2022-04-24 05:03:12 +00:00
self . _layer_spec_list . append ( layer_spec )
2022-05-11 01:23:58 +00:00
elif isinstance ( element , tuple ) and element [ 1 ] == " front " :
front_funcs_list . append ( element [ 0 ] )
2022-04-24 05:03:12 +00:00
else :
if func_key not in self . _func_dict :
self . _func_dict [ func_key ] = [ ]
2022-05-11 01:23:58 +00:00
if isinstance ( element , tuple ) :
self . _func_dict [ func_key ] . append ( element [ 0 ] )
else :
self . _func_dict [ func_key ] . append ( element )
2022-04-24 05:03:12 +00:00
def partition ( self , num_chunks , pipeline_size , rank ) :
"""
2023-06-06 06:07:36 +00:00
Partitioned model will be built respect to partition policy .
2022-04-24 05:03:12 +00:00
The real module instance will be built in this method .
"""
if isinstance ( self . _policy , str ) :
if self . _policy == " uniform " :
parts = partition_uniform ( len ( self . _layer_spec_list ) , pipeline_size , num_chunks ) [ rank ]
elif self . _policy == " balanced " :
param_counts = [ ]
for layer_spec in self . _layer_spec_list :
param_counts . append ( layer_spec . count_params ( ) )
parts = partition_balanced ( param_counts , pipeline_size , num_chunks ) [ rank ]
2022-06-21 07:23:41 +00:00
elif self . _policy == " customized " :
2023-09-19 06:20:26 +00:00
assert (
self . _exec_seq is not None
) , f " An explicit exec_seq must be defined by user in customized policy mode. "
2022-06-21 07:23:41 +00:00
self . customized_parts = customized_partition ( self . _exec_seq )
assert len ( self . customized_parts ) == gpc . get_world_size (
ParallelMode . PIPELINE
2023-09-18 08:31:06 +00:00
) , f " World size is { gpc . get_world_size ( ParallelMode . PIPELINE ) } , but the number of partitions is { len ( self . customized_parts ) } "
2022-06-21 07:23:41 +00:00
parts = self . customized_parts [ rank ]
2022-04-24 05:03:12 +00:00
else :
2022-06-21 07:23:41 +00:00
raise ValueError ( " A string partition policy should be one of [ ' uniform ' , ' balanced ' , ' customized ' ]. " )
2022-04-24 05:03:12 +00:00
elif isinstance ( self . _policy , dict ) :
parts = self . _policy [ rank ]
else :
raise ValueError ( " A partition policy should be either a string or a dictionary. " )
layers_to_build = [ ]
for start , end in parts :
layers_to_build + = self . _layer_spec_list [ start : end ]
2022-05-11 01:23:58 +00:00
behind_func_dict_in_partition = { }
front_func_dict_in_partition = { }
2022-04-24 05:03:12 +00:00
module_list_in_partition = [ ]
for layer in layers_to_build :
module = layer . build ( )
module_list_in_partition . append ( module )
2022-05-11 01:23:58 +00:00
if ( layer , " front " ) in self . _func_dict :
front_func_dict_in_partition [ id ( module ) ] = self . _func_dict [ ( layer , " front " ) ]
elif ( layer , " behind " ) in self . _func_dict :
behind_func_dict_in_partition [ id ( module ) ] = self . _func_dict [ ( layer , " behind " ) ]
2022-04-24 05:03:12 +00:00
module_list_in_partition = torch . nn . ModuleList ( module_list_in_partition )
2023-09-19 06:20:26 +00:00
pipeline_model = PipelinableModel (
module_list_in_partition , front_func_dict_in_partition , behind_func_dict_in_partition
)
2022-04-24 05:03:12 +00:00
return pipeline_model
2022-05-11 01:23:58 +00:00
2022-04-24 05:03:12 +00:00
class PipelinableModel ( torch . nn . Module ) :
2022-05-11 01:23:58 +00:00
def __init__ ( self , module_list , front_func_dict , behind_func_dict ) :
2022-04-24 05:03:12 +00:00
super ( ) . __init__ ( )
self . _module_list = module_list
2022-05-11 01:23:58 +00:00
self . _front_func_dict = front_func_dict
self . _behind_func_dict = behind_func_dict
2022-06-21 06:40:50 +00:00
def forward ( self , * input_tensor , * * kwargs ) :
2022-04-24 05:03:12 +00:00
for module in self . _module_list :
2022-05-11 01:23:58 +00:00
if id ( module ) in self . _front_func_dict :
2022-06-10 03:27:38 +00:00
input_tensor = exec_funcs_with_kwargs ( self . _front_func_dict , id ( module ) , input_tensor , kwargs )
2022-05-11 01:23:58 +00:00
if isinstance ( module , CheckpointModule ) :
forward_func = module . _forward
else :
forward_func = module . forward
2022-06-21 06:40:50 +00:00
module_kwargs = build_kwargs_for_module ( forward_func , input_tensor , kwargs )
2022-05-11 01:23:58 +00:00
if input_tensor is None :
2022-06-21 06:40:50 +00:00
input_tensor = call_module ( module , kwargs = module_kwargs )
elif isinstance ( input_tensor , torch . Tensor ) :
input_tensor = call_module ( module , args = ( input_tensor , ) , kwargs = module_kwargs )
2022-05-11 01:23:58 +00:00
else :
2022-06-21 06:40:50 +00:00
input_tensor = call_module ( module , args = input_tensor , kwargs = module_kwargs )
2022-05-11 01:23:58 +00:00
if id ( module ) in self . _behind_func_dict :
2022-06-10 03:27:38 +00:00
input_tensor = exec_funcs_with_kwargs ( self . _behind_func_dict , id ( module ) , input_tensor , kwargs )
2022-04-24 05:03:12 +00:00
return input_tensor