fix sync condition (#6000)

pull/6003/head
Tong Li 2024-08-14 11:22:39 +08:00 committed by GitHub
parent ed97d3a5d3
commit ceb1e262e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 4 additions and 2 deletions

View File

@ -1326,8 +1326,10 @@ class HybridParallelPlugin(PipelinePluginBase):
)
# run with gradients accumulation
if model.require_grad_sync == False or (
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
if (
model.require_grad_sync == False
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
or not torch.is_grad_enabled()
):
return outputs