mirror of https://github.com/hpcaitech/ColossalAI
[lora] add lora APIs for booster, support lora for TorchDDP (#4981)
* add apis and peft requirement * add liscense and implement apis * add checkpointio apis * add torchddp fwd_bwd test * add support_lora methods * add checkpointio test and debug * delete unneeded codes * remove peft from LICENSE * add concrete methods for enable_lora * simplify enable_lora api * fix requirementspull/5670/head
parent
c1594e4bad
commit
14b0d4c7e5
|
@ -8,6 +8,14 @@ from torch.optim import Optimizer
|
||||||
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
|
||||||
from torch.utils.data import DataLoader
|
from torch.utils.data import DataLoader
|
||||||
|
|
||||||
|
SUPPORT_PEFT = False
|
||||||
|
try:
|
||||||
|
import peft
|
||||||
|
|
||||||
|
SUPPORT_PEFT = True
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
import colossalai.interface.pretrained as pretrained_utils
|
import colossalai.interface.pretrained as pretrained_utils
|
||||||
from colossalai.checkpoint_io import GeneralCheckpointIO
|
from colossalai.checkpoint_io import GeneralCheckpointIO
|
||||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||||
|
@ -221,6 +229,38 @@ class Booster:
|
||||||
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
|
assert self.plugin.support_no_sync(), f"The plugin {self.plugin.__class__.__name__} does not support no_sync."
|
||||||
return self.plugin.no_sync(model, optimizer)
|
return self.plugin.no_sync(model, optimizer)
|
||||||
|
|
||||||
|
def enable_lora(
|
||||||
|
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: "peft.LoraConfig" = None
|
||||||
|
) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Wrap the passed in model with LoRA modules for training. If pretrained directory is provided, lora configs and weights are loaded from that directory.
|
||||||
|
Lora in ColossalAI is implemented using Huggingface peft library, so the arguments for Lora configuration are same as those of peft.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (nn.Module): The model to be appended with LoRA modules.
|
||||||
|
pretrained_dir(str, optional): The path to the pretrained directory, can be a local directory
|
||||||
|
or model_id of a PEFT configuration hosted inside a model repo on the Hugging Face Hub.
|
||||||
|
When set to None, create new lora configs and weights for the model using the passed in lora_config. Defaults to None.
|
||||||
|
lora_config: (peft.LoraConfig, optional): Passed in LoraConfig for peft. Defaults to None.
|
||||||
|
"""
|
||||||
|
if not SUPPORT_PEFT:
|
||||||
|
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
||||||
|
|
||||||
|
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
||||||
|
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
||||||
|
if pretrained_dir is None:
|
||||||
|
assert (
|
||||||
|
lora_config is not None
|
||||||
|
), "Please provide configuration for Lora when pretrained directory path isn't passed in."
|
||||||
|
assert isinstance(
|
||||||
|
lora_config, peft.LoraConfig
|
||||||
|
), "The passed in configuration should be an instance of peft.LoraConfig."
|
||||||
|
if lora_config is None:
|
||||||
|
assert (
|
||||||
|
pretrained_dir is not None
|
||||||
|
), "Please provide pretrained directory path if not passing in lora configuration."
|
||||||
|
return self.plugin.enable_lora(model, pretrained_dir, lora_config)
|
||||||
|
|
||||||
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
def load_model(self, model: Union[nn.Module, ModelWrapper], checkpoint: str, strict: bool = True) -> None:
|
||||||
"""Load model from checkpoint.
|
"""Load model from checkpoint.
|
||||||
|
|
||||||
|
@ -323,3 +363,20 @@ class Booster:
|
||||||
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
checkpoint (str): Path to the checkpoint. It must be a local file path.
|
||||||
"""
|
"""
|
||||||
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
self.checkpoint_io.load_lr_scheduler(lr_scheduler, checkpoint)
|
||||||
|
|
||||||
|
def save_lora_as_pretrained(
|
||||||
|
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||||
|
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||||
|
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||||
|
"""
|
||||||
|
if not SUPPORT_PEFT:
|
||||||
|
raise ImportError("Please install Huggingface Peft library to enable lora features in ColossalAI!")
|
||||||
|
assert self.plugin is not None, f"Lora can only be enabled when a plugin is provided."
|
||||||
|
assert self.plugin.support_lora(), f"The plugin {self.plugin.__class__.__name__} does not support lora."
|
||||||
|
self.checkpoint_io.save_lora_as_pretrained(model, checkpoint, use_safetensors)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Iterator, List, Optional, Tuple
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -444,6 +444,9 @@ class GeminiPlugin(DPPluginBase):
|
||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def support_lora(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def control_precision(self) -> bool:
|
def control_precision(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -573,3 +576,8 @@ class GeminiPlugin(DPPluginBase):
|
||||||
|
|
||||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def enable_lora(
|
||||||
|
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||||
|
) -> nn.Module:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -4,7 +4,7 @@ import warnings
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Any, Callable, Iterator, List, Optional, OrderedDict, Tuple, Union
|
from typing import Any, Callable, Dict, Iterator, List, Optional, OrderedDict, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -1156,6 +1156,9 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def support_lora(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def control_checkpoint_io(self) -> bool:
|
def control_checkpoint_io(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -1356,3 +1359,8 @@ class HybridParallelPlugin(PipelinePluginBase):
|
||||||
self.zero_stage != 2
|
self.zero_stage != 2
|
||||||
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
|
), "ZERO2 is not compatible with no_sync function, please run gradient accumulation with gradient synchronization allowed."
|
||||||
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||||
|
|
||||||
|
def enable_lora(
|
||||||
|
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||||
|
) -> Module:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -3,7 +3,7 @@ import os
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Callable, Iterator, List, Optional, Tuple
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -296,6 +296,9 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
return self.stage == 1
|
return self.stage == 1
|
||||||
|
|
||||||
|
def support_lora(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def control_precision(self) -> bool:
|
def control_precision(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -337,3 +340,8 @@ class LowLevelZeroPlugin(DPPluginBase):
|
||||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
assert isinstance(optimizer, LowLevelZeroOptimizer)
|
||||||
return optimizer.no_sync()
|
return optimizer.no_sync()
|
||||||
|
|
||||||
|
def enable_lora(
|
||||||
|
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||||
|
) -> nn.Module:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Callable, Iterator, List, Optional, Tuple
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
@ -33,6 +33,10 @@ class Plugin(ABC):
|
||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def support_lora(self) -> bool:
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def configure(
|
def configure(
|
||||||
self,
|
self,
|
||||||
|
@ -63,6 +67,12 @@ class Plugin(ABC):
|
||||||
Context manager to disable gradient synchronization.
|
Context manager to disable gradient synchronization.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def enable_lora(self, model: nn.Module, pretrained_dir: str, lora_config: Dict) -> nn.Module:
|
||||||
|
"""
|
||||||
|
Add LoRA modules to the model passed in. Should only be called in booster.enable_lora().
|
||||||
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def prepare_dataloader(
|
def prepare_dataloader(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Callable, Iterator, List, Optional, Tuple
|
from typing import Callable, Dict, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
@ -116,6 +116,22 @@ class TorchDDPCheckpointIO(GeneralCheckpointIO):
|
||||||
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
assert isinstance(optimizer, OptimizerWrapper), "Please boost the optimizer before loading!"
|
||||||
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
super().load_sharded_optimizer(optimizer.unwrap(), index_file_path, prefix)
|
||||||
|
|
||||||
|
def save_lora_as_pretrained(
|
||||||
|
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save the lora adapters and adapter configuration file to checkpoint directory.
|
||||||
|
"""
|
||||||
|
from peft import PeftModel
|
||||||
|
|
||||||
|
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||||
|
if self.coordinator.is_master():
|
||||||
|
peft_model = model.unwrap()
|
||||||
|
assert isinstance(
|
||||||
|
peft_model, PeftModel
|
||||||
|
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||||
|
peft_model.save_pretrained(save_directory=checkpoint, safe_serialization=use_safetensors)
|
||||||
|
|
||||||
|
|
||||||
class TorchDDPModel(ModelWrapper):
|
class TorchDDPModel(ModelWrapper):
|
||||||
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
|
def __init__(self, module: nn.Module, *args, **kwargs) -> None:
|
||||||
|
@ -173,6 +189,9 @@ class TorchDDPPlugin(DPPluginBase):
|
||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def support_lora(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
def control_precision(self) -> bool:
|
def control_precision(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -216,3 +235,14 @@ class TorchDDPPlugin(DPPluginBase):
|
||||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
|
assert isinstance(model, TorchDDPModel), "Model is not boosted by TorchDDPPlugin."
|
||||||
return model.module.no_sync()
|
return model.module.no_sync()
|
||||||
|
|
||||||
|
def enable_lora(
|
||||||
|
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||||
|
) -> nn.Module:
|
||||||
|
from peft import PeftModel, get_peft_model
|
||||||
|
|
||||||
|
assert not isinstance(model, TorchDDPModel), "Lora should be enabled before boosting the model."
|
||||||
|
if pretrained_dir is None:
|
||||||
|
return get_peft_model(model, lora_config)
|
||||||
|
else:
|
||||||
|
return PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
||||||
|
|
|
@ -2,7 +2,7 @@ import logging
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Iterable, Iterator, List, Optional, Tuple
|
from typing import Callable, Dict, Iterable, Iterator, List, Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
@ -318,6 +318,9 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||||
def support_no_sync(self) -> bool:
|
def support_no_sync(self) -> bool:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
def support_lora(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
def no_sync(self, model: nn.Module, optimizer: OptimizerWrapper) -> Iterator[None]:
|
||||||
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
raise NotImplementedError("Torch fsdp no_sync func not supported yet.")
|
||||||
|
|
||||||
|
@ -361,3 +364,8 @@ class TorchFSDPPlugin(DPPluginBase):
|
||||||
|
|
||||||
def get_checkpoint_io(self) -> CheckpointIO:
|
def get_checkpoint_io(self) -> CheckpointIO:
|
||||||
return TorchFSDPCheckpointIO()
|
return TorchFSDPCheckpointIO()
|
||||||
|
|
||||||
|
def enable_lora(
|
||||||
|
self, model: nn.Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||||
|
) -> nn.Module:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -335,3 +335,20 @@ class CheckpointIO(ABC):
|
||||||
"""
|
"""
|
||||||
state_dict = torch.load(checkpoint)
|
state_dict = torch.load(checkpoint)
|
||||||
lr_scheduler.load_state_dict(state_dict)
|
lr_scheduler.load_state_dict(state_dict)
|
||||||
|
|
||||||
|
# ================================================================================
|
||||||
|
# Abstract method for lora saving implementation.
|
||||||
|
# ================================================================================
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def save_lora_as_pretrained(
|
||||||
|
self, model: Union[nn.Module, ModelWrapper], checkpoint: str, use_safetensors: bool = False
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
Save the lora adapters and adapter configuration file to a pretrained checkpoint directory.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model (Union[nn.Module, ModelWrapper]): A model boosted by Booster.
|
||||||
|
checkpoint (str): Path to the checkpoint directory. It must be a local path.
|
||||||
|
use_safetensors (bool, optional): Whether to use safe tensors when saving. Defaults to False.
|
||||||
|
"""
|
||||||
|
|
|
@ -228,3 +228,6 @@ class GeneralCheckpointIO(CheckpointIO):
|
||||||
self.__class__.__name__, "\n\t".join(error_msgs)
|
self.__class__.__name__, "\n\t".join(error_msgs)
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -5,7 +5,7 @@ git+https://github.com/hpcaitech/pytest-testmon
|
||||||
torchvision
|
torchvision
|
||||||
timm
|
timm
|
||||||
titans
|
titans
|
||||||
torchaudio
|
torchaudio>=0.13.1
|
||||||
torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package is updated every day. We fix the version to a specific date to avoid breaking changes.
|
torchx-nightly==2022.6.29 # torchrec 0.2.0 requires torchx-nightly. This package is updated every day. We fix the version to a specific date to avoid breaking changes.
|
||||||
torchrec==0.2.0
|
torchrec==0.2.0
|
||||||
contexttimer
|
contexttimer
|
||||||
|
@ -18,4 +18,5 @@ flash_attn
|
||||||
datasets
|
datasets
|
||||||
pydantic
|
pydantic
|
||||||
ray
|
ray
|
||||||
|
peft
|
||||||
#auto-gptq now not support torch1.12
|
#auto-gptq now not support torch1.12
|
||||||
|
|
|
@ -0,0 +1,108 @@
|
||||||
|
import copy
|
||||||
|
import os
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from peft import LoraConfig
|
||||||
|
from torch import distributed as dist
|
||||||
|
from torch.optim import AdamW
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.booster import Booster
|
||||||
|
from colossalai.booster.plugin import TorchDDPPlugin
|
||||||
|
from colossalai.testing import (
|
||||||
|
assert_equal,
|
||||||
|
assert_not_equal,
|
||||||
|
check_state_dict_equal,
|
||||||
|
clear_cache_before_run,
|
||||||
|
rerun_if_address_is_in_use,
|
||||||
|
spawn,
|
||||||
|
)
|
||||||
|
from tests.kit.model_zoo import model_zoo
|
||||||
|
from tests.test_checkpoint_io.utils import shared_tempdir
|
||||||
|
|
||||||
|
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
||||||
|
model = model_fn()
|
||||||
|
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||||
|
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
|
||||||
|
model = booster.enable_lora(model, lora_config=lora_config)
|
||||||
|
model_copy = copy.deepcopy(model)
|
||||||
|
|
||||||
|
optimizer = AdamW(model.parameters(), lr=0.001)
|
||||||
|
criterion = loss_fn
|
||||||
|
|
||||||
|
model, optimizer, criterion, _, _ = booster.boost(model, optimizer, criterion)
|
||||||
|
|
||||||
|
data = data_gen_fn()
|
||||||
|
data = {k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()}
|
||||||
|
|
||||||
|
output = model(**data)
|
||||||
|
output = output_transform_fn(output)
|
||||||
|
loss = criterion(output)
|
||||||
|
|
||||||
|
booster.backward(loss, optimizer)
|
||||||
|
optimizer.clip_grad_by_norm(1.0)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
for (n1, p1), (n2, p2) in zip(model.named_parameters(), model_copy.named_parameters()):
|
||||||
|
if "lora_" in n1:
|
||||||
|
# lora modules require gradients, thus updated
|
||||||
|
assert p1.requires_grad
|
||||||
|
assert_not_equal(p1.to(p2.device), p2)
|
||||||
|
else:
|
||||||
|
if not p1.requires_grad:
|
||||||
|
assert_equal(p1.to(p2.device), p2)
|
||||||
|
|
||||||
|
|
||||||
|
@clear_cache_before_run()
|
||||||
|
def check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
|
||||||
|
plugin = TorchDDPPlugin()
|
||||||
|
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||||
|
|
||||||
|
model_save = model_fn()
|
||||||
|
model_load = copy.deepcopy(model_save)
|
||||||
|
|
||||||
|
booster = Booster(plugin=plugin)
|
||||||
|
model_save = booster.enable_lora(model_save, lora_config=lora_config)
|
||||||
|
model_save, _, _, _, _ = booster.boost(model_save)
|
||||||
|
|
||||||
|
with shared_tempdir() as tempdir:
|
||||||
|
lora_ckpt_path = os.path.join(tempdir, "ckpt")
|
||||||
|
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
|
||||||
|
dist.barrier()
|
||||||
|
|
||||||
|
# The Lora checkpoint should be small in size
|
||||||
|
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
|
||||||
|
assert checkpoint_size_mb < 1
|
||||||
|
|
||||||
|
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path)
|
||||||
|
model_load, _, _, _, _ = booster.boost(model_load)
|
||||||
|
|
||||||
|
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
|
||||||
|
|
||||||
|
|
||||||
|
def run_lora_test():
|
||||||
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
|
||||||
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||||
|
task_type = None
|
||||||
|
if name == "transformers_llama_for_casual_lm":
|
||||||
|
task_type = "CAUSAL_LM"
|
||||||
|
if name == "transformers_llama_for_sequence_classification":
|
||||||
|
task_type = "SEQ_CLS"
|
||||||
|
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
||||||
|
check_checkpoint(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
|
||||||
|
|
||||||
|
|
||||||
|
def run_dist(rank, world_size, port):
|
||||||
|
config = {}
|
||||||
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
|
||||||
|
run_lora_test()
|
||||||
|
|
||||||
|
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_torch_ddp_lora():
|
||||||
|
spawn(run_dist, 2)
|
Loading…
Reference in New Issue