mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix zero ddp warmup check (#2545)
parent
fa3d66feb9
commit
5b1854309a
|
@ -58,6 +58,10 @@ class GeminiManager:
|
|||
self._evict_time = 0
|
||||
self._comp_cuda_demand_time = 0
|
||||
|
||||
@property
|
||||
def need_warmup(self) -> bool:
|
||||
return self.policy_name in ('auto', 'const')
|
||||
|
||||
def is_warmup(self):
|
||||
return self._warmup
|
||||
|
||||
|
|
|
@ -269,7 +269,8 @@ class ZeroDDP(ColoDDP):
|
|||
# check whether we are in a inference mode
|
||||
grad_flag = torch.is_grad_enabled()
|
||||
if not grad_flag:
|
||||
assert not self.gemini_manager.is_warmup(), "You should run a completed iteration as your warmup iter"
|
||||
assert not self.gemini_manager.need_warmup or not self.gemini_manager.is_warmup(
|
||||
), "You should run a completed iteration as your warmup iter"
|
||||
|
||||
args, kwargs = _cast_float(args, torch.half), _cast_float(kwargs, torch.half)
|
||||
self.module.zero_grad(set_to_none=True)
|
||||
|
|
Loading…
Reference in New Issue