You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/applications/ColossalEval/colossal_eval/models/base.py

79 lines
2.1 KiB

from abc import abstractclassmethod
from typing import Dict, List
from colossal_eval.utils import Conversation, prompt_templates
from colossalai.logging import DistributedLogger
class BaseModel:
"""
Base class for model wrapper.
Args:
path: The path to the model.
model_max_length: The maximum sequence length of the model.
prompt_template: The model's prompt template.
batch_size: Batch size for inference.
logger: Logger for the model.
"""
def __init__(
self,
path: str,
model_max_length: int = 2048,
prompt_template: Conversation = None,
batch_size: int = 1,
logger: DistributedLogger = None,
):
self.path = path
self.model_max_length = model_max_length
if prompt_template:
self.prompt_template = prompt_template
else:
self.prompt_template = prompt_templates["plain"]
self.batch_size = batch_size
self.logger = logger
@abstractclassmethod
def inference(self, data: List[Dict]) -> None:
"""
Infer the given data.
This function will call self.generate() to get model outputs and also self.model(input) to get logits.
Args:
data: The data for inference.
"""
@abstractclassmethod
def generate(self, inputs: List[str], max_new_tokens: int) -> List[str]:
"""
Generate results given a list of inputs.
Args:
inputs: A list of strings.
max_new_tokens: The maximum length of the output.
Returns:
A list of generated strings.
"""
@abstractclassmethod
def get_loss(self, batch: List[str], batch_target: List[str]) -> List[float]:
"""
Get loss given batch and batch with target.
Use their length difference after tokenization to mask the loss and only compute loss at target tokens.
Args:
batch: batch prompt without target answer.
batch_target: batch prompt with target answer.
Returns:
A list of loss.
"""
def to(self, device):
self.model.to(device)