From ccf3c58c89fa1df4f7c5cf68c464045bc9905a3f Mon Sep 17 00:00:00 2001 From: ver217 Date: Tue, 21 Jun 2022 13:21:20 +0800 Subject: [PATCH] embedding op use gather_out (#1143) --- colossalai/nn/_ops/embedding.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py index 25b33e95a..18b59eb34 100644 --- a/colossalai/nn/_ops/embedding.py +++ b/colossalai/nn/_ops/embedding.py @@ -30,7 +30,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), ParallelAction(ComputePattern.TP1D)) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) + if weight.spec.parallel_action.gather_out: + output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) return output