import logging
import os
import warnings
from pathlib import Path
from typing import Callable , Iterator , List , Optional , Tuple , Union
import torch
import torch . nn as nn
from torch import Tensor
from torch . optim import Optimizer
from torch . optim . lr_scheduler import _LRScheduler as LRScheduler
from torch . utils . data import DataLoader
from colossalai . checkpoint_io import CheckpointIndexFile , CheckpointIO , GeneralCheckpointIO
from colossalai . checkpoint_io . utils import get_base_filenames , get_shard_filename , save_state_dict
from colossalai . cluster import DistCoordinator
from colossalai . interface import ModelWrapper , OptimizerWrapper
from colossalai . utils import get_current_device
from colossalai . zero import GeminiDDP , zero_model_wrapper , zero_optim_wrapper
from colossalai . zero . gemini . memory_tracer import MemStats
from . dp_plugin_base import DPPluginBase
__all__ = [ ' GeminiPlugin ' ]
class GeminiCheckpointIO ( GeneralCheckpointIO ) :
def __init__ ( self ) - > None :
super ( ) . __init__ ( )
self . coordinator = DistCoordinator ( )
def load_unsharded_model ( self , model : GeminiDDP , checkpoint : str , strict : bool = True ) :
"""
Load model from checkpoint with automatic unwrapping .
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
return super ( ) . load_unsharded_model ( model , checkpoint , strict = strict )
def save_unsharded_model ( self , model : GeminiDDP , checkpoint : str , gather_dtensor : bool , use_safetensors : bool ) :
"""
Save model to checkpoint but only on master process .
"""
# the model should be unwrapped in self.load_model via ModelWrapper.unwrap
# as there is communication when get state dict, this must be called on all processes
state_dict = model . state_dict ( only_rank_0 = True )
if self . coordinator . is_master ( ) :
save_state_dict ( state_dict , checkpoint , use_safetensors )
def save_unsharded_optimizer ( self , optimizer : Optimizer , checkpoint : str , gather_dtensor : bool ) :
"""
Save optimizer to checkpoint but only on master process .
"""
# TODO(ver217): optimizer state dict is sharded
super ( ) . save_unsharded_optimizer ( optimizer , checkpoint , gather_dtensor )
def save_lr_scheduler ( self , lr_scheduler : LRScheduler , checkpoint : str ) :
"""
Save model to checkpoint but only on master process .
"""
if self . coordinator . is_master ( ) :
super ( ) . save_lr_scheduler ( lr_scheduler , checkpoint )
def save_sharded_model ( self ,
model : GeminiDDP ,
checkpoint_path : str ,
gather_dtensor : bool = False ,
variant : Optional [ str ] = None ,
max_shard_size : int = 1024 ,
use_safetensors : bool = False ) :
"""
Save sharded model
"""
state_dict_shard = model . state_dict_shard ( max_shard_size = max_shard_size , only_rank_0 = True , dtype = torch . float32 )
weights_name , save_index_file = get_base_filenames ( variant , use_safetensors )
total_size = 0
index_file = CheckpointIndexFile ( checkpoint_path )
for idx , shard_pair in enumerate ( state_dict_shard ) :
if not self . coordinator . is_master ( ) :
continue
shard = shard_pair [ 0 ]
shard_file = get_shard_filename ( weights_name , idx )
total_size = total_size + shard_pair [ 1 ]
for key in shard . keys ( ) :
index_file . append_weight_map ( key , shard_file )
checkpoint_file_path = os . path . join ( checkpoint_path , shard_file )
save_state_dict ( shard , checkpoint_file_path , use_safetensors )
index_file . append_meta_data ( " total_size " , total_size )
index_file . write_index_file ( save_index_file )
logging . info ( f " The model is going to be split to checkpoint shards. "
f " You can find where each parameters has been saved in the "
f " index located at { save_index_file } . " )
def load_sharded_model ( self ,
model : GeminiDDP ,
checkpoint_index_file : Path ,
strict : bool = False ,
use_safetensors : bool = False ) :
"""
load shard model , load model from multiple files
"""
return super ( ) . load_sharded_model ( model , checkpoint_index_file , strict , use_safetensors , load_sub_module = False )
class GeminiModel ( ModelWrapper ) :
def __init__ ( self , module : nn . Module , gemini_config : dict , verbose : bool = False ) - > None :
super ( ) . __init__ ( module )
self . module = zero_model_wrapper ( module , zero_stage = 3 , gemini_config = gemini_config , verbose = verbose )
def unwrap ( self ) :
# as save/load state dict is coupled with the GeminiDDP, we only return GeminiDDP model
return self . module
class GeminiOptimizer ( OptimizerWrapper ) :
def __init__ ( self ,
module : GeminiDDP ,
optimizer : Optimizer ,
zero_optim_config : dict ,
optim_kwargs : dict ,
verbose : bool = False ) - > None :
optimizer = zero_optim_wrapper ( module ,
optimizer ,
optim_config = zero_optim_config ,
* * optim_kwargs ,
verbose = verbose )
super ( ) . __init__ ( optimizer )
def backward ( self , loss : Tensor , * args , * * kwargs ) :
self . optim . backward ( loss )
def clip_grad_by_norm ( self ,
max_norm : Union [ float , int ] ,
norm_type : Union [ float , int ] = 2 ,
error_if_nonfinite : bool = False ,
* args ,
* * kwargs ) - > Tensor :
warnings . warn ( f ' Gemini controls grad clipping by itself, so you should not use clip_grad_by_norm ' )
def clip_grad_by_value ( self , clip_value : float , * args , * * kwargs ) - > None :
raise NotImplementedError ( ' Gemini does not support clip_grad_by_value ' )
class GeminiPlugin ( DPPluginBase ) :
"""
Plugin for Gemini .
Example :
>> > from colossalai . booster import Booster
>> > from colossalai . booster . plugin import GeminiPlugin
>> >
>> > model , train_dataset , optimizer , criterion = . . .
>> > plugin = GeminiPlugin ( )
>> > train_dataloader = plugin . prepare_dataloader ( train_dataset , batch_size = 8 )
>> > booster = Booster ( plugin = plugin )
>> > model , optimizer , train_dataloader , criterion = booster . boost ( model , optimizer , train_dataloader , criterion )
Args :
device ( torch . device ) : device to place the model .
placement_policy ( str , optional ) : " cpu " , " cuda " , " auto " . Defaults to " cpu " .
pin_memory ( bool , optional ) : use pin memory on CPU . Defaults to False .
force_outputs_fp32 ( bool , optional ) : force outputs are fp32 . Defaults to False .
strict_ddp_mode ( bool , optional ) : use strict ddp mode ( only use dp without other parallelism ) . Defaults to False .
search_range_mb ( int , optional ) : chunk size searching range in MegaByte . Defaults to 32.
hidden_dim ( int , optional ) : the hidden dimension of DNN .
Users can provide this argument to speed up searching .
If users do not know this argument before training , it is ok . We will use a default value 1024.
min_chunk_size_mb ( float , optional ) : the minimum chunk size in MegaByte .
If the aggregate size of parameters is still samller than the minimum chunk size ,
all parameters will be compacted into one small chunk .
memstats ( MemStats , optional ) the memory statistics collector by a runtime memory tracer .
gpu_margin_mem_ratio ( float , optional ) : The ratio of GPU remaining memory ( after the first forward - backward )
which will be used when using hybrid CPU optimizer .
This argument is meaningless when ` placement_policy ` of ` GeminiManager ` is not " auto " .
Defaults to 0.0 .
initial_scale ( float , optional ) : Initial scale used by DynamicGradScaler . Defaults to 2 * * 32.
min_scale ( float , optional ) : Min scale used by DynamicGradScaler . Defaults to 1.
growth_factor ( float , optional ) : growth_factor used by DynamicGradScaler . Defaults to 2.
backoff_factor ( float , optional ) : backoff_factor used by DynamicGradScaler . Defaults to 0.5 .
growth_interval ( float , optional ) : growth_interval used by DynamicGradScaler . Defaults to 1000.
hysteresis ( float , optional ) : hysteresis used by DynamicGradScaler . Defaults to 2.
max_scale ( int , optional ) : max_scale used by DynamicGradScaler . Defaults to 2 * * 32.
max_norm ( float , optional ) : max_norm used for ` clip_grad_norm ` . You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP . The ZeRO optimizer will take care of clip_grad_norm .
norm_type ( float , optional ) : norm_type used for ` clip_grad_norm ` .
verbose ( bool , optional ) : verbose mode . Debug info including chunk search result will be printed . Defaults to False .
"""
def __init__ (
self ,
device : Optional [ torch . device ] = None ,
placement_policy : str = " cpu " ,
pin_memory : bool = False ,
force_outputs_fp32 : bool = False ,
strict_ddp_mode : bool = False ,
search_range_mb : int = 32 ,
hidden_dim : Optional [ int ] = None ,
min_chunk_size_mb : float = 32 ,
memstats : Optional [ MemStats ] = None ,
gpu_margin_mem_ratio : float = 0.0 ,
initial_scale : float = 2 * * 32 ,
min_scale : float = 1 ,
growth_factor : float = 2 ,
backoff_factor : float = 0.5 ,
growth_interval : int = 1000 ,
hysteresis : int = 2 ,
max_scale : float = 2 * * 32 ,
max_norm : float = 0.0 ,
norm_type : float = 2.0 ,
verbose : bool = False ,
) - > None :
super ( ) . __init__ ( )
self . gemini_config = dict (
device = ( device or get_current_device ( ) ) ,
placement_policy = placement_policy ,
pin_memory = pin_memory ,
force_outputs_fp32 = force_outputs_fp32 ,
strict_ddp_mode = strict_ddp_mode ,
search_range_mb = search_range_mb ,
hidden_dim = hidden_dim ,
min_chunk_size_mb = min_chunk_size_mb ,
memstats = memstats ,
)
self . zero_optim_config = dict ( gpu_margin_mem_ratio = gpu_margin_mem_ratio , )
self . optim_kwargs = dict ( initial_scale = initial_scale ,
growth_factor = growth_factor ,
backoff_factor = backoff_factor ,
growth_interval = growth_interval ,
hysteresis = hysteresis ,
min_scale = min_scale ,
max_scale = max_scale ,
max_norm = max_norm ,
norm_type = norm_type )
self . verbose = verbose
def support_no_sync ( self ) - > bool :
return False
def control_precision ( self ) - > bool :
return True
def supported_precisions ( self ) - > List [ str ] :
return [ ' fp16 ' ]
def control_device ( self ) - > bool :
return True
def supported_devices ( self ) - > List [ str ] :
return [ ' cuda ' ]
def configure (
self ,
model : nn . Module ,
optimizer : Optimizer ,
criterion : Callable = None ,
dataloader : DataLoader = None ,
lr_scheduler : LRScheduler = None ,
) - > Tuple [ Union [ nn . Module , OptimizerWrapper , LRScheduler , DataLoader ] ] :
if not isinstance ( model , ModelWrapper ) :
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
# In torch/nn/modules/_functions.py, line 22, ``mean, invstd = torch.batch_norm_stats(input, eps)`` will get fp32 mean and invstd even though the input is fp16.
# This inconsistency of dtype will cause the error.
# We have two possible solutions:
# 1. keep batch norm always in fp32. This is hard for gemini, as it use chunks.
# 2. patch sync bn or write a new on. This is relatively easy, but we need to test it.
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini
model = GeminiModel ( model , self . gemini_config , self . verbose )
if not isinstance ( optimizer , OptimizerWrapper ) :
optimizer = GeminiOptimizer ( model . unwrap ( ) , optimizer , self . zero_optim_config , self . optim_kwargs ,
self . verbose )
return model , optimizer , criterion , dataloader , lr_scheduler
def control_checkpoint_io ( self ) - > bool :
return True
def get_checkpoint_io ( self ) - > CheckpointIO :
return GeminiCheckpointIO ( )
def no_sync ( self , model : nn . Module ) - > Iterator [ None ] :
raise NotImplementedError