mirror of https://github.com/InternLM/InternLM
diag all reduce
parent
d418eba094
commit
acff1a00c9
|
@ -181,3 +181,11 @@ monitor = dict(
|
||||||
# metric_dtype can be "fp32" or other string
|
# metric_dtype can be "fp32" or other string
|
||||||
# only when set to "fp32" will use fp32 to calc in metrics
|
# only when set to "fp32" will use fp32 to calc in metrics
|
||||||
# metric_dtype = "fp32"
|
# metric_dtype = "fp32"
|
||||||
|
|
||||||
|
diag_all_reduce = dict(
|
||||||
|
enable_diag=False,
|
||||||
|
average_val=0.5, # threshold for average time
|
||||||
|
range_val=0.05, # threshold for the difference between max_time and min_time
|
||||||
|
skip_first=0, # the number of the first tracebacks to skip
|
||||||
|
skip_last=0, # the number of the last tracebacks to skip
|
||||||
|
)
|
||||||
|
|
|
@ -20,10 +20,15 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
GPUtil, psutil = None, None
|
GPUtil, psutil = None, None
|
||||||
|
|
||||||
|
import re
|
||||||
|
import traceback
|
||||||
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.utils.common import get_current_device
|
from internlm.utils.common import get_current_device
|
||||||
|
|
||||||
|
original_all_reduce = dist.all_reduce
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -338,3 +343,53 @@ def cuda_memory_analyze(step=0, print_mm_suage=False):
|
||||||
"that all ranks flush their caches at the same time"
|
"that all ranks flush their caches at the same time"
|
||||||
)
|
)
|
||||||
n_caching_allocator_flushes = alloc_retries
|
n_caching_allocator_flushes = alloc_retries
|
||||||
|
|
||||||
|
|
||||||
|
def get_traceback_list():
|
||||||
|
pattern = r"file ([^,]+), line (\d+)"
|
||||||
|
traceback_list = list(traceback.extract_stack())
|
||||||
|
result = []
|
||||||
|
for item in traceback_list:
|
||||||
|
item = str(item)
|
||||||
|
match = re.search(pattern, item)
|
||||||
|
if match:
|
||||||
|
file_path = match.group(1)
|
||||||
|
line_number = match.group(2)
|
||||||
|
result.append(f"{file_path}, line {line_number}")
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def diag_all_reduce(tensor, op=dist.ReduceOp.SUM, group=None, async_op=False):
|
||||||
|
import time
|
||||||
|
|
||||||
|
if "diag_all_reduce" in gpc.config and gpc.config.diag_all_reduce.enable_diag:
|
||||||
|
diag_config = gpc.config.diag_all_reduce
|
||||||
|
start_wait = time.time()
|
||||||
|
handle = original_all_reduce(tensor, op=op, group=group, async_op=async_op)
|
||||||
|
dist.barrier()
|
||||||
|
wait_time = (gpc.get_global_rank(), time.time() - start_wait)
|
||||||
|
|
||||||
|
object_gather_list = [None for _ in range(dist.get_world_size(group))]
|
||||||
|
dist.all_gather_object(object_gather_list, wait_time, group)
|
||||||
|
if dist.get_rank(group) == 0:
|
||||||
|
sort_list = sorted(object_gather_list, key=lambda x: x[1])
|
||||||
|
times = [tup[1] for tup in sort_list]
|
||||||
|
average_val = sum(times) / len(times)
|
||||||
|
range_val = sort_list[-1][1] - sort_list[0][1]
|
||||||
|
if average_val > diag_config.average_val or range_val > diag_config.range_val:
|
||||||
|
skip_first = diag_config.skip_first
|
||||||
|
skip_last = diag_config.skip_last
|
||||||
|
traceback_list = get_traceback_list()[skip_first : -skip_last if skip_last > 0 else None]
|
||||||
|
result = {
|
||||||
|
"rank_time_sort (global_rank, )": sort_list,
|
||||||
|
"traceback": traceback_list,
|
||||||
|
}
|
||||||
|
logger.warning(result)
|
||||||
|
else:
|
||||||
|
handle = original_all_reduce(tensor, op=op, group=group, async_op=async_op)
|
||||||
|
|
||||||
|
return handle
|
||||||
|
|
||||||
|
|
||||||
|
dist.all_reduce = diag_all_reduce
|
||||||
|
|
Loading…
Reference in New Issue