|
|
|
@ -1,12 +1,15 @@
|
|
|
|
|
import logging |
|
|
|
|
import warnings |
|
|
|
|
import enum |
|
|
|
|
import os |
|
|
|
|
from functools import partial |
|
|
|
|
from pathlib import Path |
|
|
|
|
from types import MethodType |
|
|
|
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple |
|
|
|
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict |
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
import torch.nn as nn |
|
|
|
|
from torch.nn import Parameter |
|
|
|
|
from torch.optim import Optimizer |
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler |
|
|
|
|
from torch.utils._pytree import tree_map |
|
|
|
@ -41,6 +44,11 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
|
|
|
|
|
|
|
|
|
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"] |
|
|
|
|
|
|
|
|
|
class OptimizerParamCheckState(enum.Enum): |
|
|
|
|
ORIGIN_PARAM_FINDED = 0 |
|
|
|
|
ORIGIN_PARAM_NOT_FIND = -1 |
|
|
|
|
LORA_PARM_EXISTED = -2 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroModel(ModelWrapper, AMPModelMixin): |
|
|
|
|
def __init__(self, module: nn.Module, precision: str) -> None: |
|
|
|
@ -208,6 +216,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|
|
|
|
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module) |
|
|
|
|
model.update_master_params() |
|
|
|
|
|
|
|
|
|
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors): |
|
|
|
|
if os.path.isfile(checkpoint): |
|
|
|
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file") |
|
|
|
|
return |
|
|
|
|
from peft import PeftModel |
|
|
|
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!" |
|
|
|
|
peft_model = model.unwrap() |
|
|
|
|
assert isinstance( |
|
|
|
|
peft_model, PeftModel |
|
|
|
|
), "The model doesn't have lora adapters, please enable lora before saving." |
|
|
|
|
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroPlugin(DPPluginBase): |
|
|
|
|
""" |
|
|
|
@ -287,6 +307,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
cpu_offload=cpu_offload, |
|
|
|
|
master_weights=master_weights, |
|
|
|
|
) |
|
|
|
|
self.lora_enabled = False |
|
|
|
|
self.verbose = verbose |
|
|
|
|
|
|
|
|
|
# set class name with stage, for better error message |
|
|
|
@ -310,6 +331,66 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
def supported_devices(self) -> List[str]: |
|
|
|
|
return ["cuda"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def support_lora(self) -> bool: |
|
|
|
|
return True |
|
|
|
|
|
|
|
|
|
def enable_lora( |
|
|
|
|
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None |
|
|
|
|
) -> nn.Module: |
|
|
|
|
from peft import PeftModel, get_peft_model |
|
|
|
|
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model." |
|
|
|
|
self.lora_enabled = True |
|
|
|
|
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr") |
|
|
|
|
|
|
|
|
|
if pretrained_dir is None: |
|
|
|
|
peft_model = get_peft_model(model, lora_config) |
|
|
|
|
else: |
|
|
|
|
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True) |
|
|
|
|
return peft_model |
|
|
|
|
|
|
|
|
|
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter): |
|
|
|
|
origin_param_id = id(origin_param) |
|
|
|
|
for group_id, param_group in enumerate(optimizer.param_groups): |
|
|
|
|
for p in param_group['params']: |
|
|
|
|
if id(p) == origin_param_id: |
|
|
|
|
return group_id |
|
|
|
|
return -1 |
|
|
|
|
|
|
|
|
|
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter): |
|
|
|
|
origin_param_id = id(origin_param) |
|
|
|
|
lora_param_id = id(lora_param) |
|
|
|
|
target_group_id = None |
|
|
|
|
for group_id, param_group in enumerate(optimizer.param_groups): |
|
|
|
|
for p in param_group['params']: |
|
|
|
|
if id(p) == lora_param_id: |
|
|
|
|
# check if the lora parameter exists. |
|
|
|
|
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED |
|
|
|
|
if id(p) == origin_param_id: |
|
|
|
|
target_group_id = group_id |
|
|
|
|
if target_group_id is not None: |
|
|
|
|
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED |
|
|
|
|
else: |
|
|
|
|
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND |
|
|
|
|
|
|
|
|
|
def add_lora_params_to_optimizer(self, model, optimizer): |
|
|
|
|
""" add lora parameters to optimizer """ |
|
|
|
|
name2param= {} |
|
|
|
|
for name, param in model.named_parameters(): |
|
|
|
|
name2param[name] = param |
|
|
|
|
|
|
|
|
|
for name, param in name2param.items(): |
|
|
|
|
if 'lora_A' in name or 'lora_B' in name: |
|
|
|
|
origin_key = name.replace("lora_A.", "") |
|
|
|
|
origin_key = origin_key.replace("lora_B.", "") |
|
|
|
|
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer") |
|
|
|
|
origin_param = name2param[origin_key] |
|
|
|
|
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param) |
|
|
|
|
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND: |
|
|
|
|
warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.") |
|
|
|
|
elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0: |
|
|
|
|
optimizer.param_groups[group_id]['params'].append(param) |
|
|
|
|
|
|
|
|
|
def configure( |
|
|
|
|
self, |
|
|
|
|
model: nn.Module, |
|
|
|
@ -318,6 +399,13 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
dataloader: Optional[DataLoader] = None, |
|
|
|
|
lr_scheduler: Optional[LRScheduler] = None, |
|
|
|
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]: |
|
|
|
|
if self.lora_enabled: |
|
|
|
|
from peft import PeftModel |
|
|
|
|
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True" |
|
|
|
|
if optimizer is not None: |
|
|
|
|
self.add_lora_params_to_optimizer(model, optimizer) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(model, ModelWrapper): |
|
|
|
|
model = LowLevelZeroModel(model, self.precision) |
|
|
|
|
|
|
|
|
@ -339,8 +427,3 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]: |
|
|
|
|
assert isinstance(optimizer, LowLevelZeroOptimizer) |
|
|
|
|
return optimizer.no_sync() |
|
|
|
|
|
|
|
|
|
def enable_lora( |
|
|
|
|
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None |
|
|
|
|
) -> nn.Module: |
|
|
|
|
raise NotImplementedError |
|
|
|
|