mirror of https://github.com/hpcaitech/ColossalAI
fix bug for mefture (#5299)
parent
f7e3f82a7e
commit
ddf879e2db
|
@ -16,7 +16,10 @@ import torch
|
|||
|
||||
|
||||
def unwrap(model):
|
||||
return model.unwrap().module
|
||||
if hasattr(model, "module"):
|
||||
return unwrap_model(model.module)
|
||||
else:
|
||||
return model
|
||||
|
||||
|
||||
def neftune_post_forward_hook(module, input, output):
|
||||
|
|
Loading…
Reference in New Issue