[fp8] support gemini plugin (#5978)

* [fp8] refactor hook

* [fp8] support gemini plugin

* [example] add fp8 option for llama benchmark
pull/5984/head
Hongxin Liu 2024-08-09 14:09:48 +08:00 committed by GitHub
parent 4b9bec8176
commit 8241c0c054
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 21 additions and 7 deletions

View File

@ -363,6 +363,7 @@ class GeminiPlugin(DPPluginBase):
enable_jit_fused: bool = False,
enable_sequence_overlap: bool = False,
enable_async_reduce: bool = True,
use_fp8: bool = False,
verbose: bool = False,
fp8_communication: bool = False,
) -> None:
@ -397,6 +398,7 @@ class GeminiPlugin(DPPluginBase):
max_prefetch=max_prefetch,
enable_async_reduce=enable_async_reduce,
fp8_communication=fp8_communication,
use_fp8=use_fp8,
)
self.zero_optim_config = dict(
gpu_margin_mem_ratio=gpu_margin_mem_ratio,

View File

@ -31,6 +31,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
from colossalai.shardformer.layer.utils import SeqParallelUtils
from colossalai.shardformer.policies.base_policy import Policy
@ -40,7 +41,6 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.zero.low_level import LowLevelZeroOptimizer
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
from .fp8_hook import FP8Hook
from .pp_plugin_base import PipelinePluginBase
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]

View File

@ -652,5 +652,5 @@ class _LinearFp8(torch.autograd.Function):
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(x, w, bias)
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return _LinearFp8.apply(input, weight, bias)

View File

@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_
from colossalai.interface import ModelWrapper
from colossalai.lazy import LazyTensor
from colossalai.logging import get_dist_logger
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.d_tensor import (
distribute_tensor,
@ -99,6 +100,7 @@ class GeminiDDP(ModelWrapper):
verbose: bool = False,
enable_async_reduce: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
assert mixed_precision in (torch.float16, torch.bfloat16)
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
@ -138,6 +140,9 @@ class GeminiDDP(ModelWrapper):
)
self.force_outputs_fp32 = force_outputs_fp32
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
self.hooks = [self.param_op_hook]
if use_fp8:
self.hooks.append(FP8Hook())
self.fp32_params: List[torch.Tensor] = list()
self.fp16_params: List[ColoParameter] = list()
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
@ -310,7 +315,7 @@ class GeminiDDP(ModelWrapper):
outputs = self._inference_forward(*args, **kwargs)
else:
self.gemini_manager.pre_iter(*args)
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
with ColoParamOpHookManager.use_hooks(*self.hooks):
outputs = self.module(*args, **kwargs)
if self.force_outputs_fp32:
@ -319,7 +324,7 @@ class GeminiDDP(ModelWrapper):
def _inference_forward(self, *args, **kwargs):
"""This function is only triggered for inference."""
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks)
if not self.scatter_after_inference:
# gather all chunks
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
@ -372,7 +377,7 @@ class GeminiDDP(ModelWrapper):
def backward(self, loss: torch.Tensor):
self._pre_backward()
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks):
loss.backward()
self._post_backward()

View File

@ -99,6 +99,8 @@ def main():
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
parser.add_argument("--no_cache", action="store_true")
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
parser.add_argument("--overlap_allgather", action="store_true")
parser.add_argument("--use_fp8", action="store_true")
args = parser.parse_args()
colossalai.launch_from_torch()
@ -136,6 +138,7 @@ def main():
enable_flash_attention=args.xformers,
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
use_fp8=args.use_fp8,
)
elif args.plugin == "gemini_auto":
plugin = GeminiPlugin(
@ -148,6 +151,7 @@ def main():
max_prefetch=args.prefetch_num,
enable_async_reduce=not args.disable_async_reduce,
enable_flash_attention=args.xformers,
use_fp8=args.use_fp8,
)
elif args.plugin == "fsdp":
if use_empty_init:
@ -207,6 +211,8 @@ def main():
dp_outside=False,
overlap_p2p=args.overlap,
enable_metadata_cache=not args.no_cache,
overlap_allgather=args.overlap_allgather,
use_fp8=args.use_fp8,
**hybrid_kwargs,
)
elif args.plugin == "3d_cpu":
@ -223,6 +229,7 @@ def main():
initial_scale=2**8,
precision="bf16",
overlap_p2p=args.overlap,
use_fp8=args.use_fp8,
)
else:
raise ValueError(f"Unknown plugin {args.plugin}")

View File

@ -4,8 +4,8 @@ import torch.nn as nn
import torch.nn.functional as F
from colossalai.accelerator import get_accelerator
from colossalai.booster.plugin.fp8_hook import FP8Hook
from colossalai.quantization.fp8 import linear_fp8
from colossalai.quantization.fp8_hook import FP8Hook
from colossalai.tensor.colo_parameter import ColoParameter
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
from colossalai.utils import get_current_device