mirror of https://github.com/hpcaitech/ColossalAI
[fp8] add fp8 linear (#5967)
* [fp8] add fp8 linear * [test] fix fp8 linear test condition * [test] fix fp8 linear test condition * [test] fix fp8 linear test conditionpull/5976/head
parent
afb26de873
commit
76ea16466f
|
@ -1,4 +1,4 @@
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -415,3 +415,62 @@ def gather_fp8(output_list, input_, group=None, fp8_format="e5m2"):
|
||||||
output = tensor_list[i].view(fp8_type)
|
output = tensor_list[i].view(fp8_type)
|
||||||
scale = scale_list[i]
|
scale = scale_list[i]
|
||||||
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
|
output_list[i].copy_(cast_from_fp8(output, scale, input_type))
|
||||||
|
|
||||||
|
|
||||||
|
class _LinearFp8(torch.autograd.Function):
|
||||||
|
@staticmethod
|
||||||
|
def forward(
|
||||||
|
ctx: Any,
|
||||||
|
x: torch.Tensor,
|
||||||
|
w: torch.Tensor,
|
||||||
|
bias: Optional[torch.Tensor],
|
||||||
|
) -> Any:
|
||||||
|
assert (
|
||||||
|
x.dtype in (torch.bfloat16, torch.float16) and x.dtype == w.dtype
|
||||||
|
), "Only float16 and bfloat16 are allowed."
|
||||||
|
if bias is not None:
|
||||||
|
assert bias.dtype == x.dtype, "Bias should have the same dtype as input."
|
||||||
|
# ensure x and w are row-major
|
||||||
|
assert x.is_contiguous() and w.is_contiguous(), "Input and weight should be contiguous."
|
||||||
|
ctx.x_shape = x.shape
|
||||||
|
ctx.has_bias = bias is not None
|
||||||
|
ctx.out_dtype = x.dtype
|
||||||
|
x = x.reshape(-1, x.shape[-1])
|
||||||
|
|
||||||
|
x_fp8, inv_scale_x = cast_to_fp8(x, fp8_format="e4m3")
|
||||||
|
w_fp8, inv_scale_w = cast_to_fp8(w, fp8_format="e4m3")
|
||||||
|
ctx.x_fp8 = x_fp8
|
||||||
|
ctx.w_fp8_t = w_fp8.t()
|
||||||
|
ctx.inv_scale_x = inv_scale_x
|
||||||
|
ctx.inv_scale_w = inv_scale_w
|
||||||
|
out = torch._scaled_mm(
|
||||||
|
x_fp8, ctx.w_fp8_t, bias=bias, out_dtype=ctx.out_dtype, scale_a=inv_scale_x, scale_b=inv_scale_w
|
||||||
|
)[0]
|
||||||
|
return out.reshape(*ctx.x_shape[:-1], w.shape[0])
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def backward(ctx: Any, out_grad) -> Any:
|
||||||
|
out_grad = out_grad.reshape(-1, out_grad.shape[-1])
|
||||||
|
out_grad_fp8, out_grad_scale = cast_to_fp8(out_grad, fp8_format="e5m2")
|
||||||
|
x_grad = torch._scaled_mm(
|
||||||
|
out_grad_fp8,
|
||||||
|
ctx.w_fp8_t.contiguous().t(),
|
||||||
|
out_dtype=ctx.out_dtype,
|
||||||
|
scale_a=out_grad_scale,
|
||||||
|
scale_b=ctx.inv_scale_w,
|
||||||
|
)[0]
|
||||||
|
w_grad = torch._scaled_mm(
|
||||||
|
out_grad_fp8.t().contiguous(),
|
||||||
|
ctx.x_fp8.t().contiguous().t(),
|
||||||
|
out_dtype=ctx.out_dtype,
|
||||||
|
scale_a=out_grad_scale,
|
||||||
|
scale_b=ctx.inv_scale_x,
|
||||||
|
)[0]
|
||||||
|
bias_grad = None
|
||||||
|
if ctx.has_bias:
|
||||||
|
bias_grad = out_grad.sum(0)
|
||||||
|
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||||
|
|
||||||
|
|
||||||
|
def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
|
return _LinearFp8.apply(x, w, bias)
|
||||||
|
|
|
@ -0,0 +1,45 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
from colossalai.accelerator import get_accelerator
|
||||||
|
from colossalai.quantization.fp8 import linear_fp8
|
||||||
|
from colossalai.utils import get_current_device
|
||||||
|
|
||||||
|
D_IN, D_OUT = 16, 32
|
||||||
|
B, S = 2, 64
|
||||||
|
DTYPE = torch.bfloat16
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
||||||
|
@pytest.mark.parametrize("use_bias", [True, False])
|
||||||
|
@pytest.mark.parametrize("use_batch", [True, False])
|
||||||
|
def test_fp8_linear(use_bias: bool, use_batch: bool):
|
||||||
|
# create tensors
|
||||||
|
w = torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||||
|
ref_w = w.clone().detach().requires_grad_()
|
||||||
|
if use_batch:
|
||||||
|
x_shape = (B, S, D_IN)
|
||||||
|
else:
|
||||||
|
x_shape = (S, D_IN)
|
||||||
|
x = torch.rand(x_shape, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||||
|
ref_x = x.clone().detach().requires_grad_()
|
||||||
|
if use_bias:
|
||||||
|
bias = torch.rand(D_OUT, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||||
|
ref_bias = bias.clone().detach().requires_grad_()
|
||||||
|
else:
|
||||||
|
bias = None
|
||||||
|
ref_bias = None
|
||||||
|
|
||||||
|
out = linear_fp8(x, w, bias)
|
||||||
|
assert out.shape == x_shape[:-1] + (D_OUT,)
|
||||||
|
out.sum().backward()
|
||||||
|
ref_out = F.linear(ref_x, ref_w, ref_bias)
|
||||||
|
ref_out.sum().backward()
|
||||||
|
|
||||||
|
assert_close(out, ref_out, rtol=0.2, atol=0.1)
|
||||||
|
assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
|
||||||
|
assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)
|
||||||
|
if use_bias:
|
||||||
|
assert_close(bias.grad, ref_bias.grad, rtol=0.2, atol=0.1)
|
Loading…
Reference in New Issue