mirror of https://github.com/hpcaitech/ColossalAI
37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
|
import abc
|
||
|
import torch.nn as nn
|
||
|
|
||
|
|
||
|
class BaseEmbeddingBag(abc.ABC, nn.Module):
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
num_embeddings,
|
||
|
embedding_dim,
|
||
|
padding_idx=None,
|
||
|
max_norm=None,
|
||
|
norm_type=2.,
|
||
|
scale_grad_by_freq=False,
|
||
|
sparse=False,
|
||
|
mode='mean',
|
||
|
include_last_offset=False,
|
||
|
):
|
||
|
super(BaseEmbeddingBag, self).__init__()
|
||
|
self.num_embeddings = num_embeddings
|
||
|
self.embedding_dim = embedding_dim
|
||
|
if padding_idx is not None:
|
||
|
if padding_idx > 0:
|
||
|
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||
|
elif padding_idx < 0:
|
||
|
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
||
|
padding_idx = self.num_embeddings + padding_idx
|
||
|
self.padding_idx = padding_idx
|
||
|
self.max_norm = max_norm
|
||
|
self.norm_type = norm_type
|
||
|
self.scale_grad_by_freq = scale_grad_by_freq
|
||
|
self.sparse = sparse
|
||
|
|
||
|
# Specific to embedding bag
|
||
|
self.mode = mode
|
||
|
self.include_last_offset = include_last_offset
|