You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_lora/test_lora.py

107 lines
3.8 KiB

import copy
import os
from itertools import product
import torch
from peft import LoraConfig
from torch import distributed as dist
from torch.optim import AdamW
import colossalai
from colossalai.booster import Booster
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_checkpoint_io.utils import shared_tempdir
@clear_cache_before_run()
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type):
model = model_fn()
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1)
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()]
test_configs = [
{
"lora_config": lora_config,
"quantize": False,
},
{
"lora_config": lora_config,
"quantize": True,
},
]
for plugin, test_config in product(test_plugins, test_configs):
# checkpoint loaded model
model_save = model_fn()
model_load = copy.deepcopy(model_save)
optimizer = AdamW(model.parameters(), lr=0.001)
criterion = loss_fn
booster = Booster(plugin=plugin)
model_save = booster.enable_lora(model_save, **test_config)
model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion)
with shared_tempdir() as tempdir:
lora_ckpt_path = os.path.join(tempdir, "ckpt")
booster.save_lora_as_pretrained(model_save, lora_ckpt_path)
dist.barrier()
# The Lora checkpoint should be small in size
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024)
assert checkpoint_size_mb < 1
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config)
model_load, _, _, _, _ = booster.boost(model_load)
check_state_dict_equal(model_save.state_dict(), model_load.state_dict())
# test fwd bwd correctness
test_model = model_load
model_copy = copy.deepcopy(model_load)
data = data_gen_fn()
data = {
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items()
}
output = test_model(**data)
output = output_transform_fn(output)
loss = criterion(output)
booster.backward(loss, optimizer)
optimizer.clip_grad_by_norm(1.0)
optimizer.step()
for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()):
if "lora_" in n1:
# lora modules require gradients, thus updated
assert p1.requires_grad
assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
else:
if not p1.requires_grad:
torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3)
def run_lora_test():
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama")
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
task_type = None
if name == "transformers_llama_for_casual_lm":
task_type = "CAUSAL_LM"
if name == "transformers_llama_for_sequence_classification":
task_type = "SEQ_CLS"
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type)
def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_lora_test()
@rerun_if_address_is_in_use()
def test_torch_ddp_lora():
spawn(run_dist, 2)