[zero] sharded model manages ophooks individually (#492)

pull/493/head
ver217 2022-03-22 17:33:20 +08:00 committed by GitHub
parent c9023d4078
commit c4c02424f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 2 deletions

View File

@ -89,8 +89,8 @@ class ShardedModelV2(nn.Module):
self._iter_cnter = 0
# Register hooks
register_ophooks_recursively(self.module,
[ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)])
self._ophook_list = [ZeroHook(self.shard_strategy, self._memstats_collector, self.process_group)]
register_ophooks_recursively(self.module, self._ophook_list)
self.param_hook_mgr = BaseParamHookMgr(list(self.module.parameters()))
self.param_hook_mgr.register_backward_hooks(self._grad_post_backward_hook)
@ -134,10 +134,14 @@ class ShardedModelV2(nn.Module):
def backward(self, loss):
loss.backward()
self._post_backward_operations()
for ophook in self._ophook_list:
ophook.post_iter()
def backward_by_grad(self, tensor, grad):
torch.autograd.backward(tensors=tensor, grad_tensors=grad)
self._post_backward_operations()
for ophook in self._ophook_list:
ophook.post_iter()
@torch.no_grad()
def _post_backward_operations(self) -> None: