|
|
|
@ -8,6 +8,8 @@ import torch.nn.functional as F
|
|
|
|
|
from packaging.version import Version |
|
|
|
|
from torch.distributed import ReduceOp |
|
|
|
|
|
|
|
|
|
from .fp8_config import dynamic_kernel |
|
|
|
|
|
|
|
|
|
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.4.0") |
|
|
|
|
SCALE_BYTES = 4 |
|
|
|
|
try: |
|
|
|
@ -832,11 +834,13 @@ class _LinearFp8(torch.autograd.Function):
|
|
|
|
|
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False) |
|
|
|
|
@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=dynamic_kernel) |
|
|
|
|
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
|
|
return _LinearFp8.apply(input, weight, bias) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
|
|
if input.shape[-1] % 16 != 0 or np.prod(input.shape[:-1]) % 16 != 0: |
|
|
|
|
return F.linear(input, weight, bias) |
|
|
|
|
out = _linear_fp8(input, weight, bias) |
|
|
|
|
return out |
|
|
|
|