from collections import namedtuple

import torch
import torchvision
import torchvision.models as tm
from packaging import version

from ..registry import ModelAttribute, model_zoo

data_gen_fn = lambda: dict(x=torch.rand(4, 3, 224, 224))
output_transform_fn = lambda x: dict(output=x)

# special data gen fn
inception_v3_data_gen_fn = lambda: dict(x=torch.rand(4, 3, 299, 299))


# special model fn
def swin_s():
    from torchvision.models.swin_transformer import Swin_T_Weights, _swin_transformer

    # adapted from torchvision.models.swin_transformer.swin_small
    weights = None
    weights = Swin_T_Weights.verify(weights)
    progress = True

    return _swin_transformer(
        patch_size=[4, 4],
        embed_dim=96,
        depths=[2, 2, 6, 2],
        num_heads=[3, 6, 12, 24],
        window_size=[7, 7],
        stochastic_depth_prob=0,    # it is originally 0.2, but we set it to 0 to make it deterministic
        weights=weights,
        progress=progress,
    )


# special output transform fn
google_net_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.GoogLeNetOutputs
                                                                            ) else dict(output=x)
swin_s_output_output_transform_fn = lambda x: {f'output{idx}': val
                                               for idx, val in enumerate(x)} if isinstance(x, tuple) else dict(output=x)
inception_v3_output_transform_fn = lambda x: dict(output=sum(x)) if isinstance(x, torchvision.models.InceptionOutputs
                                                                              ) else dict(output=x)

model_zoo.register(name='torchvision_alexnet',
                   model_fn=tm.alexnet,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_densenet121',
                   model_fn=tm.densenet121,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_efficientnet_b0',
                   model_fn=tm.efficientnet_b0,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn,
                   model_attribute=ModelAttribute(has_stochastic_depth_prob=True))
model_zoo.register(name='torchvision_googlenet',
                   model_fn=tm.googlenet,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=google_net_output_transform_fn)
model_zoo.register(name='torchvision_inception_v3',
                   model_fn=tm.inception_v3,
                   data_gen_fn=inception_v3_data_gen_fn,
                   output_transform_fn=inception_v3_output_transform_fn)
model_zoo.register(name='torchvision_mobilenet_v2',
                   model_fn=tm.mobilenet_v2,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_mobilenet_v3_small',
                   model_fn=tm.mobilenet_v3_small,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_mnasnet0_5',
                   model_fn=tm.mnasnet0_5,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_resnet18',
                   model_fn=tm.resnet18,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_regnet_x_16gf',
                   model_fn=tm.regnet_x_16gf,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_resnext50_32x4d',
                   model_fn=tm.resnext50_32x4d,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_shufflenet_v2_x0_5',
                   model_fn=tm.shufflenet_v2_x0_5,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_squeezenet1_0',
                   model_fn=tm.squeezenet1_0,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)

model_zoo.register(name='torchvision_vgg11',
                   model_fn=tm.vgg11,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)
model_zoo.register(name='torchvision_wide_resnet50_2',
                   model_fn=tm.wide_resnet50_2,
                   data_gen_fn=data_gen_fn,
                   output_transform_fn=output_transform_fn)

if version.parse(torchvision.__version__) >= version.parse('0.12.0'):
    model_zoo.register(name='torchvision_vit_b_16',
                       model_fn=tm.vit_b_16,
                       data_gen_fn=data_gen_fn,
                       output_transform_fn=output_transform_fn)
    model_zoo.register(name='torchvision_convnext_base',
                       model_fn=tm.convnext_base,
                       data_gen_fn=data_gen_fn,
                       output_transform_fn=output_transform_fn,
                       model_attribute=ModelAttribute(has_stochastic_depth_prob=True))

if version.parse(torchvision.__version__) >= version.parse('0.13.0'):
    model_zoo.register(
        name='torchvision_swin_s',
        model_fn=swin_s,
        data_gen_fn=data_gen_fn,
        output_transform_fn=swin_s_output_output_transform_fn,
    )
    model_zoo.register(name='torchvision_efficientnet_v2_s',
                       model_fn=tm.efficientnet_v2_s,
                       data_gen_fn=data_gen_fn,
                       output_transform_fn=output_transform_fn,
                       model_attribute=ModelAttribute(has_stochastic_depth_prob=True))