2023-07-21 06:39:01 +00:00
import gc
2023-05-05 11:36:10 +00:00
import logging
import os
2023-05-05 06:37:21 +00:00
from pathlib import Path
2023-08-24 01:29:25 +00:00
from typing import Callable , Iterator , List , Optional , Tuple
2023-03-31 08:06:13 +00:00
import torch
import torch . nn as nn
from torch . optim import Optimizer
from torch . optim . lr_scheduler import _LRScheduler as LRScheduler
from torch . utils . data import DataLoader
2023-05-05 11:36:10 +00:00
from colossalai . checkpoint_io import CheckpointIndexFile , CheckpointIO , GeneralCheckpointIO
2023-07-21 06:39:01 +00:00
from colossalai . checkpoint_io . utils import (
get_model_base_filenames ,
get_optimizer_base_filenames ,
load_shard_state_dict ,
2023-09-04 15:25:01 +00:00
save_config_file ,
2023-07-21 06:39:01 +00:00
save_state_dict ,
save_state_dict_shards ,
)
2023-03-31 08:06:13 +00:00
from colossalai . cluster import DistCoordinator
from colossalai . interface import ModelWrapper , OptimizerWrapper
from colossalai . utils import get_current_device
2023-08-24 01:29:25 +00:00
from colossalai . zero import GeminiDDP , GeminiOptimizer
2023-04-04 05:48:16 +00:00
from colossalai . zero . gemini . memory_tracer import MemStats
2023-03-31 08:06:13 +00:00
2023-05-05 11:36:10 +00:00
from . dp_plugin_base import DPPluginBase
2023-03-31 08:06:13 +00:00
2023-09-19 06:20:26 +00:00
__all__ = [ " GeminiPlugin " ]
2023-03-31 08:06:13 +00:00
2023-09-19 06:20:26 +00:00
SUPPORTED_PRECISION = [ " fp16 " , " bf16 " ]
PRECISION_STR_TO_DTYPE = { " fp16 " : torch . half , " bf16 " : torch . bfloat16 }
2023-06-05 07:58:31 +00:00
2023-03-31 08:06:13 +00:00
class GeminiCheckpointIO ( GeneralCheckpointIO ) :
def __init__ ( self ) - > None :
super ( ) . __init__ ( )
self . coordinator = DistCoordinator ( )
2023-04-06 01:43:51 +00:00
def save_unsharded_model ( self , model : GeminiDDP , checkpoint : str , gather_dtensor : bool , use_safetensors : bool ) :
2023-03-31 08:06:13 +00:00
"""
2023-07-07 08:33:06 +00:00
Save sharded model to checkpoint but only on master process .
The model should be unwrapped in self . load_model via ModelWrapper . unwrap .
2023-07-21 06:39:01 +00:00
As there is communication when getting state dict , model . state_dict ( ) must be called on all processes .
2023-03-31 08:06:13 +00:00
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( model , GeminiDDP ) , " Please boost the model before saving! "
2023-03-31 08:06:13 +00:00
state_dict = model . state_dict ( only_rank_0 = True )
if self . coordinator . is_master ( ) :
2023-04-06 01:43:51 +00:00
save_state_dict ( state_dict , checkpoint , use_safetensors )
2023-03-31 08:06:13 +00:00
2023-07-07 08:33:06 +00:00
def load_unsharded_model ( self , model : GeminiDDP , checkpoint : str , strict : bool = True ) :
2023-03-31 08:06:13 +00:00
"""
2023-07-07 08:33:06 +00:00
Load model from checkpoint with automatic unwrapping .
The model should be unwrapped in self . load_model via ModelWrapper . unwrap .
2023-03-31 08:06:13 +00:00
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( model , GeminiDDP ) , " Please boost the model before loading! "
2023-07-07 08:33:06 +00:00
super ( ) . load_unsharded_model ( model , checkpoint , strict = strict )
2023-05-19 11:42:31 +00:00
2023-09-20 10:29:37 +00:00
def save_unsharded_optimizer ( self , optimizer : GeminiOptimizer , checkpoint : str , gather_dtensor : bool ) :
2023-03-31 08:06:13 +00:00
"""
2023-07-07 08:33:06 +00:00
Save unsharded optimizer state dict to checkpoint .
After calling optimizer . state_dict ( ) , the complete optimizer states will be collected on master rank .
2023-07-21 06:39:01 +00:00
As there is communication when getting state dict , optimizer . state_dict ( ) must be called on all processes .
2023-07-07 08:33:06 +00:00
The saving process will only be executed by master rank .
2023-03-31 08:06:13 +00:00
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( optimizer , GeminiOptimizer ) , " Please boost the optimizer before saving! "
2023-07-07 08:33:06 +00:00
state_dict = optimizer . state_dict ( )
2023-03-31 08:06:13 +00:00
if self . coordinator . is_master ( ) :
2023-07-07 08:33:06 +00:00
save_state_dict ( state_dict , checkpoint , use_safetensors = False )
2023-09-20 10:29:37 +00:00
def load_unsharded_optimizer ( self , optimizer : GeminiOptimizer , checkpoint : str ) :
2023-07-07 08:33:06 +00:00
"""
Loading unsharded optimizer from checkpoint file .
For each process , only loading optimizer states of parameters it controls .
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( optimizer , GeminiOptimizer ) , " Please boost the optimizer before loading! "
2023-07-07 08:33:06 +00:00
super ( ) . load_unsharded_optimizer ( optimizer , checkpoint )
2023-03-31 08:06:13 +00:00
2023-09-19 06:20:26 +00:00
def save_sharded_model (
self ,
model : GeminiDDP ,
checkpoint_path : str ,
gather_dtensor : bool = False ,
prefix : Optional [ str ] = None ,
max_shard_size : int = 1024 ,
use_safetensors : bool = False ,
) :
2023-05-05 06:37:21 +00:00
"""
2023-07-21 06:39:01 +00:00
Save sharded model .
As there is communication when getting state dict , model . state_dict ( ) must be called on all processes .
2023-05-05 06:37:21 +00:00
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( model , GeminiDDP ) , " Please boost the model before saving! "
2023-07-07 08:33:06 +00:00
if os . path . isfile ( checkpoint_path ) :
logging . error ( f " Provided path ( { checkpoint_path } ) should be a directory, not a file " )
return
Path ( checkpoint_path ) . mkdir ( parents = True , exist_ok = True )
2023-10-12 02:39:08 +00:00
state_dict_shard = model . state_dict_shard ( max_shard_size = max_shard_size , only_rank_0 = True )
2023-06-15 07:21:26 +00:00
weights_name , save_index_file = get_model_base_filenames ( prefix , use_safetensors )
2023-05-05 06:37:21 +00:00
index_file = CheckpointIndexFile ( checkpoint_path )
2023-05-05 11:36:10 +00:00
2023-07-21 06:39:01 +00:00
# Save shards of optimizer states.
is_master = self . coordinator . is_master ( )
2023-09-19 06:20:26 +00:00
total_size = save_state_dict_shards (
sharded_state_dict = state_dict_shard ,
checkpoint = checkpoint_path ,
index_file = index_file ,
base_filename = weights_name ,
is_master = is_master ,
use_safetensors = use_safetensors ,
)
2023-06-09 01:48:49 +00:00
# only save the index file on the master rank
if self . coordinator . is_master ( ) :
2023-07-21 06:39:01 +00:00
index_file . append_meta_data ( " total_size " , total_size )
2023-06-09 01:48:49 +00:00
index_file . write_index_file ( save_index_file )
2023-09-20 10:29:37 +00:00
save_config_file ( model . unwrap ( ) , checkpoint_path )
2023-09-19 06:20:26 +00:00
logging . info (
f " The model is split into checkpoint shards. "
f " You can find where each parameters has been saved in the "
f " index located at { save_index_file } . "
)
def load_sharded_model (
self , model : GeminiDDP , checkpoint_index_file : Path , strict : bool = False , use_safetensors : bool = False
) :
2023-05-05 06:37:21 +00:00
"""
2023-07-21 06:39:01 +00:00
Load shard model , load model from multiple files .
2023-05-05 06:37:21 +00:00
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( model , GeminiDDP ) , " Please boost the model before loading! "
2023-05-05 06:37:21 +00:00
return super ( ) . load_sharded_model ( model , checkpoint_index_file , strict , use_safetensors , load_sub_module = False )
2023-03-31 08:06:13 +00:00
2023-09-19 06:20:26 +00:00
def save_sharded_optimizer (
2023-09-20 10:29:37 +00:00
self , optimizer : GeminiOptimizer , checkpoint : Path , gather_dtensor : bool , prefix : str , size_per_shard : int
2023-09-19 06:20:26 +00:00
) :
2023-07-07 08:33:06 +00:00
"""
Save sharded optimizer state dict to checkpoint folder .
As there is communication when getting state dict , this must be called on all processes .
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( optimizer , GeminiOptimizer ) , " Please boost the optimizer before saving! "
2023-07-21 06:39:01 +00:00
if os . path . isfile ( checkpoint ) :
logging . error ( f " Provided path ( { checkpoint } ) should be a directory, not a file " )
return
2023-07-07 08:33:06 +00:00
Path ( checkpoint ) . mkdir ( parents = True , exist_ok = True )
2023-07-21 06:39:01 +00:00
# Preparing file paths and index file.
states_name , save_index_file , param_group_file = get_optimizer_base_filenames ( prefix )
index_file = CheckpointIndexFile ( checkpoint )
# Store the information of param groups to param_group_file.
index_file . append_meta_data ( " param_groups " , param_group_file )
group_file_path = os . path . join ( checkpoint , param_group_file )
param_groups = optimizer . get_param_groups_for_saving ( )
torch . save ( param_groups , group_file_path )
# States are broken into shards within max_shard_size.
state_dict_shard = optimizer . state_shard ( prefix = prefix , max_shard_size = size_per_shard , only_rank_0 = True )
# Save shards of optimizer states.
is_master = self . coordinator . is_master ( )
2023-09-19 06:20:26 +00:00
total_size = save_state_dict_shards (
sharded_state_dict = state_dict_shard ,
checkpoint = checkpoint ,
index_file = index_file ,
base_filename = states_name ,
is_master = is_master ,
use_safetensors = False ,
)
2023-07-21 06:39:01 +00:00
# Wrap up index file. Only save it on master rank.
if self . coordinator . is_master ( ) :
index_file . append_meta_data ( " total_size " , total_size )
index_file . write_index_file ( save_index_file )
2023-09-19 06:20:26 +00:00
logging . info (
f " The optimizer is going to be split to checkpoint shards. "
f " You can find where each parameters has been saved in the "
f " index located at { save_index_file } . "
)
2023-07-07 08:33:06 +00:00
2023-09-20 10:29:37 +00:00
def load_sharded_optimizer ( self , optimizer : GeminiOptimizer , checkpoint_index_file : Path , prefix : str ) :
2023-07-07 08:33:06 +00:00
"""
Loading sharded optimizer from checkpoint folder , with index file given .
For each process , only loading optimizer states of parameters it controls .
"""
2023-09-20 10:29:37 +00:00
assert isinstance ( optimizer , GeminiOptimizer ) , " Please boost the optimizer before loading! "
2023-07-21 06:39:01 +00:00
if not os . path . isfile ( checkpoint_index_file ) :
logging . error ( f " Provided path ( { checkpoint_index_file } ) should be a file " )
2023-08-24 01:29:25 +00:00
assert isinstance ( optimizer , GeminiOptimizer )
2023-07-21 06:39:01 +00:00
# Read checkpoint index file.
ckpt_index_file = CheckpointIndexFile . from_file ( checkpoint_index_file )
# Load param_groups.
param_group_path = ckpt_index_file . get_param_group_filename ( )
if param_group_path is None :
2023-09-19 06:20:26 +00:00
raise RuntimeError (
f " Invalid index file path { checkpoint_index_file } for an optimizer. \
Lacking param group file under current directory . "
)
2023-07-21 06:39:01 +00:00
saved_param_groups = torch . load ( param_group_path )
optimizer . load_param_groups ( saved_param_groups )
checkpoint_files , _ = ckpt_index_file . get_checkpoint_filenames ( )
# Load optimizer states from shard files under checkpoint path.
# For each file, only load the states managed by current process.
for shard_file in checkpoint_files :
state_dict_shard = load_shard_state_dict ( Path ( shard_file ) , use_safetensors = False )
optimizer . load_param_states ( state_dict_shard )
del state_dict_shard
gc . collect ( )
optimizer . optimizer_loading_epilogue ( )
def save_lr_scheduler ( self , lr_scheduler : LRScheduler , checkpoint : str ) :
"""
Save model to checkpoint but only on master process .
"""
if self . coordinator . is_master ( ) :
super ( ) . save_lr_scheduler ( lr_scheduler , checkpoint )
2023-07-07 08:33:06 +00:00
2023-05-05 11:36:10 +00:00
class GeminiPlugin ( DPPluginBase ) :
2023-03-31 08:06:13 +00:00
"""
Plugin for Gemini .
2023-09-26 02:57:47 +00:00
` ` ` python
from colossalai . booster import Booster
from colossalai . booster . plugin import GeminiPlugin
model , train_dataset , optimizer , criterion = . . .
plugin = GeminiPlugin ( )
train_dataloader = plugin . prepare_dataloader ( train_dataset , batch_size = 8 )
booster = Booster ( plugin = plugin )
model , optimizer , train_dataloader , criterion = booster . boost ( model , optimizer , train_dataloader , criterion )
` ` `
2023-03-31 08:06:13 +00:00
Args :
2023-08-24 01:29:25 +00:00
chunk_config_dict ( dict , optional ) : chunk configuration dictionary .
chunk_init_device ( torch . device , optional ) : device to initialize the chunk .
placement_policy ( str , optional ) : " static " and " auto " . Defaults to " static " .
2023-10-17 06:07:21 +00:00
enable_gradient_accumulation ( bool , optional ) : Whether to enable gradient accumulation . When set to True , gradient will be stored after doing backward pass . Defaults to False .
2023-08-24 01:29:25 +00:00
shard_param_frac ( float , optional ) : fraction of parameters to be sharded . Only for " static " placement .
If ` shard_param_frac ` is 1.0 , it ' s equal to zero-3. If `shard_param_frac` is 0.0, it ' s equal to zero - 2. Defaults to 1.0 .
offload_optim_frac ( float , optional ) : fraction of optimizer states to be offloaded . Only for " static " placement .
If ` shard_param_frac ` is 1.0 and ` offload_optim_frac ` is 0.0 , it ' s equal to old " cuda " placement. Defaults to 0.0.
offload_param_frac ( float , optional ) : fraction of parameters to be offloaded . Only for " static " placement .
For efficiency , this argument is useful only when ` shard_param_frac ` is 1.0 and ` offload_optim_frac ` is 1.0 .
If ` shard_param_frac ` is 1.0 , ` offload_optim_frac ` is 1.0 and ` offload_param_frac ` is 1.0 , it ' s equal to old " cpu " placement.
When using static placement , we recommend users to tune ` shard_param_frac ` first and then ` offload_optim_frac ` .
Defaults to 0.0 .
warmup_non_model_data_ratio ( float , optional ) : ratio of expected non - model data memory during warmup . Only for " auto " placement . Defaults to 0.8 .
steady_cuda_cap_ratio ( float , optional ) : ratio of allowed cuda capacity for model data during steady state . Only for " auto " placement . Defaults to 0.9 .
2023-06-05 07:58:31 +00:00
precision ( str , optional ) : precision . Support ' fp16 ' and ' bf16 ' . Defaults to ' fp16 ' .
2023-10-17 06:07:21 +00:00
master_weights ( bool , optional ) : Whether to keep fp32 master parameter weights in optimizer . Defaults to True .
2023-03-31 08:06:13 +00:00
pin_memory ( bool , optional ) : use pin memory on CPU . Defaults to False .
force_outputs_fp32 ( bool , optional ) : force outputs are fp32 . Defaults to False .
strict_ddp_mode ( bool , optional ) : use strict ddp mode ( only use dp without other parallelism ) . Defaults to False .
2023-06-25 05:34:15 +00:00
search_range_m ( int , optional ) : chunk size searching range divided by 2 ^ 20. Defaults to 32.
2023-03-31 08:06:13 +00:00
hidden_dim ( int , optional ) : the hidden dimension of DNN .
Users can provide this argument to speed up searching .
If users do not know this argument before training , it is ok . We will use a default value 1024.
2023-06-25 05:34:15 +00:00
min_chunk_size_m ( float , optional ) : the minimum chunk size divided by 2 ^ 20.
2023-05-24 01:01:50 +00:00
If the aggregate size of parameters is still smaller than the minimum chunk size ,
2023-03-31 08:06:13 +00:00
all parameters will be compacted into one small chunk .
memstats ( MemStats , optional ) the memory statistics collector by a runtime memory tracer .
gpu_margin_mem_ratio ( float , optional ) : The ratio of GPU remaining memory ( after the first forward - backward )
which will be used when using hybrid CPU optimizer .
This argument is meaningless when ` placement_policy ` of ` GeminiManager ` is not " auto " .
Defaults to 0.0 .
2023-07-07 08:33:06 +00:00
initial_scale ( float , optional ) : Initial scale used by DynamicGradScaler . Defaults to 2 * * 16.
2023-03-31 08:06:13 +00:00
min_scale ( float , optional ) : Min scale used by DynamicGradScaler . Defaults to 1.
growth_factor ( float , optional ) : growth_factor used by DynamicGradScaler . Defaults to 2.
backoff_factor ( float , optional ) : backoff_factor used by DynamicGradScaler . Defaults to 0.5 .
growth_interval ( float , optional ) : growth_interval used by DynamicGradScaler . Defaults to 1000.
hysteresis ( float , optional ) : hysteresis used by DynamicGradScaler . Defaults to 2.
max_scale ( int , optional ) : max_scale used by DynamicGradScaler . Defaults to 2 * * 32.
max_norm ( float , optional ) : max_norm used for ` clip_grad_norm ` . You should notice that you shall not do
clip_grad_norm by yourself when using ZeRO DDP . The ZeRO optimizer will take care of clip_grad_norm .
norm_type ( float , optional ) : norm_type used for ` clip_grad_norm ` .
2023-04-17 03:25:35 +00:00
verbose ( bool , optional ) : verbose mode . Debug info including chunk search result will be printed . Defaults to False .
2023-03-31 08:06:13 +00:00
"""
def __init__ (
self ,
2023-08-24 01:29:25 +00:00
chunk_config_dict : Optional [ dict ] = None ,
chunk_init_device : Optional [ torch . device ] = None ,
placement_policy : str = " static " ,
2023-10-17 06:07:21 +00:00
enable_gradient_accumulation : bool = False ,
2023-09-19 06:20:26 +00:00
shard_param_frac : float = 1.0 , # only for static placement
offload_optim_frac : float = 0.0 , # only for static placement
offload_param_frac : float = 0.0 , # only for static placement
warmup_non_model_data_ratio : float = 0.8 , # only for auto placement
steady_cuda_cap_ratio : float = 0.9 , # only for auto placement
2023-06-05 07:58:31 +00:00
precision : str = " fp16 " ,
2023-10-12 02:39:08 +00:00
master_weights : bool = True ,
2023-03-31 08:06:13 +00:00
pin_memory : bool = False ,
force_outputs_fp32 : bool = False ,
strict_ddp_mode : bool = False ,
2023-06-25 05:34:15 +00:00
search_range_m : int = 32 ,
2023-03-31 08:06:13 +00:00
hidden_dim : Optional [ int ] = None ,
2023-06-25 05:34:15 +00:00
min_chunk_size_m : float = 32 ,
2023-03-31 08:06:13 +00:00
memstats : Optional [ MemStats ] = None ,
gpu_margin_mem_ratio : float = 0.0 ,
2023-07-07 08:33:06 +00:00
initial_scale : float = 2 * * 16 ,
2023-03-31 08:06:13 +00:00
min_scale : float = 1 ,
growth_factor : float = 2 ,
backoff_factor : float = 0.5 ,
growth_interval : int = 1000 ,
hysteresis : int = 2 ,
max_scale : float = 2 * * 32 ,
max_norm : float = 0.0 ,
norm_type : float = 2.0 ,
2023-04-17 03:25:35 +00:00
verbose : bool = False ,
2023-03-31 08:06:13 +00:00
) - > None :
2023-05-05 11:36:10 +00:00
super ( ) . __init__ ( )
2023-09-19 06:20:26 +00:00
assert precision in SUPPORTED_PRECISION , f " precision { precision } is not supported "
2023-03-31 08:06:13 +00:00
self . gemini_config = dict (
2023-08-24 01:29:25 +00:00
chunk_config_dict = chunk_config_dict ,
chunk_init_device = ( chunk_init_device or get_current_device ( ) ) ,
2023-03-31 08:06:13 +00:00
placement_policy = placement_policy ,
2023-10-17 06:07:21 +00:00
enable_gradient_accumulation = enable_gradient_accumulation ,
2023-08-24 01:29:25 +00:00
shard_param_frac = shard_param_frac ,
offload_optim_frac = offload_optim_frac ,
offload_param_frac = offload_param_frac ,
warmup_non_model_data_ratio = warmup_non_model_data_ratio ,
steady_cuda_cap_ratio = steady_cuda_cap_ratio ,
2023-03-31 08:06:13 +00:00
pin_memory = pin_memory ,
force_outputs_fp32 = force_outputs_fp32 ,
strict_ddp_mode = strict_ddp_mode ,
2023-06-25 05:34:15 +00:00
search_range_m = search_range_m ,
2023-03-31 08:06:13 +00:00
hidden_dim = hidden_dim ,
2023-06-25 05:34:15 +00:00
min_chunk_size_m = min_chunk_size_m ,
2023-03-31 08:06:13 +00:00
memstats = memstats ,
2023-06-05 07:58:31 +00:00
mixed_precision = PRECISION_STR_TO_DTYPE [ precision ] ,
2023-10-12 02:39:08 +00:00
master_weights = master_weights ,
2023-03-31 08:06:13 +00:00
)
2023-09-19 06:20:26 +00:00
self . zero_optim_config = dict (
gpu_margin_mem_ratio = gpu_margin_mem_ratio ,
)
self . optim_kwargs = dict (
initial_scale = initial_scale ,
growth_factor = growth_factor ,
backoff_factor = backoff_factor ,
growth_interval = growth_interval ,
hysteresis = hysteresis ,
min_scale = min_scale ,
max_scale = max_scale ,
max_norm = max_norm ,
norm_type = norm_type ,
)
2023-04-17 03:25:35 +00:00
self . verbose = verbose
2023-03-31 08:06:13 +00:00
def support_no_sync ( self ) - > bool :
return False
def control_precision ( self ) - > bool :
return True
def supported_precisions ( self ) - > List [ str ] :
2023-06-05 07:58:31 +00:00
return SUPPORTED_PRECISION
2023-03-31 08:06:13 +00:00
def control_device ( self ) - > bool :
return True
def supported_devices ( self ) - > List [ str ] :
2023-09-19 06:20:26 +00:00
return [ " cuda " ]
2023-03-31 08:06:13 +00:00
def configure (
self ,
model : nn . Module ,
2023-06-15 09:38:42 +00:00
optimizer : Optional [ Optimizer ] = None ,
criterion : Optional [ Callable ] = None ,
dataloader : Optional [ DataLoader ] = None ,
lr_scheduler : Optional [ LRScheduler ] = None ,
) - > Tuple [ nn . Module , OptimizerWrapper , Callable , DataLoader , LRScheduler ] :
2023-03-31 08:06:13 +00:00
if not isinstance ( model , ModelWrapper ) :
# convert model to sync bn
# FIXME(ver217): gemini does not support sync bn
# In torch/nn/modules/_functions.py, line 22, ``mean, invstd = torch.batch_norm_stats(input, eps)`` will get fp32 mean and invstd even though the input is fp16.
# This inconsistency of dtype will cause the error.
# We have two possible solutions:
# 1. keep batch norm always in fp32. This is hard for gemini, as it use chunks.
# 2. patch sync bn or write a new on. This is relatively easy, but we need to test it.
# model = nn.SyncBatchNorm.convert_sync_batchnorm(model, None)
# wrap the model with Gemini
2023-08-24 01:29:25 +00:00
model = GeminiDDP ( model , * * self . gemini_config , verbose = self . verbose )
2023-03-31 08:06:13 +00:00
2023-09-19 06:20:26 +00:00
if optimizer is not None and not isinstance ( optimizer , OptimizerWrapper ) :
optimizer = GeminiOptimizer (
2023-09-20 10:29:37 +00:00
optimizer , model , * * self . zero_optim_config , * * self . optim_kwargs , verbose = self . verbose
2023-09-19 06:20:26 +00:00
)
2023-03-31 08:06:13 +00:00
return model , optimizer , criterion , dataloader , lr_scheduler
def control_checkpoint_io ( self ) - > bool :
return True
def get_checkpoint_io ( self ) - > CheckpointIO :
return GeminiCheckpointIO ( )
2023-05-09 03:10:02 +00:00
2023-07-04 04:00:33 +00:00
def no_sync ( self , model : nn . Module , optimizer : OptimizerWrapper ) - > Iterator [ None ] :
2023-05-09 03:10:02 +00:00
raise NotImplementedError