|
|
|
@ -111,18 +111,17 @@ def colo_embedding(input_tensor: GeneralTensor,
|
|
|
|
|
assert isinstance(weight, ColoTensor)
|
|
|
|
|
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
|
|
|
|
|
|
|
|
|
|
if not weight.has_compute_spec(): # No Model Parallel Applied
|
|
|
|
|
if not weight.has_compute_spec(): # No Model Parallel Applied
|
|
|
|
|
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
|
|
|
|
|
return ColoTensor.from_torch_tensor(
|
|
|
|
|
tensor=F.embedding(input_tensor,
|
|
|
|
|
weight,
|
|
|
|
|
padding_idx=padding_idx,
|
|
|
|
|
max_norm=max_norm,
|
|
|
|
|
norm_type=norm_type,
|
|
|
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
|
|
|
sparse=sparse),
|
|
|
|
|
spec=ColoTensorSpec(weight.get_process_group()))
|
|
|
|
|
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
|
|
|
|
return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor,
|
|
|
|
|
weight,
|
|
|
|
|
padding_idx=padding_idx,
|
|
|
|
|
max_norm=max_norm,
|
|
|
|
|
norm_type=norm_type,
|
|
|
|
|
scale_grad_by_freq=scale_grad_by_freq,
|
|
|
|
|
sparse=sparse),
|
|
|
|
|
spec=ColoTensorSpec(weight.get_process_group()))
|
|
|
|
|
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
|
|
|
|
if weight.is_shard_1drow():
|
|
|
|
|
mode = 'row'
|
|
|
|
|
elif weight.is_shard_1dcol():
|
|
|
|
|