|
|
@ -30,7 +30,8 @@ def colo_embedding_1Dcol(input_tensor: ColoTensor, |
|
|
|
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), |
|
|
|
distspec.shard(weight.spec.get_process_group(), [-1], [weight.spec.get_process_group_size()]), |
|
|
|
ParallelAction(ComputePattern.TP1D)) |
|
|
|
ParallelAction(ComputePattern.TP1D)) |
|
|
|
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) |
|
|
|
output = ColoTensor.from_torch_tensor(output_parallel, spec=output_spec) |
|
|
|
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) |
|
|
|
if weight.spec.parallel_action.gather_out: |
|
|
|
|
|
|
|
output = output.convert_to_dist_spec(distspec.replicate(weight.spec.get_process_group())) |
|
|
|
return output |
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|