diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index a191694c1..7c25bee1a 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -482,7 +482,8 @@ class RingAttention(torch.autograd.Function): for i in range(num_rings): for j in range(num_inner_group): # find inner ring group in one sp groups - ranks = list(range(j + i * num_ring_size, j + (i + 1) * num_ring_size, tp_size)) + start = j + i * num_ring_size + ranks = list(range(start, start + tp_size * inner_ring_size, tp_size)) group = dist.new_group(ranks) if rank in ranks: inner_ring_group = group