mirror of https://github.com/InternLM/InternLM
299 lines
10 KiB
Python
299 lines
10 KiB
Python
import inspect
|
|
import logging
|
|
import os
|
|
import re
|
|
from typing import Callable, Dict, List, Optional, Union
|
|
|
|
import torch
|
|
|
|
from internlm.apis.inference import SequenceGenerator
|
|
from internlm.core.context import ParallelMode
|
|
from internlm.core.context import global_context as gpc
|
|
from internlm.initialize.launch import launch_from_torch
|
|
from internlm.train import initialize_model
|
|
from internlm.utils.registry import MODEL_INITIALIZER
|
|
from internlm.utils.storage_manager import get_fns, init_storage_manager, llm_load
|
|
from tools.transformers.interface import GenerationConfig
|
|
|
|
logger = logging.getLogger(__file__)
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
|
|
|
def merge_pp_within_tp(folder, del_model_prefix: bool = False) -> Dict[str, torch.Tensor]:
|
|
"""Merge checkpoints from different pipelines of the model.
|
|
|
|
Args:
|
|
folder (str): checkpoint directory.
|
|
del_model_prefix (bool, optional): Whether to remove the "model." string in the key in state_dict. Defaults
|
|
to False.
|
|
|
|
Returns:
|
|
Dict[str, torch.Tensor]: A state_dict that belongs to the model of the current tensor parallel rank.
|
|
"""
|
|
assert folder is not None, "Please specify the folder of the pretrained model"
|
|
fns = get_fns(folder)
|
|
|
|
model_fns = []
|
|
for fn in fns:
|
|
# filter with `_t` is for avoiding conflict with model_config.py
|
|
if fn.startswith("model_t") and not fn.endswith("md5"):
|
|
model_fns.append(fn)
|
|
|
|
max_tp, max_pp = -1, -1
|
|
for fn in model_fns:
|
|
_, tp, pp = os.path.splitext(fn)[0].split("_")
|
|
max_pp = max(max_pp, int(pp[2:]) + 1)
|
|
max_tp = max(max_tp, int(tp[2:]) + 1)
|
|
|
|
if max_tp < 0 or max_pp < 0:
|
|
raise RuntimeError(f"Could not find any model checkpoints starting with 'model_t' in the folder: {folder}")
|
|
|
|
assert max_tp == gpc.get_world_size(ParallelMode.TENSOR), (
|
|
f"Expected TP size in loaded checkpoint to be equal to TP size in current config, but got {max_tp} in "
|
|
f"checkpoint and {gpc.get_world_size(ParallelMode.TENSOR)} in current config"
|
|
)
|
|
tp = gpc.get_local_rank(ParallelMode.TENSOR)
|
|
|
|
layer_shift = 0
|
|
|
|
tp_states = {}
|
|
for pp in range(max_pp):
|
|
_layer_shift = 0
|
|
model_name = f"model_tp{tp}_pp{pp}.pt"
|
|
states = llm_load(os.path.join(folder, model_name), map_location="cpu")
|
|
|
|
keys = list(states.keys())
|
|
for key in keys:
|
|
match = re.search(r"\.\d+\.", key)
|
|
if match is not None: # Indicate layer related. shift is required.
|
|
s, e = match.span()
|
|
layer_idx = int(key[s + 1 : e - 1]) + layer_shift
|
|
_layer_shift = max(_layer_shift, int(key[s + 1 : e - 1]))
|
|
name = key[:s] + f".{layer_idx}." + key[e:]
|
|
if del_model_prefix:
|
|
name = name[6:] if name.startswith("model.") else name
|
|
tp_states[name] = states[key]
|
|
else:
|
|
name = key
|
|
if del_model_prefix:
|
|
name = name[6:] if name.startswith("model.") else name
|
|
tp_states[name] = states[key]
|
|
layer_shift += _layer_shift + 1
|
|
|
|
return tp_states
|
|
|
|
|
|
def match_fn_signature(func: Callable, args_dict: Dict) -> None:
|
|
"""Matches the parameters of a function.
|
|
|
|
Given a function and a parameter dictionary, remove key-value pairs in the dictionary
|
|
that do not match the parameters of the function.
|
|
|
|
Args:
|
|
func (Callable): specific function.
|
|
args_dict (Dict): parameter dictionary.
|
|
"""
|
|
params = inspect.signature(func).parameters
|
|
params = set(params)
|
|
args_set = set(args_dict).difference(params)
|
|
for _name in args_set:
|
|
args_dict.pop(_name)
|
|
if len(args_set) and gpc.is_rank_for_log():
|
|
logger.warning(f"These args:{args_set} are popped for func:{func.__name__}.")
|
|
|
|
|
|
def get_tp_rank() -> int:
|
|
"""Get the tensor parallel rank.
|
|
This script uses torchrun to initialize the environment, so RANK in the environment variable is the tensor
|
|
parallel rank.
|
|
|
|
Returns:
|
|
int: The tensor parallel rank to which the current process belongs.
|
|
"""
|
|
return int(os.environ.get("RANK", 0))
|
|
|
|
|
|
def get_tp_world_size() -> int:
|
|
"""Get the tensor parallel world size.
|
|
|
|
Returns:
|
|
int: The tensor parallel world size to which the current process belongs.
|
|
"""
|
|
return int(os.environ.get("WORLD_SIZE", 0))
|
|
|
|
|
|
def initialize_internlm_model(
|
|
model_type: str,
|
|
ckpt_dir: str,
|
|
model_config: Dict,
|
|
del_model_prefix: bool = False,
|
|
param_dtype: torch.dtype = torch.bfloat16,
|
|
training: bool = False,
|
|
seed: int = 1024,
|
|
) -> torch.nn.Module:
|
|
"""Initialize internlm model.
|
|
|
|
Args:
|
|
model_type (str): The types of models supported by internlm framework, such as "INTERNLM".
|
|
ckpt_dir (str): Directory where model checkpoints are stored. Its format needs to be like this:
|
|
(a) local path, such as: "local:{your local path}";
|
|
(b) boto3 path, such as: "boto3:s3://{bucket name}.{ip}/{your ceph path}".
|
|
model_config (Optional[Union[Dict, str]], optional): Configuration of models. Defaults to None.
|
|
del_model_prefix (bool, optional): Whether to remove the "model." string in the key in state_dict.
|
|
Defaults to False.
|
|
param_dtype (torch.dtype, optional): The dtype of the model at inference time. This value can be a string.
|
|
Use "torch.tf32" when you want to use tf32 to do the inference. Defaults to torch.bfloat16.
|
|
training (bool, optional): model.train() or model.eval(). Defaults to False.
|
|
seed (int, optional): Defaults to 1024.
|
|
"""
|
|
|
|
if gpc.is_rank_for_log():
|
|
logger.info(f"tp world size: {get_tp_world_size()}.")
|
|
|
|
if isinstance(param_dtype, str):
|
|
try:
|
|
param_dtype = eval(param_dtype) # pylint: disable=W0123
|
|
finally:
|
|
pass
|
|
if param_dtype == "torch.tf32":
|
|
param_dtype = torch.float32
|
|
torch.backends.cudnn.allow_tf32 = True
|
|
torch.backends.cuda.matmul.allow_tf32 = True
|
|
|
|
if not isinstance(param_dtype, torch.dtype):
|
|
raise ValueError("Parameter ``param_dtype`` is not right.")
|
|
|
|
try:
|
|
init_storage_manager(False, None, None)
|
|
except AssertionError:
|
|
pass
|
|
except Exception as e:
|
|
raise e
|
|
|
|
model_config["dtype"] = param_dtype
|
|
model_config["parallel_output"] = False
|
|
match_fn_signature(MODEL_INITIALIZER.get_module(model_type), model_config)
|
|
if gpc.is_rank_for_log():
|
|
logger.info(f"model_config: {model_config}.")
|
|
launch_from_torch(
|
|
config=dict(
|
|
model_type=model_type,
|
|
model=model_config,
|
|
parallel=dict(
|
|
zero1=dict(size=1, fsdp=False),
|
|
pipeline=dict(size=1, interleaved_overlap=True),
|
|
tensor=get_tp_world_size(),
|
|
sequence_parallel=0,
|
|
),
|
|
),
|
|
seed=seed,
|
|
)
|
|
model = initialize_model()
|
|
# Directly get the origin model without NativeAMP wrapper.
|
|
model = model.model
|
|
|
|
state_dict = merge_pp_within_tp(ckpt_dir, del_model_prefix=del_model_prefix)
|
|
load_info = model.load_state_dict(state_dict, strict=False)
|
|
logger.info(f"Rank:{gpc.get_local_rank(ParallelMode.TENSOR)}. Load info: {load_info}.")
|
|
|
|
model.to(param_dtype)
|
|
if training:
|
|
model.train()
|
|
else:
|
|
model.eval()
|
|
model.cuda()
|
|
torch.distributed.barrier()
|
|
return model
|
|
|
|
|
|
def get_model_device(model):
|
|
for param in model.parameters():
|
|
device = param.device
|
|
break
|
|
return device
|
|
|
|
|
|
def internlm_interactive_generation(
|
|
model,
|
|
tokenizer,
|
|
prompt,
|
|
generation_config: Optional[GenerationConfig] = None,
|
|
additional_eos_token_list: Optional[Union[int, List[int]]] = None,
|
|
):
|
|
sequenece_generator = SequenceGenerator(
|
|
decoder=model,
|
|
eos_token_id=tokenizer.eos_id(),
|
|
pad_token_id=tokenizer.eos_id(),
|
|
bos_token_id=tokenizer.bos_id(),
|
|
additional_eos_token_list=additional_eos_token_list,
|
|
)
|
|
additional_eos_token_list = torch.LongTensor(additional_eos_token_list)
|
|
input_ids = [tokenizer.bos_id()] + tokenizer.encode(prompt)
|
|
input_ids = torch.LongTensor([input_ids]).to(get_model_device(model))
|
|
output_generator = sequenece_generator.streaming_generate(
|
|
tokens=input_ids,
|
|
max_length=generation_config.max_length,
|
|
do_sample=generation_config.do_sample,
|
|
temperature=generation_config.temperature,
|
|
top_k=40,
|
|
top_p=generation_config.top_p,
|
|
repetition_penalty=generation_config.repetition_penalty,
|
|
length_penalty=1.0,
|
|
)
|
|
for token_ids in output_generator:
|
|
token_ids = token_ids.cpu().tolist()[0]
|
|
if torch.any(additional_eos_token_list == token_ids[-1]):
|
|
token_ids.pop()
|
|
cur_output = tokenizer.decode(token_ids)
|
|
cur_output = cur_output[len(prompt) :]
|
|
yield cur_output
|
|
|
|
|
|
if __name__ == "__main__":
|
|
"""
|
|
Here is a simple example to generate with origin internlm model architecture.
|
|
Use the following command to run:
|
|
>>> torchrun --master_port 12331 --nnodes=1 --node_rank=0 --nproc_per_node=1 tools/load_internlm_model.py
|
|
"""
|
|
model = initialize_internlm_model(
|
|
model_type="INTERNLM",
|
|
ckpt_dir="[Please replace this with the directory where the internlm model weights are stored]",
|
|
model_config=dict(
|
|
checkpoint=False,
|
|
num_attention_heads=32,
|
|
embed_split_hidden=True,
|
|
vocab_size=103168,
|
|
embed_grad_scale=1,
|
|
parallel_output=False,
|
|
hidden_size=4096,
|
|
num_layers=32,
|
|
mlp_ratio=8 / 3,
|
|
apply_post_layer_norm=False,
|
|
dtype="torch.bfloat16",
|
|
norm_type="rmsnorm",
|
|
layer_norm_epsilon=1e-5,
|
|
use_flash_attn=True,
|
|
num_chunks=1,
|
|
use_dynamic_ntk_rope=True,
|
|
),
|
|
del_model_prefix=True,
|
|
)
|
|
|
|
from sentencepiece import SentencePieceProcessor
|
|
|
|
prompt = """<|User|>:{query}<eoh>\n<|Bot|>:"""
|
|
prompt = prompt.replace("{query}", "hello")
|
|
tokenizer = SentencePieceProcessor("tools/V7_sft.model") # pylint: disable=E1121
|
|
|
|
generation_config = GenerationConfig()
|
|
output_generator = internlm_interactive_generation(
|
|
model=model,
|
|
tokenizer=tokenizer,
|
|
prompt=prompt,
|
|
generation_config=generation_config,
|
|
additional_eos_token_list=[103028],
|
|
)
|
|
|
|
for text in output_generator:
|
|
print(text)
|