import torch
import torch.nn as nn
from colossalai.utils.checkpoint_io.meta import ParamDistMeta
from colossalai.utils.checkpoint_io.utils import build_checkpoints
from torch.optim import Adam


class DummyModel(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.fc = nn.Linear(20, 1)


def test_global_model():
    model = DummyModel()
    model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model)
    assert len(model_checkpoints) == 1
    assert len(optimizer_checkpoints) == 0
    assert meta['dist_meta'] is None
    orig_state_dict = model.state_dict()
    global_state_dict = model_checkpoints[0]
    assert set(orig_state_dict.keys()) == set(global_state_dict.keys())
    for k, v in orig_state_dict.items():
        assert torch.equal(v, global_state_dict[k])


def test_global_model_shard():
    model = DummyModel()
    model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model)
    assert len(model_checkpoints) == 2
    assert len(optimizer_checkpoints) == 0
    assert meta['dist_meta'] is None
    orig_state_dict = model.state_dict()
    assert set(orig_state_dict.keys()) == set(model_checkpoints[0].keys()) | set(model_checkpoints[1].keys())
    assert len(set(model_checkpoints[0].keys()) & set(model_checkpoints[1].keys())) == 0
    for k, v in orig_state_dict.items():
        for state_dict in model_checkpoints:
            if k in state_dict:
                assert torch.equal(v, state_dict[k])


def test_global_optimizer():
    model = DummyModel()
    for p in model.parameters():
        p.grad = torch.rand_like(p)
    optimizer = Adam(model.parameters(), lr=1e-3)
    optimizer.step()
    model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer)
    assert len(optimizer_checkpoints) == 1
    assert meta['param_to_os'] == {'fc.weight': 0, 'fc.bias': 1}
    for state in meta['paired_os'].values():
        for k, is_paired in state.items():
            if k == 'step':
                assert not is_paired
            else:
                assert is_paired
    orig_state_dict = optimizer.state_dict()
    state_dict = optimizer_checkpoints[0]
    for k, orig_state in orig_state_dict['state'].items():
        state = state_dict['state'][k]
        for v1, v2 in zip(orig_state.values(), state.values()):
            if isinstance(v2, torch.Tensor):
                assert torch.equal(v1, v2)
            else:
                assert v2 == v2
    assert orig_state_dict['param_groups'] == state_dict['param_groups']


def test_global_optimizer_shard():
    model = DummyModel()
    for p in model.parameters():
        p.grad = torch.rand_like(p)
    optimizer = Adam(model.parameters(), lr=1e-3)
    optimizer.step()
    model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(80, model, optimizer)
    assert len(optimizer_checkpoints) == 2
    assert 'param_groups' in optimizer_checkpoints[0] and 'param_groups' not in optimizer_checkpoints[1]
    orig_state_dict = optimizer.state_dict()
    assert set(orig_state_dict['state'].keys()) == set(optimizer_checkpoints[0]['state'].keys()) | set(
        optimizer_checkpoints[1]['state'].keys())
    assert len(set(optimizer_checkpoints[0]['state'].keys()) & set(optimizer_checkpoints[1]['state'].keys())) == 0
    for k, orig_state in orig_state_dict['state'].items():
        state = optimizer_checkpoints[0]['state'][k] if k in optimizer_checkpoints[0][
            'state'] else optimizer_checkpoints[1]['state'][k]
        for v1, v2 in zip(orig_state.values(), state.values()):
            if isinstance(v2, torch.Tensor):
                assert torch.equal(v1, v2)
            else:
                assert v1 == v2

    assert orig_state_dict['param_groups'] == optimizer_checkpoints[0]['param_groups']


def test_dist_model_optimizer():
    model = DummyModel()
    for p in model.parameters():
        p.grad = torch.rand_like(p)
    optimizer = Adam(model.parameters(), lr=1e-3)
    optimizer.step()
    dist_meta = {'fc.weight': ParamDistMeta(0, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
    model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
    assert dist_meta == meta['dist_meta']
    assert len(model_checkpoints) == 1
    assert len(optimizer_checkpoints) == 1
    assert 'fc.weight' in model_checkpoints[0] and 'fc.bias' in model_checkpoints[0]
    assert 0 in optimizer_checkpoints[0]['state'] and 1 in optimizer_checkpoints[0]['state']
    dist_meta = {'fc.weight': ParamDistMeta(1, 2, 0, 1), 'fc.bias': ParamDistMeta(1, 2, 0, 1)}
    model_checkpoints, optimizer_checkpoints, meta = build_checkpoints(0, model, optimizer, dist_meta=dist_meta)
    assert dist_meta == meta['dist_meta']
    assert len(model_checkpoints) == 1
    assert len(optimizer_checkpoints) == 1


if __name__ == '__main__':
    test_global_model()
    test_global_model_shard()
    test_global_optimizer()
    test_global_optimizer_shard()
    test_dist_model_optimizer()