2022-12-02 10:13:20 +00:00
|
|
|
import pytest
|
|
|
|
import torch
|
|
|
|
import transformers
|
2023-03-30 09:47:24 +00:00
|
|
|
from topo_utils import MLP, check_topo, split_model_and_get_DAG
|
2022-12-02 10:13:20 +00:00
|
|
|
|
|
|
|
BATCH_SIZE = 1
|
|
|
|
SEQ_LENGHT = 16
|
|
|
|
|
2023-03-30 09:47:24 +00:00
|
|
|
|
|
|
|
@pytest.mark.skip('ShapeProp is not compatible with PyTorch 1.11.0')
|
2022-12-02 10:13:20 +00:00
|
|
|
def test_opt():
|
|
|
|
MODEL_LIST = [
|
|
|
|
MLP,
|
|
|
|
transformers.OPTModel,
|
|
|
|
]
|
|
|
|
|
|
|
|
CONFIGS = [
|
2023-03-30 09:47:24 +00:00
|
|
|
{
|
|
|
|
'dim': 10,
|
|
|
|
'layers': 12
|
|
|
|
},
|
2022-12-02 10:13:20 +00:00
|
|
|
transformers.OPTConfig(vocab_size=100, hidden_size=128, num_hidden_layers=4, num_attention_heads=4),
|
|
|
|
]
|
|
|
|
|
|
|
|
def data_gen_MLP():
|
|
|
|
x = torch.zeros((16, 10))
|
|
|
|
kwargs = dict(x=x)
|
|
|
|
return kwargs
|
2023-03-30 09:47:24 +00:00
|
|
|
|
2022-12-02 10:13:20 +00:00
|
|
|
def data_gen_OPT():
|
|
|
|
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
|
|
|
attention_mask = torch.zeros((BATCH_SIZE, SEQ_LENGHT), dtype=torch.int64)
|
|
|
|
kwargs = dict(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
|
return kwargs
|
2023-03-30 09:47:24 +00:00
|
|
|
|
2022-12-02 10:13:20 +00:00
|
|
|
DATAGEN = [
|
2023-03-30 09:47:24 +00:00
|
|
|
data_gen_MLP,
|
2022-12-02 10:13:20 +00:00
|
|
|
data_gen_OPT,
|
|
|
|
]
|
|
|
|
|
|
|
|
for i, model_cls in enumerate(MODEL_LIST):
|
|
|
|
model = model_cls(config=CONFIGS[i])
|
|
|
|
top_mod, topo = split_model_and_get_DAG(model, DATAGEN[i])
|
|
|
|
# print(f'{top_mod=}\n----\n{topo=}')
|
|
|
|
check_topo(top_mod, topo)
|
|
|
|
|
2023-03-30 09:47:24 +00:00
|
|
|
|
2022-12-02 10:13:20 +00:00
|
|
|
if __name__ == '__main__':
|
2023-03-30 09:47:24 +00:00
|
|
|
test_opt()
|