mirror of https://github.com/InternLM/InternLM
diag all reduce
parent
d418eba094
commit
acff1a00c9
|
@ -181,3 +181,11 @@ monitor = dict(
|
|||
# metric_dtype can be "fp32" or other string
|
||||
# only when set to "fp32" will use fp32 to calc in metrics
|
||||
# metric_dtype = "fp32"
|
||||
|
||||
diag_all_reduce = dict(
|
||||
enable_diag=False,
|
||||
average_val=0.5, # threshold for average time
|
||||
range_val=0.05, # threshold for the difference between max_time and min_time
|
||||
skip_first=0, # the number of the first tracebacks to skip
|
||||
skip_last=0, # the number of the last tracebacks to skip
|
||||
)
|
||||
|
|
|
@ -20,10 +20,15 @@ try:
|
|||
except ImportError:
|
||||
GPUtil, psutil = None, None
|
||||
|
||||
import re
|
||||
import traceback
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.utils.common import get_current_device
|
||||
|
||||
original_all_reduce = dist.all_reduce
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
|
@ -338,3 +343,53 @@ def cuda_memory_analyze(step=0, print_mm_suage=False):
|
|||
"that all ranks flush their caches at the same time"
|
||||
)
|
||||
n_caching_allocator_flushes = alloc_retries
|
||||
|
||||
|
||||
def get_traceback_list():
|
||||
pattern = r"file ([^,]+), line (\d+)"
|
||||
traceback_list = list(traceback.extract_stack())
|
||||
result = []
|
||||
for item in traceback_list:
|
||||
item = str(item)
|
||||
match = re.search(pattern, item)
|
||||
if match:
|
||||
file_path = match.group(1)
|
||||
line_number = match.group(2)
|
||||
result.append(f"{file_path}, line {line_number}")
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def diag_all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, async_op=False):
|
||||
import time
|
||||
|
||||
if "diag_all_reduce" in gpc.config and gpc.config.diag_all_reduce.enable_diag:
|
||||
diag_config = gpc.config.diag_all_reduce
|
||||
start_wait = time.time()
|
||||
handle = original_all_reduce(tensor, op=op, group=group, async_op=async_op)
|
||||
dist.barrier()
|
||||
wait_time = (gpc.get_global_rank(), time.time() - start_wait)
|
||||
|
||||
object_gather_list = [None for _ in range(dist.get_world_size(group))]
|
||||
dist.all_gather_object(object_gather_list, wait_time, group)
|
||||
if dist.get_rank(group) == 0:
|
||||
sort_list = sorted(object_gather_list, key=lambda x: x[1])
|
||||
times = [tup[1] for tup in sort_list]
|
||||
average_val = sum(times) / len(times)
|
||||
range_val = sort_list[-1][1] - sort_list[0][1]
|
||||
if average_val > diag_config.average_val or range_val > diag_config.range_val:
|
||||
skip_first = diag_config.skip_first
|
||||
skip_last = diag_config.skip_last
|
||||
traceback_list = get_traceback_list()[skip_first : -skip_last if skip_last > 0 else None]
|
||||
result = {
|
||||
"rank_time_sort (global_rank, )": sort_list,
|
||||
"traceback": traceback_list,
|
||||
}
|
||||
logger.warning(result)
|
||||
else:
|
||||
handle = original_all_reduce(tensor, op=op, group=group, async_op=async_op)
|
||||
|
||||
return handle
|
||||
|
||||
|
||||
dist.all_reduce = diag_all_reduce
|
||||
|
|
Loading…
Reference in New Issue