import abc import torch.nn as nn class BaseEmbeddingBag(abc.ABC, nn.Module): def __init__( self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2., scale_grad_by_freq=False, sparse=False, mode='mean', include_last_offset=False, ): super(BaseEmbeddingBag, self).__init__() self.num_embeddings = num_embeddings self.embedding_dim = embedding_dim if padding_idx is not None: if padding_idx > 0: assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings' elif padding_idx < 0: assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings' padding_idx = self.num_embeddings + padding_idx self.padding_idx = padding_idx self.max_norm = max_norm self.norm_type = norm_type self.scale_grad_by_freq = scale_grad_by_freq self.sparse = sparse # Specific to embedding bag self.mode = mode self.include_last_offset = include_last_offset