InternLM/internlm/utils/common.py

249 lines
7.2 KiB
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import bisect
import inspect
import os
import random
from contextlib import contextmanager
from datetime import datetime
from typing import Union
import numpy as np
import torch
import internlm
CURRENT_TIME = None
def parse_args():
parser = internlm.get_default_parser()
args = parser.parse_args()
return args
def get_master_node():
import subprocess
if os.getenv("SLURM_JOB_ID") is None:
raise RuntimeError("get_master_node can only used in Slurm launch!")
result = subprocess.check_output('scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1', shell=True)
result = result.decode("utf8").strip()
return result
def get_process_rank():
proc_rank = -1
if os.getenv("SLURM_PROCID") is not None:
proc_rank = int(os.getenv("SLURM_PROCID"))
elif os.getenv("RANK") is not None:
# In k8s env, we use $RANK.
proc_rank = int(os.getenv("RANK"))
# assert proc_rank != -1, "get_process_rank cant't get right process rank!"
return proc_rank
def move_norm_to_cuda(norm: Union[float, torch.Tensor]) -> Union[float, torch.Tensor]:
if torch.is_tensor(norm) and norm.device.type != "cuda":
norm = norm.to(torch.cuda.current_device())
return norm
def _move_tensor(element):
if not torch.is_tensor(element):
# we expecte the data type if a list of dictionaries
for item in element:
if isinstance(item, dict):
for key, value in item.items():
assert not value.is_cuda, "elements are already on devices."
item[key] = value.to(get_current_device()).detach()
elif isinstance(item, list):
for index, value in enumerate(item):
assert not value.is_cuda, "elements are already on devices."
item[index] = value.to(get_current_device()).detach()
elif torch.is_tensor(item):
if not item.is_cuda:
item = item.to(get_current_device()).detach()
else:
assert torch.is_tensor(element), f"element should be of type tensor, but got {type(element)}"
if not element.is_cuda:
element = element.to(get_current_device()).detach()
return element
def move_to_device(data):
if isinstance(data, torch.Tensor):
data = data.to(get_current_device())
elif isinstance(data, (list, tuple)):
data_to_return = []
for element in data:
if isinstance(element, dict):
data_to_return.append(
{
k: (
_move_tensor(v)
if k != "inference_params"
else v._replace(attention_mask=_move_tensor(v.attention_mask))
)
for k, v in element.items()
}
)
else:
data_to_return.append(_move_tensor(element))
data = data_to_return
elif isinstance(data, dict):
data = {
k: (
_move_tensor(v)
if k != "inference_params"
else v._replace(attention_mask=_move_tensor(v.attention_mask))
)
for k, v in data.items()
}
else:
raise TypeError(f"Expected batch data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
return data
def get_tensor_norm(norm: Union[float, torch.Tensor], move_to_cuda) -> torch.Tensor:
if isinstance(norm, float):
norm = torch.Tensor([norm])
if move_to_cuda:
norm = norm.to(torch.cuda.current_device())
return norm
def get_current_device() -> torch.device:
"""
Returns currently selected device (gpu/cpu).
If cuda available, return gpu, otherwise return cpu.
"""
if torch.cuda.is_available():
return torch.device(f"cuda:{torch.cuda.current_device()}")
else:
return torch.device("cpu")
def get_batch_size(data):
if isinstance(data, torch.Tensor):
return data.size(0)
elif isinstance(data, (list, tuple)):
if isinstance(data[0], dict):
return data[0][list(data[0].keys())[0]].size(0)
return data[0].size(0)
elif isinstance(data, dict):
return data[list(data.keys())[0]].size(0)
def filter_kwargs(func, kwargs):
sig = inspect.signature(func)
return {k: v for k, v in kwargs.items() if k in sig.parameters}
def launch_time():
global CURRENT_TIME
if not CURRENT_TIME:
CURRENT_TIME = datetime.now().strftime("%b%d_%H-%M-%S")
return CURRENT_TIME
def set_random_seed(seed):
"""Set random seed for reproducability."""
# It is recommended to use this only when inference.
if seed is not None:
assert seed > 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
# if you are using multi-GPU.
torch.cuda.manual_seed_all(seed)
@contextmanager
def conditional_context(context_manager, enable=True):
if enable:
with context_manager:
yield
else:
yield
class BatchSkipper:
"""
BatchSkipper is used to determine whether to skip the current batch_idx.
"""
def __init__(self, skip_batches):
if skip_batches == "":
pass
intervals = skip_batches.split(",")
spans = []
if skip_batches != "":
for interval in intervals:
if "-" in interval:
start, end = map(int, interval.split("-"))
else:
start, end = int(interval), int(interval)
if spans:
assert spans[-1] <= start
spans.extend((start, end + 1))
self.spans = spans
def __call__(self, batch_count):
index = bisect.bisect_right(self.spans, batch_count)
return index % 2 == 1
class SingletonMeta(type):
"""
Singleton Meta.
"""
_instances = {}
def __call__(cls, *args, **kwargs):
if cls not in cls._instances:
cls._instances[cls] = super().__call__(*args, **kwargs)
else:
assert (
len(args) == 0 and len(kwargs) == 0
), f"{cls.__name__} is a singleton class and a instance has been created."
return cls._instances[cls]
def get_megatron_flops(
elapsed_time_per_iter,
checkpoint=False,
seq_len=2048,
hidden_size=12,
num_layers=32,
vocab_size=12,
global_batch_size=4,
global_world_size=1,
mlp_ratio=4,
use_swiglu=True,
):
"""
Calc flops based on the paper of Megatron https://deepakn94.github.io/assets/papers/megatron-sc21.pdf
"""
checkpoint_activations_factor = 4 if checkpoint else 3
if use_swiglu:
mlp_ratio = mlp_ratio * 3 / 2
flops_per_iteration = (
checkpoint_activations_factor
* (
(8 + mlp_ratio * 4) * global_batch_size * seq_len * hidden_size**2
+ 4 * global_batch_size * seq_len**2 * hidden_size
)
) * num_layers + 6 * global_batch_size * seq_len * hidden_size * vocab_size
tflops = flops_per_iteration / (elapsed_time_per_iter * global_world_size * (10**12))
return tflops