[inference] overlap comm and compute in Linear1D_Row when stream_chunk_num > 1 (#1876)

pull/1880/head
Jiarui Fang 2022-11-10 17:36:42 +08:00 committed by GitHub
parent 1b494ad73c
commit 986f8cbaa7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 20 additions and 8 deletions

View File

@ -706,13 +706,22 @@ class Linear1D_Row(ParallelLayer):
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
if self.stream_chunk_num > 1:
output_parallel_list = [None for i in range(self.stream_chunk_num)]
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
output = torch.cat(output_parallel_list, dim=-1)
if self.training:
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
with torch.no_grad():
output_parallel_list = [None for i in range(self.stream_chunk_num)]
handle_list = []
for i in range(self.stream_chunk_num):
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
handle = torch.distributed.all_reduce(output_parallel_list[i],
group=gpc.get_group(ParallelMode.PARALLEL_1D),
async_op=True)
handle_list.append(handle)
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
for handle in handle_list:
handle.wait()
output = torch.cat(output_parallel_list, dim=-1)
else:
print(input_.shape, self.weight.shape)
output_parallel = F.linear(input_, self.weight)
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)

View File

@ -514,8 +514,9 @@ def check_linear_row_stream_inference():
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
assert HIDDEN_SIZE % 2 == 0
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=2)
stream_chunk_num = 4
assert HIDDEN_SIZE % stream_chunk_num == 0
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=stream_chunk_num)
A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE)
A_master = torch.randn(A_shape, dtype=dtype, device=device)
@ -537,6 +538,8 @@ def check_linear_row_stream_inference():
layer.weight = Parameter(W)
layer.bias = Parameter(B)
layer.chunk_weight()
layer.eval()
out = layer(A)
A_master = A_master.clone()