From acff1a00c927b2bbc2a3cc1bff149fd88b99ad8d Mon Sep 17 00:00:00 2001 From: lijiaxing Date: Thu, 28 Dec 2023 11:06:38 +0800 Subject: [PATCH] diag all reduce --- configs/7B_sft.py | 8 ++++++ internlm/utils/gputest.py | 55 +++++++++++++++++++++++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/configs/7B_sft.py b/configs/7B_sft.py index c0a9bc8..559d75a 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -181,3 +181,11 @@ monitor = dict( # metric_dtype can be "fp32" or other string # only when set to "fp32" will use fp32 to calc in metrics # metric_dtype = "fp32" + +diag_all_reduce = dict( + enable_diag=False, + average_val=0.5, # threshold for average time + range_val=0.05, # threshold for the difference between max_time and min_time + skip_first=0, # the number of the first tracebacks to skip + skip_last=0, # the number of the last tracebacks to skip +) diff --git a/internlm/utils/gputest.py b/internlm/utils/gputest.py index 48ec0e3..90116b6 100644 --- a/internlm/utils/gputest.py +++ b/internlm/utils/gputest.py @@ -20,10 +20,15 @@ try: except ImportError: GPUtil, psutil = None, None +import re +import traceback + from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.utils.common import get_current_device +original_all_reduce = dist.all_reduce + logger = get_logger(__file__) @@ -338,3 +343,53 @@ def cuda_memory_analyze(step=0, print_mm_suage=False): "that all ranks flush their caches at the same time" ) n_caching_allocator_flushes = alloc_retries + + +def get_traceback_list(): + pattern = r"file ([^,]+), line (\d+)" + traceback_list = list(traceback.extract_stack()) + result = [] + for item in traceback_list: + item = str(item) + match = re.search(pattern, item) + if match: + file_path = match.group(1) + line_number = match.group(2) + result.append(f"{file_path}, line {line_number}") + + return result + + +def diag_all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, async_op=False): + import time + + if "diag_all_reduce" in gpc.config and gpc.config.diag_all_reduce.enable_diag: + diag_config = gpc.config.diag_all_reduce + start_wait = time.time() + handle = original_all_reduce(tensor, op=op, group=group, async_op=async_op) + dist.barrier() + wait_time = (gpc.get_global_rank(), time.time() - start_wait) + + object_gather_list = [None for _ in range(dist.get_world_size(group))] + dist.all_gather_object(object_gather_list, wait_time, group) + if dist.get_rank(group) == 0: + sort_list = sorted(object_gather_list, key=lambda x: x[1]) + times = [tup[1] for tup in sort_list] + average_val = sum(times) / len(times) + range_val = sort_list[-1][1] - sort_list[0][1] + if average_val > diag_config.average_val or range_val > diag_config.range_val: + skip_first = diag_config.skip_first + skip_last = diag_config.skip_last + traceback_list = get_traceback_list()[skip_first : -skip_last if skip_last > 0 else None] + result = { + "rank_time_sort (global_rank, )": sort_list, + "traceback": traceback_list, + } + logger.warning(result) + else: + handle = original_all_reduce(tensor, op=op, group=group, async_op=async_op) + + return handle + + +dist.all_reduce = diag_all_reduce