mirror of https://github.com/hpcaitech/ColossalAI
[lora] add lora APIs for booster, support lora for TorchDDP (#4981)
* add apis and peft requirement * add liscense and implement apis * add checkpointio apis * add torchddp fwd_bwd test * add support_lora methods * add checkpointio test and debug * delete unneeded codes * remove peft from LICENSE * add concrete methods for enable_lora * simplify enable_lora api * fix requirementspull/5670/head
parent
c1594e4bad
commit
14b0d4c7e5
|
@ -8,6 +8,14 @@ from torch.optim import Optimizer
|
|||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
SUPPORT_PEFT = False
|
||||
try:
|
||||
import peft
|
||||
|
||||
SUPPORT_PEFT = True
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import colossalai.interface.pretrained as pretrained_utils
|
||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
|
@ -221,6 +229,38 @@ class Booster:
|
|||
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
|
||||
return self.plugin.no_sync(model, optimizer)
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
|
||||
) -> nn.Module:
|
||||
"""
|
||||
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
|
||||
Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.
|
||||
|
||||
Args:
|
||||
model (nn.Module): The model to be appended with LoRA modules.
|
||||
pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory
|
||||
or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub.
|
||||
When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None.
|
||||
lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
|
||||
"""
|
||||
if not SUPPORT_PEFT:
|
||||
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
||||
|
||||
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
||||
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
||||
if pretrained_dir is None:
|
||||
assert (
|
||||
lora_config is not None
|
||||
), "Please provide configuration for Lora when pretrained directory path isn't passed in."
|
||||
assert isinstance(
|
||||
lora_config, peft.LoraConfig
|
||||
), "The passed in configuration should be an instance of peft.LoraConfig."
|
||||
if lora_config is None:
|
||||
assert (
|
||||
pretrained_dir is not None
|
||||
), "Please provide pretrained directory path if not passing in lora configuration."
|
||||
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
|
||||
|
||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
||||
"""Load model from checkpoint.
|
||||
|
||||
|
@ -323,3 +363,20 @@ class Booster:
|
|||
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
||||
"""
|
||||
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
||||
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||
|
||||
Args:
|
||||
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||
"""
|
||||
if not SUPPORT_PEFT:
|
||||
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
||||
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
||||
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
||||
self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors)
|
||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
|||
import os
|
||||
import random
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -444,6 +444,9 @@ class GeminiPlugin(DPPluginBase):
|
|||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return True
|
||||
|
||||
|
@ -573,3 +576,8 @@ class GeminiPlugin(DPPluginBase):
|
|||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -4,7 +4,7 @@ import warnings
|
|||
from contextlib import contextmanager
|
||||
from functools import partial
|
||||
from types import MethodType
|
||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -1156,6 +1156,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
def support_no_sync(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
||||
|
@ -1356,3 +1359,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
self.zero_stage != 2
|
||||
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
|
||||
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> Module:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -3,7 +3,7 @@ import os
|
|||
from functools import partial
|
||||
from pathlib import Path
|
||||
from types import MethodType
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -296,6 +296,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
def support_no_sync(self) -> bool:
|
||||
return self.stage == 1
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return True
|
||||
|
||||
|
@ -337,3 +340,8 @@ class LowLevelZeroPlugin(DPPluginBase):
|
|||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
||||
return optimizer.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.optim import Optimizer
|
||||
|
@ -33,6 +33,10 @@ class Plugin(ABC):
|
|||
def support_no_sync(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def support_lora(self) -> bool:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def configure(
|
||||
self,
|
||||
|
@ -63,6 +67,12 @@ class Plugin(ABC):
|
|||
Context manager to disable gradient synchronization.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
|
||||
"""
|
||||
Add LoRA modules to the model passed in. Should only be called in booster.enable_lora().
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def prepare_dataloader(
|
||||
self,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Callable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
|
@ -116,6 +116,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
|||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
||||
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to checkpoint directory.
|
||||
"""
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
if self.coordinator.is_master():
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
|
||||
|
||||
|
||||
class TorchDDPModel(ModelWrapper):
|
||||
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
|
||||
|
@ -173,6 +189,9 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
def support_no_sync(self) -> bool:
|
||||
return True
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return True
|
||||
|
||||
def control_precision(self) -> bool:
|
||||
return False
|
||||
|
||||
|
@ -216,3 +235,14 @@ class TorchDDPPlugin(DPPluginBase):
|
|||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
|
||||
return model.module.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
from peft import PeftModel, get_peft_model
|
||||
|
||||
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
|
||||
if pretrained_dir is None:
|
||||
return get_peft_model(model, lora_config)
|
||||
else:
|
||||
return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
|||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
|
||||
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -318,6 +318,9 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
def support_no_sync(self) -> bool:
|
||||
return False
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
|
||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
||||
|
||||
|
@ -361,3 +364,8 @@ class TorchFSDPPlugin(DPPluginBase):
|
|||
|
||||
def get_checkpoint_io(self) -> CheckpointIO:
|
||||
return TorchFSDPCheckpointIO()
|
||||
|
||||
def enable_lora(
|
||||
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
) -> nn.Module:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -335,3 +335,20 @@ class CheckpointIO(ABC):
|
|||
"""
|
||||
state_dict = torch.load(checkpoint)
|
||||
lr_scheduler.load_state_dict(state_dict)
|
||||
|
||||
# ================================================================================
|
||||
# Abstract method for lora saving implementation.
|
||||
# ================================================================================
|
||||
|
||||
@abstractmethod
|
||||
def save_lora_as_pretrained(
|
||||
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||
|
||||
Args:
|
||||
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||
"""
|
||||
|
|
|
@ -228,3 +228,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
|||
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||
)
|
||||
)
|
||||
|
||||
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -5,7 +5,7 @@ git+https://github.com/hpcaitech/pytest-testmon
|
|||
torchvision
|
||||
timm
|
||||
titans
|
||||
torchaudio
|
||||
torchaudio>=0.13.1
|
||||
torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package is updated every day. We fix the version to a specific date to avoid breaking changes.
|
||||
torchrec==0.2.0
|
||||
contexttimer
|
||||
|
@ -18,4 +18,5 @@ flash_attn
|
|||
datasets
|
||||
pydantic
|
||||
ray
|
||||
peft
|
||||
#auto-gptq now not support torch1.12
|
||||
|
|
|
@ -0,0 +1,108 @@
|
|||
import copy
|
||||
import os
|
||||
|
||||
import torch
|
||||
from peft import LoraConfig
|
||||
from torch import distributed as dist
|
||||
from torch.optim import AdamW
|
||||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import TorchDDPPlugin
|
||||
from colossalai.testing import (
|
||||
assert_equal,
|
||||
assert_not_equal,
|
||||
check_state_dict_equal,
|
||||
clear_cache_before_run,
|
||||
rerun_if_address_is_in_use,
|
||||
spawn,
|
||||
)
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_checkpoint_io.utils import shared_tempdir
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
||||
model = model_fn()
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
|
||||
plugin = TorchDDPPlugin()
|
||||
booster = Booster(plugin=plugin)
|
||||
|
||||
model = booster.enable_lora(model, lora_config=lora_config)
|
||||
model_copy = copy.deepcopy(model)
|
||||
|
||||
optimizer = AdamW(model.parameters(), lr=0.001)
|
||||
criterion = loss_fn
|
||||
|
||||
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||
|
||||
data = data_gen_fn()
|
||||
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
|
||||
|
||||
output = model(**data)
|
||||
output = output_transform_fn(output)
|
||||
loss = criterion(output)
|
||||
|
||||
booster.backward(loss, optimizer)
|
||||
optimizer.clip_grad_by_norm(1.0)
|
||||
optimizer.step()
|
||||
|
||||
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
|
||||
if "lora_" in n1:
|
||||
# lora modules require gradients, thus updated
|
||||
assert p1.requires_grad
|
||||
assert_not_equal(p1.to(p2.device), p2)
|
||||
else:
|
||||
if not p1.requires_grad:
|
||||
assert_equal(p1.to(p2.device), p2)
|
||||
|
||||
|
||||
@clear_cache_before_run()
|
||||
def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
||||
plugin = TorchDDPPlugin()
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
|
||||
model_save = model_fn()
|
||||
model_load = copy.deepcopy(model_save)
|
||||
|
||||
booster = Booster(plugin=plugin)
|
||||
model_save = booster.enable_lora(model_save, lora_config=lora_config)
|
||||
model_save, _, _, _, _ = booster.boost(model_save)
|
||||
|
||||
with shared_tempdir() as tempdir:
|
||||
lora_ckpt_path = os.path.join(tempdir, "ckpt")
|
||||
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
|
||||
dist.barrier()
|
||||
|
||||
# The Lora checkpoint should be small in size
|
||||
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
|
||||
assert checkpoint_size_mb < 1
|
||||
|
||||
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path)
|
||||
model_load, _, _, _, _ = booster.boost(model_load)
|
||||
|
||||
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
|
||||
|
||||
|
||||
def run_lora_test():
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
task_type = None
|
||||
if name == "transformers_llama_for_casual_lm":
|
||||
task_type = "CAUSAL_LM"
|
||||
if name == "transformers_llama_for_sequence_classification":
|
||||
task_type = "SEQ_CLS"
|
||||
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
||||
check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
config = {}
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||
run_lora_test()
|
||||
|
||||
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_torch_ddp_lora():
|
||||
spawn(run_dist, 2)
|
Loading…
Reference in New Issue