diff --git a/colossalai/nn/_ops/embedding.py b/colossalai/nn/_ops/embedding.py index 25b33e95a..18b59eb34 100644 --- a/colossalai/nn/_ops/embedding.py +++ b/colossalai/nn/_ops/embedding.py @@ -30,7 +30,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), ParallelAction(ComputePattern.TP1D)) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) - output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) + if weight.spec.parallel_action.gather_out: + output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) return output