ColossalAI/tests/test_layers/test_3d/checks_3d/common.py

19 lines
307 B
Python

#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
DEPTH = 2
BATCH_SIZE = 8
SEQ_LENGTH = 8
HIDDEN_SIZE = 8
NUM_CLASSES = 8
NUM_BLOCKS = 2
IMG_SIZE = 16
VOCAB_SIZE = 16
def check_equal(A, B):
eq = torch.allclose(A, B, rtol=1e-3, atol=1e-2)
assert eq, f"\nA = {A}\nB = {B}"
return eq