[fp8] use torch compile (torch >= 2.3.0) (#5979)

* [fp8] use torch compile (torch >= 2.4.0)

* [fp8] set use_fast_accum in linear

* [chore] formal version check

* [chore] fix sig
pull/5984/head
botbw 2024-08-09 15:51:06 +08:00 committed by GitHub
parent 8241c0c054
commit e4aadeee20
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 22 additions and 5 deletions

View File

@ -1,13 +1,14 @@
from typing import Any, Optional
from typing import Any, Optional, Tuple
import numpy as np
import torch
import torch.distributed as dist
import torch.nn.functional as F
from packaging.version import Version
from torch.distributed import ReduceOp
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor):
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
Args:
@ -624,7 +625,13 @@ class _LinearFp8(torch.autograd.Function):
ctx.inv_scale_x = inv_scale_x
ctx.inv_scale_w = inv_scale_w
out = torch._scaled_mm(
x_fp8, ctx.w_fp8_t, bias=bias, out_dtype=ctx.out_dtype, scale_a=inv_scale_x, scale_b=inv_scale_w
x_fp8,
ctx.w_fp8_t,
bias=bias,
out_dtype=ctx.out_dtype,
scale_a=inv_scale_x,
scale_b=inv_scale_w,
use_fast_accum=True,
)[0]
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
@ -638,6 +645,7 @@ class _LinearFp8(torch.autograd.Function):
out_dtype=ctx.out_dtype,
scale_a=out_grad_scale,
scale_b=ctx.inv_scale_w,
use_fast_accum=True,
)[0]
w_grad = torch._scaled_mm(
out_grad_fp8.t().contiguous(),
@ -645,6 +653,7 @@ class _LinearFp8(torch.autograd.Function):
out_dtype=ctx.out_dtype,
scale_a=out_grad_scale,
scale_b=ctx.inv_scale_x,
use_fast_accum=True,
)[0]
bias_grad = None
if ctx.has_bias:
@ -652,5 +661,13 @@ class _LinearFp8(torch.autograd.Function):
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0
@torch.compile(mode="reduce-overhead", fullgraph=True)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)
else:
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)