feat(utils): add timeout warpper for key functions (#286)

pull/298/head
Guoteng 2023-09-07 17:26:17 +08:00 committed by GitHub
parent 7f687bf4b3
commit 37b8c6684e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 280 additions and 19 deletions

View File

@ -18,6 +18,7 @@ import torch.distributed as dist
from internlm.utils.common import SingletonMeta from internlm.utils.common import SingletonMeta
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
from . import process_group_initializer as pgroup_initializer from . import process_group_initializer as pgroup_initializer
from .process_group_initializer import ParallelMode from .process_group_initializer import ParallelMode
@ -374,12 +375,22 @@ class ParallelContext(metaclass=SingletonMeta):
""" """
# initialize the default process group # initialize the default process group
init_method = f"tcp://[{host}]:{port}" init_method = f"tcp://[{host}]:{port}"
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method) dist.init_process_group(
rank=rank,
world_size=world_size,
backend=backend,
init_method=init_method,
timeout=LLM_NCCL_TIMEOUT,
)
# None will give the default global process group for pytorch dist operations # None will give the default global process group for pytorch dist operations
ranks = list(range(world_size)) ranks = list(range(world_size))
if use_cpu: if use_cpu:
cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None cpu_group = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else None
)
else: else:
cpu_group = None cpu_group = None
self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL) self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL)

View File

@ -9,6 +9,8 @@ from enum import Enum
import torch.distributed as dist import torch.distributed as dist
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
# parallel modes # parallel modes
class ParallelMode(Enum): class ParallelMode(Enum):
@ -109,9 +111,13 @@ class Initializer_Data(ProcessGroupInitializer):
for i in range(self.rank_num_per_dp_group): for i in range(self.rank_num_per_dp_group):
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)] ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
group = dist.new_group(ranks) group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu: if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else: else:
group_cpu = None group_cpu = None
@ -161,9 +167,13 @@ class Initializer_Model(ProcessGroupInitializer):
for i in range(self.num_group): for i in range(self.num_group):
ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)] ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)]
group = dist.new_group(ranks) group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu: if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else: else:
group_cpu = None group_cpu = None
@ -221,9 +231,13 @@ class Initializer_Pipeline(ProcessGroupInitializer):
) )
) )
pipe_group_size = len(ranks) pipe_group_size = len(ranks)
pipe_group = dist.new_group(ranks) pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu: if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else pipe_group
)
else: else:
group_cpu = None group_cpu = None
@ -271,9 +285,13 @@ class Initializer_Tensor(ProcessGroupInitializer):
for i in range(self.num_tensor_parallel_group): for i in range(self.num_tensor_parallel_group):
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)] ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
group = dist.new_group(ranks) group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu: if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else: else:
group_cpu = None group_cpu = None
@ -327,9 +345,13 @@ class Initializer_Zero1(ProcessGroupInitializer):
i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group
for k in range(self.zero1_parallel_size) for k in range(self.zero1_parallel_size)
] ]
group = dist.new_group(ranks) group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu: if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else: else:
group_cpu = None group_cpu = None
@ -376,9 +398,13 @@ class Initializer_Nettest(ProcessGroupInitializer):
rank = i * self.nettest_parallel_size + j rank = i * self.nettest_parallel_size + j
if rank < self.world_size: if rank < self.world_size:
ranks.append(rank) ranks.append(rank)
group = dist.new_group(ranks) group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
if use_cpu: if use_cpu:
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group group_cpu = (
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
if dist.get_backend() != "gloo"
else group
)
else: else:
group_cpu = None group_cpu = None

View File

@ -9,6 +9,7 @@ import torch
from internlm.core.engine import Engine from internlm.core.engine import Engine
from internlm.utils.common import conditional_context from internlm.utils.common import conditional_context
from internlm.utils.timeout import llm_timeout
from .base_scheduler import BaseScheduler, SchedulerHook from .base_scheduler import BaseScheduler, SchedulerHook
@ -126,6 +127,7 @@ class NonPipelineScheduler(BaseScheduler):
return output, loss return output, loss
@llm_timeout(func_name="nopp_forward_backward_step")
def forward_backward_step( def forward_backward_step(
self, self,
engine: Engine, engine: Engine,

View File

@ -15,6 +15,7 @@ from internlm.core.engine import Engine
from internlm.core.naive_amp import NaiveAMPModel from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.common import get_current_device, move_to_device from internlm.utils.common import get_current_device, move_to_device
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout
from .base_scheduler import BaseScheduler, SchedulerHook from .base_scheduler import BaseScheduler, SchedulerHook
@ -592,6 +593,7 @@ class PipelineScheduler(BaseScheduler):
return output, label, accum_loss return output, label, accum_loss
@llm_timeout(func_name="nointerleaved_forward_backward_step")
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages. """Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
Returns a tuple with losses if the last stage, an empty tuple otherwise. Returns a tuple with losses if the last stage, an empty tuple otherwise.
@ -1248,6 +1250,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
# 3. Cooldown # 3. Cooldown
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs) self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
@llm_timeout(func_name="interleaved_forward_backward_step")
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True): def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
"""Run interleaved 1F1B schedule (model split into model chunks), with """Run interleaved 1F1B schedule (model split into model chunks), with
communication between pipeline stages as needed. communication between pipeline stages as needed.

View File

@ -13,6 +13,7 @@ from internlm.core.context import global_context as gpc
from internlm.monitor import initialize_light_monitor from internlm.monitor import initialize_light_monitor
from internlm.utils.common import get_master_node from internlm.utils.common import get_master_node
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout
logger = get_logger(__file__) logger = get_logger(__file__)
@ -410,6 +411,7 @@ def launch_from_torch(
) )
@llm_timeout(func_name="initialize_distributed_env")
def initialize_distributed_env( def initialize_distributed_env(
config: str, config: str,
launcher: str = "slurm", launcher: str = "slurm",

View File

@ -32,6 +32,7 @@ from internlm.solver.optimizer.utils import (
from internlm.utils.common import get_current_device from internlm.utils.common import get_current_device
from internlm.utils.logger import get_logger from internlm.utils.logger import get_logger
from internlm.utils.megatron_timers import megatron_timer as timer from internlm.utils.megatron_timers import megatron_timer as timer
from internlm.utils.timeout import llm_timeout
from .utils import compute_norm from .utils import compute_norm
@ -506,6 +507,7 @@ class HybridZeroOptimizer(BaseOptimizer):
return norm return norm
@llm_timeout(func_name="optim_step")
def step(self, closure=None): def step(self, closure=None):
"""Performs a single optimization step. """Performs a single optimization step.

View File

@ -40,10 +40,12 @@ from internlm.utils.parallel import (
sync_model_param_within_tp, sync_model_param_within_tp,
) )
from internlm.utils.registry import MODEL_INITIALIZER from internlm.utils.registry import MODEL_INITIALIZER
from internlm.utils.timeout import llm_timeout
logger = get_logger(__file__) logger = get_logger(__file__)
@llm_timeout(func_name="initialize_model")
def initialize_model(): def initialize_model():
""" """
Initialize model. Initialize model.
@ -88,6 +90,7 @@ def initialize_model():
return model return model
@llm_timeout(func_name="initialize_optimizer")
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]): def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
""" """
Initialize optimizer. Initialize optimizer.
@ -124,6 +127,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
return optimizer, beta2_scheduler, lr_scheduler return optimizer, beta2_scheduler, lr_scheduler
@llm_timeout(func_name="get_train_data_loader")
def get_train_data_loader( def get_train_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
): ):
@ -196,6 +200,7 @@ def get_train_data_loader(
return train_dl, dataset_types return train_dl, dataset_types
@llm_timeout(func_name="get_validation_data_loader")
def get_validation_data_loader( def get_validation_data_loader(
num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None
): ):
@ -257,6 +262,7 @@ def get_validation_data_loader(
return val_dls return val_dls
@llm_timeout(func_name="load_new_batch")
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState): def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
""" """
Load and return the new batch data based on training data loader. Load and return the new batch data based on training data loader.
@ -314,6 +320,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
) )
@llm_timeout(func_name="record_current_batch_training_metrics")
def record_current_batch_training_metrics( def record_current_batch_training_metrics(
get_tflops_func, get_tflops_func,
logger, logger,

View File

@ -84,7 +84,7 @@ def initialize_uniscale_logger(
job_name and launch_time and file_name job_name and launch_time and file_name
), "If file_path is None, job_name, launch_time and file_name must be setted." ), "If file_path is None, job_name, launch_time and file_name must be setted."
log_file_name = file_name log_file_name = file_name
log_folder = os.path.join(job_name, launch_time, "logs") log_folder = os.path.join("RUN", job_name, launch_time, "logs")
log_dir = os.path.join(log_folder, log_file_name) log_dir = os.path.join(log_folder, log_file_name)
file_path = log_dir file_path = log_dir

View File

@ -33,6 +33,7 @@ from internlm.utils.storage_manager import (
llm_save, llm_save,
try_get_storage_backend, try_get_storage_backend,
) )
from internlm.utils.timeout import llm_timeout
logger = get_logger(__file__) logger = get_logger(__file__)
@ -727,6 +728,7 @@ now step_count is {train_state.step_count}",
if load_content_str: if load_content_str:
logger.info(f"===========Load contents are: {load_content_str}") logger.info(f"===========Load contents are: {load_content_str}")
@llm_timeout(func_name="save_checkpoint")
def save_checkpoint( def save_checkpoint(
self, self,
folder, folder,

View File

@ -1,4 +1,13 @@
import datetime
import os
import signal import signal
import socket
import traceback
from functools import wraps
from internlm.utils.logger import get_logger
logger = get_logger(__file__)
class Timeout: class Timeout:
@ -24,3 +33,81 @@ class Timeout:
def __exit__(self, error_type, value, traceback): def __exit__(self, error_type, value, traceback):
signal.alarm(0) signal.alarm(0)
ENABLE_TIMEOUT = os.getenv("INTERNLM_ENABLE_TIMEOUT", None)
timeout_threshold_dict = {
"initialize_distributed_env": 120,
"nopp_forward_backward_step": 360,
"initialize_model": 10,
"initialize_optimizer": 20,
"optim_step": 30,
"get_train_data_loader": 600,
"get_validation_data_loader": 60,
"load_new_batch": 10,
"record_current_batch_training_metrics": 10,
"save_checkpoint": 1200,
"interleaved_forward_backward_step": 600,
"nointerleaved_forward_backward_step": 600,
}
if ENABLE_TIMEOUT is not None:
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=int(os.getenv("NCCL_TIMEOUT", str(60))))
else:
timeout_threshold_dict = dict.fromkeys(timeout_threshold_dict.keys(), 0)
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=1800)
def try_get_gpc_rank():
try:
from internlm.core.context import global_context as gpc
rank = gpc.get_global_rank()
except: # noqa # pylint: disable=bare-except
rank = "unknown"
return f"host-{socket.gethostname()}-rank-{rank}"
def llm_timeout(seconds=0, func_name=None):
"""timeout decorator, Note that this decorator cannot be reentrant,
otherwise the signal will be reset.
Args:
seconds (int, optional): timeout threshold. Defaults to 300.
func_name (str, optional): the func who is been waited to timeout.
"""
def decorator(func):
nonlocal func_name
if func_name is None:
func_name = func.__name__
@wraps(func)
def wrapper(*args, **kwargs):
def _handle_timeout(signum, frame):
raise TimeoutError
nonlocal seconds
seconds = timeout_threshold_dict.get(func_name, seconds)
if seconds > 0:
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(seconds)
try:
result = func(*args, **kwargs)
except TimeoutError as e:
logger.error(f"TimeoutError at {try_get_gpc_rank()}: {func_name}\\n {traceback.format_exc()}")
raise e
finally:
signal.alarm(0)
return result
return wrapper
return decorator

View File

@ -127,12 +127,12 @@ def reset_seed():
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
def init_dist_and_model(): def init_dist_and_model(rank=0, world_size=1):
from internlm.initialize import initialize_distributed_env from internlm.initialize import initialize_distributed_env
os.environ["RANK"] = "0" os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = "0" os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = "1" os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1" os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12377" os.environ["MASTER_PORT"] = "12377"
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False) initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)

View File

@ -0,0 +1,119 @@
import fcntl
import os
import time
from multiprocessing import Process
import pytest
import torch
import torch.distributed as dist
os.environ["INTERNLM_ENABLE_TIMEOUT"] = "1" # noqa # pylint: disable=wrong-import-position
os.environ["NCCL_TIMEOUT"] = "5"
from internlm.utils.timeout import llm_timeout
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
init_config,
)
WORLD_SIZE = 2
@llm_timeout(2, "fake_timeout_func")
def fake_timeout_func():
time.sleep(10)
@llm_timeout(10, "nccl_timeout_func")
def nccl_timeout_func(rank):
# see: https://github.com/pytorch/pytorch/issues/104506#issuecomment-1679762880
# 'NCCL_ASYNC_ERROR_HANDLING' cannot take effect on the first collective communication.
buff = torch.ones([64, 64]).cuda(rank)
dist.all_reduce(buff) # lazy communicator init
torch.cuda.synchronize()
if rank == 0:
dist.all_reduce(buff)
torch.cuda.synchronize() # main thread will hang at here.
else:
time.sleep(9999)
@llm_timeout(10, "try_file_lock")
def try_file_lock(rank, stop_file_path):
if rank == 1:
time.sleep(5)
with open(stop_file_path, "r", encoding="utf-8") as f:
fcntl.flock(f, fcntl.LOCK_EX) # rank 1 hang.
if rank == 0:
time.sleep(99999) # rank 0 hang.
f.seek(0)
f.read()
fcntl.flock(f, fcntl.LOCK_UN)
def local_timeout(rank, _):
try:
fake_timeout_func()
except TimeoutError as e:
print(f"local_timeout, rank:{rank}, e:{e}", flush=True)
else:
assert False, "It should timeout!"
def gpc_timeout(rank, world_size):
from internlm.initialize import initialize_distributed_env
os.environ["RANK"] = str(rank)
os.environ["LOCAL_RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "12377"
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
try:
nccl_timeout_func(rank)
except TimeoutError as e:
print(f"gpc_timeout, rank:{rank}, e:{e}", flush=True)
time.sleep(5) # wait rank 0 to be killed
else:
time.sleep(5) # give some time to let Watchdog kill rank 0.
assert False, "It should timeout!"
def file_lock_timeout(rank, _, stop_file_path):
if rank == 0:
with open(stop_file_path, "w"):
pass
try:
try_file_lock(rank, stop_file_path)
except TimeoutError as e:
print(e, flush=True)
else:
assert False, "It should timeout!"
finally:
if rank == 0:
os.remove(stop_file_path)
timeout_func_list = [(gpc_timeout, 2, None), (local_timeout, 1, None), (file_lock_timeout, 2, "test_lock.log")]
@pytest.mark.parametrize("timeout_func_and_args", timeout_func_list)
def test_timeout(timeout_func_and_args):
timeout_func, world_size, other_args = timeout_func_and_args
procs = []
for i in range(world_size):
if other_args is None:
args = (i, world_size)
else:
args = (i, world_size, other_args)
proc = Process(target=timeout_func, args=args)
proc.start()
procs.append(proc)
for proc in procs:
proc.join(15)
if proc.is_alive():
proc.terminate()
proc.join()