fix token grad norm with tp (#547)

pull/550/head
jiaopenglong 2023-12-18 18:33:28 +08:00 committed by GitHub
parent 513ebb9c3a
commit de53b17506
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 3 deletions

View File

@ -242,7 +242,7 @@ def reduce_grads(gradients, parameters, fine_grained=False, only_output=False):
elif only_output:
param_name = p.param_name if hasattr(p, "param_name") else "unknown-padding"
if (
gpc.config.model["vocab_size"] == g.shape[0]
gpc.config.model["vocab_size"] == g.shape[0] * gpc.get_world_size(ParallelMode.TENSOR)
and gpc.config.model["hidden_size"] == g.shape[1]
and "embedding" not in param_name.lower()
):
@ -393,9 +393,13 @@ def compute_vocab_grad_norm(
if param_grads:
for grad in param_grads:
# get grad norm of each vocab
for i in range(vocab_size):
vocab_slice_size = grad.shape[0]
local_tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
for i in range(vocab_slice_size):
cur_vocab_grad_norm = get_norm([grad[i, :]], norm_type, enable_cuda_kernels)[0]
vocab_grad_norm[i] += get_tensor_norm(cur_vocab_grad_norm, move_to_cuda=True)
vocab_grad_norm[i + vocab_slice_size * local_tp_rank] += get_tensor_norm(
cur_vocab_grad_norm, move_to_cuda=True
)
if last_stage is False:
return vocab_grad_norm