From 83cf2f84fb0c08a351a5affc71527556ff8912bc Mon Sep 17 00:00:00 2001 From: wangbluo <2538539015@qq.com> Date: Tue, 15 Oct 2024 14:50:27 +0800 Subject: [PATCH] fix --- colossalai/shardformer/layer/attn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index 1a175f426..bbd99d162 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -470,14 +470,14 @@ class RingAttention(torch.autograd.Function): # Create inner ring groups for i in range(inner_ring_size): ranks = list(range(i * inner_ring_size, (i + 1) * inner_ring_size)) - group = pg_mesh.get_group_along_axis(2, ranks) + group = pg_mesh.get_group_along_axis(sp_axis, ranks) if sp_rank in ranks: inner_ring_group = group # Create inter ring groups for i in range(num_rings): ranks = list(range(i, sp_size, num_rings)) - group = pg_mesh.get_group_along_axis(2, ranks) + group = pg_mesh.get_group_along_axis(sp_axis, ranks) if sp_rank in ranks: inter_ring_group = group