mirror of https://github.com/hpcaitech/ColossalAI
parent
09c5f72595
commit
30f4e31a33
|
@ -1,7 +1,7 @@
|
||||||
from .base import BaseModel
|
from .base import BaseModel
|
||||||
from .critic import Critic
|
from .critic import Critic
|
||||||
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
|
from .generation import generate, generate_streaming, prepare_inputs_fn, update_model_kwargs_fn
|
||||||
from .lora import convert_to_lora_module
|
from .lora import LoraConfig, convert_to_lora_module, lora_manager
|
||||||
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
from .loss import DpoLoss, KTOLoss, LogExpLoss, LogSigLoss, PolicyLoss, ValueLoss
|
||||||
from .reward_model import RewardModel
|
from .reward_model import RewardModel
|
||||||
from .utils import disable_dropout
|
from .utils import disable_dropout
|
||||||
|
@ -14,6 +14,8 @@ __all__ = [
|
||||||
"ValueLoss",
|
"ValueLoss",
|
||||||
"LogSigLoss",
|
"LogSigLoss",
|
||||||
"LogExpLoss",
|
"LogExpLoss",
|
||||||
|
"LoraConfig",
|
||||||
|
"lora_manager",
|
||||||
"convert_to_lora_module",
|
"convert_to_lora_module",
|
||||||
"DpoLoss",
|
"DpoLoss",
|
||||||
"KTOLoss" "generate",
|
"KTOLoss" "generate",
|
||||||
|
|
|
@ -5,10 +5,11 @@ LORA utils
|
||||||
import dataclasses
|
import dataclasses
|
||||||
import math
|
import math
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Optional
|
from typing import List, Optional, Union
|
||||||
|
|
||||||
import loralib as lora
|
import loralib as lora
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
@ -18,148 +19,349 @@ logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
class LoRAManager:
|
class LoraManager:
|
||||||
merge_weights: bool = False
|
able_to_merge: bool = True
|
||||||
|
|
||||||
|
|
||||||
LORA_MANAGER = LoRAManager()
|
lora_manager = LoraManager()
|
||||||
|
|
||||||
|
|
||||||
class LoraLinear(lora.LoRALayer, nn.Module):
|
@dataclasses.dataclass
|
||||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
class LoraConfig:
|
||||||
|
r: int = 0
|
||||||
|
lora_alpha: int = 32
|
||||||
|
linear_lora_dropout: float = 0.1
|
||||||
|
embedding_lora_dropout: float = 0.0
|
||||||
|
lora_train_bias: str = "none"
|
||||||
|
lora_initialization_method: str = "kaiming_uniform"
|
||||||
|
target_modules: List = None
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_file(cls, config_file: str):
|
||||||
|
import json
|
||||||
|
|
||||||
|
with open(config_file, "r") as f:
|
||||||
|
config = json.load(f)
|
||||||
|
return cls(**config)
|
||||||
|
|
||||||
|
|
||||||
|
class LoraBase(lora.LoRALayer, nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
weight: nn.Parameter,
|
|
||||||
bias: Optional[nn.Parameter],
|
|
||||||
r: int = 0,
|
r: int = 0,
|
||||||
lora_alpha: int = 1,
|
lora_alpha: int = 32,
|
||||||
lora_dropout: float = 0.0,
|
lora_dropout: float = 0.1,
|
||||||
# Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
lora_initialization_method: str = "kaiming_uniform",
|
||||||
fan_in_fan_out: bool = False,
|
|
||||||
):
|
):
|
||||||
nn.Module.__init__(self)
|
nn.Module.__init__(self)
|
||||||
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
|
lora.LoRALayer.__init__(self, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, merge_weights=False)
|
||||||
self.weight = weight
|
self.r = r
|
||||||
self.bias = bias
|
self.lora_alpha = lora_alpha
|
||||||
|
self.lora_dropout = nn.Dropout(lora_dropout)
|
||||||
out_features, in_features = weight.shape
|
self.merged = False
|
||||||
self.in_features = in_features
|
self.lora_initialization_method = lora_initialization_method
|
||||||
self.out_features = out_features
|
self.weight = None
|
||||||
|
self.bias = None
|
||||||
self.fan_in_fan_out = fan_in_fan_out
|
self.lora_A = None
|
||||||
# Actual trainable parameters
|
self.lora_B = None
|
||||||
if r > 0:
|
|
||||||
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)), requires_grad=False)
|
|
||||||
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
|
|
||||||
self.scaling = self.lora_alpha / self.r
|
|
||||||
# Freezing the pre-trained weight matrix
|
|
||||||
self.weight.requires_grad = False
|
|
||||||
self.reset_parameters()
|
|
||||||
if fan_in_fan_out:
|
|
||||||
self.weight.data = self.weight.data.T
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
if hasattr(self, "lora_A"):
|
if hasattr(self, "lora_A"):
|
||||||
# Initialize A with the default values for nn.Linear and set B to zero.
|
if self.lora_initialization_method == "kaiming_uniform" or self.weight.size() != (
|
||||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
self.out_features,
|
||||||
nn.init.zeros_(self.lora_B)
|
self.in_features,
|
||||||
|
):
|
||||||
|
# Initialize A with the default values for nn.Linear and set B to zero.
|
||||||
|
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||||
|
nn.init.zeros_(self.lora_B)
|
||||||
|
elif self.lora_initialization_method == "PiSSA":
|
||||||
|
# PiSSA method in this paper: https://arxiv.org/abs/2404.02948
|
||||||
|
# Assume the SVD of the original weights is W = USV^T
|
||||||
|
# Initialize a frozen weight to U[:,r:]S[r:,r:]V^T[:,r:] to store less significent part of W
|
||||||
|
# Only A, B are trainable, which are initialized to S[r:,:r]^0.5V^T[:,:r] and U[:,:r]S[r:,:r] respectively
|
||||||
|
# self.scaling = 1.
|
||||||
|
# SVD
|
||||||
|
U, S, Vh = torch.svd_lowrank(
|
||||||
|
self.weight.to(torch.float32).data, self.r, niter=4
|
||||||
|
) # U: [out_features, in_features], S: [in_features], V: [in_features, in_features]
|
||||||
|
# weight_backup = self.weight.clone()
|
||||||
|
|
||||||
|
# Initialize A, B
|
||||||
|
S = S / self.scaling
|
||||||
|
self.lora_B.data = (U @ torch.diag(torch.sqrt(S))).to(torch.float32).contiguous()
|
||||||
|
self.lora_A.data = (torch.diag(torch.sqrt(S)) @ Vh.T).to(torch.float32).contiguous()
|
||||||
|
# Initialize weight
|
||||||
|
# To reduce floating point error, we use residual instead of directly using U[:, :self.r] @ S[:self.r] @ Vh[:self.r, :]
|
||||||
|
self.weight.data = (
|
||||||
|
((self.weight - self.scaling * self.lora_B @ self.lora_A)).contiguous().to(self.weight.dtype)
|
||||||
|
)
|
||||||
|
self.lora_A.requires_grad = True
|
||||||
|
self.lora_B.requires_grad = True
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown LoRA initialization method {self.lora_initialization_method}")
|
||||||
|
|
||||||
def train(self, mode: bool = True):
|
def train(self, mode: bool = True):
|
||||||
"""
|
"""
|
||||||
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
|
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def T(w):
|
|
||||||
return w.T if self.fan_in_fan_out else w
|
|
||||||
|
|
||||||
self.training = mode
|
self.training = mode
|
||||||
if LORA_MANAGER.merge_weights:
|
if mode and self.merged:
|
||||||
if mode and self.merged:
|
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
|
||||||
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
|
raise NotImplementedError("LoRA unmerge is not tested.")
|
||||||
raise NotImplementedError("LoRA unmerge is not tested.")
|
elif not mode and not self.merged and lora_manager.able_to_merge:
|
||||||
# Make sure that the weights are not merged
|
warnings.warn("Invoke module.eval() would merge LoRA weights.")
|
||||||
if self.r > 0:
|
# Merge the weights and mark it
|
||||||
if not hasattr(self, "lora_A") or not hasattr(self, "lora_B"):
|
if self.r > 0:
|
||||||
# FIXME(csric): temporary fix
|
self.weight.data += self.lora_B @ self.lora_A * self.scaling
|
||||||
self.lora_A = nn.Parameter(self.weight.new_empty((self.r, self.in_features)))
|
delattr(self, "lora_A")
|
||||||
self.lora_B = nn.Parameter(self.weight.new_empty((self.out_features, self.r)))
|
delattr(self, "lora_B")
|
||||||
self.reset_parameters()
|
self.merged = True
|
||||||
else:
|
|
||||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
|
||||||
self.merged = False
|
|
||||||
elif not mode and not self.merged:
|
|
||||||
warnings.warn("Invoke module.eval() would merge LoRA weights.")
|
|
||||||
# Merge the weights and mark it
|
|
||||||
if self.r > 0:
|
|
||||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
|
||||||
delattr(self, "lora_A")
|
|
||||||
delattr(self, "lora_B")
|
|
||||||
self.merged = True
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor):
|
|
||||||
def T(w):
|
|
||||||
return w.T if self.fan_in_fan_out else w
|
|
||||||
|
|
||||||
|
class LoraLinear(LoraBase):
|
||||||
|
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: nn.Parameter,
|
||||||
|
bias: Union[nn.Parameter, bool],
|
||||||
|
r: int = 0,
|
||||||
|
lora_alpha: int = 32,
|
||||||
|
lora_dropout: float = 0.0,
|
||||||
|
lora_initialization_method: str = "kaiming_uniform",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
|
||||||
|
)
|
||||||
|
self.weight = weight
|
||||||
|
self.bias = bias
|
||||||
|
if bias is True:
|
||||||
|
self.bias = nn.Parameter(torch.zeros(weight.shape[0]))
|
||||||
|
if bias is not None:
|
||||||
|
self.bias.requires_grad = True
|
||||||
|
|
||||||
|
out_features, in_features = weight.shape
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
|
||||||
|
self.lora_initialization_method = lora_initialization_method
|
||||||
|
# Actual trainable parameters
|
||||||
|
if r > 0:
|
||||||
|
self.lora_A = nn.Parameter(torch.randn((r, in_features)))
|
||||||
|
self.lora_B = nn.Parameter(torch.randn((out_features, r)))
|
||||||
|
self.scaling = self.lora_alpha / self.r
|
||||||
|
# Freezing the pre-trained weight matrix
|
||||||
|
self.weight.requires_grad = False
|
||||||
|
self.reset_parameters()
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
if self.r > 0 and not self.merged:
|
if self.r > 0 and not self.merged:
|
||||||
result = F.linear(x, T(self.weight), bias=self.bias)
|
result = F.linear(x, self.weight, bias=self.bias)
|
||||||
if self.r > 0:
|
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
|
||||||
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
|
|
||||||
return result
|
return result
|
||||||
else:
|
else:
|
||||||
return F.linear(x, T(self.weight), bias=self.bias)
|
return F.linear(x, self.weight, bias=self.bias)
|
||||||
|
|
||||||
|
|
||||||
def _lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
class LoraEmbedding(LoraBase):
|
||||||
|
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
weight: nn.Parameter,
|
||||||
|
r: int = 0,
|
||||||
|
lora_alpha: int = 32,
|
||||||
|
lora_dropout: float = 0.1,
|
||||||
|
num_embeddings: int = None,
|
||||||
|
embedding_dim: int = None,
|
||||||
|
padding_idx: Optional[int] = None,
|
||||||
|
max_norm: Optional[float] = None,
|
||||||
|
norm_type: float = 2.0,
|
||||||
|
scale_grad_by_freq: bool = False,
|
||||||
|
sparse: bool = False,
|
||||||
|
lora_initialization_method: str = "kaiming_uniform",
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout, lora_initialization_method=lora_initialization_method
|
||||||
|
)
|
||||||
|
self.padding_idx = padding_idx
|
||||||
|
self.max_norm = max_norm
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.scale_grad_by_freq = scale_grad_by_freq
|
||||||
|
self.sparse = sparse
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
|
||||||
|
self.weight = weight
|
||||||
|
|
||||||
|
in_features, out_features = num_embeddings, embedding_dim
|
||||||
|
self.in_features = in_features
|
||||||
|
self.out_features = out_features
|
||||||
|
assert lora_initialization_method in ["kaiming_uniform", "PiSSA"]
|
||||||
|
self.lora_initialization_method = lora_initialization_method
|
||||||
|
|
||||||
|
# Actual trainable parameters
|
||||||
|
if r > 0:
|
||||||
|
self.lora_A = nn.Parameter(torch.randn((r, in_features)))
|
||||||
|
self.lora_B = nn.Parameter(torch.randn((out_features, r)))
|
||||||
|
self.scaling = self.lora_alpha / self.r
|
||||||
|
# Freezing the pre-trained weight matrix
|
||||||
|
self.weight.requires_grad = False
|
||||||
|
|
||||||
|
# reset parameters
|
||||||
|
nn.init.zeros_(self.lora_A)
|
||||||
|
nn.init.normal_(self.lora_B)
|
||||||
|
|
||||||
|
def _embed(self, x: torch.Tensor, weight) -> torch.Tensor:
|
||||||
|
return F.embedding(
|
||||||
|
x,
|
||||||
|
weight,
|
||||||
|
padding_idx=self.padding_idx,
|
||||||
|
max_norm=self.max_norm,
|
||||||
|
norm_type=self.norm_type,
|
||||||
|
scale_grad_by_freq=self.scale_grad_by_freq,
|
||||||
|
sparse=self.sparse,
|
||||||
|
)
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor):
|
||||||
|
base_embedding = self._embed(x, self.weight)
|
||||||
|
# base_embedding.requires_grad = True # force the embedding layer to be trainable for gradient checkpointing
|
||||||
|
if self.r > 0 and not self.merged:
|
||||||
|
lora_A_embedding = self._embed(x, self.lora_A.t())
|
||||||
|
embedding = base_embedding + (lora_A_embedding @ self.lora_B.t()) * self.scaling
|
||||||
|
return embedding
|
||||||
|
else:
|
||||||
|
return base_embedding
|
||||||
|
|
||||||
|
def train(self, mode: bool = True):
|
||||||
|
"""
|
||||||
|
This function runs when model.train() is invoked. It is used to prepare the linear layer for training
|
||||||
|
"""
|
||||||
|
|
||||||
|
self.training = mode
|
||||||
|
if mode and self.merged:
|
||||||
|
warnings.warn("Invoke module.train() would unmerge LoRA weights.")
|
||||||
|
raise NotImplementedError("LoRA unmerge is not tested.")
|
||||||
|
elif not mode and not self.merged and lora_manager.able_to_merge:
|
||||||
|
warnings.warn("Invoke module.eval() would merge LoRA weights.")
|
||||||
|
# Merge the weights and mark it
|
||||||
|
if self.r > 0:
|
||||||
|
self.weight.data += self.lora_A.t() @ self.lora_B.t() * self.scaling
|
||||||
|
delattr(self, "lora_A")
|
||||||
|
delattr(self, "lora_B")
|
||||||
|
self.merged = True
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
def _lora_linear_wrapper(linear: nn.Linear, lora_config: LoraConfig) -> LoraLinear:
|
||||||
"""
|
"""
|
||||||
Wraps a linear layer with LoRA functionality.
|
Wraps a linear layer with LoRA functionality.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
linear (nn.Linear): The linear layer to be wrapped.
|
linear (nn.Linear): The linear layer to be wrapped.
|
||||||
lora_rank (int): The rank of the LoRA decomposition.
|
lora_rank (int): The rank of the LoRA decomposition.
|
||||||
|
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
|
||||||
|
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
LoraLinear: The wrapped linear layer with LoRA functionality.
|
LoraLinear: The wrapped linear layer with LoRA functionality.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert (
|
||||||
lora_rank <= linear.in_features
|
lora_config.r <= linear.in_features
|
||||||
), f"LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})"
|
), f"LoRA rank ({lora_config.r}) must be less than or equal to in features ({linear.in_features})"
|
||||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank)
|
bias = None
|
||||||
|
if lora_config.lora_train_bias in ["all", "lora"]:
|
||||||
|
bias = linear.bias
|
||||||
|
if bias is None:
|
||||||
|
bias = True
|
||||||
|
lora_linear = LoraLinear(
|
||||||
|
linear.weight, bias, r=lora_config.r, lora_initialization_method=lora_config.lora_initialization_method
|
||||||
|
)
|
||||||
return lora_linear
|
return lora_linear
|
||||||
|
|
||||||
|
|
||||||
def _convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
def _convert_to_lora_recursively(module: nn.Module, parent_name: str, lora_config: LoraConfig) -> None:
|
||||||
"""
|
"""
|
||||||
Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form.
|
Recursively converts the given module and its children to LoRA (Low-Rank Approximation) form.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (nn.Module): The module to convert to LoRA form.
|
module (nn.Module): The module to convert to LoRA form.
|
||||||
lora_rank (int): The rank of the LoRA approximation.
|
lora_rank (int): The rank of the LoRA approximation.
|
||||||
|
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
|
||||||
|
parent_name (str): The name of the parent module.
|
||||||
|
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
for name, child in module.named_children():
|
for name, child in module.named_children():
|
||||||
if isinstance(child, nn.Linear):
|
if isinstance(child, nn.Linear):
|
||||||
setattr(module, name, _lora_linear_wrapper(child, lora_rank))
|
if lora_config.target_modules is None or any(
|
||||||
|
[name in target_module for target_module in lora_config.target_modules]
|
||||||
|
):
|
||||||
|
if dist.is_initialized() and dist.get_rank() == 0:
|
||||||
|
logger.info(f"Converting {parent_name}.{name} to LoRA")
|
||||||
|
setattr(module, name, _lora_linear_wrapper(child, lora_config))
|
||||||
|
elif isinstance(child, nn.Embedding):
|
||||||
|
if lora_config.target_modules is None or any(
|
||||||
|
[name in target_module for target_module in lora_config.target_modules]
|
||||||
|
):
|
||||||
|
if dist.is_initialized() and dist.get_rank() == 0:
|
||||||
|
logger.info(f"Converting {parent_name}.{name} to LoRA")
|
||||||
|
setattr(
|
||||||
|
module,
|
||||||
|
name,
|
||||||
|
LoraEmbedding(
|
||||||
|
child.weight,
|
||||||
|
r=lora_config.r,
|
||||||
|
lora_alpha=lora_config.lora_alpha,
|
||||||
|
lora_dropout=lora_config.embedding_lora_dropout,
|
||||||
|
num_embeddings=child.num_embeddings,
|
||||||
|
embedding_dim=child.embedding_dim,
|
||||||
|
padding_idx=child.padding_idx,
|
||||||
|
max_norm=child.max_norm,
|
||||||
|
norm_type=child.norm_type,
|
||||||
|
scale_grad_by_freq=child.scale_grad_by_freq,
|
||||||
|
sparse=child.sparse,
|
||||||
|
lora_initialization_method=lora_config.lora_initialization_method,
|
||||||
|
),
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
_convert_to_lora_recursively(child, lora_rank)
|
_convert_to_lora_recursively(child, f"{parent_name}.{name}", lora_config)
|
||||||
|
|
||||||
|
|
||||||
def convert_to_lora_module(module: nn.Module, lora_rank: int, lora_train_bias: str = "none") -> nn.Module:
|
def convert_to_lora_module(module: nn.Module, lora_config: LoraConfig) -> nn.Module:
|
||||||
"""Convert a torch.nn.Module to a LoRA module.
|
"""Convert a torch.nn.Module to a LoRA module.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
module (nn.Module): The module to convert.
|
module (nn.Module): The module to convert.
|
||||||
lora_rank (int): LoRA rank.
|
lora_rank (int): LoRA rank.
|
||||||
|
lora_train_bias (str): Whether to train the bias. Can be "none", "all", "lora".
|
||||||
|
lora_initialization_method (str): The initialization method for LoRA. Can be "kaiming_uniform" or "PiSSA".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
nn.Module: The converted module.
|
nn.Module: The converted module.
|
||||||
"""
|
"""
|
||||||
if lora_rank <= 0:
|
if lora_config.r <= 0:
|
||||||
return module
|
return module
|
||||||
_convert_to_lora_recursively(module, lora_rank)
|
# make all parameter not trainable, if lora_train_bias is "all", set bias to trainable
|
||||||
lora.mark_only_lora_as_trainable(module, lora_train_bias)
|
total_parameter_size = 0
|
||||||
|
for name, p in module.named_parameters():
|
||||||
|
p.requires_grad = False
|
||||||
|
if "bias" in name and lora_config.lora_train_bias == "all":
|
||||||
|
p.requires_grad = True
|
||||||
|
total_parameter_size += p.numel()
|
||||||
|
_convert_to_lora_recursively(module, "", lora_config)
|
||||||
|
trainable_parameter_size = 0
|
||||||
|
for name, p in module.named_parameters():
|
||||||
|
if p.requires_grad == True:
|
||||||
|
trainable_parameter_size += p.numel()
|
||||||
|
if dist.is_initialized() and dist.get_rank() == 0:
|
||||||
|
logger.info(
|
||||||
|
f"Trainable parameter size: {trainable_parameter_size/1024/1024:.2f}M\nOriginal trainable parameter size: {total_parameter_size/1024/1024:.2f}M\nPercentage: {trainable_parameter_size/total_parameter_size*100:.2f}%"
|
||||||
|
)
|
||||||
return module
|
return module
|
||||||
|
|
|
@ -9,6 +9,7 @@
|
||||||
- [Install Requirements](#install-requirements)
|
- [Install Requirements](#install-requirements)
|
||||||
- [Get Start with ColossalRun](#get-start-with-colossalrun)
|
- [Get Start with ColossalRun](#get-start-with-colossalrun)
|
||||||
- [Training Configuration](#training-configuration)
|
- [Training Configuration](#training-configuration)
|
||||||
|
- [Parameter Efficient Finetuning (PEFT)](#parameter-efficient-finetuning-peft)
|
||||||
- [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning)
|
- [RLHF Stage 1: Supervised Instruction Tuning](#rlhf-training-stage1---supervised-instructs-tuning)
|
||||||
- [Step 1: Data Collection](#step-1-data-collection)
|
- [Step 1: Data Collection](#step-1-data-collection)
|
||||||
- [Step 2: Preprocessing](#step-2-preprocessing)
|
- [Step 2: Preprocessing](#step-2-preprocessing)
|
||||||
|
@ -377,35 +378,6 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
<details><summary><b>Low Rank Adaption</b></summary>
|
|
||||||
|
|
||||||
|
|
||||||
Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). It dramatically reduces the VRAM consumption at the cost of sacrifice model capability. It is suitable for training LLM with constrained resources.
|
|
||||||
|
|
||||||
|
|
||||||
To enable LoRA, set --lora_rank to a positive value (usually between 20 and 64).
|
|
||||||
```
|
|
||||||
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
|
|
||||||
--pretrain $PRETRAINED_MODEL_PATH \
|
|
||||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
|
||||||
--dataset ${dataset[@]} \
|
|
||||||
--save_interval 5000 \
|
|
||||||
--save_path $SAVE_DIR \
|
|
||||||
--config_file $CONFIG_FILE \
|
|
||||||
--plugin zero2_cpu \
|
|
||||||
--batch_size 4 \
|
|
||||||
--max_epochs 1 \
|
|
||||||
--accumulation_steps 4 \
|
|
||||||
--lr 2e-5 \
|
|
||||||
--max_len 2048 \
|
|
||||||
--lora_rank 32 \ # This enables LoRA
|
|
||||||
--use_wandb
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
</details>
|
|
||||||
|
|
||||||
|
|
||||||
<details><summary><b>Other Training Arguments</b></summary>
|
<details><summary><b>Other Training Arguments</b></summary>
|
||||||
|
|
||||||
|
|
||||||
|
@ -430,6 +402,60 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
|
||||||
- use_wandb: if this flag is up, you can view logs on wandb.
|
- use_wandb: if this flag is up, you can view logs on wandb.
|
||||||
|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
### Parameter Efficient Finetuning (PEFT)
|
||||||
|
|
||||||
|
Currently, we have support LoRA (low-rank adaptation) and PiSSA (principal singular values and singular vectors adaptation). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance.
|
||||||
|
|
||||||
|
|
||||||
|
<details><summary><b>Low Rank Adaption and PiSSA</b></summary>
|
||||||
|
|
||||||
|
|
||||||
|
Details about Low Rank Adaption (LoRA) can be found in the paper: [LoRA: Low-Rank Adaptation of Large Language Models](https://arxiv.org/abs/2106.09685). Details about Principal Singular Values and Singular Vectors Adaptation (PiSSA) can be found in the paper: [PiSSA: Principal Singular Values and Singular Vectors Adaptation of Large Language Models](https://arxiv.org/abs/2404.02948). Both help to reduce the running-time VRAM consumption as well as timing at the cost of overall model performance. It is suitable for training LLM with constrained resources.
|
||||||
|
|
||||||
|
To use LoRA/PiSSA in training, please create a config file as in the following example and set the `--lora_config` to that configuration file.
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"r": 128,
|
||||||
|
"embedding_lora_dropout": 0.0,
|
||||||
|
"linear_lora_dropout": 0.1,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"lora_train_bias": "all",
|
||||||
|
"lora_initialization_method": "PiSSA",
|
||||||
|
"target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
|
||||||
|
}
|
||||||
|
```
|
||||||
|
#### Lora Parameters
|
||||||
|
- r: lora rank
|
||||||
|
- embedding_lora_dropout: dropout probability for embedding layer
|
||||||
|
- linear_lora_dropout: dropout probability for linear layer
|
||||||
|
- lora_alpha: lora alpha, controls how much the adaptor can deviate from the pretrained model.
|
||||||
|
- lora_train_bias: whether to add trainable bias to lora layers, choose from "all" (all layers (including but not limited to lora layers) will have trainable biases), "none" (no trainable biases), "lora" (only lora layers will have trainable biases)
|
||||||
|
- lora_initialization_method: how to initialize lora weights, choose one from ["kaiming_uniform", "PiSSA"], default to "kaiming_uniform". Use "kaiming_uniform" for standard LoRA and "PiSSA" for PiSSA.
|
||||||
|
- target_modules: which module(s) should be converted to lora layers, if the module's name contain the keywords in target modules and the module is a linear or embedding layer, the module will be converted. Otherwise, the module will be frozen. Setting this field to None will automatically convert all linear and embedding layer to their LoRA counterparts. Note that this example only works for LLaMA, for other models, you need to modify it.
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile train_sft.py \
|
||||||
|
--pretrain $PRETRAINED_MODEL_PATH \
|
||||||
|
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||||
|
--dataset ${dataset[@]} \
|
||||||
|
--save_interval 5000 \
|
||||||
|
--save_path $SAVE_DIR \
|
||||||
|
--config_file $CONFIG_FILE \
|
||||||
|
--plugin zero2_cpu \
|
||||||
|
--batch_size 4 \
|
||||||
|
--max_epochs 1 \
|
||||||
|
--accumulation_steps 4 \
|
||||||
|
--lr 2e-5 \
|
||||||
|
--max_len 2048 \
|
||||||
|
--lora_config /PATH/TO/THE/LORA/CONFIG/FILE.json \ # Setting this enables LoRA
|
||||||
|
--use_wandb
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
{
|
||||||
|
"r": 128,
|
||||||
|
"embedding_lora_dropout": 0.0,
|
||||||
|
"linear_lora_dropout": 0.1,
|
||||||
|
"lora_alpha": 32,
|
||||||
|
"lora_train_bias": "all",
|
||||||
|
"lora_initialization_method": "PiSSA",
|
||||||
|
"target_modules": ["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj", "embed_tokens"]
|
||||||
|
}
|
|
@ -6,7 +6,7 @@ from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||||
from coati.models import convert_to_lora_module, disable_dropout
|
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
|
||||||
from coati.trainer import DPOTrainer
|
from coati.trainer import DPOTrainer
|
||||||
from coati.utils import load_checkpoint
|
from coati.utils import load_checkpoint
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
@ -23,8 +23,11 @@ logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
lora_config = None
|
||||||
|
if args.lora_config is not None:
|
||||||
|
lora_config = LoraConfig.from_file(args.lora_config)
|
||||||
# check lora compatibility
|
# check lora compatibility
|
||||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||||
|
@ -115,7 +118,7 @@ def train(args):
|
||||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||||
disable_dropout(model)
|
|
||||||
if not args.disable_reference_model:
|
if not args.disable_reference_model:
|
||||||
if args.use_flash_attn:
|
if args.use_flash_attn:
|
||||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||||
|
@ -125,15 +128,19 @@ def train(args):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||||
disable_dropout(ref_model)
|
|
||||||
else:
|
else:
|
||||||
ref_model = None
|
ref_model = None
|
||||||
if args.lora_rank > 0:
|
if args.lora_config is not None:
|
||||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if "norm" in name or "gate" in name:
|
||||||
|
module = module.to(torch.float32)
|
||||||
|
disable_dropout(model)
|
||||||
|
disable_dropout(ref_model)
|
||||||
|
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
|
@ -280,11 +287,8 @@ def train(args):
|
||||||
use_wandb=args.use_wandb,
|
use_wandb=args.use_wandb,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
if lora_config is not None and lora_config.r > 0:
|
||||||
from coati.models.lora import LORA_MANAGER
|
|
||||||
|
|
||||||
# NOTE: set model to eval to merge LoRA weights
|
# NOTE: set model to eval to merge LoRA weights
|
||||||
LORA_MANAGER.merge_weights = True
|
|
||||||
model.eval()
|
model.eval()
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
if args.save_dir is not None:
|
if args.save_dir is not None:
|
||||||
|
@ -343,15 +347,8 @@ if __name__ == "__main__":
|
||||||
help="Disable the reference model (enabled by default)",
|
help="Disable the reference model (enabled by default)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||||
parser.add_argument(
|
|
||||||
"--lora_train_bias",
|
|
||||||
type=str,
|
|
||||||
default="none",
|
|
||||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
|
||||||
)
|
|
||||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
|
||||||
parser.add_argument("--lr", type=float, default=5e-6)
|
parser.add_argument("--lr", type=float, default=5e-6)
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||||
parser.add_argument("--log_dir", default=None, type=str)
|
parser.add_argument("--log_dir", default=None, type=str)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset
|
from coati.dataset import DataCollatorForKTODataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||||
from coati.models import convert_to_lora_module, disable_dropout
|
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
|
||||||
from coati.trainer import KTOTrainer
|
from coati.trainer import KTOTrainer
|
||||||
from coati.utils import load_checkpoint
|
from coati.utils import load_checkpoint
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
@ -23,8 +23,11 @@ logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
lora_config = None
|
||||||
|
if args.lora_config is not None:
|
||||||
|
lora_config = LoraConfig.from_file(args.lora_config)
|
||||||
# check lora compatibility
|
# check lora compatibility
|
||||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||||
|
@ -115,7 +118,7 @@ def train(args):
|
||||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||||
disable_dropout(model)
|
|
||||||
if args.use_flash_attn:
|
if args.use_flash_attn:
|
||||||
ref_model = AutoModelForCausalLM.from_pretrained(
|
ref_model = AutoModelForCausalLM.from_pretrained(
|
||||||
args.pretrain,
|
args.pretrain,
|
||||||
|
@ -124,13 +127,17 @@ def train(args):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||||
|
if args.lora_config is not None:
|
||||||
|
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if "norm" in name or "gate" in name:
|
||||||
|
module = module.to(torch.float32)
|
||||||
disable_dropout(ref_model)
|
disable_dropout(ref_model)
|
||||||
if args.lora_rank > 0:
|
disable_dropout(model)
|
||||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
|
||||||
|
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
|
@ -299,11 +306,8 @@ def train(args):
|
||||||
use_wandb=args.use_wandb,
|
use_wandb=args.use_wandb,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
if lora_config is not None and lora_config.r > 0:
|
||||||
from coati.models.lora import LORA_MANAGER
|
|
||||||
|
|
||||||
# NOTE: set model to eval to merge LoRA weights
|
# NOTE: set model to eval to merge LoRA weights
|
||||||
LORA_MANAGER.merge_weights = True
|
|
||||||
model.eval()
|
model.eval()
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
if args.save_dir is not None:
|
if args.save_dir is not None:
|
||||||
|
@ -355,15 +359,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--batch_size", type=int, default=4)
|
parser.add_argument("--batch_size", type=int, default=4)
|
||||||
|
|
||||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||||
parser.add_argument(
|
|
||||||
"--lora_train_bias",
|
|
||||||
type=str,
|
|
||||||
default="none",
|
|
||||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
|
||||||
)
|
|
||||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
|
||||||
parser.add_argument("--auto_weight", default=False, action="store_true")
|
parser.add_argument("--auto_weight", default=False, action="store_true")
|
||||||
parser.add_argument("--lr", type=float, default=5e-6)
|
parser.add_argument("--lr", type=float, default=5e-6)
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||||
|
|
|
@ -6,7 +6,7 @@ from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||||
from coati.models import convert_to_lora_module, disable_dropout
|
from coati.models import LoraConfig, convert_to_lora_module, disable_dropout
|
||||||
from coati.trainer import ORPOTrainer
|
from coati.trainer import ORPOTrainer
|
||||||
from coati.utils import load_checkpoint
|
from coati.utils import load_checkpoint
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
@ -23,8 +23,11 @@ logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
lora_config = None
|
||||||
|
if args.lora_config is not None:
|
||||||
|
lora_config = LoraConfig.from_file(args.lora_config)
|
||||||
# check lora compatibility
|
# check lora compatibility
|
||||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||||
|
@ -114,13 +117,16 @@ def train(args):
|
||||||
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
coordinator.print_on_master(msg="Flash-attention enabled successfully")
|
||||||
else:
|
else:
|
||||||
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
model = AutoModelForCausalLM.from_pretrained(args.pretrain)
|
||||||
|
if args.lora_config is not None:
|
||||||
|
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if "norm" in name or "gate" in name:
|
||||||
|
module = module.to(torch.float32)
|
||||||
disable_dropout(model)
|
disable_dropout(model)
|
||||||
if args.lora_rank > 0:
|
|
||||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
|
||||||
|
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
|
@ -262,11 +268,8 @@ def train(args):
|
||||||
use_wandb=args.use_wandb,
|
use_wandb=args.use_wandb,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
if lora_config is not None and lora_config.r > 0:
|
||||||
from coati.models.lora import LORA_MANAGER
|
|
||||||
|
|
||||||
# NOTE: set model to eval to merge LoRA weights
|
# NOTE: set model to eval to merge LoRA weights
|
||||||
LORA_MANAGER.merge_weights = True
|
|
||||||
model.eval()
|
model.eval()
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
if args.save_dir is not None:
|
if args.save_dir is not None:
|
||||||
|
@ -322,15 +325,8 @@ if __name__ == "__main__":
|
||||||
help="Disable the reference model (enabled by default)",
|
help="Disable the reference model (enabled by default)",
|
||||||
)
|
)
|
||||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||||
parser.add_argument(
|
|
||||||
"--lora_train_bias",
|
|
||||||
type=str,
|
|
||||||
default="none",
|
|
||||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
|
||||||
)
|
|
||||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
|
||||||
parser.add_argument("--lr", type=float, default=5e-6)
|
parser.add_argument("--lr", type=float, default=5e-6)
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||||
parser.add_argument("--log_dir", default=None, type=str)
|
parser.add_argument("--log_dir", default=None, type=str)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from coati.dataset import (
|
||||||
load_tokenized_dataset,
|
load_tokenized_dataset,
|
||||||
setup_conversation_template,
|
setup_conversation_template,
|
||||||
)
|
)
|
||||||
from coati.models import Critic, RewardModel, convert_to_lora_module, disable_dropout
|
from coati.models import Critic, LoraConfig, RewardModel, convert_to_lora_module, disable_dropout, lora_manager
|
||||||
from coati.trainer import PPOTrainer
|
from coati.trainer import PPOTrainer
|
||||||
from coati.utils import load_checkpoint
|
from coati.utils import load_checkpoint
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
@ -31,8 +31,11 @@ logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
lora_config = None
|
||||||
|
if args.lora_config is not None:
|
||||||
|
lora_config = LoraConfig.from_file(args.lora_config)
|
||||||
# check lora compatibility
|
# check lora compatibility
|
||||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||||
|
@ -81,20 +84,26 @@ def train(args):
|
||||||
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
|
ref_model = AutoModelForCausalLM.from_pretrained(args.pretrain, local_files_only=True)
|
||||||
reward_model = RewardModel(args.rm_pretrain)
|
reward_model = RewardModel(args.rm_pretrain)
|
||||||
critic = Critic(args.rm_pretrain)
|
critic = Critic(args.rm_pretrain)
|
||||||
|
|
||||||
|
if args.lora_config is not None:
|
||||||
|
actor = convert_to_lora_module(actor, lora_config=lora_config)
|
||||||
|
critic = convert_to_lora_module(critic, lora_config=lora_config)
|
||||||
|
for name, module in actor.named_modules():
|
||||||
|
if "norm" in name or "gate" in name:
|
||||||
|
module = module.to(torch.float32)
|
||||||
|
for name, module in critic.named_modules():
|
||||||
|
if "norm" in name or "gate" in name:
|
||||||
|
module = module.to(torch.float32)
|
||||||
|
lora_manager.able_to_merge = False
|
||||||
|
|
||||||
# Disable dropout
|
# Disable dropout
|
||||||
disable_dropout(actor)
|
disable_dropout(actor)
|
||||||
disable_dropout(critic)
|
disable_dropout(critic)
|
||||||
|
|
||||||
if args.lora_rank > 0:
|
if args.grad_checkpoint:
|
||||||
actor = convert_to_lora_module(actor, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
actor.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
critic = convert_to_lora_module(critic, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
critic.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
|
|
||||||
if args.grad_checkpoint and args.lora_rank == 0:
|
|
||||||
actor.gradient_checkpointing_enable()
|
|
||||||
critic.model.gradient_checkpointing_enable()
|
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
elif args.lora_rank > 0:
|
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||||
|
@ -421,11 +430,9 @@ def train(args):
|
||||||
use_wandb=args.use_wandb,
|
use_wandb=args.use_wandb,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
if lora_config is not None and lora_config.r > 0:
|
||||||
from coati.models.lora import LORA_MANAGER
|
|
||||||
|
|
||||||
# NOTE: set model to eval to merge LoRA weights
|
# NOTE: set model to eval to merge LoRA weights
|
||||||
LORA_MANAGER.merge_weights = True
|
lora_manager.able_to_merge = True
|
||||||
actor.eval()
|
actor.eval()
|
||||||
critic.eval()
|
critic.eval()
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
|
@ -484,11 +491,9 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--train_batch_size", type=int, default=16)
|
parser.add_argument("--train_batch_size", type=int, default=16)
|
||||||
parser.add_argument("--experience_batch_size", type=int, default=16)
|
parser.add_argument("--experience_batch_size", type=int, default=16)
|
||||||
parser.add_argument("--ptx_batch_size", type=int, default=4)
|
parser.add_argument("--ptx_batch_size", type=int, default=4)
|
||||||
parser.add_argument("--lora_train_bias", type=str, default="none")
|
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
|
||||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
|
||||||
parser.add_argument("--lr", type=float, default=9e-6)
|
parser.add_argument("--lr", type=float, default=9e-6)
|
||||||
parser.add_argument("--critic_lr", type=float, default=9e-6)
|
parser.add_argument("--critic_lr", type=float, default=9e-6)
|
||||||
parser.add_argument("--kl_coef", type=float, default=0.1)
|
parser.add_argument("--kl_coef", type=float, default=0.1)
|
||||||
|
|
|
@ -7,7 +7,7 @@ from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
from coati.dataset import DataCollatorForPreferenceDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||||
from coati.models import LogExpLoss, LogSigLoss, RewardModel, convert_to_lora_module
|
from coati.models import LogExpLoss, LogSigLoss, LoraConfig, RewardModel, convert_to_lora_module
|
||||||
from coati.trainer import RewardModelTrainer
|
from coati.trainer import RewardModelTrainer
|
||||||
from coati.utils import load_checkpoint
|
from coati.utils import load_checkpoint
|
||||||
from transformers import AutoTokenizer
|
from transformers import AutoTokenizer
|
||||||
|
@ -25,8 +25,11 @@ logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
lora_config = None
|
||||||
|
if args.lora_config is not None:
|
||||||
|
lora_config = LoraConfig.from_file(args.lora_config)
|
||||||
# check lora compatibility
|
# check lora compatibility
|
||||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||||
|
@ -58,9 +61,11 @@ def train(args):
|
||||||
args.pretrain,
|
args.pretrain,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.lora_rank > 0:
|
if lora_config is not None:
|
||||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if "norm" in name or "gate" in name:
|
||||||
|
module = module.to(torch.float32)
|
||||||
# ==============================
|
# ==============================
|
||||||
# Initialize Booster
|
# Initialize Booster
|
||||||
# ==============================
|
# ==============================
|
||||||
|
@ -122,11 +127,9 @@ def train(args):
|
||||||
|
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
if args.grad_checkpoint and args.lora_rank == 0:
|
if args.grad_checkpoint:
|
||||||
model.model.gradient_checkpointing_enable() # TODO: support gradient checkpoint for the last linear layer
|
model.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
elif args.lora_rank > 0:
|
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing will be disabled when LoRA is enabled")
|
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
tokenizer_dir = args.tokenizer_dir if args.tokenizer_dir is not None else args.pretrain
|
||||||
|
@ -272,16 +275,13 @@ def train(args):
|
||||||
|
|
||||||
trainer.fit(
|
trainer.fit(
|
||||||
train_preference_dataloader=train_dataloader,
|
train_preference_dataloader=train_dataloader,
|
||||||
eval_preference_dataloader=None,
|
eval_preference_dataloader=eval_dataloader,
|
||||||
log_dir=args.log_dir,
|
log_dir=args.log_dir,
|
||||||
use_wandb=args.use_wandb,
|
use_wandb=args.use_wandb,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
if lora_config is not None and lora_config.r > 0:
|
||||||
from coati.models.lora import LORA_MANAGER
|
|
||||||
|
|
||||||
# NOTE: set model to eval to merge LoRA weights
|
# NOTE: set model to eval to merge LoRA weights
|
||||||
LORA_MANAGER.merge_weights = True
|
|
||||||
model.eval()
|
model.eval()
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
if args.save_dir is not None:
|
if args.save_dir is not None:
|
||||||
|
@ -330,15 +330,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--batch_size", type=int, default=4)
|
parser.add_argument("--batch_size", type=int, default=4)
|
||||||
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
parser.add_argument("--mixed_precision", type=str, default="fp16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||||
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function")
|
parser.add_argument("--loss_fn", type=str, default="log_sig", choices=["log_sig", "log_exp"], help="Loss function")
|
||||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||||
parser.add_argument(
|
|
||||||
"--lora_train_bias",
|
|
||||||
type=str,
|
|
||||||
default="none",
|
|
||||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
|
||||||
)
|
|
||||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
|
||||||
parser.add_argument("--lr", type=float, default=5e-6)
|
parser.add_argument("--lr", type=float, default=5e-6)
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||||
parser.add_argument("--log_dir", default=None, type=str)
|
parser.add_argument("--log_dir", default=None, type=str)
|
||||||
|
|
|
@ -7,7 +7,7 @@ from contextlib import nullcontext
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset
|
from coati.dataset import DataCollatorForSupervisedDataset, StatefulDistributedSampler, load_tokenized_dataset
|
||||||
from coati.models import convert_to_lora_module
|
from coati.models import LoraConfig, convert_to_lora_module
|
||||||
from coati.trainer import SFTTrainer
|
from coati.trainer import SFTTrainer
|
||||||
from coati.utils import load_checkpoint
|
from coati.utils import load_checkpoint
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer
|
from transformers import AutoModelForCausalLM, AutoTokenizer
|
||||||
|
@ -24,8 +24,11 @@ logger = get_dist_logger()
|
||||||
|
|
||||||
|
|
||||||
def train(args):
|
def train(args):
|
||||||
|
lora_config = None
|
||||||
|
if args.lora_config is not None:
|
||||||
|
lora_config = LoraConfig.from_file(args.lora_config)
|
||||||
# check lora compatibility
|
# check lora compatibility
|
||||||
if "gemini" in args.plugin and args.lora_rank > 0:
|
if "gemini" in args.plugin and lora_config is not None and lora_config.r > 0:
|
||||||
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("LoRA is not supported in GeminiPlugin. Please use other plugin")
|
||||||
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
if args.plugin == "gemini_auto" and args.accumulation_steps > 1:
|
||||||
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
raise ValueError("Gradient accumulation is not supported in GeminiPlugin. Please use other plugin")
|
||||||
|
@ -53,8 +56,12 @@ def train(args):
|
||||||
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
torch_dtype=torch.bfloat16 if args.mixed_precision == "bf16" else torch.float16,
|
||||||
trust_remote_code=True,
|
trust_remote_code=True,
|
||||||
)
|
)
|
||||||
if args.lora_rank > 0:
|
|
||||||
model = convert_to_lora_module(model, args.lora_rank, lora_train_bias=args.lora_train_bias)
|
if lora_config is not None:
|
||||||
|
model = convert_to_lora_module(model, lora_config=lora_config)
|
||||||
|
for name, module in model.named_modules():
|
||||||
|
if "norm" in name or "gate" in name:
|
||||||
|
module = module.to(torch.float32)
|
||||||
|
|
||||||
if args.plugin == "ddp":
|
if args.plugin == "ddp":
|
||||||
"""
|
"""
|
||||||
|
@ -114,6 +121,15 @@ def train(args):
|
||||||
|
|
||||||
booster = Booster(plugin=plugin)
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
# configure optimizer
|
||||||
|
optim = HybridAdam(
|
||||||
|
model_params=model.parameters(),
|
||||||
|
lr=args.lr,
|
||||||
|
betas=(0.9, 0.95),
|
||||||
|
weight_decay=args.weight_decay,
|
||||||
|
adamw_mode=True,
|
||||||
|
)
|
||||||
|
|
||||||
# ======================================================
|
# ======================================================
|
||||||
# Initialize Model, Objective, Optimizer and LR Scheduler
|
# Initialize Model, Objective, Optimizer and LR Scheduler
|
||||||
# ======================================================
|
# ======================================================
|
||||||
|
@ -124,7 +140,7 @@ def train(args):
|
||||||
|
|
||||||
if args.grad_checkpoint:
|
if args.grad_checkpoint:
|
||||||
# Note, for some models, lora may not be compatible with gradient checkpointing
|
# Note, for some models, lora may not be compatible with gradient checkpointing
|
||||||
model.gradient_checkpointing_enable()
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
|
||||||
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
|
||||||
|
|
||||||
# configure tokenizer
|
# configure tokenizer
|
||||||
|
@ -149,15 +165,6 @@ def train(args):
|
||||||
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
coordinator.print_on_master(f"Configuration file will be saved at: {args.config_file}")
|
||||||
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
|
coordinator.print_on_master(f"Model checkpoint will be saved at: {args.save_path}")
|
||||||
|
|
||||||
# configure optimizer
|
|
||||||
optim = HybridAdam(
|
|
||||||
model_params=model.parameters(),
|
|
||||||
lr=args.lr,
|
|
||||||
betas=(0.9, 0.95),
|
|
||||||
weight_decay=args.weight_decay,
|
|
||||||
adamw_mode=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# configure dataset
|
# configure dataset
|
||||||
coordinator.print_on_master(
|
coordinator.print_on_master(
|
||||||
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
f"Max CUDA memory before data loader: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB"
|
||||||
|
@ -217,6 +224,7 @@ def train(args):
|
||||||
lr_scheduler=lr_scheduler,
|
lr_scheduler=lr_scheduler,
|
||||||
dataloader=train_dataloader,
|
dataloader=train_dataloader,
|
||||||
)
|
)
|
||||||
|
|
||||||
torch.set_default_dtype(torch.float)
|
torch.set_default_dtype(torch.float)
|
||||||
|
|
||||||
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
coordinator.print_on_master(f"Booster init max CUDA memory: {torch.cuda.max_memory_allocated() / 1024 ** 2:.2f} MB")
|
||||||
|
@ -277,11 +285,8 @@ def train(args):
|
||||||
use_wandb=args.use_wandb,
|
use_wandb=args.use_wandb,
|
||||||
)
|
)
|
||||||
|
|
||||||
if args.lora_rank > 0 and args.merge_lora_weights:
|
if lora_config is not None and lora_config.r > 0:
|
||||||
from coati.models.lora import LORA_MANAGER
|
|
||||||
|
|
||||||
# NOTE: set model to eval to merge LoRA weights
|
# NOTE: set model to eval to merge LoRA weights
|
||||||
LORA_MANAGER.merge_weights = True
|
|
||||||
model.eval()
|
model.eval()
|
||||||
# save model checkpoint after fitting on only rank0
|
# save model checkpoint after fitting on only rank0
|
||||||
if args.save_path is not None:
|
if args.save_path is not None:
|
||||||
|
@ -328,15 +333,8 @@ if __name__ == "__main__":
|
||||||
parser.add_argument("--batch_size", type=int, default=4)
|
parser.add_argument("--batch_size", type=int, default=4)
|
||||||
parser.add_argument("--max_len", type=int, default=512)
|
parser.add_argument("--max_len", type=int, default=512)
|
||||||
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
|
parser.add_argument("--mixed_precision", type=str, default="bf16", choices=["fp16", "bf16"], help="Mixed precision")
|
||||||
parser.add_argument("--lora_rank", type=int, default=0, help="low-rank adaptation matrices rank")
|
parser.add_argument("--lora_config", type=str, default=None, help="low-rank adaptation config file path")
|
||||||
parser.add_argument(
|
|
||||||
"--lora_train_bias",
|
|
||||||
type=str,
|
|
||||||
default="none",
|
|
||||||
help="'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers",
|
|
||||||
)
|
|
||||||
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
parser.add_argument("--save_interval", type=int, default=1000, help="number of step between two checkpoints")
|
||||||
parser.add_argument("--merge_lora_weights", type=bool, default=True)
|
|
||||||
parser.add_argument("--lr", type=float, default=5e-6)
|
parser.add_argument("--lr", type=float, default=5e-6)
|
||||||
parser.add_argument("--config_file", type=str, default=None, help="Config file")
|
parser.add_argument("--config_file", type=str, default=None, help="Config file")
|
||||||
parser.add_argument("--accumulation_steps", type=int, default=8)
|
parser.add_argument("--accumulation_steps", type=int, default=8)
|
||||||
|
|
|
@ -21,16 +21,16 @@ PARENT_LOG_DIR="" # Path to a folder to save training config logs
|
||||||
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
PRETRAINED_MODEL_PATH="" # huggingface or local model path
|
||||||
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
PRETRAINED_TOKENIZER_PATH="" # huggingface or local tokenizer path
|
||||||
declare -a dataset=(
|
declare -a dataset=(
|
||||||
/Your/SFT/Data/arrow/part-00000
|
YOUR/SFT/DATA/DIR/arrow/part-00000
|
||||||
/Your/SFT/Data/arrow/part-00001
|
YOUR/SFT/DATA/DIR/arrow/part-00001
|
||||||
/Your/SFT/Data/arrow/part-00002
|
YOUR/SFT/DATA/DIR/arrow/part-00002
|
||||||
/Your/SFT/Data/arrow/part-00003
|
YOUR/SFT/DATA/DIR/arrow/part-00003
|
||||||
/Your/SFT/Data/arrow/part-00004
|
YOUR/SFT/DATA/DIR/arrow/part-00004
|
||||||
/Your/SFT/Data/arrow/part-00005
|
YOUR/SFT/DATA/DIR/arrow/part-00005
|
||||||
/Your/SFT/Data/arrow/part-00006
|
YOUR/SFT/DATA/DIR/arrow/part-00006
|
||||||
/Your/SFT/Data/arrow/part-00007
|
YOUR/SFT/DATA/DIR/arrow/part-00007
|
||||||
/Your/SFT/Data/arrow/part-00008
|
YOUR/SFT/DATA/DIR/arrow/part-00008
|
||||||
/Your/SFT/Data/arrow/part-00009
|
YOUR/SFT/DATA/DIR/arrow/part-00009
|
||||||
)
|
)
|
||||||
|
|
||||||
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
TIMESTAMP=$(date +%Y-%m-%d-%H-%M-%S)
|
||||||
|
@ -47,15 +47,14 @@ colossalai run --nproc_per_node 4 --master_port 31312 --hostfile ./hostfile trai
|
||||||
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
--tokenizer_dir $PRETRAINED_TOKENIZER_PATH \
|
||||||
--save_interval 2000 \
|
--save_interval 2000 \
|
||||||
--dataset ${dataset[@]} \
|
--dataset ${dataset[@]} \
|
||||||
--save_path $SAVE_DIR \
|
|
||||||
--config_file $CONFIG_FILE \
|
|
||||||
--log_dir $LOG_DIR \
|
|
||||||
--lora_rank 0 \
|
|
||||||
--plugin zero2 \
|
--plugin zero2 \
|
||||||
--batch_size 8 \
|
--batch_size 8 \
|
||||||
--max_epochs 1 \
|
--max_epochs 1 \
|
||||||
--accumulation_steps 2 \
|
--accumulation_steps 1 \
|
||||||
--lr 5e-5 \
|
--lr 5e-5 \
|
||||||
--max_len 4096 \
|
--max_len 4096 \
|
||||||
|
--use_flash_attn \
|
||||||
--grad_checkpoint \
|
--grad_checkpoint \
|
||||||
--use_flash_attn
|
--save_path $SAVE_DIR \
|
||||||
|
--config_file $CONFIG_FILE \
|
||||||
|
--log_dir $LOG_DIR \
|
||||||
|
|
|
@ -2,6 +2,7 @@ import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
from coati.models import convert_to_lora_module
|
from coati.models import convert_to_lora_module
|
||||||
|
from coati.models.lora import LoraConfig, LoraEmbedding, LoraLinear
|
||||||
from torch.utils.data import DataLoader, TensorDataset
|
from torch.utils.data import DataLoader, TensorDataset
|
||||||
|
|
||||||
|
|
||||||
|
@ -38,7 +39,7 @@ def test_overfit():
|
||||||
# Build and convert model
|
# Build and convert model
|
||||||
model = SimpleNN(input_size, hidden_size, num_classes)
|
model = SimpleNN(input_size, hidden_size, num_classes)
|
||||||
weight_to_compare = model.fc1.weight.detach().clone()
|
weight_to_compare = model.fc1.weight.detach().clone()
|
||||||
model = convert_to_lora_module(model, lora_rank=30)
|
model = convert_to_lora_module(model, lora_config=LoraConfig(r=32))
|
||||||
|
|
||||||
# Loss and optimizer
|
# Loss and optimizer
|
||||||
criterion = nn.CrossEntropyLoss()
|
criterion = nn.CrossEntropyLoss()
|
||||||
|
@ -50,7 +51,6 @@ def test_overfit():
|
||||||
# Forward pass
|
# Forward pass
|
||||||
outputs = model(inputs)
|
outputs = model(inputs)
|
||||||
loss = criterion(outputs, labels)
|
loss = criterion(outputs, labels)
|
||||||
print(loss)
|
|
||||||
# Backward and optimize
|
# Backward and optimize
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
loss.backward()
|
loss.backward()
|
||||||
|
@ -65,5 +65,50 @@ def test_overfit():
|
||||||
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
|
assert (weight_to_compare - model.fc1.weight).sum() < 0.01
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_linear_accuracy():
|
||||||
|
|
||||||
|
weight = torch.randn(10, 5)
|
||||||
|
linear = nn.Linear(5, 10)
|
||||||
|
linear.weight.data = weight
|
||||||
|
x = torch.randn(10, 5)
|
||||||
|
out_linear = linear(x)
|
||||||
|
|
||||||
|
# lora linear Pissa
|
||||||
|
linear.weight.data = weight
|
||||||
|
lora_linear = LoraLinear(linear.weight, linear.bias, r=2, lora_initialization_method="PiSSA")
|
||||||
|
out_lora = lora_linear(x)
|
||||||
|
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
|
||||||
|
|
||||||
|
# lora linear
|
||||||
|
linear.weight.data = weight
|
||||||
|
lora_linear = LoraLinear(linear.weight, linear.bias, r=2)
|
||||||
|
out_lora = lora_linear(x)
|
||||||
|
assert torch.allclose(out_linear, out_lora, atol=1e-5, rtol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
|
def test_lora_embedding_accuracy():
|
||||||
|
weight = torch.randn(10, 5)
|
||||||
|
embedding = nn.Embedding(10, 5)
|
||||||
|
embedding.weight.data = weight
|
||||||
|
x = torch.randint(0, 10, (10,))
|
||||||
|
out_embedding = embedding(x)
|
||||||
|
|
||||||
|
# lora embedding Pissa
|
||||||
|
embedding.weight.data = weight
|
||||||
|
lora_embedding = LoraEmbedding(
|
||||||
|
embedding.weight, r=2, lora_initialization_method="PiSSA", num_embeddings=10, embedding_dim=5
|
||||||
|
)
|
||||||
|
out_lora = lora_embedding(x)
|
||||||
|
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
|
||||||
|
|
||||||
|
# lora embedding
|
||||||
|
embedding.weight.data = weight
|
||||||
|
lora_embedding = LoraEmbedding(embedding.weight, r=2, num_embeddings=10, embedding_dim=5)
|
||||||
|
out_lora = lora_embedding(x)
|
||||||
|
assert torch.allclose(out_embedding, out_lora, atol=1e-5, rtol=1e-05)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test_overfit()
|
test_overfit()
|
||||||
|
test_lora_linear_accuracy()
|
||||||
|
test_lora_embedding_accuracy()
|
||||||
|
|
|
@ -30,9 +30,10 @@ MODEL_SAVE_PATH=$TEMP_DIR/rlhf_models
|
||||||
MODELS_DIR=$TEMP_DIR/models_config
|
MODELS_DIR=$TEMP_DIR/models_config
|
||||||
# Skip those tests due to CI tests timeout
|
# Skip those tests due to CI tests timeout
|
||||||
MODELS=('llama')
|
MODELS=('llama')
|
||||||
ADVANCED_PLUGINS=('sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu') # pp is still buggy
|
ADVANCED_PLUGINS=('zero2' 'sp_split_gather' 'sp_ring' 'sp_all_to_all' 'tp_zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu') # pp is still buggy
|
||||||
PLUGINS=('3d' 'gemini' 'gemini_auto' 'zero2' 'zero2_cpu')
|
PLUGINS=('zero2' '3d' 'gemini' 'gemini_auto' 'zero2_cpu')
|
||||||
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
|
LORA_RANK=('0') # skip to reduce CI execution time, can pass all locally
|
||||||
|
LORA_CONFIG_ENABLE="--lora_config $BASE_DIR/examples/training_scripts/lora_config.json"
|
||||||
|
|
||||||
export OMP_NUM_THREADS=8
|
export OMP_NUM_THREADS=8
|
||||||
|
|
||||||
|
@ -112,6 +113,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
sp='1'
|
sp='1'
|
||||||
sp_mode='split_gather'
|
sp_mode='split_gather'
|
||||||
enable_sequence_parallelism=''
|
enable_sequence_parallelism=''
|
||||||
|
if [[ $plugin == "zero2" ]]; then
|
||||||
|
lora_config=$LORA_CONFIG_ENABLE
|
||||||
|
else
|
||||||
|
lora_config=""
|
||||||
|
fi
|
||||||
if [[ $plugin == "3d" ]]; then
|
if [[ $plugin == "3d" ]]; then
|
||||||
tp='4'
|
tp='4'
|
||||||
bs='8'
|
bs='8'
|
||||||
|
@ -176,7 +182,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
--eval_dataset ${dataset[@]} \
|
--eval_dataset ${dataset[@]} \
|
||||||
--save_path $MODEL_SAVE_PATH \
|
--save_path $MODEL_SAVE_PATH \
|
||||||
--config_file $MODELS_DIR/config.jsonl \
|
--config_file $MODELS_DIR/config.jsonl \
|
||||||
--lora_rank $lora_rank \
|
$lora_config \
|
||||||
--plugin $plugin \
|
--plugin $plugin \
|
||||||
--batch_size $bs \
|
--batch_size $bs \
|
||||||
--max_epochs 1 \
|
--max_epochs 1 \
|
||||||
|
@ -230,6 +236,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
grad_ckpt=$(random_choice "${GRAD_CKPTS[@]}")
|
||||||
tp='1'
|
tp='1'
|
||||||
bs='2'
|
bs='2'
|
||||||
|
if [[ $plugin == "zero2" ]]; then
|
||||||
|
lora_config=$LORA_CONFIG_ENABLE
|
||||||
|
else
|
||||||
|
lora_config=""
|
||||||
|
fi
|
||||||
if [[ $plugin == "3d" ]]; then
|
if [[ $plugin == "3d" ]]; then
|
||||||
tp='4'
|
tp='4'
|
||||||
bs='8'
|
bs='8'
|
||||||
|
@ -252,7 +263,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
--eval_dataset ${dataset[@]} \
|
--eval_dataset ${dataset[@]} \
|
||||||
--save_dir $MODEL_SAVE_PATH \
|
--save_dir $MODEL_SAVE_PATH \
|
||||||
--config_file $MODELS_DIR/config.jsonl \
|
--config_file $MODELS_DIR/config.jsonl \
|
||||||
--lora_rank $lora_rank \
|
$lora_config \
|
||||||
--plugin $plugin \
|
--plugin $plugin \
|
||||||
--batch_size $bs \
|
--batch_size $bs \
|
||||||
--max_epochs 1 \
|
--max_epochs 1 \
|
||||||
|
@ -308,6 +319,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
bs='4'
|
bs='4'
|
||||||
ebs='8'
|
ebs='8'
|
||||||
conversation_template=$(get_conversation_template_config $model)
|
conversation_template=$(get_conversation_template_config $model)
|
||||||
|
if [[ $plugin == "zero2" ]]; then
|
||||||
|
lora_config=$LORA_CONFIG_ENABLE
|
||||||
|
else
|
||||||
|
lora_config=""
|
||||||
|
fi
|
||||||
if [[ $plugin == "3d" ]]; then
|
if [[ $plugin == "3d" ]]; then
|
||||||
tp='4'
|
tp='4'
|
||||||
bs='16'
|
bs='16'
|
||||||
|
@ -344,7 +360,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
--ptx_batch_size 1 \
|
--ptx_batch_size 1 \
|
||||||
--ptx_coef 0.2 \
|
--ptx_coef 0.2 \
|
||||||
--save_path $MODEL_SAVE_PATH \
|
--save_path $MODEL_SAVE_PATH \
|
||||||
--lora_rank $lora_rank \
|
$lora_config \
|
||||||
--plugin $plugin \
|
--plugin $plugin \
|
||||||
--num_episodes 5 \
|
--num_episodes 5 \
|
||||||
--num_collect_steps 1 \
|
--num_collect_steps 1 \
|
||||||
|
@ -404,6 +420,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
tp='4'
|
tp='4'
|
||||||
bs='8'
|
bs='8'
|
||||||
fi
|
fi
|
||||||
|
if [[ $plugin == "zero2" ]]; then
|
||||||
|
lora_config=$LORA_CONFIG_ENABLE
|
||||||
|
else
|
||||||
|
lora_config=""
|
||||||
|
fi
|
||||||
grad_accu='2'
|
grad_accu='2'
|
||||||
# gemini_auto and gemini doesn't support gradient accumulation
|
# gemini_auto and gemini doesn't support gradient accumulation
|
||||||
if [[ $plugin == "gemini_auto" ]]; then
|
if [[ $plugin == "gemini_auto" ]]; then
|
||||||
|
@ -428,7 +449,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
--eval_dataset ${dataset[@]} \
|
--eval_dataset ${dataset[@]} \
|
||||||
--save_dir $MODEL_SAVE_PATH \
|
--save_dir $MODEL_SAVE_PATH \
|
||||||
--config_file $MODELS_DIR/config.jsonl \
|
--config_file $MODELS_DIR/config.jsonl \
|
||||||
--lora_rank $lora_rank \
|
$lora_config \
|
||||||
--plugin $plugin \
|
--plugin $plugin \
|
||||||
--batch_size $bs \
|
--batch_size $bs \
|
||||||
--max_epochs 1 \
|
--max_epochs 1 \
|
||||||
|
@ -482,6 +503,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
tp='4'
|
tp='4'
|
||||||
bs='8'
|
bs='8'
|
||||||
fi
|
fi
|
||||||
|
if [[ $plugin == "zero2" ]]; then
|
||||||
|
lora_config=$LORA_CONFIG_ENABLE
|
||||||
|
else
|
||||||
|
lora_config=""
|
||||||
|
fi
|
||||||
grad_accu='2'
|
grad_accu='2'
|
||||||
# gemini_auto and gemini doesn't support gradient accumulation
|
# gemini_auto and gemini doesn't support gradient accumulation
|
||||||
if [[ $plugin == "gemini_auto" ]]; then
|
if [[ $plugin == "gemini_auto" ]]; then
|
||||||
|
@ -506,7 +532,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
--eval_dataset ${dataset[@]} \
|
--eval_dataset ${dataset[@]} \
|
||||||
--save_dir $MODEL_SAVE_PATH \
|
--save_dir $MODEL_SAVE_PATH \
|
||||||
--config_file $MODELS_DIR/config.jsonl \
|
--config_file $MODELS_DIR/config.jsonl \
|
||||||
--lora_rank $lora_rank \
|
$lora_config \
|
||||||
--plugin $plugin \
|
--plugin $plugin \
|
||||||
--batch_size $bs \
|
--batch_size $bs \
|
||||||
--max_epochs 1 \
|
--max_epochs 1 \
|
||||||
|
@ -560,6 +586,11 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
tp='4'
|
tp='4'
|
||||||
bs='8'
|
bs='8'
|
||||||
fi
|
fi
|
||||||
|
if [[ $plugin == "zero2" ]]; then
|
||||||
|
lora_config=$LORA_CONFIG_ENABLE
|
||||||
|
else
|
||||||
|
lora_config=""
|
||||||
|
fi
|
||||||
grad_accu='2'
|
grad_accu='2'
|
||||||
# gemini_auto and gemini doesn't support gradient accumulation
|
# gemini_auto and gemini doesn't support gradient accumulation
|
||||||
if [[ $plugin == "gemini_auto" ]]; then
|
if [[ $plugin == "gemini_auto" ]]; then
|
||||||
|
@ -584,7 +615,7 @@ for lora_rank in ${LORA_RANK[@]}; do
|
||||||
--eval_dataset ${dataset[@]} \
|
--eval_dataset ${dataset[@]} \
|
||||||
--save_dir $MODEL_SAVE_PATH \
|
--save_dir $MODEL_SAVE_PATH \
|
||||||
--config_file $MODELS_DIR/config.jsonl \
|
--config_file $MODELS_DIR/config.jsonl \
|
||||||
--lora_rank $lora_rank \
|
$lora_config \
|
||||||
--plugin $plugin \
|
--plugin $plugin \
|
||||||
--batch_size $bs \
|
--batch_size $bs \
|
||||||
--max_epochs 1 \
|
--max_epochs 1 \
|
||||||
|
|
Loading…
Reference in New Issue