diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 1f897c1be..419932a00 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -500,7 +500,7 @@ class RingAttention(torch.autograd.Function): k, v, sp_group, - tp_group, + tp_group : Optional[dist.ProcessGroup], attention_mask_type, cu_seqlens=None, max_seqlen=None, diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 798fca88f..8f476ab86 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -858,12 +858,14 @@ def get_gpt2_flash_attention_forward(shard_config: Optional[ShardConfig] = None) sp_mode = shard_config.sequence_parallelism_mode sp_group = shard_config.sequence_parallel_process_group + tp_group = shard_config.tensor_parallel_process_group if sp_mode == "ring_attn": attn_output = RingAttention.attention( query, key, value, sp_group, + tp_group, **attention_mask, dropout_p=dropout_p, scale=scale,