InternLM/internlm/utils/timeout.py

114 lines
3.0 KiB
Python

import datetime
import os
import signal
import socket
import traceback
from functools import wraps
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
class Timeout:
"""Timer to execute code
Adapted from https://github.com/reasoning-machines/pal
Args:
seconds (float): The maximum seconds to execute code
error_message (str)
"""
def __init__(self, seconds=1, error_message="Timeout"):
self.seconds = seconds
self.error_message = error_message
def timeout_handler(self, signum, frame):
raise TimeoutError(self.error_message)
def __enter__(self):
signal.signal(signal.SIGALRM, self.timeout_handler)
signal.alarm(self.seconds)
def __exit__(self, error_type, value, traceback):
signal.alarm(0)
ENABLE_TIMEOUT = os.getenv("INTERNLM_ENABLE_TIMEOUT", None)
timeout_threshold_dict = {
"initialize_distributed_env": 120,
"nopp_forward_backward_step": 360,
"initialize_model": 10,
"initialize_optimizer": 20,
"optim_step": 30,
"get_train_data_loader": 600,
"get_validation_data_loader": 60,
"load_new_batch": 10,
"record_current_batch_training_metrics": 10,
"save_checkpoint": 1200,
"interleaved_forward_backward_step": 600,
"nointerleaved_forward_backward_step": 600,
}
if ENABLE_TIMEOUT is not None:
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=int(os.getenv("NCCL_TIMEOUT", str(60))))
else:
timeout_threshold_dict = dict.fromkeys(timeout_threshold_dict.keys(), 0)
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=1800)
def try_get_gpc_rank():
try:
from internlm.core.context import global_context as gpc
rank = gpc.get_global_rank()
except: # noqa # pylint: disable=bare-except
rank = "unknown"
return f"host-{socket.gethostname()}-rank-{rank}"
def llm_timeout(seconds=0, func_name=None):
"""timeout decorator, Note that this decorator cannot be reentrant,
otherwise the signal will be reset.
Args:
seconds (int, optional): timeout threshold. Defaults to 300.
func_name (str, optional): the func who is been waited to timeout.
"""
def decorator(func):
nonlocal func_name
if func_name is None:
func_name = func.__name__
@wraps(func)
def wrapper(*args, **kwargs):
def _handle_timeout(signum, frame):
raise TimeoutError
nonlocal seconds
seconds = timeout_threshold_dict.get(func_name, seconds)
if seconds > 0:
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(seconds)
try:
result = func(*args, **kwargs)
except TimeoutError as e:
logger.error(f"TimeoutError at {try_get_gpc_rank()}: {func_name}\\n {traceback.format_exc()}")
raise e
finally:
signal.alarm(0)
return result
return wrapper
return decorator