You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/gemini/ophooks/_shard_param_ophook.py

48 lines
1.3 KiB

import torch
from colossalai.registry import OPHOOKS
from . import BaseOpHook
@OPHOOKS.register_module
class ShardParamHook(BaseOpHook):
"""
A hook to process sharded param before and afther FWD and BWD operator executing.
"""
def __init__(self):
super().__init__()
def niter(self):
return self._niter
def pre_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
param.data = param.ca_attr.payload()
def post_fwd_exec(self, module: torch.nn.Module, *args):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
param.data = param.ca_attr.payload()
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.gather()
param.data = param.ca_attr.payload()
def post_bwd_exec(self, module: torch.nn.Module, input):
for param in module.parameters():
assert hasattr(param, 'ca_attr')
param.ca_attr.shard()
param.data = param.ca_attr.payload()
def pre_iter(self):
pass
def post_iter(self):
pass