From bc09b95f504ffd9155d5d57eb370dd82cc641a0b Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Sat, 18 Nov 2023 18:41:58 +0800 Subject: [PATCH] [exampe] fix llama example' loss error when using gemini plugin (#5060) fix llama example --- examples/language/llama2/finetune.py | 1 + 1 file changed, 1 insertion(+) diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py index 33aa1d33e..f7708b1a3 100644 --- a/examples/language/llama2/finetune.py +++ b/examples/language/llama2/finetune.py @@ -58,6 +58,7 @@ def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = Non def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor = tensor.data tensor.div_(dist.get_world_size()) return tensor