mirror of https://github.com/InternLM/InternLM
no overlap for save ckpt
parent
9bf24d9768
commit
d3ca22cf3d
|
@ -36,8 +36,7 @@ def empty_cache_and_diag(batch_count, interval=50):
|
||||||
if interval <= 0:
|
if interval <= 0:
|
||||||
interval = 50
|
interval = 50
|
||||||
|
|
||||||
if not gpc.config.hybrid_zero_optimizer.overlap_sync_param:
|
cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5)
|
||||||
cuda_memory_analyze(batch_count, batch_count % int(interval) == 0 or batch_count <= 5)
|
|
||||||
|
|
||||||
if batch_count % int(interval) == 0:
|
if batch_count % int(interval) == 0:
|
||||||
# there is no need to do diag on the first batch
|
# there is no need to do diag on the first batch
|
||||||
|
@ -302,7 +301,6 @@ def warmup_process_group():
|
||||||
|
|
||||||
def cuda_memory_analyze(step=0, print_mm_suage=False):
|
def cuda_memory_analyze(step=0, print_mm_suage=False):
|
||||||
global n_caching_allocator_flushes
|
global n_caching_allocator_flushes
|
||||||
torch.cuda.synchronize()
|
|
||||||
|
|
||||||
g_rank = gpc.get_global_rank()
|
g_rank = gpc.get_global_rank()
|
||||||
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
||||||
|
|
Loading…
Reference in New Issue