mirror of https://github.com/InternLM/InternLM
				
				
				
			feat: add runtime diag (#297)
* feat: add runtime diag * add diag_outlier_ratio --------- Co-authored-by: yingtongxiong <974106207@qq.com>pull/299/head
							parent
							
								
									06807a6fd5
								
							
						
					
					
						commit
						1ee31ff9b1
					
				|  | @ -56,6 +56,8 @@ data = dict( | |||
|     min_length=50, | ||||
|     # train_folder=TRAIN_FOLDER, | ||||
|     # valid_folder=VALID_FOLDER, | ||||
|     empty_cache_and_diag_interval=10, | ||||
|     diag_outlier_ratio=1.1, | ||||
| ) | ||||
| 
 | ||||
| grad_scaler = dict( | ||||
|  |  | |||
|  | @ -98,6 +98,13 @@ def args_sanity_check(): | |||
|     if "valid_every" not in data: | ||||
|         data._add_item("valid_every", 0) | ||||
| 
 | ||||
|     if "empty_cache_and_diag_interval" not in data: | ||||
|         data._add_item("empty_cache_and_diag_interval", 50) | ||||
| 
 | ||||
|     if "diag_outlier_ratio" not in data: | ||||
|         data._add_item("diag_outlier_ratio", 1.1) | ||||
|     data.diag_outlier_ratio = max(1, data.diag_outlier_ratio) | ||||
| 
 | ||||
|     if gpc.is_rank_for_log(): | ||||
|         logger.info("+" * 15 + " Data Info " + "+" * 15)  # pylint: disable=W1201 | ||||
|         logger.info(f"seq_len: {data.seq_len}") | ||||
|  |  | |||
|  | @ -570,6 +570,7 @@ class HybridZeroOptimizer(BaseOptimizer): | |||
| 
 | ||||
|         # check for overflow | ||||
|         found_inf = False | ||||
|         found_nan = False | ||||
|         # if there is INF values in grades, compute_norm func would also returns -1 | ||||
|         # thus, we try to avoid call _check_overflow here | ||||
|         # found_inf = self._check_overflow() | ||||
|  | @ -578,9 +579,13 @@ class HybridZeroOptimizer(BaseOptimizer): | |||
|         if -1 in norms.values(): | ||||
|             found_inf = True | ||||
| 
 | ||||
|         if -2 in norms.values(): | ||||
|             found_nan = True | ||||
| 
 | ||||
|         loss_scale = float(self.loss_scale.item())  # backup | ||||
|         if gpc.config.model.dtype is not torch.float32: | ||||
|             self.grad_scaler.update(found_inf) | ||||
| 
 | ||||
|         # update loss scale if overflow occurs | ||||
|         if found_inf: | ||||
|             if gpc.is_rank_for_log(): | ||||
|  | @ -593,6 +598,17 @@ class HybridZeroOptimizer(BaseOptimizer): | |||
|             self.zero_grad() | ||||
|             return False, norms | ||||
| 
 | ||||
|         if found_nan: | ||||
|             if gpc.is_rank_for_log(): | ||||
|                 logger.warning("Nan grad norm occurs, please check it.") | ||||
|                 send_alert_message( | ||||
|                     address=gpc.config.monitor.alert.feishu_alert_address, | ||||
|                     message="Nan grad norm  occurs, please check it.", | ||||
|                 ) | ||||
|             self._grad_store._averaged_gradients = dict() | ||||
|             self.zero_grad() | ||||
|             return False, norms | ||||
| 
 | ||||
|         # copy the grad of fp16 param to fp32 param | ||||
|         single_grad_partition_groups = [] | ||||
|         for group_id in range(self.num_param_groups): | ||||
|  |  | |||
|  | @ -311,6 +311,9 @@ def compute_norm(gradients, parameters, last_stage=False, previous_norm=None, no | |||
|     if total_norm == float("inf") or total_norm == -float("inf"): | ||||
|         total_norm = -1 | ||||
| 
 | ||||
|     if math.isnan(total_norm): | ||||
|         total_norm = -2 | ||||
| 
 | ||||
|     return total_norm | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -354,6 +354,7 @@ def record_current_batch_training_metrics( | |||
| 
 | ||||
|     set_env_var(key="LAST_ACTIVE_TIMESTAMP", value=int(time.time())) | ||||
| 
 | ||||
|     timer.store_last_timers() | ||||
|     if success_update in (0, True): | ||||
|         train_state.num_consumed_tokens += batch[1].nelement() * gpc.get_world_size(ParallelMode.DATA) | ||||
|     if is_no_pp_or_last_stage(): | ||||
|  |  | |||
|  | @ -9,7 +9,9 @@ import torch.distributed as dist | |||
| from flash_attn.modules.mha import FlashSelfAttention, SelfAttention | ||||
| from torch.utils import benchmark | ||||
| 
 | ||||
| from internlm.monitor import send_alert_message | ||||
| from internlm.utils.logger import get_logger | ||||
| from internlm.utils.megatron_timers import megatron_timer as timer | ||||
| 
 | ||||
| try: | ||||
|     import GPUtil | ||||
|  | @ -24,6 +26,23 @@ from internlm.utils.common import get_current_device | |||
| logger = get_logger(__file__) | ||||
| 
 | ||||
| 
 | ||||
| def empty_cache_and_diag(batch_count, interval=50): | ||||
|     """empty cuda cache and run diag bench or tests.""" | ||||
|     if interval <= 0: | ||||
|         interval = 50 | ||||
|     if batch_count % int(interval) == 0: | ||||
|         # there is no need to do diag on the first batch | ||||
|         if batch_count > 0: | ||||
|             if gpc.is_rank_for_log(): | ||||
|                 logger.info("Empty Cache and Diagnosis GPU/NCCL/Timer ...") | ||||
|             with torch.no_grad(): | ||||
|                 timer_diagnosis() | ||||
|                 bench_gpu() | ||||
|                 bench_net() | ||||
|         # do empty_cache after the bench | ||||
|         torch.cuda.empty_cache() | ||||
| 
 | ||||
| 
 | ||||
| def benchmark_forward( | ||||
|     test_fn, | ||||
|     *inputs, | ||||
|  | @ -81,14 +100,78 @@ def get_cpu_temperature(): | |||
|     return cpu_temperature | ||||
| 
 | ||||
| 
 | ||||
| def timer_diagnosis(): | ||||
|     """Diagnosis running time""" | ||||
| 
 | ||||
|     if len(timer.names) == 0 or len(timer.times) == 0: | ||||
|         return | ||||
| 
 | ||||
|     world_size = gpc.get_world_size(ParallelMode.DATA) | ||||
|     if world_size < 2: | ||||
|         return | ||||
| 
 | ||||
|     # if gpc.is_rank_for_log(): | ||||
|     #     logger.info("Diagnosis running timers ...") | ||||
| 
 | ||||
|     # detect slow rank compared to other ranks in the same DP group | ||||
|     running_time = torch.Tensor(timer.times).to(device=get_current_device()) | ||||
|     avg_time = running_time.detach().clone() | ||||
|     if world_size <= 4: | ||||
|         dist.all_reduce(avg_time, op=torch.distributed.ReduceOp.AVG, group=gpc.get_group(ParallelMode.DATA)) | ||||
|     else: | ||||
|         running_time_max = avg_time.detach().clone() | ||||
|         running_time_min = avg_time.detach().clone() | ||||
|         dist.all_reduce(running_time_max, op=torch.distributed.ReduceOp.MAX, group=gpc.get_group(ParallelMode.DATA)) | ||||
|         dist.all_reduce(running_time_min, op=torch.distributed.ReduceOp.MIN, group=gpc.get_group(ParallelMode.DATA)) | ||||
|         dist.all_reduce(avg_time, op=torch.distributed.ReduceOp.SUM, group=gpc.get_group(ParallelMode.DATA)) | ||||
|         avg_time = (avg_time - running_time_max - running_time_min) / (world_size - 2) | ||||
| 
 | ||||
|     diag_result = running_time > avg_time * gpc.config.data.diag_outlier_ratio | ||||
|     diag_result = diag_result.tolist() | ||||
|     avg_time = avg_time.tolist() | ||||
| 
 | ||||
|     for slow, name, time, avg in zip(diag_result, timer.names, timer.times, avg_time): | ||||
|         if slow is False or avg < 0.5: | ||||
|             continue | ||||
|         msg = ( | ||||
|             f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} is slower than avg on {name}, " | ||||
|             f"Hostname {socket.gethostname()}, " | ||||
|             f"its time {time:.2f}, avg {avg:.2f}, " | ||||
|             f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}" | ||||
|         ) | ||||
|         logger.warning(msg) | ||||
|         send_alert_message( | ||||
|             address=gpc.config.monitor.alert.feishu_alert_address, | ||||
|             message=msg, | ||||
|         ) | ||||
| 
 | ||||
|     # detect slow rank compared to historical timer data | ||||
|     for name, time in zip(timer.names, timer.times): | ||||
|         if name not in timer.hist or len(timer.hist[name]) < 5: | ||||
|             continue | ||||
|         hist_avg = sum(timer.hist[name]) / len(timer.hist[name]) | ||||
|         if time > hist_avg * gpc.config.data.diag_outlier_ratio and time > 0.5: | ||||
|             msg = ( | ||||
|                 f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} is slower than hist avg on {name}, " | ||||
|                 f"Hostname {socket.gethostname()}, " | ||||
|                 f"its time {time:.2f}, hist_avg {hist_avg:.2f}, " | ||||
|                 f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}" | ||||
|             ) | ||||
|             logger.warning(msg) | ||||
|             send_alert_message( | ||||
|                 address=gpc.config.monitor.alert.feishu_alert_address, | ||||
|                 message=msg, | ||||
|             ) | ||||
| 
 | ||||
| 
 | ||||
| def bench_net(): | ||||
|     """Benchmark nccl performance for slow node detection.""" | ||||
| 
 | ||||
|     if gpc.get_world_size(ParallelMode.GLOBAL) <= 1: | ||||
|         return | ||||
| 
 | ||||
|     if gpc.is_rank_for_log(): | ||||
|         logger.info("benchmarking network speed ...") | ||||
|     # if gpc.is_rank_for_log(): | ||||
|     #     logger.info("benchmarking network speed ...") | ||||
| 
 | ||||
|     repeats = 100 | ||||
|     input_data = torch.randn( | ||||
|  | @ -113,20 +196,25 @@ def bench_net(): | |||
|     allreduce_time_avg = allreduce_time / gpc.get_world_size(ParallelMode.GLOBAL) | ||||
|     allreduce_time_avg = float(allreduce_time_avg.item()) | ||||
| 
 | ||||
|     if allreduce_time_this >= allreduce_time_avg * 1.05: | ||||
|         logger.warning( | ||||
|     if allreduce_time_this >= allreduce_time_avg * gpc.config.data.diag_outlier_ratio: | ||||
|         msg = ( | ||||
|             f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} NCCL test is slower than avg, " | ||||
|             f"Hostname {socket.gethostname()}, " | ||||
|             f"allreduce_time {allreduce_time_this:.2f}, avg {allreduce_time_avg:.2f}, " | ||||
|             f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}" | ||||
|         ) | ||||
|         logger.warning(msg) | ||||
|         send_alert_message( | ||||
|             address=gpc.config.monitor.alert.feishu_alert_address, | ||||
|             message=msg, | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| def bench_gpu(use_flash_attn=True): | ||||
|     """Benchmark single GPU performance for slow node detection.""" | ||||
| 
 | ||||
|     if gpc.is_rank_for_log(): | ||||
|         logger.info("benchmarking gpu speed ...") | ||||
|     # if gpc.is_rank_for_log(): | ||||
|     #     logger.info("benchmarking gpu speed ...") | ||||
| 
 | ||||
|     headdim = 64 | ||||
|     dim = 2048 | ||||
|  | @ -154,10 +242,15 @@ def bench_gpu(use_flash_attn=True): | |||
|     speed_avg = speed / gpc.get_world_size(ParallelMode.GLOBAL) | ||||
|     speed_avg = float(speed_avg.item()) | ||||
| 
 | ||||
|     if speed_this <= speed_avg * 0.95: | ||||
|         logger.warning( | ||||
|     if speed_this <= speed_avg / gpc.config.data.diag_outlier_ratio: | ||||
|         msg = ( | ||||
|             f"Rank {gpc.get_local_rank(ParallelMode.GLOBAL)} GPU is slower than avg, " | ||||
|             f"Hostname {socket.gethostname()}, " | ||||
|             f"tflops {speed_this:.2f}, avg {speed_avg:.2f}, " | ||||
|             f"CPU temp {get_cpu_temperature()}, GPU temp { get_gpu_temperature()}" | ||||
|         ) | ||||
|         logger.warning(msg) | ||||
|         send_alert_message( | ||||
|             address=gpc.config.monitor.alert.feishu_alert_address, | ||||
|             message=msg, | ||||
|         ) | ||||
|  |  | |||
|  | @ -16,8 +16,12 @@ class _Timer: | |||
|         self.start_time = time.time() | ||||
|         self.stream = torch.cuda.current_stream() | ||||
| 
 | ||||
|     def start(self): | ||||
|     def start(self, reset_all=True): | ||||
|         """Start the timer.""" | ||||
|         # need to reset all timers in a new batch | ||||
|         if self.name_ == "one-batch" and reset_all is True: | ||||
|             megatron_timer.reset() | ||||
| 
 | ||||
|         assert not self.started_, "timer has already been started" | ||||
|         self.stream.synchronize() | ||||
|         self.start_time = time.time() | ||||
|  | @ -48,7 +52,7 @@ class _Timer: | |||
|             self.reset() | ||||
|         # If timing was in progress, set it back. | ||||
|         if started_: | ||||
|             self.start() | ||||
|             self.start(reset_all=False) | ||||
|         return elapsed_ | ||||
| 
 | ||||
| 
 | ||||
|  | @ -57,12 +61,29 @@ class Timers: | |||
| 
 | ||||
|     def __init__(self): | ||||
|         self.timers = {} | ||||
|         self.hist = {} | ||||
|         self.names = [] | ||||
|         self.times = [] | ||||
| 
 | ||||
|     def __call__(self, name): | ||||
|         if name not in self.timers: | ||||
|             self.timers[name] = _Timer(name) | ||||
|         return self.timers[name] | ||||
| 
 | ||||
|     def store_last_timers(self): | ||||
|         """Store timers to two list""" | ||||
|         self.names = [] | ||||
|         self.times = [] | ||||
|         for key, value in self.timers.items(): | ||||
|             senconds = round(float(value.elapsed(reset=False)), 4) | ||||
|             self.names.append(key) | ||||
|             self.times.append(senconds) | ||||
|             if key not in self.hist: | ||||
|                 self.hist[key] = [] | ||||
|             self.hist[key].append(senconds) | ||||
|             if len(self.hist[key]) > 10: | ||||
|                 self.hist[key].pop(0) | ||||
| 
 | ||||
|     def write(self, names, writer, iteration, normalizer=1.0, reset=False): | ||||
|         """Write timers to a tensorboard writer""" | ||||
|         # currently when using add_scalars, | ||||
|  |  | |||
							
								
								
									
										5
									
								
								train.py
								
								
								
								
							
							
						
						
									
										5
									
								
								train.py
								
								
								
								
							|  | @ -35,6 +35,7 @@ from internlm.utils.common import ( | |||
|     parse_args, | ||||
| ) | ||||
| from internlm.utils.evaluation import evaluate_on_val_dls | ||||
| from internlm.utils.gputest import empty_cache_and_diag | ||||
| from internlm.utils.logger import get_logger, initialize_uniscale_logger | ||||
| from internlm.utils.megatron_timers import megatron_timer as timer | ||||
| from internlm.utils.model_checkpoint import CheckpointManager | ||||
|  | @ -193,9 +194,7 @@ def main(args): | |||
|     with initialize_llm_profile(profiling=args.profiling, start_time=current_time) as prof: | ||||
|         # start iterating the train data and begin training | ||||
|         for batch_count in range(train_state.batch_count, total_steps): | ||||
|             if batch_count % 50 == 0: | ||||
|                 torch.cuda.empty_cache() | ||||
| 
 | ||||
|             empty_cache_and_diag(batch_count, interval=gpc.config.data.empty_cache_and_diag_interval) | ||||
|             start_time = time.time() | ||||
|             timer("one-batch").start() | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	 Sun Peng
						Sun Peng