mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
1.5 KiB
55 lines
1.5 KiB
3 years ago
|
import torch
|
||
|
from colossalai.utils.model.utils import call_to_str
|
||
|
|
||
|
class LayerSpec:
|
||
|
"""
|
||
|
|
||
|
"""
|
||
|
|
||
|
def __init__(self, typename, *module_args, **module_kwargs):
|
||
|
self.typename = typename
|
||
|
self.module_args = module_args
|
||
|
self.module_kwargs = module_kwargs
|
||
|
self.children = None
|
||
|
self._param_count = 0
|
||
|
|
||
|
if not issubclass(typename, torch.nn.Module):
|
||
|
raise RuntimeError('LayerSpec only supports torch.nn.Module types.')
|
||
|
|
||
|
def __repr__(self):
|
||
|
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs)
|
||
|
|
||
|
@property
|
||
|
def param_count(self):
|
||
|
return self._param_count
|
||
|
|
||
|
def build(self):
|
||
|
"""Build the stored specification."""
|
||
|
|
||
|
recovered_args = []
|
||
|
for obj in self.module_args:
|
||
|
if isinstance(obj, LayerSpec):
|
||
|
obj = obj.build()
|
||
|
recovered_args.append(obj)
|
||
|
recovered_args = tuple(recovered_args)
|
||
|
|
||
|
recovered_kwargs = {}
|
||
|
for k, v in self.module_kwargs.items():
|
||
|
if isinstance(v, LayerSpec):
|
||
|
v = v.build()
|
||
|
recovered_kwargs[k] = v
|
||
|
|
||
|
return self.typename(*recovered_args, **recovered_kwargs)
|
||
|
|
||
|
def set_children(self, children):
|
||
|
self.children = children
|
||
|
|
||
|
def count_params(self):
|
||
|
self._param_count = 0
|
||
|
layer = self.build()
|
||
|
for param in layer.parameters():
|
||
|
self._param_count += param.numel()
|
||
|
return self._param_count
|
||
|
|
||
|
def reset_param_count(self):
|
||
|
self._param_count = 0
|