From 1a2e90dcc13e64f3bf273d36315b26095c0dd64c Mon Sep 17 00:00:00 2001 From: botbw Date: Thu, 15 Aug 2024 03:12:08 +0000 Subject: [PATCH] [fp8] linear perf enhancement --- colossalai/quantization/fp8.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index 8ada42935..5b606616e 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -728,14 +728,11 @@ class _LinearFp8(torch.autograd.Function): return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -@torch.compile(mode="reduce-overhead", disable=not SUPPORT_TORCH_COMPILE) +@torch.compile(mode="max-autotune-no-cudagraphs", disable=not SUPPORT_TORCH_COMPILE, dynamic=False) def _linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: return _LinearFp8.apply(input, weight, bias) def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: out = _linear_fp8(input, weight, bias) - if SUPPORT_TORCH_COMPILE: - # avoid modifying the tensor created from cuda graph - out = out.clone() return out