mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
306 lines
14 KiB
306 lines
14 KiB
import copy
|
|
from typing import List
|
|
|
|
import torch
|
|
|
|
from colossalai.utils import get_current_device
|
|
|
|
from .huggingface import HuggingFaceModel
|
|
|
|
IGNORE_INDEX = -100
|
|
|
|
|
|
class ChatGLMModel(HuggingFaceModel):
|
|
def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
|
|
truncated_inputs = copy.deepcopy(inputs)
|
|
# Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187
|
|
for i, input in enumerate(inputs):
|
|
a_ids = self.tokenizer.encode(text=input, truncation=False, add_special_tokens=False)
|
|
|
|
if len(a_ids) > self.model_max_length - max_new_tokens:
|
|
half = (self.model_max_length - max_new_tokens) // 2
|
|
prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode(
|
|
a_ids[-half:], skip_special_tokens=True
|
|
)
|
|
truncated_inputs[i] = prompt
|
|
|
|
return truncated_inputs
|
|
|
|
@torch.no_grad()
|
|
def get_loss(
|
|
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
|
|
) -> List[List[float]]:
|
|
"""
|
|
Calculate loss only on target tokens.
|
|
|
|
Args:
|
|
batch: A batch of prompt without target answer.
|
|
batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
|
|
|
|
Returns:
|
|
Loss.
|
|
|
|
"""
|
|
|
|
# We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
|
|
# We don't need to generate new tokens.
|
|
# Target answer's length is usually << model_max_length, but we still call it in case.
|
|
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
|
|
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
|
|
|
|
# Get the number of target answers for different questions
|
|
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
|
|
|
|
labels_list = []
|
|
input_ids_list = []
|
|
|
|
for input, targets in zip(batch_prompt, batch_target):
|
|
for target in targets:
|
|
# Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187
|
|
# If there is no history, the prompt is just the query.
|
|
# We don't need to override self.generate() in ChatGLM-6B but need to override it in ChatGLM2-6B.
|
|
# See https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1276
|
|
target_tokenized = self.tokenizer.encode(text=target, add_special_tokens=False)
|
|
|
|
# Get prompt with length model_max_length - len(target_tokenized).
|
|
# Reserve some space for target answer tokens using max_new_tokens.
|
|
# This will generate the correct start_idx and end_idx.
|
|
max_new_tokens = len(target_tokenized)
|
|
|
|
# Here 3 tokens are reserved for [gmask_id, bos_token, eos_id]. So we reserve max_new_tokens + 3 tokens.
|
|
# See https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py#L323
|
|
prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens + 3)[0]
|
|
input_tokenized = self.tokenizer.encode(prompt_with_correct_length, add_special_tokens=False)
|
|
|
|
input_ids = self.tokenizer.build_inputs_with_special_tokens(input_tokenized, target_tokenized)
|
|
|
|
context_length = input_ids.index(self.tokenizer.bos_token_id)
|
|
context_length - 1
|
|
|
|
target_ids = [IGNORE_INDEX] * len(input_ids)
|
|
|
|
# -1 is for eos_token, we don't want to calculate loss on eos token.
|
|
target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1]
|
|
|
|
input_ids_list.append(torch.LongTensor(input_ids))
|
|
labels_list.append(torch.LongTensor(target_ids))
|
|
|
|
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
|
|
# We will generate new batches.
|
|
losses = []
|
|
target_token_nums = []
|
|
|
|
batched_input_ids = [
|
|
input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
|
|
]
|
|
batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
|
|
|
|
for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
|
|
losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
|
|
losses.extend(losses_per_batch)
|
|
target_token_nums.extend(target_token_num_per_batch)
|
|
|
|
start_indice = 0
|
|
losses_per_sample = []
|
|
|
|
target_token_nums_per_sample = []
|
|
for length in batch_target_nums:
|
|
losses_per_sample.append(losses[start_indice : start_indice + length])
|
|
target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
|
|
start_indice += length
|
|
|
|
return losses_per_sample, target_token_nums_per_sample, None
|
|
|
|
def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> List[float]:
|
|
"""
|
|
Calculate loss only on target tokens.
|
|
Hugging Face generate() function can't return per sample loss.
|
|
It will only return the mean of the loss in a batch.
|
|
In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss.
|
|
|
|
Args:
|
|
input_ids_list: A batch of input token ids.
|
|
labels: A batch of labels.
|
|
|
|
Returns:
|
|
A list of loss.
|
|
|
|
"""
|
|
input_ids = torch.nn.utils.rnn.pad_sequence(
|
|
input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id
|
|
).to(get_current_device())
|
|
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to(
|
|
get_current_device()
|
|
)
|
|
|
|
outputs = self.model(input_ids)[0]
|
|
|
|
shift_logits = outputs[..., :-1, :].contiguous()
|
|
shift_labels = labels[..., 1:].contiguous()
|
|
|
|
loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX)
|
|
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size())
|
|
|
|
lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy()
|
|
|
|
loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy()
|
|
return loss_sum.tolist(), lens.tolist()
|
|
|
|
|
|
class ChatGLM2Model(ChatGLMModel):
|
|
def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]:
|
|
truncated_inputs = copy.deepcopy(inputs)
|
|
# Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180
|
|
for i, input in enumerate(inputs):
|
|
a_ids = self.tokenizer.encode(text=input, add_special_tokens=True, truncation=False)
|
|
|
|
if len(a_ids) > self.model_max_length - max_new_tokens:
|
|
half = (self.model_max_length - max_new_tokens) // 2
|
|
prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode(
|
|
a_ids[-half:], skip_special_tokens=True
|
|
)
|
|
truncated_inputs[i] = prompt
|
|
|
|
return truncated_inputs
|
|
|
|
@torch.no_grad()
|
|
def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]:
|
|
"""Generate results given a list of inputs and get logits of the first new token over choices.
|
|
|
|
Args:
|
|
inputs: A list of strings.
|
|
max_new_tokens: Max new tokens for generation.
|
|
kwargs: Key arguments for generation
|
|
|
|
Returns:
|
|
A list of generated strings and logits over choices.
|
|
|
|
Note:
|
|
Currently the function only returns the logits of the first new token.
|
|
It is used for single choice question.
|
|
For multiple choices question, please avoid using the loss over choices.
|
|
You should set argument choices as None in self.inference().
|
|
|
|
"""
|
|
# Follow the process of model.chat() method in modeling_chatglm2.py
|
|
# See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1020
|
|
# See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1001
|
|
|
|
query = []
|
|
for input in inputs:
|
|
prompt = self.tokenizer.build_prompt(input, None)
|
|
query.append(prompt)
|
|
|
|
truncated_query = self._get_truncated_prompts(query, max_new_tokens)
|
|
|
|
encoded_inputs = self.tokenizer(
|
|
truncated_query,
|
|
padding=True,
|
|
truncation=True,
|
|
return_tensors="pt",
|
|
max_length=self.model_max_length - max_new_tokens,
|
|
).to(get_current_device())
|
|
|
|
# Set output_scores=True to get prediction scores.
|
|
outputs = self.model.generate(
|
|
**encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs
|
|
)
|
|
|
|
# We only need to decode predicted tokens.
|
|
sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :]
|
|
|
|
scores = []
|
|
if self.indices_for_choices:
|
|
# If the question is a single-choice question, we will return the scores of specific indices for first predicted token.
|
|
# The indices are the tokenization results of the options for the single-choice question.
|
|
# For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D.
|
|
for option_indices in self.indices_for_choices:
|
|
scores.append(outputs.scores[0][:, option_indices].detach().cpu())
|
|
|
|
scores = torch.max(torch.stack(scores), dim=0)[0]
|
|
|
|
decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True)
|
|
|
|
return decoded_sequences, scores
|
|
|
|
@torch.no_grad()
|
|
def get_loss(
|
|
self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False
|
|
) -> List[List[float]]:
|
|
"""
|
|
Calculate loss only on target tokens.
|
|
|
|
Args:
|
|
batch: A batch of prompt without target answer.
|
|
batch_target: A batch of target answer. Sometimes one question can have multiple target answers.
|
|
|
|
Returns:
|
|
Loss.
|
|
|
|
"""
|
|
|
|
# We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss.
|
|
# We don't need to generate new tokens.
|
|
# Target answer's length is usually << model_max_length, but we still call it in case.
|
|
# We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens.
|
|
batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target]
|
|
|
|
# Get the number of target answers for different questions
|
|
batch_target_nums = [len(prompt_target) for prompt_target in batch_target]
|
|
|
|
labels_list = []
|
|
input_ids_list = []
|
|
|
|
for input, targets in zip(batch_prompt, batch_target):
|
|
for target in targets:
|
|
# Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180
|
|
prompt = self.tokenizer.build_prompt(input, None)
|
|
|
|
target_tokenized = self.tokenizer.encode(
|
|
text=target, add_special_tokens=False, truncation=True, max_length=self.model_max_length
|
|
)
|
|
|
|
max_new_tokens = len(target_tokenized)
|
|
prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0]
|
|
input_tokenized = self.tokenizer.encode(
|
|
prompt_with_correct_length,
|
|
add_special_tokens=True,
|
|
truncation=True,
|
|
max_length=self.model_max_length,
|
|
)
|
|
|
|
input_ids = input_tokenized + target_tokenized + [self.tokenizer.eos_token_id]
|
|
target_ids = [IGNORE_INDEX] * len(input_ids)
|
|
|
|
# -1 is for "eos"
|
|
target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1]
|
|
|
|
input_ids_list.append(torch.LongTensor(input_ids))
|
|
labels_list.append(torch.LongTensor(target_ids))
|
|
|
|
# Because of multiple target answers, the final batch size may be greater than self.batch_size.
|
|
# We will generate new batches.
|
|
losses = []
|
|
target_token_nums = []
|
|
|
|
batched_input_ids = [
|
|
input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size)
|
|
]
|
|
batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)]
|
|
|
|
for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels):
|
|
losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels)
|
|
losses.extend(losses_per_batch)
|
|
target_token_nums.extend(target_token_num_per_batch)
|
|
|
|
start_indice = 0
|
|
losses_per_sample = []
|
|
|
|
target_token_nums_per_sample = []
|
|
for length in batch_target_nums:
|
|
losses_per_sample.append(losses[start_indice : start_indice + length])
|
|
target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length])
|
|
start_indice += length
|
|
|
|
return losses_per_sample, target_token_nums_per_sample, None
|