mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/nn/_ops/embedding_bag.py code style (#1552)
parent
73e9eb13b7
commit
bd2d789832
|
@ -92,8 +92,7 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
|
|||
|
||||
if not weight.has_compute_spec(): # No Model Parallel Applied
|
||||
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
|
||||
return ColoTensor.from_torch_tensor(
|
||||
tensor=F.embedding_bag(input_tensor,
|
||||
return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor,
|
||||
weight,
|
||||
offsets=offsets,
|
||||
max_norm=max_norm,
|
||||
|
|
Loading…
Reference in New Issue