From 126cf180bcde5603f3cc4935b7f00c09223308c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E3=82=A2=E3=83=9E=E3=83=87=E3=82=A6=E3=82=B9?= Date: Tue, 28 Nov 2023 15:38:26 +0800 Subject: [PATCH] [hotfix] fixed memory usage of shardformer module replacement (#5122) --- colossalai/shardformer/layer/_operation.py | 10 +++++----- colossalai/tensor/d_tensor/comm_spec.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/shardformer/layer/_operation.py b/colossalai/shardformer/layer/_operation.py index 0d8c3d453..8fd92a2ed 100644 --- a/colossalai/shardformer/layer/_operation.py +++ b/colossalai/shardformer/layer/_operation.py @@ -473,16 +473,17 @@ class _GatherForwardSplitBackward(torch.autograd.Function): @staticmethod def backward(ctx, grad_output): return _split(grad_output, ctx.dim, ctx.process_group), None, None - + class HookParameter(torch.autograd.Function): """In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm""" + @staticmethod def forward(ctx, input, weight, bias): ctx.save_for_backward(weight, bias) output = input return output - + @staticmethod def backward(ctx, grad_output): weight, bias = ctx.saved_tensors @@ -491,13 +492,12 @@ class HookParameter(torch.autograd.Function): if bias is not None: bias = bias.view(bias.shape) return grad_output, None, None - + def hook_paramter_in_backward(input, weight=None, bias=None): return HookParameter.apply(input, weight, bias) - def _reduce(input_, process_group): # skip if only one rank involved if dist.get_world_size(process_group) == 1: @@ -522,7 +522,7 @@ def _split(input_, dim=-1, process_group=None): tensor_list = torch.split(input_, dim_size // world_size, dim=dim) rank = dist.get_rank(process_group) - output = tensor_list[rank].contiguous() + output = tensor_list[rank].clone().contiguous() return output diff --git a/colossalai/tensor/d_tensor/comm_spec.py b/colossalai/tensor/d_tensor/comm_spec.py index 8f5b52aab..fc017c663 100644 --- a/colossalai/tensor/d_tensor/comm_spec.py +++ b/colossalai/tensor/d_tensor/comm_spec.py @@ -112,7 +112,7 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec): dim = comm_spec.shard_dim length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group) start = length * dist.get_rank(process_group) - output = torch.narrow(tensor, dim, start, length).contiguous() + output = torch.narrow(tensor, dim, start, length).clone().contiguous() return output