From e7f9f1d20853e856f175d178bf94350871744b67 Mon Sep 17 00:00:00 2001 From: huangting4201 <1538303371@qq.com> Date: Mon, 23 Oct 2023 13:31:23 +0800 Subject: [PATCH] feat(model/overlap_handler.py): optimize reduce scatter mem pool --- internlm/model/overlap_handler.py | 35 ++++++++++--------- internlm/model/utils.py | 4 +-- .../solver/optimizer/hybrid_zero_optim.py | 2 +- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/internlm/model/overlap_handler.py b/internlm/model/overlap_handler.py index b687723..b3c8b8b 100644 --- a/internlm/model/overlap_handler.py +++ b/internlm/model/overlap_handler.py @@ -125,37 +125,38 @@ class FSTPOverlapHandler: # if key not in dict if key not in self.reduce_scatter_memory_pool: - self.reduce_scatter_memory_pool[key] = {"data": [], "used": []} + self.reduce_scatter_memory_pool[key] = [] # if the data is empty - if len(self.reduce_scatter_memory_pool[key]["data"]) == 0: - self.reduce_scatter_memory_pool[key]["data"].append( + if len(self.reduce_scatter_memory_pool[key]) == 0: + self.reduce_scatter_memory_pool[key].append( torch.zeros( key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() ).contiguous() ) - self.reduce_scatter_memory_pool[key]["used"].append(True) - return_idx = 0 - return return_idx + setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False) + setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx) + return self.reduce_scatter_memory_pool[key][return_idx] else: # if not empty - for index, used in enumerate(self.reduce_scatter_memory_pool[key]["used"]): - if used is False: - self.reduce_scatter_memory_pool[key]["used"][index] = True + for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]): + if mem_item.idle is True: + self.reduce_scatter_memory_pool[key][index].idle = False return_idx = index - return return_idx + return self.reduce_scatter_memory_pool[key][return_idx] # if the memory pool is all used - length = len(self.reduce_scatter_memory_pool[key]["data"]) - self.reduce_scatter_memory_pool[key]["data"].append( + cur_len = len(self.reduce_scatter_memory_pool[key]) + self.reduce_scatter_memory_pool[key].append( torch.zeros( key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() ).contiguous() ) - self.reduce_scatter_memory_pool[key]["used"].append(True) - return_idx = length - return return_idx + setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False) + return_idx = cur_len + setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx) + return self.reduce_scatter_memory_pool[key][return_idx] - def release_reduce_scatter_memory(self, size, index): - self.reduce_scatter_memory_pool[size]["used"][index] = False + def release_reduce_scatter_memory(self, key, index): + self.reduce_scatter_memory_pool[key][index].idle = True def _all_gather_block_weight_memory_pool(self, block_index: int): fstp_modules = self.index_to_fstp_modules[block_index] diff --git a/internlm/model/utils.py b/internlm/model/utils.py index cdbed95..8070cbd 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -164,9 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup, world_size = torch.distributed.get_world_size(process_group) assert input_.shape[0] % world_size == 0 size = (input_.shape[0] // world_size, *input_.shape[1:]) - index = gpc.fstp_handler.get_reduce_scatter_memory(size) - output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index] - setattr(output, "index", index) + output = gpc.fstp_handler.get_reduce_scatter_memory(size) handle = torch.distributed.reduce_scatter_tensor( output, input_.contiguous(), group=process_group, async_op=async_op ) diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 08d9722..0d0c8a3 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer): _param.grad.add_(_grad) # release cuda memory. - gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index) + gpc.fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index) self._fstp_handler.reduce_scatter_handlers[_key] = None bucket.reset_by_rank(reduce_rank)