From e76f76c08b0c76b5800861d32469a5834b3416f6 Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Thu, 28 Apr 2022 10:57:14 +0800 Subject: [PATCH] [Tensor] test parameters() as member function (#896) --- tests/test_tensor/test_model.py | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/tests/test_tensor/test_model.py b/tests/test_tensor/test_model.py index cc0bf73b5..1561cc177 100644 --- a/tests/test_tensor/test_model.py +++ b/tests/test_tensor/test_model.py @@ -98,6 +98,35 @@ def run_1d_col_tp(): if i > 5: break +# Test the overrided parameters() and named_parameters() member functions +def test_model_parameters(): + # build a module with 2 Linear, 4 parameters in total. + class Net(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fcs = torch.nn.Sequential(torch.nn.Linear(2, 3), torch.nn.Linear(3, 2)) + self.extra_param = torch.nn.Parameter(torch.randn(2)) + + with ColoInitContext(device=get_current_device()): + model = Net() + + param_cnt = 0 + for name, p in model.named_parameters(): + param_cnt += 1 + assert param_cnt == 5 + + param_cnt = 0 + for name, p in model.named_parameters(recurse=False): + param_cnt += 1 + assert param_cnt == 1 + + param_cnt = 0 + for p in model.fcs[0].parameters(recurse=False): + param_cnt += 1 + assert param_cnt == 2 + + def run_1d_row_tp(): # A simple net with two stacked nn.Linear get_components_func = non_distributed_component_funcs.get_callable('simple_net') @@ -179,4 +208,5 @@ def test_simple_net(world_size): if __name__ == '__main__': - test_simple_net() + # test_simple_net() + test_model_parameters()