mirror of https://github.com/hpcaitech/ColossalAI
[fp8] refactor fp8 linear with compile (#5993)
* [fp8] refactor fp8 linear with compile * [fp8] fix linear test * [fp8] fix linear testpull/5998/head
parent
b2483c8e31
commit
0978080a69
|
@ -7,6 +7,8 @@ import torch.nn.functional as F
|
|||
from packaging.version import Version
|
||||
from torch.distributed import ReduceOp
|
||||
|
||||
SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0")
|
||||
|
||||
|
||||
def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
r"""
|
||||
|
@ -664,13 +666,14 @@ class _LinearFp8(torch.autograd.Function):
|
|||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||
|
||||
|
||||
if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0
|
||||
@torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE)
|
||||
def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
|
||||
@torch.compile(mode="reduce-overhead", fullgraph=True)
|
||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
|
||||
else:
|
||||
|
||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
out = _linear_fp8(input, weight, bias)
|
||||
if SUPPORT_TORCH_COMPILE:
|
||||
# avoid modifying the tensor created from cuda graph
|
||||
out = out.clone()
|
||||
return out
|
||||
|
|
Loading…
Reference in New Issue