mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
main
support-npu
feature/zerobubble
feature/async-io
pre-commit-ci-update-config
ckpt
supercooledith-patch-1
flybird11111-patch-1
ColossalChat
colossalchat
moe_sp
dev/zero-offload
colossalchat_upgrade
fix-setup
feature/colossal-infer
fix/format
feat/online-serving
feature/lora
llama3
feat/speculative-decoding
hotfix/kernel_build_before_load
feat/moe
refactor/inference
feature/inference-refactor
hotfix/example_test
cloud/coati
feature/2-stage
feature/stable-diffusion
develop
feature/elixir
dev/gpt2_metainfo_patch
v0.0.1-beta
v0.0.2
v0.1.0
v0.1.1
v0.1.10
v0.1.11rc1
v0.1.11rc2
v0.1.11rc3
v0.1.11rc4
v0.1.11rc5
v0.1.12
v0.1.13
v0.1.2
v0.1.3
v0.1.4
v0.1.5
v0.1.6
v0.1.7
v0.1.8
v0.1.9
v0.2.0
v0.2.1
v0.2.2
v0.2.3
v0.2.4
v0.2.5
v0.2.6
v0.2.7
v0.2.8
v0.3.0
v0.3.1
v0.3.2
v0.3.3
v0.3.4
v0.3.5
v0.3.6
v0.3.7
v0.3.8
v0.3.9
v0.4.0
v0.4.1
v0.4.2
v0.4.3
v0.4.4
v0.4.5
v0.4.6
${ noResults }
24 lines
521 B
24 lines
521 B
import torch.nn.functional as F
|
|||
|
|||
from colossalai.quantization.fp8 import linear_fp8
|
|||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
|||
|
|||
|
|||
class FP8Hook(ColoParamOpHook):
|
|||
def pre_forward(self, params) -> None:
|
|||
pass
|
|||
|
|||
def post_forward(self, params) -> None:
|
|||
pass
|
|||
|
|||
def pre_backward(self, params) -> None:
|
|||
pass
|
|||
|
|||
def post_backward(self, params) -> None:
|
|||
pass
|
|||
|
|||
def rewrite_op(self, func):
|
|||
if func is F.linear:
|
|||
return linear_fp8
|
|||
return func
|