From ccabcf64858f98d580d9ff9e704f3bf6de57effc Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 7 Aug 2024 18:21:08 +0800 Subject: [PATCH] [fp8] support fp8 amp for hybrid parallel plugin (#5975) * [fp8] support fp8 amp for hybrid parallel plugin * [test] add fp8 hook test * [fp8] fix fp8 linear compatibility --- colossalai/booster/plugin/fp8_hook.py | 23 +++++++++ .../booster/plugin/hybrid_parallel_plugin.py | 18 ++++++- colossalai/quantization/fp8.py | 3 +- colossalai/tensor/colo_parameter.py | 2 + colossalai/tensor/param_op_hook.py | 9 ++++ tests/test_fp8/test_fp8_hook.py | 50 +++++++++++++++++++ 6 files changed, 102 insertions(+), 3 deletions(-) create mode 100644 colossalai/booster/plugin/fp8_hook.py create mode 100644 tests/test_fp8/test_fp8_hook.py diff --git a/colossalai/booster/plugin/fp8_hook.py b/colossalai/booster/plugin/fp8_hook.py new file mode 100644 index 000000000..6171dd755 --- /dev/null +++ b/colossalai/booster/plugin/fp8_hook.py @@ -0,0 +1,23 @@ +import torch.nn.functional as F + +from colossalai.quantization.fp8 import linear_fp8 +from colossalai.tensor.param_op_hook import ColoParamOpHook + + +class FP8Hook(ColoParamOpHook): + def pre_forward(self, params) -> None: + pass + + def post_forward(self, params) -> None: + pass + + def pre_backward(self, params) -> None: + pass + + def post_backward(self, params) -> None: + pass + + def rewrite_op(self, func): + if func is F.linear: + return linear_fp8 + return func diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index bf0788650..f585abdb5 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -40,6 +40,7 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle +from .fp8_hook import FP8Hook from .pp_plugin_base import PipelinePluginBase SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] @@ -66,6 +67,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): ddp_config: dict, custom_policy: Policy, overlap_allgather: bool = False, + use_fp8: bool = False, ) -> None: self.stage_manager = shard_config.pipeline_stage_manager self.shard_config = shard_config @@ -75,6 +77,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): self.use_dpp = use_ddp self.require_grad_sync = True self.overlap_allgather = overlap_allgather + self.use_fp8 = use_fp8 shardformer = ShardFormer(shard_config) if custom_policy is not None: @@ -112,8 +115,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): module = DDP(module, process_group=dp_group, **ddp_config) super().__init__(module) + self.op_hooks = [] if overlap_allgather: - self.op_hook = ZeroOpHook() + self.op_hooks.append(ZeroOpHook()) + if use_fp8: + self.op_hooks.append(FP8Hook()) + if overlap_allgather or use_fp8: for p in module.parameters(): if p.requires_grad and type(p) is not ColoParameter: p.__class__ = ColoParameter @@ -223,7 +230,11 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin): wait_all_gather_handle(p) def _wait_all_gather(self): - return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext() + return ( + ColoParamOpHookManager.use_hooks(*self.op_hooks) + if (self.overlap_allgather or self.use_fp8) + else nullcontext() + ) def get_param_info(optim: Optimizer): @@ -1019,6 +1030,7 @@ class HybridParallelPlugin(PipelinePluginBase): overlap_p2p: bool = True, overlap_allgather: bool = False, fp8_communication: bool = False, + use_fp8: bool = False, ) -> None: super().__init__() @@ -1063,6 +1075,7 @@ class HybridParallelPlugin(PipelinePluginBase): self.enable_flash_attention = enable_flash_attention self.enable_jit_fused = enable_jit_fused self.enable_sequence_parallelism = enable_sequence_parallelism + self.use_fp8 = use_fp8 if dp_outside: self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3 self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size) @@ -1243,6 +1256,7 @@ class HybridParallelPlugin(PipelinePluginBase): ddp_config=self.ddp_config, custom_policy=self.custom_policy, overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]), + use_fp8=self.use_fp8, ) if optimizer is not None and not isinstance(optimizer, OptimizerWrapper): if zero_stage == 0: diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index bc8c3ced4..6d777e8a4 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -431,7 +431,8 @@ class _LinearFp8(torch.autograd.Function): if bias is not None: assert bias.dtype == x.dtype, "Bias should have the same dtype as input." # ensure x and w are row-major - assert x.is_contiguous() and w.is_contiguous(), "Input and weight should be contiguous." + x = x.contiguous() + w = w.contiguous() ctx.x_shape = x.shape ctx.has_bias = bias is not None ctx.out_dtype = x.dtype diff --git a/colossalai/tensor/colo_parameter.py b/colossalai/tensor/colo_parameter.py index acb9fc4ae..8992b89a3 100644 --- a/colossalai/tensor/colo_parameter.py +++ b/colossalai/tensor/colo_parameter.py @@ -61,6 +61,8 @@ class ColoParameter(ColoTensor, torch.nn.Parameter): with torch._C.DisableTorchFunction(): new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values()) args, kwargs = replace_args(args, kwargs, new_args) + with torch._C.DisableTorchFunction(): + func = ColoParamOpHookManager.rewrite_op(func) ret = super().__torch_function__(func, types, args, kwargs) with torch._C.DisableTorchFunction(): ret = ColoParamOpHookManager.post_op(params, ret) diff --git a/colossalai/tensor/param_op_hook.py b/colossalai/tensor/param_op_hook.py index 40de43c43..c8dd5a0c8 100644 --- a/colossalai/tensor/param_op_hook.py +++ b/colossalai/tensor/param_op_hook.py @@ -30,6 +30,9 @@ class ColoParamOpHook(ABC): def post_backward(self, params: List[torch.Tensor]) -> None: pass + def rewrite_op(self, func) -> Any: + return func + class ColoParamOpHookManager: """ @@ -101,6 +104,12 @@ class ColoParamOpHookManager: def has_hook() -> bool: return len(ColoParamOpHookManager.hooks) > 0 + @staticmethod + def rewrite_op(func) -> Any: + for hook in ColoParamOpHookManager.hooks: + func = hook.rewrite_op(func) + return func + class PreFwdPostBwd(torch.autograd.Function): @staticmethod diff --git a/tests/test_fp8/test_fp8_hook.py b/tests/test_fp8/test_fp8_hook.py new file mode 100644 index 000000000..6cc147be7 --- /dev/null +++ b/tests/test_fp8/test_fp8_hook.py @@ -0,0 +1,50 @@ +import pytest +import torch +import torch.nn as nn +import torch.nn.functional as F + +from colossalai.accelerator import get_accelerator +from colossalai.booster.plugin.fp8_hook import FP8Hook +from colossalai.quantization.fp8 import linear_fp8 +from colossalai.tensor.colo_parameter import ColoParameter +from colossalai.tensor.param_op_hook import ColoParamOpHookManager +from colossalai.utils import get_current_device + +REPLACED = False +TRIGGERED = False + + +def new_linear_fp8(x, w, bias=None): + global TRIGGERED + TRIGGERED = True + return linear_fp8(x, w, bias) + + +class FP8TestHook(FP8Hook): + def rewrite_op(self, func): + func = super().rewrite_op(func) + if func is linear_fp8: + global REPLACED + REPLACED = True + return new_linear_fp8 + return func + + +D_IN, D_OUT = 16, 32 +B, S = 2, 64 +DTYPE = torch.bfloat16 + + +@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0") +def test_fp8_hook(): + # create tensors + w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE)) + x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True) + w.__class__ = ColoParameter + w.__init__(w, requires_grad=True) + hook = FP8TestHook() + with ColoParamOpHookManager.use_hooks(hook): + o = F.linear(x, w) + assert o.shape == (B, S, D_OUT) + assert REPLACED + assert TRIGGERED