ColossalAI/colossalai/legacy/nn/parallel/layers/cache_embedding/base_embedding.py

37 lines
1.1 KiB
Python

import abc
import torch.nn as nn
class BaseEmbeddingBag(abc.ABC, nn.Module):
def __init__(
self,
num_embeddings,
embedding_dim,
padding_idx=None,
max_norm=None,
norm_type=2.0,
scale_grad_by_freq=False,
sparse=False,
mode="mean",
include_last_offset=False,
):
super(BaseEmbeddingBag, self).__init__()
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < self.num_embeddings, "Padding_idx must be within num_embeddings"
elif padding_idx < 0:
assert padding_idx >= -self.num_embeddings, "Padding_idx must be within num_embeddings"
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# Specific to embedding bag
self.mode = mode
self.include_last_offset = include_last_offset