mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
161 lines
5.7 KiB
161 lines
5.7 KiB
""" |
|
Utils for model inference |
|
""" |
|
|
|
import math |
|
import os |
|
import re |
|
from pathlib import Path |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from colossalai.logging import get_dist_logger |
|
from colossalai.testing import free_port |
|
|
|
logger = get_dist_logger(__name__) |
|
|
|
|
|
def init_to_get_rotary(self, base=10000, use_elem=False): |
|
""" |
|
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer |
|
Args: |
|
self : Model that holds the rotary positional embedding |
|
base : calculation arg |
|
use_elem : activated when using chatglm-based models |
|
""" |
|
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads |
|
if not hasattr(self.config, "rope_scaling"): |
|
rope_scaling_factor = 1.0 |
|
else: |
|
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0 |
|
|
|
if hasattr(self.config, "max_sequence_length"): |
|
max_seq_len = self.config.max_sequence_length |
|
elif hasattr(self.config, "max_position_embeddings"): |
|
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor |
|
else: |
|
max_seq_len = 2048 * rope_scaling_factor |
|
base = float(base) |
|
|
|
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/ |
|
ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None) |
|
|
|
if ntk_alpha is not None: |
|
ntk_alpha = float(ntk_alpha) |
|
assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1" |
|
if ntk_alpha > 1: |
|
print(f"Note: NTK enabled, alpha set to {ntk_alpha}") |
|
max_seq_len *= ntk_alpha |
|
base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula |
|
|
|
n_elem = self.config.head_dim_ |
|
if use_elem: |
|
n_elem //= 2 |
|
|
|
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem)) |
|
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor |
|
freqs = torch.outer(t, inv_freq) |
|
|
|
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda() |
|
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda() |
|
|
|
|
|
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]: |
|
""" |
|
Check whether the checkpoint has an index file. |
|
|
|
Args: |
|
checkpoint_path (str): path to the checkpoint. |
|
|
|
Returns: |
|
Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path) |
|
""" |
|
checkpoint_path = Path(checkpoint_path) |
|
if checkpoint_path.is_file(): |
|
# check if it is .index.json |
|
reg = re.compile("(.*?).index((\..*)?).json") |
|
if reg.fullmatch(checkpoint_path.name) is not None: |
|
return True, checkpoint_path |
|
else: |
|
return False, None |
|
elif checkpoint_path.is_dir(): |
|
index_files = list(checkpoint_path.glob("*.index.*json")) |
|
|
|
for index_file in index_files: |
|
if "safetensors" in index_file.__str__(): |
|
return True, index_file.__str__() # return the safetensors file first |
|
|
|
if len(index_files) == 1: |
|
return True, index_files[0] |
|
else: |
|
assert ( |
|
len(index_files) == 1 |
|
), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}" |
|
return False, None |
|
else: |
|
raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.") |
|
|
|
|
|
def get_model_size(model: nn.Module): |
|
"""Calculates the total size of the model weights (including biases) in bytes. |
|
Args: |
|
model: The PyTorch model to analyze. |
|
Returns: |
|
The total size of the model weights in bytes. |
|
""" |
|
total_size = 0 |
|
for key, param in model.named_parameters(): |
|
total_size += param.element_size() * param.numel() |
|
return total_size / (1024**3) |
|
|
|
|
|
def find_available_ports(num: int): |
|
try: |
|
free_ports = [free_port() for i in range(num)] |
|
except OSError as e: |
|
print(f"An OS error occurred: {e}") |
|
raise RuntimeError("Error finding available ports") |
|
return free_ports |
|
|
|
|
|
def get_alibi_slopes(num_heads: int, device: torch.device) -> torch.Tensor: |
|
""" |
|
Alibi slopes calculation adapted from https://github.com/huggingface/transformers/blob/v4.36.0/src/transformers/models/bloom/modeling_bloom.py#L57 |
|
|
|
Args: |
|
num_heads (int): The number of attention heads. |
|
device (torch.device): The device to use. |
|
|
|
Returns: |
|
torch.Tensor: The Alibi slopes. |
|
""" |
|
closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) |
|
base = torch.tensor(2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), dtype=torch.float32, device=device) |
|
powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32, device=device) |
|
slopes = torch.pow(base, powers) |
|
if closest_power_of_2 != num_heads: |
|
extra_base = torch.tensor( |
|
2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), dtype=torch.float32, device=device |
|
) |
|
num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) |
|
extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, dtype=torch.int32, device=device) |
|
slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) |
|
return slopes |
|
|
|
|
|
def can_use_flash_attn2(dtype: torch.dtype) -> bool: |
|
""" |
|
Check flash attention2 availability. |
|
""" |
|
if dtype not in (torch.float16, torch.bfloat16): |
|
return False |
|
|
|
try: |
|
from flash_attn import flash_attn_varlen_func # noqa |
|
|
|
return True |
|
except ImportError: |
|
logger.warning(f"flash_attn2 has not been installed yet, we will use triton flash attn instead.") |
|
return False
|
|
|