# Copyright (c) InternLM. All rights reserved.
import argparse
import json
import os

import torch
from einops import rearrange
from tqdm import tqdm
from transformers import AutoConfig, LlamaConfig, LlamaTokenizer


def save_conifg(config, tgt):
    config_dict = config.to_dict()
    unnecessary_keys = [
        '_name_or_path',
        'auto_map',
        'transformers_version',
        'model_type',
        'architectures',
        'tokenizer_class',
        'attn_implementation',
    ]
    for k in unnecessary_keys:
        config_dict.pop(k, None)
    config_dict['attention_bias'] = config_dict.pop('bias')
    config_dict['architectures'] = ['LlamaForCausalLM']
    llama_config = LlamaConfig(**config_dict)
    llama_config.save_pretrained(tgt)


def convert(src, tgt):
    """Convert InternLM2 huggingface checkpoints to Llama-style."""

    print('Convert InternLM2 huggingface checkpoints to Llama...')

    config = AutoConfig.from_pretrained(src, trust_remote_code=True)
    assert not config.bias, 'Cannot convert InternLM Model with bias to LLaMA.'

    head_dim = config.hidden_size // config.num_attention_heads
    num_key_value_groups = config.num_attention_heads \
        // config.num_key_value_heads

    # load index json file
    index_file = os.path.join(src, 'pytorch_model.bin.index.json')
    if os.path.exists(index_file):
        with open(index_file) as fp:
            index_dict = json.load(fp)
            index_dict['weight_map'] = {}
    else:
        index_dict = None

    os.makedirs(tgt, exist_ok=True)
    for filename in tqdm(os.listdir(src)):
        if not filename.endswith('.bin'):
            continue
        states = torch.load(os.path.join(src, filename))
        llama_states = {}
        for k, v in states.copy().items():
            if 'wqkv' in k:
                v = rearrange(
                    v,
                    '(h gs d) dim -> h gs d dim',
                    gs=2 + num_key_value_groups,
                    d=head_dim,
                )
                wq, wk, wv = torch.split(v, [num_key_value_groups, 1, 1],
                                         dim=1)
                wq = rearrange(wq, 'h gs d dim -> (h gs d) dim')
                wk = rearrange(wk, 'h gs d dim -> (h gs d) dim')
                wv = rearrange(wv, 'h gs d dim -> (h gs d) dim')
                _prefix = k.split('attention')[0]
                wq_key = _prefix + 'self_attn.q_proj.weight'
                wk_key = _prefix + 'self_attn.k_proj.weight'
                wv_key = _prefix + 'self_attn.v_proj.weight'
                llama_states[wq_key] = wq.clone()
                llama_states[wk_key] = wk.clone()
                llama_states[wv_key] = wv.clone()

            elif 'attention.wo' in k:
                new_k = k.replace('attention.wo', 'self_attn.o_proj')
                llama_states[new_k] = v
            elif 'feed_forward.w1' in k:
                new_k = k.replace('feed_forward.w1', 'mlp.gate_proj')
                llama_states[new_k] = v
            elif 'feed_forward.w2' in k:
                new_k = k.replace('feed_forward.w2', 'mlp.down_proj')
                llama_states[new_k] = v
            elif 'feed_forward.w3' in k:
                new_k = k.replace('feed_forward.w3', 'mlp.up_proj')
                llama_states[new_k] = v
            elif 'attention_norm' in k:
                new_k = k.replace('attention_norm', 'input_layernorm')
                llama_states[new_k] = v
            elif 'ffn_norm' in k:
                new_k = k.replace('ffn_norm', 'post_attention_layernorm')
                llama_states[new_k] = v
            elif 'tok_embeddings' in k:
                llama_states['model.embed_tokens.weight'] = v
            elif 'output' in k:
                llama_states['lm_head.weight'] = v
            else:
                llama_states[k] = v

        if index_dict is not None:
            for k in llama_states:
                index_dict['weight_map'][k] = filename
        print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)
        torch.save(llama_states, os.path.join(tgt, filename))
        del states

    print('Saving config and tokenizer...')
    # index.json
    if index_dict is not None:
        with open(os.path.join(tgt, 'pytorch_model.bin.index.json'),
                  'w') as fp:
            json.dump(index_dict, fp, indent=2)
    # tokenizer
    tokenizer = LlamaTokenizer.from_pretrained(src)
    tokenizer.init_kwargs.pop('auto_map', None)
    tokenizer.save_pretrained(tgt)
    # config
    save_conifg(config, tgt)
    print('Done!')


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--src', type=str, help='Input folder')
    parser.add_argument('--tgt', type=str, help='Output folder')

    args = parser.parse_args()

    return args


if __name__ == '__main__':
    args = parse_args()

    convert(args.src, args.tgt)