ColossalAI/tests/test_elixir/test_tracer/test_op_cache.py

146 lines
4.3 KiB
Python

import pytest
import torch
from colossalai.elixir.tracer.memory_tracer import MTensor
from colossalai.elixir.tracer.memory_tracer.op_cache import addmm_cache, bmm_cache, mm_cache
from colossalai.elixir.tracer.utils import get_cuda_allocated, get_cuda_max_allocated
from colossalai.testing import run_on_environment_flag
def op_mm(x, y):
u = torch.matmul(x, y)
return u.shape
def op_addmm(x, y, z):
u = torch.addmm(x, y, z)
return u.shape
def op_bmm(x, y):
u = torch.bmm(x, y)
return u.shape
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
@run_on_environment_flag('ELX')
def test_mm(dtype, size0=(4, 256), size1=(256, 1024)):
torch.cuda.reset_peak_memory_stats()
assert get_cuda_allocated() == 0
x = torch.randn(size0, dtype=dtype, device='cuda')
y = torch.randn(size1, dtype=dtype, device='cuda')
torch_pre_alc = get_cuda_allocated()
torch_z_size = op_mm(x, y)
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
del x
del y
assert get_cuda_allocated() == 0
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
op1_pre_alc = get_cuda_allocated()
MTensor.reset_peak_memory()
op1_z_size = op_mm(x, y)
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op1_z_size
assert torch_pre_alc == op1_pre_alc
assert torch_temp_alc == op1_temp_alc
assert len(mm_cache.temp_memory) > 0
MTensor.reset_peak_memory()
op2_z_size = op_mm(x, y)
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op2_z_size
assert torch_temp_alc == op2_temp_alc
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
@run_on_environment_flag('ELX')
def test_addmm(dtype, size0=(4, 16), size1=(16, 64)):
torch.cuda.reset_peak_memory_stats()
assert get_cuda_allocated() == 0
x = torch.randn(size0, dtype=dtype, device='cuda')
y = torch.randn(size1, dtype=dtype, device='cuda')
u = torch.randn(size1[-1], dtype=dtype, device='cuda')
torch_pre_alc = get_cuda_allocated()
torch_z_size = op_addmm(u, x, y)
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
del x
del y
del u
assert get_cuda_allocated() == 0
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
u = MTensor(torch.randn(size1[-1], dtype=dtype, device='cuda'))
op1_pre_alc = get_cuda_allocated()
MTensor.reset_peak_memory()
op1_z_size = op_addmm(u, x, y)
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op1_z_size
assert torch_pre_alc == op1_pre_alc
assert torch_temp_alc == op1_temp_alc
assert len(addmm_cache.temp_memory) > 0
MTensor.reset_peak_memory()
op2_z_size = op_addmm(u, x, y)
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op2_z_size
assert torch_temp_alc == op2_temp_alc
@pytest.mark.parametrize('dtype', [torch.float, torch.float16, torch.bfloat16])
@run_on_environment_flag('ELX')
def test_bmm(dtype, size0=(10, 4, 15), size1=(10, 15, 64)):
torch.cuda.reset_peak_memory_stats()
assert get_cuda_allocated() == 0
x = torch.randn(size0, dtype=dtype, device='cuda')
y = torch.randn(size1, dtype=dtype, device='cuda')
torch_pre_alc = get_cuda_allocated()
torch_z_size = op_bmm(x, y)
torch_temp_alc = get_cuda_max_allocated() - torch_pre_alc
del x
del y
assert get_cuda_allocated() == 0
x = MTensor(torch.randn(size0, dtype=dtype, device='cuda'))
y = MTensor(torch.randn(size1, dtype=dtype, device='cuda'))
op1_pre_alc = get_cuda_allocated()
MTensor.reset_peak_memory()
op1_z_size = op_bmm(x, y)
op1_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op1_z_size
assert torch_pre_alc == op1_pre_alc
assert torch_temp_alc == op1_temp_alc
assert len(bmm_cache.temp_memory) > 0
bmm_cache.print()
MTensor.reset_peak_memory()
op2_z_size = op_bmm(x, y)
op2_temp_alc = MTensor.current_peak_memory() - op1_pre_alc
assert torch_z_size == op2_z_size
assert torch_temp_alc == op2_temp_alc
if __name__ == '__main__':
test_addmm(dtype=torch.float)