mirror of https://github.com/InternLM/InternLM
fix nan grad norm
parent
8b4d983856
commit
10f01c4e08
|
@ -42,6 +42,7 @@ inf = math.inf
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
@llm_timeout(seconds=30, func_name="_find_tensors_with_target_memory")
|
||||||
def _find_tensors_with_target_memory(tensors: List[torch.Tensor], target: int) -> List[int]:
|
def _find_tensors_with_target_memory(tensors: List[torch.Tensor], target: int) -> List[int]:
|
||||||
tensor_mems = [tensor.nelement() * tensor.element_size() for tensor in tensors]
|
tensor_mems = [tensor.nelement() * tensor.element_size() for tensor in tensors]
|
||||||
approximate_thresholds = [0.01 * i for i in range(1, 100)]
|
approximate_thresholds = [0.01 * i for i in range(1, 100)]
|
||||||
|
@ -174,7 +175,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
) # 0: sender, 1: receiver
|
) # 0: sender, 1: receiver
|
||||||
self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank]
|
self._memory_balance_peer = gpc.get_ranks_in_group(ParallelMode.PIPELINE)[_peer_local_rank]
|
||||||
self._fp32_flat_proxy_param_of_current_rank = None
|
self._fp32_flat_proxy_param_of_current_rank = None
|
||||||
self._memory_balance_comm_handle = None
|
|
||||||
|
|
||||||
compensation_conf = {
|
compensation_conf = {
|
||||||
k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v
|
k if k > 0 else gpc.get_world_size(ParallelMode.PIPELINE) + k: v
|
||||||
|
@ -579,14 +579,14 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
if self._enable_memory_balance and self._memory_balance_role == 0:
|
if self._enable_memory_balance and self._memory_balance_role == 0:
|
||||||
flat_proxy_grads = flatten(proxy_gradinets)
|
flat_proxy_grads = flatten(proxy_gradinets)
|
||||||
|
|
||||||
self._memory_balance_comm_handle = dist.isend(flat_proxy_grads, self._memory_balance_peer)
|
dist.send(flat_proxy_grads, self._memory_balance_peer)
|
||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
elif self._enable_memory_balance and self._enable_memory_balance == 1:
|
elif self._enable_memory_balance and self._enable_memory_balance == 1:
|
||||||
_shape = self._fp32_flat_proxy_param_of_current_rank.shape
|
_shape = self._fp32_flat_proxy_param_of_current_rank.shape
|
||||||
_device = self._fp32_flat_proxy_param_of_current_rank.device
|
_device = self._fp32_flat_proxy_param_of_current_rank.device
|
||||||
flat_proxy_gradient = torch.empty(_shape, device=_device, dtype=self._dtype)
|
flat_proxy_gradient = torch.empty(_shape, device=_device, dtype=self._dtype)
|
||||||
|
|
||||||
self._memory_balance_comm_handle = dist.irecv(flat_proxy_gradient, self._memory_balance_peer)
|
dist.recv(flat_proxy_gradient, self._memory_balance_peer)
|
||||||
# torch.cuda.synchronize()
|
# torch.cuda.synchronize()
|
||||||
self._fp32_flat_proxy_param_of_current_rank.grad = flat_proxy_gradient.to(
|
self._fp32_flat_proxy_param_of_current_rank.grad = flat_proxy_gradient.to(
|
||||||
dtype=self._fp32_flat_proxy_param_of_current_rank.dtype
|
dtype=self._fp32_flat_proxy_param_of_current_rank.dtype
|
||||||
|
@ -794,9 +794,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# For those ranks that are not assigned parameters, we just wait for other ranks
|
# For those ranks that are not assigned parameters, we just wait for other ranks
|
||||||
# to send them updated their own parameters.
|
# to send them updated their own parameters.
|
||||||
if self.has_params:
|
if self.has_params:
|
||||||
if self._enable_memory_balance:
|
|
||||||
self._memory_balance_comm_handle.wait()
|
|
||||||
|
|
||||||
self.optim.step()
|
self.optim.step()
|
||||||
# release the fp32 grad
|
# release the fp32 grad
|
||||||
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
release_param_grad(self._fp32_flat_param_groups_of_current_rank.values())
|
||||||
|
|
Loading…
Reference in New Issue