mirror of https://github.com/hpcaitech/ColossalAI
commit
afe845ff15
|
@ -63,7 +63,12 @@ class OptimizerParamCheckState(enum.Enum):
|
|||
|
||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||
def __init__(
|
||||
self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False
|
||||
self,
|
||||
module: nn.Module,
|
||||
precision: str,
|
||||
overlap_allgather: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
super().__init__(module)
|
||||
self.dtype = None
|
||||
|
@ -77,7 +82,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|||
self.module = module
|
||||
self.convert_fn = None
|
||||
self.use_fp8 = use_fp8
|
||||
if self.dtype is not None:
|
||||
if self.dtype is not None and cast_inputs:
|
||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||
self.overlap_allgather = overlap_allgather
|
||||
self.op_hooks = []
|
||||
|
@ -342,6 +347,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
cpu_offload: bool = False,
|
||||
master_weights: bool = True,
|
||||
verbose: bool = False,
|
||||
cast_inputs: bool = True,
|
||||
fp8_communication: bool = False,
|
||||
use_fp8: bool = False,
|
||||
) -> None:
|
||||
|
@ -372,6 +378,8 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
self.lora_enabled = False
|
||||
self.verbose = verbose
|
||||
self.logger = get_dist_logger()
|
||||
self.cast_inputs = cast_inputs
|
||||
|
||||
self.use_fp8 = use_fp8
|
||||
# set class name with stage, for better error message
|
||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||
|
@ -490,6 +498,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
model,
|
||||
self.precision,
|
||||
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
||||
cast_inputs=self.cast_inputs,
|
||||
use_fp8=self.use_fp8,
|
||||
)
|
||||
|
||||
|
|
Loading…
Reference in New Issue