2022-03-15 09:07:35 +00:00
from typing import Optional
2022-03-08 10:18:06 +00:00
import torch
2022-03-18 08:18:31 +00:00
import torch . distributed as dist
2022-04-20 03:29:48 +00:00
from colossalai . logging import get_dist_logger
2022-03-08 10:18:06 +00:00
from colossalai . registry import OPHOOKS
2022-04-11 15:13:02 +00:00
2022-03-10 06:08:58 +00:00
from colossalai . utils import get_current_device
2022-04-11 15:13:02 +00:00
2022-03-14 06:48:32 +00:00
from colossalai . zero . shard_utils import BaseShardStrategy
2022-07-14 05:44:26 +00:00
from colossalai . gemini . ophooks import BaseOpHook
2022-03-08 10:18:06 +00:00
2022-04-19 02:13:08 +00:00
from colossalai . gemini . stateful_tensor_mgr import StatefulTensorMgr
from colossalai . gemini . memory_tracer import MemStatsCollector
2022-04-24 05:08:48 +00:00
from colossalai . gemini . stateful_tensor import TensorState
2022-04-19 02:13:08 +00:00
2022-03-08 10:18:06 +00:00
@OPHOOKS.register_module
class ZeroHook ( BaseOpHook ) :
"""
A hook to process sharded param for ZeRO method .
"""
2022-03-18 08:18:31 +00:00
def __init__ ( self ,
shard_strategy : BaseShardStrategy ,
2022-04-08 09:51:34 +00:00
memstarts_collector : Optional [ MemStatsCollector ] = None ,
stateful_tensor_mgr : Optional [ StatefulTensorMgr ] = None ,
2022-03-18 08:18:31 +00:00
process_group : Optional [ dist . ProcessGroup ] = None ) :
2022-03-08 10:18:06 +00:00
super ( ) . __init__ ( )
2022-04-20 03:29:48 +00:00
self . logger = get_dist_logger ( " ZeROHook " )
2022-03-08 10:18:06 +00:00
self . shard_strategy = shard_strategy
2022-03-18 08:18:31 +00:00
self . process_group = process_group
2022-04-08 09:51:34 +00:00
2022-03-10 06:08:58 +00:00
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
2022-04-11 08:47:57 +00:00
self . computing_device = get_current_device ( )
2022-03-08 10:18:06 +00:00
2022-03-14 14:05:30 +00:00
self . _memstarts_collector = memstarts_collector
2022-04-08 09:51:34 +00:00
self . _stateful_tensor_mgr = stateful_tensor_mgr
2022-03-14 14:05:30 +00:00
2022-04-14 04:01:12 +00:00
def gather_parameters ( self , module : torch . nn . Module ) :
2022-04-08 12:23:26 +00:00
# gather sharded parameters
if module . param_is_sharded :
tensor_list = [ ]
for param in module . parameters ( recurse = False ) :
assert hasattr ( param , ' colo_attr ' )
tensor_list . append ( param . colo_attr . sharded_data_tensor )
self . shard_strategy . gather ( tensor_list , self . process_group )
2022-03-08 10:18:06 +00:00
2022-04-14 04:01:12 +00:00
def shard_parameters ( self , module : torch . nn . Module ) :
# shard gathered parameters
if module . param_is_sharded :
tensor_list = [ ]
for param in module . parameters ( recurse = False ) :
assert hasattr ( param , ' colo_attr ' )
tensor_list . append ( param . colo_attr . sharded_data_tensor )
self . shard_strategy . shard ( tensor_list , self . process_group )
def adjust_module_data ( self , module : torch . nn . Module ) :
# record overall data statistics
2022-03-14 14:05:30 +00:00
if self . _memstarts_collector :
2022-04-14 04:01:12 +00:00
self . _memstarts_collector . sample_overall_data ( )
2022-03-14 14:05:30 +00:00
2022-04-14 04:01:12 +00:00
for param in module . parameters ( recurse = False ) :
param . colo_attr . sharded_data_tensor . trans_state ( TensorState . COMPUTE )
# adjust stateful tensor to get enough CUDA memory
self . _stateful_tensor_mgr . adjust_layout ( )
# record model data statistics
if self . _memstarts_collector :
self . _memstarts_collector . sample_model_data ( )
def pre_fwd_exec ( self , module : torch . nn . Module , * args ) :
self . adjust_module_data ( module )
self . gather_parameters ( module )
2022-03-30 07:57:46 +00:00
for param in module . parameters ( recurse = False ) :
2022-04-13 06:54:26 +00:00
param . data = param . colo_attr . data_payload
2022-04-08 09:51:34 +00:00
assert param . data . device . type == ' cuda ' , f " PRE FWD param.data must be on CUDA "
2022-03-30 07:57:46 +00:00
2022-03-08 10:18:06 +00:00
def post_fwd_exec ( self , module : torch . nn . Module , * args ) :
2022-04-08 12:23:26 +00:00
# change tensor state to HOLD_AFTER_FWD
2022-03-30 07:57:46 +00:00
for param in module . parameters ( recurse = False ) :
2022-03-31 04:25:45 +00:00
param . colo_attr . sharded_data_tensor . trans_state ( TensorState . HOLD_AFTER_FWD )
2022-03-30 07:57:46 +00:00
2022-04-14 04:01:12 +00:00
self . shard_parameters ( module )
2022-04-08 12:23:26 +00:00
# remove torch payload
2022-03-28 09:42:18 +00:00
for param in module . parameters ( recurse = False ) :
2022-04-13 06:54:26 +00:00
param . colo_attr . set_data_none ( )
2022-03-08 10:18:06 +00:00
def pre_bwd_exec ( self , module : torch . nn . Module , input , output ) :
2022-04-14 04:01:12 +00:00
self . adjust_module_data ( module )
self . gather_parameters ( module )
2022-03-30 07:57:46 +00:00
for param in module . parameters ( recurse = False ) :
2022-04-13 06:54:26 +00:00
param . data = param . colo_attr . data_payload
2022-04-08 09:51:34 +00:00
assert param . data . device . type == ' cuda ' , f " PRE BWD param.data must be on CUDA "
2022-03-30 07:57:46 +00:00
2022-03-08 10:18:06 +00:00
def post_bwd_exec ( self , module : torch . nn . Module , input ) :
2022-04-08 12:23:26 +00:00
# change tensor state to HOLD_AFTER_BWD
2022-03-30 07:57:46 +00:00
for param in module . parameters ( recurse = False ) :
2022-03-31 04:25:45 +00:00
param . colo_attr . sharded_data_tensor . trans_state ( TensorState . HOLD_AFTER_BWD )
2022-03-30 07:57:46 +00:00
2022-04-14 04:01:12 +00:00
self . shard_parameters ( module )
2022-03-30 07:57:46 +00:00
2022-04-08 12:23:26 +00:00
# remove torch payload
2022-03-28 09:42:18 +00:00
for param in module . parameters ( recurse = False ) :
2022-04-13 06:54:26 +00:00
param . colo_attr . set_data_none ( )
2022-03-08 10:18:06 +00:00
def pre_iter ( self ) :
pass
def post_iter ( self ) :
2022-04-08 09:51:34 +00:00
if self . _stateful_tensor_mgr :
2022-06-21 02:44:01 +00:00
self . logger . debug (
2022-04-26 10:08:31 +00:00
f " CPU-GPU data moving this iteration { self . _stateful_tensor_mgr . cpu_gpu_move_volume / 1e9 } GB, get layout info time: { self . _stateful_tensor_mgr . _layout_time } , evict cpu time: { self . _stateful_tensor_mgr . _evict_time } " ,
ranks = [ 0 ] )
2022-04-26 07:05:03 +00:00
self . _stateful_tensor_mgr . finish_iter ( )