feat(model/overlap_handler.py): optimize reduce scatter mem pool

pull/456/head
huangting4201 2023-10-23 13:31:23 +08:00
parent b20f47a1fe
commit e7f9f1d208
3 changed files with 20 additions and 21 deletions

View File

@ -125,37 +125,38 @@ class FSTPOverlapHandler:
# if key not in dict # if key not in dict
if key not in self.reduce_scatter_memory_pool: if key not in self.reduce_scatter_memory_pool:
self.reduce_scatter_memory_pool[key] = {"data": [], "used": []} self.reduce_scatter_memory_pool[key] = []
# if the data is empty # if the data is empty
if len(self.reduce_scatter_memory_pool[key]["data"]) == 0: if len(self.reduce_scatter_memory_pool[key]) == 0:
self.reduce_scatter_memory_pool[key]["data"].append( self.reduce_scatter_memory_pool[key].append(
torch.zeros( torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous() ).contiguous()
) )
self.reduce_scatter_memory_pool[key]["used"].append(True) setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False)
return_idx = 0 setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
return return_idx return self.reduce_scatter_memory_pool[key][return_idx]
else: # if not empty else: # if not empty
for index, used in enumerate(self.reduce_scatter_memory_pool[key]["used"]): for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]):
if used is False: if mem_item.idle is True:
self.reduce_scatter_memory_pool[key]["used"][index] = True self.reduce_scatter_memory_pool[key][index].idle = False
return_idx = index return_idx = index
return return_idx return self.reduce_scatter_memory_pool[key][return_idx]
# if the memory pool is all used # if the memory pool is all used
length = len(self.reduce_scatter_memory_pool[key]["data"]) cur_len = len(self.reduce_scatter_memory_pool[key])
self.reduce_scatter_memory_pool[key]["data"].append( self.reduce_scatter_memory_pool[key].append(
torch.zeros( torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device() key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous() ).contiguous()
) )
self.reduce_scatter_memory_pool[key]["used"].append(True) setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False)
return_idx = length return_idx = cur_len
return return_idx setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
return self.reduce_scatter_memory_pool[key][return_idx]
def release_reduce_scatter_memory(self, size, index): def release_reduce_scatter_memory(self, key, index):
self.reduce_scatter_memory_pool[size]["used"][index] = False self.reduce_scatter_memory_pool[key][index].idle = True
def _all_gather_block_weight_memory_pool(self, block_index: int): def _all_gather_block_weight_memory_pool(self, block_index: int):
fstp_modules = self.index_to_fstp_modules[block_index] fstp_modules = self.index_to_fstp_modules[block_index]

View File

@ -164,9 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
world_size = torch.distributed.get_world_size(process_group) world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0 assert input_.shape[0] % world_size == 0
size = (input_.shape[0] // world_size, *input_.shape[1:]) size = (input_.shape[0] // world_size, *input_.shape[1:])
index = gpc.fstp_handler.get_reduce_scatter_memory(size) output = gpc.fstp_handler.get_reduce_scatter_memory(size)
output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
setattr(output, "index", index)
handle = torch.distributed.reduce_scatter_tensor( handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op output, input_.contiguous(), group=process_group, async_op=async_op
) )

View File

@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer):
_param.grad.add_(_grad) _param.grad.add_(_grad)
# release cuda memory. # release cuda memory.
gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index) gpc.fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index)
self._fstp_handler.reduce_scatter_handlers[_key] = None self._fstp_handler.reduce_scatter_handlers[_key] = None
bucket.reset_by_rank(reduce_rank) bucket.reset_by_rank(reduce_rank)