mirror of https://github.com/hpcaitech/ColossalAI
[inference] decouple pipeline logci for bloom (#5097)
parent
afe3c78d9a
commit
67a07e6f64
|
@ -5,6 +5,7 @@ from typing import List, Optional, Tuple, Union
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.nn import functional as F
|
||||
from transformers.modeling_outputs import BaseModelOutputWithPast
|
||||
from transformers.models.bloom.modeling_bloom import (
|
||||
BaseModelOutputWithPastAndCrossAttentions,
|
||||
BloomAttention,
|
||||
|
@ -86,6 +87,7 @@ class BloomInferenceForwards:
|
|||
**deprecated_arguments,
|
||||
):
|
||||
r"""
|
||||
This function is only used when pipeline is enabled.
|
||||
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
||||
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
|
||||
`labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
|
||||
|
@ -153,6 +155,7 @@ class BloomInferenceForwards:
|
|||
tp_group: Optional[dist.ProcessGroup] = None,
|
||||
**deprecated_arguments,
|
||||
) -> Union[Tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
|
||||
infer_state = infer_state or getattr(self, "infer_state", None)
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
# add warnings here
|
||||
|
@ -183,7 +186,7 @@ class BloomInferenceForwards:
|
|||
head_mask = self.get_head_mask(head_mask, self.config.n_layer)
|
||||
|
||||
# first stage
|
||||
if stage_manager.is_first_stage():
|
||||
if stage_manager is None or stage_manager.is_first_stage():
|
||||
# check inputs and inputs embeds
|
||||
if input_ids is not None and inputs_embeds is not None:
|
||||
raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
|
||||
|
@ -255,10 +258,14 @@ class BloomInferenceForwards:
|
|||
|
||||
infer_state.decode_layer_id = 0
|
||||
|
||||
if stage_index is None:
|
||||
stage_index = (0, len(self.h))
|
||||
start_idx, end_idx = stage_index[0], stage_index[1]
|
||||
if past_key_values is None:
|
||||
past_key_values = tuple([None] * (end_idx - start_idx + 1))
|
||||
|
||||
# for HF api compatibility, kv-cache must be returned
|
||||
next_decoder_cache = () if use_cache else None
|
||||
for idx, past_key_value in zip(range(start_idx, end_idx), past_key_values):
|
||||
block = self.h[idx]
|
||||
outputs = block(
|
||||
|
@ -274,8 +281,10 @@ class BloomInferenceForwards:
|
|||
|
||||
infer_state.decode_layer_id += 1
|
||||
hidden_states = outputs[0]
|
||||
if use_cache:
|
||||
next_decoder_cache += (outputs[2 if output_attentions else 1],)
|
||||
|
||||
if stage_manager.is_last_stage() or stage_manager.num_stages == 1:
|
||||
if stage_manager is None or stage_manager.is_last_stage() or stage_manager.num_stages == 1:
|
||||
hidden_states = self.ln_f(hidden_states)
|
||||
|
||||
# update indices
|
||||
|
@ -283,6 +292,12 @@ class BloomInferenceForwards:
|
|||
infer_state.seq_len += 1
|
||||
infer_state.max_len_in_batch += 1
|
||||
|
||||
next_cache = next_decoder_cache if use_cache else None
|
||||
if stage_manager is None:
|
||||
if not return_dict:
|
||||
return (hidden_states, next_cache)
|
||||
return BaseModelOutputWithPast(last_hidden_state=hidden_states, past_key_values=next_cache)
|
||||
|
||||
# always return dict for imediate stage
|
||||
return {"hidden_states": hidden_states}
|
||||
|
||||
|
|
Loading…
Reference in New Issue