mirror of https://github.com/hpcaitech/ColossalAI
64 lines
2.1 KiB
Python
64 lines
2.1 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import torch.nn as nn
|
|
|
|
from colossalai.registry import LAYERS
|
|
from .conv import conv1x1
|
|
|
|
|
|
@LAYERS.register_module
|
|
class ResLayer(nn.Module):
|
|
|
|
def __init__(self,
|
|
block_type: str,
|
|
norm_layer_type: str,
|
|
inplanes: int,
|
|
planes: int,
|
|
blocks: int,
|
|
groups: int,
|
|
base_width: int,
|
|
stride: int = 1,
|
|
dilation: int = 1,
|
|
dilate: bool = False,
|
|
):
|
|
super().__init__()
|
|
self.block = LAYERS.get_module(block_type)
|
|
self.norm_layer = LAYERS.get_module(norm_layer_type)
|
|
self.inplanes = inplanes
|
|
self.planes = planes
|
|
self.blocks = blocks
|
|
self.groups = groups
|
|
self.dilation = dilation
|
|
self.base_width = base_width
|
|
self.dilate = dilate
|
|
self.stride = stride
|
|
self.layer = self._make_layer()
|
|
|
|
def _make_layer(self):
|
|
norm_layer = self.norm_layer
|
|
downsample = None
|
|
previous_dilation = self.dilation
|
|
if self.dilate:
|
|
self.dilation *= self.stride
|
|
self.stride = 1
|
|
if self.stride != 1 or self.inplanes != self.planes * self.block.expansion:
|
|
downsample = nn.Sequential(
|
|
conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride),
|
|
norm_layer(self.planes * self.block.expansion),
|
|
)
|
|
|
|
layers = []
|
|
layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups,
|
|
self.base_width, previous_dilation, norm_layer))
|
|
self.inplanes = self.planes * self.block.expansion
|
|
for _ in range(1, self.blocks):
|
|
layers.append(self.block(self.inplanes, self.planes, groups=self.groups,
|
|
base_width=self.base_width, dilation=self.dilation,
|
|
norm_layer=norm_layer))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.layer(x)
|