[NFC] polish colossalai/nn/_ops/embedding_bag.py code style (#1552)

pull/1550/head
Maruyama_Aya 2022-09-08 14:42:02 +08:00 committed by Frank Lee
parent 73e9eb13b7
commit bd2d789832
1 changed files with 14 additions and 15 deletions

View File

@ -92,8 +92,7 @@ def colo_embedding_bag(input_tensor: GeneralTensor,
if not weight.has_compute_spec(): # No Model Parallel Applied
assert weight.is_replicate(), 'Invalid weight spec for native embedding op'
return ColoTensor.from_torch_tensor(
tensor=F.embedding_bag(input_tensor,
return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor,
weight,
offsets=offsets,
max_norm=max_norm,