import importlib
import os
import time
from abc import abstractmethod
from pathlib import Path
from typing import List

from .base_extension import _Extension

__all__ = ["_CppExtension"]


class _CppExtension(_Extension):
    def __init__(self, name: str, priority: int = 1):
        super().__init__(name, support_aot=True, support_jit=True, priority=priority)

        # we store the op as an attribute to avoid repeated building and loading
        self.cached_op = None

        # build-related variables
        self.prebuilt_module_path = "colossalai._C"
        self.prebuilt_import_path = f"{self.prebuilt_module_path}.{self.name}"
        self.version_dependent_macros = ["-DVERSION_GE_1_1", "-DVERSION_GE_1_3", "-DVERSION_GE_1_5"]

    def csrc_abs_path(self, path):
        return os.path.join(self.relative_to_abs_path("csrc"), path)

    def relative_to_abs_path(self, code_path: str) -> str:
        """
        This function takes in a path relative to the colossalai root directory and return the absolute path.
        """

        # get the current file path
        # iteratively check the parent directory
        # if the parent directory is "extensions", then the current file path is the root directory
        # otherwise, the current file path is inside the root directory
        current_file_path = Path(__file__)
        while True:
            if current_file_path.name == "extensions":
                break
            else:
                current_file_path = current_file_path.parent
        extension_module_path = current_file_path
        code_abs_path = extension_module_path.joinpath(code_path)
        return str(code_abs_path)

    # functions must be overrided over
    def strip_empty_entries(self, args):
        """
        Drop any empty strings from the list of compile and link flags
        """
        return [x for x in args if len(x) > 0]

    def import_op(self):
        """
        This function will import the op module by its string name.
        """
        return importlib.import_module(self.prebuilt_import_path)

    def build_aot(self) -> "CppExtension":
        from torch.utils.cpp_extension import CppExtension

        return CppExtension(
            name=self.prebuilt_import_path,
            sources=self.strip_empty_entries(self.sources_files()),
            include_dirs=self.strip_empty_entries(self.include_dirs()),
            extra_compile_args=self.strip_empty_entries(self.cxx_flags()),
        )

    def build_jit(self) -> None:
        from torch.utils.cpp_extension import load

        build_directory = _Extension.get_jit_extension_folder_path()
        build_directory = Path(build_directory)
        build_directory.mkdir(parents=True, exist_ok=True)

        # check if the kernel has been built
        compiled_before = False
        kernel_file_path = build_directory.joinpath(f"{self.name}.o")
        if kernel_file_path.exists():
            compiled_before = True

        # load the kernel
        if compiled_before:
            print(f"[extension] Loading the JIT-built {self.name} kernel during runtime now")
        else:
            print(f"[extension] Compiling the JIT {self.name} kernel during runtime now")

        build_start = time.time()
        op_kernel = load(
            name=self.name,
            sources=self.strip_empty_entries(self.sources_files()),
            extra_include_paths=self.strip_empty_entries(self.include_dirs()),
            extra_cflags=self.cxx_flags(),
            extra_ldflags=[],
            build_directory=str(build_directory),
        )
        build_duration = time.time() - build_start

        if compiled_before:
            print(f"[extension] Time taken to load {self.name} op: {build_duration} seconds")
        else:
            print(f"[extension] Time taken to compile {self.name} op: {build_duration} seconds")

        return op_kernel

    # functions must be overrided begin
    @abstractmethod
    def sources_files(self) -> List[str]:
        """
        This function should return a list of source files for extensions.
        """

    @abstractmethod
    def include_dirs(self) -> List[str]:
        """
        This function should return a list of include files for extensions.
        """

    @abstractmethod
    def cxx_flags(self) -> List[str]:
        """
        This function should return a list of cxx compilation flags for extensions.
        """

    def load(self):
        try:
            op_kernel = self.import_op()
        except (ImportError, ModuleNotFoundError):
            # if import error occurs, it means that the kernel is not pre-built
            # so we build it jit
            op_kernel = self.build_jit()

        return op_kernel