Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

24 lines
521 B

import torch.nn.functional as F
from colossalai.quantization.fp8 import linear_fp8
from colossalai.tensor.param_op_hook import ColoParamOpHook
class FP8Hook(ColoParamOpHook):
def pre_forward(self, params) -> None:
pass
def post_forward(self, params) -> None:
pass
def pre_backward(self, params) -> None:
pass
def post_backward(self, params) -> None:
pass
def rewrite_op(self, func):
if func is F.linear:
return linear_fp8
return func