[NFC] polish colossalai/nn/_ops/embedding.py code style (#1561)

pull/1550/head
BigOneLiXiaoMing 2 years ago committed by Frank Lee
parent 08815f0e72
commit 0c4c9aa6e0

@ -111,18 +111,17 @@ def colo_embedding(input_tensor: GeneralTensor,
assert isinstance(weight, ColoTensor) assert isinstance(weight, ColoTensor)
input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group()) input_tensor = convert_to_colo_tensor(input_tensor, weight.get_process_group())
if not weight.has_compute_spec(): # No Model Parallel Applied if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native embedding op' assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor( return ColoTensor.from_torch_tensor(tensor=F.embedding(input_tensor,
tensor=F.embedding(input_tensor, weight,
weight, padding_idx=padding_idx,
padding_idx=padding_idx, max_norm=max_norm,
max_norm=max_norm, norm_type=norm_type,
norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq,
scale_grad_by_freq=scale_grad_by_freq, sparse=sparse),
sparse=sparse), spec=ColoTensorSpec(weight.get_process_group()))
spec=ColoTensorSpec(weight.get_process_group())) elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
if weight.is_shard_1drow(): if weight.is_shard_1drow():
mode = 'row' mode = 'row'
elif weight.is_shard_1dcol(): elif weight.is_shard_1dcol():

Loading…
Cancel
Save