mirror of https://github.com/hpcaitech/ColossalAI
[fp8] support gemini plugin (#5978)
* [fp8] refactor hook * [fp8] support gemini plugin * [example] add fp8 option for llama benchmarkpull/5984/head
parent
4b9bec8176
commit
8241c0c054
|
@ -363,6 +363,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
enable_jit_fused: bool = False,
|
||||
enable_sequence_overlap: bool = False,
|
||||
enable_async_reduce: bool = True,
|
||||
use_fp8: bool = False,
|
||||
verbose: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
) -> None:
|
||||
|
@ -397,6 +398,7 @@ class GeminiPlugin(DPPluginBase):
|
|||
max_prefetch=max_prefetch,
|
||||
enable_async_reduce=enable_async_reduce,
|
||||
fp8_communication=fp8_communication,
|
||||
use_fp8=use_fp8,
|
||||
)
|
||||
self.zero_optim_config = dict(
|
||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||
|
|
|
@ -31,6 +31,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
|||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
@ -40,7 +41,6 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
|||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||
|
||||
from .fp8_hook import FP8Hook
|
||||
from .pp_plugin_base import PipelinePluginBase
|
||||
|
||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||
|
|
|
@ -652,5 +652,5 @@ class _LinearFp8(torch.autograd.Function):
|
|||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||
|
||||
|
||||
def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(x, w, bias)
|
||||
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
return _LinearFp8.apply(input, weight, bias)
|
||||
|
|
|
@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_
|
|||
from colossalai.interface import ModelWrapper
|
||||
from colossalai.lazy import LazyTensor
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.d_tensor import (
|
||||
distribute_tensor,
|
||||
|
@ -99,6 +100,7 @@ class GeminiDDP(ModelWrapper):
|
|||
verbose: bool = False,
|
||||
enable_async_reduce: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
||||
|
@ -138,6 +140,9 @@ class GeminiDDP(ModelWrapper):
|
|||
)
|
||||
self.force_outputs_fp32 = force_outputs_fp32
|
||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||
self.hooks = [self.param_op_hook]
|
||||
if use_fp8:
|
||||
self.hooks.append(FP8Hook())
|
||||
self.fp32_params: List[torch.Tensor] = list()
|
||||
self.fp16_params: List[ColoParameter] = list()
|
||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||
|
@ -310,7 +315,7 @@ class GeminiDDP(ModelWrapper):
|
|||
outputs = self._inference_forward(*args, **kwargs)
|
||||
else:
|
||||
self.gemini_manager.pre_iter(*args)
|
||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with ColoParamOpHookManager.use_hooks(*self.hooks):
|
||||
outputs = self.module(*args, **kwargs)
|
||||
|
||||
if self.force_outputs_fp32:
|
||||
|
@ -319,7 +324,7 @@ class GeminiDDP(ModelWrapper):
|
|||
|
||||
def _inference_forward(self, *args, **kwargs):
|
||||
"""This function is only triggered for inference."""
|
||||
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
|
||||
fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks)
|
||||
if not self.scatter_after_inference:
|
||||
# gather all chunks
|
||||
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
|
||||
|
@ -372,7 +377,7 @@ class GeminiDDP(ModelWrapper):
|
|||
|
||||
def backward(self, loss: torch.Tensor):
|
||||
self._pre_backward()
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks):
|
||||
loss.backward()
|
||||
self._post_backward()
|
||||
|
||||
|
|
|
@ -99,6 +99,8 @@ def main():
|
|||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||
parser.add_argument("--no_cache", action="store_true")
|
||||
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||
parser.add_argument("--overlap_allgather", action="store_true")
|
||||
parser.add_argument("--use_fp8", action="store_true")
|
||||
args = parser.parse_args()
|
||||
|
||||
colossalai.launch_from_torch()
|
||||
|
@ -136,6 +138,7 @@ def main():
|
|||
enable_flash_attention=args.xformers,
|
||||
max_prefetch=args.prefetch_num,
|
||||
enable_async_reduce=not args.disable_async_reduce,
|
||||
use_fp8=args.use_fp8,
|
||||
)
|
||||
elif args.plugin == "gemini_auto":
|
||||
plugin = GeminiPlugin(
|
||||
|
@ -148,6 +151,7 @@ def main():
|
|||
max_prefetch=args.prefetch_num,
|
||||
enable_async_reduce=not args.disable_async_reduce,
|
||||
enable_flash_attention=args.xformers,
|
||||
use_fp8=args.use_fp8,
|
||||
)
|
||||
elif args.plugin == "fsdp":
|
||||
if use_empty_init:
|
||||
|
@ -207,6 +211,8 @@ def main():
|
|||
dp_outside=False,
|
||||
overlap_p2p=args.overlap,
|
||||
enable_metadata_cache=not args.no_cache,
|
||||
overlap_allgather=args.overlap_allgather,
|
||||
use_fp8=args.use_fp8,
|
||||
**hybrid_kwargs,
|
||||
)
|
||||
elif args.plugin == "3d_cpu":
|
||||
|
@ -223,6 +229,7 @@ def main():
|
|||
initial_scale=2**8,
|
||||
precision="bf16",
|
||||
overlap_p2p=args.overlap,
|
||||
use_fp8=args.use_fp8,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||
|
|
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
|
||||
from colossalai.accelerator import get_accelerator
|
||||
from colossalai.booster.plugin.fp8_hook import FP8Hook
|
||||
from colossalai.quantization.fp8 import linear_fp8
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.utils import get_current_device
|
||||
|
|
Loading…
Reference in New Issue