mirror of https://github.com/hpcaitech/ColossalAI
embedding op use gather_out (#1143)
parent
e61dc31b05
commit
ccf3c58c89
|
@ -30,6 +30,7 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor,
|
|||
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]),
|
||||
ParallelAction(ComputePattern.TP1D))
|
||||
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec)
|
||||
if weight.spec.parallel_action.gather_out:
|
||||
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group()))
|
||||
return output
|
||||
|
||||
|
|
Loading…
Reference in New Issue