mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
55 lines
1.5 KiB
55 lines
1.5 KiB
import torch |
|
from colossalai.utils.model.utils import call_to_str |
|
|
|
class LayerSpec: |
|
""" |
|
|
|
""" |
|
|
|
def __init__(self, typename, *module_args, **module_kwargs): |
|
self.typename = typename |
|
self.module_args = module_args |
|
self.module_kwargs = module_kwargs |
|
self.children = None |
|
self._param_count = 0 |
|
|
|
if not issubclass(typename, torch.nn.Module): |
|
raise RuntimeError('LayerSpec only supports torch.nn.Module types.') |
|
|
|
def __repr__(self): |
|
return call_to_str(self.typename.__name__, self.module_args, self.module_kwargs) |
|
|
|
@property |
|
def param_count(self): |
|
return self._param_count |
|
|
|
def build(self): |
|
"""Build the stored specification.""" |
|
|
|
recovered_args = [] |
|
for obj in self.module_args: |
|
if isinstance(obj, LayerSpec): |
|
obj = obj.build() |
|
recovered_args.append(obj) |
|
recovered_args = tuple(recovered_args) |
|
|
|
recovered_kwargs = {} |
|
for k, v in self.module_kwargs.items(): |
|
if isinstance(v, LayerSpec): |
|
v = v.build() |
|
recovered_kwargs[k] = v |
|
|
|
return self.typename(*recovered_args, **recovered_kwargs) |
|
|
|
def set_children(self, children): |
|
self.children = children |
|
|
|
def count_params(self): |
|
self._param_count = 0 |
|
layer = self.build() |
|
for param in layer.parameters(): |
|
self._param_count += param.numel() |
|
return self._param_count |
|
|
|
def reset_param_count(self): |
|
self._param_count = 0 |