mirror of https://github.com/hpcaitech/ColossalAI
[fp8] zero support fp8 linear. (#6006)
* fix * fix * fix * zero fp8 * zero fp8 * Update requirements.txtpull/6024/head
parent
3f09a6145f
commit
0a51319113
|
@ -35,6 +35,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
|||
from colossalai.interface.optimizer import DistributedOptim
|
||||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.quantization.fp8_hook import FP8Hook
|
||||
from colossalai.tensor.colo_parameter import ColoParameter
|
||||
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
||||
from colossalai.zero import LowLevelZeroOptimizer
|
||||
|
@ -62,7 +63,9 @@ class OptimizerParamCheckState(enum.Enum):
|
|||
|
||||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
|
||||
def __init__(
|
||||
self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
if precision == "fp16":
|
||||
|
@ -74,11 +77,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||
module = module.to(get_accelerator().get_current_device())
|
||||
self.module = module
|
||||
self.convert_fn = None
|
||||
self.use_fp8 = use_fp8
|
||||
if self.dtype is not None:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
self.overlap_allgather = overlap_allgather
|
||||
self.op_hooks = []
|
||||
if overlap_allgather:
|
||||
self.op_hook = ZeroOpHook()
|
||||
self.op_hooks.append(ZeroOpHook())
|
||||
if use_fp8:
|
||||
self.op_hooks.append(FP8Hook())
|
||||
if overlap_allgather or use_fp8:
|
||||
for p in module.parameters():
|
||||
if p.requires_grad and type(p) is not ColoParameter:
|
||||
p.__class__ = ColoParameter
|
||||
|
@ -335,6 +343,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||
|
@ -362,6 +371,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
)
|
||||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
self.use_fp8 = use_fp8
|
||||
|
||||
# set class name with stage, for better error message
|
||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||
|
@ -476,7 +486,10 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
|
||||
if not isinstance(model, ModelWrapper):
|
||||
model = LowLevelZeroModel(
|
||||
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
|
||||
model,
|
||||
self.precision,
|
||||
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
|
||||
# TODO: Support Galore + ZeRO
|
||||
|
|
|
@ -259,7 +259,6 @@ def main():
|
|||
if isinstance(plugin, (GeminiPlugin, HybridParallelPlugin))
|
||||
else nullcontext()
|
||||
)
|
||||
|
||||
init_kwargs = {}
|
||||
if config.model_type == "chatglm":
|
||||
init_kwargs["empty_init"] = False
|
||||
|
|
Loading…
Reference in New Issue