diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 730244d..968489c 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -18,6 +18,7 @@ import torch.distributed as dist from internlm.utils.common import SingletonMeta from internlm.utils.logger import get_logger +from internlm.utils.timeout import LLM_NCCL_TIMEOUT from . import process_group_initializer as pgroup_initializer from .process_group_initializer import ParallelMode @@ -374,12 +375,22 @@ class ParallelContext(metaclass=SingletonMeta): """ # initialize the default process group init_method = f"tcp://[{host}]:{port}" - dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) + dist.init_process_group( + rank=rank, + world_size=world_size, + backend=backend, + init_method=init_method, + timeout=LLM_NCCL_TIMEOUT, + ) # None will give the default global process group for pytorch dist operations ranks = list(range(world_size)) if use_cpu: - cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None + cpu_group = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else None + ) else: cpu_group = None self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) diff --git a/internlm/core/context/process_group_initializer.py b/internlm/core/context/process_group_initializer.py index 11b41c0..97e9ef0 100644 --- a/internlm/core/context/process_group_initializer.py +++ b/internlm/core/context/process_group_initializer.py @@ -9,6 +9,8 @@ from enum import Enum import torch.distributed as dist +from internlm.utils.timeout import LLM_NCCL_TIMEOUT + # parallel modes class ParallelMode(Enum): @@ -109,9 +111,13 @@ class Initializer_Data(ProcessGroupInitializer): for i in range(self.rank_num_per_dp_group): ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -161,9 +167,13 @@ class Initializer_Model(ProcessGroupInitializer): for i in range(self.num_group): ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -221,9 +231,13 @@ class Initializer_Pipeline(ProcessGroupInitializer): ) ) pipe_group_size = len(ranks) - pipe_group = dist.new_group(ranks) + pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else pipe_group + ) else: group_cpu = None @@ -271,9 +285,13 @@ class Initializer_Tensor(ProcessGroupInitializer): for i in range(self.num_tensor_parallel_group): ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -327,9 +345,13 @@ class Initializer_Zero1(ProcessGroupInitializer): i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group for k in range(self.zero1_parallel_size) ] - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None @@ -376,9 +398,13 @@ class Initializer_Nettest(ProcessGroupInitializer): rank = i * self.nettest_parallel_size + j if rank < self.world_size: ranks.append(rank) - group = dist.new_group(ranks) + group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT) if use_cpu: - group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group + group_cpu = ( + dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT) + if dist.get_backend() != "gloo" + else group + ) else: group_cpu = None diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 2633a9c..998bdb1 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -9,6 +9,7 @@ import torch from internlm.core.engine import Engine from internlm.utils.common import conditional_context +from internlm.utils.timeout import llm_timeout from .base_scheduler import BaseScheduler, SchedulerHook @@ -126,6 +127,7 @@ class NonPipelineScheduler(BaseScheduler): return output, loss + @llm_timeout(func_name="nopp_forward_backward_step") def forward_backward_step( self, engine: Engine, diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 501794d..4714474 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -15,6 +15,7 @@ from internlm.core.engine import Engine from internlm.core.naive_amp import NaiveAMPModel from internlm.utils.common import get_current_device, move_to_device from internlm.utils.logger import get_logger +from internlm.utils.timeout import llm_timeout from .base_scheduler import BaseScheduler, SchedulerHook @@ -592,6 +593,7 @@ class PipelineScheduler(BaseScheduler): return output, label, accum_loss + @llm_timeout(func_name="nointerleaved_forward_backward_step") def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): """Runs non-interleaved 1F1B schedule, with communication between pipeline stages. Returns a tuple with losses if the last stage, an empty tuple otherwise. @@ -1248,6 +1250,7 @@ class InterleavedPipelineScheduler(PipelineScheduler): # 3. Cooldown self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs) + @llm_timeout(func_name="interleaved_forward_backward_step") def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): """Run interleaved 1F1B schedule (model split into model chunks), with communication between pipeline stages as needed. diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index bd45183..0945337 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -13,6 +13,7 @@ from internlm.core.context import global_context as gpc from internlm.monitor import initialize_light_monitor from internlm.utils.common import get_master_node from internlm.utils.logger import get_logger +from internlm.utils.timeout import llm_timeout logger = get_logger(__file__) @@ -410,6 +411,7 @@ def launch_from_torch( ) +@llm_timeout(func_name="initialize_distributed_env") def initialize_distributed_env( config: str, launcher: str = "slurm", diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 0e4343a..0e44c99 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -32,6 +32,7 @@ from internlm.solver.optimizer.utils import ( from internlm.utils.common import get_current_device from internlm.utils.logger import get_logger from internlm.utils.megatron_timers import megatron_timer as timer +from internlm.utils.timeout import llm_timeout from .utils import compute_norm @@ -506,6 +507,7 @@ class HybridZeroOptimizer(BaseOptimizer): return norm + @llm_timeout(func_name="optim_step") def step(self, closure=None): """Performs a single optimization step. diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index fec9239..3a2e0bd 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -40,10 +40,12 @@ from internlm.utils.parallel import ( sync_model_param_within_tp, ) from internlm.utils.registry import MODEL_INITIALIZER +from internlm.utils.timeout import llm_timeout logger = get_logger(__file__) +@llm_timeout(func_name="initialize_model") def initialize_model(): """ Initialize model. @@ -88,6 +90,7 @@ def initialize_model(): return model +@llm_timeout(func_name="initialize_optimizer") def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): """ Initialize optimizer. @@ -124,6 +127,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): return optimizer, beta2_scheduler, lr_scheduler +@llm_timeout(func_name="get_train_data_loader") def get_train_data_loader( num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None ): @@ -196,6 +200,7 @@ def get_train_data_loader( return train_dl, dataset_types +@llm_timeout(func_name="get_validation_data_loader") def get_validation_data_loader( num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None ): @@ -257,6 +262,7 @@ def get_validation_data_loader( return val_dls +@llm_timeout(func_name="load_new_batch") def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState): """ Load and return the new batch data based on training data loader. @@ -314,6 +320,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None): ) +@llm_timeout(func_name="record_current_batch_training_metrics") def record_current_batch_training_metrics( get_tflops_func, logger, diff --git a/internlm/utils/logger.py b/internlm/utils/logger.py index 679913a..6111553 100644 --- a/internlm/utils/logger.py +++ b/internlm/utils/logger.py @@ -84,7 +84,7 @@ def initialize_uniscale_logger( job_name and launch_time and file_name ), "If file_path is None, job_name, launch_time and file_name must be setted." log_file_name = file_name - log_folder = os.path.join(job_name, launch_time, "logs") + log_folder = os.path.join("RUN", job_name, launch_time, "logs") log_dir = os.path.join(log_folder, log_file_name) file_path = log_dir diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index 21d76d1..b8f7ad6 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -33,6 +33,7 @@ from internlm.utils.storage_manager import ( llm_save, try_get_storage_backend, ) +from internlm.utils.timeout import llm_timeout logger = get_logger(__file__) @@ -727,6 +728,7 @@ now step_count is {train_state.step_count}", if load_content_str: logger.info(f"===========Load contents are: {load_content_str}") + @llm_timeout(func_name="save_checkpoint") def save_checkpoint( self, folder, diff --git a/internlm/utils/timeout.py b/internlm/utils/timeout.py index 07a0911..7a96841 100644 --- a/internlm/utils/timeout.py +++ b/internlm/utils/timeout.py @@ -1,4 +1,13 @@ +import datetime +import os import signal +import socket +import traceback +from functools import wraps + +from internlm.utils.logger import get_logger + +logger = get_logger(__file__) class Timeout: @@ -24,3 +33,81 @@ class Timeout: def __exit__(self, error_type, value, traceback): signal.alarm(0) + + +ENABLE_TIMEOUT = os.getenv("INTERNLM_ENABLE_TIMEOUT", None) + + +timeout_threshold_dict = { + "initialize_distributed_env": 120, + "nopp_forward_backward_step": 360, + "initialize_model": 10, + "initialize_optimizer": 20, + "optim_step": 30, + "get_train_data_loader": 600, + "get_validation_data_loader": 60, + "load_new_batch": 10, + "record_current_batch_training_metrics": 10, + "save_checkpoint": 1200, + "interleaved_forward_backward_step": 600, + "nointerleaved_forward_backward_step": 600, +} + +if ENABLE_TIMEOUT is not None: + os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1" + LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=int(os.getenv("NCCL_TIMEOUT", str(60)))) +else: + timeout_threshold_dict = dict.fromkeys(timeout_threshold_dict.keys(), 0) + LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=1800) + + +def try_get_gpc_rank(): + try: + from internlm.core.context import global_context as gpc + + rank = gpc.get_global_rank() + except: # noqa # pylint: disable=bare-except + rank = "unknown" + + return f"host-{socket.gethostname()}-rank-{rank}" + + +def llm_timeout(seconds=0, func_name=None): + """timeout decorator, Note that this decorator cannot be reentrant, + otherwise the signal will be reset. + + Args: + seconds (int, optional): timeout threshold. Defaults to 300. + func_name (str, optional): the func who is been waited to timeout. + """ + + def decorator(func): + nonlocal func_name + if func_name is None: + func_name = func.__name__ + + @wraps(func) + def wrapper(*args, **kwargs): + def _handle_timeout(signum, frame): + raise TimeoutError + + nonlocal seconds + seconds = timeout_threshold_dict.get(func_name, seconds) + + if seconds > 0: + signal.signal(signal.SIGALRM, _handle_timeout) + signal.alarm(seconds) + + try: + result = func(*args, **kwargs) + except TimeoutError as e: + logger.error(f"TimeoutError at {try_get_gpc_rank()}: {func_name}\\n {traceback.format_exc()}") + raise e + finally: + signal.alarm(0) + + return result + + return wrapper + + return decorator diff --git a/tests/test_utils/common_fixture.py b/tests/test_utils/common_fixture.py index b5f61e3..d6a19b6 100644 --- a/tests/test_utils/common_fixture.py +++ b/tests/test_utils/common_fixture.py @@ -127,12 +127,12 @@ def reset_seed(): @pytest.fixture(scope="module") -def init_dist_and_model(): +def init_dist_and_model(rank=0, world_size=1): from internlm.initialize import initialize_distributed_env - os.environ["RANK"] = "0" - os.environ["LOCAL_RANK"] = "0" - os.environ["WORLD_SIZE"] = "1" + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_PORT"] = "12377" initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False) diff --git a/tests/test_utils/test_timeout.py b/tests/test_utils/test_timeout.py new file mode 100644 index 0000000..a3f15f9 --- /dev/null +++ b/tests/test_utils/test_timeout.py @@ -0,0 +1,119 @@ +import fcntl +import os +import time +from multiprocessing import Process + +import pytest +import torch +import torch.distributed as dist + +os.environ["INTERNLM_ENABLE_TIMEOUT"] = "1" # noqa # pylint: disable=wrong-import-position +os.environ["NCCL_TIMEOUT"] = "5" +from internlm.utils.timeout import llm_timeout +from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import + init_config, +) + +WORLD_SIZE = 2 + + +@llm_timeout(2, "fake_timeout_func") +def fake_timeout_func(): + time.sleep(10) + + +@llm_timeout(10, "nccl_timeout_func") +def nccl_timeout_func(rank): + # see: https://github.com/pytorch/pytorch/issues/104506#issuecomment-1679762880 + # 'NCCL_ASYNC_ERROR_HANDLING' cannot take effect on the first collective communication. + buff = torch.ones([64, 64]).cuda(rank) + dist.all_reduce(buff) # lazy communicator init + torch.cuda.synchronize() + if rank == 0: + dist.all_reduce(buff) + torch.cuda.synchronize() # main thread will hang at here. + else: + time.sleep(9999) + + +@llm_timeout(10, "try_file_lock") +def try_file_lock(rank, stop_file_path): + if rank == 1: + time.sleep(5) + + with open(stop_file_path, "r", encoding="utf-8") as f: + fcntl.flock(f, fcntl.LOCK_EX) # rank 1 hang. + if rank == 0: + time.sleep(99999) # rank 0 hang. + f.seek(0) + f.read() + fcntl.flock(f, fcntl.LOCK_UN) + + +def local_timeout(rank, _): + + try: + fake_timeout_func() + except TimeoutError as e: + print(f"local_timeout, rank:{rank}, e:{e}", flush=True) + else: + assert False, "It should timeout!" + + +def gpc_timeout(rank, world_size): + + from internlm.initialize import initialize_distributed_env + + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "12377" + initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False) + + try: + nccl_timeout_func(rank) + except TimeoutError as e: + print(f"gpc_timeout, rank:{rank}, e:{e}", flush=True) + time.sleep(5) # wait rank 0 to be killed + else: + time.sleep(5) # give some time to let Watchdog kill rank 0. + assert False, "It should timeout!" + + +def file_lock_timeout(rank, _, stop_file_path): + if rank == 0: + with open(stop_file_path, "w"): + pass + try: + try_file_lock(rank, stop_file_path) + except TimeoutError as e: + print(e, flush=True) + else: + assert False, "It should timeout!" + finally: + if rank == 0: + os.remove(stop_file_path) + + +timeout_func_list = [(gpc_timeout, 2, None), (local_timeout, 1, None), (file_lock_timeout, 2, "test_lock.log")] + + +@pytest.mark.parametrize("timeout_func_and_args", timeout_func_list) +def test_timeout(timeout_func_and_args): + timeout_func, world_size, other_args = timeout_func_and_args + procs = [] + for i in range(world_size): + if other_args is None: + args = (i, world_size) + else: + args = (i, world_size, other_args) + proc = Process(target=timeout_func, args=args) + proc.start() + procs.append(proc) + + for proc in procs: + proc.join(15) + if proc.is_alive(): + proc.terminate() + proc.join()