from torch import nn
from transformers import GPT2Config, GPT2LMHeadModel


## Define the Model and Loss Based on Huggingface transformers GPT2LMHeadModel
class GPTLMModel(nn.Module):

    def __init__(self,
                 hidden_size=768,
                 num_layers=12,
                 num_attention_heads=12,
                 max_seq_len=1024,
                 vocab_size=50257,
                 checkpoint=False):
        super().__init__()
        self.checkpoint = checkpoint
        self.config = GPT2Config(n_embd=hidden_size,
                                 n_layer=num_layers,
                                 n_head=num_attention_heads,
                                 n_positions=max_seq_len,
                                 n_ctx=max_seq_len,
                                 vocab_size=vocab_size)
        self.model = GPT2LMHeadModel(self.config)
        if checkpoint:
            self.model.gradient_checkpointing_enable()

    def forward(self, input_ids, attention_mask):
        # Only return lm_logits
        return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]


def gpt2_medium(checkpoint=False):
    return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_xl(checkpoint=True):
    return GPTLMModel(hidden_size=1600, num_layers=48, num_attention_heads=32, checkpoint=checkpoint)


def gpt2_10b(checkpoint=True):
    return GPTLMModel(hidden_size=4096, num_layers=50, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_14b(checkpoint=True):
    return GPTLMModel(hidden_size=4096, num_layers=70, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_20b(checkpoint=True):
    return GPTLMModel(hidden_size=8192, num_layers=25, num_attention_heads=16, checkpoint=checkpoint)


def gpt2_24b(checkpoint=True):
    return GPTLMModel(hidden_size=8192, num_layers=30, num_attention_heads=16, checkpoint=checkpoint)


def model_builder(model_size: str) -> callable:
    if model_size == "gpt2_medium":
        return gpt2_medium
    elif model_size == "gpt2_xl":
        return gpt2_xl
    elif model_size == "gpt2_10b":
        return gpt2_10b
    elif model_size == "gpt2_14b":
        return gpt2_14b
    elif model_size == "gpt2_20b":
        return gpt2_20b
    elif model_size == "gpt2_24b":
        return gpt2_24b
    else:
        raise TypeError(f"model_builder {model_size}")


__all__ = ['model_builder']