feat(model/overlap_handler.py): optimize reduce scatter mem pool

pull/456/head
huangting4201 2023-10-23 13:31:23 +08:00
parent b20f47a1fe
commit e7f9f1d208
3 changed files with 20 additions and 21 deletions

View File

@ -125,37 +125,38 @@ class FSTPOverlapHandler:
# if key not in dict
if key not in self.reduce_scatter_memory_pool:
self.reduce_scatter_memory_pool[key] = {"data": [], "used": []}
self.reduce_scatter_memory_pool[key] = []
# if the data is empty
if len(self.reduce_scatter_memory_pool[key]["data"]) == 0:
self.reduce_scatter_memory_pool[key]["data"].append(
if len(self.reduce_scatter_memory_pool[key]) == 0:
self.reduce_scatter_memory_pool[key].append(
torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous()
)
self.reduce_scatter_memory_pool[key]["used"].append(True)
return_idx = 0
return return_idx
setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False)
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
return self.reduce_scatter_memory_pool[key][return_idx]
else: # if not empty
for index, used in enumerate(self.reduce_scatter_memory_pool[key]["used"]):
if used is False:
self.reduce_scatter_memory_pool[key]["used"][index] = True
for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]):
if mem_item.idle is True:
self.reduce_scatter_memory_pool[key][index].idle = False
return_idx = index
return return_idx
return self.reduce_scatter_memory_pool[key][return_idx]
# if the memory pool is all used
length = len(self.reduce_scatter_memory_pool[key]["data"])
self.reduce_scatter_memory_pool[key]["data"].append(
cur_len = len(self.reduce_scatter_memory_pool[key])
self.reduce_scatter_memory_pool[key].append(
torch.zeros(
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
).contiguous()
)
self.reduce_scatter_memory_pool[key]["used"].append(True)
return_idx = length
return return_idx
setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False)
return_idx = cur_len
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
return self.reduce_scatter_memory_pool[key][return_idx]
def release_reduce_scatter_memory(self, size, index):
self.reduce_scatter_memory_pool[size]["used"][index] = False
def release_reduce_scatter_memory(self, key, index):
self.reduce_scatter_memory_pool[key][index].idle = True
def _all_gather_block_weight_memory_pool(self, block_index: int):
fstp_modules = self.index_to_fstp_modules[block_index]

View File

@ -164,9 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
size = (input_.shape[0] // world_size, *input_.shape[1:])
index = gpc.fstp_handler.get_reduce_scatter_memory(size)
output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
setattr(output, "index", index)
output = gpc.fstp_handler.get_reduce_scatter_memory(size)
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), group=process_group, async_op=async_op
)

View File

@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer):
_param.grad.add_(_grad)
# release cuda memory.
gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
gpc.fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index)
self._fstp_handler.reduce_scatter_handlers[_key] = None
bucket.reset_by_rank(reduce_rank)