ColossalAI/colossalai/engine/ophooks/zero_hook.py

95 lines
3.9 KiB
Python
Raw Normal View History

2022-03-15 09:07:35 +00:00
from typing import Optional
import torch
import torch.distributed as dist
from colossalai.registry import OPHOOKS
from colossalai.utils import get_current_device
2022-03-15 09:07:35 +00:00
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
2022-03-14 06:48:32 +00:00
from colossalai.zero.shard_utils import BaseShardStrategy
from colossalai.zero.sharded_param.tensorful_state import TensorState
2022-03-14 06:48:32 +00:00
from ._base_ophook import BaseOpHook
2022-04-01 01:22:33 +00:00
from colossalai.zero.shard_utils.tensor_utils import colo_model_data_tensor_move_inline
@OPHOOKS.register_module
class ZeroHook(BaseOpHook):
"""
A hook to process sharded param for ZeRO method.
"""
def __init__(self,
shard_strategy: BaseShardStrategy,
memstarts_collector: Optional[MemStatsCollector],
process_group: Optional[dist.ProcessGroup] = None):
super().__init__()
self.shard_strategy = shard_strategy
self.process_group = process_group
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
self.computing_device = torch.device(f'cuda:{get_current_device()}')
self._memstarts_collector = memstarts_collector
def pre_fwd_exec(self, module: torch.nn.Module, *args):
2022-03-14 06:48:32 +00:00
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
param.data = param.colo_attr.sharded_data_tensor.payload
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
2022-03-14 06:48:32 +00:00
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
2022-03-14 06:48:32 +00:00
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.gather(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
colo_model_data_tensor_move_inline(param.colo_attr.sharded_data_tensor, self.computing_device)
param.data = param.colo_attr.sharded_data_tensor.payload
if self._memstarts_collector:
self._memstarts_collector.sample_memstats()
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters(recurse=False):
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
2022-03-14 06:48:32 +00:00
tensor_list = []
for param in module.parameters(recurse=False):
assert hasattr(param, 'colo_attr')
tensor_list.append(param.colo_attr.sharded_data_tensor)
self.shard_strategy.shard(tensor_list, self.process_group)
for param in module.parameters(recurse=False):
param.colo_attr.remove_torch_payload()
def pre_iter(self):
pass
def post_iter(self):
pass