[gemini]fix gemini optimzer, saving Shardformer in Gemini got list assignment index out of range (#5085)

pull/5076/head
flybird11111 2023-11-22 11:14:25 +08:00 committed by GitHub
parent 0d482302a1
commit 4ccb9ded7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 6 deletions

View File

@ -423,8 +423,8 @@ class GeminiOptimizer(OptimizerWrapper):
param = self.id_to_real_params[param_id]
fake_param = self.id_to_fake_params.get(param_id, None)
chunk = self.chunk_manager.get_chunk(param)
dp_group = chunk.torch_pg
rank = dist.get_rank(dp_group)
zero_group = chunk.torch_pg
rank = dist.get_rank(zero_group)
master_rank = 0
collected_states = {}
@ -432,9 +432,9 @@ class GeminiOptimizer(OptimizerWrapper):
local_state_names = None
if fake_param is not None:
local_state_names = list(self.optim.state[fake_param].keys())
gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))]
gathered_state_names = [None for _ in range(dist.get_world_size(zero_group))]
dist.barrier()
dist.all_gather_object(gathered_state_names, local_state_names, dp_group)
dist.all_gather_object(gathered_state_names, local_state_names, zero_group)
state_names = None
for names in gathered_state_names:
if names is not None:
@ -512,10 +512,10 @@ class GeminiOptimizer(OptimizerWrapper):
_, shard_offset, shard_size = self.get_offsets(param_id)
# Collectors gather state shards through all_gathering.
gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))]
gathered_state_shards = [None for _ in range(dist.get_world_size(zero_group))]
dist.barrier()
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group)
if is_collector:
for state_shard in gathered_state_shards: