[hotfix] fixed memory usage of shardformer module replacement (#5122)

pull/5125/head
アマデウス 2023-11-28 15:38:26 +08:00 committed by GitHub
parent 7b789f4dd2
commit 126cf180bc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 6 additions and 6 deletions

View File

@ -477,6 +477,7 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
class HookParameter(torch.autograd.Function):
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
@staticmethod
def forward(ctx, input, weight, bias):
ctx.save_for_backward(weight, bias)
@ -497,7 +498,6 @@ def hook_paramter_in_backward(input, weight=None, bias=None):
return HookParameter.apply(input, weight, bias)
def _reduce(input_, process_group):
# skip if only one rank involved
if dist.get_world_size(process_group) == 1:
@ -522,7 +522,7 @@ def _split(input_, dim=-1, process_group=None):
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
rank = dist.get_rank(process_group)
output = tensor_list[rank].contiguous()
output = tensor_list[rank].clone().contiguous()
return output

View File

@ -112,7 +112,7 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec):
dim = comm_spec.shard_dim
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
start = length * dist.get_rank(process_group)
output = torch.narrow(tensor, dim, start, length).contiguous()
output = torch.narrow(tensor, dim, start, length).clone().contiguous()
return output