Browse Source

[fp8] support fp8 amp for hybrid parallel plugin (#5975)

* [fp8] support fp8 amp for hybrid parallel plugin

* [test] add fp8 hook test

* [fp8] fix fp8 linear compatibility
pull/5981/head
Hongxin Liu 4 months ago committed by GitHub
parent
commit
ccabcf6485
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
  1. 23
      colossalai/booster/plugin/fp8_hook.py
  2. 18
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  3. 3
      colossalai/quantization/fp8.py
  4. 2
      colossalai/tensor/colo_parameter.py
  5. 9
      colossalai/tensor/param_op_hook.py
  6. 50
      tests/test_fp8/test_fp8_hook.py

23
colossalai/booster/plugin/fp8_hook.py

@ -0,0 +1,23 @@
import torch.nn.functional as F
from colossalai.quantization.fp8 import linear_fp8
from colossalai.tensor.param_op_hook import ColoParamOpHook
class FP8Hook(ColoParamOpHook):
def pre_forward(self, params) -> None:
pass
def post_forward(self, params) -> None:
pass
def pre_backward(self, params) -> None:
pass
def post_backward(self, params) -> None:
pass
def rewrite_op(self, func):
if func is F.linear:
return linear_fp8
return func

18
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -40,6 +40,7 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .fp8_hook import FP8Hook
from .pp_plugin_base import PipelinePluginBase
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
@ -66,6 +67,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
ddp_config: dict,
custom_policy: Policy,
overlap_allgather: bool = False,
use_fp8: bool = False,
) -> None:
self.stage_manager = shard_config.pipeline_stage_manager
self.shard_config = shard_config
@ -75,6 +77,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
self.use_dpp = use_ddp
self.require_grad_sync = True
self.overlap_allgather = overlap_allgather
self.use_fp8 = use_fp8
shardformer = ShardFormer(shard_config)
if custom_policy is not None:
@ -112,8 +115,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
module = DDP(module, process_group=dp_group, **ddp_config)
super().__init__(module)
self.op_hooks = []
if overlap_allgather:
self.op_hook = ZeroOpHook()
self.op_hooks.append(ZeroOpHook())
if use_fp8:
self.op_hooks.append(FP8Hook())
if overlap_allgather or use_fp8:
for p in module.parameters():
if p.requires_grad and type(p) is not ColoParameter:
p.__class__ = ColoParameter
@ -223,7 +230,11 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
wait_all_gather_handle(p)
def _wait_all_gather(self):
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
return (
ColoParamOpHookManager.use_hooks(*self.op_hooks)
if (self.overlap_allgather or self.use_fp8)
else nullcontext()
)
def get_param_info(optim: Optimizer):
@ -1019,6 +1030,7 @@ class HybridParallelPlugin(PipelinePluginBase):
overlap_p2p: bool = True,
overlap_allgather: bool = False,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
super().__init__()
@ -1063,6 +1075,7 @@ class HybridParallelPlugin(PipelinePluginBase):
self.enable_flash_attention = enable_flash_attention
self.enable_jit_fused = enable_jit_fused
self.enable_sequence_parallelism = enable_sequence_parallelism
self.use_fp8 = use_fp8
if dp_outside:
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
@ -1243,6 +1256,7 @@ class HybridParallelPlugin(PipelinePluginBase):
ddp_config=self.ddp_config,
custom_policy=self.custom_policy,
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
use_fp8=self.use_fp8,
)
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
if zero_stage == 0:

3
colossalai/quantization/fp8.py

@ -431,7 +431,8 @@ class _LinearFp8(torch.autograd.Function):
if bias is not None:
assert bias.dtype == x.dtype, "Bias should have the same dtype as input."
# ensure x and w are row-major
assert x.is_contiguous() and w.is_contiguous(), "Input and weight should be contiguous."
x = x.contiguous()
w = w.contiguous()
ctx.x_shape = x.shape
ctx.has_bias = bias is not None
ctx.out_dtype = x.dtype

2
colossalai/tensor/colo_parameter.py

@ -61,6 +61,8 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
with torch._C.DisableTorchFunction():
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
args, kwargs = replace_args(args, kwargs, new_args)
with torch._C.DisableTorchFunction():
func = ColoParamOpHookManager.rewrite_op(func)
ret = super().__torch_function__(func, types, args, kwargs)
with torch._C.DisableTorchFunction():
ret = ColoParamOpHookManager.post_op(params, ret)

9
colossalai/tensor/param_op_hook.py

@ -30,6 +30,9 @@ class ColoParamOpHook(ABC):
def post_backward(self, params: List[torch.Tensor]) -> None:
pass
def rewrite_op(self, func) -> Any:
return func
class ColoParamOpHookManager:
"""
@ -101,6 +104,12 @@ class ColoParamOpHookManager:
def has_hook() -> bool:
return len(ColoParamOpHookManager.hooks) > 0
@staticmethod
def rewrite_op(func) -> Any:
for hook in ColoParamOpHookManager.hooks:
func = hook.rewrite_op(func)
return func
class PreFwdPostBwd(torch.autograd.Function):
@staticmethod

50
tests/test_fp8/test_fp8_hook.py

@ -0,0 +1,50 @@
import pytest
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.accelerator import get_accelerator
from colossalai.booster.plugin.fp8_hook import FP8Hook
from colossalai.quantization.fp8 import linear_fp8
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device
REPLACED = False
TRIGGERED = False
def new_linear_fp8(x, w, bias=None):
global TRIGGERED
TRIGGERED = True
return linear_fp8(x, w, bias)
class FP8TestHook(FP8Hook):
def rewrite_op(self, func):
func = super().rewrite_op(func)
if func is linear_fp8:
global REPLACED
REPLACED = True
return new_linear_fp8
return func
D_IN, D_OUT = 16, 32
B, S = 2, 64
DTYPE = torch.bfloat16
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
def test_fp8_hook():
# create tensors
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
w.__class__ = ColoParameter
w.__init__(w, requires_grad=True)
hook = FP8TestHook()
with ColoParamOpHookManager.use_hooks(hook):
o = F.linear(x, w)
assert o.shape == (B, S, D_OUT)
assert REPLACED
assert TRIGGERED
Loading…
Cancel
Save