#!/usr/bin/env python # -*- encoding: utf-8 -*- import torch.nn as nn from colossalai.registry import LAYERS from .conv import conv1x1 @LAYERS.register_module class ResLayer(nn.Module): def __init__(self, block_type: str, norm_layer_type: str, inplanes: int, planes: int, blocks: int, groups: int, base_width: int, stride: int = 1, dilation: int = 1, dilate: bool = False, ): super().__init__() self.block = LAYERS.get_module(block_type) self.norm_layer = LAYERS.get_module(norm_layer_type) self.inplanes = inplanes self.planes = planes self.blocks = blocks self.groups = groups self.dilation = dilation self.base_width = base_width self.dilate = dilate self.stride = stride self.layer = self._make_layer() def _make_layer(self): norm_layer = self.norm_layer downsample = None previous_dilation = self.dilation if self.dilate: self.dilation *= self.stride self.stride = 1 if self.stride != 1 or self.inplanes != self.planes * self.block.expansion: downsample = nn.Sequential( conv1x1(self.inplanes, self.planes * self.block.expansion, self.stride), norm_layer(self.planes * self.block.expansion), ) layers = [] layers.append(self.block(self.inplanes, self.planes, self.stride, downsample, self.groups, self.base_width, previous_dilation, norm_layer)) self.inplanes = self.planes * self.block.expansion for _ in range(1, self.blocks): layers.append(self.block(self.inplanes, self.planes, groups=self.groups, base_width=self.base_width, dilation=self.dilation, norm_layer=norm_layer)) return nn.Sequential(*layers) def forward(self, x): return self.layer(x)