mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
46 lines
1.6 KiB
46 lines
1.6 KiB
import pytest
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch.testing import assert_close
|
|
|
|
from colossalai.accelerator import get_accelerator
|
|
from colossalai.quantization.fp8 import linear_fp8
|
|
from colossalai.utils import get_current_device
|
|
|
|
D_IN, D_OUT = 16, 32
|
|
B, S = 2, 64
|
|
DTYPE = torch.bfloat16
|
|
|
|
|
|
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
|
@pytest.mark.parametrize("use_bias", [True, False])
|
|
@pytest.mark.parametrize("use_batch", [True, False])
|
|
def test_fp8_linear(use_bias: bool, use_batch: bool):
|
|
# create tensors
|
|
w = torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
|
ref_w = w.clone().detach().requires_grad_()
|
|
if use_batch:
|
|
x_shape = (B, S, D_IN)
|
|
else:
|
|
x_shape = (S, D_IN)
|
|
x = torch.rand(x_shape, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
|
ref_x = x.clone().detach().requires_grad_()
|
|
if use_bias:
|
|
bias = torch.rand(D_OUT, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
|
ref_bias = bias.clone().detach().requires_grad_()
|
|
else:
|
|
bias = None
|
|
ref_bias = None
|
|
|
|
out = linear_fp8(x, w, bias)
|
|
assert out.shape == x_shape[:-1] + (D_OUT,)
|
|
out.sum().backward()
|
|
ref_out = F.linear(ref_x, ref_w, ref_bias)
|
|
ref_out.sum().backward()
|
|
|
|
assert_close(out, ref_out, rtol=0.2, atol=0.1)
|
|
assert_close(x.grad, ref_x.grad, rtol=0.2, atol=0.1)
|
|
assert_close(w.grad, ref_w.grad, rtol=0.2, atol=0.1)
|
|
if use_bias:
|
|
assert_close(bias.grad, ref_bias.grad, rtol=0.2, atol=0.1)
|