import torch
import torch.nn as nn

__all__ = ['Accelerator']

_supported_devices = [
    'cpu',
    'cuda',

    # To be supported
    # 'xpu',
    # 'npu',
    # 'tpu',
]


class Accelerator:
    """
    Accelerator is an abstraction for the hardware device that is used to run the model.

    Args:
        device (str): The device to be used. Currently only support 'cpu' and 'gpu'.
    """

    def __init__(self, device: str):
        self.device = device

        assert self.device in _supported_devices, f"Device {self.device} is not supported yet, supported devices include {_supported_devices}"

    def bind(self):
        """
        Set the default device for the current process.
        """
        if self.device == 'cpu':
            pass
        elif self.device == 'cuda':
            # TODO(FrankLeeeee): use global environment to check if it is a dist job
            # if is_distributed:
            #     local_rank = EnvTable().get_local_rank()
            #     torch.cuda.set_device(torch.device(f'cuda:{local_rank}'))
            torch.cuda.set_device(torch.device('cuda'))
            pass
        else:
            raise ValueError(f"Device {self.device} is not supported yet")

    def configure_model(self, model: nn.Module) -> nn.Module:
        """
        Move the model to the device.

        Args:
            model (nn.Module): The model to be moved.
        """
        model = model.to(torch.device(self.device))
        return model