mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fixed memory usage of shardformer module replacement (#5122)
parent
7b789f4dd2
commit
126cf180bc
|
@ -473,16 +473,17 @@ class _GatherForwardSplitBackward(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
return _split(grad_output, ctx.dim, ctx.process_group), None, None
|
||||||
|
|
||||||
|
|
||||||
class HookParameter(torch.autograd.Function):
|
class HookParameter(torch.autograd.Function):
|
||||||
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
|
"""In order to be hooked into Gemini's '__torch_function__', adding a view operation to weight and bias. Used in FusedLayerNorm"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, input, weight, bias):
|
def forward(ctx, input, weight, bias):
|
||||||
ctx.save_for_backward(weight, bias)
|
ctx.save_for_backward(weight, bias)
|
||||||
output = input
|
output = input
|
||||||
return output
|
return output
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
weight, bias = ctx.saved_tensors
|
weight, bias = ctx.saved_tensors
|
||||||
|
@ -491,13 +492,12 @@ class HookParameter(torch.autograd.Function):
|
||||||
if bias is not None:
|
if bias is not None:
|
||||||
bias = bias.view(bias.shape)
|
bias = bias.view(bias.shape)
|
||||||
return grad_output, None, None
|
return grad_output, None, None
|
||||||
|
|
||||||
|
|
||||||
def hook_paramter_in_backward(input, weight=None, bias=None):
|
def hook_paramter_in_backward(input, weight=None, bias=None):
|
||||||
return HookParameter.apply(input, weight, bias)
|
return HookParameter.apply(input, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _reduce(input_, process_group):
|
def _reduce(input_, process_group):
|
||||||
# skip if only one rank involved
|
# skip if only one rank involved
|
||||||
if dist.get_world_size(process_group) == 1:
|
if dist.get_world_size(process_group) == 1:
|
||||||
|
@ -522,7 +522,7 @@ def _split(input_, dim=-1, process_group=None):
|
||||||
|
|
||||||
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
||||||
rank = dist.get_rank(process_group)
|
rank = dist.get_rank(process_group)
|
||||||
output = tensor_list[rank].contiguous()
|
output = tensor_list[rank].clone().contiguous()
|
||||||
|
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
|
@ -112,7 +112,7 @@ def _split(tensor: torch.Tensor, comm_spec: CommSpec):
|
||||||
dim = comm_spec.shard_dim
|
dim = comm_spec.shard_dim
|
||||||
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
length = tensor.shape[comm_spec.shard_dim] // dist.get_world_size(process_group)
|
||||||
start = length * dist.get_rank(process_group)
|
start = length * dist.get_rank(process_group)
|
||||||
output = torch.narrow(tensor, dim, start, length).contiguous()
|
output = torch.narrow(tensor, dim, start, length).clone().contiguous()
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue