mirror of https://github.com/hpcaitech/ColossalAI
[tensorparallel] fixed tp layers (#1938)
parent
cf68cc92ac
commit
e52f9d9109
|
@ -77,12 +77,11 @@ class Linear1D(ColossalaiModule):
|
||||||
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||||||
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1)):
|
||||||
parallel_input = get_parallel_input()
|
parallel_input = get_parallel_input()
|
||||||
if not parallel_input:
|
if not parallel_input and not gather_output:
|
||||||
layer = Linear1D_Col(in_features,
|
layer = Linear1D_Col(in_features,
|
||||||
out_features,
|
out_features,
|
||||||
bias=bias,
|
bias=bias,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
gather_output=gather_output,
|
|
||||||
skip_bias_add=skip_bias_add,
|
skip_bias_add=skip_bias_add,
|
||||||
weight_initializer=weight_initializer,
|
weight_initializer=weight_initializer,
|
||||||
bias_initializer=bias_initializer)
|
bias_initializer=bias_initializer)
|
||||||
|
|
|
@ -4,13 +4,15 @@
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from colossalai.communication import (all_gather, all_reduce, broadcast, reduce, reduce_scatter)
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.core import global_context as gpc
|
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
from torch.cuda.amp import custom_bwd, custom_fwd
|
from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
from ._utils import get_parallel_mode_from_env, push_async_grad
|
|
||||||
|
from colossalai.communication import all_gather, all_reduce, broadcast, reduce, reduce_scatter
|
||||||
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
from colossalai.constants import INPUT_GROUP_3D, WEIGHT_GROUP_3D
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
from ._utils import get_parallel_mode_from_env, push_async_grad
|
||||||
|
|
||||||
|
|
||||||
class _Linear3D(torch.autograd.Function):
|
class _Linear3D(torch.autograd.Function):
|
||||||
|
@ -44,18 +46,17 @@ class _Linear3D(torch.autograd.Function):
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
input_, weight = ctx.saved_tensors
|
input_, weight = ctx.saved_tensors
|
||||||
with torch.no_grad():
|
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
||||||
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
|
||||||
|
|
||||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||||
|
|
||||||
weight_grad = torch.matmul(
|
weight_grad = torch.matmul(
|
||||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||||
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
|
weight_grad, op = reduce_scatter(weight_grad, -1, ctx.weight_parallel_mode, async_op=True)
|
||||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||||
|
|
||||||
input_op.wait()
|
input_op.wait()
|
||||||
|
|
||||||
return input_grad, weight_grad, None, None, None, None
|
return input_grad, weight_grad, None, None, None, None
|
||||||
|
|
||||||
|
@ -129,25 +130,24 @@ class _Classifier3D(torch.autograd.Function):
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
input_, weight = ctx.saved_tensors
|
input_, weight = ctx.saved_tensors
|
||||||
with torch.no_grad():
|
weight_grad = torch.matmul(
|
||||||
weight_grad = torch.matmul(
|
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
|
||||||
output_grad.reshape(-1, output_grad.shape[-1]).transpose(0, 1), input_.reshape(-1, input_.shape[-1]))
|
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
|
||||||
weight_grad = reduce(weight_grad, ctx.src_rank, ctx.input_parallel_mode)
|
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
||||||
if gpc.get_local_rank(ctx.input_parallel_mode) == gpc.get_local_rank(ctx.output_parallel_mode):
|
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
||||||
weight_grad, op = all_reduce(weight_grad, ctx.weight_parallel_mode, async_op=True)
|
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
else:
|
||||||
else:
|
weight_grad = None
|
||||||
weight_grad = None
|
|
||||||
|
|
||||||
if ctx.use_bias:
|
if ctx.use_bias:
|
||||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||||
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
|
bias_grad = all_reduce(bias_grad, ctx.input_parallel_mode)
|
||||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||||
else:
|
else:
|
||||||
bias_grad = None
|
bias_grad = None
|
||||||
|
|
||||||
input_grad = torch.matmul(output_grad, weight)
|
input_grad = torch.matmul(output_grad, weight)
|
||||||
|
|
||||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||||
|
|
||||||
|
@ -224,25 +224,24 @@ class _VocabParallelClassifier3D(torch.autograd.Function):
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
input_, weight = ctx.saved_tensors
|
input_, weight = ctx.saved_tensors
|
||||||
with torch.no_grad():
|
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
||||||
output_grad = all_gather(output_grad, 0, ctx.output_parallel_mode)
|
|
||||||
|
|
||||||
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
input_grad = torch.matmul(output_grad, weight.transpose(0, 1))
|
||||||
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
input_grad, input_op = reduce_scatter(input_grad, 0, ctx.input_parallel_mode, async_op=True)
|
||||||
|
|
||||||
weight_grad = torch.matmul(
|
weight_grad = torch.matmul(
|
||||||
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
input_.reshape(-1, input_.shape[-1]).transpose(0, 1), output_grad.reshape(-1, output_grad.shape[-1]))
|
||||||
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
|
weight_grad, op = reduce_scatter(weight_grad.transpose(0, 1), 0, ctx.weight_parallel_mode, async_op=True)
|
||||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||||
|
|
||||||
if ctx.use_bias:
|
if ctx.use_bias:
|
||||||
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
bias_grad = torch.sum(output_grad, dim=tuple(range(len(output_grad.shape))[:-1]))
|
||||||
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
bias_grad, op = all_reduce(bias_grad, ctx.weight_parallel_mode, async_op=True)
|
||||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||||
else:
|
else:
|
||||||
bias_grad = None
|
bias_grad = None
|
||||||
|
|
||||||
input_op.wait()
|
input_op.wait()
|
||||||
|
|
||||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
return input_grad, weight_grad, bias_grad, None, None, None, None, None
|
||||||
|
|
||||||
|
@ -281,6 +280,30 @@ def vocab_parallel_classifier_3d(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def norm_forward(x, mean, sqr_mean, weight, bias, eps):
|
||||||
|
mu = x - mean
|
||||||
|
var = sqr_mean - mean**2
|
||||||
|
sigma = torch.sqrt(var + eps)
|
||||||
|
z = mu / sigma
|
||||||
|
output = weight * z + bias
|
||||||
|
|
||||||
|
return output, mu, sigma
|
||||||
|
|
||||||
|
|
||||||
|
@torch.jit.script
|
||||||
|
def norm_backward(grad, mu, sigma, weight):
|
||||||
|
# dbias, dweight = grad, grad * mu / sigma
|
||||||
|
dz = grad * weight
|
||||||
|
dmu = dz / sigma
|
||||||
|
dvar = dz * mu * (-0.5) * sigma**(-3)
|
||||||
|
dmean = -dmu
|
||||||
|
dvar = torch.sum(dvar, -1, keepdim=True)
|
||||||
|
dmean = torch.sum(dmean, -1, keepdim=True)
|
||||||
|
|
||||||
|
return dmu, dmean, dvar
|
||||||
|
|
||||||
|
|
||||||
class _Layernorm3D(torch.autograd.Function):
|
class _Layernorm3D(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -294,27 +317,21 @@ class _Layernorm3D(torch.autograd.Function):
|
||||||
bias_id: int,
|
bias_id: int,
|
||||||
normalized_shape: int,
|
normalized_shape: int,
|
||||||
eps: float,
|
eps: float,
|
||||||
input_parallel_mode: ParallelMode,
|
|
||||||
weight_parallel_mode: ParallelMode,
|
|
||||||
output_parallel_mode: ParallelMode,
|
output_parallel_mode: ParallelMode,
|
||||||
input_x_weight_parallel_mode: ParallelMode,
|
input_x_weight_parallel_mode: ParallelMode,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
ctx.weight_id = weight_id
|
ctx.weight_id = weight_id
|
||||||
ctx.bias_id = bias_id
|
ctx.bias_id = bias_id
|
||||||
|
|
||||||
mean = all_reduce(torch.sum(input_, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
sum_ = torch.sum(input_, dim=-1, keepdim=True)
|
||||||
mu = input_ - mean
|
sqr_sum = torch.sum(input_**2, dim=-1, keepdim=True)
|
||||||
var = all_reduce(torch.sum(mu**2, dim=-1, keepdim=True), output_parallel_mode) / normalized_shape
|
mean, sqr_mean = all_reduce(torch.stack((sum_, sqr_sum)), output_parallel_mode) / normalized_shape
|
||||||
sigma = torch.sqrt(var + eps)
|
|
||||||
|
output, mu, sigma = norm_forward(input_, mean, sqr_mean, weight, bias, eps)
|
||||||
|
|
||||||
ctx.save_for_backward(mu, sigma, weight)
|
ctx.save_for_backward(mu, sigma, weight)
|
||||||
|
|
||||||
z = mu / sigma
|
|
||||||
output = weight * z + bias
|
|
||||||
|
|
||||||
ctx.normalized_shape = normalized_shape
|
ctx.normalized_shape = normalized_shape
|
||||||
ctx.input_parallel_mode = input_parallel_mode
|
|
||||||
ctx.weight_parallel_mode = weight_parallel_mode
|
|
||||||
ctx.output_parallel_mode = output_parallel_mode
|
ctx.output_parallel_mode = output_parallel_mode
|
||||||
ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode
|
ctx.input_x_weight_parallel_mode = input_x_weight_parallel_mode
|
||||||
|
|
||||||
|
@ -324,23 +341,18 @@ class _Layernorm3D(torch.autograd.Function):
|
||||||
@custom_bwd
|
@custom_bwd
|
||||||
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
def backward(ctx, output_grad: Tensor) -> Tuple[Tensor, ...]:
|
||||||
mu, sigma, weight = ctx.saved_tensors
|
mu, sigma, weight = ctx.saved_tensors
|
||||||
with torch.no_grad():
|
|
||||||
|
|
||||||
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
|
bias_grad, weight_grad = output_grad, output_grad * mu / sigma
|
||||||
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
|
bias_grad = torch.sum(bias_grad, dim=tuple(range(len(bias_grad.shape))[:-1]))
|
||||||
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
bias_grad, op = all_reduce(bias_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
||||||
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
bias_grad = push_async_grad(op, bias_grad, ctx.bias_id)
|
||||||
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
|
weight_grad = torch.sum(weight_grad, dim=tuple(range(len(weight_grad.shape))[:-1]))
|
||||||
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
weight_grad, op = all_reduce(weight_grad, ctx.input_x_weight_parallel_mode, async_op=True)
|
||||||
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
weight_grad = push_async_grad(op, weight_grad, ctx.weight_id)
|
||||||
|
|
||||||
dz = output_grad * weight
|
dmu, dmean, dvar = norm_backward(output_grad, mu, sigma, weight)
|
||||||
dvar = dz * mu * (-0.5) * sigma**(-3)
|
dvar, dmean = all_reduce(torch.stack((dvar, dmean)), ctx.output_parallel_mode)
|
||||||
dvar = all_reduce(torch.sum(dvar, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
input_grad = dmu + (dmean + 2 * dvar * mu) / ctx.normalized_shape
|
||||||
dmean = dz * (-1 / sigma) + dvar * -2 * mu / ctx.normalized_shape
|
|
||||||
dmean = all_reduce(torch.sum(dmean, dim=-1, keepdim=True), ctx.output_parallel_mode)
|
|
||||||
|
|
||||||
input_grad = dz / sigma + dvar * 2 * mu / ctx.normalized_shape + dmean / ctx.normalized_shape
|
|
||||||
|
|
||||||
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None
|
return input_grad, weight_grad, bias_grad, None, None, None, None, None, None, None, None
|
||||||
|
|
||||||
|
@ -351,8 +363,6 @@ def layernorm_3d(
|
||||||
bias: Tensor,
|
bias: Tensor,
|
||||||
normalized_shape: int,
|
normalized_shape: int,
|
||||||
eps: float,
|
eps: float,
|
||||||
input_parallel_mode: ParallelMode,
|
|
||||||
weight_parallel_mode: ParallelMode,
|
|
||||||
output_parallel_mode: ParallelMode,
|
output_parallel_mode: ParallelMode,
|
||||||
input_x_weight_parallel_mode: ParallelMode,
|
input_x_weight_parallel_mode: ParallelMode,
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
|
@ -368,9 +378,8 @@ def layernorm_3d(
|
||||||
If a single integer is used, it is treated as a singleton list, and this module will
|
If a single integer is used, it is treated as a singleton list, and this module will
|
||||||
normalize over the last dimension which is expected to be of that specific size.
|
normalize over the last dimension which is expected to be of that specific size.
|
||||||
eps (float): a value added to the denominator for numerical stability
|
eps (float): a value added to the denominator for numerical stability
|
||||||
input_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input parallel mode.
|
|
||||||
weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): weight parallel mode.
|
|
||||||
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
output_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): output parallel mode.
|
||||||
|
input_x_weight_parallel_mode (:class:`colossalai.context.parallel_mode.ParallelMode`): input x weight parallel mode.
|
||||||
|
|
||||||
Note:
|
Note:
|
||||||
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
The parallel_mode should be concluded in ``ParallelMode``. More details about ``ParallelMode`` could be found
|
||||||
|
@ -384,8 +393,6 @@ def layernorm_3d(
|
||||||
id(bias),
|
id(bias),
|
||||||
normalized_shape,
|
normalized_shape,
|
||||||
eps,
|
eps,
|
||||||
input_parallel_mode,
|
|
||||||
weight_parallel_mode,
|
|
||||||
output_parallel_mode,
|
output_parallel_mode,
|
||||||
input_x_weight_parallel_mode,
|
input_x_weight_parallel_mode,
|
||||||
)
|
)
|
||||||
|
|
|
@ -5,6 +5,9 @@ from typing import Callable
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from torch import Tensor
|
||||||
|
from torch.nn import Parameter
|
||||||
|
|
||||||
from colossalai.communication import all_reduce, broadcast
|
from colossalai.communication import all_reduce, broadcast
|
||||||
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
|
from colossalai.constants import INPUT_GROUP_3D, INPUT_X_WEIGHT_3D, OUTPUT_GROUP_3D, OUTPUT_X_WEIGHT_3D, WEIGHT_GROUP_3D
|
||||||
from colossalai.context import ParallelMode, seed
|
from colossalai.context import ParallelMode, seed
|
||||||
|
@ -13,16 +16,25 @@ from colossalai.global_variables import tensor_parallel_env as env
|
||||||
from colossalai.nn import init as init
|
from colossalai.nn import init as init
|
||||||
from colossalai.nn.layer.base_layer import ParallelLayer
|
from colossalai.nn.layer.base_layer import ParallelLayer
|
||||||
from colossalai.registry import LAYERS
|
from colossalai.registry import LAYERS
|
||||||
from colossalai.utils.checkpointing import (broadcast_state_dict, gather_tensor_parallel_state_dict,
|
from colossalai.utils.checkpointing import (
|
||||||
partition_tensor_parallel_state_dict)
|
broadcast_state_dict,
|
||||||
|
gather_tensor_parallel_state_dict,
|
||||||
|
partition_tensor_parallel_state_dict,
|
||||||
|
)
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from torch import Tensor
|
|
||||||
from torch.nn import Parameter
|
|
||||||
|
|
||||||
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
from ..utils import divide, set_tensor_parallel_attribute_by_partition, to_2tuple
|
||||||
from ._operation import (all_gather_tensor_3d, classifier_3d, vocab_parallel_classifier_3d, layernorm_3d, linear_3d,
|
from ._operation import (
|
||||||
reduce_scatter_tensor_3d, split_tensor_3d, split_batch_3d)
|
all_gather_tensor_3d,
|
||||||
from ._utils import get_depth_from_env, get_parallel_mode_from_env, swap_in_out_group, register_async_grad_hook
|
classifier_3d,
|
||||||
|
layernorm_3d,
|
||||||
|
linear_3d,
|
||||||
|
reduce_scatter_tensor_3d,
|
||||||
|
split_batch_3d,
|
||||||
|
split_tensor_3d,
|
||||||
|
vocab_parallel_classifier_3d,
|
||||||
|
)
|
||||||
|
from ._utils import get_depth_from_env, get_parallel_mode_from_env, register_async_grad_hook, swap_in_out_group
|
||||||
|
|
||||||
|
|
||||||
@LAYERS.register_module
|
@LAYERS.register_module
|
||||||
|
@ -144,8 +156,6 @@ class LayerNorm3D(ParallelLayer):
|
||||||
self.bias,
|
self.bias,
|
||||||
self.normalized_shape,
|
self.normalized_shape,
|
||||||
self.variance_epsilon,
|
self.variance_epsilon,
|
||||||
self.input_parallel_mode,
|
|
||||||
self.weight_parallel_mode,
|
|
||||||
self.output_parallel_mode,
|
self.output_parallel_mode,
|
||||||
self.input_x_weight_parallel_mode,
|
self.input_x_weight_parallel_mode,
|
||||||
)
|
)
|
||||||
|
@ -900,7 +910,7 @@ class PatchEmbedding3D(ParallelLayer):
|
||||||
weight_parallel_mode=self.weight_parallel_mode)
|
weight_parallel_mode=self.weight_parallel_mode)
|
||||||
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
output = F.conv2d(input_, self.weight, self.bias, stride=self.patch_size)
|
||||||
if self.flatten:
|
if self.flatten:
|
||||||
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
output = output.flatten(2).transpose(1, 2) # BCHW -> BNC
|
||||||
|
|
||||||
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
cls_token = self.cls_token.expand(output.shape[0], -1, -1)
|
||||||
output = torch.cat((cls_token, output), dim=1)
|
output = torch.cat((cls_token, output), dim=1)
|
||||||
|
|
Loading…
Reference in New Issue