fix _post_init_method of zero init ctx (#847)

pull/848/head
ver217 2022-04-24 14:16:50 +08:00 committed by GitHub
parent 2a0a427e04
commit 0f7ed8c192
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -155,7 +155,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
torch.set_rng_state(self.cpu_rng_state) torch.set_rng_state(self.cpu_rng_state)
torch.cuda.set_rng_state(self.cuda_rng_state) torch.cuda.set_rng_state(self.cuda_rng_state)
def _post_init_method(self, module: torch.nn.Module): def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
""" """
The function to call at the end of the constructor of each module. The function to call at the end of the constructor of each module.
NOTE() The module may be passed to this function multiple times. NOTE() The module may be passed to this function multiple times.