From 5b1854309a066f058f5a51c8adcbff1e51870c25 Mon Sep 17 00:00:00 2001 From: ver217 Date: Thu, 2 Feb 2023 16:42:38 +0800 Subject: [PATCH] [hotfix] fix zero ddp warmup check (#2545) --- colossalai/gemini/gemini_mgr.py | 4 ++++ colossalai/nn/parallel/data_parallel.py | 3 ++- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/colossalai/gemini/gemini_mgr.py b/colossalai/gemini/gemini_mgr.py index 08fc0cf92..72a5e4a7f 100644 --- a/colossalai/gemini/gemini_mgr.py +++ b/colossalai/gemini/gemini_mgr.py @@ -58,6 +58,10 @@ class GeminiManager: self._evict_time = 0 self._comp_cuda_demand_time = 0 + @property + def need_warmup(self) -> bool: + return self.policy_name in ('auto', 'const') + def is_warmup(self): return self._warmup diff --git a/colossalai/nn/parallel/data_parallel.py b/colossalai/nn/parallel/data_parallel.py index a30416ab9..a313da59b 100644 --- a/colossalai/nn/parallel/data_parallel.py +++ b/colossalai/nn/parallel/data_parallel.py @@ -269,7 +269,8 @@ class ZeroDDP(ColoDDP): # check whether we are in a inference mode grad_flag = torch.is_grad_enabled() if not grad_flag: - assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter" + assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup( + ), "You should run a completed iteration as your warmup iter" args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half) self.module.zero_grad(set_to_none=True)