mirror of https://github.com/hpcaitech/ColossalAI
[fp8] linear perf enhancement
parent
88fa096d78
commit
1a2e90dcc1
|
@ -728,14 +728,11 @@ class _LinearFp8(torch.autograd.Function):
|
||||||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||||
|
|
||||||
|
|
||||||
@torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE)
|
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False)
|
||||||
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
return _LinearFp8.apply(input, weight, bias)
|
return _LinearFp8.apply(input, weight, bias)
|
||||||
|
|
||||||
|
|
||||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
out = _linear_fp8(input, weight, bias)
|
out = _linear_fp8(input, weight, bias)
|
||||||
if SUPPORT_TORCH_COMPILE:
|
|
||||||
# avoid modifying the tensor created from cuda graph
|
|
||||||
out = out.clone()
|
|
||||||
return out
|
return out
|
||||||
|
|
Loading…
Reference in New Issue