fix(train.py): fix overflow grad norm error (#230)

pull/231/head
huangting4201 2023-08-24 17:46:27 +08:00 committed by GitHub
parent 2acb278e1f
commit 29dd401071
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 2 additions and 2 deletions

View File

@ -585,7 +585,7 @@ class HybridZeroOptimizer(BaseOptimizer):
)
self._grad_store._averaged_gradients = dict()
self.zero_grad()
return False, None
return False, norms
# copy the grad of fp16 param to fp32 param
single_grad_partition_groups = []

View File

@ -235,7 +235,7 @@ def main(args):
train_state.step_count += 1
else:
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
send_alert_message(
address=gpc.config.alert_address,