import torch
import torch.nn as nn
from torch.optim.adam import Adam
from torch.optim import AdamW

from colossalai.nn.optimizer.hybrid_adam import HybridAdam
from colossalai.testing import parameterize

RE = 1024


@parameterize('adamw', [False, True])
@parameterize('device', ['cpu', 'cuda:0'])
@parameterize('p_dtype', [torch.float])
@parameterize('g_dtype', [torch.float, torch.half])
def test_adam(adamw, device, p_dtype, g_dtype):
    rng_state = torch.get_rng_state()
    p = nn.Parameter(torch.rand(64).to(device, p_dtype))
    torch.set_rng_state(rng_state)
    p_copy = nn.Parameter(torch.rand(64).to(device).float()) 

    if adamw:
        optim = HybridAdam([p], lr=1e-3, adamw_mode=True)
        torch_optim = AdamW([p_copy], lr=1e-3)
    else:
        optim = HybridAdam([p], lr=1e-3)
        torch_optim = Adam([p_copy], lr=1e-3)

    print(f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}")
    for i in range(RE):
        p.grad = torch.rand(64).to(device, p_dtype)
        p_copy.grad = p.grad.clone().float()
        p.grad.data = p.grad.data.to(g_dtype)

        optim.step()
        torch_optim.step()

        if torch.isnan(p.data).any() or torch.isnan(p_copy.data).any():
            continue
        assert torch.allclose(p.data, p_copy.data, 1e-4, 1e-2), \
                              f"adaw mode {adamw}, device {device}, p_dtype {p_dtype}, g_dtype {g_dtype}"