2023-06-13 06:44:40 +00:00
|
|
|
import os
|
|
|
|
|
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
|
|
|
|
import colossalai
|
|
|
|
from colossalai.logging import disable_existing_loggers
|
2023-07-04 01:57:03 +00:00
|
|
|
from colossalai.tensor.d_tensor.api import is_customized_distributed_tensor, is_distributed_tensor
|
|
|
|
from colossalai.testing import (
|
|
|
|
assert_hf_output_close,
|
|
|
|
clear_cache_before_run,
|
|
|
|
parameterize,
|
|
|
|
rerun_if_address_is_in_use,
|
|
|
|
spawn,
|
|
|
|
)
|
2023-06-21 01:32:46 +00:00
|
|
|
from tests.kit.model_zoo import model_zoo
|
|
|
|
from tests.test_shardformer.test_model._utils import build_model, run_forward
|
2023-06-13 06:44:40 +00:00
|
|
|
|
|
|
|
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true'
|
2023-06-19 09:57:37 +00:00
|
|
|
|
2023-06-21 01:32:46 +00:00
|
|
|
|
|
|
|
def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
|
|
|
|
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
|
|
|
|
output_transform_fn, loss_fn)
|
|
|
|
|
|
|
|
# forward check
|
2023-06-19 09:57:37 +00:00
|
|
|
assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4)
|
|
|
|
|
|
|
|
# run backward
|
|
|
|
org_loss.backward()
|
2023-06-13 06:44:40 +00:00
|
|
|
shard_loss.backward()
|
2023-06-19 09:57:37 +00:00
|
|
|
|
2023-06-30 08:16:44 +00:00
|
|
|
assert torch.allclose(org_loss, shard_loss,
|
|
|
|
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
|
|
|
|
|
|
|
|
# unwrap model
|
2023-06-21 01:32:46 +00:00
|
|
|
if hasattr(org_model, 'model'):
|
2023-06-19 09:57:37 +00:00
|
|
|
llama_model = org_model.model
|
|
|
|
shard_llama_model = sharded_model.model
|
2023-06-21 01:32:46 +00:00
|
|
|
else:
|
|
|
|
llama_model = org_model
|
|
|
|
shard_llama_model = sharded_model
|
2023-06-19 09:57:37 +00:00
|
|
|
|
2023-06-30 08:16:44 +00:00
|
|
|
# check attention grad
|
2023-06-19 09:57:37 +00:00
|
|
|
org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad
|
|
|
|
shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad
|
2023-07-04 01:57:03 +00:00
|
|
|
shard_weight = shard_llama_model.layers[0].self_attn.q_proj.weight
|
|
|
|
|
|
|
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
|
|
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
|
|
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
|
|
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
|
|
|
else:
|
|
|
|
all_shard_grad = shard_grad
|
2023-06-30 08:16:44 +00:00
|
|
|
assert torch.allclose(org_grad, all_shard_grad,
|
|
|
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
2023-06-13 06:44:40 +00:00
|
|
|
|
2023-06-30 08:16:44 +00:00
|
|
|
# check embedding grad
|
|
|
|
org_grad = llama_model.embed_tokens.weight.grad
|
|
|
|
shard_grad = shard_llama_model.embed_tokens.weight.grad
|
2023-07-04 01:57:03 +00:00
|
|
|
shard_weight = shard_llama_model.embed_tokens.weight
|
|
|
|
|
|
|
|
if is_distributed_tensor(shard_weight) or is_customized_distributed_tensor(shard_weight):
|
|
|
|
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)]
|
|
|
|
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
|
|
|
|
all_shard_grad = torch.cat(shard_grad_list, dim=0)
|
|
|
|
else:
|
|
|
|
all_shard_grad = shard_grad
|
2023-06-13 06:44:40 +00:00
|
|
|
assert torch.allclose(org_grad, all_shard_grad,
|
|
|
|
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}"
|
|
|
|
|
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
@parameterize('enable_fused_normalization', [True, False])
|
|
|
|
@parameterize('enable_tensor_parallelism', [True, False])
|
2023-07-10 02:48:53 +00:00
|
|
|
@parameterize('use_lazy_init', [False, True])
|
|
|
|
def run_gpt2_llama(enable_fused_normalization, enable_tensor_parallelism, use_lazy_init):
|
2023-06-21 01:32:46 +00:00
|
|
|
sub_model_zoo = model_zoo.get_sub_registry('transformers_llama')
|
|
|
|
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
2023-07-10 02:48:53 +00:00
|
|
|
org_model, sharded_model = build_model(model_fn, enable_fused_normalization, enable_tensor_parallelism,
|
|
|
|
use_lazy_init)
|
2023-06-21 01:32:46 +00:00
|
|
|
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)
|
2023-06-13 06:44:40 +00:00
|
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
2023-07-04 01:57:03 +00:00
|
|
|
def check_llama(rank, world_size, port):
|
|
|
|
disable_existing_loggers()
|
|
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
|
|
|
run_gpt2_llama()
|
|
|
|
|
|
|
|
|
2023-06-13 06:44:40 +00:00
|
|
|
@pytest.mark.dist
|
|
|
|
@rerun_if_address_is_in_use()
|
2023-06-19 09:57:37 +00:00
|
|
|
@clear_cache_before_run()
|
2023-06-13 06:44:40 +00:00
|
|
|
def test_llama():
|
|
|
|
spawn(check_llama, 4)
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
test_llama()
|