ColossalAI/colossalai/inference/utils.py

105 lines
3.8 KiB
Python

"""
Utils for model inference
"""
import os
import re
from pathlib import Path
from typing import Optional, Tuple
import torch
from torch import nn
def init_to_get_rotary(self, base=10000, use_elem=False):
"""
This function initializes the rotary positional embedding, it is compatible for all models and is called in ShardFormer
Args:
self : Model that holds the rotary positional embedding
base : calculation arg
use_elem : activated when using chatglm-based models
"""
self.config.head_dim_ = self.config.hidden_size // self.config.num_attention_heads
if not hasattr(self.config, "rope_scaling"):
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.rope_scaling.factor if self.config.rope_scaling is not None else 1.0
if hasattr(self.config, "max_sequence_length"):
max_seq_len = self.config.max_sequence_length
elif hasattr(self.config, "max_position_embeddings"):
max_seq_len = self.config.max_position_embeddings * rope_scaling_factor
else:
max_seq_len = 2048 * rope_scaling_factor
base = float(base)
# NTK ref: https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
ntk_alpha = os.environ.get("INFER_NTK_ALPHA", None)
if ntk_alpha is not None:
ntk_alpha = float(ntk_alpha)
assert ntk_alpha >= 1, "NTK alpha must be greater than or equal to 1"
if ntk_alpha > 1:
print(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (self.head_dim_ / (self.head_dim_ - 2))) # Base change formula
n_elem = self.config.head_dim_
if use_elem:
n_elem //= 2
inv_freq = 1.0 / (base ** (torch.arange(0, n_elem, 2, device="cpu", dtype=torch.float32) / n_elem))
t = torch.arange(max_seq_len + 1024 * 64, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
"""
Check whether the checkpoint has an index file.
Args:
checkpoint_path (str): path to the checkpoint.
Returns:
Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path)
"""
checkpoint_path = Path(checkpoint_path)
if checkpoint_path.is_file():
# check if it is .index.json
reg = re.compile("(.*?).index((\..*)?).json")
if reg.fullmatch(checkpoint_path.name) is not None:
return True, checkpoint_path
else:
return False, None
elif checkpoint_path.is_dir():
index_files = list(checkpoint_path.glob("*.index.*json"))
for index_file in index_files:
if "safetensors" in index_file.__str__():
return True, index_file.__str__() # return the safetensors file first
if len(index_files) == 1:
return True, index_files[0]
else:
assert (
len(index_files) == 1
), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
return False, None
else:
raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
def get_model_size(model: nn.Module):
"""Calculates the total size of the model weights (including biases) in bytes.
Args:
model: The PyTorch model to analyze.
Returns:
The total size of the model weights in bytes.
"""
total_size = 0
for key, param in model.named_parameters():
total_size += param.element_size() * param.numel()
return total_size / (1024**3)