mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
466 lines
17 KiB
466 lines
17 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from colossalai.context import ParallelMode
|
|
from colossalai.core import global_context
|
|
from colossalai.logging import get_global_dist_logger
|
|
from colossalai.nn.layer.parallel_3d._operation import *
|
|
from colossalai.utils import get_current_device
|
|
|
|
from common import *
|
|
|
|
|
|
def check_AB():
|
|
rank = torch.distributed.get_rank()
|
|
logger = get_global_dist_logger()
|
|
dtype = torch.float
|
|
j = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
|
i = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
|
k = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
|
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
|
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
|
torch.distributed.broadcast(A_master, src=0)
|
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
|
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
|
A = torch.chunk(A, DEPTH, dim=0)[j]
|
|
A = A.clone()
|
|
A.requires_grad = True
|
|
|
|
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
|
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
|
|
torch.distributed.broadcast(B_master, src=0)
|
|
B = torch.chunk(B_master, DEPTH, dim=0)[k]
|
|
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
|
B = torch.chunk(B, DEPTH, dim=-1)[i]
|
|
B = B.clone()
|
|
B.requires_grad = True
|
|
|
|
out = Matmul_AB_3D.apply(A, B, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
|
|
ParallelMode.PARALLEL_3D_WEIGHT,
|
|
ParallelMode.PARALLEL_3D_OUTPUT)
|
|
|
|
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
|
A_master = A_master.clone()
|
|
A_master.requires_grad = True
|
|
B_master = B_master.clone()
|
|
B_master.requires_grad = True
|
|
C_master = torch.matmul(A_master, B_master)
|
|
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
|
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
|
C = torch.chunk(C, DEPTH, dim=0)[k]
|
|
# check forward correctness
|
|
logger.info('Rank {} AB forward: {}'.format(rank, check_equal(out, C)))
|
|
|
|
grad_shape = C_master.shape
|
|
grad_master = torch.randn(grad_shape,
|
|
dtype=dtype,
|
|
device=get_current_device())
|
|
torch.distributed.broadcast(grad_master, src=0)
|
|
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
|
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
|
grad = torch.chunk(grad, DEPTH, dim=0)[k]
|
|
|
|
out.backward(grad)
|
|
|
|
C_master.backward(grad_master)
|
|
A_grad = A_master.grad
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
|
# check backward correctness
|
|
logger.info('Rank {} AB backward (A_grad): {}'.format(
|
|
rank, check_equal(A_grad, A.grad)))
|
|
|
|
B_grad = B_master.grad
|
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
|
# check backward correctness
|
|
logger.info('Rank {} AB backward (B_grad): {}'.format(
|
|
rank, check_equal(B_grad, B.grad)))
|
|
|
|
|
|
def check_ABT():
|
|
rank = torch.distributed.get_rank()
|
|
logger = get_global_dist_logger()
|
|
dtype = torch.float
|
|
|
|
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
|
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
|
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
|
device = get_current_device()
|
|
|
|
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
|
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(C_master, src=0)
|
|
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
|
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
|
C = torch.chunk(C, DEPTH, dim=0)[k]
|
|
C = C.clone()
|
|
C.requires_grad = True
|
|
|
|
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
|
B_master = torch.randn(B_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(B_master, src=0)
|
|
B = torch.chunk(B_master, DEPTH, dim=0)[k]
|
|
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
|
B = torch.chunk(B, DEPTH, dim=-1)[i]
|
|
B = B.clone()
|
|
B.requires_grad = True
|
|
|
|
out = Matmul_ABT_3D.apply(C, B, DEPTH, ParallelMode.PARALLEL_3D_OUTPUT,
|
|
ParallelMode.PARALLEL_3D_WEIGHT,
|
|
ParallelMode.PARALLEL_3D_INPUT)
|
|
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
|
C_master = C_master.clone()
|
|
C_master.requires_grad = True
|
|
B_master = B_master.clone()
|
|
B_master.requires_grad = True
|
|
A_master = torch.matmul(C_master, B_master.transpose(0, 1))
|
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
|
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
|
A = torch.chunk(A, DEPTH, dim=0)[j]
|
|
logger.info('Rank {} ABT forward: {}'.format(rank, check_equal(out, A)))
|
|
|
|
grad_shape = A_master.shape
|
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(grad_master, src=0)
|
|
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
|
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
|
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
|
|
|
# backward
|
|
out.backward(grad)
|
|
|
|
A_master.backward(grad_master)
|
|
C_grad = C_master.grad
|
|
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
|
|
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
|
|
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[k]
|
|
logger.info('Rank {} ABT backward (A_grad): {}'.format(
|
|
rank, check_equal(C_grad, C.grad)))
|
|
|
|
B_grad = B_master.grad
|
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[k]
|
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[j]
|
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[i]
|
|
logger.info('Rank {} ABT backward (B_grad): {}'.format(
|
|
rank, check_equal(B_grad, B.grad)))
|
|
|
|
|
|
def check_ATB():
|
|
rank = torch.distributed.get_rank()
|
|
logger = get_global_dist_logger()
|
|
device = get_current_device()
|
|
dtype = torch.float
|
|
|
|
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
|
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
|
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
|
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
|
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(A_master, src=0)
|
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
|
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
|
A = torch.chunk(A, DEPTH, dim=0)[j]
|
|
A = A.clone()
|
|
A.requires_grad = True
|
|
|
|
C_shape = (BATCH_SIZE, SEQ_LENGTH, 4 * HIDDEN_SIZE)
|
|
C_master = torch.randn(C_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(C_master, src=0)
|
|
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
|
C = torch.chunk(C, DEPTH, dim=-1)[j]
|
|
C = torch.chunk(C, DEPTH, dim=0)[k]
|
|
C = C.clone()
|
|
C.requires_grad = True
|
|
|
|
out = Matmul_ATB_3D.apply(A, C, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
|
|
ParallelMode.PARALLEL_3D_OUTPUT,
|
|
ParallelMode.PARALLEL_3D_WEIGHT)
|
|
|
|
B_shape = (HIDDEN_SIZE, 4 * HIDDEN_SIZE)
|
|
A_master = A_master.clone()
|
|
A_master.requires_grad = True
|
|
C_master = C_master.clone()
|
|
C_master.requires_grad = True
|
|
B_master = torch.matmul(
|
|
A_master.view(-1, A_master.shape[-1]).transpose(0, 1),
|
|
C_master.view(-1, C_master.shape[-1]))
|
|
B = torch.chunk(B_master, DEPTH, dim=0)[k]
|
|
B = torch.chunk(B, DEPTH, dim=-1)[j]
|
|
B = torch.chunk(B, DEPTH, dim=-1)[i]
|
|
logger.info('Rank {} ATB forward: {}'.format(rank, check_equal(out, B)))
|
|
|
|
grad_shape = B_master.shape
|
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(grad_master, src=0)
|
|
grad = torch.chunk(grad_master, DEPTH, dim=0)[k]
|
|
grad = torch.chunk(grad, DEPTH, dim=-1)[j]
|
|
grad = torch.chunk(grad, DEPTH, dim=-1)[i]
|
|
|
|
out.backward(grad)
|
|
|
|
B_master.backward(grad_master)
|
|
A_grad = A_master.grad
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
|
logger.info('Rank {} ATB backward (A_grad): {}'.format(
|
|
rank, check_equal(A_grad, A.grad)))
|
|
|
|
C_grad = C_master.grad
|
|
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[i]
|
|
C_grad = torch.chunk(C_grad, DEPTH, dim=-1)[j]
|
|
C_grad = torch.chunk(C_grad, DEPTH, dim=0)[k]
|
|
logger.info('Rank {} ATB backward (B_grad): {}'.format(
|
|
rank, check_equal(C_grad, C.grad)))
|
|
|
|
|
|
def check_add():
|
|
rank = torch.distributed.get_rank()
|
|
logger = get_global_dist_logger()
|
|
dtype = torch.float
|
|
|
|
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
|
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
|
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
|
device = get_current_device()
|
|
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
|
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
|
torch.distributed.broadcast(A_master, src=0)
|
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
|
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
|
A = torch.chunk(A, DEPTH, dim=0)[j]
|
|
A = A.clone()
|
|
A.requires_grad = True
|
|
|
|
bias_shape = (HIDDEN_SIZE, )
|
|
bias_master = torch.randn(bias_shape,
|
|
dtype=dtype,
|
|
device=get_current_device())
|
|
torch.distributed.broadcast(bias_master, src=0)
|
|
bias = torch.chunk(bias_master, DEPTH)[j]
|
|
bias = torch.chunk(bias, DEPTH)[i]
|
|
bias = bias.clone()
|
|
bias.requires_grad = True
|
|
|
|
out = Add_3D.apply(A, bias, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
|
|
ParallelMode.PARALLEL_3D_WEIGHT,
|
|
ParallelMode.PARALLEL_3D_OUTPUT)
|
|
|
|
A_master = A_master.clone()
|
|
A_master.requires_grad = True
|
|
bias_master = bias_master.clone()
|
|
bias_master.requires_grad = True
|
|
C_master = A_master + bias_master
|
|
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
|
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
|
C = torch.chunk(C, DEPTH, dim=0)[j]
|
|
|
|
logger.info('Rank {} Add forward: {}'.format(rank, check_equal(out, C)))
|
|
|
|
grad_shape = C_master.shape
|
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(grad_master, src=0)
|
|
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
|
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
|
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
|
|
|
out.backward(grad)
|
|
|
|
C_master.backward(grad_master)
|
|
A_grad = A_master.grad
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
|
logger.info('Rank {} Add backward (A_grad): {}'.format(
|
|
rank, check_equal(A_grad, A.grad)))
|
|
|
|
if j == k:
|
|
bias_grad = bias_master.grad
|
|
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
|
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
|
|
logger.info('Rank {} Add backward (b_grad): {}'.format(
|
|
rank, check_equal(bias_grad, bias.grad)))
|
|
else:
|
|
logger.info('Rank {} Add backward (b_grad): {}'.format(
|
|
rank,
|
|
# np.count_nonzero(bias.grad.detach().cpu().numpy()) == 0))
|
|
bias.grad is None))
|
|
|
|
|
|
def check_mul():
|
|
rank = torch.distributed.get_rank()
|
|
logger = get_global_dist_logger()
|
|
dtype = torch.float
|
|
|
|
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
|
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
|
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
|
device = get_current_device()
|
|
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
|
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
|
torch.distributed.broadcast(A_master, src=0)
|
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
|
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
|
A = torch.chunk(A, DEPTH, dim=0)[j]
|
|
A = A.clone()
|
|
A.requires_grad = True
|
|
|
|
bias_shape = (HIDDEN_SIZE, )
|
|
bias_master = torch.randn(bias_shape,
|
|
dtype=dtype,
|
|
device=get_current_device())
|
|
torch.distributed.broadcast(bias_master, src=0)
|
|
bias = torch.chunk(bias_master, DEPTH)[j]
|
|
bias = torch.chunk(bias, DEPTH)[i]
|
|
bias = bias.clone()
|
|
bias.requires_grad = True
|
|
|
|
out = Mul_3D.apply(A, bias, DEPTH, ParallelMode.PARALLEL_3D_INPUT,
|
|
ParallelMode.PARALLEL_3D_WEIGHT,
|
|
ParallelMode.PARALLEL_3D_OUTPUT)
|
|
|
|
A_master = A_master.clone()
|
|
A_master.requires_grad = True
|
|
bias_master = bias_master.clone()
|
|
bias_master.requires_grad = True
|
|
C_master = torch.mul(A_master, bias_master)
|
|
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
|
C = torch.chunk(C, DEPTH, dim=-1)[k]
|
|
C = torch.chunk(C, DEPTH, dim=0)[j]
|
|
|
|
logger.info('Rank {} Mul forward: {}'.format(rank, check_equal(out, C)))
|
|
|
|
grad_shape = C_master.shape
|
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(grad_master, src=0)
|
|
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
|
grad = torch.chunk(grad, DEPTH, dim=-1)[k]
|
|
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
|
|
|
out.backward(grad)
|
|
|
|
C_master.backward(grad_master)
|
|
A_grad = A_master.grad
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
|
logger.info('Rank {} Mul backward (A_grad): {}'.format(
|
|
rank, check_equal(A_grad, A.grad)))
|
|
|
|
if j == k:
|
|
bias_grad = bias_master.grad
|
|
bias_grad = torch.chunk(bias_grad, DEPTH)[j]
|
|
bias_grad = torch.chunk(bias_grad, DEPTH)[i]
|
|
logger.info('Rank {} Mul backward (b_grad): {}'.format(
|
|
rank, check_equal(bias_grad, bias.grad)))
|
|
else:
|
|
logger.info('Rank {} Mul backward (b_grad): {}'.format(
|
|
rank,
|
|
# np.count_nonzero(bias.grad.detach().cpu().numpy()) == 0))
|
|
bias.grad is None))
|
|
|
|
|
|
def check_sum():
|
|
rank = torch.distributed.get_rank()
|
|
logger = get_global_dist_logger()
|
|
dtype = torch.float
|
|
|
|
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
|
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
|
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
|
device = get_current_device()
|
|
|
|
# tensor
|
|
A_shape = (BATCH_SIZE, SEQ_LENGTH, HIDDEN_SIZE)
|
|
A_master = torch.randn(A_shape, dtype=dtype, device=get_current_device())
|
|
torch.distributed.broadcast(A_master, src=0)
|
|
A = torch.chunk(A_master, DEPTH, dim=0)[i]
|
|
A = torch.chunk(A, DEPTH, dim=-1)[k]
|
|
A = torch.chunk(A, DEPTH, dim=0)[j]
|
|
A = A.clone()
|
|
A.requires_grad = True
|
|
|
|
out_tensor = Sum_3D.apply(A, -1, DEPTH, ParallelMode.PARALLEL_3D_OUTPUT)
|
|
|
|
A_master = A_master.clone()
|
|
A_master.requires_grad = True
|
|
C_master = torch.sum(A_master, dim=-1)
|
|
C = torch.chunk(C_master, DEPTH, dim=0)[i]
|
|
C = torch.chunk(C, DEPTH, dim=0)[j]
|
|
logger.info('Rank {} Sum forward: {}'.format(rank,
|
|
check_equal(out_tensor, C)))
|
|
|
|
grad_shape = C_master.shape
|
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(grad_master, src=0)
|
|
grad = torch.chunk(grad_master, DEPTH, dim=0)[i]
|
|
grad = torch.chunk(grad, DEPTH, dim=0)[j]
|
|
|
|
out_tensor.backward(grad / DEPTH)
|
|
|
|
C_master.backward(grad_master)
|
|
A_grad = A_master.grad
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[i]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=-1)[k]
|
|
A_grad = torch.chunk(A_grad, DEPTH, dim=0)[j]
|
|
logger.info('Rank {} Sum backward: {}'.format(rank,
|
|
check_equal(A_grad, A.grad)))
|
|
|
|
|
|
def check_reduce():
|
|
rank = torch.distributed.get_rank()
|
|
logger = get_global_dist_logger()
|
|
dtype = torch.float
|
|
|
|
j = A_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_INPUT)
|
|
i = B_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_WEIGHT)
|
|
k = C_rank = global_context.get_local_rank(ParallelMode.PARALLEL_3D_OUTPUT)
|
|
device = get_current_device()
|
|
|
|
# scaler
|
|
B_shape = (DEPTH * DEPTH, DEPTH)
|
|
B_master = torch.randn(B_shape, dtype=dtype, device=get_current_device())
|
|
torch.distributed.broadcast(B_master, src=0)
|
|
B = torch.chunk(B_master, DEPTH, dim=0)[i]
|
|
B = torch.chunk(B, DEPTH, dim=-1)[k]
|
|
B = torch.chunk(B, DEPTH, dim=0)[j]
|
|
B = torch.squeeze(B)
|
|
B = B.clone()
|
|
B.requires_grad = True
|
|
|
|
out_scaler = Reduce_3D.apply(B, 0, DEPTH, ParallelMode.PARALLEL_3D_OUTPUT)
|
|
out_scaler = Reduce_3D.apply(out_scaler, 0, DEPTH,
|
|
ParallelMode.PARALLEL_3D_INPUT)
|
|
out_scaler = Reduce_3D.apply(out_scaler, 0, DEPTH,
|
|
ParallelMode.PARALLEL_3D_WEIGHT)
|
|
|
|
B_master = B_master.clone()
|
|
B_master.requires_grad = True
|
|
D = torch.sum(B_master)
|
|
logger.info('Rank {} Reduce forward: {}'.format(rank,
|
|
check_equal(out_scaler,
|
|
D)))
|
|
|
|
grad_shape = D.shape
|
|
grad_master = torch.randn(grad_shape, dtype=dtype, device=device)
|
|
torch.distributed.broadcast(grad_master, src=0)
|
|
|
|
out_scaler.backward(grad_master)
|
|
|
|
D.backward(grad_master)
|
|
B_grad = B_master.grad
|
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[i]
|
|
B_grad = torch.chunk(B_grad, DEPTH, dim=-1)[k]
|
|
B_grad = torch.chunk(B_grad, DEPTH, dim=0)[j]
|
|
B_grad = torch.squeeze(B_grad)
|
|
logger.info('Rank {} Reduce backward: {}'.format(
|
|
rank, check_equal(B_grad, B.grad)))
|