mirror of https://github.com/hpcaitech/ColossalAI
fix bug for mefture (#5299)
parent
f7e3f82a7e
commit
ddf879e2db
|
@ -16,7 +16,10 @@ import torch
|
||||||
|
|
||||||
|
|
||||||
def unwrap(model):
|
def unwrap(model):
|
||||||
return model.unwrap().module
|
if hasattr(model, "module"):
|
||||||
|
return unwrap_model(model.module)
|
||||||
|
else:
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def neftune_post_forward_hook(module, input, output):
|
def neftune_post_forward_hook(module, input, output):
|
||||||
|
|
Loading…
Reference in New Issue