mirror of https://github.com/hpcaitech/ColossalAI
[fp8] use torch compile (torch >= 2.3.0) (#5979)
* [fp8] use torch compile (torch >= 2.4.0) * [fp8] set use_fast_accum in linear * [chore] formal version check * [chore] fix sigpull/5984/head
parent
8241c0c054
commit
e4aadeee20
|
@ -1,13 +1,14 @@
|
|||
from typing import Any, Optional
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
from packaging.version import Version
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> (torch.Tensor, torch.Tensor):
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
casting torch Tensor into specified fp8 tensor with per-channel scaling or per-tensor scaling.
|
||||
Args:
|
||||
|
@ -624,7 +625,13 @@ class _LinearFp8(torch.autograd.Function):
|
|||
ctx.inv_scale_x = inv_scale_x
|
||||
ctx.inv_scale_w = inv_scale_w
|
||||
out = torch._scaled_mm(
|
||||
x_fp8, ctx.w_fp8_t, bias=bias, out_dtype=ctx.out_dtype, scale_a=inv_scale_x, scale_b=inv_scale_w
|
||||
x_fp8,
|
||||
ctx.w_fp8_t,
|
||||
bias=bias,
|
||||
out_dtype=ctx.out_dtype,
|
||||
scale_a=inv_scale_x,
|
||||
scale_b=inv_scale_w,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
|
||||
|
||||
|
@ -638,6 +645,7 @@ class _LinearFp8(torch.autograd.Function):
|
|||
out_dtype=ctx.out_dtype,
|
||||
scale_a=out_grad_scale,
|
||||
scale_b=ctx.inv_scale_w,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
w_grad = torch._scaled_mm(
|
||||
out_grad_fp8.t().contiguous(),
|
||||
|
@ -645,6 +653,7 @@ class _LinearFp8(torch.autograd.Function):
|
|||
out_dtype=ctx.out_dtype,
|
||||
scale_a=out_grad_scale,
|
||||
scale_b=ctx.inv_scale_x,
|
||||
use_fast_accum=True,
|
||||
)[0]
|
||||
bias_grad = None
|
||||
if ctx.has_bias:
|
||||
|
@ -652,5 +661,13 @@ class _LinearFp8(torch.autograd.Function):
|
|||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||
|
||||
|
||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0
|
||||
|
||||
@torch.compile(mode="reduce-overhead", fullgraph=True)
|
||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
|
||||
else:
|
||||
|
||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
|
|
Loading…
Reference in New Issue