mirror of https://github.com/hpcaitech/ColossalAI
37 lines
1.1 KiB
Python
37 lines
1.1 KiB
Python
import abc
|
|
|
|
import torch.nn as nn
|
|
|
|
|
|
class BaseEmbeddingBag(abc.ABC, nn.Module):
|
|
def __init__(
|
|
self,
|
|
num_embeddings,
|
|
embedding_dim,
|
|
padding_idx=None,
|
|
max_norm=None,
|
|
norm_type=2.0,
|
|
scale_grad_by_freq=False,
|
|
sparse=False,
|
|
mode="mean",
|
|
include_last_offset=False,
|
|
):
|
|
super(BaseEmbeddingBag, self).__init__()
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
if padding_idx is not None:
|
|
if padding_idx > 0:
|
|
assert padding_idx < self.num_embeddings, "Padding_idx must be within num_embeddings"
|
|
elif padding_idx < 0:
|
|
assert padding_idx >= -self.num_embeddings, "Padding_idx must be within num_embeddings"
|
|
padding_idx = self.num_embeddings + padding_idx
|
|
self.padding_idx = padding_idx
|
|
self.max_norm = max_norm
|
|
self.norm_type = norm_type
|
|
self.scale_grad_by_freq = scale_grad_by_freq
|
|
self.sparse = sparse
|
|
|
|
# Specific to embedding bag
|
|
self.mode = mode
|
|
self.include_last_offset = include_last_offset
|