Merge pull request #6024 from wangbluo/fix_merge

[fp8] merge
pull/6026/head
Wang Binluo 2024-08-22 11:07:04 +08:00 committed by GitHub
commit afe845ff15
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 11 additions and 2 deletions

View File

@ -63,7 +63,12 @@ class OptimizerParamCheckState(enum.Enum):
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
def __init__(
self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False
self,
module: nn.Module,
precision: str,
overlap_allgather: bool = False,
cast_inputs: bool = True,
use_fp8: bool = False,
) -> None:
super().__init__(module)
self.dtype = None
@ -77,7 +82,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
self.module = module
self.convert_fn = None
self.use_fp8 = use_fp8
if self.dtype is not None:
if self.dtype is not None and cast_inputs:
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
self.overlap_allgather = overlap_allgather
self.op_hooks = []
@ -342,6 +347,7 @@ class LowLevelZeroPlugin(DPPluginBase):
cpu_offload: bool = False,
master_weights: bool = True,
verbose: bool = False,
cast_inputs: bool = True,
fp8_communication: bool = False,
use_fp8: bool = False,
) -> None:
@ -372,6 +378,8 @@ class LowLevelZeroPlugin(DPPluginBase):
self.lora_enabled = False
self.verbose = verbose
self.logger = get_dist_logger()
self.cast_inputs = cast_inputs
self.use_fp8 = use_fp8
# set class name with stage, for better error message
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
@ -490,6 +498,7 @@ class LowLevelZeroPlugin(DPPluginBase):
model,
self.precision,
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
cast_inputs=self.cast_inputs,
use_fp8=self.use_fp8,
)