mirror of https://github.com/InternLM/InternLM
feat(model/overlap_handler.py): optimize reduce scatter mem pool
parent
b20f47a1fe
commit
e7f9f1d208
|
@ -125,37 +125,38 @@ class FSTPOverlapHandler:
|
||||||
|
|
||||||
# if key not in dict
|
# if key not in dict
|
||||||
if key not in self.reduce_scatter_memory_pool:
|
if key not in self.reduce_scatter_memory_pool:
|
||||||
self.reduce_scatter_memory_pool[key] = {"data": [], "used": []}
|
self.reduce_scatter_memory_pool[key] = []
|
||||||
|
|
||||||
# if the data is empty
|
# if the data is empty
|
||||||
if len(self.reduce_scatter_memory_pool[key]["data"]) == 0:
|
if len(self.reduce_scatter_memory_pool[key]) == 0:
|
||||||
self.reduce_scatter_memory_pool[key]["data"].append(
|
self.reduce_scatter_memory_pool[key].append(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
|
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
|
||||||
).contiguous()
|
).contiguous()
|
||||||
)
|
)
|
||||||
self.reduce_scatter_memory_pool[key]["used"].append(True)
|
setattr(self.reduce_scatter_memory_pool[key][return_idx], "idle", False)
|
||||||
return_idx = 0
|
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
|
||||||
return return_idx
|
return self.reduce_scatter_memory_pool[key][return_idx]
|
||||||
else: # if not empty
|
else: # if not empty
|
||||||
for index, used in enumerate(self.reduce_scatter_memory_pool[key]["used"]):
|
for index, mem_item in enumerate(self.reduce_scatter_memory_pool[key]):
|
||||||
if used is False:
|
if mem_item.idle is True:
|
||||||
self.reduce_scatter_memory_pool[key]["used"][index] = True
|
self.reduce_scatter_memory_pool[key][index].idle = False
|
||||||
return_idx = index
|
return_idx = index
|
||||||
return return_idx
|
return self.reduce_scatter_memory_pool[key][return_idx]
|
||||||
# if the memory pool is all used
|
# if the memory pool is all used
|
||||||
length = len(self.reduce_scatter_memory_pool[key]["data"])
|
cur_len = len(self.reduce_scatter_memory_pool[key])
|
||||||
self.reduce_scatter_memory_pool[key]["data"].append(
|
self.reduce_scatter_memory_pool[key].append(
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
|
key, dtype=gpc.config.model.get("dtype", torch.half), device=get_current_device()
|
||||||
).contiguous()
|
).contiguous()
|
||||||
)
|
)
|
||||||
self.reduce_scatter_memory_pool[key]["used"].append(True)
|
setattr(self.reduce_scatter_memory_pool[key][cur_len], "idle", False)
|
||||||
return_idx = length
|
return_idx = cur_len
|
||||||
return return_idx
|
setattr(self.reduce_scatter_memory_pool[key][return_idx], "index", return_idx)
|
||||||
|
return self.reduce_scatter_memory_pool[key][return_idx]
|
||||||
|
|
||||||
def release_reduce_scatter_memory(self, size, index):
|
def release_reduce_scatter_memory(self, key, index):
|
||||||
self.reduce_scatter_memory_pool[size]["used"][index] = False
|
self.reduce_scatter_memory_pool[key][index].idle = True
|
||||||
|
|
||||||
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
def _all_gather_block_weight_memory_pool(self, block_index: int):
|
||||||
fstp_modules = self.index_to_fstp_modules[block_index]
|
fstp_modules = self.index_to_fstp_modules[block_index]
|
||||||
|
|
|
@ -164,9 +164,7 @@ def reduce_scatter_raw_memory_pool(input_: Tensor, process_group: ProcessGroup,
|
||||||
world_size = torch.distributed.get_world_size(process_group)
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
assert input_.shape[0] % world_size == 0
|
assert input_.shape[0] % world_size == 0
|
||||||
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
size = (input_.shape[0] // world_size, *input_.shape[1:])
|
||||||
index = gpc.fstp_handler.get_reduce_scatter_memory(size)
|
output = gpc.fstp_handler.get_reduce_scatter_memory(size)
|
||||||
output = gpc.fstp_handler.reduce_scatter_memory_pool[size]["data"][index]
|
|
||||||
setattr(output, "index", index)
|
|
||||||
handle = torch.distributed.reduce_scatter_tensor(
|
handle = torch.distributed.reduce_scatter_tensor(
|
||||||
output, input_.contiguous(), group=process_group, async_op=async_op
|
output, input_.contiguous(), group=process_group, async_op=async_op
|
||||||
)
|
)
|
||||||
|
|
|
@ -350,7 +350,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
_param.grad.add_(_grad)
|
_param.grad.add_(_grad)
|
||||||
|
|
||||||
# release cuda memory.
|
# release cuda memory.
|
||||||
gpc.fstp_handler.release_reduce_scatter_memory(size=tuple(_grad.size()), index=_grad.index)
|
gpc.fstp_handler.release_reduce_scatter_memory(key=tuple(_grad.size()), index=_grad.index)
|
||||||
self._fstp_handler.reduce_scatter_handlers[_key] = None
|
self._fstp_handler.reduce_scatter_handlers[_key] = None
|
||||||
|
|
||||||
bucket.reset_by_rank(reduce_rank)
|
bucket.reset_by_rank(reduce_rank)
|
||||||
|
|
Loading…
Reference in New Issue