diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 9d6849daa..ebdde83b4 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -450,6 +450,7 @@ class GeminiDDP(ModelWrapper): chunk_manager.move_chunk(grad_chunk, grads_device[p], force_copy=True) if not (master_weights) or (enable_gradient_accumulation): chunk_manager.move_chunk(chunk, grads_device[p], force_copy=True) + return empty_grad def zero_grad(self, set_to_none: bool = False) -> None: self.module.zero_grad(set_to_none=True)