diff --git a/colossalai/nn/layer/parallel_2d/_operation.py b/colossalai/nn/layer/parallel_2d/_operation.py index c3722b43e..6e839c0e8 100644 --- a/colossalai/nn/layer/parallel_2d/_operation.py +++ b/colossalai/nn/layer/parallel_2d/_operation.py @@ -85,30 +85,57 @@ class Matmul_AB_2D(torch.autograd.Function): ctx.save_for_backward(A, B) A_shape = A.shape - A = A.reshape((-1, A_shape[-1])).contiguous() + A = A.reshape((-1, A_shape[-1])) B_shape = B.shape - B = B.reshape((-1, B_shape[-1])).contiguous() + B = B.reshape((-1, B_shape[-1])) C_shape = (A.shape[0], B.shape[-1]) C = torch.zeros(C_shape, dtype=A.dtype, device=get_current_device()) - A_list = [torch.empty_like(A) for _ in range(gpc.get_world_size(row_parallel_mode)-1)] - B_list = [torch.empty_like(B) for _ in range(gpc.get_world_size(col_parallel_mode)-1)] - A_list.insert(gpc.get_local_rank(row_parallel_mode), A) - B_list.insert(gpc.get_local_rank(col_parallel_mode), B) - op_a = dist.all_gather(A_list, A, group=gpc.get_group(row_parallel_mode), async_op=True) - op_a.wait() - op_b = dist.all_gather(B_list, B, group=gpc.get_group(col_parallel_mode), async_op=True) - for op in [op_a, op_b]: - op.wait() + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + B_list = [torch.empty_like(B) for _ in range(2)] + + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opb = [None] * 2 + + A_list[0].copy_(A) + B_list[0].copy_(B) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 for i in range(summa_dim): - src_a = i + summa_dim * row_rank - src_b = i + summa_dim * col_rank - src_a = src_a % summa_dim - src_b = src_b % summa_dim - A_temp = A_list[src_a] - B_temp = B_list[src_b] - torch.addmm(C, A_temp, B_temp, out=C) + if i != summa_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], + src=src_a + 1, + group=row_group, + async_op=True) + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], + src=src_b + summa_dim, + group=col_group, + async_op=True) + + if opa[cur] is not None: + opa[cur].wait() + if opb[cur] is not None: + opb[cur].wait() + + torch.addmm(C, A_list[cur], B_list[cur], out=C) + cur = 1 - cur + src_a += 1 + src_b += summa_dim + out = C.reshape(out_shape) if ctx: @@ -188,21 +215,55 @@ class Matmul_ABT_2D(torch.autograd.Function): C_shape = (A.shape[0], B.shape[0]) C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) - for i in range(summa_dim): - B_temp = B.clone() - # C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device()) - src_b = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.broadcast(B_temp, src=src_b, - group=gpc.get_group(col_parallel_mode)) - C_temp = torch.matmul(A, B_temp.transpose(0, 1)) - src_c = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.reduce(C_temp, dst=src_c, - group=gpc.get_group(row_parallel_mode)) - if i == col_rank: - C = C_temp.clone() + # use circular buffer to store the communication tensor + # 2 is enough for all cases + B_list = [torch.empty_like(B) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_b = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opb = [None] * 2 + opr = [None] * 2 + + B_list[0].copy_(B) + opb[0] = dist.broadcast(B_list[0], src=src_b, group=col_group, async_op=True) + cur = 0 + + for i in range(summa_dim): + if i != summa_dim - 1: + B_list[1 - cur].copy_(B) + opb[1 - cur] = dist.broadcast(B_list[1 - cur], + src=src_b + summa_dim, + group=col_group, + async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == col_rank: + C.copy_(C_list[cur]) + + if opb[cur] is not None: + opb[cur].wait() + + torch.matmul(A, B_list[cur].transpose(0, 1), out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=row_group, async_op=True) + cur = 1 - cur + src_b += summa_dim + src_c += 1 + + for op in opr: + op.wait() + + if summa_dim - 2 == col_rank: + C.copy_(C_list[cur]) + if summa_dim - 1 == col_rank: + C.copy_(C_list[1 - cur]) out = C.reshape(out_shape) if ctx: @@ -284,21 +345,55 @@ class Matmul_ATB_2D(torch.autograd.Function): C_shape = (A.shape[-1], B.shape[-1]) C = torch.empty(C_shape, dtype=A.dtype, device=get_current_device()) - for i in range(summa_dim): - A_temp = A.clone() - # C_temp = torch.zeros(C_shape, dtype=C.dtype, device=get_current_device()) - src_a = i + summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.broadcast(A_temp, src=src_a, - group=gpc.get_group(row_parallel_mode)) - C_temp = torch.matmul(A_temp.transpose(0, 1), B) - src_c = col_rank + summa_dim * i + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size - dist.reduce(C_temp, dst=src_c, - group=gpc.get_group(col_parallel_mode)) - if i == row_rank: - C = C_temp.clone() + # use circular buffer to store the communication tensor + # 2 is enough for all cases + A_list = [torch.empty_like(A) for _ in range(2)] + C_list = [torch.empty_like(C) for _ in range(2)] + row_group = gpc.get_group(row_parallel_mode) + col_group = gpc.get_group(col_parallel_mode) + + src_a = summa_dim * row_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + src_c = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ + pipeline_parallel_rank * tensor_parallel_size + + opa = [None] * 2 + opr = [None] * 2 + + A_list[0].copy_(A) + opa[0] = dist.broadcast(A_list[0], src=src_a, group=row_group, async_op=True) + cur = 0 + + for i in range(summa_dim): + if i != summa_dim - 1: + A_list[1 - cur].copy_(A) + opa[1 - cur] = dist.broadcast(A_list[1 - cur], + src=src_a + 1, + group=row_group, + async_op=True) + + if opr[cur] is not None: + opr[cur].wait() + if i - 2 == row_rank: + C.copy_(C_list[cur]) + + if opa[cur] is not None: + opa[cur].wait() + + torch.matmul(A_list[cur].transpose(0, 1), B, out=C_list[cur]) + opr[cur] = dist.reduce(C_list[cur], dst=src_c, group=col_group, async_op=True) + cur = 1 - cur + src_a += 1 + src_c += summa_dim + + for op in opr: + op.wait() + + if summa_dim - 2 == row_rank: + C.copy_(C_list[cur]) + if summa_dim - 1 == row_rank: + C.copy_(C_list[1 - cur]) out = C.reshape(out_shape) if ctx: @@ -374,7 +469,7 @@ class Add_Bias_2D(torch.autograd.Function): dtype=bias.dtype, device=get_current_device()) src_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size dist.broadcast(bias_temp, src=src_rank, group=gpc.get_group(col_parallel_mode)) @@ -408,7 +503,7 @@ class Add_Bias_2D(torch.autograd.Function): if ctx.bias: dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size dist.reduce(output_grad, dst=dst_rank, group=gpc.get_group(col_parallel_mode)) if row_rank == 0: @@ -421,7 +516,7 @@ class Add_Bias_2D(torch.autograd.Function): reduce_dim = tuple(range(output_grad.ndim - 1)) reduce = torch.sum(output_grad, dim=reduce_dim) dst_rank = col_rank + data_parallel_rank * pipeline_parallel_size * tensor_parallel_size + \ - pipeline_parallel_rank * tensor_parallel_size + pipeline_parallel_rank * tensor_parallel_size dist.reduce(reduce, dst=dst_rank, group=gpc.get_group(col_parallel_mode)) if row_rank == 0: