fix _post_init_method of zero init ctx (#847)

pull/848/head
ver217 2022-04-24 14:16:50 +08:00 committed by GitHub
parent 2a0a427e04
commit 0f7ed8c192
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -155,7 +155,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state)
def _post_init_method(self, module: torch.nn.Module):
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
"""
The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times.