[inference] decouple pipeline logci for bloom (#5097)

pull/5079/head^2
Hongxin Liu 2023-11-22 17:49:25 +08:00 committed by GitHub
parent afe3c78d9a
commit 67a07e6f64
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 17 additions and 2 deletions

View File

@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
import torch
import torch.distributed as dist
from torch.nn import functional as F
from transformers.modeling_outputs import BaseModelOutputWithPast
from transformers.models.bloom.modeling_bloom import (
BaseModelOutputWithPastAndCrossAttentions,
BloomAttention,
@ -86,6 +87,7 @@ class BloomInferenceForwards:
**deprecated_arguments,
):
r"""
This function is only used when pipeline is enabled.
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
@ -153,6 +155,7 @@ class BloomInferenceForwards:
tp_group: Optional[dist.ProcessGroup] = None,
**deprecated_arguments,
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
infer_state = infer_state or getattr(self, "infer_state", None)
logger = logging.get_logger(__name__)
# add warnings here
@ -183,7 +186,7 @@ class BloomInferenceForwards:
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
# first stage
if stage_manager.is_first_stage():
if stage_manager is None or stage_manager.is_first_stage():
# check inputs and inputs embeds
if input_ids is not None and inputs_embeds is not None:
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
@ -255,10 +258,14 @@ class BloomInferenceForwards:
infer_state.decode_layer_id = 0
if stage_index is None:
stage_index = (0, len(self.h))
start_idx, end_idx = stage_index[0], stage_index[1]
if past_key_values is None:
past_key_values = tuple([None] * (end_idx - start_idx + 1))
# for HF api compatibility, kv-cache must be returned
next_decoder_cache = () if use_cache else None
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
block = self.h[idx]
outputs = block(
@ -274,8 +281,10 @@ class BloomInferenceForwards:
infer_state.decode_layer_id += 1
hidden_states = outputs[0]
if use_cache:
next_decoder_cache += (outputs[2 if output_attentions else 1],)
if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
if stage_manager is None or stage_manager.is_last_stage() or stage_manager.num_stages == 1:
hidden_states = self.ln_f(hidden_states)
# update indices
@ -283,6 +292,12 @@ class BloomInferenceForwards:
infer_state.seq_len += 1
infer_state.max_len_in_batch += 1
next_cache = next_decoder_cache if use_cache else None
if stage_manager is None:
if not return_dict:
return (hidden_states, next_cache)
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)
# always return dict for imediate stage
return {"hidden_states": hidden_states}