diff --git a/examples/language/llama2/finetune.py b/examples/language/llama2/finetune.py index 33aa1d33e..f7708b1a3 100644 --- a/examples/language/llama2/finetune.py +++ b/examples/language/llama2/finetune.py @@ -58,6 +58,7 @@ def tokenize_batch_for_finetune(batch, tokenizer: Optional[LlamaTokenizer] = Non def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: dist.all_reduce(tensor, op=dist.ReduceOp.SUM) + tensor = tensor.data tensor.div_(dist.get_world_size()) return tensor