mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] fix modeling of bloom and falcon (#5796)
parent
587bbf4c6d
commit
aa125bcc91
|
@ -475,7 +475,10 @@ class BloomPipelineForwards:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
|
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||||
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||||
|
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||||
|
sequence_lengths = sequence_lengths.to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
|
@ -291,18 +291,17 @@ class FalconPipelineForwards:
|
||||||
if attention_mask_2d is None:
|
if attention_mask_2d is None:
|
||||||
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
|
attention_mask = alibi / math.sqrt(self.config.hidden_size // self.num_heads)
|
||||||
else:
|
else:
|
||||||
|
min_dtype = torch.finfo(alibi.dtype).min
|
||||||
attention_mask = torch.masked_fill(
|
attention_mask = torch.masked_fill(
|
||||||
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
alibi / math.sqrt(self.config.hidden_size // self.num_heads),
|
||||||
attention_mask < -1,
|
attention_mask < -1,
|
||||||
torch.finfo(alibi.dtype).min,
|
min_dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
# From PyTorch 2.1 onwards, F.scaled_dot_product_attention with the memory-efficient attention backend
|
||||||
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
# produces nans if sequences are completely unattended in the attention mask. Details: https://github.com/pytorch/pytorch/issues/110213
|
||||||
if seq_length > 1:
|
if seq_length > 1 and attention_mask.device.type == "cuda":
|
||||||
attention_mask = AttentionMaskConverter._unmask_unattended(
|
attention_mask = AttentionMaskConverter._unmask_unattended(attention_mask, min_dtype=min_dtype)
|
||||||
attention_mask, attention_mask_2d, unmasked_value=0.0
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
# PyTorch SDPA does not support head_mask, we fall back on the eager implementation in this case.
|
||||||
attention_mask = _prepare_4d_causal_attention_mask(
|
attention_mask = _prepare_4d_causal_attention_mask(
|
||||||
|
@ -543,7 +542,10 @@ class FalconPipelineForwards:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
else:
|
else:
|
||||||
if input_ids is not None:
|
if input_ids is not None:
|
||||||
sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(dim=-1) - 1).to(logits.device)
|
# if no pad token found, use modulo instead of reverse indexing for ONNX compatibility
|
||||||
|
sequence_lengths = torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1
|
||||||
|
sequence_lengths = sequence_lengths % input_ids.shape[-1]
|
||||||
|
sequence_lengths = sequence_lengths.to(logits.device)
|
||||||
else:
|
else:
|
||||||
sequence_lengths = -1
|
sequence_lengths = -1
|
||||||
logger.warning(
|
logger.warning(
|
||||||
|
|
Loading…
Reference in New Issue