[hotfix] fix zero ddp warmup check (#2545)

pull/2492/head
ver217 2023-02-02 16:42:38 +08:00 committed by GitHub
parent fa3d66feb9
commit 5b1854309a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 1 deletions

View File

@ -58,6 +58,10 @@ class GeminiManager:
self._evict_time = 0
self._comp_cuda_demand_time = 0
@property
def need_warmup(self) -> bool:
return self.policy_name in ('auto', 'const')
def is_warmup(self):
return self._warmup

View File

@ -269,7 +269,8 @@ class ZeroDDP(ColoDDP):
# check whether we are in a inference mode
grad_flag = torch.is_grad_enabled()
if not grad_flag:
assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter"
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
), "You should run a completed iteration as your warmup iter"
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
self.module.zero_grad(set_to_none=True)