2022-06-29 07:05:25 +00:00
|
|
|
import torch
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.fx import GraphModule
|
2023-04-06 06:51:35 +00:00
|
|
|
|
2022-06-29 07:05:25 +00:00
|
|
|
from colossalai.fx import ColoTracer as Tracer
|
2023-04-06 06:51:35 +00:00
|
|
|
from colossalai.testing import clear_cache_before_run
|
2022-06-29 07:05:25 +00:00
|
|
|
|
|
|
|
|
|
|
|
class ControlFlowModel(nn.Module):
|
|
|
|
|
|
|
|
def __init__(self):
|
|
|
|
super().__init__()
|
|
|
|
self.linear1 = nn.Linear(10, 10)
|
|
|
|
self.linear2 = nn.Linear(10, 10)
|
|
|
|
|
|
|
|
def forward(self, x, y):
|
|
|
|
x1 = self.linear1(x)
|
|
|
|
y1 = self.linear2(y)
|
|
|
|
|
|
|
|
if x1.dim() == 2:
|
|
|
|
return x1 + y1
|
|
|
|
else:
|
|
|
|
return x1 - y1
|
|
|
|
|
|
|
|
|
2023-04-06 06:51:35 +00:00
|
|
|
@clear_cache_before_run()
|
2022-06-29 07:05:25 +00:00
|
|
|
def test_control_flow():
|
|
|
|
model = ControlFlowModel()
|
|
|
|
tracer = Tracer()
|
|
|
|
graph_branch_true = tracer.trace(model,
|
|
|
|
meta_args={
|
|
|
|
'x': torch.rand(4, 10, device='meta'),
|
|
|
|
'y': torch.rand(4, 10, device='meta')
|
|
|
|
})
|
|
|
|
graph_branch_false = tracer.trace(model,
|
|
|
|
meta_args={
|
|
|
|
'x': torch.rand(10, device='meta'),
|
|
|
|
'y': torch.rand(4, 10, device='meta')
|
|
|
|
})
|
|
|
|
|
|
|
|
gm_branch_true = GraphModule(model, graph_branch_true, model.__class__.__name__)
|
|
|
|
gm_branch_false = GraphModule(model, graph_branch_false, model.__class__.__name__)
|
|
|
|
gm_branch_true.recompile()
|
|
|
|
gm_branch_false.recompile()
|
|
|
|
|
|
|
|
# test the true branch
|
|
|
|
x = torch.rand(4, 10)
|
|
|
|
y = torch.rand(4, 10)
|
|
|
|
assert torch.all(model(x, y) == gm_branch_true(x, y))
|
|
|
|
assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y))
|
|
|
|
|
|
|
|
# test the true branch
|
|
|
|
x = torch.rand(10)
|
|
|
|
y = torch.rand(4, 10)
|
|
|
|
assert torch.all(model(x, y) == gm_branch_false(x, y))
|
|
|
|
assert torch.all(gm_branch_false(x, y) != gm_branch_true(x, y))
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
test_control_flow()
|