[NFC] polish tests/test_layers/test_3d/checks_3d/check_layer_3d.py code style (#1731)

pull/1743/head
Xue Fuzhao 2022-10-18 20:05:22 +08:00 committed by Frank Lee
parent ff373a11eb
commit 754aa7c81f
1 changed files with 2 additions and 2 deletions

View File

@ -784,7 +784,7 @@ def check_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]
@ -837,7 +837,7 @@ def check_vocab_parallel_loss():
out_shape = (BATCH_SIZE, NUM_CLASSES)
out_master = torch.randn(out_shape, dtype=dtype, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
torch.distributed.broadcast(out_master, src=0)
torch.distributed.broadcast(target_master, src=0)
out = torch.chunk(out_master, DEPTH, dim=0)[i]