mirror of https://github.com/hpcaitech/ColossalAI
[inference] overlap comm and compute in Linear1D_Row when stream_chunk_num > 1 (#1876)
parent
1b494ad73c
commit
986f8cbaa7
|
@ -706,13 +706,22 @@ class Linear1D_Row(ParallelLayer):
|
|||
input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1)
|
||||
|
||||
if self.stream_chunk_num > 1:
|
||||
output_parallel_list = [None for i in range(self.stream_chunk_num)]
|
||||
for i in range(self.stream_chunk_num):
|
||||
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
||||
output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
if self.training:
|
||||
raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!")
|
||||
with torch.no_grad():
|
||||
output_parallel_list = [None for i in range(self.stream_chunk_num)]
|
||||
handle_list = []
|
||||
for i in range(self.stream_chunk_num):
|
||||
output_parallel_list[i] = F.linear(input_, self.weight_list[i])
|
||||
handle = torch.distributed.all_reduce(output_parallel_list[i],
|
||||
group=gpc.get_group(ParallelMode.PARALLEL_1D),
|
||||
async_op=True)
|
||||
handle_list.append(handle)
|
||||
# output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D)
|
||||
for handle in handle_list:
|
||||
handle.wait()
|
||||
output = torch.cat(output_parallel_list, dim=-1)
|
||||
else:
|
||||
print(input_.shape, self.weight.shape)
|
||||
output_parallel = F.linear(input_, self.weight)
|
||||
# output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False)
|
||||
output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D)
|
||||
|
|
|
@ -514,8 +514,9 @@ def check_linear_row_stream_inference():
|
|||
|
||||
i = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
assert HIDDEN_SIZE % 2 == 0
|
||||
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=2)
|
||||
stream_chunk_num = 4
|
||||
assert HIDDEN_SIZE % stream_chunk_num == 0
|
||||
layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=stream_chunk_num)
|
||||
|
||||
A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE)
|
||||
A_master = torch.randn(A_shape, dtype=dtype, device=device)
|
||||
|
@ -537,6 +538,8 @@ def check_linear_row_stream_inference():
|
|||
layer.weight = Parameter(W)
|
||||
layer.bias = Parameter(B)
|
||||
layer.chunk_weight()
|
||||
layer.eval()
|
||||
|
||||
out = layer(A)
|
||||
|
||||
A_master = A_master.clone()
|
||||
|
|
Loading…
Reference in New Issue