mirror of https://github.com/hpcaitech/ColossalAI
fix _post_init_method of zero init ctx (#847)
parent
2a0a427e04
commit
0f7ed8c192
|
@ -155,7 +155,7 @@ class ZeroInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
torch.set_rng_state(self.cpu_rng_state)
|
torch.set_rng_state(self.cpu_rng_state)
|
||||||
torch.cuda.set_rng_state(self.cuda_rng_state)
|
torch.cuda.set_rng_state(self.cuda_rng_state)
|
||||||
|
|
||||||
def _post_init_method(self, module: torch.nn.Module):
|
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||||
"""
|
"""
|
||||||
The function to call at the end of the constructor of each module.
|
The function to call at the end of the constructor of each module.
|
||||||
NOTE() The module may be passed to this function multiple times.
|
NOTE() The module may be passed to this function multiple times.
|
||||||
|
|
Loading…
Reference in New Issue