diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index aad9e70..1bd3243 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -847,9 +847,11 @@ class HybridZeroOptimizer(BaseOptimizer): def broadcast_params(self, disable_overlap=False): handles = [] - assert all( - isinstance(value, list) and not value for value in self._param_bcast_sync_handler._bcast_handles.values() - ) + if self._overlap_sync_param: + assert all( + isinstance(value, list) and not value + for value in self._param_bcast_sync_handler._bcast_handles.values() + ) for group_id in range(self.num_param_groups): for rank in range(self._zero_world_size[group_id]): # The following operations are performed only on the rank to which parameters are assigned.