mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
105 lines
3.8 KiB
105 lines
3.8 KiB
import copy |
|
import os |
|
from itertools import product |
|
|
|
import torch |
|
from peft import LoraConfig |
|
from torch import distributed as dist |
|
from torch.optim import AdamW |
|
|
|
import colossalai |
|
from colossalai.booster import Booster |
|
from colossalai.booster.plugin import LowLevelZeroPlugin, TorchDDPPlugin |
|
from colossalai.testing import check_state_dict_equal, clear_cache_before_run, rerun_if_address_is_in_use, spawn |
|
from tests.kit.model_zoo import model_zoo |
|
from tests.test_checkpoint_io.utils import shared_tempdir |
|
|
|
|
|
@clear_cache_before_run() |
|
def check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type): |
|
model = model_fn() |
|
lora_config = LoraConfig(task_type=task_type, r=8, lora_alpha=32, lora_dropout=0.1) |
|
|
|
test_plugins = [TorchDDPPlugin(), LowLevelZeroPlugin()] |
|
test_configs = [ |
|
{ |
|
"lora_config": lora_config, |
|
"quantize": False, |
|
}, |
|
{ |
|
"lora_config": lora_config, |
|
"quantize": True, |
|
}, |
|
] |
|
for plugin, test_config in product(test_plugins, test_configs): |
|
# checkpoint loaded model |
|
model_save = model_fn() |
|
model_load = copy.deepcopy(model_save) |
|
|
|
optimizer = AdamW(model.parameters(), lr=0.001) |
|
criterion = loss_fn |
|
|
|
booster = Booster(plugin=plugin) |
|
model_save = booster.enable_lora(model_save, **test_config) |
|
model_save, optimizer, criterion, _, _ = booster.boost(model_save, optimizer, criterion) |
|
|
|
with shared_tempdir() as tempdir: |
|
lora_ckpt_path = os.path.join(tempdir, "ckpt") |
|
booster.save_lora_as_pretrained(model_save, lora_ckpt_path) |
|
dist.barrier() |
|
|
|
# The Lora checkpoint should be small in size |
|
checkpoint_size_mb = os.path.getsize(os.path.join(lora_ckpt_path, "adapter_model.bin")) / (1024 * 1024) |
|
assert checkpoint_size_mb < 1 |
|
|
|
model_load = booster.enable_lora(model_load, pretrained_dir=lora_ckpt_path, **test_config) |
|
model_load, _, _, _, _ = booster.boost(model_load) |
|
|
|
check_state_dict_equal(model_save.state_dict(), model_load.state_dict()) |
|
|
|
# test fwd bwd correctness |
|
test_model = model_load |
|
model_copy = copy.deepcopy(model_load) |
|
|
|
data = data_gen_fn() |
|
data = { |
|
k: v.to("cuda") if torch.is_tensor(v) or "Tensor" in v.__class__.__name__ else v for k, v in data.items() |
|
} |
|
|
|
output = test_model(**data) |
|
output = output_transform_fn(output) |
|
loss = criterion(output) |
|
|
|
booster.backward(loss, optimizer) |
|
optimizer.clip_grad_by_norm(1.0) |
|
optimizer.step() |
|
|
|
for (n1, p1), (n2, p2) in zip(test_model.named_parameters(), model_copy.named_parameters()): |
|
if "lora_" in n1: |
|
# lora modules require gradients, thus updated |
|
assert p1.requires_grad |
|
assert not torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3) |
|
else: |
|
if not p1.requires_grad: |
|
torch.testing.assert_close(p1.to(p2.device).to(p2.dtype), p2, atol=5e-3, rtol=5e-3) |
|
|
|
|
|
def run_lora_test(): |
|
sub_model_zoo = model_zoo.get_sub_registry("transformers_llama") |
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): |
|
task_type = None |
|
if name == "transformers_llama_for_casual_lm": |
|
task_type = "CAUSAL_LM" |
|
if name == "transformers_llama_for_sequence_classification": |
|
task_type = "SEQ_CLS" |
|
check_fwd_bwd(model_fn, data_gen_fn, output_transform_fn, loss_fn, task_type) |
|
|
|
|
|
def run_dist(rank, world_size, port): |
|
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl") |
|
run_lora_test() |
|
|
|
|
|
@rerun_if_address_is_in_use() |
|
def test_torch_ddp_lora(): |
|
spawn(run_dist, 2)
|
|
|