mirror of https://github.com/hpcaitech/ColossalAI
fix _post_init_method of zero init ctx (#847)
parent
2a0a427e04
commit
0f7ed8c192
|
@ -155,7 +155,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
torch.set_rng_state(self.cpu_rng_state)
|
||||
torch.cuda.set_rng_state(self.cuda_rng_state)
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module):
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
The function to call at the end of the constructor of each module.
|
||||
NOTE() The module may be passed to this function multiple times.
|
||||
|
|
Loading…
Reference in New Issue