mirror of https://github.com/hpcaitech/ColossalAI
249 lines
7.8 KiB
Python
249 lines
7.8 KiB
Python
import torch
|
|
from torch.nn import Parameter
|
|
|
|
from colossalai.context.parallel_mode import ParallelMode
|
|
from colossalai.core import global_context as gpc
|
|
from colossalai.nn import Linear2D, LayerNorm2D, TransformerSelfAttention2D, TransformerMLP2D, TransformerLayer2D
|
|
from colossalai.utils import get_current_device, print_rank_0
|
|
from common import HIDDEN_SIZE, DEPTH, BATCH_SIZE, SEQ_LENGTH, check_equal
|
|
|
|
|
|
def check_linear():
|
|
device = get_current_device()
|
|
dtype = torch.float32
|
|
INPUT_SIZE = HIDDEN_SIZE
|
|
OUTPUT_SIZE = 2 * HIDDEN_SIZE
|
|
|
|
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
|
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
|
|
|
layer = Linear2D(INPUT_SIZE, OUTPUT_SIZE)
|
|
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
|
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(A_master, src=0)
|
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
|
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
|
A = A.clone()
|
|
A.requires_grad = True
|
|
|
|
W_shape = (INPUT_SIZE, OUTPUT_SIZE)
|
|
W_master = torch.randn(W_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(W_master, src=0)
|
|
W = torch.chunk(W_master, DEPTH, dim=0)[i]
|
|
W = torch.chunk(W, DEPTH, dim=-1)[j]
|
|
W = W.clone()
|
|
W.requires_grad = True
|
|
|
|
B_shape = (OUTPUT_SIZE)
|
|
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(B_master, src=0)
|
|
B = torch.chunk(B_master, DEPTH, dim=0)[j]
|
|
B = B.clone()
|
|
B.requires_grad = True
|
|
|
|
layer.weight = Parameter(W)
|
|
layer.bias = Parameter(B)
|
|
out = layer(A)
|
|
|
|
A_master = A_master.clone()
|
|
A_master.requires_grad = True
|
|
W_master = W_master.clone()
|
|
W_master.requires_grad = True
|
|
B_master = B_master.clone()
|
|
B_master.requires_grad = True
|
|
C_master = torch.matmul(A_master, W_master) + B_master
|
|
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
|
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
|
|
|
check_equal(out, C)
|
|
print_rank_0('linear forward: pass')
|
|
|
|
grad_shape = C_master.shape
|
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
|
torch.distributed.broadcast(grad_master, src=0)
|
|
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
|
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
|
out.backward(grad)
|
|
|
|
C_master.backward(grad_master)
|
|
A_grad = A_master.grad
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
|
check_equal(A_grad, A.grad)
|
|
|
|
W_grad = W_master.grad
|
|
W_grad = torch.chunk(W_grad, DEPTH, dim=0)[i]
|
|
W_grad = torch.chunk(W_grad, DEPTH, dim=-1)[j]
|
|
check_equal(W_grad, layer.weight.grad)
|
|
|
|
B_grad = B_master.grad
|
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
|
if i == 0:
|
|
check_equal(B_grad, layer.bias.grad)
|
|
|
|
print_rank_0('linear backward: pass')
|
|
|
|
|
|
def check_layernorm():
|
|
device = get_current_device()
|
|
dtype = torch.float32
|
|
INPUT_SIZE = HIDDEN_SIZE
|
|
EPS = 1e-12
|
|
|
|
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
|
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
|
|
|
layernorm = LayerNorm2D(INPUT_SIZE)
|
|
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
|
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(A_master, src=0)
|
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
|
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
|
A = A.clone()
|
|
A.requires_grad = True
|
|
|
|
out = layernorm(A)
|
|
|
|
A_master = A_master.clone()
|
|
A_master.requires_grad = True
|
|
E_master = torch.sum(A_master, dim=-1, keepdim=True)
|
|
E_master /= INPUT_SIZE
|
|
V_master = torch.sum(A_master * A_master, dim=-1, keepdim=True)
|
|
V_master /= INPUT_SIZE
|
|
V_master = V_master - E_master * E_master
|
|
V_master = 1.0 / torch.sqrt(V_master + EPS)
|
|
C_master = (A_master - E_master) * V_master
|
|
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
|
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
|
|
|
check_equal(out, C)
|
|
print_rank_0('layer norm forward: pass')
|
|
|
|
grad_shape = C_master.shape
|
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=get_current_device())
|
|
torch.distributed.broadcast(grad_master, src=0)
|
|
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
|
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
|
out.backward(grad)
|
|
|
|
C_master.backward(grad_master)
|
|
A_grad = A_master.grad
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[j]
|
|
check_equal(A_grad, A.grad)
|
|
print_rank_0('layer norm backward: pass')
|
|
|
|
|
|
def check_attention():
|
|
device = get_current_device()
|
|
dtype = torch.float32
|
|
INPUT_SIZE = HIDDEN_SIZE
|
|
NUM_ATTENTION_HEADS = 2
|
|
|
|
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
|
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
|
|
|
layer = TransformerSelfAttention2D(
|
|
HIDDEN_SIZE,
|
|
NUM_ATTENTION_HEADS,
|
|
attention_dropout_prob=0.5,
|
|
hidden_dropout_prob=0.5,
|
|
)
|
|
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
|
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(A_master, src=0)
|
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
|
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
|
A = A.clone()
|
|
A.requires_grad = True
|
|
|
|
mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
|
|
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
|
|
|
|
out = layer(A, attention_mask)
|
|
assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
|
|
print_rank_0('self attention forward: pass')
|
|
|
|
grad_shape = out.shape
|
|
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
|
|
|
out.backward(grad)
|
|
assert A.grad.shape == A.shape
|
|
print_rank_0('self attention backward: pass')
|
|
|
|
|
|
def check_mlp():
|
|
device = get_current_device()
|
|
dtype = torch.float32
|
|
INPUT_SIZE = HIDDEN_SIZE
|
|
|
|
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
|
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
|
|
|
layer = TransformerMLP2D(
|
|
HIDDEN_SIZE,
|
|
dropout_prob=0.5,
|
|
act_func='gelu',
|
|
)
|
|
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
|
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(A_master, src=0)
|
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
|
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
|
A = A.clone()
|
|
A.requires_grad = True
|
|
|
|
out = layer(A)
|
|
assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
|
|
print_rank_0('mlp forward: pass')
|
|
|
|
grad_shape = out.shape
|
|
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
|
|
|
out.backward(grad)
|
|
assert A.grad.shape == A.shape
|
|
print_rank_0('mlp backward: pass')
|
|
|
|
|
|
def check_transformerlayer():
|
|
device = get_current_device()
|
|
dtype = torch.float32
|
|
INPUT_SIZE = HIDDEN_SIZE
|
|
NUM_ATTENTION_HEADS = 2
|
|
|
|
j = gpc.get_local_rank(ParallelMode.PARALLEL_2D_ROW)
|
|
i = gpc.get_local_rank(ParallelMode.PARALLEL_2D_COL)
|
|
|
|
layer = TransformerLayer2D(
|
|
HIDDEN_SIZE,
|
|
NUM_ATTENTION_HEADS,
|
|
act_func='gelu',
|
|
attention_dropout_prob=0.5,
|
|
hidden_dropout_prob=0.5)
|
|
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, INPUT_SIZE)
|
|
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(A_master, src=0)
|
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
|
A = torch.chunk(A, DEPTH, dim=-1)[j]
|
|
A = A.clone()
|
|
A.requires_grad = True
|
|
|
|
mask_shape = (BATCH_SIZE // DEPTH, NUM_ATTENTION_HEADS // DEPTH, SEQ_LENGTH, SEQ_LENGTH)
|
|
attention_mask = torch.zeros(mask_shape, dtype=dtype, device=device)
|
|
|
|
out = layer(A, attention_mask)
|
|
assert out.shape == (BATCH_SIZE // DEPTH, SEQ_LENGTH, INPUT_SIZE // DEPTH)
|
|
print_rank_0('transformerlayer forward: pass')
|
|
|
|
grad_shape = out.shape
|
|
grad = torch.randn(grad_shape, dtype=dtype, device=device)
|
|
|
|
out.backward(grad)
|
|
assert A.grad.shape == A.shape
|
|
print_rank_0('transformerlayer backward: pass')
|