From bd2d7898325f4245fd831280f9f1ca5b539348de Mon Sep 17 00:00:00 2001 From: Maruyama_Aya <38985202+MaruyamaAya@users.noreply.github.com> Date: Thu, 8 Sep 2022 14:42:02 +0800 Subject: [PATCH] [NFC] polish colossalai/nn/_ops/embedding_bag.py code style (#1552) --- colossalai/nn/_ops/embedding_bag.py | 29 ++++++++++++++--------------- 1 file changed, 14 insertions(+), 15 deletions(-) diff --git a/colossalai/nn/_ops/embedding_bag.py b/colossalai/nn/_ops/embedding_bag.py index cdab44856..0e8aa8fec 100644 --- a/colossalai/nn/_ops/embedding_bag.py +++ b/colossalai/nn/_ops/embedding_bag.py @@ -90,22 +90,21 @@ def colo_embedding_bag(input_tensor: GeneralTensor, # Handle differen parallel actions. - if not weight.has_compute_spec(): # No Model Parallel Applied + if not weight.has_compute_spec(): # No Model Parallel Applied assert weight.is_replicate(), 'Invalid weight spec for native embedding op' - return ColoTensor.from_torch_tensor( - tensor=F.embedding_bag(input_tensor, - weight, - offsets=offsets, - max_norm=max_norm, - norm_type=norm_type, - scale_grad_by_freq=scale_grad_by_freq, - mode=mode, - sparse=sparse, - per_sample_weights=per_sample_weights, - include_last_offset=include_last_offset, - padding_idx=padding_idx), - spec=ColoTensorSpec(weight.get_process_group())) - elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied + return ColoTensor.from_torch_tensor(tensor=F.embedding_bag(input_tensor, + weight, + offsets=offsets, + max_norm=max_norm, + norm_type=norm_type, + scale_grad_by_freq=scale_grad_by_freq, + mode=mode, + sparse=sparse, + per_sample_weights=per_sample_weights, + include_last_offset=include_last_offset, + padding_idx=padding_idx), + spec=ColoTensorSpec(weight.get_process_group())) + elif weight.has_compute_pattern(ComputePattern.TP1D): # Single Model Parallel Applied if weight.is_shard_1dcol(): tp_mode = 'col' else: