From 0d3b0bd86498a6de7a38ca5c5905d60a80c9f6a7 Mon Sep 17 00:00:00 2001 From: Hongxin Liu Date: Wed, 21 Aug 2024 10:21:26 +0800 Subject: [PATCH] [plugin] add cast inputs option for zero (#6003) (#6022) --- colossalai/booster/plugin/low_level_zero_plugin.py | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index cc3755e6f..185d34f12 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -61,7 +61,9 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): - def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: + def __init__( + self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True + ) -> None: super().__init__(module) self.dtype = None if precision == "fp16": @@ -73,7 +75,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): module = module.to(get_accelerator().get_current_device()) self.module = module self.convert_fn = None - if self.dtype is not None: + if self.dtype is not None and cast_inputs: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather if overlap_allgather: @@ -334,6 +336,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, + cast_inputs: bool = True, ) -> None: super().__init__() assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" @@ -361,6 +364,8 @@ class LowLevelZeroPlugin(DPPluginBase): self.lora_enabled = False self.verbose = verbose self.logger = get_dist_logger() + self.cast_inputs = cast_inputs + # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -475,7 +480,10 @@ class LowLevelZeroPlugin(DPPluginBase): if not isinstance(model, ModelWrapper): model = LowLevelZeroModel( - model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] + model, + self.precision, + overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], + cast_inputs=self.cast_inputs, ) # TODO: Support Galore + ZeRO