mirror of https://github.com/hpcaitech/ColossalAI
[fp8] support gemini plugin (#5978)
* [fp8] refactor hook * [fp8] support gemini plugin * [example] add fp8 option for llama benchmarkpull/5984/head
parent
4b9bec8176
commit
8241c0c054
|
@ -363,6 +363,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
enable_jit_fused: bool = False,
|
enable_jit_fused: bool = False,
|
||||||
enable_sequence_overlap: bool = False,
|
enable_sequence_overlap: bool = False,
|
||||||
enable_async_reduce: bool = True,
|
enable_async_reduce: bool = True,
|
||||||
|
use_fp8: bool = False,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
@ -397,6 +398,7 @@ class GeminiPlugin(DPPluginBase):
|
||||||
max_prefetch=max_prefetch,
|
max_prefetch=max_prefetch,
|
||||||
enable_async_reduce=enable_async_reduce,
|
enable_async_reduce=enable_async_reduce,
|
||||||
fp8_communication=fp8_communication,
|
fp8_communication=fp8_communication,
|
||||||
|
use_fp8=use_fp8,
|
||||||
)
|
)
|
||||||
self.zero_optim_config = dict(
|
self.zero_optim_config = dict(
|
||||||
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
gpu_margin_mem_ratio=gpu_margin_mem_ratio,
|
||||||
|
|
|
@ -31,6 +31,7 @@ from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||||
|
from colossalai.quantization.fp8_hook import FP8Hook
|
||||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||||
from colossalai.shardformer.policies.base_policy import Policy
|
from colossalai.shardformer.policies.base_policy import Policy
|
||||||
|
@ -40,7 +41,6 @@ from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
from colossalai.zero.low_level import LowLevelZeroOptimizer
|
||||||
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
||||||
|
|
||||||
from .fp8_hook import FP8Hook
|
|
||||||
from .pp_plugin_base import PipelinePluginBase
|
from .pp_plugin_base import PipelinePluginBase
|
||||||
|
|
||||||
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
SUPPORT_SP_MODE = ["split_gather", "ring", "all_to_all"]
|
||||||
|
|
|
@ -652,5 +652,5 @@ class _LinearFp8(torch.autograd.Function):
|
||||||
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
return x_grad.reshape(ctx.x_shape), w_grad, bias_grad
|
||||||
|
|
||||||
|
|
||||||
def linear_fp8(x: torch.Tensor, w: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
def linear_fp8(input: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
return _LinearFp8.apply(x, w, bias)
|
return _LinearFp8.apply(input, weight, bias)
|
||||||
|
|
|
@ -15,6 +15,7 @@ from colossalai.checkpoint_io.utils import StateDictSharder, gather_distributed_
|
||||||
from colossalai.interface import ModelWrapper
|
from colossalai.interface import ModelWrapper
|
||||||
from colossalai.lazy import LazyTensor
|
from colossalai.lazy import LazyTensor
|
||||||
from colossalai.logging import get_dist_logger
|
from colossalai.logging import get_dist_logger
|
||||||
|
from colossalai.quantization.fp8_hook import FP8Hook
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
from colossalai.tensor.d_tensor import (
|
from colossalai.tensor.d_tensor import (
|
||||||
distribute_tensor,
|
distribute_tensor,
|
||||||
|
@ -99,6 +100,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
enable_async_reduce: bool = True,
|
enable_async_reduce: bool = True,
|
||||||
fp8_communication: bool = False,
|
fp8_communication: bool = False,
|
||||||
|
use_fp8: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert mixed_precision in (torch.float16, torch.bfloat16)
|
assert mixed_precision in (torch.float16, torch.bfloat16)
|
||||||
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
reuse_fp16_chunk = master_weights if not enable_gradient_accumulation else False
|
||||||
|
@ -138,6 +140,9 @@ class GeminiDDP(ModelWrapper):
|
||||||
)
|
)
|
||||||
self.force_outputs_fp32 = force_outputs_fp32
|
self.force_outputs_fp32 = force_outputs_fp32
|
||||||
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
self.param_op_hook = GeminiZeROHook(self.gemini_manager)
|
||||||
|
self.hooks = [self.param_op_hook]
|
||||||
|
if use_fp8:
|
||||||
|
self.hooks.append(FP8Hook())
|
||||||
self.fp32_params: List[torch.Tensor] = list()
|
self.fp32_params: List[torch.Tensor] = list()
|
||||||
self.fp16_params: List[ColoParameter] = list()
|
self.fp16_params: List[ColoParameter] = list()
|
||||||
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
self.grads_device: Dict[torch.Tensor, torch.device] = dict()
|
||||||
|
@ -310,7 +315,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
outputs = self._inference_forward(*args, **kwargs)
|
outputs = self._inference_forward(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
self.gemini_manager.pre_iter(*args)
|
self.gemini_manager.pre_iter(*args)
|
||||||
with ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
with ColoParamOpHookManager.use_hooks(*self.hooks):
|
||||||
outputs = self.module(*args, **kwargs)
|
outputs = self.module(*args, **kwargs)
|
||||||
|
|
||||||
if self.force_outputs_fp32:
|
if self.force_outputs_fp32:
|
||||||
|
@ -319,7 +324,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
|
|
||||||
def _inference_forward(self, *args, **kwargs):
|
def _inference_forward(self, *args, **kwargs):
|
||||||
"""This function is only triggered for inference."""
|
"""This function is only triggered for inference."""
|
||||||
fwd_ctx = ColoParamOpHookManager.use_hooks(self.param_op_hook)
|
fwd_ctx = ColoParamOpHookManager.use_hooks(*self.hooks)
|
||||||
if not self.scatter_after_inference:
|
if not self.scatter_after_inference:
|
||||||
# gather all chunks
|
# gather all chunks
|
||||||
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
|
for chunk in self.chunk_manager.get_chunks(self.fp16_params):
|
||||||
|
@ -372,7 +377,7 @@ class GeminiDDP(ModelWrapper):
|
||||||
|
|
||||||
def backward(self, loss: torch.Tensor):
|
def backward(self, loss: torch.Tensor):
|
||||||
self._pre_backward()
|
self._pre_backward()
|
||||||
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(self.param_op_hook):
|
with self.param_op_hook.switch_to_backward(), ColoParamOpHookManager.use_hooks(*self.hooks):
|
||||||
loss.backward()
|
loss.backward()
|
||||||
self._post_backward()
|
self._post_backward()
|
||||||
|
|
||||||
|
|
|
@ -99,6 +99,8 @@ def main():
|
||||||
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
parser.add_argument("--prefetch_num", type=int, default=0, help="chunk prefetch max number")
|
||||||
parser.add_argument("--no_cache", action="store_true")
|
parser.add_argument("--no_cache", action="store_true")
|
||||||
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
parser.add_argument("--use_fp8_comm", action="store_true", default=False, help="for using fp8 during communication")
|
||||||
|
parser.add_argument("--overlap_allgather", action="store_true")
|
||||||
|
parser.add_argument("--use_fp8", action="store_true")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
colossalai.launch_from_torch()
|
colossalai.launch_from_torch()
|
||||||
|
@ -136,6 +138,7 @@ def main():
|
||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
max_prefetch=args.prefetch_num,
|
max_prefetch=args.prefetch_num,
|
||||||
enable_async_reduce=not args.disable_async_reduce,
|
enable_async_reduce=not args.disable_async_reduce,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
)
|
)
|
||||||
elif args.plugin == "gemini_auto":
|
elif args.plugin == "gemini_auto":
|
||||||
plugin = GeminiPlugin(
|
plugin = GeminiPlugin(
|
||||||
|
@ -148,6 +151,7 @@ def main():
|
||||||
max_prefetch=args.prefetch_num,
|
max_prefetch=args.prefetch_num,
|
||||||
enable_async_reduce=not args.disable_async_reduce,
|
enable_async_reduce=not args.disable_async_reduce,
|
||||||
enable_flash_attention=args.xformers,
|
enable_flash_attention=args.xformers,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
)
|
)
|
||||||
elif args.plugin == "fsdp":
|
elif args.plugin == "fsdp":
|
||||||
if use_empty_init:
|
if use_empty_init:
|
||||||
|
@ -207,6 +211,8 @@ def main():
|
||||||
dp_outside=False,
|
dp_outside=False,
|
||||||
overlap_p2p=args.overlap,
|
overlap_p2p=args.overlap,
|
||||||
enable_metadata_cache=not args.no_cache,
|
enable_metadata_cache=not args.no_cache,
|
||||||
|
overlap_allgather=args.overlap_allgather,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
**hybrid_kwargs,
|
**hybrid_kwargs,
|
||||||
)
|
)
|
||||||
elif args.plugin == "3d_cpu":
|
elif args.plugin == "3d_cpu":
|
||||||
|
@ -223,6 +229,7 @@ def main():
|
||||||
initial_scale=2**8,
|
initial_scale=2**8,
|
||||||
precision="bf16",
|
precision="bf16",
|
||||||
overlap_p2p=args.overlap,
|
overlap_p2p=args.overlap,
|
||||||
|
use_fp8=args.use_fp8,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown plugin {args.plugin}")
|
raise ValueError(f"Unknown plugin {args.plugin}")
|
||||||
|
|
|
@ -4,8 +4,8 @@ import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from colossalai.accelerator import get_accelerator
|
from colossalai.accelerator import get_accelerator
|
||||||
from colossalai.booster.plugin.fp8_hook import FP8Hook
|
|
||||||
from colossalai.quantization.fp8 import linear_fp8
|
from colossalai.quantization.fp8 import linear_fp8
|
||||||
|
from colossalai.quantization.fp8_hook import FP8Hook
|
||||||
from colossalai.tensor.colo_parameter import ColoParameter
|
from colossalai.tensor.colo_parameter import ColoParameter
|
||||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||||
from colossalai.utils import get_current_device
|
from colossalai.utils import get_current_device
|
||||||
|
|
Loading…
Reference in New Issue