mirror of https://github.com/InternLM/InternLM
249 lines
7.2 KiB
Python
249 lines
7.2 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import bisect
|
|
import inspect
|
|
import os
|
|
import random
|
|
from contextlib import contextmanager
|
|
from datetime import datetime
|
|
from typing import Union
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
import internlm
|
|
|
|
CURRENT_TIME = None
|
|
|
|
|
|
def parse_args():
|
|
parser = internlm.get_default_parser()
|
|
args = parser.parse_args()
|
|
|
|
return args
|
|
|
|
|
|
def get_master_node():
|
|
import subprocess
|
|
|
|
if os.getenv("SLURM_JOB_ID") is None:
|
|
raise RuntimeError("get_master_node can only used in Slurm launch!")
|
|
result = subprocess.check_output('scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1', shell=True)
|
|
result = result.decode("utf8").strip()
|
|
return result
|
|
|
|
|
|
def get_process_rank():
|
|
proc_rank = -1
|
|
if os.getenv("SLURM_PROCID") is not None:
|
|
proc_rank = int(os.getenv("SLURM_PROCID"))
|
|
elif os.getenv("RANK") is not None:
|
|
# In k8s env, we use $RANK.
|
|
proc_rank = int(os.getenv("RANK"))
|
|
|
|
# assert proc_rank != -1, "get_process_rank cant't get right process rank!"
|
|
return proc_rank
|
|
|
|
|
|
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
|
|
if torch.is_tensor(norm) and norm.device.type != "cuda":
|
|
norm = norm.to(torch.cuda.current_device())
|
|
return norm
|
|
|
|
|
|
def _move_tensor(element):
|
|
if not torch.is_tensor(element):
|
|
# we expecte the data type if a list of dictionaries
|
|
for item in element:
|
|
if isinstance(item, dict):
|
|
for key, value in item.items():
|
|
assert not value.is_cuda, "elements are already on devices."
|
|
item[key] = value.to(get_current_device()).detach()
|
|
elif isinstance(item, list):
|
|
for index, value in enumerate(item):
|
|
assert not value.is_cuda, "elements are already on devices."
|
|
item[index] = value.to(get_current_device()).detach()
|
|
elif torch.is_tensor(item):
|
|
if not item.is_cuda:
|
|
item = item.to(get_current_device()).detach()
|
|
else:
|
|
assert torch.is_tensor(element), f"element should be of type tensor, but got {type(element)}"
|
|
if not element.is_cuda:
|
|
element = element.to(get_current_device()).detach()
|
|
return element
|
|
|
|
|
|
def move_to_device(data):
|
|
if isinstance(data, torch.Tensor):
|
|
data = data.to(get_current_device())
|
|
elif isinstance(data, (list, tuple)):
|
|
data_to_return = []
|
|
for element in data:
|
|
if isinstance(element, dict):
|
|
data_to_return.append(
|
|
{
|
|
k: (
|
|
_move_tensor(v)
|
|
if k != "inference_params"
|
|
else v._replace(attention_mask=_move_tensor(v.attention_mask))
|
|
)
|
|
for k, v in element.items()
|
|
}
|
|
)
|
|
else:
|
|
data_to_return.append(_move_tensor(element))
|
|
data = data_to_return
|
|
elif isinstance(data, dict):
|
|
data = {
|
|
k: (
|
|
_move_tensor(v)
|
|
if k != "inference_params"
|
|
else v._replace(attention_mask=_move_tensor(v.attention_mask))
|
|
)
|
|
for k, v in data.items()
|
|
}
|
|
else:
|
|
raise TypeError(f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
|
return data
|
|
|
|
|
|
def get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
|
|
if isinstance(norm, float):
|
|
norm = torch.Tensor([norm])
|
|
if move_to_cuda:
|
|
norm = norm.to(torch.cuda.current_device())
|
|
return norm
|
|
|
|
|
|
def get_current_device() -> torch.device:
|
|
"""
|
|
Returns currently selected device (gpu/cpu).
|
|
If cuda available, return gpu, otherwise return cpu.
|
|
"""
|
|
if torch.cuda.is_available():
|
|
return torch.device(f"cuda:{torch.cuda.current_device()}")
|
|
else:
|
|
return torch.device("cpu")
|
|
|
|
|
|
def get_batch_size(data):
|
|
if isinstance(data, torch.Tensor):
|
|
return data.size(0)
|
|
elif isinstance(data, (list, tuple)):
|
|
if isinstance(data[0], dict):
|
|
return data[0][list(data[0].keys())[0]].size(0)
|
|
return data[0].size(0)
|
|
elif isinstance(data, dict):
|
|
return data[list(data.keys())[0]].size(0)
|
|
|
|
|
|
def filter_kwargs(func, kwargs):
|
|
sig = inspect.signature(func)
|
|
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
|
|
|
|
|
def launch_time():
|
|
global CURRENT_TIME
|
|
if not CURRENT_TIME:
|
|
CURRENT_TIME = datetime.now().strftime("%b%d_%H-%M-%S")
|
|
return CURRENT_TIME
|
|
|
|
|
|
def set_random_seed(seed):
|
|
"""Set random seed for reproducability."""
|
|
# It is recommended to use this only when inference.
|
|
if seed is not None:
|
|
assert seed > 0
|
|
random.seed(seed)
|
|
np.random.seed(seed)
|
|
torch.manual_seed(seed)
|
|
torch.cuda.manual_seed(seed)
|
|
# if you are using multi-GPU.
|
|
torch.cuda.manual_seed_all(seed)
|
|
|
|
|
|
@contextmanager
|
|
def conditional_context(context_manager, enable=True):
|
|
if enable:
|
|
with context_manager:
|
|
yield
|
|
else:
|
|
yield
|
|
|
|
|
|
class BatchSkipper:
|
|
"""
|
|
BatchSkipper is used to determine whether to skip the current batch_idx.
|
|
"""
|
|
|
|
def __init__(self, skip_batches):
|
|
if skip_batches == "":
|
|
pass
|
|
intervals = skip_batches.split(",")
|
|
spans = []
|
|
if skip_batches != "":
|
|
for interval in intervals:
|
|
if "-" in interval:
|
|
start, end = map(int, interval.split("-"))
|
|
else:
|
|
start, end = int(interval), int(interval)
|
|
if spans:
|
|
assert spans[-1] <= start
|
|
spans.extend((start, end + 1))
|
|
self.spans = spans
|
|
|
|
def __call__(self, batch_count):
|
|
index = bisect.bisect_right(self.spans, batch_count)
|
|
return index % 2 == 1
|
|
|
|
|
|
class SingletonMeta(type):
|
|
"""
|
|
Singleton Meta.
|
|
"""
|
|
|
|
_instances = {}
|
|
|
|
def __call__(cls, *args, **kwargs):
|
|
if cls not in cls._instances:
|
|
cls._instances[cls] = super().__call__(*args, **kwargs)
|
|
else:
|
|
assert (
|
|
len(args) == 0 and len(kwargs) == 0
|
|
), f"{cls.__name__} is a singleton class and a instance has been created."
|
|
return cls._instances[cls]
|
|
|
|
|
|
def get_megatron_flops(
|
|
elapsed_time_per_iter,
|
|
checkpoint=False,
|
|
seq_len=2048,
|
|
hidden_size=12,
|
|
num_layers=32,
|
|
vocab_size=12,
|
|
global_batch_size=4,
|
|
global_world_size=1,
|
|
mlp_ratio=4,
|
|
use_swiglu=True,
|
|
):
|
|
"""
|
|
Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf
|
|
"""
|
|
|
|
checkpoint_activations_factor = 4 if checkpoint else 3
|
|
|
|
if use_swiglu:
|
|
mlp_ratio = mlp_ratio * 3 / 2
|
|
|
|
flops_per_iteration = (
|
|
checkpoint_activations_factor
|
|
* (
|
|
(8 + mlp_ratio * 4) * global_batch_size * seq_len * hidden_size**2
|
|
+ 4 * global_batch_size * seq_len**2 * hidden_size
|
|
)
|
|
) * num_layers + 6 * global_batch_size * seq_len * hidden_size * vocab_size
|
|
|
|
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
|
|
return tflops
|