diff --git a/colossalai/booster/plugin/gemini_plugin.py b/colossalai/booster/plugin/gemini_plugin.py index 5ab8f05ad..d0d3275cf 100644 --- a/colossalai/booster/plugin/gemini_plugin.py +++ b/colossalai/booster/plugin/gemini_plugin.py @@ -363,6 +363,7 @@ class GeminiPlugin(DPPluginBase): enable_jit_fused: bool = False, enable_sequence_overlap: bool = False, enable_async_reduce: bool = True, + use_fp8: bool = False, verbose: bool = False, fp8_communication: bool = False, ) -> None: @@ -397,6 +398,7 @@ class GeminiPlugin(DPPluginBase): max_prefetch=max_prefetch, enable_async_reduce=enable_async_reduce, fp8_communication=fp8_communication, + use_fp8=use_fp8, ) self.zero_optim_config = dict( gpu_margin_mem_ratio=gpu_margin_mem_ratio, diff --git a/colossalai/booster/plugin/hybrid_parallel_plugin.py b/colossalai/booster/plugin/hybrid_parallel_plugin.py index f585abdb5..9c0ee9e50 100644 --- a/colossalai/booster/plugin/hybrid_parallel_plugin.py +++ b/colossalai/booster/plugin/hybrid_parallel_plugin.py @@ -31,6 +31,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.quantization import BnbQuantizationConfig, quantize_model +from colossalai.quantization.fp8_hook import FP8Hook from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer from colossalai.shardformer.layer.utils import SeqParallelUtils from colossalai.shardformer.policies.base_policy import Policy @@ -40,7 +41,6 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.zero.low_level import LowLevelZeroOptimizer from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle -from .fp8_hook import FP8Hook from .pp_plugin_base import PipelinePluginBase SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"] diff --git a/colossalai/quantization/fp8.py b/colossalai/quantization/fp8.py index e933680a9..53febd16c 100644 --- a/colossalai/quantization/fp8.py +++ b/colossalai/quantization/fp8.py @@ -652,5 +652,5 @@ class _LinearFp8(torch.autograd.Function): return x_grad.reshape(ctx.x_shape), w_grad, bias_grad -def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: - return _LinearFp8.apply(x, w, bias) +def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor: + return _LinearFp8.apply(input, weight, bias) diff --git a/colossalai/booster/plugin/fp8_hook.py b/colossalai/quantization/fp8_hook.py similarity index 100% rename from colossalai/booster/plugin/fp8_hook.py rename to colossalai/quantization/fp8_hook.py diff --git a/colossalai/zero/gemini/gemini_ddp.py b/colossalai/zero/gemini/gemini_ddp.py index 0b2039a4d..dbaae6610 100644 --- a/colossalai/zero/gemini/gemini_ddp.py +++ b/colossalai/zero/gemini/gemini_ddp.py @@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_ from colossalai.interface import ModelWrapper from colossalai.lazy import LazyTensor from colossalai.logging import get_dist_logger +from colossalai.quantization.fp8_hook import FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.d_tensor import ( distribute_tensor, @@ -99,6 +100,7 @@ class GeminiDDP(ModelWrapper): verbose: bool = False, enable_async_reduce: bool = True, fp8_communication: bool = False, + use_fp8: bool = False, ) -> None: assert mixed_precision in (torch.float16, torch.bfloat16) reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False @@ -138,6 +140,9 @@ class GeminiDDP(ModelWrapper): ) self.force_outputs_fp32 = force_outputs_fp32 self.param_op_hook = GeminiZeROHook(self.gemini_manager) + self.hooks = [self.param_op_hook] + if use_fp8: + self.hooks.append(FP8Hook()) self.fp32_params: List[torch.Tensor] = list() self.fp16_params: List[ColoParameter] = list() self.grads_device: Dict[torch.Tensor, torch.device] = dict() @@ -310,7 +315,7 @@ class GeminiDDP(ModelWrapper): outputs = self._inference_forward(*args, **kwargs) else: self.gemini_manager.pre_iter(*args) - with ColoParamOpHookManager.use_hooks(self.param_op_hook): + with ColoParamOpHookManager.use_hooks(*self.hooks): outputs = self.module(*args, **kwargs) if self.force_outputs_fp32: @@ -319,7 +324,7 @@ class GeminiDDP(ModelWrapper): def _inference_forward(self, *args, **kwargs): """This function is only triggered for inference.""" - fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook) + fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks) if not self.scatter_after_inference: # gather all chunks for chunk in self.chunk_manager.get_chunks(self.fp16_params): @@ -372,7 +377,7 @@ class GeminiDDP(ModelWrapper): def backward(self, loss: torch.Tensor): self._pre_backward() - with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook): + with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks): loss.backward() self._post_backward() diff --git a/examples/language/llama/benchmark.py b/examples/language/llama/benchmark.py index 2bd9671d8..07583161b 100644 --- a/examples/language/llama/benchmark.py +++ b/examples/language/llama/benchmark.py @@ -99,6 +99,8 @@ def main(): parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number") parser.add_argument("--no_cache", action="store_true") parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication") + parser.add_argument("--overlap_allgather", action="store_true") + parser.add_argument("--use_fp8", action="store_true") args = parser.parse_args() colossalai.launch_from_torch() @@ -136,6 +138,7 @@ def main(): enable_flash_attention=args.xformers, max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, + use_fp8=args.use_fp8, ) elif args.plugin == "gemini_auto": plugin = GeminiPlugin( @@ -148,6 +151,7 @@ def main(): max_prefetch=args.prefetch_num, enable_async_reduce=not args.disable_async_reduce, enable_flash_attention=args.xformers, + use_fp8=args.use_fp8, ) elif args.plugin == "fsdp": if use_empty_init: @@ -207,6 +211,8 @@ def main(): dp_outside=False, overlap_p2p=args.overlap, enable_metadata_cache=not args.no_cache, + overlap_allgather=args.overlap_allgather, + use_fp8=args.use_fp8, **hybrid_kwargs, ) elif args.plugin == "3d_cpu": @@ -223,6 +229,7 @@ def main(): initial_scale=2**8, precision="bf16", overlap_p2p=args.overlap, + use_fp8=args.use_fp8, ) else: raise ValueError(f"Unknown plugin {args.plugin}") diff --git a/tests/test_fp8/test_fp8_hook.py b/tests/test_fp8/test_fp8_hook.py index 6cc147be7..abd5d09e1 100644 --- a/tests/test_fp8/test_fp8_hook.py +++ b/tests/test_fp8/test_fp8_hook.py @@ -4,8 +4,8 @@ import torch.nn as nn import torch.nn.functional as F from colossalai.accelerator import get_accelerator -from colossalai.booster.plugin.fp8_hook import FP8Hook from colossalai.quantization.fp8 import linear_fp8 +from colossalai.quantization.fp8_hook import FP8Hook from colossalai.tensor.colo_parameter import ColoParameter from colossalai.tensor.param_op_hook import ColoParamOpHookManager from colossalai.utils import get_current_device