InternLM/tests/test_model/test_embedding.py

66 lines
1.7 KiB
Python

import multiprocessing as mp
import pytest
import torch
from internlm.model.embedding import Embedding1D
from tests.test_model.test_model_internlm import build_environment, seed_all
def check_embedding(args):
# init
rank, world_size = args
device = torch.device("cuda")
build_environment(rank, world_size)
rtol, atol = (1e-3, 5e-3)
vocab_size = 4
hidden_size = 2
# fix seed
seed_all(1024)
# define embedding
embedding = Embedding1D(
num_embeddings=vocab_size,
embedding_dim=hidden_size,
padding_idx=None,
)
embedding.weight.data.copy_(torch.randn(vocab_size, hidden_size))
embedding = embedding.to(device)
# create input
input_ids = torch.tensor([[0, 2], [1, 3]]).to(device)
result = embedding(input_ids)
standard_list = [[[-1.4837, 0.2671], [0.6002, -0.5496]], [[-1.8337, -0.1047], [1.0391, 0.2261]]]
standard_result = torch.tensor(standard_list).to(device)
# check output
assert torch.allclose(result, standard_result, rtol=rtol, atol=atol, equal_nan=True)
loss = torch.randn_like(result)
# backward
result.backward(loss)
grad = embedding.weight.grad
standard_glist = [[-0.4461, 0.5602], [0.4353, 1.2988], [-0.0625, -1.3609], [0.9595, -0.1144]]
standard_grad = torch.tensor(standard_glist).to(device)
# check grad
assert torch.allclose(grad, standard_grad, rtol=rtol, atol=atol, equal_nan=True)
@pytest.mark.embedding
def test_embedding():
ctx = mp.get_context("spawn")
with ctx.Pool(processes=8) as pool:
pool.map(check_embedding, [[rank, 8] for rank in range(8)])
pool.close()
pool.join()
if __name__ == "__main__":
pytest.main(["-s", "-q", "test_embedding.py"])