mirror of https://github.com/hpcaitech/ColossalAI
[NFC] polish colossalai/nn/layer/parallel_3d/layers.py code style (#966)
parent
955463e542
commit
fb5bc6cb28
|
@ -53,8 +53,8 @@ class LayerNorm3D(ParallelLayer):
|
||||||
self.weight = Parameter(
|
self.weight = Parameter(
|
||||||
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
torch.ones(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
||||||
if bias:
|
if bias:
|
||||||
self.bias = Parameter(torch.zeros(self.normalized_shape_per_partition,
|
self.bias = Parameter(
|
||||||
device=get_current_device(), dtype=dtype))
|
torch.zeros(self.normalized_shape_per_partition, device=get_current_device(), dtype=dtype))
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
self.variance_epsilon = eps
|
self.variance_epsilon = eps
|
||||||
|
@ -854,7 +854,7 @@ class PatchEmbedding3D(ParallelLayer):
|
||||||
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
input_ = split_tensor_3d(input_, 0, self.input_parallel_mode)
|
||||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||||
if self.flatten:
|
if self.flatten:
|
||||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||||
|
|
||||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||||
output = torch.cat((cls_token, output), dim=1)
|
output = torch.cat((cls_token, output), dim=1)
|
||||||
|
|
Loading…
Reference in New Issue