mirror of https://github.com/hpcaitech/ColossalAI
[lora] lora support hybrid parallel plugin (#5956)
* lora support hybrid plugin * fix * fix * fix * fixcolossalchat
parent
19d1510ea2
commit
75c963686f
|
@ -30,6 +30,7 @@ from colossalai.interface.optimizer import DistributedOptim
|
|||
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.quantization import BnbQuantizationConfig, quantize_model
|
||||
from colossalai.shardformer import GradientCheckpointConfig, ShardConfig, ShardFormer
|
||||
from colossalai.shardformer.layer.utils import SeqParallelUtils
|
||||
from colossalai.shardformer.policies.base_policy import Policy
|
||||
|
@ -1187,7 +1188,7 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return True
|
||||
|
||||
def support_lora(self) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
def control_checkpoint_io(self) -> bool:
|
||||
return True
|
||||
|
@ -1415,6 +1416,24 @@ class HybridParallelPlugin(PipelinePluginBase):
|
|||
return optimizer.no_sync() if isinstance(optimizer, HybridParallelZeroOptimizer) else model.no_sync()
|
||||
|
||||
def enable_lora(
|
||||
self, model: Module, pretrained_dir: Optional[str] = None, lora_config: Optional[Dict] = None
|
||||
self,
|
||||
model: Module,
|
||||
pretrained_dir: Optional[str] = None,
|
||||
lora_config: Optional[Dict] = None,
|
||||
bnb_quantization_config: Optional[BnbQuantizationConfig] = None,
|
||||
) -> Module:
|
||||
raise NotImplementedError
|
||||
from peft import PeftModel, get_peft_model
|
||||
|
||||
assert not isinstance(model, HybridParallelModule), "Lora should be enabled before boosting the model."
|
||||
assert self.pp_size == 1 and self.tp_size == 1
|
||||
self.lora_enabled = True
|
||||
warnings.warn("You have enabled LoRa training. Please check the hyperparameters such as lr")
|
||||
|
||||
if bnb_quantization_config is not None:
|
||||
model = quantize_model(model, bnb_quantization_config)
|
||||
|
||||
if pretrained_dir is None:
|
||||
peft_model = get_peft_model(model, lora_config)
|
||||
else:
|
||||
peft_model = PeftModel.from_pretrained(model, pretrained_dir, is_trainable=True)
|
||||
return peft_model
|
||||
|
|
|
@ -947,3 +947,17 @@ class HybridParallelCheckpointIO(GeneralCheckpointIO):
|
|||
state_[k] = v.detach().clone().to(device)
|
||||
|
||||
return state_
|
||||
|
||||
def save_lora_as_pretrained(self, model, checkpoint, use_safetensors):
|
||||
if os.path.isfile(checkpoint):
|
||||
logging.error(f"Provided path ({checkpoint}) should be a directory, not a file")
|
||||
return
|
||||
from peft import PeftModel
|
||||
|
||||
assert isinstance(model, ModelWrapper), "Please boost the model before saving!"
|
||||
model._force_wait_all_gather()
|
||||
peft_model = model.unwrap()
|
||||
assert isinstance(
|
||||
peft_model, PeftModel
|
||||
), "The model doesn't have lora adapters, please enable lora before saving."
|
||||
return peft_model.save_pretrained(checkpoint, safe_serialization=use_safetensors)
|
||||
|
|
|
@ -243,6 +243,9 @@ def _fullname(obj):
|
|||
# patch custom models which are not in transformers
|
||||
# it can be like 'transformers_modules.THUDM.chatglm3-6b.103caa40027ebfd8450289ca2f278eac4ff26405.modeling_chatglm' (from huggingface hub)
|
||||
# or like 'transformers_modules.chatglm.modeling_chatglm' (from local directory)
|
||||
if module.startswith("peft"):
|
||||
klass = obj.base_model.model.__class__
|
||||
module = klass.__module__
|
||||
if module.startswith("transformers_modules"):
|
||||
split_module = module.split(".")
|
||||
if len(split_module) >= 2:
|
||||
|
|
|
@ -9,7 +9,8 @@ from torch.optim import AdamW
|
|||
|
||||
import colossalai
|
||||
from colossalai.booster import Booster
|
||||
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin import HybridParallelPlugin, LowLevelZeroPlugin, TorchDDPPlugin
|
||||
from colossalai.booster.plugin.hybrid_parallel_plugin import HybridParallelModule
|
||||
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.test_checkpoint_io.utils import shared_tempdir
|
||||
|
@ -20,7 +21,7 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
|
|||
model = model_fn()
|
||||
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
|
||||
|
||||
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
|
||||
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin(), HybridParallelPlugin(tp_size=1, pp_size=1)]
|
||||
test_configs = [
|
||||
{
|
||||
"lora_config": lora_config,
|
||||
|
@ -59,6 +60,8 @@ def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type
|
|||
|
||||
# test fwd bwd correctness
|
||||
test_model = model_load
|
||||
if isinstance(model_load, HybridParallelModule):
|
||||
model_load = model_load.module.module
|
||||
model_copy = copy.deepcopy(model_load)
|
||||
|
||||
data = data_gen_fn()
|
||||
|
|
Loading…
Reference in New Issue