Feat/optimizer (#194)

* feat(optimier.py): reduce memory footprint and avoid _check_overflow call

* feat(optimier.py): reduce memory footprint and avoid _check_overflow call

* feat(optimizer.py): overlap compute norm with allreduce

* update var and function name

* update function compute norm (#197)

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>

* feat(optimizer/hybrid_zero_optim.py): overlap gradients last bucket allreduce and compute norm (#196)

* support gradients allreduce and compute norm overlap

* fix para set error

* remove timer cal_norm for testing

* feat(optimizer/hybrid_zero_optim.py): support group global norm

* format(lint): fix lint error

* feat(optimizer/store.py): update code based on comment

---------

Co-authored-by: ChenQiaoling00 <qiaoling_chen@u.nus.edu>
Co-authored-by: huangting4201 <1538303371@qq.com>
pull/199/head
Sun Peng 2023-08-15 18:55:10 +08:00 committed by GitHub
parent 4e8bd39d8f
commit ef851d16c6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 168 additions and 92 deletions

1
.gitignore vendored
View File

@ -115,6 +115,7 @@ venv.bak/
*.pkl *.pkl
*.pkl.json *.pkl.json
*.log.json *.log.json
*.trace.json
docs/modelzoo_statistics.md docs/modelzoo_statistics.md
mmdet/.mim mmdet/.mim
work_dirs/ work_dirs/

View File

@ -7,23 +7,25 @@ from torch.nn import init
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
def manual_rms_norm(input, normalized_shape, weight, eps): def manual_rms_norm(my_input, normalized_shape, weight, eps):
# layer norm should always be calculated in float32 # layer norm should always be calculated in float32
dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1)) dims = tuple(i for i in range(-1, -len(normalized_shape) - 1, -1))
variance = input.to(torch.float32).pow(2).mean(dims, keepdim=True) variance = my_input.to(torch.float32).pow(2).mean(dims, keepdim=True)
input = input * torch.rsqrt(variance + eps) my_input = my_input * torch.rsqrt(variance + eps)
if weight is None: if weight is None:
return input return my_input
# convert into half-precision if necessary # convert into half-precision if necessary
if weight.dtype in [torch.float16, torch.bfloat16]: if weight.dtype in [torch.float16, torch.bfloat16]:
input = input.to(weight.dtype) my_input = my_input.to(weight.dtype)
return weight * input return weight * my_input
class RMSNormTorch(torch.nn.Module): class RMSNormTorch(torch.nn.Module):
"""A custom PyTorch module for RMS normalization."""
def __init__(self, normalized_shape, eps=1e-5): def __init__(self, normalized_shape, eps=1e-5):
super().__init__() super().__init__()
@ -34,8 +36,8 @@ class RMSNormTorch(torch.nn.Module):
self.weight = Parameter(torch.empty(*normalized_shape)) self.weight = Parameter(torch.empty(*normalized_shape))
self.reset_parameters() self.reset_parameters()
def forward(self, input: torch.Tensor): def forward(self, _input: torch.Tensor):
return manual_rms_norm(input, self.normalized_shape, self.weight, self.eps) return manual_rms_norm(_input, self.normalized_shape, self.weight, self.eps)
def reset_parameters(self): def reset_parameters(self):
init.ones_(self.weight) init.ones_(self.weight)

View File

@ -88,15 +88,17 @@ def gather_forward_split_backward(input_, parallel_mode, dim):
return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim) return _GatherForwardSplitBackward.apply(input_, parallel_mode, dim)
def linear_bias_wgrad_torch(input, grad_output, has_d_bias): def linear_bias_wgrad_torch(my_input, grad_output, has_d_bias):
assert input.dtype == grad_output.dtype assert my_input.dtype == grad_output.dtype
grad_weight = torch.matmul(grad_output.t(), input) grad_weight = torch.matmul(grad_output.t(), my_input)
grad_bias = grad_output.sum(dim=0) if has_d_bias else None grad_bias = grad_output.sum(dim=0) if has_d_bias else None
return grad_weight, grad_bias return grad_weight, grad_bias
# adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py # adpated from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/ops/fused_dense.py
class FusedDenseFuncTorch(FusedDenseFunc): class FusedDenseFuncTorch(FusedDenseFunc):
"""A custom PyTorch module extending FusedDenseFunc."""
@staticmethod @staticmethod
@custom_bwd @custom_bwd
def backward(ctx, grad_output, *args): def backward(ctx, grad_output, *args):
@ -173,8 +175,8 @@ class _SplitForwardGatherBackward(torch.autograd.Function):
""" """
@staticmethod @staticmethod
def symbolic(graph, input_): def symbolic(input_):
return _split(input_) return _split(input_, parallel_mode=None)
@staticmethod @staticmethod
def forward(ctx, input_, parallel_mode, dim): def forward(ctx, input_, parallel_mode, dim):

View File

@ -178,6 +178,7 @@ class HybridZeroOptimizer(BaseOptimizer):
if len(params) != 0: if len(params) != 0:
self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params) self._param_store.add_fp16_param_list_by_rank_group(rank, group_id, params)
for param in params: for param in params:
setattr(param, "group_id", group_id)
self._param_store.set_param_to_rank(param, rank) self._param_store.set_param_to_rank(param, rank)
# move to cpu to make room to create the flat tensor # move to cpu to make room to create the flat tensor
@ -317,7 +318,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# if full, will reduce the grads already in the bucket # if full, will reduce the grads already in the bucket
# after reduction, the bucket will be empty # after reduction, the bucket will be empty
if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size: if self._bucket_store.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
self._reduce_grads_stored_in_bucket(reduce_rank) self._reduce_grads_stored_in_bucket(reduce_rank, last_bucket=False)
# the param must not be reduced to ensure correctness # the param must not be reduced to ensure correctness
is_param_reduced = self._param_store.is_param_reduced(param) is_param_reduced = self._param_store.is_param_reduced(param)
@ -335,7 +336,7 @@ class HybridZeroOptimizer(BaseOptimizer):
self._bucket_store.add_grad(param.grad, reduce_rank) self._bucket_store.add_grad(param.grad, reduce_rank)
self._bucket_store.add_param(param, reduce_rank) self._bucket_store.add_param(param, reduce_rank)
def _reduce_grads_stored_in_bucket(self, reduce_rank=None): def _reduce_grads_stored_in_bucket(self, reduce_rank=None, last_bucket=False):
# reduce grads # reduce grads
self._reduce_grads_by_rank( self._reduce_grads_by_rank(
reduce_rank=reduce_rank, reduce_rank=reduce_rank,
@ -343,14 +344,6 @@ class HybridZeroOptimizer(BaseOptimizer):
bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank), bucket_size=self._bucket_store.num_elements_in_bucket(reduce_rank),
) )
# use communication stream if overlapping
# communication with computation
if self._overlap_communication:
stream = self._comm_stream
else:
stream = torch.cuda.current_stream()
with torch.cuda.stream(stream):
params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank) params_in_bucket = self._bucket_store.get_param(reduce_rank=reduce_rank)
for param in params_in_bucket: for param in params_in_bucket:
@ -368,6 +361,11 @@ class HybridZeroOptimizer(BaseOptimizer):
# update the flag # update the flag
self._param_store.set_param_reduction_state(param, True) self._param_store.set_param_reduction_state(param, True)
if self._param_store.belongs_to_current_rank(param):
self._param_store.add_reduced_param_for_compute_norm(param, last_bucket)
else:
self._param_store.add_previous_reduced_param(param)
self._bucket_store.reset_by_rank(reduce_rank) self._bucket_store.reset_by_rank(reduce_rank)
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size): def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
@ -385,9 +383,9 @@ class HybridZeroOptimizer(BaseOptimizer):
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank): def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
if self._overlap_communication: if self._overlap_communication:
torch.cuda.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
stream = self._comm_stream stream = self._comm_stream
stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params()
else: else:
stream = torch.cuda.current_stream() stream = torch.cuda.current_stream()
@ -421,6 +419,7 @@ class HybridZeroOptimizer(BaseOptimizer):
reduction_states = self._param_store.get_param_reduction_states() reduction_states = self._param_store.get_param_reduction_states()
for tensor, _ in reduction_states.items(): for tensor, _ in reduction_states.items():
reduction_states[tensor] = False reduction_states[tensor] = False
self._param_store.reset_reduced_data_for_compute_norm()
# accumulate gradient # accumulate gradient
avg_gradients = self._grad_store._averaged_gradients avg_gradients = self._grad_store._averaged_gradients
@ -469,6 +468,30 @@ class HybridZeroOptimizer(BaseOptimizer):
# Gradients may not be fully synchronized here. # Gradients may not be fully synchronized here.
def _compute_norm_with_stage(
self,
group_id: int = 0,
last_bucket: bool = False,
last_stage: bool = False,
previous_norm=None,
):
# compute norm for gradients that have been reduced
params, grads = self._param_store.get_reduced_param_for_compute_norm(group_id=group_id, last_bucket=last_bucket)
if len(params) == 0:
grads = [self.padding_grad]
params = [self.padding_tensor]
if self._clip_grad_norm > 0:
# this norm is before scaling, it will be very large
norm = compute_norm(
gradients=grads,
parameters=params,
last_stage=last_stage,
previous_norm=previous_norm,
)
return norm
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.
@ -480,7 +503,6 @@ class HybridZeroOptimizer(BaseOptimizer):
""" """
assert closure is None, "closure is not supported by step()" assert closure is None, "closure is not supported by step()"
timer("sync_grad").start()
# if not overlapping communication (no reduction hook is attached) # if not overlapping communication (no reduction hook is attached)
# we need to manually reduce these gradients # we need to manually reduce these gradients
if not self._overlap_communication: if not self._overlap_communication:
@ -490,54 +512,49 @@ class HybridZeroOptimizer(BaseOptimizer):
self._store_and_try_reduce_grads_by_bucket(param) self._store_and_try_reduce_grads_by_bucket(param)
# we need to reduce the gradients left in the communication bucket # we need to reduce the gradients left in the communication bucket
self._reduce_grads_stored_in_bucket() self._reduce_grads_stored_in_bucket(reduce_rank=None, last_bucket=True)
# compute norm for gradients in the before bucket
groups_norms = []
for group_id in range(self.num_param_groups):
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
# clear reduced grads # clear reduced grads
if self._overlap_communication: if self._overlap_communication:
torch.cuda.synchronize() # grads in the last bucket is reduced
self._comm_stream.synchronize()
self._param_store.clear_grads_of_previous_reduced_params() self._param_store.clear_grads_of_previous_reduced_params()
# compute norm for gradients in the last bucket
total_norms = []
for group_id in range(self.num_param_groups):
total_norms.append(
self._compute_norm_with_stage(
group_id=group_id, last_bucket=True, last_stage=True, previous_norm=groups_norms[group_id]
)
)
timer("sync_grad").start()
self._sync_grad() self._sync_grad()
timer("sync_grad").stop() timer("sync_grad").stop()
return self._step(closure=closure) return self._step(closure=closure, norms=total_norms)
def _step(self, closure=None): def _step(self, closure=None, norms=None):
assert closure is None, "closure is not supported by step()" assert closure is None, "closure is not supported by step()"
# check for overflow # check for overflow
found_inf = self._check_overflow() found_inf = False
# if there is INF values in grades, compute_norm func would also returns -1
# thus, we try to avoid call _check_overflow here
# found_inf = self._check_overflow()
# Because you may encounter inf when computing norm # Because you may encounter inf when computing norm
timer("cal_norm").start()
norm_groups = []
for group_id in range(self.num_param_groups):
# compute norm
if self._zero_local_rank not in self.param_group_no_params_ranks[group_id]:
gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
parameters = self._param_store.get_fp16_params_by_rank_group(
group_id=group_id, rank=self._zero_local_rank
)
else:
# in order to prevent collection communication from hanging,
# we need to involve rank that are not assigned parameters in compute_norm(),
# so we give them a fp16 vector of 0 values.
gradients = [self.padding_grad]
parameters = [self.padding_tensor]
if self._clip_grad_norm > 0: if -1 in norms:
# this norm is before scaling, it will be very large
norm_group = compute_norm(
gradients=gradients,
parameters=parameters,
)
if norm_group == -1:
timer("cal_norm").stop()
found_inf = True found_inf = True
break
norm_groups.append(norm_group)
loss_scale = float(self.loss_scale.item()) # backup loss_scale = float(self.loss_scale.item()) # backup
if not gpc.config.model.dtype is torch.float32: if gpc.config.model.dtype is not torch.float32:
self.grad_scaler.update(found_inf) self.grad_scaler.update(found_inf)
# update loss scale if overflow occurs # update loss scale if overflow occurs
if found_inf: if found_inf:
@ -550,7 +567,6 @@ class HybridZeroOptimizer(BaseOptimizer):
# copy the grad of fp16 param to fp32 param # copy the grad of fp16 param to fp32 param
single_grad_partition_groups = [] single_grad_partition_groups = []
global_norm = 0
for group_id in range(self.num_param_groups): for group_id in range(self.num_param_groups):
# compute norm # compute norm
# The following operations are performed only on the rank to which parameters are assigned. # The following operations are performed only on the rank to which parameters are assigned.
@ -559,13 +575,14 @@ class HybridZeroOptimizer(BaseOptimizer):
# create flat gradient for the flat fp32 params # create flat gradient for the flat fp32 params
gradients = self._grad_store.get_averaged_gradients_by_group(group_id) gradients = self._grad_store.get_averaged_gradients_by_group(group_id)
with torch.no_grad():
flat_fp16_avg_grads = flatten(gradients) flat_fp16_avg_grads = flatten(gradients)
self._grad_store.reset_average_gradients_by_group(group_id) self._grad_store.reset_average_gradients_by_group(group_id)
del gradients # release cuda memory gradients = None # release cuda memory
dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype dtype = self._fp32_flat_param_groups_of_current_rank[group_id].dtype
flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype) flat_fp32_avg_grads = flat_fp16_avg_grads.to(dtype)
del flat_fp16_avg_grads # release cuda memory flat_fp16_avg_grads = None # release cuda memory
param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape param_shape = self._fp32_flat_param_groups_of_current_rank[group_id].shape
assert ( assert (
@ -578,15 +595,16 @@ class HybridZeroOptimizer(BaseOptimizer):
# unscale and clip grads # unscale and clip grads
# get the global norm # get the global norm
global_norm_groups = []
if self._clip_grad_norm > 0: if self._clip_grad_norm > 0:
global_norm = sum(norm_groups) ** 0.5 for norm in norms:
global_norm_groups.append(norm**0.5)
# the following operations are performed only on the rank to which parameters are assigned. # the following operations are performed only on the rank to which parameters are assigned.
if not gpc.config.model.dtype is torch.float32: if gpc.config.model.dtype is not torch.float32:
if len(single_grad_partition_groups) != 0: if len(single_grad_partition_groups) != 0:
self._unscale_and_clip_grads(single_grad_partition_groups, global_norm, loss_scale) self._unscale_and_clip_grads(single_grad_partition_groups, global_norm_groups, loss_scale)
timer("cal_norm").stop()
# update the parameters # update the parameters
timer("step").start() timer("step").start()
@ -611,7 +629,7 @@ class HybridZeroOptimizer(BaseOptimizer):
timer("step").stop() timer("step").stop()
# update gradients may not be needed here, because the sync_params function is used in initialization, # update gradients may not be needed here, because the sync_params function is used in initialization,
# so synchronization is maintained # so synchronization is maintained
return True, global_norm / loss_scale return True, [global_norm / loss_scale for global_norm in global_norm_groups]
def broadcast_params(self, overlap=False): def broadcast_params(self, overlap=False):
handles = [] handles = []
@ -655,18 +673,20 @@ class HybridZeroOptimizer(BaseOptimizer):
return self._found_overflow.item() > 0 return self._found_overflow.item() > 0
def _unscale_and_clip_grads(self, grad_groups_flat, total_norm, loss_scale): def _unscale_and_clip_grads(self, grad_groups_flat, total_norm_groups, loss_scale):
# compute combined scale factor for this group # compute combined scale factor for this group
combined_scale = loss_scale combined_scale_groups = []
if self._clip_grad_norm > 0.0: if self._clip_grad_norm > 0.0:
# norm is in fact norm*scale # norm is in fact norm*scale
for group_id, total_norm in enumerate(total_norm_groups):
combined_scale_groups.append(loss_scale)
clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm clip = ((total_norm / loss_scale) + 1e-6) / self._clip_grad_norm
if clip > 1.0: if clip > 1.0:
combined_scale = clip * loss_scale combined_scale_groups[group_id] = clip * loss_scale
for grad in grad_groups_flat: for group_id, grad in enumerate(grad_groups_flat):
grad.data.mul_(1.0 / combined_scale) grad.data.mul_(1.0 / combined_scale_groups[group_id])
def clip_grad_norm(self, model, max_norm): def clip_grad_norm(self, model, max_norm):
# will conduct in the step() # will conduct in the step()

View File

@ -152,6 +152,11 @@ class ParameterStore(BaseStore):
self._is_param_reduced = dict() self._is_param_reduced = dict()
self._reduced_param = [] self._reduced_param = []
self._former_bucket_reduced_param = {}
self._last_bucket_reduced_param = {}
self._former_bucket_reduced_grad = {}
self._last_bucket_reduced_grad = {}
def set_param_to_rank(self, tensor: Tensor, rank: int) -> None: def set_param_to_rank(self, tensor: Tensor, rank: int) -> None:
""" """
Set the mapping between parameter to rank, each parameter should be owned by a rank. Set the mapping between parameter to rank, each parameter should be owned by a rank.
@ -223,6 +228,35 @@ class ParameterStore(BaseStore):
def add_previous_reduced_param(self, tensor): def add_previous_reduced_param(self, tensor):
self._reduced_param.append(tensor) self._reduced_param.append(tensor)
def add_reduced_param_for_compute_norm(self, param, last_bucket=False):
group_id = getattr(param, "group_id")
if last_bucket:
if group_id not in self._last_bucket_reduced_param:
self._last_bucket_reduced_param[group_id] = []
self._last_bucket_reduced_grad[group_id] = []
self._last_bucket_reduced_param[group_id].append(param)
self._last_bucket_reduced_grad[group_id].append(param.grad)
else:
if group_id not in self._former_bucket_reduced_param:
self._former_bucket_reduced_param[group_id] = []
self._former_bucket_reduced_grad[group_id] = []
self._former_bucket_reduced_param[group_id].append(param)
self._former_bucket_reduced_grad[group_id].append(param.grad)
def get_reduced_param_for_compute_norm(self, group_id=0, last_bucket=False):
if not last_bucket:
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
else:
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
def reset_reduced_data_for_compute_norm(self):
self._former_bucket_reduced_param = {}
self._last_bucket_reduced_param = {}
self._former_bucket_reduced_grad = {}
self._last_bucket_reduced_grad = {}
def clear_grads_of_previous_reduced_params(self): def clear_grads_of_previous_reduced_params(self):
if len(self._reduced_param) > 0: if len(self._reduced_param) > 0:
for param in self._reduced_param: for param in self._reduced_param:

View File

@ -21,6 +21,7 @@ logger = get_logger(__file__)
try: try:
import amp_C import amp_C
from apex.multi_tensor_apply import multi_tensor_applier from apex.multi_tensor_apply import multi_tensor_applier
APEX_AVAILABLE = True APEX_AVAILABLE = True
except (ModuleNotFoundError, ImportError): except (ModuleNotFoundError, ImportError):
logger.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!") logger.warn("The torch implementation for cal_l2norm is slower than apex. Please note this!")
@ -162,6 +163,7 @@ def sync_param(flat_tensor, tensor_list):
for p, q in zip(tensor_list, updated_params): for p, q in zip(tensor_list, updated_params):
p.data = q.data p.data = q.data
def multi_tensor_l2norm_torch(tensor_list, per_tensor): def multi_tensor_l2norm_torch(tensor_list, per_tensor):
# Convert tensor_list elements to torch.float32 # Convert tensor_list elements to torch.float32
tensor_list = [tensor.float() for tensor in tensor_list] tensor_list = [tensor.float() for tensor in tensor_list]
@ -175,6 +177,7 @@ def multi_tensor_l2norm_torch(tensor_list, per_tensor):
return l2_norm, per_tensor_norm return l2_norm, per_tensor_norm
def calc_l2_norm(grads): def calc_l2_norm(grads):
norm = 0.0 norm = 0.0
if len(grads) > 0: if len(grads) > 0:
@ -187,6 +190,7 @@ def calc_l2_norm(grads):
norm, _ = multi_tensor_l2norm_torch(grads, False) norm, _ = multi_tensor_l2norm_torch(grads, False)
return norm return norm
def calc_lp(grads, norm_type): def calc_lp(grads, norm_type):
norm = 0.0 norm = 0.0
for grad in grads: for grad in grads:
@ -195,7 +199,7 @@ def calc_lp(grads, norm_type):
return norm return norm
def compute_norm(gradients, parameters, norm_type=2): def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, norm_type=2):
"""Get the norm """Get the norm
Arguments: Arguments:
gradients (Iterable[Tensor]): The gradient value. gradients (Iterable[Tensor]): The gradient value.
@ -215,6 +219,13 @@ def compute_norm(gradients, parameters, norm_type=2):
if norm_type == inf: if norm_type == inf:
total_norm = max(g.data.abs().max() for g in gradients) total_norm = max(g.data.abs().max() for g in gradients)
total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device) total_norm_cuda = torch.FloatTensor([float(total_norm)], device=gradients[0].device)
if last_stage is False:
return total_norm_cuda
if previous_norm is not None:
total_norm_cuda = max(total_norm_cuda, previous_norm)
# Take max across all model-parallel GPUs. # Take max across all model-parallel GPUs.
if gpc.get_world_size(ParallelMode.MODEL) > 1: if gpc.get_world_size(ParallelMode.MODEL) > 1:
dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL)) dist.all_reduce(total_norm_cuda, op=dist.ReduceOp.MAX, group=gpc.get_group(ParallelMode.MODEL))
@ -261,6 +272,12 @@ def compute_norm(gradients, parameters, norm_type=2):
total_norm = tensor_parallel_norm total_norm = tensor_parallel_norm
if last_stage is False:
return total_norm
if previous_norm is not None:
total_norm = total_norm + previous_norm
# Sum across all model-parallel GPUs. # Sum across all model-parallel GPUs.
if gpc.is_initialized(ParallelMode.MODEL): if gpc.is_initialized(ParallelMode.MODEL):
dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL)) dist.all_reduce(total_norm, op=dist.ReduceOp.SUM, group=gpc.get_group(ParallelMode.MODEL))

View File

@ -7,6 +7,7 @@ import traceback
from functools import partial from functools import partial
from typing import Iterable from typing import Iterable
import numpy as np
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from torch import nn from torch import nn
@ -603,12 +604,12 @@ def main(args):
trainer_result = trainer.step() trainer_result = trainer.step()
assert trainer_result is not None assert trainer_result is not None
success_update, grad_norm = trainer_result success_update, grad_norm_groups = trainer_result
if success_update: # update parameters successfully if success_update: # update parameters successfully
train_state.step_count += 1 train_state.step_count += 1
else: else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully. train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if grad_norm == -99.0 and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.") logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message( send_alert_message(
address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}." address=gpc.config.alert_address, message=f"Warning: skip parameter update at step {batch_count}."
@ -628,7 +629,7 @@ def main(args):
trainer=trainer, trainer=trainer,
start_time=start_time, start_time=start_time,
loss=loss, loss=loss,
grad_norm=grad_norm, grad_norm=np.array(grad_norm_groups),
metric=metric, metric=metric,
update_panel=uniscale_logger is not None, update_panel=uniscale_logger is not None,
) )
@ -668,7 +669,6 @@ if __name__ == "__main__":
main(args) main(args)
except Exception: except Exception:
logger.error( logger.error(
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}", f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
exc_info=traceback.format_exc(),
) )
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc()) mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())