|
|
|
@ -39,16 +39,9 @@ def check_AB():
|
|
|
|
|
B.requires_grad = True |
|
|
|
|
|
|
|
|
|
out_shape = (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, 4 * HIDDEN_SIZE // TESSERACT_DIM) |
|
|
|
|
out = Matmul_AB_2p5D.apply( |
|
|
|
|
A, B, |
|
|
|
|
TESSERACT_DIM, out_shape, |
|
|
|
|
i, j, k, |
|
|
|
|
ParallelMode.PARALLEL_2P5D_ROW, |
|
|
|
|
ParallelMode.PARALLEL_2P5D_COL, |
|
|
|
|
data_parallel_rank, |
|
|
|
|
pipeline_parallel_rank, |
|
|
|
|
pipeline_parallel_size, |
|
|
|
|
tensor_parallel_size) |
|
|
|
|
out = Matmul_AB_2p5D.apply(A, B, TESSERACT_DIM, out_shape, i, j, k, ParallelMode.PARALLEL_2P5D_ROW, |
|
|
|
|
ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, pipeline_parallel_rank, |
|
|
|
|
pipeline_parallel_size, tensor_parallel_size) |
|
|
|
|
|
|
|
|
|
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE) |
|
|
|
|
A_master = A_master.clone() |
|
|
|
@ -116,16 +109,10 @@ def check_ABT():
|
|
|
|
|
B = B.clone() |
|
|
|
|
B.requires_grad = True |
|
|
|
|
|
|
|
|
|
out = Matmul_ABT_2p5D.apply( |
|
|
|
|
C, B, |
|
|
|
|
TESSERACT_DIM, (BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), |
|
|
|
|
i, j, k, |
|
|
|
|
ParallelMode.PARALLEL_2P5D_ROW, |
|
|
|
|
ParallelMode.PARALLEL_2P5D_COL, |
|
|
|
|
data_parallel_rank, |
|
|
|
|
pipeline_parallel_rank, |
|
|
|
|
pipeline_parallel_size, |
|
|
|
|
tensor_parallel_size) |
|
|
|
|
out = Matmul_ABT_2p5D.apply(C, B, TESSERACT_DIM, |
|
|
|
|
(BATCH_SIZE // TESSERACT_DIM, SEQ_LENGTH, HIDDEN_SIZE // TESSERACT_DIM), i, j, k, |
|
|
|
|
ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, data_parallel_rank, |
|
|
|
|
pipeline_parallel_rank, pipeline_parallel_size, tensor_parallel_size) |
|
|
|
|
|
|
|
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE) |
|
|
|
|
C_master = C_master.clone() |
|
|
|
@ -191,16 +178,10 @@ def check_ATB():
|
|
|
|
|
C = C.clone() |
|
|
|
|
C.requires_grad = True |
|
|
|
|
|
|
|
|
|
out = Matmul_ATB_2p5D.apply( |
|
|
|
|
A, C, |
|
|
|
|
TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM), |
|
|
|
|
i, j, k, |
|
|
|
|
ParallelMode.PARALLEL_2P5D_ROW, |
|
|
|
|
ParallelMode.PARALLEL_2P5D_COL, |
|
|
|
|
data_parallel_rank, |
|
|
|
|
pipeline_parallel_rank, |
|
|
|
|
pipeline_parallel_size, |
|
|
|
|
tensor_parallel_size) |
|
|
|
|
out = Matmul_ATB_2p5D.apply(A, C, TESSERACT_DIM, (HIDDEN_SIZE // TESSERACT_DIM, 4 * HIDDEN_SIZE // TESSERACT_DIM), |
|
|
|
|
i, j, k, ParallelMode.PARALLEL_2P5D_ROW, ParallelMode.PARALLEL_2P5D_COL, |
|
|
|
|
data_parallel_rank, pipeline_parallel_rank, pipeline_parallel_size, |
|
|
|
|
tensor_parallel_size) |
|
|
|
|
|
|
|
|
|
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE) |
|
|
|
|
A_master = A_master.clone() |
|
|
|
@ -208,8 +189,7 @@ def check_ATB():
|
|
|
|
|
C_master = C_master.clone() |
|
|
|
|
C_master.requires_grad = True |
|
|
|
|
B_master = torch.matmul( |
|
|
|
|
A_master.view(-1, A_master.shape[-1]).transpose(0, 1), |
|
|
|
|
C_master.view(-1, C_master.shape[-1])) |
|
|
|
|
A_master.view(-1, A_master.shape[-1]).transpose(0, 1), C_master.view(-1, C_master.shape[-1])) |
|
|
|
|
B = torch.chunk(B_master, TESSERACT_DIM, dim=0)[i] |
|
|
|
|
B = torch.chunk(B, TESSERACT_DIM, dim=-1)[j] |
|
|
|
|
check_equal(out, B) |
|
|
|
|