[hotfix] issue #2388

pull/2389/head
jiaruifang 2023-01-07 18:23:02 +08:00
parent 4e96039649
commit 69d9180c4b
2 changed files with 16 additions and 13 deletions

View File

@ -16,17 +16,17 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
@custom_fwd(cast_inputs=torch.float32)
def forward(ctx, input, weight, bias, normalized_shape, eps):
try:
import colossalai._C.layer_norm
from colossalai._C import layer_norm
except ImportError:
raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
layer_norm = LayerNormBuilder().load()
ctx.normalized_shape = normalized_shape
ctx.eps = eps
input_ = input.contiguous()
weight_ = weight.contiguous()
bias_ = bias.contiguous()
output, mean, invvar = colossalai._C.layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_,
ctx.eps)
output, mean, invvar = layer_norm.forward_affine(input_, ctx.normalized_shape, weight_, bias_, ctx.eps)
ctx.save_for_backward(input_, weight_, bias_, mean, invvar)
return output
@ -35,14 +35,15 @@ class FusedLayerNormAffineFunction(torch.autograd.Function):
@custom_bwd
def backward(ctx, grad_output):
try:
import colossalai._C.layer_norm
from colossalai._C import layer_norm
except ImportError:
raise RuntimeError('FusedLayerNormAffineFunction requires cuda extensions')
from colossalai.kernel.op_builder.layernorm import LayerNormBuilder
layer_norm = LayerNormBuilder().load()
input_, weight_, bias_, mean, invvar = ctx.saved_tensors
grad_input = grad_weight = grad_bias = None
grad_input, grad_weight, grad_bias \
= colossalai._C.layer_norm.backward_affine(
= layer_norm.backward_affine(
grad_output.contiguous(), mean, invvar,
input_, ctx.normalized_shape,
weight_, bias_, ctx.eps)

View File

@ -53,26 +53,28 @@ class ScaledMaskedSoftmax(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, mask, scale):
try:
import colossalai._C.scaled_masked_softmax
from colossalai._C import scaled_masked_softmax
except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
scale_t = torch.tensor([scale])
softmax_results = colossalai._C.scaled_masked_softmax.forward(inputs, mask, scale_t[0])
softmax_results = scaled_masked_softmax.forward(inputs, mask, scale_t[0])
ctx.save_for_backward(softmax_results, scale_t)
return softmax_results
@staticmethod
def backward(ctx, output_grads):
try:
import colossalai._C.scaled_masked_softmax
from colossalai._C import scaled_masked_softmax
except ImportError:
raise RuntimeError('ScaledMaskedSoftmax requires cuda extensions')
from colossalai.kernel.op_builder.scaled_masked_softmax import ScaledMaskedSoftmaxBuilder
scaled_masked_softmax = ScaledMaskedSoftmaxBuilder().load()
softmax_results, scale_t = ctx.saved_tensors
input_grads = colossalai._C.scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
input_grads = scaled_masked_softmax.backward(output_grads, softmax_results, scale_t[0])
return input_grads, None, None