diff --git a/colossalai/nn/layer/parallel_1d/layers.py b/colossalai/nn/layer/parallel_1d/layers.py index 1976da95a..b64488a12 100644 --- a/colossalai/nn/layer/parallel_1d/layers.py +++ b/colossalai/nn/layer/parallel_1d/layers.py @@ -706,13 +706,22 @@ class Linear1D_Row(ParallelLayer): input_ = split_forward_gather_backward(input_, ParallelMode.PARALLEL_1D, dim=-1) if self.stream_chunk_num > 1: - output_parallel_list = [None for i in range(self.stream_chunk_num)] - for i in range(self.stream_chunk_num): - output_parallel_list[i] = F.linear(input_, self.weight_list[i]) - output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) - output = torch.cat(output_parallel_list, dim=-1) + if self.training: + raise RuntimeError("use stream_chunk_num=1 in Linear1D_Row for training!") + with torch.no_grad(): + output_parallel_list = [None for i in range(self.stream_chunk_num)] + handle_list = [] + for i in range(self.stream_chunk_num): + output_parallel_list[i] = F.linear(input_, self.weight_list[i]) + handle = torch.distributed.all_reduce(output_parallel_list[i], + group=gpc.get_group(ParallelMode.PARALLEL_1D), + async_op=True) + handle_list.append(handle) + # output_parallel_list[i] = reduce_input(output_parallel_list[i], ParallelMode.PARALLEL_1D) + for handle in handle_list: + handle.wait() + output = torch.cat(output_parallel_list, dim=-1) else: - print(input_.shape, self.weight.shape) output_parallel = F.linear(input_, self.weight) # output_parallel = linear_with_async_comm(input_, self.weight, None, ParallelMode.PARALLEL_1D, False) output = reduce_input(output_parallel, ParallelMode.PARALLEL_1D) diff --git a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py index 7d77391ea..668b8a334 100644 --- a/tests/test_layers/test_1d/checks_1d/check_layer_1d.py +++ b/tests/test_layers/test_1d/checks_1d/check_layer_1d.py @@ -514,8 +514,9 @@ def check_linear_row_stream_inference(): i = gpc.get_local_rank(ParallelMode.PARALLEL_1D) - assert HIDDEN_SIZE % 2 == 0 - layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=2) + stream_chunk_num = 4 + assert HIDDEN_SIZE % stream_chunk_num == 0 + layer = Linear1D_Row(OUTPUT_SIZE, INPUT_SIZE, stream_chunk_num=stream_chunk_num) A_shape = (BATCH_SIZE, SEQ_LENGTH, OUTPUT_SIZE) A_master = torch.randn(A_shape, dtype=dtype, device=device) @@ -537,6 +538,8 @@ def check_linear_row_stream_inference(): layer.weight = Parameter(W) layer.bias = Parameter(B) layer.chunk_weight() + layer.eval() + out = layer(A) A_master = A_master.clone()