[NFC] polish tests/test_layers/test_2d/checks_2d/check_operation_2d.py code style (#1715)

pull/1743/head
binmakeswell 2022-10-17 17:30:42 +08:00 committed by Frank Lee
parent e1d780030d
commit f6389d0813
1 changed files with 9 additions and 36 deletions

View File

@ -41,18 +41,8 @@ def check_AB():
out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH) out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH)
out = Matmul_AB_2D.apply( out = Matmul_AB_2D.apply(A, B, DEPTH, out_shape, i, j, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL,
A, B, data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
DEPTH,
out_shape,
i, j,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size
)
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
A_master = A_master.clone() A_master = A_master.clone()
@ -119,17 +109,9 @@ def check_ABT():
B = B.clone() B = B.clone()
B.requires_grad = True B.requires_grad = True
out = Matmul_ABT_2D.apply( out = Matmul_ABT_2D.apply(C, B, DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), i, j,
C, B, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank,
DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH), pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
i, j,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size
)
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
C_master = C_master.clone() C_master = C_master.clone()
@ -194,17 +176,9 @@ def check_ATB():
C = C.clone() C = C.clone()
C.requires_grad = True C.requires_grad = True
out = Matmul_ATB_2D.apply( out = Matmul_ATB_2D.apply(A, C, DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), i, j,
A, C, ParallelMode.PARALLEL_2D_ROW, ParallelMode.PARALLEL_2D_COL, data_parallel_rank,
DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH), pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size)
i, j,
ParallelMode.PARALLEL_2D_ROW,
ParallelMode.PARALLEL_2D_COL,
data_parallel_rank,
pipeline_parallel_rank,
pipeline_parallel_size,
tensor_parallel_size
)
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
A_master = A_master.clone() A_master = A_master.clone()
@ -212,8 +186,7 @@ def check_ATB():
C_master = C_master.clone() C_master = C_master.clone()
C_master.requires_grad = True C_master.requires_grad = True
B_master = torch.matmul( B_master = torch.matmul(
A_master.view(-1, A_master.shape[-1]).transpose(0, 1), A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1]))
C_master.view(-1, C_master.shape[-1]))
B = torch.chunk(B_master, DEPTH, dim=0)[i] B = torch.chunk(B_master, DEPTH, dim=0)[i]
B = torch.chunk(B, DEPTH, dim=-1)[j] B = torch.chunk(B, DEPTH, dim=-1)[j]
check_equal(out, B) check_equal(out, B)