mirror of https://github.com/hpcaitech/ColossalAI
146 lines
4.3 KiB
Python
146 lines
4.3 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from colossalai.elixir.tracer.memory_tracer import MTensor
|
|
from colossalai.elixir.tracer.memory_tracer.op_cache import addmm_cache, bmm_cache, mm_cache
|
|
from colossalai.elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated
|
|
from colossalai.testing import run_on_environment_flag
|
|
|
|
|
|
def op_mm(x, y):
|
|
u = torch.matmul(x, y)
|
|
return u.shape
|
|
|
|
|
|
def op_addmm(x, y, z):
|
|
u = torch.addmm(x, y, z)
|
|
return u.shape
|
|
|
|
|
|
def op_bmm(x, y):
|
|
u = torch.bmm(x, y)
|
|
return u.shape
|
|
|
|
|
|
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
|
|
@run_on_environment_flag('ELX')
|
|
def test_mm(dtype, size0=(4, 256), size1=(256, 1024)):
|
|
torch.cuda.reset_peak_memory_stats()
|
|
assert get_cuda_allocated() == 0
|
|
|
|
x = torch.randn(size0, dtype=dtype, device='cuda')
|
|
y = torch.randn(size1, dtype=dtype, device='cuda')
|
|
torch_pre_alc = get_cuda_allocated()
|
|
|
|
torch_z_size = op_mm(x, y)
|
|
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
|
|
|
|
del x
|
|
del y
|
|
|
|
assert get_cuda_allocated() == 0
|
|
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
|
|
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
|
|
op1_pre_alc = get_cuda_allocated()
|
|
|
|
MTensor.reset_peak_memory()
|
|
op1_z_size = op_mm(x, y)
|
|
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
|
|
|
assert torch_z_size == op1_z_size
|
|
assert torch_pre_alc == op1_pre_alc
|
|
assert torch_temp_alc == op1_temp_alc
|
|
assert len(mm_cache.temp_memory) > 0
|
|
|
|
MTensor.reset_peak_memory()
|
|
op2_z_size = op_mm(x, y)
|
|
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
|
|
|
assert torch_z_size == op2_z_size
|
|
assert torch_temp_alc == op2_temp_alc
|
|
|
|
|
|
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
|
|
@run_on_environment_flag('ELX')
|
|
def test_addmm(dtype, size0=(4, 16), size1=(16, 64)):
|
|
torch.cuda.reset_peak_memory_stats()
|
|
assert get_cuda_allocated() == 0
|
|
|
|
x = torch.randn(size0, dtype=dtype, device='cuda')
|
|
y = torch.randn(size1, dtype=dtype, device='cuda')
|
|
u = torch.randn(size1[-1], dtype=dtype, device='cuda')
|
|
torch_pre_alc = get_cuda_allocated()
|
|
|
|
torch_z_size = op_addmm(u, x, y)
|
|
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
|
|
|
|
del x
|
|
del y
|
|
del u
|
|
|
|
assert get_cuda_allocated() == 0
|
|
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
|
|
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
|
|
u = MTensor(torch.randn(size1[-1], dtype=dtype, device='cuda'))
|
|
op1_pre_alc = get_cuda_allocated()
|
|
|
|
MTensor.reset_peak_memory()
|
|
op1_z_size = op_addmm(u, x, y)
|
|
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
|
|
|
assert torch_z_size == op1_z_size
|
|
assert torch_pre_alc == op1_pre_alc
|
|
assert torch_temp_alc == op1_temp_alc
|
|
assert len(addmm_cache.temp_memory) > 0
|
|
|
|
MTensor.reset_peak_memory()
|
|
op2_z_size = op_addmm(u, x, y)
|
|
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
|
|
|
assert torch_z_size == op2_z_size
|
|
assert torch_temp_alc == op2_temp_alc
|
|
|
|
|
|
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
|
|
@run_on_environment_flag('ELX')
|
|
def test_bmm(dtype, size0=(10, 4, 15), size1=(10, 15, 64)):
|
|
torch.cuda.reset_peak_memory_stats()
|
|
assert get_cuda_allocated() == 0
|
|
|
|
x = torch.randn(size0, dtype=dtype, device='cuda')
|
|
y = torch.randn(size1, dtype=dtype, device='cuda')
|
|
torch_pre_alc = get_cuda_allocated()
|
|
|
|
torch_z_size = op_bmm(x, y)
|
|
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
|
|
|
|
del x
|
|
del y
|
|
|
|
assert get_cuda_allocated() == 0
|
|
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
|
|
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
|
|
op1_pre_alc = get_cuda_allocated()
|
|
|
|
MTensor.reset_peak_memory()
|
|
op1_z_size = op_bmm(x, y)
|
|
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
|
|
|
assert torch_z_size == op1_z_size
|
|
assert torch_pre_alc == op1_pre_alc
|
|
assert torch_temp_alc == op1_temp_alc
|
|
assert len(bmm_cache.temp_memory) > 0
|
|
|
|
bmm_cache.print()
|
|
|
|
MTensor.reset_peak_memory()
|
|
op2_z_size = op_bmm(x, y)
|
|
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
|
|
|
|
assert torch_z_size == op2_z_size
|
|
assert torch_temp_alc == op2_temp_alc
|
|
|
|
|
|
if __name__ == '__main__':
|
|
test_addmm(dtype=torch.float)
|