From 8dced41ad00560538d4ff1ac46c6bfae67c4b6fd Mon Sep 17 00:00:00 2001 From: ver217 Date: Fri, 29 Jul 2022 13:22:50 +0800 Subject: [PATCH] [zero] zero optim state_dict takes only_rank_0 (#1384) * zero optim state_dict takes only_rank_0 * fix unit test --- colossalai/zero/zero_optimizer.py | 24 +++++++++++-------- tests/test_zero/test_zero_optim_state_dict.py | 7 +++--- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/colossalai/zero/zero_optimizer.py b/colossalai/zero/zero_optimizer.py index c78d517e5..e5a2f9f90 100644 --- a/colossalai/zero/zero_optimizer.py +++ b/colossalai/zero/zero_optimizer.py @@ -193,8 +193,9 @@ class ZeroOptimizer(ColossalaiOptimizer): if isinstance(val, torch.Tensor): self.chunk_manager.add_extern_static_tensor(val) - def state_dict(self): - r"""Returns the state of the optimizer as a :class:`dict`. For DP rank != 0, this function returns None. + def state_dict(self, only_rank_0: bool = True): + r"""Returns the state of the optimizer as a :class:`dict`. If only_rank_0 is True, for DP rank != 0, this function returns None. + This saves memory usage. It contains two entries: @@ -204,7 +205,7 @@ class ZeroOptimizer(ColossalaiOptimizer): parameter group is a dict """ is_rank_0 = self.chunk_manager.process_group.dp_local_rank() == 0 - if not self.chunk_manager.enable_distributed_storage and not is_rank_0: + if not self.chunk_manager.enable_distributed_storage and only_rank_0 and not is_rank_0: return optim_state_dict = super().state_dict() scaler_state_dict = self.grad_scaler.state_dict() @@ -214,14 +215,17 @@ class ZeroOptimizer(ColossalaiOptimizer): local_state = {k: convert_state_dict_to_cpu(v) for k, v in optim_state_dict['state'].items() if len(v) > 0} if not self.chunk_manager.process_group.has_cpu_groups: self.chunk_manager.process_group.set_cpu_groups() - dst_rank = self.chunk_manager.process_group.dp_rank_list()[0] output = [None for _ in range(self.chunk_manager.process_group.dp_world_size())] - dist.gather_object(local_state, - output if self.chunk_manager.process_group.dp_local_rank() == 0 else None, - dst=dst_rank, - group=self.chunk_manager.process_group.cpu_dp_process_group()) - if not is_rank_0: - return + if only_rank_0: + dst_rank = self.chunk_manager.process_group.dp_rank_list()[0] + dist.gather_object(local_state, + output if self.chunk_manager.process_group.dp_local_rank() == 0 else None, + dst=dst_rank, + group=self.chunk_manager.process_group.cpu_dp_process_group()) + if not is_rank_0: + return + else: + dist.all_gather_object(output, local_state, group=self.chunk_manager.process_group.cpu_dp_process_group()) for state in output: optim_state_dict['state'].update(state) return optim_state_dict diff --git a/tests/test_zero/test_zero_optim_state_dict.py b/tests/test_zero/test_zero_optim_state_dict.py index 104ca21bc..cc67242c9 100644 --- a/tests/test_zero/test_zero_optim_state_dict.py +++ b/tests/test_zero/test_zero_optim_state_dict.py @@ -45,7 +45,8 @@ def check_state_dict(state_dict, torch_state_dict): @parameterize('use_chunk', [False, True]) @parameterize('use_zero', [False, True]) @parameterize('placement_policy', ['cuda', 'cpu', 'auto']) -def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy): +@parameterize('only_rank_0', [False, True]) +def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy, only_rank_0): get_components_func = non_distributed_component_funcs.get_callable('gpt2') model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() @@ -76,8 +77,8 @@ def run_zero_optim_state_dict(use_chunk, use_zero, placement_policy): optim.load_state_dict(torch_state_dict) check_load_state_dict(optim, torch_optim) - state_dict = optim.state_dict() - if pg.rank() == 0: + state_dict = optim.state_dict(only_rank_0) + if not only_rank_0 or pg.rank() == 0: check_state_dict(state_dict, torch_state_dict)