mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
79 lines
2.1 KiB
79 lines
2.1 KiB
from abc import abstractclassmethod
|
|
from typing import Dict, List
|
|
|
|
from colossal_eval.utils import Conversation, prompt_templates
|
|
|
|
from colossalai.logging import DistributedLogger
|
|
|
|
|
|
class BaseModel:
|
|
"""
|
|
Base class for model wrapper.
|
|
|
|
Args:
|
|
path: The path to the model.
|
|
model_max_length: The maximum sequence length of the model.
|
|
prompt_template: The model's prompt template.
|
|
batch_size: Batch size for inference.
|
|
logger: Logger for the model.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
path: str,
|
|
model_max_length: int = 2048,
|
|
prompt_template: Conversation = None,
|
|
batch_size: int = 1,
|
|
logger: DistributedLogger = None,
|
|
):
|
|
self.path = path
|
|
self.model_max_length = model_max_length
|
|
|
|
if prompt_template:
|
|
self.prompt_template = prompt_template
|
|
else:
|
|
self.prompt_template = prompt_templates["plain"]
|
|
|
|
self.batch_size = batch_size
|
|
self.logger = logger
|
|
|
|
@abstractclassmethod
|
|
def inference(self, data: List[Dict]) -> None:
|
|
"""
|
|
Infer the given data.
|
|
This function will call self.generate() to get model outputs and also self.model(input) to get logits.
|
|
|
|
Args:
|
|
data: The data for inference.
|
|
"""
|
|
|
|
@abstractclassmethod
|
|
def generate(self, inputs: List[str], max_new_tokens: int) -> List[str]:
|
|
"""
|
|
Generate results given a list of inputs.
|
|
|
|
Args:
|
|
inputs: A list of strings.
|
|
max_new_tokens: The maximum length of the output.
|
|
|
|
Returns:
|
|
A list of generated strings.
|
|
"""
|
|
|
|
@abstractclassmethod
|
|
def get_loss(self, batch: List[str], batch_target: List[str]) -> List[float]:
|
|
"""
|
|
Get loss given batch and batch with target.
|
|
Use their length difference after tokenization to mask the loss and only compute loss at target tokens.
|
|
|
|
Args:
|
|
batch: batch prompt without target answer.
|
|
batch_target: batch prompt with target answer.
|
|
|
|
Returns:
|
|
A list of loss.
|
|
"""
|
|
|
|
def to(self, device):
|
|
self.model.to(device)
|