mirror of https://github.com/InternLM/InternLM
fix(optimizer/hybrid_zero_optim.py): fix bucket size full judge condition when reduce scatter overlap
parent
10b5056e1e
commit
4851291356
|
@ -404,7 +404,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# check if the bucket is full
|
# check if the bucket is full
|
||||||
# if full, will reduce the grads already in the bucket
|
# if full, will reduce the grads already in the bucket
|
||||||
# after reduction, the bucket will be empty
|
# after reduction, the bucket will be empty
|
||||||
if current_bucket.num_elements_in_bucket(reduce_rank) >= self._reduce_bucket_size:
|
if current_bucket.num_elements_in_bucket(reduce_rank) + param_size > self._reduce_bucket_size:
|
||||||
self._accum_grads_store_in_bucket(current_bucket, reduce_rank)
|
self._accum_grads_store_in_bucket(current_bucket, reduce_rank)
|
||||||
|
|
||||||
# otherwise, add the parameter into bucket.
|
# otherwise, add the parameter into bucket.
|
||||||
|
|
|
@ -600,12 +600,12 @@ def record_current_batch_training_metrics(
|
||||||
step_count=batch_count,
|
step_count=batch_count,
|
||||||
cur_step_loss=loss.item(),
|
cur_step_loss=loss.item(),
|
||||||
)
|
)
|
||||||
|
|
||||||
loss_list.append(loss.item())
|
loss_list.append(loss.item())
|
||||||
if batch_count >= 5:
|
if batch_count >= 5:
|
||||||
tgs_list.append(tgs_origin)
|
tgs_list.append(tgs_origin)
|
||||||
tflops_list.append(tflops)
|
tflops_list.append(tflops)
|
||||||
tflops_list_2.append(tflops_2)
|
tflops_list_2.append(tflops_2)
|
||||||
|
|
||||||
if batch_count == gpc.config.data.total_steps - 1:
|
if batch_count == gpc.config.data.total_steps - 1:
|
||||||
print(tgs_list, flush=True)
|
print(tgs_list, flush=True)
|
||||||
avg_tgs = sum(tgs_list) / len(tgs_list)
|
avg_tgs = sum(tgs_list) / len(tgs_list)
|
||||||
|
|
Loading…
Reference in New Issue