mirror of https://github.com/InternLM/InternLM
74 lines
1.9 KiB
Python
74 lines
1.9 KiB
Python
import multiprocessing as mp
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from internlm.model.embedding import Embedding1D
|
|
from tests.test_model.test_model_internlm import build_environment, seed_all
|
|
|
|
|
|
def check_embedding(args):
|
|
# init
|
|
rank, world_size = args
|
|
device = torch.device("cuda")
|
|
build_environment(rank, world_size)
|
|
rtol, atol = (1e-3, 5e-3)
|
|
vocab_size = 4
|
|
hidden_size = 2
|
|
|
|
# fix seed
|
|
seed_all(1024)
|
|
|
|
# define embedding
|
|
embedding = Embedding1D(
|
|
num_embeddings=vocab_size,
|
|
embedding_dim=hidden_size,
|
|
padding_idx=None,
|
|
)
|
|
|
|
embedding.weight.data.copy_(torch.randn(vocab_size, hidden_size))
|
|
embedding = embedding.to(device)
|
|
|
|
# create input
|
|
input_ids = torch.tensor([[0, 2], [1, 3]]).to(device)
|
|
output_list = []
|
|
for _ in range(10):
|
|
result = embedding(input_ids)
|
|
output_list.append(result)
|
|
|
|
# check only forward logits
|
|
first_output = output_list[0]
|
|
for i in range(1, 10):
|
|
assert torch.equal(first_output, output_list[i])
|
|
|
|
standard_list = [[[-1.4837, 0.2671], [0.6002, -0.5496]], [[-1.8337, -0.1047], [1.0391, 0.2261]]]
|
|
standard_result = torch.tensor(standard_list).to(device)
|
|
|
|
# check output
|
|
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol, equal_nan=True)
|
|
|
|
loss = torch.randn_like(result)
|
|
|
|
# backward
|
|
result.backward(loss)
|
|
|
|
grad = embedding.weight.grad
|
|
standard_glist = [[-0.4461, 0.5602], [0.4353, 1.2988], [-0.0625, -1.3609], [0.9595, -0.1144]]
|
|
standard_grad = torch.tensor(standard_glist).to(device)
|
|
|
|
# check grad
|
|
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol, equal_nan=True)
|
|
|
|
|
|
@pytest.mark.embedding
|
|
def test_embedding():
|
|
ctx = mp.get_context("spawn")
|
|
with ctx.Pool(processes=8) as pool:
|
|
pool.map(check_embedding, [[rank, 8] for rank in range(8)])
|
|
pool.close()
|
|
pool.join()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
pytest.main(["-s", "-q", "test_embedding.py"])
|