mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
38 lines
1.2 KiB
38 lines
1.2 KiB
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
from abc import ABC, abstractmethod
|
|
|
|
import torch.nn as nn
|
|
|
|
from colossalai.builder import build_layer
|
|
|
|
|
|
class ModelFromConfig(nn.Module, ABC):
|
|
|
|
def __init__(self):
|
|
super(ModelFromConfig, self).__init__()
|
|
self.layers = nn.ModuleList()
|
|
self.layers_cfg = []
|
|
|
|
def build_from_cfg(self, start=None, end=None):
|
|
assert hasattr(self, 'layers_cfg'), 'Cannot find attribute layers_cfg from the module, please check the ' \
|
|
'spelling and if you have initialized this variable'
|
|
if start is None:
|
|
start = 0
|
|
if end is None:
|
|
end = len(self.layers_cfg)
|
|
for cfg in self.layers_cfg[start: end]:
|
|
layer = build_layer(cfg)
|
|
self.layers.append(layer)
|
|
|
|
@abstractmethod
|
|
def init_weights(self):
|
|
pass
|
|
|
|
def state_dict_for_save_checkpoint(self, destination=None, prefix='',
|
|
keep_vars=False):
|
|
"""Use this function to override the state dict for
|
|
saving checkpoints."""
|
|
return self.state_dict(destination, prefix, keep_vars)
|