From ddf879e2db66bbd76623ac8c0752303f891b2593 Mon Sep 17 00:00:00 2001 From: Desperado-Jia <502205863@qq.com> Date: Mon, 22 Jan 2024 22:17:54 +0800 Subject: [PATCH] fix bug for mefture (#5299) --- .../Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py index 079faaace..9f6c9c1cc 100644 --- a/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py +++ b/applications/Colossal-LLaMA-2/colossal_llama2/utils/neftune_patch.py @@ -16,7 +16,10 @@ import torch def unwrap(model): - return model.unwrap().module + if hasattr(model, "module"): + return unwrap_model(model.module) + else: + return model def neftune_post_forward_hook(module, input, output):