mirror of https://github.com/hpcaitech/ColossalAI
[gemini]fix gemini optimzer, saving Shardformer in Gemini got list assignment index out of range (#5085)
parent
0d482302a1
commit
4ccb9ded7d
|
@ -423,8 +423,8 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
param = self.id_to_real_params[param_id]
|
||||
fake_param = self.id_to_fake_params.get(param_id, None)
|
||||
chunk = self.chunk_manager.get_chunk(param)
|
||||
dp_group = chunk.torch_pg
|
||||
rank = dist.get_rank(dp_group)
|
||||
zero_group = chunk.torch_pg
|
||||
rank = dist.get_rank(zero_group)
|
||||
master_rank = 0
|
||||
collected_states = {}
|
||||
|
||||
|
@ -432,9 +432,9 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
local_state_names = None
|
||||
if fake_param is not None:
|
||||
local_state_names = list(self.optim.state[fake_param].keys())
|
||||
gathered_state_names = [None for _ in range(dist.get_world_size(dp_group))]
|
||||
gathered_state_names = [None for _ in range(dist.get_world_size(zero_group))]
|
||||
dist.barrier()
|
||||
dist.all_gather_object(gathered_state_names, local_state_names, dp_group)
|
||||
dist.all_gather_object(gathered_state_names, local_state_names, zero_group)
|
||||
state_names = None
|
||||
for names in gathered_state_names:
|
||||
if names is not None:
|
||||
|
@ -512,10 +512,10 @@ class GeminiOptimizer(OptimizerWrapper):
|
|||
_, shard_offset, shard_size = self.get_offsets(param_id)
|
||||
|
||||
# Collectors gather state shards through all_gathering.
|
||||
gathered_state_shards = [None for _ in range(dist.get_world_size(dp_group))]
|
||||
gathered_state_shards = [None for _ in range(dist.get_world_size(zero_group))]
|
||||
|
||||
dist.barrier()
|
||||
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size])
|
||||
dist.all_gather_object(gathered_state_shards, [compacted_states, shard_offset, shard_size], group=zero_group)
|
||||
|
||||
if is_collector:
|
||||
for state_shard in gathered_state_shards:
|
||||
|
|
Loading…
Reference in New Issue