|
|
|
@ -35,6 +35,7 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
|
|
|
|
from colossalai.interface.optimizer import DistributedOptim |
|
|
|
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed |
|
|
|
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model |
|
|
|
|
from colossalai.quantization.fp8_hook import FP8Hook |
|
|
|
|
from colossalai.tensor.colo_parameter import ColoParameter |
|
|
|
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager |
|
|
|
|
from colossalai.zero import LowLevelZeroOptimizer |
|
|
|
@ -62,7 +63,9 @@ class OptimizerParamCheckState(enum.Enum):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroModel(ModelWrapper, AMPModelMixin): |
|
|
|
|
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None: |
|
|
|
|
def __init__( |
|
|
|
|
self, module: nn.Module, precision: str, overlap_allgather: bool = False, use_fp8: bool = False |
|
|
|
|
) -> None: |
|
|
|
|
super().__init__(module) |
|
|
|
|
self.dtype = None |
|
|
|
|
if precision == "fp16": |
|
|
|
@ -74,11 +77,16 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|
|
|
|
module = module.to(get_accelerator().get_current_device()) |
|
|
|
|
self.module = module |
|
|
|
|
self.convert_fn = None |
|
|
|
|
self.use_fp8 = use_fp8 |
|
|
|
|
if self.dtype is not None: |
|
|
|
|
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype) |
|
|
|
|
self.overlap_allgather = overlap_allgather |
|
|
|
|
self.op_hooks = [] |
|
|
|
|
if overlap_allgather: |
|
|
|
|
self.op_hook = ZeroOpHook() |
|
|
|
|
self.op_hooks.append(ZeroOpHook()) |
|
|
|
|
if use_fp8: |
|
|
|
|
self.op_hooks.append(FP8Hook()) |
|
|
|
|
if overlap_allgather or use_fp8: |
|
|
|
|
for p in module.parameters(): |
|
|
|
|
if p.requires_grad and type(p) is not ColoParameter: |
|
|
|
|
p.__class__ = ColoParameter |
|
|
|
@ -335,6 +343,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
master_weights: bool = True, |
|
|
|
|
verbose: bool = False, |
|
|
|
|
fp8_communication: bool = False, |
|
|
|
|
use_fp8: bool = False, |
|
|
|
|
) -> None: |
|
|
|
|
super().__init__() |
|
|
|
|
assert stage in (1, 2), f"LowLevelZeroPlugin only supports stage 1/2 training" |
|
|
|
@ -362,6 +371,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
) |
|
|
|
|
self.lora_enabled = False |
|
|
|
|
self.verbose = verbose |
|
|
|
|
self.use_fp8 = use_fp8 |
|
|
|
|
|
|
|
|
|
# set class name with stage, for better error message |
|
|
|
|
setattr(self.__class__, "__name__", f"LowLevelZeroPlugin_ZeRO-{stage}") |
|
|
|
@ -476,7 +486,10 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
|
|
|
|
|
if not isinstance(model, ModelWrapper): |
|
|
|
|
model = LowLevelZeroModel( |
|
|
|
|
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"] |
|
|
|
|
model, |
|
|
|
|
self.precision, |
|
|
|
|
overlap_allgather=self.zero_optim_kwargs["overlap_allgather"], |
|
|
|
|
use_fp8=self.use_fp8, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
# TODO: Support Galore + ZeRO |
|
|
|
|