import torch
from torch.fx import GraphModule
from colossalai.fx.passes.adding_split_node_pass import split_with_split_nodes_pass, balanced_split_pass
from colossalai.fx import ColoTracer
from colossalai.pipeline.middleware import Partition, PartitionInputVal, PartitionOutputVal, Topo
from colossalai.pipeline.middleware.adaptor import get_fx_topology
import random
import numpy as np

MANUAL_SEED = 0
random.seed(MANUAL_SEED)
np.random.seed(MANUAL_SEED)
torch.manual_seed(MANUAL_SEED)

class MLP(torch.nn.Module):
    def __init__(self, config={}):
        super().__init__()
        dim = config['dim']
        layers = config['layers']
        self.layers = torch.nn.ModuleList()

        for _ in range(layers):
            self.layers.append(torch.nn.Linear(dim, dim, bias=False))

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

def split_model_and_get_DAG(model, data_gen):
    model.eval()

    # generate input sample
    kwargs = data_gen()

    # tracing model
    tracer = ColoTracer()
    try:
        meta_args = {k: v.to('meta') for k, v in kwargs.items()}
        graph = tracer.trace(root=model, meta_args=meta_args)
    except Exception as e:
        raise RuntimeError(f"Failed to trace {model.__class__.__name__}, error: {e}")
    gm = GraphModule(model, graph, model.__class__.__name__)
    gm.recompile()

    # apply transform passes
    annotated_model = balanced_split_pass(gm, 2)
    top_module, split_submodules = split_with_split_nodes_pass(annotated_model)
    
    topo = get_fx_topology(top_module)
    for submodule in split_submodules:
        if isinstance(submodule, torch.fx.GraphModule):
            setattr(submodule, '_topo', topo)

    return top_module, split_submodules[0]._topo

def check_input(top_module, input_partition: Partition):
    partition_output = input_partition.get_output_vals()
    arg_pos = 0
    for node in top_module.graph.nodes:
        if node.op == 'placeholder':
            cur_checkee = partition_output[arg_pos]
            to_partition_and_offset = cur_checkee.get()
            assert len(to_partition_and_offset) == len(node.users.keys())
            arg_pos += 1
        
    assert arg_pos == len(partition_output)
        
def check_submod(top_module, part_id, mid_partition: Partition):
    partition_input = mid_partition.get_input_vals()
    partition_output = mid_partition.get_output_vals()
    
    cnt = 1
    cur_node = None
    for node in top_module.graph.nodes:
        if node.name.startswith('submod'):
            cnt += 1
        if cnt == part_id:
            cur_node = node
            break
    
    assert len(partition_input) == len(cur_node.args)
    assert len(partition_output) == len(cur_node.users)

def check_topo(top_module, topo: Topo):    
    input_partition = topo.get_input_partition()
    mid_partitions = topo.get_mid_partitions()
    
    check_input(top_module, input_partition)
    for part_id, submod in mid_partitions.items():
        check_submod(top_module, part_id, submod)