mirror of https://github.com/hpcaitech/ColossalAI
108 lines
3.4 KiB
Python
108 lines
3.4 KiB
Python
|
import math
|
||
|
from typing import Callable, Optional
|
||
|
|
||
|
from colossalai.utils import get_current_device
|
||
|
from torch import dtype, nn
|
||
|
|
||
|
from ... import init as init
|
||
|
from ..parallel_1d import *
|
||
|
from ..parallel_2d import *
|
||
|
from ..parallel_2p5d import *
|
||
|
from ..parallel_3d import *
|
||
|
from ..utils import get_tensor_parallel_mode
|
||
|
from ..vanilla import *
|
||
|
|
||
|
_parallel_embedding = {'1d': Embedding1D, '2d': Embedding2D, '2.5d': Embedding2p5D, '3d': Embedding3D}
|
||
|
|
||
|
_parallel_patchembedding = {
|
||
|
'None': VanillaPatchEmbedding,
|
||
|
'1d': VanillaPatchEmbedding,
|
||
|
'2d': PatchEmbedding2D,
|
||
|
'2.5d': PatchEmbedding2p5D,
|
||
|
'3d': PatchEmbedding3D
|
||
|
}
|
||
|
|
||
|
|
||
|
class Embedding(nn.Module):
|
||
|
def __init__(self,
|
||
|
num_embeddings: int,
|
||
|
embedding_dim: int,
|
||
|
padding_idx: int = None,
|
||
|
dtype: dtype = None,
|
||
|
weight_initializer: Callable = init.normal_(),
|
||
|
*args,
|
||
|
**kwargs) -> None:
|
||
|
super().__init__()
|
||
|
tensor_parallel = get_tensor_parallel_mode()
|
||
|
if tensor_parallel == 'None':
|
||
|
self.embed = nn.Embedding(num_embeddings,
|
||
|
embedding_dim,
|
||
|
padding_idx=padding_idx,
|
||
|
device=get_current_device(),
|
||
|
dtype=dtype,
|
||
|
*args,
|
||
|
**kwargs)
|
||
|
weight_initializer(self.embed.weight, fan_in=num_embeddings, fan_out=embedding_dim)
|
||
|
else:
|
||
|
self.embed = _parallel_embedding[tensor_parallel](
|
||
|
num_embeddings,
|
||
|
embedding_dim,
|
||
|
padding_idx=padding_idx,
|
||
|
dtype=dtype,
|
||
|
weight_initializer=weight_initializer,
|
||
|
*args,
|
||
|
**kwargs,
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def weight(self):
|
||
|
return self.embed.weight
|
||
|
|
||
|
def forward(self, *args):
|
||
|
return self.embed(*args)
|
||
|
|
||
|
|
||
|
class PatchEmbedding(nn.Module):
|
||
|
def __init__(self,
|
||
|
img_size: int,
|
||
|
patch_size: int,
|
||
|
in_chans: int,
|
||
|
embed_size: int,
|
||
|
dtype: dtype = None,
|
||
|
flatten: bool = True,
|
||
|
weight_initializer: Callable = init.kaiming_uniform_(a=math.sqrt(5)),
|
||
|
bias_initializer: Callable = init.xavier_uniform_(a=1, scale=1),
|
||
|
position_embed_initializer: Callable = init.zeros_()) -> None:
|
||
|
super().__init__()
|
||
|
tensor_parallel = get_tensor_parallel_mode()
|
||
|
self.embed = _parallel_patchembedding[tensor_parallel](
|
||
|
img_size,
|
||
|
patch_size,
|
||
|
in_chans,
|
||
|
embed_size,
|
||
|
dtype=dtype,
|
||
|
flatten=flatten,
|
||
|
weight_initializer=weight_initializer,
|
||
|
bias_initializer=bias_initializer,
|
||
|
position_embed_initializer=position_embed_initializer,
|
||
|
)
|
||
|
|
||
|
@property
|
||
|
def weight(self):
|
||
|
return self.embed.weight
|
||
|
|
||
|
@property
|
||
|
def bias(self):
|
||
|
return self.embed.bias
|
||
|
|
||
|
@property
|
||
|
def pos_embed(self):
|
||
|
return self.embed.pos_embed
|
||
|
|
||
|
@property
|
||
|
def cls_token(self):
|
||
|
return self.embed.cls_token
|
||
|
|
||
|
def forward(self, *args):
|
||
|
return self.embed(*args)
|