2022-03-29 09:57:59 +00:00
import contextlib
2022-03-07 08:14:40 +00:00
import functools
2022-06-06 07:34:41 +00:00
from contextlib import AbstractContextManager
2023-03-21 04:36:47 +00:00
from dataclasses import dataclass
from typing import Optional
2022-06-06 07:34:41 +00:00
2022-03-07 08:14:40 +00:00
import torch
2022-04-07 09:38:45 +00:00
import torch . distributed as dist
2023-03-21 04:36:47 +00:00
import torch . nn as nn
2022-06-06 07:34:41 +00:00
2022-03-21 03:18:55 +00:00
from colossalai . context . parallel_mode import ParallelMode
2022-03-29 09:57:59 +00:00
from colossalai . context . singleton_meta import SingletonMeta
2023-03-21 04:36:47 +00:00
from colossalai . core import global_context as gpc
2022-03-29 01:09:04 +00:00
from colossalai . logging import get_dist_logger
2023-03-21 04:36:47 +00:00
from colossalai . utils . model . utils import InsertPostInitMethodToModuleSubClasses
2023-04-04 05:48:16 +00:00
from colossalai . zero . legacy . shard_utils import BaseShardStrategy
2023-06-05 07:58:31 +00:00
from colossalai . zero . legacy . sharded_model . _utils import cast_tensor_to_bf16 , cast_tensor_to_fp16
2023-04-04 05:48:16 +00:00
from colossalai . zero . legacy . sharded_model . sharded_model_v2 import ShardedModelV2
from colossalai . zero . legacy . sharded_param import ShardedParamV2
2022-03-07 08:14:40 +00:00
2023-03-21 04:36:47 +00:00
@dataclass
class ZeroContextConfig :
2022-03-29 09:57:59 +00:00
""" The configuration used to control zero context initialization.
Args :
2022-04-01 12:10:47 +00:00
target_device ( torch . device ) : The device where param data are after exiting the context .
2023-03-21 04:36:47 +00:00
is_replicated ( bool , optional ) : Whether the param is replicated across data parallel group .
2022-03-31 10:34:11 +00:00
Some parameters are not replicated , e . g . parameters in MOE experts .
2022-03-29 09:57:59 +00:00
shard_param ( bool , optional ) : Is param sharded after exiting the context . Defaults to False .
"""
2023-03-21 04:36:47 +00:00
target_device : torch . device
is_replicated : bool = True
shard_param : bool = False
2022-04-07 09:38:45 +00:00
2023-03-21 04:36:47 +00:00
def __post_init__ ( self ) :
if self . shard_param :
assert self . is_replicated , " Non-replicated parameters can ' t be sharded. "
2022-04-07 09:38:45 +00:00
2023-03-21 04:36:47 +00:00
if self . is_replicated and not self . shard_param :
assert self . target_device . type == ' cuda ' , " Replicated no-shard parameters should be located in cuda. "
2022-03-29 09:57:59 +00:00
2022-03-07 08:14:40 +00:00
class ZeroInitContext ( InsertPostInitMethodToModuleSubClasses ) :
2022-03-24 15:44:00 +00:00
""" A context to initialize model.
2022-03-07 08:14:40 +00:00
1. Convert the model to fp16 .
2022-04-02 10:30:06 +00:00
2. The paramaters of the module are adapted to type ShardedParameter .
3. Shard the param and grad according to flags .
2022-03-10 06:08:58 +00:00
2022-03-24 15:44:00 +00:00
Args :
2022-04-01 12:10:47 +00:00
target_device ( torch . device ) : The device where param data are after exiting the context .
2022-03-24 15:44:00 +00:00
shard_strategy ( BaseShardStrategy ) : Shard strategy instance .
2022-04-07 09:38:45 +00:00
seed ( int , optional ) : Random seed for weight initialization
2022-03-24 15:44:00 +00:00
shard_param ( bool , optional ) : Is param sharded after exiting the context . Defaults to False .
2022-04-19 08:16:48 +00:00
default_dtype ( torch . dtype , optional ) : If it ' s not None, parameters will be initialized as ``default_dtype`` then converted to fp16.
2023-06-05 07:58:31 +00:00
bf16 ( bool , optional ) : If it ' s True, parameters will be initialized as ``torch.bfloat16``. Otherwise, parameters will be initialized as ``torch.float16``. Defaults to False.
2022-03-24 15:44:00 +00:00
model_numel_tensor ( torch . Tensor , optional ) : A tensor which will store the number of elements of model . Defaults to torch . zeros ( 1 , dtype = torch . int ) .
2022-03-07 08:14:40 +00:00
"""
2022-03-08 06:45:01 +00:00
def __init__ ( self ,
2022-03-10 06:08:58 +00:00
target_device : torch . device ,
2022-03-08 06:45:01 +00:00
shard_strategy : BaseShardStrategy ,
2022-04-07 09:38:45 +00:00
seed : int = 2 * * 10 - 1 ,
2022-03-08 06:45:01 +00:00
shard_param : bool = False ,
2022-04-19 08:16:48 +00:00
default_dtype : Optional [ torch . dtype ] = None ,
2023-06-05 07:58:31 +00:00
bf16 : bool = False ,
2022-04-07 09:38:45 +00:00
model_numel_tensor : torch . Tensor = torch . zeros ( 1 , dtype = torch . long ) ) :
2022-03-24 15:44:00 +00:00
2022-04-19 08:16:48 +00:00
super ( ) . __init__ ( default_dtype = default_dtype )
2022-03-07 08:14:40 +00:00
self . shard_strategy = shard_strategy
2022-04-08 12:23:26 +00:00
self . param_list = [ ]
2022-03-10 08:31:02 +00:00
self . model_numel_tensor = model_numel_tensor
2022-04-07 09:38:45 +00:00
self . seed = seed
2023-06-05 07:58:31 +00:00
self . bf16 = bf16
2022-04-07 09:38:45 +00:00
self . dp_process_group = gpc . get_group ( ParallelMode . DATA )
2022-03-07 08:14:40 +00:00
2023-03-21 04:36:47 +00:00
self . config = ZeroContextConfig ( target_device = target_device , is_replicated = True , shard_param = shard_param )
2022-04-01 12:10:47 +00:00
2022-03-29 09:57:59 +00:00
ZeroContextMgr ( ) . current_context = self
2022-06-16 09:17:27 +00:00
self . param_numel = { }
self . top_module = None
2022-04-01 12:10:47 +00:00
@property
def target_device ( self ) :
return self . config . target_device
2022-03-31 10:34:11 +00:00
@property
def is_replicated ( self ) :
return self . config . is_replicated
2022-03-29 09:57:59 +00:00
@property
def shard_param ( self ) :
return self . config . shard_param
2022-04-07 09:38:45 +00:00
@staticmethod
def calc_fanin_fanout ( tensor : torch . Tensor ) :
""" We use this function to substitute fan-in and fan-out calculation in torch.nn.init.
This can help us get correct fan - in and fan - out for sharded tensor .
"""
2023-04-26 03:38:43 +00:00
assert isinstance ( tensor , nn . Parameter ) , " Sharded tensor initialization is only allowed for parameters "
2022-04-07 09:38:45 +00:00
# get correct shape of input tensor
if not hasattr ( tensor , ' colo_attr ' ) or not tensor . colo_attr . param_is_sharded :
tensor_shape = tensor . shape
else :
tensor_shape = tensor . colo_attr . sharded_data_tensor . origin_shape
dimensions = len ( tensor_shape )
if dimensions < 2 :
raise ValueError ( " Fan in and fan out can not be computed for tensor with fewer than 2 dimensions " )
num_input_fmaps = tensor_shape [ 1 ]
num_output_fmaps = tensor_shape [ 0 ]
receptive_field_size = 1
if dimensions > 2 :
# math.prod is not always available, accumulate the product manually
# we could use functools.reduce but that is not supported by TorchScript
for s in tensor_shape [ 2 : ] :
receptive_field_size * = s
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in , fan_out
2022-03-29 09:57:59 +00:00
2022-03-25 03:23:35 +00:00
def _pre_context_exec ( self ) :
2023-03-21 04:36:47 +00:00
"""
2022-03-25 03:23:35 +00:00
The Callback function when entering the context
"""
self . logger = get_dist_logger ( " ZeroInitContext " )
2022-04-07 09:38:45 +00:00
# substitute fan-in and fan-out calculation
self . nn_fanin_fanout = nn . init . _calculate_fan_in_and_fan_out
nn . init . _calculate_fan_in_and_fan_out = self . calc_fanin_fanout
2022-05-27 02:25:08 +00:00
self . module_load_from_state_dict = nn . Module . _load_from_state_dict
shard_strategy = self . shard_strategy if self . config . shard_param else None
nn . Module . _load_from_state_dict = functools . partialmethod ( ShardedModelV2 . _colo_load_from_state_dict ,
shard_strategy = shard_strategy )
self . module_state_dict = nn . Module . state_dict
nn . Module . state_dict = functools . partialmethod ( ShardedModelV2 . _colo_state_dict ,
shard_strategy = shard_strategy ,
state_dict_func = self . module_state_dict ,
process_group = self . dp_process_group )
2022-04-07 09:38:45 +00:00
# reserve rng states
self . cpu_rng_state = torch . get_rng_state ( )
self . cuda_rng_state = torch . cuda . get_rng_state ( )
# set new seed for initialization, since we initialize sharded tensor separately
# we don't want all processes have the same seed
# otherwise all sharded tensors are same after init
offset = self . seed + 1 # we want to have more 1 in binary format seed
torch . manual_seed ( self . seed + offset * dist . get_rank ( ) )
2022-03-07 08:14:40 +00:00
def _post_context_exec ( self ) :
2022-03-25 03:23:35 +00:00
""" The callback function when exiting context.
2022-03-07 08:14:40 +00:00
"""
2022-04-08 12:23:26 +00:00
# broadcast replicated no-shard parameters
2022-04-07 09:38:45 +00:00
src_rank = gpc . get_ranks_in_group ( ParallelMode . DATA ) [ 0 ]
2022-04-08 12:23:26 +00:00
for param in self . param_list :
2022-04-07 09:38:45 +00:00
assert hasattr ( param , ' colo_attr ' )
2022-04-11 05:38:51 +00:00
if not param . colo_attr . param_is_sharded and param . colo_attr . is_replicated :
2022-04-07 09:38:45 +00:00
dist . broadcast ( tensor = param . data , src = src_rank , group = self . dp_process_group )
2022-04-13 06:54:26 +00:00
param . colo_attr . set_data_none ( )
2022-04-07 09:38:45 +00:00
2022-04-08 12:23:26 +00:00
del self . param_list
2022-03-08 06:45:01 +00:00
2022-04-07 09:38:45 +00:00
nn . init . _calculate_fan_in_and_fan_out = self . nn_fanin_fanout
2022-05-27 02:25:08 +00:00
nn . Module . load_state_dict = self . module_load_from_state_dict
nn . Module . state_dict = self . module_state_dict
2022-04-07 09:38:45 +00:00
torch . set_rng_state ( self . cpu_rng_state )
torch . cuda . set_rng_state ( self . cuda_rng_state )
2022-03-07 08:14:40 +00:00
2022-06-16 09:17:27 +00:00
params = frozenset ( self . top_module . parameters ( ) )
for param in self . param_numel . keys ( ) :
if param not in params :
self . param_numel [ param ] = 0
self . model_numel_tensor . fill_ ( sum ( self . param_numel . values ( ) ) )
2022-04-24 06:16:50 +00:00
def _post_init_method ( self , module : torch . nn . Module , * args , * * kwargs ) :
2022-03-25 03:23:35 +00:00
"""
The function to call at the end of the constructor of each module .
NOTE ( ) The module may be passed to this function multiple times .
2022-03-07 08:14:40 +00:00
"""
2022-06-16 09:17:27 +00:00
self . top_module = module
2023-06-05 07:58:31 +00:00
half_dtype = torch . float16 if not self . bf16 else torch . bfloat16
2022-04-02 10:30:06 +00:00
def half_fn ( t : torch . Tensor ) :
2023-06-05 07:58:31 +00:00
return t . to ( half_dtype ) if t . is_floating_point ( ) else t
2022-04-02 10:30:06 +00:00
2022-03-28 09:42:18 +00:00
for param in module . parameters ( recurse = False ) :
2022-03-07 08:14:40 +00:00
# avoid adapting a param to ShardedParam twice
2022-03-31 04:25:45 +00:00
if hasattr ( param , ' colo_attr ' ) :
2022-03-07 08:14:40 +00:00
continue
2022-06-16 09:17:27 +00:00
self . param_numel [ param ] = param . numel ( )
2022-03-10 08:31:02 +00:00
2022-03-29 09:57:59 +00:00
# convert parameters to half
2022-04-02 10:30:06 +00:00
param_half = half_fn ( param )
2022-03-29 09:57:59 +00:00
param . data = param_half
2022-03-29 01:09:04 +00:00
if param . grad is not None :
2022-04-02 10:30:06 +00:00
grad_half = half_fn ( param . grad )
2022-03-29 09:57:59 +00:00
param . grad . data = grad_half
2022-03-07 08:14:40 +00:00
2022-03-10 06:08:58 +00:00
# move torch parameters to the target device
2022-03-29 09:57:59 +00:00
target_device = self . target_device
2022-03-10 06:08:58 +00:00
param . data = param . data . to ( target_device )
if param . grad is not None :
param . grad = param . grad . to ( target_device )
2022-04-24 05:08:48 +00:00
param . colo_attr = ShardedParamV2 ( param , set_data_none = True )
2022-03-08 06:45:01 +00:00
2022-03-07 08:14:40 +00:00
if self . shard_param :
2022-03-31 04:25:45 +00:00
self . shard_strategy . shard ( [ param . colo_attr . sharded_data_tensor ] , self . dp_process_group )
2022-04-24 05:08:48 +00:00
param . data = param . colo_attr . data_payload # set param.data to payload
2022-04-08 12:23:26 +00:00
2022-04-11 05:38:51 +00:00
# mark whether the param is replicated
param . colo_attr . is_replicated = self . is_replicated
# mark whether the param should keep not sharded
# if True, the param is used as Zero stage 2
param . colo_attr . keep_not_shard = not self . shard_param
2022-04-08 12:23:26 +00:00
self . param_list . append ( param )
2022-03-25 10:03:32 +00:00
2022-03-18 07:44:47 +00:00
# We must cast buffers
# If we use BN, buffers may be on CPU and Float
# We must cast them
2023-06-05 07:58:31 +00:00
cast_fn = cast_tensor_to_fp16 if not self . bf16 else cast_tensor_to_bf16
2022-03-28 09:42:18 +00:00
for buffer in module . buffers ( recurse = False ) :
2022-03-18 07:44:47 +00:00
buffer . data = buffer . data . to ( device = torch . cuda . current_device ( ) )
2023-06-05 07:58:31 +00:00
buffer . data = cast_fn ( buffer . data )
2022-03-29 09:57:59 +00:00
class ZeroContextMgr ( metaclass = SingletonMeta ) :
current_context : Optional [ ZeroInitContext ] = None
@contextlib.contextmanager
def hijack_context_config ( self , * * kwargs ) :
if self . current_context is None :
yield
else :
old_config = self . current_context . config
self . current_context . config = ZeroContextConfig ( * * kwargs )
yield
self . current_context . config = old_config
2022-04-01 12:10:47 +00:00
def no_shard_zero_context ( is_replicated : bool = True ) - > AbstractContextManager :
return ZeroContextMgr ( ) . hijack_context_config ( target_device = torch . device ( ' cuda ' , torch . cuda . current_device ( ) ) ,
2023-03-21 04:36:47 +00:00
is_replicated = is_replicated ,
2022-04-07 09:38:45 +00:00
shard_param = False )
2022-03-31 10:34:11 +00:00
def no_shard_zero_decrator ( is_replicated : bool = True ) :
2022-03-29 09:57:59 +00:00
2022-03-31 10:34:11 +00:00
def _wrapper ( init_func ) :
2022-03-29 09:57:59 +00:00
2022-03-31 10:34:11 +00:00
def _no_shard ( * args , * * kwargs ) :
with no_shard_zero_context ( is_replicated ) :
2022-09-22 05:56:30 +00:00
ret = init_func ( * args , * * kwargs )
return ret
2022-03-29 09:57:59 +00:00
2022-03-31 10:34:11 +00:00
return _no_shard
2022-03-29 09:57:59 +00:00
2022-03-31 10:34:11 +00:00
return _wrapper