mirror of https://github.com/InternLM/InternLM
feat(utils): add timeout warpper for key functions (#286)
parent
7f687bf4b3
commit
37b8c6684e
|
@ -18,6 +18,7 @@ import torch.distributed as dist
|
|||
|
||||
from internlm.utils.common import SingletonMeta
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
|
||||
|
||||
from . import process_group_initializer as pgroup_initializer
|
||||
from .process_group_initializer import ParallelMode
|
||||
|
@ -374,12 +375,22 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
"""
|
||||
# initialize the default process group
|
||||
init_method = f"tcp://[{host}]:{port}"
|
||||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||
dist.init_process_group(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
backend=backend,
|
||||
init_method=init_method,
|
||||
timeout=LLM_NCCL_TIMEOUT,
|
||||
)
|
||||
|
||||
# None will give the default global process group for pytorch dist operations
|
||||
ranks = list(range(world_size))
|
||||
if use_cpu:
|
||||
cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None
|
||||
cpu_group = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else None
|
||||
)
|
||||
else:
|
||||
cpu_group = None
|
||||
self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL)
|
||||
|
|
|
@ -9,6 +9,8 @@ from enum import Enum
|
|||
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
|
||||
|
||||
|
||||
# parallel modes
|
||||
class ParallelMode(Enum):
|
||||
|
@ -109,9 +111,13 @@ class Initializer_Data(ProcessGroupInitializer):
|
|||
|
||||
for i in range(self.rank_num_per_dp_group):
|
||||
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
@ -161,9 +167,13 @@ class Initializer_Model(ProcessGroupInitializer):
|
|||
|
||||
for i in range(self.num_group):
|
||||
ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)]
|
||||
group = dist.new_group(ranks)
|
||||
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
@ -221,9 +231,13 @@ class Initializer_Pipeline(ProcessGroupInitializer):
|
|||
)
|
||||
)
|
||||
pipe_group_size = len(ranks)
|
||||
pipe_group = dist.new_group(ranks)
|
||||
pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else pipe_group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
@ -271,9 +285,13 @@ class Initializer_Tensor(ProcessGroupInitializer):
|
|||
|
||||
for i in range(self.num_tensor_parallel_group):
|
||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
@ -327,9 +345,13 @@ class Initializer_Zero1(ProcessGroupInitializer):
|
|||
i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group
|
||||
for k in range(self.zero1_parallel_size)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
@ -376,9 +398,13 @@ class Initializer_Nettest(ProcessGroupInitializer):
|
|||
rank = i * self.nettest_parallel_size + j
|
||||
if rank < self.world_size:
|
||||
ranks.append(rank)
|
||||
group = dist.new_group(ranks)
|
||||
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
|
|
@ -9,6 +9,7 @@ import torch
|
|||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.common import conditional_context
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||
|
||||
|
@ -126,6 +127,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
|
||||
return output, loss
|
||||
|
||||
@llm_timeout(func_name="nopp_forward_backward_step")
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine: Engine,
|
||||
|
|
|
@ -15,6 +15,7 @@ from internlm.core.engine import Engine
|
|||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.common import get_current_device, move_to_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||
|
||||
|
@ -592,6 +593,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
return output, label, accum_loss
|
||||
|
||||
@llm_timeout(func_name="nointerleaved_forward_backward_step")
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
@ -1248,6 +1250,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
# 3. Cooldown
|
||||
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
|
||||
|
||||
@llm_timeout(func_name="interleaved_forward_backward_step")
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
||||
communication between pipeline stages as needed.
|
||||
|
|
|
@ -13,6 +13,7 @@ from internlm.core.context import global_context as gpc
|
|||
from internlm.monitor import initialize_light_monitor
|
||||
from internlm.utils.common import get_master_node
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -410,6 +411,7 @@ def launch_from_torch(
|
|||
)
|
||||
|
||||
|
||||
@llm_timeout(func_name="initialize_distributed_env")
|
||||
def initialize_distributed_env(
|
||||
config: str,
|
||||
launcher: str = "slurm",
|
||||
|
|
|
@ -32,6 +32,7 @@ from internlm.solver.optimizer.utils import (
|
|||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
from .utils import compute_norm
|
||||
|
||||
|
@ -506,6 +507,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
return norm
|
||||
|
||||
@llm_timeout(func_name="optim_step")
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
|
|
|
@ -40,10 +40,12 @@ from internlm.utils.parallel import (
|
|||
sync_model_param_within_tp,
|
||||
)
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
@llm_timeout(func_name="initialize_model")
|
||||
def initialize_model():
|
||||
"""
|
||||
Initialize model.
|
||||
|
@ -88,6 +90,7 @@ def initialize_model():
|
|||
return model
|
||||
|
||||
|
||||
@llm_timeout(func_name="initialize_optimizer")
|
||||
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||
"""
|
||||
Initialize optimizer.
|
||||
|
@ -124,6 +127,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
|||
return optimizer, beta2_scheduler, lr_scheduler
|
||||
|
||||
|
||||
@llm_timeout(func_name="get_train_data_loader")
|
||||
def get_train_data_loader(
|
||||
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
|
||||
):
|
||||
|
@ -196,6 +200,7 @@ def get_train_data_loader(
|
|||
return train_dl, dataset_types
|
||||
|
||||
|
||||
@llm_timeout(func_name="get_validation_data_loader")
|
||||
def get_validation_data_loader(
|
||||
num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None
|
||||
):
|
||||
|
@ -257,6 +262,7 @@ def get_validation_data_loader(
|
|||
return val_dls
|
||||
|
||||
|
||||
@llm_timeout(func_name="load_new_batch")
|
||||
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
|
||||
"""
|
||||
Load and return the new batch data based on training data loader.
|
||||
|
@ -314,6 +320,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
|||
)
|
||||
|
||||
|
||||
@llm_timeout(func_name="record_current_batch_training_metrics")
|
||||
def record_current_batch_training_metrics(
|
||||
get_tflops_func,
|
||||
logger,
|
||||
|
|
|
@ -84,7 +84,7 @@ def initialize_uniscale_logger(
|
|||
job_name and launch_time and file_name
|
||||
), "If file_path is None, job_name, launch_time and file_name must be setted."
|
||||
log_file_name = file_name
|
||||
log_folder = os.path.join(job_name, launch_time, "logs")
|
||||
log_folder = os.path.join("RUN", job_name, launch_time, "logs")
|
||||
log_dir = os.path.join(log_folder, log_file_name)
|
||||
file_path = log_dir
|
||||
|
||||
|
|
|
@ -33,6 +33,7 @@ from internlm.utils.storage_manager import (
|
|||
llm_save,
|
||||
try_get_storage_backend,
|
||||
)
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
@ -727,6 +728,7 @@ now step_count is {train_state.step_count}",
|
|||
if load_content_str:
|
||||
logger.info(f"===========Load contents are: {load_content_str}")
|
||||
|
||||
@llm_timeout(func_name="save_checkpoint")
|
||||
def save_checkpoint(
|
||||
self,
|
||||
folder,
|
||||
|
|
|
@ -1,4 +1,13 @@
|
|||
import datetime
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import traceback
|
||||
from functools import wraps
|
||||
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class Timeout:
|
||||
|
@ -24,3 +33,81 @@ class Timeout:
|
|||
|
||||
def __exit__(self, error_type, value, traceback):
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
ENABLE_TIMEOUT = os.getenv("INTERNLM_ENABLE_TIMEOUT", None)
|
||||
|
||||
|
||||
timeout_threshold_dict = {
|
||||
"initialize_distributed_env": 120,
|
||||
"nopp_forward_backward_step": 360,
|
||||
"initialize_model": 10,
|
||||
"initialize_optimizer": 20,
|
||||
"optim_step": 30,
|
||||
"get_train_data_loader": 600,
|
||||
"get_validation_data_loader": 60,
|
||||
"load_new_batch": 10,
|
||||
"record_current_batch_training_metrics": 10,
|
||||
"save_checkpoint": 1200,
|
||||
"interleaved_forward_backward_step": 600,
|
||||
"nointerleaved_forward_backward_step": 600,
|
||||
}
|
||||
|
||||
if ENABLE_TIMEOUT is not None:
|
||||
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
||||
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=int(os.getenv("NCCL_TIMEOUT", str(60))))
|
||||
else:
|
||||
timeout_threshold_dict = dict.fromkeys(timeout_threshold_dict.keys(), 0)
|
||||
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=1800)
|
||||
|
||||
|
||||
def try_get_gpc_rank():
|
||||
try:
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
rank = gpc.get_global_rank()
|
||||
except: # noqa # pylint: disable=bare-except
|
||||
rank = "unknown"
|
||||
|
||||
return f"host-{socket.gethostname()}-rank-{rank}"
|
||||
|
||||
|
||||
def llm_timeout(seconds=0, func_name=None):
|
||||
"""timeout decorator, Note that this decorator cannot be reentrant,
|
||||
otherwise the signal will be reset.
|
||||
|
||||
Args:
|
||||
seconds (int, optional): timeout threshold. Defaults to 300.
|
||||
func_name (str, optional): the func who is been waited to timeout.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
nonlocal func_name
|
||||
if func_name is None:
|
||||
func_name = func.__name__
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
def _handle_timeout(signum, frame):
|
||||
raise TimeoutError
|
||||
|
||||
nonlocal seconds
|
||||
seconds = timeout_threshold_dict.get(func_name, seconds)
|
||||
|
||||
if seconds > 0:
|
||||
signal.signal(signal.SIGALRM, _handle_timeout)
|
||||
signal.alarm(seconds)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
except TimeoutError as e:
|
||||
logger.error(f"TimeoutError at {try_get_gpc_rank()}: {func_name}\\n {traceback.format_exc()}")
|
||||
raise e
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
|
@ -127,12 +127,12 @@ def reset_seed():
|
|||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def init_dist_and_model():
|
||||
def init_dist_and_model(rank=0, world_size=1):
|
||||
from internlm.initialize import initialize_distributed_env
|
||||
|
||||
os.environ["RANK"] = "0"
|
||||
os.environ["LOCAL_RANK"] = "0"
|
||||
os.environ["WORLD_SIZE"] = "1"
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "12377"
|
||||
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
|
||||
|
|
|
@ -0,0 +1,119 @@
|
|||
import fcntl
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Process
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
os.environ["INTERNLM_ENABLE_TIMEOUT"] = "1" # noqa # pylint: disable=wrong-import-position
|
||||
os.environ["NCCL_TIMEOUT"] = "5"
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
||||
init_config,
|
||||
)
|
||||
|
||||
WORLD_SIZE = 2
|
||||
|
||||
|
||||
@llm_timeout(2, "fake_timeout_func")
|
||||
def fake_timeout_func():
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
@llm_timeout(10, "nccl_timeout_func")
|
||||
def nccl_timeout_func(rank):
|
||||
# see: https://github.com/pytorch/pytorch/issues/104506#issuecomment-1679762880
|
||||
# 'NCCL_ASYNC_ERROR_HANDLING' cannot take effect on the first collective communication.
|
||||
buff = torch.ones([64, 64]).cuda(rank)
|
||||
dist.all_reduce(buff) # lazy communicator init
|
||||
torch.cuda.synchronize()
|
||||
if rank == 0:
|
||||
dist.all_reduce(buff)
|
||||
torch.cuda.synchronize() # main thread will hang at here.
|
||||
else:
|
||||
time.sleep(9999)
|
||||
|
||||
|
||||
@llm_timeout(10, "try_file_lock")
|
||||
def try_file_lock(rank, stop_file_path):
|
||||
if rank == 1:
|
||||
time.sleep(5)
|
||||
|
||||
with open(stop_file_path, "r", encoding="utf-8") as f:
|
||||
fcntl.flock(f, fcntl.LOCK_EX) # rank 1 hang.
|
||||
if rank == 0:
|
||||
time.sleep(99999) # rank 0 hang.
|
||||
f.seek(0)
|
||||
f.read()
|
||||
fcntl.flock(f, fcntl.LOCK_UN)
|
||||
|
||||
|
||||
def local_timeout(rank, _):
|
||||
|
||||
try:
|
||||
fake_timeout_func()
|
||||
except TimeoutError as e:
|
||||
print(f"local_timeout, rank:{rank}, e:{e}", flush=True)
|
||||
else:
|
||||
assert False, "It should timeout!"
|
||||
|
||||
|
||||
def gpc_timeout(rank, world_size):
|
||||
|
||||
from internlm.initialize import initialize_distributed_env
|
||||
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "12377"
|
||||
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
|
||||
|
||||
try:
|
||||
nccl_timeout_func(rank)
|
||||
except TimeoutError as e:
|
||||
print(f"gpc_timeout, rank:{rank}, e:{e}", flush=True)
|
||||
time.sleep(5) # wait rank 0 to be killed
|
||||
else:
|
||||
time.sleep(5) # give some time to let Watchdog kill rank 0.
|
||||
assert False, "It should timeout!"
|
||||
|
||||
|
||||
def file_lock_timeout(rank, _, stop_file_path):
|
||||
if rank == 0:
|
||||
with open(stop_file_path, "w"):
|
||||
pass
|
||||
try:
|
||||
try_file_lock(rank, stop_file_path)
|
||||
except TimeoutError as e:
|
||||
print(e, flush=True)
|
||||
else:
|
||||
assert False, "It should timeout!"
|
||||
finally:
|
||||
if rank == 0:
|
||||
os.remove(stop_file_path)
|
||||
|
||||
|
||||
timeout_func_list = [(gpc_timeout, 2, None), (local_timeout, 1, None), (file_lock_timeout, 2, "test_lock.log")]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("timeout_func_and_args", timeout_func_list)
|
||||
def test_timeout(timeout_func_and_args):
|
||||
timeout_func, world_size, other_args = timeout_func_and_args
|
||||
procs = []
|
||||
for i in range(world_size):
|
||||
if other_args is None:
|
||||
args = (i, world_size)
|
||||
else:
|
||||
args = (i, world_size, other_args)
|
||||
proc = Process(target=timeout_func, args=args)
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
for proc in procs:
|
||||
proc.join(15)
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join()
|
Loading…
Reference in New Issue