import copy from typing import List import torch from colossalai.utils import get_current_device from .huggingface import HuggingFaceModel IGNORE_INDEX = -100 class ChatGLMModel(HuggingFaceModel): def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]: truncated_inputs = copy.deepcopy(inputs) # Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187 for i, input in enumerate(inputs): a_ids = self.tokenizer.encode(text=input, truncation=False, add_special_tokens=False) if len(a_ids) > self.model_max_length - max_new_tokens: half = (self.model_max_length - max_new_tokens) // 2 prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode( a_ids[-half:], skip_special_tokens=True ) truncated_inputs[i] = prompt return truncated_inputs @torch.no_grad() def get_loss( self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False ) -> List[List[float]]: """ Calculate loss only on target tokens. Args: batch: A batch of prompt without target answer. batch_target: A batch of target answer. Sometimes one question can have multiple target answers. Returns: Loss. """ # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss. # We don't need to generate new tokens. # Target answer's length is usually << model_max_length, but we still call it in case. # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] # Get the number of target answers for different questions batch_target_nums = [len(prompt_target) for prompt_target in batch_target] labels_list = [] input_ids_list = [] for input, targets in zip(batch_prompt, batch_target): for target in targets: # Adapted from https://github.com/THUDM/ChatGLM-6B/blob/main/ptuning/main.py#L187 # If there is no history, the prompt is just the query. # We don't need to override self.generate() in ChatGLM-6B but need to override it in ChatGLM2-6B. # See https://huggingface.co/THUDM/chatglm-6b/blob/main/modeling_chatglm.py#L1276 target_tokenized = self.tokenizer.encode(text=target, add_special_tokens=False) # Get prompt with length model_max_length - len(target_tokenized). # Reserve some space for target answer tokens using max_new_tokens. # This will generate the correct start_idx and end_idx. max_new_tokens = len(target_tokenized) # Here 3 tokens are reserved for [gmask_id, bos_token, eos_id]. So we reserve max_new_tokens + 3 tokens. # See https://huggingface.co/THUDM/chatglm-6b/blob/main/tokenization_chatglm.py#L323 prompt_with_correct_length = self._get_truncated_prompts([input], max_new_tokens + 3)[0] input_tokenized = self.tokenizer.encode(prompt_with_correct_length, add_special_tokens=False) input_ids = self.tokenizer.build_inputs_with_special_tokens(input_tokenized, target_tokenized) context_length = input_ids.index(self.tokenizer.bos_token_id) context_length - 1 target_ids = [IGNORE_INDEX] * len(input_ids) # -1 is for eos_token, we don't want to calculate loss on eos token. target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1] input_ids_list.append(torch.LongTensor(input_ids)) labels_list.append(torch.LongTensor(target_ids)) # Because of multiple target answers, the final batch size may be greater than self.batch_size. # We will generate new batches. losses = [] target_token_nums = [] batched_input_ids = [ input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size) ] batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)] for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels): losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels) losses.extend(losses_per_batch) target_token_nums.extend(target_token_num_per_batch) start_indice = 0 losses_per_sample = [] target_token_nums_per_sample = [] for length in batch_target_nums: losses_per_sample.append(losses[start_indice : start_indice + length]) target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length]) start_indice += length return losses_per_sample, target_token_nums_per_sample, None def _calculate_loss(self, input_ids_list: List[torch.LongTensor], labels: List[torch.LongTensor]) -> List[float]: """ Calculate loss only on target tokens. Hugging Face generate() function can't return per sample loss. It will only return the mean of the loss in a batch. In torch.nn.CrossEntropyLoss(), reduction should be specified as "none" to get per sample loss. Args: input_ids_list: A batch of input token ids. labels: A batch of labels. Returns: A list of loss. """ input_ids = torch.nn.utils.rnn.pad_sequence( input_ids_list, batch_first=True, padding_value=self.tokenizer.pad_token_id ).to(get_current_device()) labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX).to( get_current_device() ) outputs = self.model(input_ids)[0] shift_logits = outputs[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = torch.nn.CrossEntropyLoss(reduction="none", ignore_index=IGNORE_INDEX) loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)).view(shift_labels.size()) lens = (labels != IGNORE_INDEX).sum(-1).cpu().numpy() loss_sum = loss.sum(-1).to(torch.float32).cpu().detach().numpy() return loss_sum.tolist(), lens.tolist() class ChatGLM2Model(ChatGLMModel): def _get_truncated_prompts(self, inputs: List[str], max_new_tokens: int) -> List[str]: truncated_inputs = copy.deepcopy(inputs) # Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180 for i, input in enumerate(inputs): a_ids = self.tokenizer.encode(text=input, add_special_tokens=True, truncation=False) if len(a_ids) > self.model_max_length - max_new_tokens: half = (self.model_max_length - max_new_tokens) // 2 prompt = self.tokenizer.decode(a_ids[:half], skip_special_tokens=True) + self.tokenizer.decode( a_ids[-half:], skip_special_tokens=True ) truncated_inputs[i] = prompt return truncated_inputs @torch.no_grad() def generate(self, inputs: List[str], max_new_tokens: int, **kwargs) -> List[str]: """Generate results given a list of inputs and get logits of the first new token over choices. Args: inputs: A list of strings. max_new_tokens: Max new tokens for generation. kwargs: Key arguments for generation Returns: A list of generated strings and logits over choices. Note: Currently the function only returns the logits of the first new token. It is used for single choice question. For multiple choices question, please avoid using the loss over choices. You should set argument choices as None in self.inference(). """ # Follow the process of model.chat() method in modeling_chatglm2.py # See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1020 # See https://huggingface.co/THUDM/chatglm2-6b/blob/main/modeling_chatglm.py#L1001 query = [] for input in inputs: prompt = self.tokenizer.build_prompt(input, None) query.append(prompt) truncated_query = self._get_truncated_prompts(query, max_new_tokens) encoded_inputs = self.tokenizer( truncated_query, padding=True, truncation=True, return_tensors="pt", max_length=self.model_max_length - max_new_tokens, ).to(get_current_device()) # Set output_scores=True to get prediction scores. outputs = self.model.generate( **encoded_inputs, max_new_tokens=max_new_tokens, return_dict_in_generate=True, output_scores=True, **kwargs ) # We only need to decode predicted tokens. sequences = outputs.sequences[:, encoded_inputs["input_ids"].shape[1] :] scores = [] if self.indices_for_choices: # If the question is a single-choice question, we will return the scores of specific indices for first predicted token. # The indices are the tokenization results of the options for the single-choice question. # For example, if the options of the question are A, B, C and D, we only returns scores at indices of A, B, C and D. for option_indices in self.indices_for_choices: scores.append(outputs.scores[0][:, option_indices].detach().cpu()) scores = torch.max(torch.stack(scores), dim=0)[0] decoded_sequences = self.tokenizer.batch_decode(sequences, skip_special_tokens=True) return decoded_sequences, scores @torch.no_grad() def get_loss( self, batch_prompt: List[str], batch_target: List[List[str]], pretrain: bool = False ) -> List[List[float]]: """ Calculate loss only on target tokens. Args: batch: A batch of prompt without target answer. batch_target: A batch of target answer. Sometimes one question can have multiple target answers. Returns: Loss. """ # We set max_new_tokens in self._get_truncated_prompts to 0 because we only need logits to calculate loss. # We don't need to generate new tokens. # Target answer's length is usually << model_max_length, but we still call it in case. # We don't call self._get_truncated_prompts for batch_prompt because we need target answer's length first to reserve some space for target answer's tokens. batch_target = [self._get_truncated_prompts(prompt_target, 0) for prompt_target in batch_target] # Get the number of target answers for different questions batch_target_nums = [len(prompt_target) for prompt_target in batch_target] labels_list = [] input_ids_list = [] for input, targets in zip(batch_prompt, batch_target): for target in targets: # Adapted from https://github.com/THUDM/ChatGLM2-6B/blob/main/ptuning/main.py#L180 prompt = self.tokenizer.build_prompt(input, None) target_tokenized = self.tokenizer.encode( text=target, add_special_tokens=False, truncation=True, max_length=self.model_max_length ) max_new_tokens = len(target_tokenized) prompt_with_correct_length = self._get_truncated_prompts([prompt], max_new_tokens)[0] input_tokenized = self.tokenizer.encode( prompt_with_correct_length, add_special_tokens=True, truncation=True, max_length=self.model_max_length, ) input_ids = input_tokenized + target_tokenized + [self.tokenizer.eos_token_id] target_ids = [IGNORE_INDEX] * len(input_ids) # -1 is for "eos" target_ids[-max_new_tokens - 1 : -1] = input_ids[-max_new_tokens - 1 : -1] input_ids_list.append(torch.LongTensor(input_ids)) labels_list.append(torch.LongTensor(target_ids)) # Because of multiple target answers, the final batch size may be greater than self.batch_size. # We will generate new batches. losses = [] target_token_nums = [] batched_input_ids = [ input_ids_list[i : i + self.batch_size] for i in range(0, len(input_ids_list), self.batch_size) ] batched_labels = [labels_list[i : i + self.batch_size] for i in range(0, len(labels_list), self.batch_size)] for batch_input_ids, batch_labels in zip(batched_input_ids, batched_labels): losses_per_batch, target_token_num_per_batch = self._calculate_loss(batch_input_ids, batch_labels) losses.extend(losses_per_batch) target_token_nums.extend(target_token_num_per_batch) start_indice = 0 losses_per_sample = [] target_token_nums_per_sample = [] for length in batch_target_nums: losses_per_sample.append(losses[start_indice : start_indice + length]) target_token_nums_per_sample.append(target_token_nums[start_indice : start_indice + length]) start_indice += length return losses_per_sample, target_token_nums_per_sample, None