From 91ed32c2569b18f03a870f82f7f3ddb4b2da4e4f Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Wed, 25 Sep 2024 19:00:38 +0800 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 2 +- colossalai/shardformer/modeling/gpt2.py | 2 +- colossalai/shardformer/modeling/llama.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 15ad09baa..2cc6f3163 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -500,8 +500,8 @@ class RingAttention(torch.autograd.Function): k, v, sp_group, - tp_group: Optional[dist.ProcessGroup], attention_mask_type, + tp_group=None, cu_seqlens=None, max_seqlen=None, valid_indices=None, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 01d47bcd0..6be75a3c6 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -866,7 +866,7 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) key, value, sp_group, - tp_group, + tp_group=tp_group, **attention_mask, dropout_p=dropout_p, scale=scale, diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index b5f505dce..08f4bc90d 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -571,7 +571,7 @@ def get_llama_flash_attention_forward(shard_config: ShardConfig, sp_mode=None, s key_states, value_states, sp_group, - tp_group, + tp_group=tp_group, **attention_mask, inner_ring_size=shard_config.inner_ring_size, )