fix(utils): fix split cuda memory leak

pull/566/head
877825076@qq.com 2023-12-29 16:15:35 +08:00
parent d418eba094
commit fc60986ed0
1 changed files with 5 additions and 1 deletions

View File

@ -36,7 +36,11 @@ def _split(input_, parallel_mode, dim=-1):
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = gpc.get_local_rank(parallel_mode)
output = tensor_list[rank].contiguous()
# After splitting, the small chunk will share the same storage space with the large tensor.
# We will need to clone the small chunk, which will create a new storage, otherwise it
# will hinder the large tensor's CUDA memory GC.
output = tensor_list[rank].contiguous().clone()
return output