From 10f01c4e08935b49c930529211d5ec9fced8f270 Mon Sep 17 00:00:00 2001 From: mwiacx <759046501@qq.com> Date: Mon, 18 Sep 2023 11:16:52 +0800 Subject: [PATCH] fix nan grad norm --- internlm/solver/optimizer/hybrid_zero_optim.py | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index d94bce9..ee42b4a 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -42,6 +42,7 @@ inf = math.inf logger = get_logger(__file__) +@llm_timeout(seconds=30, func_name="_find_tensors_with_target_memory") def _find_tensors_with_target_memory(tensors: List[torch.Tensor], target: int) -> List[int]: tensor_mems = [tensor.nelement() * tensor.element_size() for tensor in tensors] approximate_thresholds = [0.01 * i for i in range(1, 100)] @@ -174,7 +175,6 @@ class HybridZeroOptimizer(BaseOptimizer): ) # 0: sender, 1: receiver self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank] self._fp32_flat_proxy_param_of_current_rank = None - self._memory_balance_comm_handle = None compensation_conf = { k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v @@ -579,14 +579,14 @@ class HybridZeroOptimizer(BaseOptimizer): if self._enable_memory_balance and self._memory_balance_role == 0: flat_proxy_grads = flatten(proxy_gradinets) - self._memory_balance_comm_handle = dist.isend(flat_proxy_grads, self._memory_balance_peer) + dist.send(flat_proxy_grads, self._memory_balance_peer) # torch.cuda.synchronize() elif self._enable_memory_balance and self._enable_memory_balance == 1: _shape = self._fp32_flat_proxy_param_of_current_rank.shape _device = self._fp32_flat_proxy_param_of_current_rank.device flat_proxy_gradient = torch.empty(_shape, device=_device, dtype=self._dtype) - self._memory_balance_comm_handle = dist.irecv(flat_proxy_gradient, self._memory_balance_peer) + dist.recv(flat_proxy_gradient, self._memory_balance_peer) # torch.cuda.synchronize() self._fp32_flat_proxy_param_of_current_rank.grad = flat_proxy_gradient.to( dtype=self._fp32_flat_proxy_param_of_current_rank.dtype @@ -794,9 +794,6 @@ class HybridZeroOptimizer(BaseOptimizer): # For those ranks that are not assigned parameters, we just wait for other ranks # to send them updated their own parameters. if self.has_params: - if self._enable_memory_balance: - self._memory_balance_comm_handle.wait() - self.optim.step() # release the fp32 grad release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())