[fp8] refactor fp8 linear with compile (#5993)

* [fp8] refactor fp8 linear with compile

* [fp8] fix linear test

* [fp8] fix linear test
pull/5998/head
Hongxin Liu 3 months ago committed by GitHub
parent b2483c8e31
commit 0978080a69
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -7,6 +7,8 @@ import torch.nn.functional as F
from packaging.version import Version
from torch.distributed import ReduceOp
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0")
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
@ -664,13 +666,14 @@ class _LinearFp8(torch.autograd.Function):
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0
@torch.compile(mode="reduce-overhead", fullgraph=True)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
@torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE)
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
else:
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
out = _linear_fp8(input, weight, bias)
if SUPPORT_TORCH_COMPILE:
# avoid modifying the tensor created from cuda graph
out = out.clone()
return out

Loading…
Cancel
Save