diff --git a/internlm/model/utils.py b/internlm/model/utils.py index 46fba59..672db9c 100644 --- a/internlm/model/utils.py +++ b/internlm/model/utils.py @@ -36,7 +36,11 @@ def _split(input_, parallel_mode, dim=-1): tensor_list = torch.split(input_, dim_size // world_size, dim=dim) rank = gpc.get_local_rank(parallel_mode) - output = tensor_list[rank].contiguous() + + # After splitting, the small chunk will share the same storage space with the large tensor. + # We will need to clone the small chunk, which will create a new storage, otherwise it + # will hinder the large tensor's CUDA memory GC. + output = tensor_list[rank].contiguous().clone() return output