Browse Source

[lora] add lora APIs for booster, support lora for TorchDDP (#4981)

* add apis and peft requirement

* add liscense and implement apis

* add checkpointio apis

* add torchddp fwd_bwd test

* add support_lora methods

* add checkpointio test and debug

* delete unneeded codes

* remove peft from LICENSE

* add concrete methods for enable_lora

* simplify enable_lora api

* fix requirements
pull/5670/head
Baizhou Zhang 1 year ago committed by Hongxin Liu
parent
commit
14b0d4c7e5
  1. 57
      colossalai/booster/booster.py
  2. 10
      colossalai/booster/plugin/gemini_plugin.py
  3. 10
      colossalai/booster/plugin/hybrid_parallel_plugin.py
  4. 10
      colossalai/booster/plugin/low_level_zero_plugin.py
  5. 12
      colossalai/booster/plugin/plugin_base.py
  6. 32
      colossalai/booster/plugin/torch_ddp_plugin.py
  7. 10
      colossalai/booster/plugin/torch_fsdp_plugin.py
  8. 17
      colossalai/checkpoint_io/checkpoint_io_base.py
  9. 3
      colossalai/checkpoint_io/general_checkpoint_io.py
  10. 3
      requirements/requirements-test.txt
  11. 108
      tests/test_lora/test_torch_ddp_lora.py

57
colossalai/booster/booster.py

@ -8,6 +8,14 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
from torch.utils.data import DataLoader
SUPPORT_PEFT = False
try:
import peft
SUPPORT_PEFT = True
except ImportError:
pass
import colossalai.interface.pretrained as pretrained_utils
from colossalai.checkpoint_io import GeneralCheckpointIO
from colossalai.interface import ModelWrapper, OptimizerWrapper
@ -221,6 +229,38 @@ class Booster:
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
return self.plugin.no_sync(model, optimizer)
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
) -> nn.Module:
"""
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.
Args:
model (nn.Module): The model to be appended with LoRA modules.
pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory
or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub.
When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None.
lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
"""
if not SUPPORT_PEFT:
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
if pretrained_dir is None:
assert (
lora_config is not None
), "Please provide configuration for Lora when pretrained directory path isn't passed in."
assert isinstance(
lora_config, peft.LoraConfig
), "The passed in configuration should be an instance of peft.LoraConfig."
if lora_config is None:
assert (
pretrained_dir is not None
), "Please provide pretrained directory path if not passing in lora configuration."
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
"""Load model from checkpoint.
@ -323,3 +363,20 @@ class Booster:
checkpoint (str): Path to the checkpoint. It must be a local file path.
"""
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
) -> None:
"""
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
Args:
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
checkpoint (str): Path to the checkpoint directory. It must be a local path.
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
"""
if not SUPPORT_PEFT:
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors)

10
colossalai/booster/plugin/gemini_plugin.py

@ -3,7 +3,7 @@ import logging
import os
import random
from pathlib import Path
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple
import numpy as np
import torch
@ -444,6 +444,9 @@ class GeminiPlugin(DPPluginBase):
def support_no_sync(self) -> bool:
return False
def support_lora(self) -> bool:
return False
def control_precision(self) -> bool:
return True
@ -573,3 +576,8 @@ class GeminiPlugin(DPPluginBase):
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError

10
colossalai/booster/plugin/hybrid_parallel_plugin.py

@ -4,7 +4,7 @@ import warnings
from contextlib import contextmanager
from functools import partial
from types import MethodType
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
import numpy as np
import torch
@ -1156,6 +1156,9 @@ class HybridParallelPlugin(PipelinePluginBase):
def support_no_sync(self) -> bool:
return True
def support_lora(self) -> bool:
return False
def control_checkpoint_io(self) -> bool:
return True
@ -1356,3 +1359,8 @@ class HybridParallelPlugin(PipelinePluginBase):
self.zero_stage != 2
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
def enable_lora(
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> Module:
raise NotImplementedError

10
colossalai/booster/plugin/low_level_zero_plugin.py

@ -3,7 +3,7 @@ import os
from functools import partial
from pathlib import Path
from types import MethodType
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
@ -296,6 +296,9 @@ class LowLevelZeroPlugin(DPPluginBase):
def support_no_sync(self) -> bool:
return self.stage == 1
def support_lora(self) -> bool:
return False
def control_precision(self) -> bool:
return True
@ -337,3 +340,8 @@ class LowLevelZeroPlugin(DPPluginBase):
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(optimizer, LowLevelZeroOptimizer)
return optimizer.no_sync()
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError

12
colossalai/booster/plugin/plugin_base.py

@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple
import torch.nn as nn
from torch.optim import Optimizer
@ -33,6 +33,10 @@ class Plugin(ABC):
def support_no_sync(self) -> bool:
pass
@abstractmethod
def support_lora(self) -> bool:
pass
@abstractmethod
def configure(
self,
@ -63,6 +67,12 @@ class Plugin(ABC):
Context manager to disable gradient synchronization.
"""
@abstractmethod
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
"""
Add LoRA modules to the model passed in. Should only be called in booster.enable_lora().
"""
@abstractmethod
def prepare_dataloader(
self,

32
colossalai/booster/plugin/torch_ddp_plugin.py

@ -1,4 +1,4 @@
from typing import Callable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
import torch.nn as nn
from torch.nn.parallel import DistributedDataParallel as DDP
@ -116,6 +116,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
) -> None:
"""
Save the lora adapters and adapter configuration file to checkpoint directory.
"""
from peft import PeftModel
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
if self.coordinator.is_master():
peft_model = model.unwrap()
assert isinstance(
peft_model, PeftModel
), "The model doesn't have lora adapters, please enable lora before saving."
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
class TorchDDPModel(ModelWrapper):
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
@ -173,6 +189,9 @@ class TorchDDPPlugin(DPPluginBase):
def support_no_sync(self) -> bool:
return True
def support_lora(self) -> bool:
return True
def control_precision(self) -> bool:
return False
@ -216,3 +235,14 @@ class TorchDDPPlugin(DPPluginBase):
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
return model.module.no_sync()
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
from peft import PeftModel, get_peft_model
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
if pretrained_dir is None:
return get_peft_model(model, lora_config)
else:
return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)

10
colossalai/booster/plugin/torch_fsdp_plugin.py

@ -2,7 +2,7 @@ import logging
import os
import warnings
from pathlib import Path
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
import torch
import torch.nn as nn
@ -318,6 +318,9 @@ class TorchFSDPPlugin(DPPluginBase):
def support_no_sync(self) -> bool:
return False
def support_lora(self) -> bool:
return False
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
@ -361,3 +364,8 @@ class TorchFSDPPlugin(DPPluginBase):
def get_checkpoint_io(self) -> CheckpointIO:
return TorchFSDPCheckpointIO()
def enable_lora(
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
) -> nn.Module:
raise NotImplementedError

17
colossalai/checkpoint_io/checkpoint_io_base.py

@ -335,3 +335,20 @@ class CheckpointIO(ABC):
"""
state_dict = torch.load(checkpoint)
lr_scheduler.load_state_dict(state_dict)
# ================================================================================
# Abstract method for lora saving implementation.
# ================================================================================
@abstractmethod
def save_lora_as_pretrained(
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
) -> None:
"""
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
Args:
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
checkpoint (str): Path to the checkpoint directory. It must be a local path.
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
"""

3
colossalai/checkpoint_io/general_checkpoint_io.py

@ -228,3 +228,6 @@ class GeneralCheckpointIO(CheckpointIO):
self.__class__.__name__, "\n\t".join(error_msgs)
)
)
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
raise NotImplementedError

3
requirements/requirements-test.txt

@ -5,7 +5,7 @@ git+https://github.com/hpcaitech/pytest-testmon
torchvision
timm
titans
torchaudio
torchaudio>=0.13.1
torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package is updated every day. We fix the version to a specific date to avoid breaking changes.
torchrec==0.2.0
contexttimer
@ -18,4 +18,5 @@ flash_attn
datasets
pydantic
ray
peft
#auto-gptq now not support torch1.12

108
tests/test_lora/test_torch_ddp_lora.py

@ -0,0 +1,108 @@
import copy
import os
import torch
from peft import LoraConfig
from torch import distributed as dist
from torch.optim import AdamW
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import TorchDDPPlugin
from colossalai.testing import (
assert_equal,
assert_not_equal,
check_state_dict_equal,
clear_cache_before_run,
rerun_if_address_is_in_use,
spawn,
)
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@clear_cache_before_run()
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
plugin = TorchDDPPlugin()
booster = Booster(plugin=plugin)
model = booster.enable_lora(model, lora_config=lora_config)
model_copy = copy.deepcopy(model)
optimizer = AdamW(model.parameters(), lr=0.001)
criterion = loss_fn
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
data = data_gen_fn()
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
output = model(**data)
output = output_transform_fn(output)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
if "lora_" in n1:
# lora modules require gradients, thus updated
assert p1.requires_grad
assert_not_equal(p1.to(p2.device), p2)
else:
if not p1.requires_grad:
assert_equal(p1.to(p2.device), p2)
@clear_cache_before_run()
def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
plugin = TorchDDPPlugin()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
model_save = model_fn()
model_load = copy.deepcopy(model_save)
booster = Booster(plugin=plugin)
model_save = booster.enable_lora(model_save, lora_config=lora_config)
model_save, _, _, _, _ = booster.boost(model_save)
with shared_tempdir() as tempdir:
lora_ckpt_path = os.path.join(tempdir, "ckpt")
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
dist.barrier()
# The Lora checkpoint should be small in size
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
assert checkpoint_size_mb < 1
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path)
model_load, _, _, _, _ = booster.boost(model_load)
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_lora_test()
@rerun_if_address_is_in_use()
def test_torch_ddp_lora():
spawn(run_dist, 2)
Loading…
Cancel
Save