mirror of https://github.com/hpcaitech/ColossalAI
embedding op use gather_out (#1143)
parent
e61dc31b05
commit
ccf3c58c89
|
@ -30,7 +30,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
||||||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||||
ParallelAction(ComputePattern.TP1D))
|
ParallelAction(ComputePattern.TP1D))
|
||||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||||
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
if weight.spec.parallel_action.gather_out:
|
||||||
|
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||||
return output
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue