mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			120 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			120 lines
		
	
	
		
			3.3 KiB
		
	
	
	
		
			Python
		
	
	
| import fcntl
 | |
| import os
 | |
| import time
 | |
| from multiprocessing import Process
 | |
| 
 | |
| import pytest
 | |
| import torch
 | |
| import torch.distributed as dist
 | |
| 
 | |
| os.environ["INTERNLM_ENABLE_TIMEOUT"] = "1"  # noqa  # pylint: disable=wrong-import-position
 | |
| os.environ["NCCL_TIMEOUT"] = "5"
 | |
| from internlm.utils.timeout import llm_timeout
 | |
| from tests.test_utils.common_fixture import (  # noqa # pylint: disable=unused-import
 | |
|     init_config,
 | |
| )
 | |
| 
 | |
| WORLD_SIZE = 2
 | |
| 
 | |
| 
 | |
| @llm_timeout(2, "fake_timeout_func")
 | |
| def fake_timeout_func():
 | |
|     time.sleep(10)
 | |
| 
 | |
| 
 | |
| @llm_timeout(10, "nccl_timeout_func")
 | |
| def nccl_timeout_func(rank):
 | |
|     # see: https://github.com/pytorch/pytorch/issues/104506#issuecomment-1679762880
 | |
|     # 'NCCL_ASYNC_ERROR_HANDLING' cannot take effect on the first collective communication.
 | |
|     buff = torch.ones([64, 64]).cuda(rank)
 | |
|     dist.all_reduce(buff)  # lazy communicator init
 | |
|     torch.cuda.synchronize()
 | |
|     if rank == 0:
 | |
|         dist.all_reduce(buff)
 | |
|         torch.cuda.synchronize()  # main thread will hang at here.
 | |
|     else:
 | |
|         time.sleep(9999)
 | |
| 
 | |
| 
 | |
| @llm_timeout(10, "try_file_lock")
 | |
| def try_file_lock(rank, stop_file_path):
 | |
|     if rank == 1:
 | |
|         time.sleep(5)
 | |
| 
 | |
|     with open(stop_file_path, "r", encoding="utf-8") as f:
 | |
|         fcntl.flock(f, fcntl.LOCK_EX)  # rank 1 hang.
 | |
|         if rank == 0:
 | |
|             time.sleep(99999)  # rank 0 hang.
 | |
|         f.seek(0)
 | |
|         f.read()
 | |
|         fcntl.flock(f, fcntl.LOCK_UN)
 | |
| 
 | |
| 
 | |
| def local_timeout(rank, _):
 | |
| 
 | |
|     try:
 | |
|         fake_timeout_func()
 | |
|     except TimeoutError as e:
 | |
|         print(f"local_timeout, rank:{rank}, e:{e}", flush=True)
 | |
|     else:
 | |
|         assert False, "It should timeout!"
 | |
| 
 | |
| 
 | |
| def gpc_timeout(rank, world_size):
 | |
| 
 | |
|     from internlm.initialize import initialize_distributed_env
 | |
| 
 | |
|     os.environ["RANK"] = str(rank)
 | |
|     os.environ["LOCAL_RANK"] = str(rank)
 | |
|     os.environ["WORLD_SIZE"] = str(world_size)
 | |
|     os.environ["MASTER_ADDR"] = "127.0.0.1"
 | |
|     os.environ["MASTER_PORT"] = "12377"
 | |
|     initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
 | |
| 
 | |
|     try:
 | |
|         nccl_timeout_func(rank)
 | |
|     except TimeoutError as e:
 | |
|         print(f"gpc_timeout, rank:{rank}, e:{e}", flush=True)
 | |
|         time.sleep(5)  # wait rank 0 to be killed
 | |
|     else:
 | |
|         time.sleep(5)  # give some time to let Watchdog kill rank 0.
 | |
|         assert False, "It should timeout!"
 | |
| 
 | |
| 
 | |
| def file_lock_timeout(rank, _, stop_file_path):
 | |
|     if rank == 0:
 | |
|         with open(stop_file_path, "w"):
 | |
|             pass
 | |
|     try:
 | |
|         try_file_lock(rank, stop_file_path)
 | |
|     except TimeoutError as e:
 | |
|         print(e, flush=True)
 | |
|     else:
 | |
|         assert False, "It should timeout!"
 | |
|     finally:
 | |
|         if rank == 0:
 | |
|             os.remove(stop_file_path)
 | |
| 
 | |
| 
 | |
| timeout_func_list = [(gpc_timeout, 2, None), (local_timeout, 1, None), (file_lock_timeout, 2, "test_lock.log")]
 | |
| 
 | |
| 
 | |
| @pytest.mark.parametrize("timeout_func_and_args", timeout_func_list)
 | |
| def test_timeout(timeout_func_and_args):
 | |
|     timeout_func, world_size, other_args = timeout_func_and_args
 | |
|     procs = []
 | |
|     for i in range(world_size):
 | |
|         if other_args is None:
 | |
|             args = (i, world_size)
 | |
|         else:
 | |
|             args = (i, world_size, other_args)
 | |
|         proc = Process(target=timeout_func, args=args)
 | |
|         proc.start()
 | |
|         procs.append(proc)
 | |
| 
 | |
|     for proc in procs:
 | |
|         proc.join(15)
 | |
|         if proc.is_alive():
 | |
|             proc.terminate()
 | |
|             proc.join()
 |