From 0978080a69249befd67994391276b1dd84d0965d Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Tue, 13 Aug 2024 16:07:26 +0800 Subject: [PATCH] [fp8] refactor fp8 linear with compile (#5993) * [fp8] refactor fp8 linear with compile * [fp8] fix linear test * [fp8] fix linear test --- colossalai/quantization/fp8.py | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index f2bffa09f..fe87e317d 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -7,6 +7,8 @@ import torch.nn.functional as F from packaging.version import Version from torch.distributed import ReduceOp +SUPPORT_TORCH_COMPILE = Version(torch.__version__) >= Version("2.3.0") + def cast_to_fp8(inp: torch.Tensor, fp8_format="e4m3", per_channel_scale=False) -> Tuple[torch.Tensor, torch.Tensor]: r""" @@ -664,13 +666,14 @@ class _LinearFp8(torch.autograd.Function): return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -if Version(torch.__version__) >= Version("2.3.0"): # TODO failed on torch < 2.3.0 +@torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE) +def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _LinearFp8.apply(input, weight, bias) - @torch.compile(mode="reduce-overhead", fullgraph=True) - def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _LinearFp8.apply(input, weight, bias) -else: - - def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _LinearFp8.apply(input, weight, bias) +def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + out = _linear_fp8(input, weight, bias) + if SUPPORT_TORCH_COMPILE: + # avoid modifying the tensor created from cuda graph + out = out.clone() + return out