|
|
|
@ -1,12 +1,15 @@
|
|
|
|
|
import logging
|
|
|
|
|
import warnings
|
|
|
|
|
import enum
|
|
|
|
|
import os
|
|
|
|
|
from functools import partial
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from types import MethodType
|
|
|
|
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
|
|
|
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Dict
|
|
|
|
|
|
|
|
|
|
import torch
|
|
|
|
|
import torch.nn as nn
|
|
|
|
|
from torch.nn import Parameter
|
|
|
|
|
from torch.optim import Optimizer
|
|
|
|
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
|
|
|
|
from torch.utils._pytree import tree_map
|
|
|
|
@ -41,6 +44,11 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
|
|
|
|
|
|
|
|
|
SUPPORTED_PRECISION = ["fp16", "bf16", "fp32"]
|
|
|
|
|
|
|
|
|
|
class OptimizerParamCheckState(enum.Enum):
|
|
|
|
|
ORIGIN_PARAM_FINDED = 0
|
|
|
|
|
ORIGIN_PARAM_NOT_FIND = -1
|
|
|
|
|
LORA_PARM_EXISTED = -2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroModel(ModelWrapper, AMPModelMixin):
|
|
|
|
|
def __init__(self, module: nn.Module, precision: str) -> None:
|
|
|
|
@ -208,6 +216,18 @@ class LowLevelZeroCheckpointIO(TorchDDPCheckpointIO):
|
|
|
|
|
super().load_sharded_model(model, checkpoint_index_file, strict, use_safetensors, load_sub_module)
|
|
|
|
|
model.update_master_params()
|
|
|
|
|
|
|
|
|
|
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
|
|
|
|
if os.path.isfile(checkpoint):
|
|
|
|
|
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
|
|
|
|
return
|
|
|
|
|
from peft import PeftModel
|
|
|
|
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
|
|
|
|
peft_model = model.unwrap()
|
|
|
|
|
assert isinstance(
|
|
|
|
|
peft_model, PeftModel
|
|
|
|
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
|
|
|
|
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
"""
|
|
|
|
@ -287,6 +307,7 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
cpu_offload=cpu_offload,
|
|
|
|
|
master_weights=master_weights,
|
|
|
|
|
)
|
|
|
|
|
self.lora_enabled = False
|
|
|
|
|
self.verbose = verbose
|
|
|
|
|
|
|
|
|
|
# set class name with stage, for better error message
|
|
|
|
@ -310,6 +331,66 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
def supported_devices(self) -> List[str]:
|
|
|
|
|
return ["cuda"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def support_lora(self) -> bool:
|
|
|
|
|
return True
|
|
|
|
|
|
|
|
|
|
def enable_lora(
|
|
|
|
|
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
|
|
|
|
) -> nn.Module:
|
|
|
|
|
from peft import PeftModel, get_peft_model
|
|
|
|
|
assert not isinstance(model, LowLevelZeroModel), "Lora should be enabled before boosting the model."
|
|
|
|
|
self.lora_enabled = True
|
|
|
|
|
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
|
|
|
|
|
|
|
|
|
if pretrained_dir is None:
|
|
|
|
|
peft_model = get_peft_model(model, lora_config)
|
|
|
|
|
else:
|
|
|
|
|
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
|
|
|
|
return peft_model
|
|
|
|
|
|
|
|
|
|
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter):
|
|
|
|
|
origin_param_id = id(origin_param)
|
|
|
|
|
for group_id, param_group in enumerate(optimizer.param_groups):
|
|
|
|
|
for p in param_group['params']:
|
|
|
|
|
if id(p) == origin_param_id:
|
|
|
|
|
return group_id
|
|
|
|
|
return -1
|
|
|
|
|
|
|
|
|
|
def get_param_group_id(self, optimizer: Optimizer, origin_param: Parameter, lora_param: Parameter):
|
|
|
|
|
origin_param_id = id(origin_param)
|
|
|
|
|
lora_param_id = id(lora_param)
|
|
|
|
|
target_group_id = None
|
|
|
|
|
for group_id, param_group in enumerate(optimizer.param_groups):
|
|
|
|
|
for p in param_group['params']:
|
|
|
|
|
if id(p) == lora_param_id:
|
|
|
|
|
# check if the lora parameter exists.
|
|
|
|
|
return target_group_id, OptimizerParamCheckState.LORA_PARM_EXISTED
|
|
|
|
|
if id(p) == origin_param_id:
|
|
|
|
|
target_group_id = group_id
|
|
|
|
|
if target_group_id is not None:
|
|
|
|
|
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_FINDED
|
|
|
|
|
else:
|
|
|
|
|
return target_group_id, OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND
|
|
|
|
|
|
|
|
|
|
def add_lora_params_to_optimizer(self, model, optimizer):
|
|
|
|
|
""" add lora parameters to optimizer """
|
|
|
|
|
name2param= {}
|
|
|
|
|
for name, param in model.named_parameters():
|
|
|
|
|
name2param[name] = param
|
|
|
|
|
|
|
|
|
|
for name, param in name2param.items():
|
|
|
|
|
if 'lora_A' in name or 'lora_B' in name:
|
|
|
|
|
origin_key = name.replace("lora_A.", "")
|
|
|
|
|
origin_key = origin_key.replace("lora_B.", "")
|
|
|
|
|
origin_key = origin_key.replace(f"{model.active_adapter}", "base_layer")
|
|
|
|
|
origin_param = name2param[origin_key]
|
|
|
|
|
group_id, check_state = self.get_param_group_id(optimizer, origin_param, param)
|
|
|
|
|
if check_state == OptimizerParamCheckState.ORIGIN_PARAM_NOT_FIND:
|
|
|
|
|
warnings.warn("Origin parameter {origin_key} related to {name} doesn't exist in optimizer param_groups.")
|
|
|
|
|
elif check_state == OptimizerParamCheckState.ORIGIN_PARAM_FINDED and group_id is not None and group_id >= 0:
|
|
|
|
|
optimizer.param_groups[group_id]['params'].append(param)
|
|
|
|
|
|
|
|
|
|
def configure(
|
|
|
|
|
self,
|
|
|
|
|
model: nn.Module,
|
|
|
|
@ -318,6 +399,13 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
dataloader: Optional[DataLoader] = None,
|
|
|
|
|
lr_scheduler: Optional[LRScheduler] = None,
|
|
|
|
|
) -> Tuple[nn.Module, OptimizerWrapper, Callable, DataLoader, LRScheduler]:
|
|
|
|
|
if self.lora_enabled:
|
|
|
|
|
from peft import PeftModel
|
|
|
|
|
assert isinstance(model, PeftModel), "The model should have been wrapped as a PeftModel when self.lora_enabled is True"
|
|
|
|
|
if optimizer is not None:
|
|
|
|
|
self.add_lora_params_to_optimizer(model, optimizer)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(model, ModelWrapper):
|
|
|
|
|
model = LowLevelZeroModel(model, self.precision)
|
|
|
|
|
|
|
|
|
@ -339,8 +427,3 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|
|
|
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
|
|
|
|
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
|
|
|
|
return optimizer.no_sync()
|
|
|
|
|
|
|
|
|
|
def enable_lora(
|
|
|
|
|
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
|
|
|
|
) -> nn.Module:
|
|
|
|
|
raise NotImplementedError
|
|
|
|
|