diff --git a/colossalai/inference/engine/modeling/bloom.py b/colossalai/inference/engine/modeling/bloom.py index 4c098d3e4..527bcf55e 100644 --- a/colossalai/inference/engine/modeling/bloom.py +++ b/colossalai/inference/engine/modeling/bloom.py @@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union import torch import torch.distributed as dist from torch.nn import functional as F +from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.models.bloom.modeling_bloom import ( BaseModelOutputWithPastAndCrossAttentions, BloomAttention, @@ -86,6 +87,7 @@ class BloomInferenceForwards: **deprecated_arguments, ): r""" + This function is only used when pipeline is enabled. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*): Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100` @@ -153,6 +155,7 @@ class BloomInferenceForwards: tp_group: Optional[dist.ProcessGroup] = None, **deprecated_arguments, ) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + infer_state = infer_state or getattr(self, "infer_state", None) logger = logging.get_logger(__name__) # add warnings here @@ -183,7 +186,7 @@ class BloomInferenceForwards: head_mask = self.get_head_mask(head_mask, self.config.n_layer) # first stage - if stage_manager.is_first_stage(): + if stage_manager is None or stage_manager.is_first_stage(): # check inputs and inputs embeds if input_ids is not None and inputs_embeds is not None: raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time") @@ -255,10 +258,14 @@ class BloomInferenceForwards: infer_state.decode_layer_id = 0 + if stage_index is None: + stage_index = (0, len(self.h)) start_idx, end_idx = stage_index[0], stage_index[1] if past_key_values is None: past_key_values = tuple([None] * (end_idx - start_idx + 1)) + # for HF api compatibility, kv-cache must be returned + next_decoder_cache = () if use_cache else None for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values): block = self.h[idx] outputs = block( @@ -274,8 +281,10 @@ class BloomInferenceForwards: infer_state.decode_layer_id += 1 hidden_states = outputs[0] + if use_cache: + next_decoder_cache += (outputs[2 if output_attentions else 1],) - if stage_manager.is_last_stage() or stage_manager.num_stages == 1: + if stage_manager is None or stage_manager.is_last_stage() or stage_manager.num_stages == 1: hidden_states = self.ln_f(hidden_states) # update indices @@ -283,6 +292,12 @@ class BloomInferenceForwards: infer_state.seq_len += 1 infer_state.max_len_in_batch += 1 + next_cache = next_decoder_cache if use_cache else None + if stage_manager is None: + if not return_dict: + return (hidden_states, next_cache) + return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache) + # always return dict for imediate stage return {"hidden_states": hidden_states}