From ff373a11ebd0f1038611640500ce7b28b56c8ac4 Mon Sep 17 00:00:00 2001 From: xyupeng <99191637+xyupeng@users.noreply.github.com> Date: Tue, 18 Oct 2022 14:16:40 +0800 Subject: [PATCH] [NFC] polish tests/test_layers/test_sequence/checks_seq/check_layer_seq.py code style (#1723) --- .../test_sequence/checks_seq/check_layer_seq.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py b/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py index 156e60333..2b7b999d4 100644 --- a/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py +++ b/tests/test_layers/test_sequence/checks_seq/check_layer_seq.py @@ -12,15 +12,10 @@ def check_selfattention(): BATCH = 4 HIDDEN_SIZE = 16 - layer = TransformerSelfAttentionRing( - 16, - 8, - 8, - 0.1 - ) + layer = TransformerSelfAttentionRing(16, 8, 8, 0.1) layer = layer.to(get_current_device()) hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device()) - attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to( - get_current_device()) + attention_mask = torch.randint(low=0, high=2, + size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(get_current_device()) out = layer(hidden_states, attention_mask)