Browse Source

fix

pull/6071/head
wangbluo 2 months ago
parent
commit
91ed32c256
  1. 2
      colossalai/shardformer/layer/attn.py
  2. 2
      colossalai/shardformer/modeling/gpt2.py
  3. 2
      colossalai/shardformer/modeling/llama.py

2
colossalai/shardformer/layer/attn.py

@ -500,8 +500,8 @@ class RingAttention(torch.autograd.Function):
k,
v,
sp_group,
tp_group: Optional[dist.ProcessGroup],
attention_mask_type,
tp_group=None,
cu_seqlens=None,
max_seqlen=None,
valid_indices=None,

2
colossalai/shardformer/modeling/gpt2.py

@ -866,7 +866,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None)
key,
value,
sp_group,
tp_group,
tp_group=tp_group,
**attention_mask,
dropout_p=dropout_p,
scale=scale,

2
colossalai/shardformer/modeling/llama.py

@ -571,7 +571,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s
key_states,
value_states,
sp_group,
tp_group,
tp_group=tp_group,
**attention_mask,
inner_ring_size=shard_config.inner_ring_size,
)

Loading…
Cancel
Save