mirror of https://github.com/hpcaitech/ColossalAI
[fp8] refactor fp8 linear with compile (#5993)
* [fp8] refactor fp8 linear with compile * [fp8] fix linear test * [fp8] fix linear testpull/5998/head
parent
b2483c8e31
commit
0978080a69
|
@ -7,6 +7,8 @@ import torch.nn.functional as F
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
from torch.distributed import ReduceOp
|
from torch.distributed import ReduceOp
|
||||||
|
|
||||||
|
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0")
|
||||||
|
|
||||||
|
|
||||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||||
r"""
|
r"""
|
||||||
|
@ -664,13 +666,14 @@ class _LinearFp8(torch.autograd.Function):
|
||||||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||||
|
|
||||||
|
|
||||||
if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0
|
@torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE)
|
||||||
|
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
@torch.compile(mode="reduce-overhead", fullgraph=True)
|
|
||||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
|
||||||
return _LinearFp8.apply(input, weight, bias)
|
return _LinearFp8.apply(input, weight, bias)
|
||||||
|
|
||||||
else:
|
|
||||||
|
|
||||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
return _LinearFp8.apply(input, weight, bias)
|
out = _linear_fp8(input, weight, bias)
|
||||||
|
if SUPPORT_TORCH_COMPILE:
|
||||||
|
# avoid modifying the tensor created from cuda graph
|
||||||
|
out = out.clone()
|
||||||
|
return out
|
||||||
|
|
Loading…
Reference in New Issue