import torch
from colossalai.utils.model.utils import call_to_str

class LayerSpec:
    """
    
    """

    def __init__(self, typename, *module_args, **module_kwargs):
        self.typename = typename
        self.module_args = module_args
        self.module_kwargs = module_kwargs
        self.children = None
        self._param_count = 0

        if not issubclass(typename, torch.nn.Module):
            raise RuntimeError('LayerSpec only supports torch.nn.Module types.')

    def __repr__(self):
        return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)

    @property
    def param_count(self):
        return self._param_count

    def build(self):
        """Build the stored specification."""

        recovered_args = []
        for obj in self.module_args:
            if isinstance(obj, LayerSpec):
                obj = obj.build()
            recovered_args.append(obj)
        recovered_args = tuple(recovered_args)

        recovered_kwargs = {}
        for k, v in self.module_kwargs.items():
            if isinstance(v, LayerSpec):
                v = v.build()
            recovered_kwargs[k] = v

        return self.typename(*recovered_args, **recovered_kwargs)

    def set_children(self, children):
        self.children = children

    def count_params(self):
        self._param_count = 0
        layer = self.build()
        for param in layer.parameters():
            self._param_count += param.numel()
        return self._param_count

    def reset_param_count(self):
        self._param_count = 0