embedding op use gather_out (#1143)

pull/1144/head
ver217 2022-06-21 13:21:20 +08:00 committed by GitHub
parent e61dc31b05
commit ccf3c58c89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 1 deletions

View File

@ -30,7 +30,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D)) ParallelAction(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) if weight.spec.parallel_action.gather_out:
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output return output