mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish tests/test_layers/test_3d/checks_3d/check_layer_3d.py code style (#1731)
parent
ff373a11eb
commit
754aa7c81f
|
@ -784,7 +784,7 @@ def check_loss():
|
|||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
|
@ -837,7 +837,7 @@ def check_vocab_parallel_loss():
|
|||
|
||||
out_shape = (BATCH_SIZE, NUM_CLASSES)
|
||||
out_master = torch.randn(out_shape, dtype=dtype, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE, ), dtype=torch.long, device=device)
|
||||
target_master = torch.randint(NUM_CLASSES, (BATCH_SIZE,), dtype=torch.long, device=device)
|
||||
torch.distributed.broadcast(out_master, src=0)
|
||||
torch.distributed.broadcast(target_master, src=0)
|
||||
out = torch.chunk(out_master, DEPTH, dim=0)[i]
|
||||
|
|
Loading…
Reference in New Issue