Browse Source

embedding op use gather_out (#1143)

pull/1144/head
ver217 2 years ago committed by GitHub
parent
commit
ccf3c58c89
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 3
      colossalai/nn/_ops/embedding.py

3
colossalai/nn/_ops/embedding.py

@ -30,7 +30,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
ParallelAction(ComputePattern.TP1D)) ParallelAction(ComputePattern.TP1D))
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) if weight.spec.parallel_action.gather_out:
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
return output return output

Loading…
Cancel
Save