mirror of https://github.com/hpcaitech/ColossalAI
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
37 lines
1.1 KiB
37 lines
1.1 KiB
import abc
|
|
import torch.nn as nn
|
|
|
|
|
|
class BaseEmbeddingBag(abc.ABC, nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
num_embeddings,
|
|
embedding_dim,
|
|
padding_idx=None,
|
|
max_norm=None,
|
|
norm_type=2.,
|
|
scale_grad_by_freq=False,
|
|
sparse=False,
|
|
mode='mean',
|
|
include_last_offset=False,
|
|
):
|
|
super(BaseEmbeddingBag, self).__init__()
|
|
self.num_embeddings = num_embeddings
|
|
self.embedding_dim = embedding_dim
|
|
if padding_idx is not None:
|
|
if padding_idx > 0:
|
|
assert padding_idx < self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
|
elif padding_idx < 0:
|
|
assert padding_idx >= -self.num_embeddings, 'Padding_idx must be within num_embeddings'
|
|
padding_idx = self.num_embeddings + padding_idx
|
|
self.padding_idx = padding_idx
|
|
self.max_norm = max_norm
|
|
self.norm_type = norm_type
|
|
self.scale_grad_by_freq = scale_grad_by_freq
|
|
self.sparse = sparse
|
|
|
|
# Specific to embedding bag
|
|
self.mode = mode
|
|
self.include_last_offset = include_last_offset
|