mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/nn/_ops/embedding_bag.py code style (#1552)
parent
73e9eb13b7
commit
bd2d789832
|
@ -90,22 +90,21 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
|
||||||
|
|
||||||
# Handle differen parallel actions.
|
# Handle differen parallel actions.
|
||||||
|
|
||||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||||
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
|
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
|
||||||
return ColoTensor.from_torch_tensor(
|
return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor,
|
||||||
tensor=F.embedding_bag(input_tensor,
|
weight,
|
||||||
weight,
|
offsets=offsets,
|
||||||
offsets=offsets,
|
max_norm=max_norm,
|
||||||
max_norm=max_norm,
|
norm_type=norm_type,
|
||||||
norm_type=norm_type,
|
scale_grad_by_freq=scale_grad_by_freq,
|
||||||
scale_grad_by_freq=scale_grad_by_freq,
|
mode=mode,
|
||||||
mode=mode,
|
sparse=sparse,
|
||||||
sparse=sparse,
|
per_sample_weights=per_sample_weights,
|
||||||
per_sample_weights=per_sample_weights,
|
include_last_offset=include_last_offset,
|
||||||
include_last_offset=include_last_offset,
|
padding_idx=padding_idx),
|
||||||
padding_idx=padding_idx),
|
spec=ColoTensorSpec(weight.get_process_group()))
|
||||||
spec=ColoTensorSpec(weight.get_process_group()))
|
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
||||||
elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied
|
|
||||||
if weight.is_shard_1dcol():
|
if weight.is_shard_1dcol():
|
||||||
tp_mode = 'col'
|
tp_mode = 'col'
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue