fix nan grad norm

pull/306/head
mwiacx 2023-09-18 11:16:52 +08:00
parent 8b4d983856
commit 10f01c4e08
1 changed files with 3 additions and 6 deletions

View File

@ -42,6 +42,7 @@ inf = math.inf
logger = get_logger(__file__)
@llm_timeout(seconds=30, func_name="_find_tensors_with_target_memory")
def _find_tensors_with_target_memory(tensors: List[torch.Tensor], target: int) -> List[int]:
tensor_mems = [tensor.nelement() * tensor.element_size() for tensor in tensors]
approximate_thresholds = [0.01 * i for i in range(1, 100)]
@ -174,7 +175,6 @@ class HybridZeroOptimizer(BaseOptimizer):
) # 0: sender, 1: receiver
self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank]
self._fp32_flat_proxy_param_of_current_rank = None
self._memory_balance_comm_handle = None
compensation_conf = {
k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v
@ -579,14 +579,14 @@ class HybridZeroOptimizer(BaseOptimizer):
if self._enable_memory_balance and self._memory_balance_role == 0:
flat_proxy_grads = flatten(proxy_gradinets)
self._memory_balance_comm_handle = dist.isend(flat_proxy_grads, self._memory_balance_peer)
dist.send(flat_proxy_grads, self._memory_balance_peer)
# torch.cuda.synchronize()
elif self._enable_memory_balance and self._enable_memory_balance == 1:
_shape = self._fp32_flat_proxy_param_of_current_rank.shape
_device = self._fp32_flat_proxy_param_of_current_rank.device
flat_proxy_gradient = torch.empty(_shape, device=_device, dtype=self._dtype)
self._memory_balance_comm_handle = dist.irecv(flat_proxy_gradient, self._memory_balance_peer)
dist.recv(flat_proxy_gradient, self._memory_balance_peer)
# torch.cuda.synchronize()
self._fp32_flat_proxy_param_of_current_rank.grad = flat_proxy_gradient.to(
dtype=self._fp32_flat_proxy_param_of_current_rank.dtype
@ -794,9 +794,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# For those ranks that are not assigned parameters, we just wait for other ranks
# to send them updated their own parameters.
if self.has_params:
if self._enable_memory_balance:
self._memory_balance_comm_handle.wait()
self.optim.step()
# release the fp32 grad
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())