2022-03-15 09:07:35 +00:00
|
|
|
from typing import Optional
|
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
import torch
|
|
|
|
from colossalai.registry import OPHOOKS
|
2022-03-10 06:08:58 +00:00
|
|
|
from colossalai.utils import get_current_device
|
2022-03-15 09:07:35 +00:00
|
|
|
from colossalai.utils.memory_tracer.memstats_collector import MemStatsCollector
|
|
|
|
from colossalai.utils.memory_tracer.model_data_memtracer import \
|
|
|
|
GLOBAL_MODEL_DATA_TRACER
|
2022-03-14 06:48:32 +00:00
|
|
|
from colossalai.zero.shard_utils import BaseShardStrategy
|
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
from ._base_ophook import BaseOpHook
|
|
|
|
|
|
|
|
|
|
|
|
@OPHOOKS.register_module
|
|
|
|
class ZeroHook(BaseOpHook):
|
|
|
|
"""
|
|
|
|
A hook to process sharded param for ZeRO method.
|
|
|
|
"""
|
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
def __init__(self, shard_strategy: BaseShardStrategy, memstarts_collector: Optional[MemStatsCollector]):
|
2022-03-08 10:18:06 +00:00
|
|
|
super().__init__()
|
|
|
|
self.shard_strategy = shard_strategy
|
2022-03-10 06:08:58 +00:00
|
|
|
# NOTE(jiaruifang) Now the computing device of FWD and BWD is always on GPU
|
|
|
|
self.computing_device = torch.device(f'cuda:{get_current_device()}')
|
2022-03-08 10:18:06 +00:00
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
self._memstarts_collector = memstarts_collector
|
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
def pre_fwd_exec(self, module: torch.nn.Module, *args):
|
2022-03-14 06:48:32 +00:00
|
|
|
tensor_list = []
|
2022-03-08 10:18:06 +00:00
|
|
|
for param in module.parameters():
|
|
|
|
assert hasattr(param, 'col_attr')
|
2022-03-14 06:48:32 +00:00
|
|
|
tensor_list.append(param.col_attr.data)
|
|
|
|
self.shard_strategy.gather(tensor_list)
|
|
|
|
for param in module.parameters():
|
2022-03-10 06:08:58 +00:00
|
|
|
if param.col_attr.data.device != self.computing_device:
|
|
|
|
param.col_attr.data.to(self.computing_device)
|
2022-03-15 03:29:46 +00:00
|
|
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
|
2022-03-08 10:18:06 +00:00
|
|
|
param.data = param.col_attr.data.payload
|
|
|
|
|
2022-03-14 14:05:30 +00:00
|
|
|
if self._memstarts_collector:
|
|
|
|
self._memstarts_collector.sample_memstats()
|
|
|
|
|
2022-03-08 10:18:06 +00:00
|
|
|
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
2022-03-14 06:48:32 +00:00
|
|
|
tensor_list = []
|
2022-03-08 10:18:06 +00:00
|
|
|
for param in module.parameters():
|
|
|
|
assert hasattr(param, 'col_attr')
|
2022-03-14 06:48:32 +00:00
|
|
|
tensor_list.append(param.col_attr.data)
|
|
|
|
self.shard_strategy.shard(tensor_list)
|
|
|
|
for param in module.parameters():
|
|
|
|
param.col_attr.remove_torch_payload()
|
2022-03-08 10:18:06 +00:00
|
|
|
|
|
|
|
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
2022-03-14 06:48:32 +00:00
|
|
|
tensor_list = []
|
2022-03-08 10:18:06 +00:00
|
|
|
for param in module.parameters():
|
|
|
|
assert hasattr(param, 'col_attr')
|
2022-03-14 06:48:32 +00:00
|
|
|
tensor_list.append(param.col_attr.data)
|
|
|
|
self.shard_strategy.gather(tensor_list)
|
|
|
|
for param in module.parameters():
|
2022-03-10 06:08:58 +00:00
|
|
|
if param.col_attr.data.device != self.computing_device:
|
|
|
|
param.col_attr.data.to(self.computing_device)
|
2022-03-15 03:29:46 +00:00
|
|
|
GLOBAL_MODEL_DATA_TRACER.add_tensor(param.col_attr.data.payload)
|
2022-03-08 10:18:06 +00:00
|
|
|
param.data = param.col_attr.data.payload
|
|
|
|
# Store local accumulated grad shard
|
|
|
|
if param.grad is not None:
|
|
|
|
if param.col_attr.bwd_count == 0:
|
|
|
|
# We haven't stored local accumulated grad yet
|
2022-03-15 09:07:35 +00:00
|
|
|
assert param.col_attr.fp32_grad is None
|
|
|
|
param.col_attr.fp32_grad = param.grad.data
|
2022-03-08 10:18:06 +00:00
|
|
|
param.grad = None
|
|
|
|
else:
|
|
|
|
# We have stored local accumulated grad
|
|
|
|
# The grad here must be locally computed full grad in this backward pass
|
|
|
|
assert param.grad.shape == param.col_attr.data.origin_shape
|
|
|
|
param.col_attr.bwd_count += 1
|
2022-03-14 14:05:30 +00:00
|
|
|
if self._memstarts_collector:
|
|
|
|
self._memstarts_collector.sample_memstats()
|
2022-03-08 10:18:06 +00:00
|
|
|
|
|
|
|
def post_bwd_exec(self, module: torch.nn.Module, input):
|
2022-03-14 06:48:32 +00:00
|
|
|
tensor_list = []
|
2022-03-08 10:18:06 +00:00
|
|
|
for param in module.parameters():
|
|
|
|
assert hasattr(param, 'col_attr')
|
2022-03-14 06:48:32 +00:00
|
|
|
tensor_list.append(param.col_attr.data)
|
|
|
|
self.shard_strategy.shard(tensor_list)
|
|
|
|
for param in module.parameters():
|
|
|
|
param.col_attr.remove_torch_payload()
|
2022-03-08 10:18:06 +00:00
|
|
|
|
|
|
|
def pre_iter(self):
|
|
|
|
pass
|
|
|
|
|
|
|
|
def post_iter(self):
|
|
|
|
pass
|