mirror of https://github.com/hpcaitech/ColossalAI
[Feat]Tensor Model Parallel Support For Inference (#5563)
* tensor parallel support naive source * [fix]precision, model load and refactor the framework * add tp unit test * docstring * fix do_samplepull/5611/head
parent
be396ad6cc
commit
e37ee2fb65
|
@ -5,8 +5,17 @@ from typing import Dict, List, Optional, Tuple, Union
|
|||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import GenerationConfig, PreTrainedTokenizer, PreTrainedTokenizerFast
|
||||
from torch import distributed as dist
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModelForCausalLM,
|
||||
GenerationConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.llama.modeling_llama import LlamaForCausalLM
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.inference.batch_bucket import BatchBucket
|
||||
from colossalai.inference.config import InferenceConfig, InputMetaData
|
||||
|
@ -14,6 +23,8 @@ from colossalai.inference.graph_runner import CUDAGraphRunner
|
|||
from colossalai.inference.modeling.policy import model_policy_map
|
||||
from colossalai.inference.spec import Drafter, GlideInput
|
||||
from colossalai.inference.struct import Sequence
|
||||
from colossalai.inference.utils import get_model_size, has_index_file
|
||||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
@ -25,10 +36,10 @@ __all__ = ["InferenceEngine"]
|
|||
|
||||
PP_AXIS, TP_AXIS = 0, 1
|
||||
|
||||
_supported_models = [
|
||||
"LlamaForCausalLM",
|
||||
"BaichuanForCausalLM",
|
||||
]
|
||||
_supported_models = {
|
||||
"LlamaForCausalLM": LlamaForCausalLM,
|
||||
"BaichuanForCausalLM": AutoModelForCausalLM,
|
||||
}
|
||||
|
||||
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [8 * i for i in range(1, 33)]
|
||||
|
||||
|
@ -39,7 +50,7 @@ class InferenceEngine:
|
|||
InferenceEngine which manages the inference process..
|
||||
|
||||
Args:
|
||||
model (nn.Module): Path or nn.Module of this model.
|
||||
model_or_path (nn.Module or str): Path or nn.Module of this model.
|
||||
tokenizer Optional[(Union[PreTrainedTokenizer, PreTrainedTokenizerFast])]: Path of the tokenizer to use.
|
||||
inference_config (Optional[InferenceConfig], optional): Store the configuration information related to inference.
|
||||
verbose (bool): Determine whether or not to log the generation process.
|
||||
|
@ -48,53 +59,25 @@ class InferenceEngine:
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
model_or_path: Union[nn.Module, str],
|
||||
tokenizer: Union[PreTrainedTokenizer, PreTrainedTokenizerFast],
|
||||
inference_config: InferenceConfig,
|
||||
verbose: bool = False,
|
||||
model_policy: Policy = None,
|
||||
) -> None:
|
||||
self.inference_config = inference_config
|
||||
self.model_config = model.config
|
||||
self.model = model
|
||||
self.device = torch.device("cuda")
|
||||
self.dtype = inference_config.dtype
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
self.high_precision = inference_config.high_precision
|
||||
self._verify_args()
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
model.eval()
|
||||
model = model.to(self.dtype)
|
||||
model = model.to(self.device)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = False
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
|
||||
if model_policy is None:
|
||||
if self.inference_config.pad_input:
|
||||
model_type = "padding_" + self.model_config.model_type
|
||||
else:
|
||||
model_type = "nopadding_" + self.model_config.model_type
|
||||
model_policy = model_policy_map[model_type]()
|
||||
|
||||
pg_mesh = ProcessGroupMesh(inference_config.pp_size, inference_config.tp_size)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
None,
|
||||
pg_mesh.get_group_along_axis(TP_AXIS) if inference_config.pp_size * inference_config.tp_size > 1 else None,
|
||||
)
|
||||
|
||||
self.verbose = verbose
|
||||
if verbose:
|
||||
self.logger = get_dist_logger(__name__)
|
||||
self.logger = get_dist_logger(__name__)
|
||||
|
||||
self.init_model(model_or_path, model_policy)
|
||||
|
||||
self.generation_config = inference_config.to_generation_config(self.model_config)
|
||||
|
||||
self.tokenizer = tokenizer
|
||||
self.tokenizer.pad_token = self.tokenizer.eos_token
|
||||
|
||||
self.request_handler = RequestHandler(self.inference_config, self.model_config)
|
||||
self.k_cache, self.v_cache = self.request_handler.get_kvcache()
|
||||
|
@ -111,6 +94,91 @@ class InferenceEngine:
|
|||
|
||||
self.capture_model(self.k_cache, self.v_cache)
|
||||
|
||||
# Model and relatable attrs of speculative decoding will be set by `enable_spec_dec`
|
||||
self.use_spec_dec = False
|
||||
self.drafter_model = None
|
||||
self.drafter = None
|
||||
self.use_glide = False
|
||||
self.n_spec_tokens = self.inference_config.max_n_spec_tokens
|
||||
|
||||
self._verify_args()
|
||||
|
||||
def init_model(self, model_or_path: Union[nn.Module, str], model_policy: Policy = None):
|
||||
"""
|
||||
Shard model or/and Load weight
|
||||
|
||||
Args:
|
||||
model_or_path Union[nn.Module, str]: path to the checkpoint or model of transformer format.
|
||||
model_policy (Policy): the policy to replace the model
|
||||
"""
|
||||
|
||||
if isinstance(model_or_path, str):
|
||||
try:
|
||||
hf_config = AutoConfig.from_pretrained(model_or_path, trust_remote_code=True)
|
||||
arch = getattr(hf_config, "architectures")[0]
|
||||
model = _supported_models[arch](hf_config)
|
||||
except Exception as e:
|
||||
self.logger.error(
|
||||
f"An exception occurred during loading model: {e}, model should be loaded by transformers\n"
|
||||
)
|
||||
else:
|
||||
model = model_or_path
|
||||
|
||||
self.model_config = model.config
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
init_gpu_memory = torch.cuda.mem_get_info()[0]
|
||||
|
||||
self.device = get_accelerator().get_current_device()
|
||||
if self.verbose:
|
||||
self.logger.info(f"the device is {self.device}")
|
||||
|
||||
model = model.to(self.dtype).eval()
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Before the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if model_policy is None:
|
||||
if self.inference_config.pad_input:
|
||||
model_type = "padding_" + self.model_config.model_type
|
||||
else:
|
||||
model_type = "nopadding_" + self.model_config.model_type
|
||||
model_policy = model_policy_map[model_type]()
|
||||
|
||||
pg_mesh = ProcessGroupMesh(self.inference_config.pp_size, self.inference_config.tp_size)
|
||||
tp_group = pg_mesh.get_group_along_axis(TP_AXIS)
|
||||
|
||||
self.model = self._shardformer(
|
||||
model,
|
||||
model_policy,
|
||||
None,
|
||||
tp_group=tp_group,
|
||||
)
|
||||
|
||||
self.model = ModelWrapper(model).to(self.device)
|
||||
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"After the shard, Rank: [{dist.get_rank()}], model size: {get_model_size(self.model)} GB, model's device is: {model.device}"
|
||||
)
|
||||
|
||||
if isinstance(model_or_path, str):
|
||||
from colossalai.inference.core.plugin import InferCheckpoint_io
|
||||
|
||||
cpt_io = InferCheckpoint_io()
|
||||
if_has_index_file, model_index_file = has_index_file(model_or_path)
|
||||
assert if_has_index_file, "the model path is invalid"
|
||||
cpt_io.load_model(self.model, model_index_file)
|
||||
|
||||
free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info()
|
||||
peak_memory = init_gpu_memory - free_gpu_memory
|
||||
if self.verbose:
|
||||
self.logger.info(
|
||||
f"Rank [{dist.get_rank()}], Model Weight Max Occupy {peak_memory / (1024 ** 3)} GB, Model size: {get_model_size(self.model)} GB"
|
||||
)
|
||||
|
||||
@torch.inference_mode()
|
||||
def capture_model(self, k_cache: List[torch.Tensor], v_cache: List[torch.Tensor]):
|
||||
assert self.use_cuda_graph, "please turn on the cuda graph"
|
||||
|
@ -194,8 +262,11 @@ class InferenceEngine:
|
|||
raise TypeError(
|
||||
f"the tokenizer type must be PreTrainedTokenizer or PreTrainedTokenizerFast, but got {type(self.tokenizer)}"
|
||||
)
|
||||
if self.model.__class__.__name__ not in _supported_models:
|
||||
raise ValueError(f"Model {self.model.__class__.__name__} is not supported.")
|
||||
if isinstance(self.model, ModelWrapper):
|
||||
model = self.model.module
|
||||
assert (
|
||||
model.__class__.__name__ in _supported_models.keys()
|
||||
), f"Model {self.model.__class__.__name__} is not supported."
|
||||
|
||||
def _shardformer(
|
||||
self,
|
||||
|
|
|
@ -0,0 +1,140 @@
|
|||
import logging
|
||||
import os
|
||||
from functools import reduce
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
from colossalai.checkpoint_io.general_checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.checkpoint_io.index_file import CheckpointIndexFile
|
||||
from colossalai.checkpoint_io.utils import is_safetensors_available, load_shard_state_dict, load_state_dict_into_model
|
||||
from colossalai.cluster import DistCoordinator
|
||||
from colossalai.interface import ModelWrapper
|
||||
|
||||
try:
|
||||
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
|
||||
except ImportError:
|
||||
_EXTRA_STATE_KEY_SUFFIX = "_extra_state"
|
||||
|
||||
|
||||
class InferCheckpoint_io(GeneralCheckpointIO):
|
||||
"""
|
||||
This class is for inference model loading, most codes are copied from colossalai.checkpoint_io.hybrid_parallel_checkpoint_io.HybridParallelCheckpointIO.
|
||||
Origin HybridParallelCheckpointIO contains some codes about MixPrecision-Training, so we remove them and build a relatively clean class specifically for Inference.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
verbose: bool = True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.verbose = verbose
|
||||
self.coordinator = DistCoordinator()
|
||||
|
||||
def load_sharded_model(self, model: ModelWrapper, checkpoint_index_file: Path, strict: bool = False):
|
||||
"""
|
||||
Load sharded model with the given path to index file of checkpoint folder.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be loaded.
|
||||
checkpoint_index_file (str): Path to the index file of checkpointing folder.
|
||||
strict (bool, optional): For name matching during loading state_dict. Defaults to False.
|
||||
This argument should be manually set to False since params on same device might be stored in different files.
|
||||
"""
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before loading!"
|
||||
model = model.unwrap()
|
||||
|
||||
# Check whether the checkpoint uses safetensors.
|
||||
use_safetensors = False
|
||||
if "safetensors" in checkpoint_index_file.name:
|
||||
use_safetensors = True
|
||||
|
||||
if use_safetensors and not is_safetensors_available():
|
||||
raise ImportError("`safe_serialization` requires the `safetensors` library: `pip install safetensors`.")
|
||||
|
||||
# Read checkpoint index file.
|
||||
ckpt_index_file = CheckpointIndexFile.from_file(checkpoint_index_file)
|
||||
ckpt_root_path = ckpt_index_file.root_path
|
||||
weight_map = ckpt_index_file.weight_map
|
||||
strict = False
|
||||
|
||||
# Load params & buffers to model.
|
||||
# Keep a record of loaded files so that file will not be repeatedly loaded.
|
||||
loaded_file = set()
|
||||
|
||||
missing_keys = []
|
||||
missing_file_keys = []
|
||||
|
||||
def _load(name: str):
|
||||
if name not in weight_map:
|
||||
missing_file_keys.append(name)
|
||||
return
|
||||
filename = weight_map[name]
|
||||
|
||||
# If this param/buffer has been loaded before, directly return.
|
||||
if filename in loaded_file:
|
||||
return
|
||||
|
||||
file_path = os.path.join(ckpt_root_path, filename)
|
||||
state_dict = load_shard_state_dict(Path(file_path), use_safetensors)
|
||||
|
||||
load_state_dict_into_model(
|
||||
model, state_dict, missing_keys=missing_keys, strict=strict, load_sub_module=True
|
||||
)
|
||||
loaded_file.add(filename)
|
||||
|
||||
# Load parameters.
|
||||
for name, _ in model.named_parameters():
|
||||
_load(name)
|
||||
|
||||
# Load buffers.
|
||||
non_persistent_buffers = set()
|
||||
for n, m in model.named_modules():
|
||||
non_persistent_buffers |= set(".".join((n, b)) for b in m._non_persistent_buffers_set)
|
||||
for name, buf in model.named_buffers():
|
||||
if buf is not None and name not in non_persistent_buffers:
|
||||
_load(name)
|
||||
|
||||
# Load extra states.
|
||||
extra_state_key = _EXTRA_STATE_KEY_SUFFIX
|
||||
if (
|
||||
getattr(model.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
|
||||
is not torch.nn.Module.get_extra_state
|
||||
):
|
||||
_load(extra_state_key)
|
||||
|
||||
if self.verbose and self.coordinator.is_master():
|
||||
logging.info(f"The model has been successfully loaded from sharded checkpoint: {ckpt_root_path}.")
|
||||
|
||||
if len(missing_keys) == 0:
|
||||
raise RuntimeError(
|
||||
"No weigth is loaded into the model. Please check the checkpoint files and the model structure."
|
||||
)
|
||||
|
||||
remain_keys = reduce(lambda a, b: a & b, map(set, missing_keys))
|
||||
remain_keys = remain_keys.union(set(missing_file_keys))
|
||||
if len(remain_keys) > 0:
|
||||
if strict:
|
||||
error_msgs = "Missing key(s) in state_dict: {}. ".format(
|
||||
", ".join('"{}"'.format(k) for k in missing_keys)
|
||||
)
|
||||
raise RuntimeError(
|
||||
"Error(s) in loading state_dict for {}:\n\t{}".format(
|
||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
else:
|
||||
if self.coordinator.is_master():
|
||||
logging.info(f"The following keys are not loaded from checkpoint: {remain_keys}")
|
||||
|
||||
def save_sharded_model(
|
||||
self,
|
||||
model: ModelWrapper,
|
||||
checkpoint: str,
|
||||
gather_dtensor: bool = True,
|
||||
prefix: Optional[str] = None,
|
||||
size_per_shard: int = 1024,
|
||||
use_safetensors: bool = False,
|
||||
) -> None:
|
||||
return NotImplementedError
|
|
@ -140,7 +140,7 @@ class RequestHandler:
|
|||
|
||||
fd_inter_tensor.initialize(
|
||||
max_batch_size=max_n_tokens,
|
||||
num_attn_heads=model_config.num_attention_heads,
|
||||
num_attn_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
kv_max_split_num=kv_max_split_num,
|
||||
head_dim=head_dim,
|
||||
dtype=self.dtype,
|
||||
|
@ -150,7 +150,7 @@ class RequestHandler:
|
|||
# TODO In the continuous batching scenario, the batch size may be greater than max_batch_size,
|
||||
# which may cause bugs and this issue should be fixed later.
|
||||
self.running_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads,
|
||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
head_dim=head_dim,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
|
@ -161,7 +161,7 @@ class RequestHandler:
|
|||
device=device,
|
||||
)
|
||||
self.prefill_bb = BatchBucket(
|
||||
num_heads=model_config.num_attention_heads,
|
||||
num_heads=model_config.num_attention_heads // inference_config.tp_size,
|
||||
head_dim=head_dim,
|
||||
max_batch_size=self.max_batch_size,
|
||||
max_length=inference_config.max_input_len + inference_config.max_output_len,
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
# This code is adapted from huggingface transformers: https://github.com/huggingface/transformers/blob/v4.34.1/src/transformers/models/llama/modeling_llama.py
|
||||
from typing import List, Optional, Tuple
|
||||
import itertools
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
from torch.distributed import ProcessGroup
|
||||
from transformers.models.llama.modeling_llama import (
|
||||
LlamaAttention,
|
||||
LlamaConfig,
|
||||
|
@ -26,6 +29,8 @@ from colossalai.kernel.triton import (
|
|||
rotary_embedding,
|
||||
)
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.shardformer.layer.parallel_module import ParallelModule
|
||||
from colossalai.tensor.d_tensor import distribute_tensor, is_distributed_tensor
|
||||
|
||||
inference_ops = InferenceOpsLoader().load()
|
||||
|
||||
|
@ -68,7 +73,8 @@ def llama_causal_lm_forward(
|
|||
use_cuda_kernel=inputmetadata.use_cuda_kernel, # Note currently the cuda kernel of layernorm, rotary_embedding_and_cache_copy couldn't pass the unitest but triton kernel could
|
||||
high_precision=inputmetadata.high_precision,
|
||||
)
|
||||
logits = torch.mm(hidden_states, self.lm_head.weight)
|
||||
|
||||
logits = self.lm_head(hidden_states)
|
||||
return logits
|
||||
|
||||
|
||||
|
@ -109,6 +115,7 @@ def llama_model_forward(
|
|||
logger.warning("CUDA kernel is disabled for speculative-decoding.")
|
||||
|
||||
hidden_states = self.embed_tokens(input_tokens_ids)
|
||||
|
||||
cu_seqlens = None
|
||||
|
||||
# NOTE (yuanheng-zhao): we do not use cuda kernels for speculative-decoding for now
|
||||
|
@ -126,7 +133,7 @@ def llama_model_forward(
|
|||
cos_sin = (self._cos_cached[rotary_indexes], self._sin_cached[rotary_indexes])
|
||||
|
||||
elif use_cuda_kernel:
|
||||
if inputmetadata != torch.float32 and use_flash_attn2:
|
||||
if inputmetadata.dtype != torch.float32 and use_flash_attn2:
|
||||
cu_seqlens = F.pad(torch.cumsum(sequence_lengths, dim=0, dtype=torch.torch.int32), (1, 0))
|
||||
|
||||
hidden_dim = self._cos_cached.size(-1)
|
||||
|
@ -270,7 +277,129 @@ def llama_rmsnorm_forward(
|
|||
return rms_layernorm(hidden_states, self.weight.data, self.variance_epsilon, norm_output, residual)
|
||||
|
||||
|
||||
class NopadLlamaAttention(LlamaAttention):
|
||||
class NopadLlamaMLP(ParallelModule, LlamaMLP):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
mlp_gproj_w: torch.Tensor = None,
|
||||
mlp_uproj_w: torch.Tensor = None,
|
||||
mlp_dproj: ParallelModule = None,
|
||||
process_group: ProcessGroup = None,
|
||||
):
|
||||
"""A Unified Layer for
|
||||
|
||||
Args:
|
||||
config (LlamaConfig): Holding the Llama model config.
|
||||
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
|
||||
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
|
||||
mlp_dproj (Linear1D_Row, optional): The Linear1D_Row mlp_dproj weight. Defaults to None.
|
||||
"""
|
||||
ParallelModule.__init__(self)
|
||||
self.config = config
|
||||
assert is_distributed_tensor(
|
||||
mlp_gproj_w
|
||||
), "mlp_gproj_w must be dtensor so we could get the layout of the weight"
|
||||
self.helper_layout = (
|
||||
mlp_gproj_w.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of gate_up_weight(used in _load_from_state_dict)
|
||||
self.gate_up_weight = nn.Parameter(
|
||||
torch.stack([mlp_gproj_w.transpose(0, 1), mlp_uproj_w.transpose(0, 1)], dim=0)
|
||||
)
|
||||
self.down_proj = mlp_dproj
|
||||
self.process_group = process_group
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(
|
||||
module: LlamaMLP, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
"""Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
|
||||
|
||||
Args:
|
||||
module (LlamaMLP): The origin LlamaMLP layer.
|
||||
"""
|
||||
|
||||
config = module.config
|
||||
|
||||
mlp_gproj_w = module.gate_proj.weight
|
||||
assert is_distributed_tensor(
|
||||
module.gate_proj.weight
|
||||
), "gate_proj.weight must be dtensor so we could get the layout of the weight"
|
||||
mlp_uproj_w = module.up_proj.weight
|
||||
mlp_dproj = module.down_proj
|
||||
|
||||
mlp_layer = NopadLlamaMLP(
|
||||
config=config,
|
||||
mlp_gproj_w=mlp_gproj_w,
|
||||
mlp_uproj_w=mlp_uproj_w,
|
||||
mlp_dproj=mlp_dproj,
|
||||
process_group=process_group,
|
||||
)
|
||||
|
||||
return mlp_layer
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
# NOTE This is a hack to ensure we could load the right weight from LlamaMLP checkpoint due to the use of torch.stack(gate_weight, up_weight)
|
||||
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
key = "gate_up_weight"
|
||||
k1 = "gate_proj.weight"
|
||||
k2 = "up_proj.weight"
|
||||
|
||||
gate_w = state_dict[prefix + k1]
|
||||
up_w = state_dict[prefix + k2]
|
||||
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
gate_w = distribute_tensor(gate_w, device_mesh, sharding_spec)
|
||||
up_w = distribute_tensor(up_w, device_mesh, sharding_spec)
|
||||
|
||||
gate_up_w = torch.stack([gate_w.T, up_w.T], dim=0)
|
||||
|
||||
input_param = nn.Parameter(
|
||||
gate_up_w
|
||||
) # NOTE gate_up_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
param = local_state[key]
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
"""
|
||||
hidden_states = hidden_states.expand(2, -1, -1)
|
||||
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
||||
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
|
||||
|
||||
return self.down_proj(act_out)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return f"gate_up_proj MergedLinear1D_Col: in_features={self.gate_up_weight.shape[1]}x2, out_features={self.gate_up_weight.shape[2]}, bias=False"
|
||||
|
||||
|
||||
class NopadLlamaAttention(ParallelModule, LlamaAttention):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
|
@ -278,7 +407,11 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
attn_qproj_w: torch.Tensor = None,
|
||||
attn_kproj_w: torch.Tensor = None,
|
||||
attn_vproj_w: torch.Tensor = None,
|
||||
attn_oproj_w: torch.Tensor = None,
|
||||
attn_oproj: ParallelModule = None,
|
||||
process_group: ProcessGroup = None,
|
||||
num_heads: int = None,
|
||||
hidden_size: int = None,
|
||||
num_key_value_heads: int = None,
|
||||
):
|
||||
"""This layer will replace the LlamaAttention.
|
||||
|
||||
|
@ -288,36 +421,54 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
attn_qproj_w (torch.Tensor, optional): The transposed q_proj weight. Defaults to None.
|
||||
attn_kproj_w (torch.Tensor, optional): The transposed k_proj weight. Defaults to None.
|
||||
attn_vproj_w (torch.Tensor, optional): The transposed v_proj weight. Defaults to None.
|
||||
attn_oproj_w (torch.Tensor, optional): The transposed o_proj weight. Defaults to None.
|
||||
attn_oproj (Linear1D_Row, optional): The Linear1D_Row o_proj weight. Defaults to None.
|
||||
"""
|
||||
super().__init__(config, layer_idx)
|
||||
self.q_proj_weight = attn_qproj_w
|
||||
self.k_proj_weight = attn_kproj_w
|
||||
self.v_proj_weight = attn_vproj_w
|
||||
self.o_proj_weight = attn_oproj_w
|
||||
ParallelModule.__init__(self)
|
||||
self.config = config
|
||||
self.layer_idx = layer_idx
|
||||
|
||||
self.o_proj = attn_oproj
|
||||
self.process_group = process_group
|
||||
|
||||
self.attention_dropout = config.attention_dropout
|
||||
self.hidden_size = hidden_size
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = self.hidden_size // self.num_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
|
||||
self.max_position_embeddings = config.max_position_embeddings
|
||||
self.rope_theta = config.rope_theta
|
||||
self.is_causal = True
|
||||
|
||||
if self.num_heads == self.num_key_value_heads:
|
||||
qkv_weight_list = [self.q_proj_weight, self.k_proj_weight, self.v_proj_weight]
|
||||
self.qkv_weight = torch.stack(qkv_weight_list, dim=0)
|
||||
|
||||
self.q_proj = None
|
||||
self.k_proj = None
|
||||
self.v_proj = None
|
||||
qkv_weight_list = [attn_qproj_w.transpose(0, 1), attn_kproj_w.transpose(0, 1), attn_vproj_w.transpose(0, 1)]
|
||||
self.qkv_weight = nn.Parameter(torch.stack(qkv_weight_list, dim=0))
|
||||
self.helper_layout = (
|
||||
attn_qproj_w.dist_layout
|
||||
) # NOTE this is a hack for the right load/shard of qkv_weight(used in _load_from_state_dict)
|
||||
else:
|
||||
self.q_proj_weight = attn_qproj_w
|
||||
self.k_proj_weight = attn_kproj_w
|
||||
self.v_proj_weight = attn_vproj_w
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: LlamaAttention, *args, **kwargs) -> LlamaAttention:
|
||||
def from_native_module(
|
||||
module: LlamaAttention, process_group: Union[ProcessGroup, List[ProcessGroup]], *args, **kwargs
|
||||
) -> ParallelModule:
|
||||
"""Used for initialize the weight of NopadLlamaAttention by origin LlamaAttention.
|
||||
|
||||
Args:
|
||||
module (LlamaAttention): The origin LlamaAttention layer.
|
||||
"""
|
||||
|
||||
config = module.config
|
||||
layer_idx = module.layer_idx
|
||||
|
||||
attn_qproj_w = module.q_proj.weight.transpose(0, 1)
|
||||
attn_kproj_w = module.k_proj.weight.transpose(0, 1)
|
||||
attn_vproj_w = module.v_proj.weight.transpose(0, 1)
|
||||
attn_oproj_w = module.o_proj.weight.transpose(0, 1)
|
||||
attn_qproj_w = module.q_proj.weight
|
||||
attn_kproj_w = module.k_proj.weight
|
||||
attn_vproj_w = module.v_proj.weight
|
||||
assert is_distributed_tensor(attn_qproj_w), "attn_qproj_w must be dist tensor"
|
||||
attn_oproj = module.o_proj
|
||||
|
||||
attn_layer = NopadLlamaAttention(
|
||||
config=config,
|
||||
|
@ -325,7 +476,11 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
attn_qproj_w=attn_qproj_w,
|
||||
attn_kproj_w=attn_kproj_w,
|
||||
attn_vproj_w=attn_vproj_w,
|
||||
attn_oproj_w=attn_oproj_w,
|
||||
attn_oproj=attn_oproj,
|
||||
process_group=process_group,
|
||||
num_heads=module.num_heads,
|
||||
hidden_size=module.hidden_size,
|
||||
num_key_value_heads=module.num_key_value_heads,
|
||||
)
|
||||
|
||||
return attn_layer
|
||||
|
@ -487,63 +642,57 @@ class NopadLlamaAttention(LlamaAttention):
|
|||
)
|
||||
|
||||
attn_output = attn_output.view(-1, self.hidden_size)
|
||||
attn_output = torch.mm(attn_output, self.o_proj_weight)
|
||||
|
||||
attn_output = self.o_proj(attn_output)
|
||||
return attn_output
|
||||
|
||||
|
||||
# NOTE This will cause difference as out length increases.
|
||||
class NopadLlamaMLP(LlamaMLP):
|
||||
def __init__(
|
||||
self,
|
||||
config: LlamaConfig,
|
||||
mlp_gproj_w: torch.Tensor = None,
|
||||
mlp_uproj_w: torch.Tensor = None,
|
||||
mlp_dproj_w: torch.Tensor = None,
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
"""This layer will replace the LlamaAttention.
|
||||
# NOTE This is a hack to ensure we could load the right weight from LlamaAttention checkpoint due to the use of torch.stack(q_weight, k_weight, v_weight)
|
||||
for hook in self._load_state_dict_pre_hooks.values():
|
||||
hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
|
||||
|
||||
Args:
|
||||
config (LlamaConfig): Holding the Llama model config.
|
||||
mlp_gproj_w (torch.Tensor, optional): The transposed gate_proj weight. Defaults to None.
|
||||
mlp_uproj_w (torch.Tensor, optional): The transposed up_proj weight. Defaults to None.
|
||||
mlp_dproj_w (torch.Tensor, optional): The transposed down_proj weight. Defaults to None.
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.gate_up_weight = torch.stack([mlp_gproj_w, mlp_uproj_w], dim=0)
|
||||
self.down_proj_weight = mlp_dproj_w
|
||||
self.gate_proj = None
|
||||
self.up_proj = None
|
||||
self.down_proj = None
|
||||
persistent_buffers = {k: v for k, v in self._buffers.items() if k not in self._non_persistent_buffers_set}
|
||||
local_name_params = itertools.chain(self._parameters.items(), persistent_buffers.items())
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
||||
@staticmethod
|
||||
def from_native_module(module: LlamaMLP, *args, **kwargs) -> LlamaMLP:
|
||||
"""Used for initialize the weight of NopadLlamaMLP by origin LlamaMLP.
|
||||
key = "qkv_weight"
|
||||
k1 = "q_proj.weight"
|
||||
k2 = "k_proj.weight"
|
||||
k3 = "v_proj.weight"
|
||||
q_w = state_dict[prefix + k1]
|
||||
k_w = state_dict[prefix + k2]
|
||||
v_w = state_dict[prefix + k3]
|
||||
|
||||
Args:
|
||||
module (LlamaMLP): The origin LlamaMLP layer.
|
||||
"""
|
||||
config = module.config
|
||||
device_mesh = self.helper_layout.device_mesh
|
||||
sharding_spec = self.helper_layout.sharding_spec
|
||||
q_w = distribute_tensor(q_w, device_mesh, sharding_spec)
|
||||
k_w = distribute_tensor(k_w, device_mesh, sharding_spec)
|
||||
v_w = distribute_tensor(v_w, device_mesh, sharding_spec)
|
||||
|
||||
mlp_gproj_w = module.gate_proj.weight.transpose(0, 1)
|
||||
mlp_uproj_w = module.up_proj.weight.transpose(0, 1)
|
||||
mlp_dproj_w = module.down_proj.weight.transpose(0, 1)
|
||||
qkv_w = torch.stack([q_w.T, k_w.T, v_w.T], dim=0)
|
||||
|
||||
mlp_layer = NopadLlamaMLP(
|
||||
config=config,
|
||||
mlp_gproj_w=mlp_gproj_w,
|
||||
mlp_uproj_w=mlp_uproj_w,
|
||||
mlp_dproj_w=mlp_dproj_w,
|
||||
input_param = nn.Parameter(
|
||||
qkv_w
|
||||
) # NOTE qkv_weight doesn't have to be a distensor, Like input_param = sharded_tensor_to_param(input_param)
|
||||
|
||||
param = local_state[key]
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
param.copy_(input_param)
|
||||
except Exception as ex:
|
||||
error_msgs.append(
|
||||
'While copying the parameter named "{}", '
|
||||
"whose dimensions in the model are {} and "
|
||||
"whose dimensions in the checkpoint are {}, "
|
||||
"an exception occurred : {}.".format(key, param.size(), input_param.size(), ex.args)
|
||||
)
|
||||
|
||||
strict = False # to avoid unexpected_keys
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
return mlp_layer
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (torch.Tensor): input to the layer of shape [token_num, embed_dim].
|
||||
"""
|
||||
hidden_states = hidden_states.expand(2, -1, -1)
|
||||
gate_up_proj_out = torch.bmm(hidden_states, self.gate_up_weight)
|
||||
act_out = inference_ops.silu_and_mul(gate_up_proj_out)
|
||||
return torch.mm(act_out, self.down_proj_weight)
|
||||
def extra_repr(self) -> str:
|
||||
return f"qkv_weight_proj MergedLinear1D_Col: in_features={self.qkv_weight.shape[1]}x3, out_features={self.qkv_weight.shape[2]}, bias=False"
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
from torch.nn import Parameter
|
||||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaForCausalLM, LlamaModel, LlamaRMSNorm
|
||||
|
||||
from colossalai.inference.modeling.models.nopadding_llama import (
|
||||
|
@ -10,6 +9,7 @@ from colossalai.inference.modeling.models.nopadding_llama import (
|
|||
llama_rmsnorm_forward,
|
||||
)
|
||||
from colossalai.inference.utils import init_to_get_rotary
|
||||
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row
|
||||
from colossalai.shardformer.policies.base_policy import ModulePolicyDescription, SubModuleReplacementDescription
|
||||
from colossalai.shardformer.policies.llama import LlamaForCausalLMPolicy
|
||||
|
||||
|
@ -21,26 +21,69 @@ class NoPaddingLlamaModelInferPolicy(LlamaForCausalLMPolicy):
|
|||
def module_policy(self):
|
||||
policy = super().module_policy()
|
||||
|
||||
decoder_attribute_replacement = {
|
||||
"lm_head.weight": Parameter(self.model.lm_head.weight.transpose(0, 1), requires_grad=False),
|
||||
}
|
||||
policy[LlamaForCausalLM] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
)
|
||||
if self.shard_config.enable_tensor_parallelism:
|
||||
decoder_attribute_replacement = {
|
||||
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
|
||||
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
|
||||
}
|
||||
if getattr(self.model.config, "num_key_value_heads", False):
|
||||
decoder_attribute_replacement["self_attn.num_key_value_heads"] = (
|
||||
self.model.config.num_key_value_heads // self.shard_config.tensor_parallel_size
|
||||
)
|
||||
else:
|
||||
decoder_attribute_replacement = None
|
||||
|
||||
policy[LlamaDecoderLayer] = ModulePolicyDescription(
|
||||
attribute_replacement=decoder_attribute_replacement,
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.gate_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.up_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp.down_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="mlp",
|
||||
target_module=NopadLlamaMLP,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.q_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.k_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.v_proj",
|
||||
target_module=Linear1D_Col,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn.o_proj",
|
||||
target_module=Linear1D_Row,
|
||||
),
|
||||
SubModuleReplacementDescription(
|
||||
suffix="self_attn",
|
||||
target_module=NopadLlamaAttention,
|
||||
),
|
||||
]
|
||||
],
|
||||
)
|
||||
|
||||
policy[LlamaForCausalLM] = ModulePolicyDescription(
|
||||
sub_module_replacement=[
|
||||
SubModuleReplacementDescription(
|
||||
suffix="lm_head", target_module=Linear1D_Col, kwargs={"gather_output": True}
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# self.shard_config._infer()
|
||||
self.append_or_create_method_replacement(
|
||||
description={"forward": llama_causal_lm_forward}, policy=policy, target_key=LlamaForCausalLM
|
||||
)
|
||||
|
|
|
@ -2,8 +2,12 @@
|
|||
Utils for model inference
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
def init_to_get_rotary(self, base=10000, use_elem=False):
|
||||
|
@ -49,3 +53,52 @@ def init_to_get_rotary(self, base=10000, use_elem=False):
|
|||
|
||||
self._cos_cached = torch.cos(freqs).to(self.dtype).cuda()
|
||||
self._sin_cached = torch.sin(freqs).to(self.dtype).cuda()
|
||||
|
||||
|
||||
def has_index_file(checkpoint_path: str) -> Tuple[bool, Optional[Path]]:
|
||||
"""
|
||||
Check whether the checkpoint has an index file.
|
||||
|
||||
Args:
|
||||
checkpoint_path (str): path to the checkpoint.
|
||||
|
||||
Returns:
|
||||
Tuple[bool, Optional[Path]]: a tuple of (has_index_file, index_file_path)
|
||||
"""
|
||||
checkpoint_path = Path(checkpoint_path)
|
||||
if checkpoint_path.is_file():
|
||||
# check if it is .index.json
|
||||
reg = re.compile("(.*?).index((\..*)?).json")
|
||||
if reg.fullmatch(checkpoint_path.name) is not None:
|
||||
return True, checkpoint_path
|
||||
else:
|
||||
return False, None
|
||||
elif checkpoint_path.is_dir():
|
||||
index_files = list(checkpoint_path.glob("*.index.*json"))
|
||||
|
||||
for index_file in index_files:
|
||||
if "safetensors" in index_file.__str__():
|
||||
return True, index_file.__str__() # return the safetensors file first
|
||||
|
||||
if len(index_files) == 1:
|
||||
return True, index_files[0]
|
||||
else:
|
||||
assert (
|
||||
len(index_files) == 1
|
||||
), f"Expected to find one .index.json file in {checkpoint_path}, but found {len(index_files)}"
|
||||
return False, None
|
||||
else:
|
||||
raise RuntimeError(f"Invalid checkpoint path {checkpoint_path}. Expected a file or a directory.")
|
||||
|
||||
|
||||
def get_model_size(model: nn.Module):
|
||||
"""Calculates the total size of the model weights (including biases) in bytes.
|
||||
Args:
|
||||
model: The PyTorch model to analyze.
|
||||
Returns:
|
||||
The total size of the model weights in bytes.
|
||||
"""
|
||||
total_size = 0
|
||||
for key, param in model.named_parameters():
|
||||
total_size += param.element_size() * param.numel()
|
||||
return total_size / (1024**3)
|
||||
|
|
|
@ -40,7 +40,7 @@ def check_inference_engine(use_cuda_graph=False, batch_size=32):
|
|||
|
||||
input_len = 1024
|
||||
output_len = 128
|
||||
do_sample = True
|
||||
do_sample = False
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
|
||||
|
|
|
@ -3,24 +3,27 @@ import random
|
|||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.multiprocessing import Manager
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
import colossalai
|
||||
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
|
||||
from colossalai.inference.core.engine import InferenceEngine
|
||||
from colossalai.inference.flash_decoding_utils import FDIntermTensors
|
||||
from colossalai.inference.modeling.models.glide_llama import GlideLlamaConfig, GlideLlamaForCausalLM
|
||||
from colossalai.inference.modeling.policy import NoPaddingLlamaModelInferPolicy
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
|
||||
|
||||
def setup_seed(seed):
|
||||
torch.manual_seed(seed)
|
||||
torch.random.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed)
|
||||
random.seed(seed)
|
||||
|
||||
|
||||
def check_inference_engine(use_engine=False, prompt_template=None):
|
||||
def check_inference_engine(use_engine=False, prompt_template=None, do_sample=True, policy=None):
|
||||
setup_seed(20)
|
||||
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
|
||||
model = LlamaForCausalLM(
|
||||
|
@ -36,13 +39,19 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||
]
|
||||
|
||||
output_len = 38
|
||||
do_sample = True
|
||||
do_sample = do_sample
|
||||
top_p = 0.5
|
||||
top_k = 50
|
||||
|
||||
if use_engine:
|
||||
inference_config = InferenceConfig(max_output_len=output_len, prompt_template=prompt_template, dtype="fp32")
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
|
||||
inference_config = InferenceConfig(
|
||||
max_output_len=output_len,
|
||||
prompt_template=prompt_template,
|
||||
dtype="fp32",
|
||||
use_cuda_kernel=True,
|
||||
tp_size=dist.get_world_size(),
|
||||
)
|
||||
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True, model_policy=policy)
|
||||
assert inference_engine.generation_config.max_new_tokens == output_len
|
||||
inference_engine.add_request(prompts=inputs)
|
||||
assert inference_engine.request_handler._has_waiting()
|
||||
|
@ -69,20 +78,14 @@ def check_inference_engine(use_engine=False, prompt_template=None):
|
|||
return outputs
|
||||
|
||||
|
||||
@parameterize("prompt_template", [None, "llama"])
|
||||
def check_output_consistency(prompt_template):
|
||||
cai_outputs = check_inference_engine(use_engine=True, prompt_template=prompt_template)
|
||||
transformer_outputs = check_inference_engine(use_engine=False, prompt_template=prompt_template)
|
||||
def run_engine(world_size, **kwargs):
|
||||
manager = Manager()
|
||||
result_list = manager.list([-1] * world_size) # Create a shared list
|
||||
|
||||
for s1, s2 in zip(cai_outputs, transformer_outputs):
|
||||
assert s1 == s2, f"\nColossalAI Output: {s1}\nTransformers Output: {s2}"
|
||||
|
||||
# clear singleton flash decoding tensors
|
||||
FDIntermTensors._instances = {}
|
||||
spawn(run_dist, world_size, func_to_run=check_inference_engine, ret=result_list, **kwargs)
|
||||
return result_list[0]
|
||||
|
||||
|
||||
@parameterize("num_layers", [1])
|
||||
@parameterize("max_length", [100])
|
||||
def check_spec_dec(num_layers, max_length):
|
||||
torch.manual_seed(123)
|
||||
|
||||
|
@ -152,16 +155,47 @@ def check_spec_dec(num_layers, max_length):
|
|||
assert len(out_token_ids) == 1 and len(out_token_ids[0]) == max_length
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
def run_dist(rank, world_size, port, func_to_run, ret=None, **kwargs):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, port=port, host="localhost")
|
||||
check_output_consistency()
|
||||
check_spec_dec()
|
||||
|
||||
if ret:
|
||||
ret[rank] = func_to_run(**kwargs)
|
||||
else:
|
||||
func_to_run(**kwargs)
|
||||
|
||||
|
||||
@parameterize("prompt_template", [None, "llama"])
|
||||
@parameterize("do_sample", [False])
|
||||
def test_tp_engine(prompt_template, do_sample):
|
||||
kwargs1 = {
|
||||
"use_engine": True,
|
||||
"prompt_template": prompt_template,
|
||||
"do_sample": do_sample,
|
||||
"policy": NoPaddingLlamaModelInferPolicy(),
|
||||
}
|
||||
|
||||
kwargs2 = {"use_engine": False, "prompt_template": prompt_template, "do_sample": do_sample, "policy": None}
|
||||
|
||||
colossal_tp_1_output = run_engine(1, **kwargs1)
|
||||
colossal_tp_2_output = run_engine(2, **kwargs1)
|
||||
transformer_tp_1_output = run_engine(1, **kwargs2)
|
||||
|
||||
for s1, s2, s3 in zip(colossal_tp_1_output, colossal_tp_2_output, transformer_tp_1_output):
|
||||
assert s1 == s3, f"\nColossalAI TP=1 Output: {s1}\nTransformers Output: {s3}"
|
||||
assert s1 == s2, f"\nColossalAI TP=1 Output: {s1}\nColossalAI TP=2 Output: {s2}"
|
||||
|
||||
|
||||
@parameterize("num_layers", [1])
|
||||
@parameterize("max_length", [100])
|
||||
def test_spec_dec(num_layers, max_length):
|
||||
spawn(run_dist, 1, func_to_run=check_spec_dec, num_layers=num_layers, max_length=max_length)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_inference_engine():
|
||||
spawn(run_dist, 1)
|
||||
test_tp_engine()
|
||||
test_spec_dec()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
|
Loading…
Reference in New Issue