diff --git a/tests/test_layers/test_2d/checks_2d/check_operation_2d.py b/tests/test_layers/test_2d/checks_2d/check_operation_2d.py index 83442df70..a5e37b1ec 100644 --- a/tests/test_layers/test_2d/checks_2d/check_operation_2d.py +++ b/tests/test_layers/test_2d/checks_2d/check_operation_2d.py @@ -41,18 +41,8 @@ def check_AB(): out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH) - out = Matmul_AB_2D.apply( - A, B, - DEPTH, - out_shape, - i, j, - ParallelMode.PARALLEL_2D_ROW, - ParallelMode.PARALLEL_2D_COL, - data_parallel_rank, - pipeline_parallel_rank, - pipeline_parallel_size, - tensor_parallel_size - ) + out = Matmul_AB_2D.apply(A, B, DEPTH, out_shape, i, j, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, + data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) A_master = A_master.clone() @@ -119,17 +109,9 @@ def check_ABT(): B = B.clone() B.requires_grad = True - out = Matmul_ABT_2D.apply( - C, B, - DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), - i, j, - ParallelMode.PARALLEL_2D_ROW, - ParallelMode.PARALLEL_2D_COL, - data_parallel_rank, - pipeline_parallel_rank, - pipeline_parallel_size, - tensor_parallel_size - ) + out = Matmul_ABT_2D.apply(C, B, DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), i, j, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, + pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) C_master = C_master.clone() @@ -194,17 +176,9 @@ def check_ATB(): C = C.clone() C.requires_grad = True - out = Matmul_ATB_2D.apply( - A, C, - DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), - i, j, - ParallelMode.PARALLEL_2D_ROW, - ParallelMode.PARALLEL_2D_COL, - data_parallel_rank, - pipeline_parallel_rank, - pipeline_parallel_size, - tensor_parallel_size - ) + out = Matmul_ATB_2D.apply(A, C, DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), i, j, + ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank, + pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) A_master = A_master.clone() @@ -212,8 +186,7 @@ def check_ATB(): C_master = C_master.clone() C_master.requires_grad = True B_master = torch.matmul( - A_master.view(-1, A_master.shape[-1]).transpose(0, 1), - C_master.view(-1, C_master.shape[-1])) + A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) B = torch.chunk(B_master, DEPTH, dim=0)[i] B = torch.chunk(B, DEPTH, dim=-1)[j] check_equal(out, B)