mirror of https://github.com/hpcaitech/ColossalAI
Browse Source
* [fp8] support fp8 amp for hybrid parallel plugin * [test] add fp8 hook test * [fp8] fix fp8 linear compatibilitypull/5981/head
Hongxin Liu
4 months ago
committed by
GitHub
6 changed files with 102 additions and 3 deletions
@ -0,0 +1,23 @@
|
||||
import torch.nn.functional as F |
||||
|
||||
from colossalai.quantization.fp8 import linear_fp8 |
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook |
||||
|
||||
|
||||
class FP8Hook(ColoParamOpHook): |
||||
def pre_forward(self, params) -> None: |
||||
pass |
||||
|
||||
def post_forward(self, params) -> None: |
||||
pass |
||||
|
||||
def pre_backward(self, params) -> None: |
||||
pass |
||||
|
||||
def post_backward(self, params) -> None: |
||||
pass |
||||
|
||||
def rewrite_op(self, func): |
||||
if func is F.linear: |
||||
return linear_fp8 |
||||
return func |
@ -0,0 +1,50 @@
|
||||
import pytest |
||||
import torch |
||||
import torch.nn as nn |
||||
import torch.nn.functional as F |
||||
|
||||
from colossalai.accelerator import get_accelerator |
||||
from colossalai.booster.plugin.fp8_hook import FP8Hook |
||||
from colossalai.quantization.fp8 import linear_fp8 |
||||
from colossalai.tensor.colo_parameter import ColoParameter |
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager |
||||
from colossalai.utils import get_current_device |
||||
|
||||
REPLACED = False |
||||
TRIGGERED = False |
||||
|
||||
|
||||
def new_linear_fp8(x, w, bias=None): |
||||
global TRIGGERED |
||||
TRIGGERED = True |
||||
return linear_fp8(x, w, bias) |
||||
|
||||
|
||||
class FP8TestHook(FP8Hook): |
||||
def rewrite_op(self, func): |
||||
func = super().rewrite_op(func) |
||||
if func is linear_fp8: |
||||
global REPLACED |
||||
REPLACED = True |
||||
return new_linear_fp8 |
||||
return func |
||||
|
||||
|
||||
D_IN, D_OUT = 16, 32 |
||||
B, S = 2, 64 |
||||
DTYPE = torch.bfloat16 |
||||
|
||||
|
||||
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") |
||||
def test_fp8_hook(): |
||||
# create tensors |
||||
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE)) |
||||
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True) |
||||
w.__class__ = ColoParameter |
||||
w.__init__(w, requires_grad=True) |
||||
hook = FP8TestHook() |
||||
with ColoParamOpHookManager.use_hooks(hook): |
||||
o = F.linear(x, w) |
||||
assert o.shape == (B, S, D_OUT) |
||||
assert REPLACED |
||||
assert TRIGGERED |
Loading…
Reference in new issue