[hotfix]different overflow status lead to communication stuck. (#1175)

* [CLI] add CLI launcher

* Revert "[CLI] add CLI launcher"

This reverts commit df7e6506d4.

* [hotfix]fix some bugs caused by refactored schedule.

* [hotfix]different overflow statu llead to communication stuck.
pull/1176/head
YuliangLiu0306 2022-06-27 09:53:57 +08:00 committed by GitHub
parent aa7bef73d4
commit e27645376d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 35 additions and 16 deletions

View File

@ -258,24 +258,25 @@ class FP16Optimizer(Optimizer):
overflow = self._check_overflow()
self._grad_scaler.update(overflow)
if overflow:
self.zero_grad()
return False, None
# Clip the main gradients.
grad_norm = None
if self._clip_grad_max_norm > 0.0:
grad_norm = self.clip_grad_norm(self._clip_grad_max_norm)
# Step the optimizer.
self._optimizer.step()
if not overflow:
# Step the optimizer.
self._optimizer.step()
# Update params from main params.
self._update_fp16_param_from_fp32_param()
# Update params from main params.
self._update_fp16_param_from_fp32_param()
# Successful update.
return True, grad_norm
# Successful update.
return True, grad_norm
else:
return False, None
def backward(self, loss):
"""Execute backward pass.

View File

@ -57,10 +57,14 @@ def process_object_to_send(object_send, scatter_gather_tensors):
if send_split:
object_send = split_tensor_into_1d_equal_chunks(object_send)
return object_send
object_send_list = []
for tensor_send in object_send:
send_split = _get_tensor_shape(tensor_send.shape, scatter_gather_tensors)[1]
if send_split:
tensor_send = split_tensor_into_1d_equal_chunks(tensor_send)
object_send_list.append(split_tensor_into_1d_equal_chunks(tensor_send))
object_send = tuple(object_send_list)
return object_send
@ -161,15 +165,17 @@ def _communicate(object_send_next: Union[torch.Tensor, List[torch.Tensor]] = Non
if isinstance(tensor_recv_prev, torch.Tensor):
tensor_recv_prev = gather_split_1d_tensor(tensor_recv_prev).view(recv_prev_shape).requires_grad_()
else:
for tensor_recv, tensor_shape in zip(tensor_recv_prev, recv_prev_shape):
tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
for index in range(len(tensor_recv_prev)):
tensor_recv_prev[index] = gather_split_1d_tensor(tensor_recv_prev[index]).view(
recv_prev_shape[index]).requires_grad_()
if recv_next and recv_next_split:
if isinstance(tensor_recv_next, torch.Tensor):
tensor_recv_next = gather_split_1d_tensor(tensor_recv_next).view(recv_next_shape).requires_grad_()
else:
for tensor_recv, tensor_shape in zip(tensor_recv_next, recv_next_shape):
tensor_recv = gather_split_1d_tensor(tensor_recv).view(tensor_shape).requires_grad_()
for index in range(len(tensor_recv_next)):
tensor_recv_next[index] = gather_split_1d_tensor(tensor_recv_next[index]).view(
recv_next_shape[index]).requires_grad_()
return tensor_recv_prev, tensor_recv_next

View File

@ -151,6 +151,14 @@ def _move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.T
return norm
def _get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
if isinstance(norm, float):
norm = torch.Tensor([norm])
if move_to_cuda:
norm = norm.to(torch.cuda.current_device())
return norm
# ======== Gradient Clipping =========
@ -192,14 +200,15 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
params.append(param)
if len(params) == 0:
return 0.0
enable_cuda_kernels = False
else:
enable_cuda_kernels = params[0].grad.device.type == 'cuda'
# Norm parameters.
max_norm = float(max_norm)
norm_type = float(norm_type)
# Parameters can be on CPU or CUDA
# If parameters are on CPU, disable CUDA kernerls
enable_cuda_kernels = params[0].grad.device.type == 'cuda'
# Calculate norm.
if norm_type == inf:
@ -238,7 +247,10 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
tensor_parallel_norm = _calc_lp(tensor_parallel_grads, norm_type)
no_tensor_parallel_norm = _calc_lp(no_tensor_parallel_grads, norm_type)
zero_sharded_norm = _calc_lp(zero_sharded_grads, norm_type)
# If norm is type of float, then we convert them into torch.Tensor.
tensor_parallel_norm = _get_tensor_norm(tensor_parallel_norm, enable_cuda_kernels)
no_tensor_parallel_norm = _get_tensor_norm(no_tensor_parallel_norm, enable_cuda_kernels)
zero_sharded_norm = _get_tensor_norm(zero_sharded_norm, enable_cuda_kernels)
# If grads are on CPU, the norms is also on CPU. Cast them to CUDA tensors
if not enable_cuda_kernels:
tensor_parallel_norm = _move_norm_to_cuda(tensor_parallel_norm)