Browse Source

[zero] zero optim state_dict takes only_rank_0 (#1384)

* zero optim state_dict takes only_rank_0

* fix unit test
pull/1387/head
ver217 2 years ago committed by GitHub
parent
commit
8dced41ad0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 24
      colossalai/zero/zero_optimizer.py
  2. 7
      tests/test_zero/test_zero_optim_state_dict.py

24
colossalai/zero/zero_optimizer.py

@ -193,8 +193,9 @@ class ZeroOptimizer(ColossalaiOptimizer):
if isinstance(val, torch.Tensor):
self.chunk_manager.add_extern_static_tensor(val)
def state_dict(self):
r"""Returns the state of the optimizer as a :class:`dict`. For DP rank != 0, this function returns None.
def state_dict(self, only_rank_0: bool = True):
r"""Returns the state of the optimizer as a :class:`dict`. If only_rank_0 is True, for DP rank != 0, this function returns None.
This saves memory usage.
It contains two entries:
@ -204,7 +205,7 @@ class ZeroOptimizer(ColossalaiOptimizer):
parameter group is a dict
"""
is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0
if not self.chunk_manager.enable_distributed_storage and not is_rank_0:
if not self.chunk_manager.enable_distributed_storage and only_rank_0 and not is_rank_0:
return
optim_state_dict = super().state_dict()
scaler_state_dict = self.grad_scaler.state_dict()
@ -214,14 +215,17 @@ class ZeroOptimizer(ColossalaiOptimizer):
local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0}
if not self.chunk_manager.process_group.has_cpu_groups:
self.chunk_manager.process_group.set_cpu_groups()
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0]
output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())]
dist.gather_object(local_state,
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None,
dst=dst_rank,
group=self.chunk_manager.process_group.cpu_dp_process_group())
if not is_rank_0:
return
if only_rank_0:
dst_rank = self.chunk_manager.process_group.dp_rank_list()[0]
dist.gather_object(local_state,
output if self.chunk_manager.process_group.dp_local_rank() == 0 else None,
dst=dst_rank,
group=self.chunk_manager.process_group.cpu_dp_process_group())
if not is_rank_0:
return
else:
dist.all_gather_object(output, local_state, group=self.chunk_manager.process_group.cpu_dp_process_group())
for state in output:
optim_state_dict['state'].update(state)
return optim_state_dict

7
tests/test_zero/test_zero_optim_state_dict.py

@ -45,7 +45,8 @@ def check_state_dict(state_dict, torch_state_dict):
@parameterize('use_chunk', [False, True])
@parameterize('use_zero', [False, True])
@parameterize('placement_policy', ['cuda', 'cpu', 'auto'])
def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy):
@parameterize('only_rank_0', [False, True])
def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy, only_rank_0):
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
@ -76,8 +77,8 @@ def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy):
optim.load_state_dict(torch_state_dict)
check_load_state_dict(optim, torch_optim)
state_dict = optim.state_dict()
if pg.rank() == 0:
state_dict = optim.state_dict(only_rank_0)
if not only_rank_0 or pg.rank() == 0:
check_state_dict(state_dict, torch_state_dict)

Loading…
Cancel
Save