mirror of https://github.com/hpcaitech/ColossalAI
241 lines
8.4 KiB
Python
241 lines
8.4 KiB
Python
![]() |
#!/usr/bin/env python
|
||
|
# -*- encoding: utf-8 -*-
|
||
|
|
||
|
import torch
|
||
|
|
||
|
from colossalai.context.parallel_mode import ParallelMode
|
||
|
from colossalai.core import global_context as gpc
|
||
|
from colossalai.nn.layer.parallel_2d import Matmul_AB_2D, Matmul_ABT_2D, Matmul_ATB_2D
|
||
|
from colossalai.utils import get_current_device
|
||
|
from colossalai.utils import print_rank_0
|
||
|
from common import check_equal, BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE, DEPTH
|
||
|
|
||
|
|
||
|
def check_AB():
|
||
|
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||
|
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||
|
ParallelMode.PIPELINE)
|
||
|
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||
|
ParallelMode.PIPELINE)
|
||
|
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||
|
|
||
|
dtype = torch.float
|
||
|
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||
|
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||
|
|
||
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||
|
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
||
|
torch.distributed.broadcast(A_master, src=0)
|
||
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||
|
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||
|
A = A.clone()
|
||
|
A.requires_grad = True
|
||
|
|
||
|
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||
|
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
|
||
|
torch.distributed.broadcast(B_master, src=0)
|
||
|
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
||
|
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
||
|
B = B.clone()
|
||
|
B.requires_grad = True
|
||
|
|
||
|
out_shape = (BATCH_SIZE // DEPTH, SEQ_LENGTH, 4 * HIDDEN_SIZE // DEPTH)
|
||
|
|
||
|
out = Matmul_AB_2D.apply(
|
||
|
A, B,
|
||
|
DEPTH,
|
||
|
out_shape,
|
||
|
i, j,
|
||
|
ParallelMode.PARALLEL_2D_ROW,
|
||
|
ParallelMode.PARALLEL_2D_COL,
|
||
|
data_parallel_rank,
|
||
|
pipeline_parallel_rank,
|
||
|
pipeline_parallel_size,
|
||
|
tensor_parallel_size
|
||
|
)
|
||
|
|
||
|
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||
|
A_master = A_master.clone()
|
||
|
A_master.requires_grad = True
|
||
|
B_master = B_master.clone()
|
||
|
B_master.requires_grad = True
|
||
|
C_master = torch.matmul(A_master, B_master)
|
||
|
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||
|
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||
|
# check forward correctness
|
||
|
check_equal(out, C)
|
||
|
print_rank_0('AB forward: pass')
|
||
|
|
||
|
grad_shape = C_master.shape
|
||
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
||
|
torch.distributed.broadcast(grad_master, src=0)
|
||
|
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||
|
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||
|
|
||
|
out.backward(grad)
|
||
|
|
||
|
C_master.backward(grad_master)
|
||
|
A_grad = A_master.grad
|
||
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||
|
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
||
|
# check backward correctness
|
||
|
check_equal(A_grad, A.grad)
|
||
|
|
||
|
B_grad = B_master.grad
|
||
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||
|
# check backward correctness
|
||
|
check_equal(B_grad, B.grad)
|
||
|
print_rank_0('AB backward: pass')
|
||
|
|
||
|
|
||
|
def check_ABT():
|
||
|
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||
|
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||
|
ParallelMode.PIPELINE)
|
||
|
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||
|
ParallelMode.PIPELINE)
|
||
|
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||
|
|
||
|
dtype = torch.float
|
||
|
device = get_current_device()
|
||
|
|
||
|
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||
|
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||
|
|
||
|
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||
|
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
||
|
torch.distributed.broadcast(C_master, src=0)
|
||
|
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||
|
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||
|
C = C.clone()
|
||
|
C.requires_grad = True
|
||
|
|
||
|
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||
|
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
||
|
torch.distributed.broadcast(B_master, src=0)
|
||
|
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
||
|
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
||
|
B = B.clone()
|
||
|
B.requires_grad = True
|
||
|
|
||
|
out = Matmul_ABT_2D.apply(
|
||
|
C, B,
|
||
|
DEPTH, (BATCH_SIZE // DEPTH, SEQ_LENGTH, HIDDEN_SIZE // DEPTH),
|
||
|
i, j,
|
||
|
ParallelMode.PARALLEL_2D_ROW,
|
||
|
ParallelMode.PARALLEL_2D_COL,
|
||
|
data_parallel_rank,
|
||
|
pipeline_parallel_rank,
|
||
|
pipeline_parallel_size,
|
||
|
tensor_parallel_size
|
||
|
)
|
||
|
|
||
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||
|
C_master = C_master.clone()
|
||
|
C_master.requires_grad = True
|
||
|
B_master = B_master.clone()
|
||
|
B_master.requires_grad = True
|
||
|
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
|
||
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||
|
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||
|
check_equal(out, A)
|
||
|
print_rank_0('ABT forward: pass')
|
||
|
|
||
|
grad_shape = A_master.shape
|
||
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||
|
torch.distributed.broadcast(grad_master, src=0)
|
||
|
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||
|
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||
|
|
||
|
# backward
|
||
|
out.backward(grad)
|
||
|
|
||
|
A_master.backward(grad_master)
|
||
|
C_grad = C_master.grad
|
||
|
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
|
||
|
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
|
||
|
check_equal(C_grad, C.grad)
|
||
|
|
||
|
B_grad = B_master.grad
|
||
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
||
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
||
|
check_equal(B_grad, B.grad)
|
||
|
print_rank_0('ABT backward: pass')
|
||
|
|
||
|
|
||
|
def check_ATB():
|
||
|
data_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.DATA) else gpc.get_local_rank(ParallelMode.DATA)
|
||
|
pipeline_parallel_rank = 0 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_local_rank(
|
||
|
ParallelMode.PIPELINE)
|
||
|
pipeline_parallel_size = 1 if not gpc.is_initialized(ParallelMode.PIPELINE) else gpc.get_world_size(
|
||
|
ParallelMode.PIPELINE)
|
||
|
tensor_parallel_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||
|
|
||
|
device = get_current_device()
|
||
|
dtype = torch.float
|
||
|
|
||
|
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
||
|
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
||
|
|
||
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
||
|
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||
|
torch.distributed.broadcast(A_master, src=0)
|
||
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
||
|
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
||
|
A = A.clone()
|
||
|
A.requires_grad = True
|
||
|
|
||
|
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
||
|
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
||
|
torch.distributed.broadcast(C_master, src=0)
|
||
|
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
||
|
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
||
|
C = C.clone()
|
||
|
C.requires_grad = True
|
||
|
|
||
|
out = Matmul_ATB_2D.apply(
|
||
|
A, C,
|
||
|
DEPTH, (HIDDEN_SIZE // DEPTH, 4 * HIDDEN_SIZE // DEPTH),
|
||
|
i, j,
|
||
|
ParallelMode.PARALLEL_2D_ROW,
|
||
|
ParallelMode.PARALLEL_2D_COL,
|
||
|
data_parallel_rank,
|
||
|
pipeline_parallel_rank,
|
||
|
pipeline_parallel_size,
|
||
|
tensor_parallel_size
|
||
|
)
|
||
|
|
||
|
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
||
|
A_master = A_master.clone()
|
||
|
A_master.requires_grad = True
|
||
|
C_master = C_master.clone()
|
||
|
C_master.requires_grad = True
|
||
|
B_master = torch.matmul(
|
||
|
A_master.view(-1, A_master.shape[-1]).transpose(0, 1),
|
||
|
C_master.view(-1, C_master.shape[-1]))
|
||
|
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
||
|
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
||
|
check_equal(out, B)
|
||
|
print_rank_0('ATB forward: pass')
|
||
|
|
||
|
grad_shape = B_master.shape
|
||
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
||
|
torch.distributed.broadcast(grad_master, src=0)
|
||
|
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
||
|
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
||
|
|
||
|
out.backward(grad)
|
||
|
|
||
|
B_master.backward(grad_master)
|
||
|
A_grad = A_master.grad
|
||
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
||
|
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
||
|
check_equal(A_grad, A.grad)
|
||
|
|
||
|
C_grad = C_master.grad
|
||
|
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
|
||
|
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
|
||
|
check_equal(C_grad, C.grad)
|
||
|
print_rank_0('ATB backward: pass')
|