overlap computation and communication in 2d operations (#75)

pull/50/head
HELSON 2021-12-16 16:05:15 +08:00 committed by GitHub
parent cd9c28e055
commit 632e622de8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 144 additions and 49 deletions

View File

@ -85,30 +85,57 @@ class Matmul_AB_2D(torch.autograd.Function):
ctx.save_for_backward(A, B)
A_shape = A.shape
A = A.reshape((-1, A_shape[-1])).contiguous()
A = A.reshape((-1, A_shape[-1]))
B_shape = B.shape
B = B.reshape((-1, B_shape[-1])).contiguous()
B = B.reshape((-1, B_shape[-1]))
C_shape = (A.shape[0], B.shape[-1])
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
op_a.wait()
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
for op in [op_a, op_b]:
op.wait()
# use circular buffer to store the communication tensor
# 2 is enough for all cases
A_list = [torch.empty_like(A) for _ in range(2)]
B_list = [torch.empty_like(B) for _ in range(2)]
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opb = [None] * 2
A_list[0].copy_(A)
B_list[0].copy_(B)
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
cur = 0
for i in range(summa_dim):
src_a = i + summa_dim * row_rank
src_b = i + summa_dim * col_rank
src_a = src_a % summa_dim
src_b = src_b % summa_dim
A_temp = A_list[src_a]
B_temp = B_list[src_b]
torch.addmm(C, A_temp, B_temp, out=C)
if i != summa_dim - 1:
A_list[1 - cur].copy_(A)
opa[1 - cur] = dist.broadcast(A_list[1 - cur],
src=src_a + 1,
group=row_group,
async_op=True)
B_list[1 - cur].copy_(B)
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
src=src_b + summa_dim,
group=col_group,
async_op=True)
if opa[cur] is not None:
opa[cur].wait()
if opb[cur] is not None:
opb[cur].wait()
torch.addmm(C, A_list[cur], B_list[cur], out=C)
cur = 1 - cur
src_a += 1
src_b += summa_dim
out = C.reshape(out_shape)
if ctx:
@ -188,21 +215,55 @@ class Matmul_ABT_2D(torch.autograd.Function):
C_shape = (A.shape[0], B.shape[0])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim):
B_temp = B.clone()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(B_temp, src=src_b,
group=gpc.get_group(col_parallel_mode))
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
src_c = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c,
group=gpc.get_group(row_parallel_mode))
if i == col_rank:
C = C_temp.clone()
# use circular buffer to store the communication tensor
# 2 is enough for all cases
B_list = [torch.empty_like(B) for _ in range(2)]
C_list = [torch.empty_like(C) for _ in range(2)]
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opb = [None] * 2
opr = [None] * 2
B_list[0].copy_(B)
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
cur = 0
for i in range(summa_dim):
if i != summa_dim - 1:
B_list[1 - cur].copy_(B)
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
src=src_b + summa_dim,
group=col_group,
async_op=True)
if opr[cur] is not None:
opr[cur].wait()
if i - 2 == col_rank:
C.copy_(C_list[cur])
if opb[cur] is not None:
opb[cur].wait()
torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur])
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True)
cur = 1 - cur
src_b += summa_dim
src_c += 1
for op in opr:
op.wait()
if summa_dim - 2 == col_rank:
C.copy_(C_list[cur])
if summa_dim - 1 == col_rank:
C.copy_(C_list[1 - cur])
out = C.reshape(out_shape)
if ctx:
@ -284,21 +345,55 @@ class Matmul_ATB_2D(torch.autograd.Function):
C_shape = (A.shape[-1], B.shape[-1])
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
for i in range(summa_dim):
A_temp = A.clone()
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(A_temp, src=src_a,
group=gpc.get_group(row_parallel_mode))
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
src_c = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(C_temp, dst=src_c,
group=gpc.get_group(col_parallel_mode))
if i == row_rank:
C = C_temp.clone()
# use circular buffer to store the communication tensor
# 2 is enough for all cases
A_list = [torch.empty_like(A) for _ in range(2)]
C_list = [torch.empty_like(C) for _ in range(2)]
row_group = gpc.get_group(row_parallel_mode)
col_group = gpc.get_group(col_parallel_mode)
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
opa = [None] * 2
opr = [None] * 2
A_list[0].copy_(A)
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
cur = 0
for i in range(summa_dim):
if i != summa_dim - 1:
A_list[1 - cur].copy_(A)
opa[1 - cur] = dist.broadcast(A_list[1 - cur],
src=src_a + 1,
group=row_group,
async_op=True)
if opr[cur] is not None:
opr[cur].wait()
if i - 2 == row_rank:
C.copy_(C_list[cur])
if opa[cur] is not None:
opa[cur].wait()
torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur])
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True)
cur = 1 - cur
src_a += 1
src_c += summa_dim
for op in opr:
op.wait()
if summa_dim - 2 == row_rank:
C.copy_(C_list[cur])
if summa_dim - 1 == row_rank:
C.copy_(C_list[1 - cur])
out = C.reshape(out_shape)
if ctx:
@ -374,7 +469,7 @@ class Add_Bias_2D(torch.autograd.Function):
dtype=bias.dtype,
device=get_current_device())
src_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
pipeline_parallel_rank * tensor_parallel_size
dist.broadcast(bias_temp, src=src_rank,
group=gpc.get_group(col_parallel_mode))
@ -408,7 +503,7 @@ class Add_Bias_2D(torch.autograd.Function):
if ctx.bias:
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(output_grad, dst=dst_rank,
group=gpc.get_group(col_parallel_mode))
if row_rank == 0:
@ -421,7 +516,7 @@ class Add_Bias_2D(torch.autograd.Function):
reduce_dim = tuple(range(output_grad.ndim - 1))
reduce = torch.sum(output_grad, dim=reduce_dim)
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
pipeline_parallel_rank * tensor_parallel_size
pipeline_parallel_rank * tensor_parallel_size
dist.reduce(reduce, dst=dst_rank,
group=gpc.get_group(col_parallel_mode))
if row_rank == 0: