mirror of https://github.com/InternLM/InternLM
fix(optimizer/hybrid_zero_optim.py): fix bucket size full judge condition when reduce scatter overlap
parent
10b5056e1e
commit
4851291356
|
@ -404,7 +404,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# check if the bucket is full
|
||||
# if full, will reduce the grads already in the bucket
|
||||
# after reduction, the bucket will be empty
|
||||
if current_bucket.num_elements_in_bucket(reduce_rank) >= self._reduce_bucket_size:
|
||||
if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||
self._accum_grads_store_in_bucket(current_bucket, reduce_rank)
|
||||
|
||||
# otherwise, add the parameter into bucket.
|
||||
|
|
|
@ -600,12 +600,12 @@ def record_current_batch_training_metrics(
|
|||
step_count=batch_count,
|
||||
cur_step_loss=loss.item(),
|
||||
)
|
||||
|
||||
loss_list.append(loss.item())
|
||||
if batch_count >= 5:
|
||||
tgs_list.append(tgs_origin)
|
||||
tflops_list.append(tflops)
|
||||
tflops_list_2.append(tflops_2)
|
||||
|
||||
if batch_count == gpc.config.data.total_steps - 1:
|
||||
print(tgs_list, flush=True)
|
||||
avg_tgs = sum(tgs_list) / len(tgs_list)
|
||||
|
@ -627,5 +627,5 @@ def record_current_batch_training_metrics(
|
|||
if abs(tf - avg_tflops_2) > 10:
|
||||
tflops_list_2.remove(tf)
|
||||
print(f"avg_tflops_2: {sum(tflops_list_2)/len(tflops_list_2)}", flush=True)
|
||||
|
||||
|
||||
print("loss: ", loss_list, flush=True)
|
||||
|
|
Loading…
Reference in New Issue