2022-03-15 09:07:35 +00:00
|
|
|
from typing import Optional
|
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
import torch
|
2022-03-18 08:18:31 +00:00
|
|
|
import torch.distributed as dist
|
2022-04-20 03:29:48 +00:00
|
|
|
from colossalai.logging import get_dist_logger
|
2022-03-08 10:18:06 +00:00
|
|
|
from colossalai.registry import OPHOOKS
|
2022-04-11 15:13:02 +00:00
|
|
|
|
2022-03-10 06:08:58 +00:00
|
|
|
from colossalai.utils import get_current_device
|
2022-04-11 15:13:02 +00:00
|
|
|
|
2022-03-14 06:48:32 +00:00
|
|
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
2022-04-11 15:13:02 +00:00
|
|
|
from colossalai.engine.ophooks import BaseOpHook
|
2022-03-08 10:18:06 +00:00
|
|
|
|
2022-04-19 02:13:08 +00:00
|
|
|
from colossalai.gemini.stateful_tensor_mgr import StatefulTensorMgr
|
|
|
|
from colossalai.gemini.memory_tracer import MemStatsCollector
|
2022-04-24 05:08:48 +00:00
|
|
|
from colossalai.gemini.stateful_tensor import TensorState
|
2022-04-19 02:13:08 +00:00
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
|
|
|
|
@OPHOOKS.register_module
|
|
|
|
class ZeroHook(BaseOpHook):
|
|
|
|
"""
|
|
|
|
A hook to process sharded param for ZeRO method.
|
|
|
|
"""
|
|
|
|
|
2022-03-18 08:18:31 +00:00
|
|
|
def __init__(self,
|
|
|
|
shard_strategy: BaseShardStrategy,
|
2022-04-08 09:51:34 +00:00
|
|
|
memstarts_collector: Optional[MemStatsCollector] = None,
|
|
|
|
stateful_tensor_mgr: Optional[StatefulTensorMgr] = None,
|
2022-03-18 08:18:31 +00:00
|
|
|
process_group: Optional[dist.ProcessGroup] = None):
|
2022-03-08 10:18:06 +00:00
|
|
|
super().__init__()
|
2022-04-20 03:29:48 +00:00
|
|
|
self.logger = get_dist_logger("ZeROHook")
|
2022-03-08 10:18:06 +00:00
|
|
|
self.shard_strategy = shard_strategy
|
2022-03-18 08:18:31 +00:00
|
|
|
self.process_group = process_group
|
2022-04-08 09:51:34 +00:00
|
|
|
|
2022-03-10 06:08:58 +00:00
|
|
|
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
2022-04-11 08:47:57 +00:00
|
|
|
self.computing_device = get_current_device()
|
2022-03-08 10:18:06 +00:00
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
self._memstarts_collector = memstarts_collector
|
2022-04-08 09:51:34 +00:00
|
|
|
self._stateful_tensor_mgr = stateful_tensor_mgr
|
2022-03-14 14:05:30 +00:00
|
|
|
|
2022-04-14 04:01:12 +00:00
|
|
|
def gather_parameters(self, module: torch.nn.Module):
|
2022-04-08 12:23:26 +00:00
|
|
|
# gather sharded parameters
|
|
|
|
if module.param_is_sharded:
|
|
|
|
tensor_list = []
|
|
|
|
for param in module.parameters(recurse=False):
|
|
|
|
assert hasattr(param, 'colo_attr')
|
|
|
|
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
|
|
|
self.shard_strategy.gather(tensor_list, self.process_group)
|
2022-03-08 10:18:06 +00:00
|
|
|
|
2022-04-14 04:01:12 +00:00
|
|
|
def shard_parameters(self, module: torch.nn.Module):
|
|
|
|
# shard gathered parameters
|
|
|
|
if module.param_is_sharded:
|
|
|
|
tensor_list = []
|
|
|
|
for param in module.parameters(recurse=False):
|
|
|
|
assert hasattr(param, 'colo_attr')
|
|
|
|
tensor_list.append(param.colo_attr.sharded_data_tensor)
|
|
|
|
self.shard_strategy.shard(tensor_list, self.process_group)
|
|
|
|
|
|
|
|
def adjust_module_data(self, module: torch.nn.Module):
|
|
|
|
# record overall data statistics
|
2022-03-14 14:05:30 +00:00
|
|
|
if self._memstarts_collector:
|
2022-04-14 04:01:12 +00:00
|
|
|
self._memstarts_collector.sample_overall_data()
|
2022-03-14 14:05:30 +00:00
|
|
|
|
2022-04-14 04:01:12 +00:00
|
|
|
for param in module.parameters(recurse=False):
|
|
|
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.COMPUTE)
|
|
|
|
|
|
|
|
# adjust stateful tensor to get enough CUDA memory
|
|
|
|
self._stateful_tensor_mgr.adjust_layout()
|
|
|
|
|
|
|
|
# record model data statistics
|
|
|
|
if self._memstarts_collector:
|
|
|
|
self._memstarts_collector.sample_model_data()
|
|
|
|
|
|
|
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
|
|
|
self.adjust_module_data(module)
|
|
|
|
self.gather_parameters(module)
|
2022-03-30 07:57:46 +00:00
|
|
|
for param in module.parameters(recurse=False):
|
2022-04-13 06:54:26 +00:00
|
|
|
param.data = param.colo_attr.data_payload
|
2022-04-08 09:51:34 +00:00
|
|
|
assert param.data.device.type == 'cuda', f"PRE FWD param.data must be on CUDA"
|
2022-03-30 07:57:46 +00:00
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
2022-04-08 12:23:26 +00:00
|
|
|
|
|
|
|
# change tensor state to HOLD_AFTER_FWD
|
2022-03-30 07:57:46 +00:00
|
|
|
for param in module.parameters(recurse=False):
|
2022-03-31 04:25:45 +00:00
|
|
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_FWD)
|
2022-03-30 07:57:46 +00:00
|
|
|
|
2022-04-14 04:01:12 +00:00
|
|
|
self.shard_parameters(module)
|
2022-04-08 12:23:26 +00:00
|
|
|
|
|
|
|
# remove torch payload
|
2022-03-28 09:42:18 +00:00
|
|
|
for param in module.parameters(recurse=False):
|
2022-04-13 06:54:26 +00:00
|
|
|
param.colo_attr.set_data_none()
|
2022-03-08 10:18:06 +00:00
|
|
|
|
|
|
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
2022-04-14 04:01:12 +00:00
|
|
|
self.adjust_module_data(module)
|
|
|
|
self.gather_parameters(module)
|
2022-03-30 07:57:46 +00:00
|
|
|
for param in module.parameters(recurse=False):
|
2022-04-13 06:54:26 +00:00
|
|
|
param.data = param.colo_attr.data_payload
|
2022-04-08 09:51:34 +00:00
|
|
|
assert param.data.device.type == 'cuda', f"PRE BWD param.data must be on CUDA"
|
2022-03-30 07:57:46 +00:00
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
def post_bwd_exec(self, module: torch.nn.Module, input):
|
2022-04-08 12:23:26 +00:00
|
|
|
|
|
|
|
# change tensor state to HOLD_AFTER_BWD
|
2022-03-30 07:57:46 +00:00
|
|
|
for param in module.parameters(recurse=False):
|
2022-03-31 04:25:45 +00:00
|
|
|
param.colo_attr.sharded_data_tensor.trans_state(TensorState.HOLD_AFTER_BWD)
|
2022-03-30 07:57:46 +00:00
|
|
|
|
2022-04-14 04:01:12 +00:00
|
|
|
self.shard_parameters(module)
|
2022-03-30 07:57:46 +00:00
|
|
|
|
2022-04-08 12:23:26 +00:00
|
|
|
# remove torch payload
|
2022-03-28 09:42:18 +00:00
|
|
|
for param in module.parameters(recurse=False):
|
2022-04-13 06:54:26 +00:00
|
|
|
param.colo_attr.set_data_none()
|
2022-03-08 10:18:06 +00:00
|
|
|
|
|
|
|
def pre_iter(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def post_iter(self):
|
2022-04-08 09:51:34 +00:00
|
|
|
if self._stateful_tensor_mgr:
|
2022-04-20 03:29:48 +00:00
|
|
|
self.logger.info(
|
|
|
|
f"CPU-GPU data moving this iteration {self._stateful_tensor_mgr.cpu_gpu_move_volume/1e9} GB", ranks=[0])
|
2022-04-26 07:05:03 +00:00
|
|
|
self._stateful_tensor_mgr.finish_iter()
|