diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index d2933a4af..e5acdb051 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -1326,8 +1326,10 @@ class HybridParallelPlugin(PipelinePluginBase): ) # run with gradients accumulation - if model.require_grad_sync == False or ( - isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False + if ( + model.require_grad_sync == False + or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False) + or not torch.is_grad_enabled() ): return outputs