|
|
|
@ -2,6 +2,7 @@ import enum
|
|
|
|
|
import logging
|
|
|
|
|
import os
|
|
|
|
|
import warnings
|
|
|
|
|
from contextlib import nullcontext
|
|
|
|
|
from functools import partial
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from types import MethodType
|
|
|
|
@ -34,7 +35,10 @@ from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
|
|
|
|
|
from colossalai.interface.optimizer import DistributedOptim
|
|
|
|
|
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
|
|
|
|
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
|
|
|
|
from colossalai.tensor.colo_parameter import ColoParameter
|
|
|
|
|
from colossalai.tensor.param_op_hook import ColoParamOpHookManager
|
|
|
|
|
from colossalai.zero import LowLevelZeroOptimizer
|
|
|
|
|
from colossalai.zero.low_level.zero_hook import ZeroOpHook, wait_all_gather_handle
|
|
|
|
|
|
|
|
|
|
from .dp_plugin_base import DPPluginBase
|
|
|
|
|
from .torch_ddp_plugin import TorchDDPCheckpointIO
|
|
|
|
@ -58,7 +62,7 @@ class OptimizerParamCheckState(enum.Enum):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|
|
|
|
def __init__(self, module: nn.Module, precision: str) -> None:
|
|
|
|
|
def __init__(self, module: nn.Module, precision: str, overlap_allgather: bool = False) -> None:
|
|
|
|
|
super().__init__(module)
|
|
|
|
|
self.dtype = None
|
|
|
|
|
if precision == "fp16":
|
|
|
|
@ -72,13 +76,26 @@ class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|
|
|
|
self.convert_fn = None
|
|
|
|
|
if self.dtype is not None:
|
|
|
|
|
self.convert_fn = partial(_convert_floating_point, dtype=self.dtype)
|
|
|
|
|
self.overlap_allgather = overlap_allgather
|
|
|
|
|
if overlap_allgather:
|
|
|
|
|
self.op_hook = ZeroOpHook()
|
|
|
|
|
for p in module.parameters():
|
|
|
|
|
if p.requires_grad and type(p) is not ColoParameter:
|
|
|
|
|
p.__class__ = ColoParameter
|
|
|
|
|
p.__init__(p, requires_grad=True)
|
|
|
|
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
|
|
|
|
if self.convert_fn is not None:
|
|
|
|
|
args = tree_map(self.convert_fn, args)
|
|
|
|
|
kwargs = tree_map(self.convert_fn, kwargs)
|
|
|
|
|
ctx = ColoParamOpHookManager.use_hooks(self.op_hook) if self.overlap_allgather else nullcontext()
|
|
|
|
|
with ctx:
|
|
|
|
|
return super().forward(*args, **kwargs)
|
|
|
|
|
|
|
|
|
|
def _force_wait_all_gather(self):
|
|
|
|
|
for p in self.module.parameters():
|
|
|
|
|
wait_all_gather_handle(p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|
|
|
|
def save_unsharded_optimizer(self, optimizer: OptimizerWrapper, checkpoint: str, gather_dtensor: bool = False):
|
|
|
|
@ -209,6 +226,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|
|
|
|
|
|
|
|
|
def load_unsharded_model(self, model: ModelWrapper, checkpoint: str, strict: bool = True):
|
|
|
|
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
|
|
|
|
model._force_wait_all_gather()
|
|
|
|
|
super().load_unsharded_model(model, checkpoint, strict)
|
|
|
|
|
model.update_master_params()
|
|
|
|
|
|
|
|
|
@ -221,9 +239,30 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|
|
|
|
load_sub_module: bool = True,
|
|
|
|
|
):
|
|
|
|
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
|
|
|
|
model._force_wait_all_gather()
|
|
|
|
|
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
|
|
|
|
model.update_master_params()
|
|
|
|
|
|
|
|
|
|
def save_unsharded_model(self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool):
|
|
|
|
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
|
|
|
|
model._force_wait_all_gather()
|
|
|
|
|
return super().save_unsharded_model(model, checkpoint, gather_dtensor, use_safetensors)
|
|
|
|
|
|
|
|
|
|
def save_sharded_model(
|
|
|
|
|
self,
|
|
|
|
|
model: ModelWrapper,
|
|
|
|
|
checkpoint_path: str,
|
|
|
|
|
gather_dtensor: bool = True,
|
|
|
|
|
prefix: Optional[str] = None,
|
|
|
|
|
max_shard_size: int = 1024,
|
|
|
|
|
use_safetensors: bool = False,
|
|
|
|
|
):
|
|
|
|
|
assert isinstance(model, LowLevelZeroModel), "Please boost the model before loading!"
|
|
|
|
|
model._force_wait_all_gather()
|
|
|
|
|
return super().save_sharded_model(
|
|
|
|
|
model, checkpoint_path, gather_dtensor, prefix, max_shard_size, use_safetensors
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
|
|
|
|
if os.path.isfile(checkpoint):
|
|
|
|
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
|
|
|
@ -231,6 +270,7 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|
|
|
|
from peft import PeftModel
|
|
|
|
|
|
|
|
|
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
|
|
|
|
model._force_wait_all_gather()
|
|
|
|
|
peft_model = model.unwrap()
|
|
|
|
|
assert isinstance(
|
|
|
|
|
peft_model, PeftModel
|
|
|
|
@ -290,6 +330,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
reduce_bucket_size_in_m: int = 12,
|
|
|
|
|
communication_dtype: Optional[torch.dtype] = None,
|
|
|
|
|
overlap_communication: bool = True,
|
|
|
|
|
overlap_allgather: bool = False,
|
|
|
|
|
cpu_offload: bool = False,
|
|
|
|
|
master_weights: bool = True,
|
|
|
|
|
verbose: bool = False,
|
|
|
|
@ -315,6 +356,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
partition_grad=(stage == 2),
|
|
|
|
|
cpu_offload=cpu_offload,
|
|
|
|
|
master_weights=master_weights,
|
|
|
|
|
overlap_allgather=overlap_allgather,
|
|
|
|
|
)
|
|
|
|
|
self.lora_enabled = False
|
|
|
|
|
self.verbose = verbose
|
|
|
|
@ -431,7 +473,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
self.add_lora_params_to_optimizer(model, optimizer)
|
|
|
|
|
|
|
|
|
|
if not isinstance(model, ModelWrapper):
|
|
|
|
|
model = LowLevelZeroModel(model, self.precision)
|
|
|
|
|
model = LowLevelZeroModel(
|
|
|
|
|
model, self.precision, overlap_allgather=self.zero_optim_kwargs["overlap_allgather"]
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# TODO: Support Galore + ZeRO
|
|
|
|
|
zero_stage = self.stage
|
|
|
|
|