mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
50 lines
1.4 KiB
50 lines
1.4 KiB
import pytest |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from colossalai.accelerator import get_accelerator |
|
from colossalai.quantization.fp8 import linear_fp8 |
|
from colossalai.quantization.fp8_hook import FP8Hook |
|
from colossalai.tensor.colo_parameter import ColoParameter |
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager |
|
from colossalai.utils import get_current_device |
|
|
|
REPLACED = False |
|
TRIGGERED = False |
|
|
|
|
|
def new_linear_fp8(x, w, bias=None): |
|
global TRIGGERED |
|
TRIGGERED = True |
|
return linear_fp8(x, w, bias) |
|
|
|
|
|
class FP8TestHook(FP8Hook): |
|
def rewrite_op(self, func): |
|
func = super().rewrite_op(func) |
|
if func is linear_fp8: |
|
global REPLACED |
|
REPLACED = True |
|
return new_linear_fp8 |
|
return func |
|
|
|
|
|
D_IN, D_OUT = 16, 32 |
|
B, S = 2, 64 |
|
DTYPE = torch.bfloat16 |
|
|
|
|
|
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") |
|
def test_fp8_hook(): |
|
# create tensors |
|
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE)) |
|
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True) |
|
w.__class__ = ColoParameter |
|
w.__init__(w, requires_grad=True) |
|
hook = FP8TestHook() |
|
with ColoParamOpHookManager.use_hooks(hook): |
|
o = F.linear(x, w) |
|
assert o.shape == (B, S, D_OUT) |
|
assert REPLACED |
|
assert TRIGGERED
|
|
|