mirror of https://github.com/hpcaitech/ColossalAI
parent
dcc44aab8d
commit
0d3b0bd864
|
@ -61,7 +61,9 @@ class OptimizerParamCheckState(enum.Enum):
|
||||||
|
|
||||||
|
|
||||||
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||||
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
|
def __init__(
|
||||||
|
self, module: nn.Module, precision: str, overlap_allgather: bool = False, cast_inputs: bool = True
|
||||||
|
) -> None:
|
||||||
super().__init__(module)
|
super().__init__(module)
|
||||||
self.dtype = None
|
self.dtype = None
|
||||||
if precision == "fp16":
|
if precision == "fp16":
|
||||||
|
@ -73,7 +75,7 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
||||||
module = module.to(get_accelerator().get_current_device())
|
module = module.to(get_accelerator().get_current_device())
|
||||||
self.module = module
|
self.module = module
|
||||||
self.convert_fn = None
|
self.convert_fn = None
|
||||||
if self.dtype is not None:
|
if self.dtype is not None and cast_inputs:
|
||||||
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
||||||
self.overlap_allgather = overlap_allgather
|
self.overlap_allgather = overlap_allgather
|
||||||
if overlap_allgather:
|
if overlap_allgather:
|
||||||
|
@ -334,6 +336,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
cpu_offload: bool = False,
|
cpu_offload: bool = False,
|
||||||
master_weights: bool = True,
|
master_weights: bool = True,
|
||||||
verbose: bool = False,
|
verbose: bool = False,
|
||||||
|
cast_inputs: bool = True,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training"
|
||||||
|
@ -361,6 +364,8 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
self.lora_enabled = False
|
self.lora_enabled = False
|
||||||
self.verbose = verbose
|
self.verbose = verbose
|
||||||
self.logger = get_dist_logger()
|
self.logger = get_dist_logger()
|
||||||
|
self.cast_inputs = cast_inputs
|
||||||
|
|
||||||
# set class name with stage, for better error message
|
# set class name with stage, for better error message
|
||||||
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}")
|
||||||
|
|
||||||
|
@ -475,7 +480,10 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
|
|
||||||
if not isinstance(model, ModelWrapper):
|
if not isinstance(model, ModelWrapper):
|
||||||
model = LowLevelZeroModel(
|
model = LowLevelZeroModel(
|
||||||
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
|
model,
|
||||||
|
self.precision,
|
||||||
|
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"],
|
||||||
|
cast_inputs=self.cast_inputs,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Support Galore + ZeRO
|
# TODO: Support Galore + ZeRO
|
||||||
|
|
Loading…
Reference in New Issue