mirror of https://github.com/hpcaitech/ColossalAI
overlap computation and communication in 2d operations (#75)
parent
cd9c28e055
commit
632e622de8
|
@ -85,30 +85,57 @@ class Matmul_AB_2D(torch.autograd.Function):
|
|||
ctx.save_for_backward(A, B)
|
||||
|
||||
A_shape = A.shape
|
||||
A = A.reshape((-1, A_shape[-1])).contiguous()
|
||||
A = A.reshape((-1, A_shape[-1]))
|
||||
B_shape = B.shape
|
||||
B = B.reshape((-1, B_shape[-1])).contiguous()
|
||||
B = B.reshape((-1, B_shape[-1]))
|
||||
C_shape = (A.shape[0], B.shape[-1])
|
||||
C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)]
|
||||
B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)]
|
||||
A_list.insert(gpc.get_local_rank(row_parallel_mode), A)
|
||||
B_list.insert(gpc.get_local_rank(col_parallel_mode), B)
|
||||
op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True)
|
||||
op_a.wait()
|
||||
op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True)
|
||||
for op in [op_a, op_b]:
|
||||
op.wait()
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
A_list = [torch.empty_like(A) for _ in range(2)]
|
||||
B_list = [torch.empty_like(B) for _ in range(2)]
|
||||
|
||||
row_group = gpc.get_group(row_parallel_mode)
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opa = [None] * 2
|
||||
opb = [None] * 2
|
||||
|
||||
A_list[0].copy_(A)
|
||||
B_list[0].copy_(B)
|
||||
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
|
||||
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
|
||||
cur = 0
|
||||
|
||||
for i in range(summa_dim):
|
||||
src_a = i + summa_dim * row_rank
|
||||
src_b = i + summa_dim * col_rank
|
||||
src_a = src_a % summa_dim
|
||||
src_b = src_b % summa_dim
|
||||
A_temp = A_list[src_a]
|
||||
B_temp = B_list[src_b]
|
||||
torch.addmm(C, A_temp, B_temp, out=C)
|
||||
if i != summa_dim - 1:
|
||||
A_list[1 - cur].copy_(A)
|
||||
opa[1 - cur] = dist.broadcast(A_list[1 - cur],
|
||||
src=src_a + 1,
|
||||
group=row_group,
|
||||
async_op=True)
|
||||
B_list[1 - cur].copy_(B)
|
||||
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
|
||||
src=src_b + summa_dim,
|
||||
group=col_group,
|
||||
async_op=True)
|
||||
|
||||
if opa[cur] is not None:
|
||||
opa[cur].wait()
|
||||
if opb[cur] is not None:
|
||||
opb[cur].wait()
|
||||
|
||||
torch.addmm(C, A_list[cur], B_list[cur], out=C)
|
||||
cur = 1 - cur
|
||||
src_a += 1
|
||||
src_b += summa_dim
|
||||
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
|
@ -188,21 +215,55 @@ class Matmul_ABT_2D(torch.autograd.Function):
|
|||
C_shape = (A.shape[0], B.shape[0])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
for i in range(summa_dim):
|
||||
B_temp = B.clone()
|
||||
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
|
||||
src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(B_temp, src=src_b,
|
||||
group=gpc.get_group(col_parallel_mode))
|
||||
C_temp = torch.matmul(A, B_temp.transpose(0, 1))
|
||||
src_c = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.reduce(C_temp, dst=src_c,
|
||||
group=gpc.get_group(row_parallel_mode))
|
||||
if i == col_rank:
|
||||
C = C_temp.clone()
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
B_list = [torch.empty_like(B) for _ in range(2)]
|
||||
C_list = [torch.empty_like(C) for _ in range(2)]
|
||||
|
||||
row_group = gpc.get_group(row_parallel_mode)
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opb = [None] * 2
|
||||
opr = [None] * 2
|
||||
|
||||
B_list[0].copy_(B)
|
||||
opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True)
|
||||
cur = 0
|
||||
|
||||
for i in range(summa_dim):
|
||||
if i != summa_dim - 1:
|
||||
B_list[1 - cur].copy_(B)
|
||||
opb[1 - cur] = dist.broadcast(B_list[1 - cur],
|
||||
src=src_b + summa_dim,
|
||||
group=col_group,
|
||||
async_op=True)
|
||||
|
||||
if opr[cur] is not None:
|
||||
opr[cur].wait()
|
||||
if i - 2 == col_rank:
|
||||
C.copy_(C_list[cur])
|
||||
|
||||
if opb[cur] is not None:
|
||||
opb[cur].wait()
|
||||
|
||||
torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur])
|
||||
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True)
|
||||
cur = 1 - cur
|
||||
src_b += summa_dim
|
||||
src_c += 1
|
||||
|
||||
for op in opr:
|
||||
op.wait()
|
||||
|
||||
if summa_dim - 2 == col_rank:
|
||||
C.copy_(C_list[cur])
|
||||
if summa_dim - 1 == col_rank:
|
||||
C.copy_(C_list[1 - cur])
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
|
@ -284,21 +345,55 @@ class Matmul_ATB_2D(torch.autograd.Function):
|
|||
C_shape = (A.shape[-1], B.shape[-1])
|
||||
C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device())
|
||||
|
||||
for i in range(summa_dim):
|
||||
A_temp = A.clone()
|
||||
# C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device())
|
||||
src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(A_temp, src=src_a,
|
||||
group=gpc.get_group(row_parallel_mode))
|
||||
C_temp = torch.matmul(A_temp.transpose(0, 1), B)
|
||||
src_c = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.reduce(C_temp, dst=src_c,
|
||||
group=gpc.get_group(col_parallel_mode))
|
||||
if i == row_rank:
|
||||
C = C_temp.clone()
|
||||
# use circular buffer to store the communication tensor
|
||||
# 2 is enough for all cases
|
||||
A_list = [torch.empty_like(A) for _ in range(2)]
|
||||
C_list = [torch.empty_like(C) for _ in range(2)]
|
||||
|
||||
row_group = gpc.get_group(row_parallel_mode)
|
||||
col_group = gpc.get_group(col_parallel_mode)
|
||||
|
||||
src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
|
||||
opa = [None] * 2
|
||||
opr = [None] * 2
|
||||
|
||||
A_list[0].copy_(A)
|
||||
opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True)
|
||||
cur = 0
|
||||
|
||||
for i in range(summa_dim):
|
||||
if i != summa_dim - 1:
|
||||
A_list[1 - cur].copy_(A)
|
||||
opa[1 - cur] = dist.broadcast(A_list[1 - cur],
|
||||
src=src_a + 1,
|
||||
group=row_group,
|
||||
async_op=True)
|
||||
|
||||
if opr[cur] is not None:
|
||||
opr[cur].wait()
|
||||
if i - 2 == row_rank:
|
||||
C.copy_(C_list[cur])
|
||||
|
||||
if opa[cur] is not None:
|
||||
opa[cur].wait()
|
||||
|
||||
torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur])
|
||||
opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True)
|
||||
cur = 1 - cur
|
||||
src_a += 1
|
||||
src_c += summa_dim
|
||||
|
||||
for op in opr:
|
||||
op.wait()
|
||||
|
||||
if summa_dim - 2 == row_rank:
|
||||
C.copy_(C_list[cur])
|
||||
if summa_dim - 1 == row_rank:
|
||||
C.copy_(C_list[1 - cur])
|
||||
out = C.reshape(out_shape)
|
||||
|
||||
if ctx:
|
||||
|
@ -374,7 +469,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
|||
dtype=bias.dtype,
|
||||
device=get_current_device())
|
||||
src_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.broadcast(bias_temp, src=src_rank,
|
||||
group=gpc.get_group(col_parallel_mode))
|
||||
|
||||
|
@ -408,7 +503,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
|||
|
||||
if ctx.bias:
|
||||
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.reduce(output_grad, dst=dst_rank,
|
||||
group=gpc.get_group(col_parallel_mode))
|
||||
if row_rank == 0:
|
||||
|
@ -421,7 +516,7 @@ class Add_Bias_2D(torch.autograd.Function):
|
|||
reduce_dim = tuple(range(output_grad.ndim - 1))
|
||||
reduce = torch.sum(output_grad, dim=reduce_dim)
|
||||
dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
pipeline_parallel_rank * tensor_parallel_size
|
||||
dist.reduce(reduce, dst=dst_rank,
|
||||
group=gpc.get_group(col_parallel_mode))
|
||||
if row_rank == 0:
|
||||
|
|
Loading…
Reference in New Issue