diff --git a/colossalai/booster/plugin/low_level_zero_plugin.py b/colossalai/booster/plugin/low_level_zero_plugin.py index 4188491c2..4082ffada 100644 --- a/colossalai/booster/plugin/low_level_zero_plugin.py +++ b/colossalai/booster/plugin/low_level_zero_plugin.py @@ -63,7 +63,7 @@ class OptimizerParamCheckState(enum.Enum): class LowLevelZeroModel(ModelWrapper, AMPModelMixin): def __init__( - self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False + self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True, use_fp8: bool = False ) -> None: super().__init__(module) self.dtype = None @@ -77,7 +77,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin): self.module = module self.convert_fn = None self.use_fp8 = use_fp8 - if self.dtype is not None: + if self.dtype is not None and cast_inputs: self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) self.overlap_allgather = overlap_allgather self.op_hooks = [] @@ -342,6 +342,7 @@ class LowLevelZeroPlugin(DPPluginBase): cpu_offload: bool = False, master_weights: bool = True, verbose: bool = False, + cast_inputs: bool = True, fp8_communication: bool = False, use_fp8: bool = False, ) -> None: @@ -372,6 +373,8 @@ class LowLevelZeroPlugin(DPPluginBase): self.lora_enabled = False self.verbose = verbose self.logger = get_dist_logger() + self.cast_inputs = cast_inputs + self.use_fp8 = use_fp8 # set class name with stage, for better error message setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") @@ -490,6 +493,7 @@ class LowLevelZeroPlugin(DPPluginBase): model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], + cast_inputs=self.cast_inputs, use_fp8=self.use_fp8, )