feat(utils): add timeout warpper for key functions (#286)

pull/298/head
Guoteng 2023-09-07 17:26:17 +08:00 committed by GitHub
parent 7f687bf4b3
commit 37b8c6684e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 280 additions and 19 deletions

View File

@ -18,6 +18,7 @@ import torch.distributed as dist
from internlm.utils.common import SingletonMeta
from internlm.utils.logger import get_logger
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
from . import process_group_initializer as pgroup_initializer
from .process_group_initializer import ParallelMode
@ -374,12 +375,22 @@ class ParallelContext(metaclass=SingletonMeta):
"""
# initialize the default process group
init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
dist.init_process_group(
rank=rank,
world_size=world_size,
backend=backend,
init_method=init_method,
timeout=LLM_NCCL_TIMEOUT,
)
# None will give the default global process group for pytorch dist operations
ranks = list(range(world_size))
if use_cpu:
cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None
cpu_group = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else None
)
else:
cpu_group = None
self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL)

View File

@ -9,6 +9,8 @@ from enum import Enum
import torch.distributed as dist
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
# parallel modes
class ParallelMode(Enum):
@ -109,9 +111,13 @@ class Initializer_Data(ProcessGroupInitializer):
for i in range(self.rank_num_per_dp_group):
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
group = dist.new_group(ranks)
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None
@ -161,9 +167,13 @@ class Initializer_Model(ProcessGroupInitializer):
for i in range(self.num_group):
ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)]
group = dist.new_group(ranks)
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None
@ -221,9 +231,13 @@ class Initializer_Pipeline(ProcessGroupInitializer):
)
)
pipe_group_size = len(ranks)
pipe_group = dist.new_group(ranks)
pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else pipe_group
)
else:
group_cpu = None
@ -271,9 +285,13 @@ class Initializer_Tensor(ProcessGroupInitializer):
for i in range(self.num_tensor_parallel_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks)
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None
@ -327,9 +345,13 @@ class Initializer_Zero1(ProcessGroupInitializer):
i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group
for k in range(self.zero1_parallel_size)
]
group = dist.new_group(ranks)
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None
@ -376,9 +398,13 @@ class Initializer_Nettest(ProcessGroupInitializer):
rank = i * self.nettest_parallel_size + j
if rank < self.world_size:
ranks.append(rank)
group = dist.new_group(ranks)
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else:
group_cpu = None

View File

@ -9,6 +9,7 @@ import torch
from internlm.core.engine import Engine
from internlm.utils.common import conditional_context
from internlm.utils.timeout import llm_timeout
from .base_scheduler import BaseScheduler, SchedulerHook
@ -126,6 +127,7 @@ class NonPipelineScheduler(BaseScheduler):
return output, loss
@llm_timeout(func_name="nopp_forward_backward_step")
def forward_backward_step(
self,
engine: Engine,

View File

@ -15,6 +15,7 @@ from internlm.core.engine import Engine
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.common import get_current_device, move_to_device
from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout
from .base_scheduler import BaseScheduler, SchedulerHook
@ -592,6 +593,7 @@ class PipelineScheduler(BaseScheduler):
return output, label, accum_loss
@llm_timeout(func_name="nointerleaved_forward_backward_step")
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -1248,6 +1250,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
# 3. Cooldown
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
@llm_timeout(func_name="interleaved_forward_backward_step")
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed.

View File

@ -13,6 +13,7 @@ from internlm.core.context import global_context as gpc
from internlm.monitor import initialize_light_monitor
from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout
logger = get_logger(__file__)
@ -410,6 +411,7 @@ def launch_from_torch(
)
@llm_timeout(func_name="initialize_distributed_env")
def initialize_distributed_env(
config: str,
launcher: str = "slurm",

View File

@ -32,6 +32,7 @@ from internlm.solver.optimizer.utils import (
from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.timeout import llm_timeout
from .utils import compute_norm
@ -506,6 +507,7 @@ class HybridZeroOptimizer(BaseOptimizer):
return norm
@llm_timeout(func_name="optim_step")
def step(self, closure=None):
"""Performs a single optimization step.

View File

@ -40,10 +40,12 @@ from internlm.utils.parallel import (
sync_model_param_within_tp,
)
from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.timeout import llm_timeout
logger = get_logger(__file__)
@llm_timeout(func_name="initialize_model")
def initialize_model():
"""
Initialize model.
@ -88,6 +90,7 @@ def initialize_model():
return model
@llm_timeout(func_name="initialize_optimizer")
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
"""
Initialize optimizer.
@ -124,6 +127,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
return optimizer, beta2_scheduler, lr_scheduler
@llm_timeout(func_name="get_train_data_loader")
def get_train_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
):
@ -196,6 +200,7 @@ def get_train_data_loader(
return train_dl, dataset_types
@llm_timeout(func_name="get_validation_data_loader")
def get_validation_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None
):
@ -257,6 +262,7 @@ def get_validation_data_loader(
return val_dls
@llm_timeout(func_name="load_new_batch")
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
"""
Load and return the new batch data based on training data loader.
@ -314,6 +320,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
)
@llm_timeout(func_name="record_current_batch_training_metrics")
def record_current_batch_training_metrics(
get_tflops_func,
logger,

View File

@ -84,7 +84,7 @@ def initialize_uniscale_logger(
job_name and launch_time and file_name
), "If file_path is None, job_name, launch_time and file_name must be setted."
log_file_name = file_name
log_folder = os.path.join(job_name, launch_time, "logs")
log_folder = os.path.join("RUN", job_name, launch_time, "logs")
log_dir = os.path.join(log_folder, log_file_name)
file_path = log_dir

View File

@ -33,6 +33,7 @@ from internlm.utils.storage_manager import (
llm_save,
try_get_storage_backend,
)
from internlm.utils.timeout import llm_timeout
logger = get_logger(__file__)
@ -727,6 +728,7 @@ now step_count is {train_state.step_count}",
if load_content_str:
logger.info(f"===========Load contents are: {load_content_str}")
@llm_timeout(func_name="save_checkpoint")
def save_checkpoint(
self,
folder,

View File

@ -1,4 +1,13 @@
import datetime
import os
import signal
import socket
import traceback
from functools import wraps
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
class Timeout:
@ -24,3 +33,81 @@ class Timeout:
def __exit__(self, error_type, value, traceback):
signal.alarm(0)
ENABLE_TIMEOUT = os.getenv("INTERNLM_ENABLE_TIMEOUT", None)
timeout_threshold_dict = {
"initialize_distributed_env": 120,
"nopp_forward_backward_step": 360,
"initialize_model": 10,
"initialize_optimizer": 20,
"optim_step": 30,
"get_train_data_loader": 600,
"get_validation_data_loader": 60,
"load_new_batch": 10,
"record_current_batch_training_metrics": 10,
"save_checkpoint": 1200,
"interleaved_forward_backward_step": 600,
"nointerleaved_forward_backward_step": 600,
}
if ENABLE_TIMEOUT is not None:
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=int(os.getenv("NCCL_TIMEOUT", str(60))))
else:
timeout_threshold_dict = dict.fromkeys(timeout_threshold_dict.keys(), 0)
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=1800)
def try_get_gpc_rank():
try:
from internlm.core.context import global_context as gpc
rank = gpc.get_global_rank()
except: # noqa # pylint: disable=bare-except
rank = "unknown"
return f"host-{socket.gethostname()}-rank-{rank}"
def llm_timeout(seconds=0, func_name=None):
"""timeout decorator, Note that this decorator cannot be reentrant,
otherwise the signal will be reset.
Args:
seconds (int, optional): timeout threshold. Defaults to 300.
func_name (str, optional): the func who is been waited to timeout.
"""
def decorator(func):
nonlocal func_name
if func_name is None:
func_name = func.__name__
@wraps(func)
def wrapper(*args, **kwargs):
def _handle_timeout(signum, frame):
raise TimeoutError
nonlocal seconds
seconds = timeout_threshold_dict.get(func_name, seconds)
if seconds > 0:
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(seconds)
try:
result = func(*args, **kwargs)
except TimeoutError as e:
logger.error(f"TimeoutError at {try_get_gpc_rank()}: {func_name}\\n {traceback.format_exc()}")
raise e
finally:
signal.alarm(0)
return result
return wrapper
return decorator

View File

@ -127,12 +127,12 @@ def reset_seed():
@pytest.fixture(scope="module")
def init_dist_and_model():
def init_dist_and_model(rank=0, world_size=1):
from internlm.initialize import initialize_distributed_env
os.environ["RANK"] = "0"
os.environ["LOCAL_RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12377"
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)

View File

@ -0,0 +1,119 @@
import fcntl
import os
import time
from multiprocessing import Process
import pytest
import torch
import torch.distributed as dist
os.environ["INTERNLM_ENABLE_TIMEOUT"] = "1" # noqa # pylint: disable=wrong-import-position
os.environ["NCCL_TIMEOUT"] = "5"
from internlm.utils.timeout import llm_timeout
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
init_config,
)
WORLD_SIZE = 2
@llm_timeout(2, "fake_timeout_func")
def fake_timeout_func():
time.sleep(10)
@llm_timeout(10, "nccl_timeout_func")
def nccl_timeout_func(rank):
# see: https://github.com/pytorch/pytorch/issues/104506#issuecomment-1679762880
# 'NCCL_ASYNC_ERROR_HANDLING' cannot take effect on the first collective communication.
buff = torch.ones([64, 64]).cuda(rank)
dist.all_reduce(buff) # lazy communicator init
torch.cuda.synchronize()
if rank == 0:
dist.all_reduce(buff)
torch.cuda.synchronize() # main thread will hang at here.
else:
time.sleep(9999)
@llm_timeout(10, "try_file_lock")
def try_file_lock(rank, stop_file_path):
if rank == 1:
time.sleep(5)
with open(stop_file_path, "r", encoding="utf-8") as f:
fcntl.flock(f, fcntl.LOCK_EX) # rank 1 hang.
if rank == 0:
time.sleep(99999) # rank 0 hang.
f.seek(0)
f.read()
fcntl.flock(f, fcntl.LOCK_UN)
def local_timeout(rank, _):
try:
fake_timeout_func()
except TimeoutError as e:
print(f"local_timeout, rank:{rank}, e:{e}", flush=True)
else:
assert False, "It should timeout!"
def gpc_timeout(rank, world_size):
from internlm.initialize import initialize_distributed_env
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12377"
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
try:
nccl_timeout_func(rank)
except TimeoutError as e:
print(f"gpc_timeout, rank:{rank}, e:{e}", flush=True)
time.sleep(5) # wait rank 0 to be killed
else:
time.sleep(5) # give some time to let Watchdog kill rank 0.
assert False, "It should timeout!"
def file_lock_timeout(rank, _, stop_file_path):
if rank == 0:
with open(stop_file_path, "w"):
pass
try:
try_file_lock(rank, stop_file_path)
except TimeoutError as e:
print(e, flush=True)
else:
assert False, "It should timeout!"
finally:
if rank == 0:
os.remove(stop_file_path)
timeout_func_list = [(gpc_timeout, 2, None), (local_timeout, 1, None), (file_lock_timeout, 2, "test_lock.log")]
@pytest.mark.parametrize("timeout_func_and_args", timeout_func_list)
def test_timeout(timeout_func_and_args):
timeout_func, world_size, other_args = timeout_func_and_args
procs = []
for i in range(world_size):
if other_args is None:
args = (i, world_size)
else:
args = (i, world_size, other_args)
proc = Process(target=timeout_func, args=args)
proc.start()
procs.append(proc)
for proc in procs:
proc.join(15)
if proc.is_alive():
proc.terminate()
proc.join()