mirror of https://github.com/InternLM/InternLM
fix(train.py): fix overflow grad norm error (#230)
parent
2acb278e1f
commit
29dd401071
|
@ -585,7 +585,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
)
|
)
|
||||||
self._grad_store._averaged_gradients = dict()
|
self._grad_store._averaged_gradients = dict()
|
||||||
self.zero_grad()
|
self.zero_grad()
|
||||||
return False, None
|
return False, norms
|
||||||
|
|
||||||
# copy the grad of fp16 param to fp32 param
|
# copy the grad of fp16 param to fp32 param
|
||||||
single_grad_partition_groups = []
|
single_grad_partition_groups = []
|
||||||
|
|
2
train.py
2
train.py
|
@ -235,7 +235,7 @@ def main(args):
|
||||||
train_state.step_count += 1
|
train_state.step_count += 1
|
||||||
else:
|
else:
|
||||||
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
train_state.inf_nan_skip_batches += 1 # record the amount of updating parameters unsuccessfully.
|
||||||
if -99.0 in grad_norm_groups and gpc.is_rank_for_log(): # -99.0 encodes a specific failure case
|
if -1 in grad_norm_groups and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
||||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||||
send_alert_message(
|
send_alert_message(
|
||||||
address=gpc.config.alert_address,
|
address=gpc.config.alert_address,
|
||||||
|
|
Loading…
Reference in New Issue