mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish tests/test_layers/test_sequence/checks_seq/check_layer_seq.py code style (#1723)
parent
7e62af28a0
commit
ff373a11eb
|
@ -12,15 +12,10 @@ def check_selfattention():
|
||||||
BATCH = 4
|
BATCH = 4
|
||||||
HIDDEN_SIZE = 16
|
HIDDEN_SIZE = 16
|
||||||
|
|
||||||
layer = TransformerSelfAttentionRing(
|
layer = TransformerSelfAttentionRing(16, 8, 8, 0.1)
|
||||||
16,
|
|
||||||
8,
|
|
||||||
8,
|
|
||||||
0.1
|
|
||||||
)
|
|
||||||
layer = layer.to(get_current_device())
|
layer = layer.to(get_current_device())
|
||||||
|
|
||||||
hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device())
|
hidden_states = torch.rand(SUB_SEQ_LENGTH, BATCH, HIDDEN_SIZE).to(get_current_device())
|
||||||
attention_mask = torch.randint(low=0, high=2, size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(
|
attention_mask = torch.randint(low=0, high=2,
|
||||||
get_current_device())
|
size=(BATCH, 1, 1, 1, SUB_SEQ_LENGTH * WORLD_SIZE)).to(get_current_device())
|
||||||
out = layer(hidden_states, attention_mask)
|
out = layer(hidden_states, attention_mask)
|
||||||
|
|
Loading…
Reference in New Issue