mirror of https://github.com/InternLM/InternLM
				
				
				
			
		
			
				
	
	
		
			249 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
			
		
		
	
	
			249 lines
		
	
	
		
			7.2 KiB
		
	
	
	
		
			Python
		
	
	
| #!/usr/bin/env python
 | |
| # -*- encoding: utf-8 -*-
 | |
| 
 | |
| import bisect
 | |
| import inspect
 | |
| import os
 | |
| import random
 | |
| from contextlib import contextmanager
 | |
| from datetime import datetime
 | |
| from typing import Union
 | |
| 
 | |
| import numpy as np
 | |
| import torch
 | |
| 
 | |
| import internlm
 | |
| 
 | |
| CURRENT_TIME = None
 | |
| 
 | |
| 
 | |
| def parse_args():
 | |
|     parser = internlm.get_default_parser()
 | |
|     args = parser.parse_args()
 | |
| 
 | |
|     return args
 | |
| 
 | |
| 
 | |
| def get_master_node():
 | |
|     import subprocess
 | |
| 
 | |
|     if os.getenv("SLURM_JOB_ID") is None:
 | |
|         raise RuntimeError("get_master_node can only used in Slurm launch!")
 | |
|     result = subprocess.check_output('scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1', shell=True)
 | |
|     result = result.decode("utf8").strip()
 | |
|     return result
 | |
| 
 | |
| 
 | |
| def get_process_rank():
 | |
|     proc_rank = -1
 | |
|     if os.getenv("SLURM_PROCID") is not None:
 | |
|         proc_rank = int(os.getenv("SLURM_PROCID"))
 | |
|     elif os.getenv("RANK") is not None:
 | |
|         # In k8s env, we use $RANK.
 | |
|         proc_rank = int(os.getenv("RANK"))
 | |
| 
 | |
|     # assert proc_rank != -1, "get_process_rank cant't get right process rank!"
 | |
|     return proc_rank
 | |
| 
 | |
| 
 | |
| def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
 | |
|     if torch.is_tensor(norm) and norm.device.type != "cuda":
 | |
|         norm = norm.to(torch.cuda.current_device())
 | |
|     return norm
 | |
| 
 | |
| 
 | |
| def _move_tensor(element):
 | |
|     if not torch.is_tensor(element):
 | |
|         # we expecte the data type if a list of dictionaries
 | |
|         for item in element:
 | |
|             if isinstance(item, dict):
 | |
|                 for key, value in item.items():
 | |
|                     assert not value.is_cuda, "elements are already on devices."
 | |
|                     item[key] = value.to(get_current_device()).detach()
 | |
|             elif isinstance(item, list):
 | |
|                 for index, value in enumerate(item):
 | |
|                     assert not value.is_cuda, "elements are already on devices."
 | |
|                     item[index] = value.to(get_current_device()).detach()
 | |
|             elif torch.is_tensor(item):
 | |
|                 if not item.is_cuda:
 | |
|                     item = item.to(get_current_device()).detach()
 | |
|     else:
 | |
|         assert torch.is_tensor(element), f"element should be of type tensor, but got {type(element)}"
 | |
|         if not element.is_cuda:
 | |
|             element = element.to(get_current_device()).detach()
 | |
|     return element
 | |
| 
 | |
| 
 | |
| def move_to_device(data):
 | |
|     if isinstance(data, torch.Tensor):
 | |
|         data = data.to(get_current_device())
 | |
|     elif isinstance(data, (list, tuple)):
 | |
|         data_to_return = []
 | |
|         for element in data:
 | |
|             if isinstance(element, dict):
 | |
|                 data_to_return.append(
 | |
|                     {
 | |
|                         k: (
 | |
|                             _move_tensor(v)
 | |
|                             if k != "inference_params"
 | |
|                             else v._replace(attention_mask=_move_tensor(v.attention_mask))
 | |
|                         )
 | |
|                         for k, v in element.items()
 | |
|                     }
 | |
|                 )
 | |
|             else:
 | |
|                 data_to_return.append(_move_tensor(element))
 | |
|         data = data_to_return
 | |
|     elif isinstance(data, dict):
 | |
|         data = {
 | |
|             k: (
 | |
|                 _move_tensor(v)
 | |
|                 if k != "inference_params"
 | |
|                 else v._replace(attention_mask=_move_tensor(v.attention_mask))
 | |
|             )
 | |
|             for k, v in data.items()
 | |
|         }
 | |
|     else:
 | |
|         raise TypeError(f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
 | |
|     return data
 | |
| 
 | |
| 
 | |
| def get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
 | |
|     if isinstance(norm, float):
 | |
|         norm = torch.Tensor([norm])
 | |
|     if move_to_cuda:
 | |
|         norm = norm.to(torch.cuda.current_device())
 | |
|     return norm
 | |
| 
 | |
| 
 | |
| def get_current_device() -> torch.device:
 | |
|     """
 | |
|     Returns currently selected device (gpu/cpu).
 | |
|     If cuda available, return gpu, otherwise return cpu.
 | |
|     """
 | |
|     if torch.cuda.is_available():
 | |
|         return torch.device(f"cuda:{torch.cuda.current_device()}")
 | |
|     else:
 | |
|         return torch.device("cpu")
 | |
| 
 | |
| 
 | |
| def get_batch_size(data):
 | |
|     if isinstance(data, torch.Tensor):
 | |
|         return data.size(0)
 | |
|     elif isinstance(data, (list, tuple)):
 | |
|         if isinstance(data[0], dict):
 | |
|             return data[0][list(data[0].keys())[0]].size(0)
 | |
|         return data[0].size(0)
 | |
|     elif isinstance(data, dict):
 | |
|         return data[list(data.keys())[0]].size(0)
 | |
| 
 | |
| 
 | |
| def filter_kwargs(func, kwargs):
 | |
|     sig = inspect.signature(func)
 | |
|     return {k: v for k, v in kwargs.items() if k in sig.parameters}
 | |
| 
 | |
| 
 | |
| def launch_time():
 | |
|     global CURRENT_TIME
 | |
|     if not CURRENT_TIME:
 | |
|         CURRENT_TIME = datetime.now().strftime("%b%d_%H-%M-%S")
 | |
|     return CURRENT_TIME
 | |
| 
 | |
| 
 | |
| def set_random_seed(seed):
 | |
|     """Set random seed for reproducability."""
 | |
|     # It is recommended to use this only when inference.
 | |
|     if seed is not None:
 | |
|         assert seed > 0
 | |
|         random.seed(seed)
 | |
|         np.random.seed(seed)
 | |
|         torch.manual_seed(seed)
 | |
|         torch.cuda.manual_seed(seed)
 | |
|         # if you are using multi-GPU.
 | |
|         torch.cuda.manual_seed_all(seed)
 | |
| 
 | |
| 
 | |
| @contextmanager
 | |
| def conditional_context(context_manager, enable=True):
 | |
|     if enable:
 | |
|         with context_manager:
 | |
|             yield
 | |
|     else:
 | |
|         yield
 | |
| 
 | |
| 
 | |
| class BatchSkipper:
 | |
|     """
 | |
|     BatchSkipper is used to determine whether to skip the current batch_idx.
 | |
|     """
 | |
| 
 | |
|     def __init__(self, skip_batches):
 | |
|         if skip_batches == "":
 | |
|             pass
 | |
|         intervals = skip_batches.split(",")
 | |
|         spans = []
 | |
|         if skip_batches != "":
 | |
|             for interval in intervals:
 | |
|                 if "-" in interval:
 | |
|                     start, end = map(int, interval.split("-"))
 | |
|                 else:
 | |
|                     start, end = int(interval), int(interval)
 | |
|                 if spans:
 | |
|                     assert spans[-1] <= start
 | |
|                 spans.extend((start, end + 1))
 | |
|         self.spans = spans
 | |
| 
 | |
|     def __call__(self, batch_count):
 | |
|         index = bisect.bisect_right(self.spans, batch_count)
 | |
|         return index % 2 == 1
 | |
| 
 | |
| 
 | |
| class SingletonMeta(type):
 | |
|     """
 | |
|     Singleton Meta.
 | |
|     """
 | |
| 
 | |
|     _instances = {}
 | |
| 
 | |
|     def __call__(cls, *args, **kwargs):
 | |
|         if cls not in cls._instances:
 | |
|             cls._instances[cls] = super().__call__(*args, **kwargs)
 | |
|         else:
 | |
|             assert (
 | |
|                 len(args) == 0 and len(kwargs) == 0
 | |
|             ), f"{cls.__name__} is a singleton class and a instance has been created."
 | |
|         return cls._instances[cls]
 | |
| 
 | |
| 
 | |
| def get_megatron_flops(
 | |
|     elapsed_time_per_iter,
 | |
|     checkpoint=False,
 | |
|     seq_len=2048,
 | |
|     hidden_size=12,
 | |
|     num_layers=32,
 | |
|     vocab_size=12,
 | |
|     global_batch_size=4,
 | |
|     global_world_size=1,
 | |
|     mlp_ratio=4,
 | |
|     use_swiglu=True,
 | |
| ):
 | |
|     """
 | |
|     Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf
 | |
|     """
 | |
| 
 | |
|     checkpoint_activations_factor = 4 if checkpoint else 3
 | |
| 
 | |
|     if use_swiglu:
 | |
|         mlp_ratio = mlp_ratio * 3 / 2
 | |
| 
 | |
|     flops_per_iteration = (
 | |
|         checkpoint_activations_factor
 | |
|         * (
 | |
|             (8 + mlp_ratio * 4) * global_batch_size * seq_len * hidden_size**2
 | |
|             + 4 * global_batch_size * seq_len**2 * hidden_size
 | |
|         )
 | |
|     ) * num_layers + 6 * global_batch_size * seq_len * hidden_size * vocab_size
 | |
| 
 | |
|     tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
 | |
|     return tflops
 |