import copy

from colossalai.shardformer import ShardConfig, ShardFormer


def build_model(
    model_fn,
    enable_fused_normalization=False,
    enable_tensor_parallelism=False,
    enable_flash_attention=False,
    enable_jit_fused=False,
):
    # create new model
    org_model = model_fn()

    # shard model
    shard_config = ShardConfig(
        enable_fused_normalization=enable_fused_normalization,
        enable_tensor_parallelism=enable_tensor_parallelism,
        enable_flash_attention=enable_flash_attention,
        enable_jit_fused=enable_jit_fused,
    )
    model_copy = copy.deepcopy(org_model)
    shard_former = ShardFormer(shard_config=shard_config)
    sharded_model, shared_params = shard_former.optimize(model_copy)
    return org_model.cuda(), sharded_model.cuda()


def run_infer(original_model, sharded_model, data_gen_fn, output_transform_fn):
    # prepare input
    data = data_gen_fn()
    data = {k: v.cuda() for k, v in data.items()}
    # run forward
    org_output = original_model(**data)
    org_output = output_transform_fn(org_output)

    shard_output = sharded_model(**data)
    shard_output = output_transform_fn(shard_output)

    return org_output, shard_output