mirror of https://github.com/hpcaitech/ColossalAI
[fp8] support fp8 amp for hybrid parallel plugin (#5975)
* [fp8] support fp8 amp for hybrid parallel plugin * [test] add fp8 hook test * [fp8] fix fp8 linear compatibilitypull/5981/head
parent
76ea16466f
commit
ccabcf6485
|
@ -0,0 +1,23 @@
|
|||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.quantization.fp8 import linear_fp8
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHook
|
||||
|
||||
|
||||
class FP8Hook(ColoParamOpHook):
|
||||
def pre_forward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def post_forward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def pre_backward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def post_backward(self, params) -> None:
|
||||
pass
|
||||
|
||||
def rewrite_op(self, func):
|
||||
if func is F.linear:
|
||||
return linear_fp8
|
||||
return func
|
|
@ -40,6 +40,7 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
|||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||
|
||||
from .fp8_hook import FP8Hook
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||
|
@ -66,6 +67,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
ddp_config: dict,
|
||||
custom_policy: Policy,
|
||||
overlap_allgather: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
self.stage_manager = shard_config.pipeline_stage_manager
|
||||
self.shard_config = shard_config
|
||||
|
@ -75,6 +77,7 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
self.use_dpp = use_ddp
|
||||
self.require_grad_sync = True
|
||||
self.overlap_allgather = overlap_allgather
|
||||
self.use_fp8 = use_fp8
|
||||
|
||||
shardformer = ShardFormer(shard_config)
|
||||
if custom_policy is not None:
|
||||
|
@ -112,8 +115,12 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
module = DDP(module, process_group=dp_group, **ddp_config)
|
||||
|
||||
super().__init__(module)
|
||||
self.op_hooks = []
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
self.op_hooks.append(ZeroOpHook())
|
||||
if use_fp8:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if overlap_allgather or use_fp8:
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
|
@ -223,7 +230,11 @@ class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
|||
wait_all_gather_handle(p)
|
||||
|
||||
def _wait_all_gather(self):
|
||||
return ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
|
||||
return (
|
||||
ColoParamOpHookManager.use_hooks(*self.op_hooks)
|
||||
if (self.overlap_allgather or self.use_fp8)
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
|
||||
def get_param_info(optim: Optimizer):
|
||||
|
@ -1019,6 +1030,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
overlap_p2p: bool = True,
|
||||
overlap_allgather: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
|
@ -1063,6 +1075,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.enable_flash_attention = enable_flash_attention
|
||||
self.enable_jit_fused = enable_jit_fused
|
||||
self.enable_sequence_parallelism = enable_sequence_parallelism
|
||||
self.use_fp8 = use_fp8
|
||||
if dp_outside:
|
||||
self.dp_axis, self.pp_axis, self.tp_axis, self.sp_axis = 0, 1, 2, 3
|
||||
self.pg_mesh = ProcessGroupMesh(self.dp_size, self.pp_size, self.tp_size, self.sp_size)
|
||||
|
@ -1243,6 +1256,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
ddp_config=self.ddp_config,
|
||||
custom_policy=self.custom_policy,
|
||||
overlap_allgather=(self.zero_stage > 0 and self.zero_config["overlap_allgather"]),
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
if optimizer is not None and not isinstance(optimizer, OptimizerWrapper):
|
||||
if zero_stage == 0:
|
||||
|
|
|
@ -431,7 +431,8 @@ class _LinearFp8(torch.autograd.Function):
|
|||
if bias is not None:
|
||||
assert bias.dtype == x.dtype, "Bias should have the same dtype as input."
|
||||
# ensure x and w are row-major
|
||||
assert x.is_contiguous() and w.is_contiguous(), "Input and weight should be contiguous."
|
||||
x = x.contiguous()
|
||||
w = w.contiguous()
|
||||
ctx.x_shape = x.shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.out_dtype = x.dtype
|
||||
|
|
|
@ -61,6 +61,8 @@ class ColoParameter(ColoTensor, torch.nn.Parameter):
|
|||
with torch._C.DisableTorchFunction():
|
||||
new_args = ColoParamOpHookManager.pre_op(params, *args, *kwargs.values())
|
||||
args, kwargs = replace_args(args, kwargs, new_args)
|
||||
with torch._C.DisableTorchFunction():
|
||||
func = ColoParamOpHookManager.rewrite_op(func)
|
||||
ret = super().__torch_function__(func, types, args, kwargs)
|
||||
with torch._C.DisableTorchFunction():
|
||||
ret = ColoParamOpHookManager.post_op(params, ret)
|
||||
|
|
|
@ -30,6 +30,9 @@ class ColoParamOpHook(ABC):
|
|||
def post_backward(self, params: List[torch.Tensor]) -> None:
|
||||
pass
|
||||
|
||||
def rewrite_op(self, func) -> Any:
|
||||
return func
|
||||
|
||||
|
||||
class ColoParamOpHookManager:
|
||||
"""
|
||||
|
@ -101,6 +104,12 @@ class ColoParamOpHookManager:
|
|||
def has_hook() -> bool:
|
||||
return len(ColoParamOpHookManager.hooks) > 0
|
||||
|
||||
@staticmethod
|
||||
def rewrite_op(func) -> Any:
|
||||
for hook in ColoParamOpHookManager.hooks:
|
||||
func = hook.rewrite_op(func)
|
||||
return func
|
||||
|
||||
|
||||
class PreFwdPostBwd(torch.autograd.Function):
|
||||
@staticmethod
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster.plugin.fp8_hook import FP8Hook
|
||||
from colossalai.quantization.fp8 import linear_fp8
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
|
||||
REPLACED = False
|
||||
TRIGGERED = False
|
||||
|
||||
|
||||
def new_linear_fp8(x, w, bias=None):
|
||||
global TRIGGERED
|
||||
TRIGGERED = True
|
||||
return linear_fp8(x, w, bias)
|
||||
|
||||
|
||||
class FP8TestHook(FP8Hook):
|
||||
def rewrite_op(self, func):
|
||||
func = super().rewrite_op(func)
|
||||
if func is linear_fp8:
|
||||
global REPLACED
|
||||
REPLACED = True
|
||||
return new_linear_fp8
|
||||
return func
|
||||
|
||||
|
||||
D_IN, D_OUT = 16, 32
|
||||
B, S = 2, 64
|
||||
DTYPE = torch.bfloat16
|
||||
|
||||
|
||||
@pytest.mark.skipif(get_accelerator().get_device_capability()[0] < 9, reason="Test requires device capability >= 9.0")
|
||||
def test_fp8_hook():
|
||||
# create tensors
|
||||
w = nn.Parameter(torch.rand(D_OUT, D_IN, device=get_current_device(), dtype=DTYPE))
|
||||
x = torch.rand(B, S, D_IN, device=get_current_device(), dtype=DTYPE, requires_grad=True)
|
||||
w.__class__ = ColoParameter
|
||||
w.__init__(w, requires_grad=True)
|
||||
hook = FP8TestHook()
|
||||
with ColoParamOpHookManager.use_hooks(hook):
|
||||
o = F.linear(x, w)
|
||||
assert o.shape == (B, S, D_OUT)
|
||||
assert REPLACED
|
||||
assert TRIGGERED
|
Loading…
Reference in New Issue