You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_fx/test_tracer/test_control_flow.py

58 lines
1.8 KiB

import torch
import torch.nn as nn
from torch.fx import GraphModule
from colossalai.fx import ColoTracer as Tracer
class ControlFlowModel(nn.Module):
def __init__(self):
super().__init__()
self.linear1 = nn.Linear(10, 10)
self.linear2 = nn.Linear(10, 10)
def forward(self, x, y):
x1 = self.linear1(x)
y1 = self.linear2(y)
if x1.dim() == 2:
return x1 + y1
else:
return x1 - y1
def test_control_flow():
model = ControlFlowModel()
tracer = Tracer()
graph_branch_true = tracer.trace(model,
meta_args={
'x': torch.rand(4, 10, device='meta'),
'y': torch.rand(4, 10, device='meta')
})
graph_branch_false = tracer.trace(model,
meta_args={
'x': torch.rand(10, device='meta'),
'y': torch.rand(4, 10, device='meta')
})
gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__)
gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__)
gm_branch_true.recompile()
gm_branch_false.recompile()
# test the true branch
x = torch.rand(4, 10)
y = torch.rand(4, 10)
assert torch.all(model(x, y) == gm_branch_true(x, y))
assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y))
# test the true branch
x = torch.rand(10)
y = torch.rand(4, 10)
assert torch.all(model(x, y) == gm_branch_false(x, y))
assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y))
if __name__ == '__main__':
test_control_flow()