From 4ccb9ded7d774d37c87cac9c133524281c143b94 Mon Sep 17 00:00:00 2001 From: flybird11111 <1829166702@qq.com> Date: Wed, 22 Nov 2023 11:14:25 +0800 Subject: [PATCH] [gemini]fix gemini optimzer, saving Shardformer in Gemini got list assignment index out of range (#5085) --- colossalai/zero/gemini/gemini_optimizer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/zero/gemini/gemini_optimizer.py b/colossalai/zero/gemini/gemini_optimizer.py index 50d4f51d3..8f828bd6c 100644 --- a/colossalai/zero/gemini/gemini_optimizer.py +++ b/colossalai/zero/gemini/gemini_optimizer.py @@ -423,8 +423,8 @@ class GeminiOptimizer(OptimizerWrapper): param = self.id_to_real_params[param_id] fake_param = self.id_to_fake_params.get(param_id, None) chunk = self.chunk_manager.get_chunk(param) - dp_group = chunk.torch_pg - rank = dist.get_rank(dp_group) + zero_group = chunk.torch_pg + rank = dist.get_rank(zero_group) master_rank = 0 collected_states = {} @@ -432,9 +432,9 @@ class GeminiOptimizer(OptimizerWrapper): local_state_names = None if fake_param is not None: local_state_names = list(self.optim.state[fake_param].keys()) - gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))] + gathered_state_names = [None for _ in range(dist.get_world_size(zero_group))] dist.barrier() - dist.all_gather_object(gathered_state_names, local_state_names, dp_group) + dist.all_gather_object(gathered_state_names, local_state_names, zero_group) state_names = None for names in gathered_state_names: if names is not None: @@ -512,10 +512,10 @@ class GeminiOptimizer(OptimizerWrapper): _, shard_offset, shard_size = self.get_offsets(param_id) # Collectors gather state shards through all_gathering. - gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))] + gathered_state_shards = [None for _ in range(dist.get_world_size(zero_group))] dist.barrier() - dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size]) + dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group) if is_collector: for state_shard in gathered_state_shards: