diag all reduce

pull/563/head
lijiaxing 2023-12-28 11:06:38 +08:00
parent d418eba094
commit acff1a00c9
2 changed files with 63 additions and 0 deletions

View File

@ -181,3 +181,11 @@ monitor = dict(
# metric_dtype can be "fp32" or other string
# only when set to "fp32" will use fp32 to calc in metrics
# metric_dtype = "fp32"
diag_all_reduce = dict(
enable_diag=False,
average_val=0.5, # threshold for average time
range_val=0.05, # threshold for the difference between max_time and min_time
skip_first=0, # the number of the first tracebacks to skip
skip_last=0, # the number of the last tracebacks to skip
)

View File

@ -20,10 +20,15 @@ try:
except ImportError:
GPUtil, psutil = None, None
import re
import traceback
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.utils.common import get_current_device
original_all_reduce = dist.all_reduce
logger = get_logger(__file__)
@ -338,3 +343,53 @@ def cuda_memory_analyze(step=0, print_mm_suage=False):
"that all ranks flush their caches at the same time"
)
n_caching_allocator_flushes = alloc_retries
def get_traceback_list():
pattern = r"file ([^,]+), line (\d+)"
traceback_list = list(traceback.extract_stack())
result = []
for item in traceback_list:
item = str(item)
match = re.search(pattern, item)
if match:
file_path = match.group(1)
line_number = match.group(2)
result.append(f"{file_path}, line {line_number}")
return result
def diag_all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, async_op=False):
import time
if "diag_all_reduce" in gpc.config and gpc.config.diag_all_reduce.enable_diag:
diag_config = gpc.config.diag_all_reduce
start_wait = time.time()
handle = original_all_reduce(tensor, op=op, group=group, async_op=async_op)
dist.barrier()
wait_time = (gpc.get_global_rank(), time.time() - start_wait)
object_gather_list = [None for _ in range(dist.get_world_size(group))]
dist.all_gather_object(object_gather_list, wait_time, group)
if dist.get_rank(group) == 0:
sort_list = sorted(object_gather_list, key=lambda x: x[1])
times = [tup[1] for tup in sort_list]
average_val = sum(times) / len(times)
range_val = sort_list[-1][1] - sort_list[0][1]
if average_val > diag_config.average_val or range_val > diag_config.range_val:
skip_first = diag_config.skip_first
skip_last = diag_config.skip_last
traceback_list = get_traceback_list()[skip_first : -skip_last if skip_last > 0 else None]
result = {
"rank_time_sort (global_rank, )": sort_list,
"traceback": traceback_list,
}
logger.warning(result)
else:
handle = original_all_reduce(tensor, op=op, group=group, async_op=async_op)
return handle
dist.all_reduce = diag_all_reduce