mirror of https://github.com/hpcaitech/ColossalAI
[zero] sharded model manages ophooks individually (#492)
parent
c9023d4078
commit
c4c02424f3
|
@ -89,8 +89,8 @@ class ShardedModelV2(nn.Module):
|
|||
self._iter_cnter = 0
|
||||
|
||||
# Register hooks
|
||||
register_ophooks_recursively(self.module,
|
||||
[ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)])
|
||||
self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]
|
||||
register_ophooks_recursively(self.module, self._ophook_list)
|
||||
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
|
||||
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
|
||||
|
||||
|
@ -134,10 +134,14 @@ class ShardedModelV2(nn.Module):
|
|||
def backward(self, loss):
|
||||
loss.backward()
|
||||
self._post_backward_operations()
|
||||
for ophook in self._ophook_list:
|
||||
ophook.post_iter()
|
||||
|
||||
def backward_by_grad(self, tensor, grad):
|
||||
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
|
||||
self._post_backward_operations()
|
||||
for ophook in self._ophook_list:
|
||||
ophook.post_iter()
|
||||
|
||||
@torch.no_grad()
|
||||
def _post_backward_operations(self) -> None:
|
||||
|
|
Loading…
Reference in New Issue