mirror of https://github.com/hpcaitech/ColossalAI
[FAW] export FAW in _ops (#1438)
parent
9056677b13
commit
30b4dd17c0
|
@ -5,4 +5,4 @@ from .loss import colo_cross_entropy
|
|||
from .embedding import colo_embedding
|
||||
from .addmm import colo_addmm
|
||||
from .embedding_bag import colo_embedding_bag
|
||||
from .view import colo_view
|
||||
from .view import colo_view
|
|
@ -3,13 +3,10 @@ from .linear import ColoLinear
|
|||
from .embedding import ColoEmbedding
|
||||
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
||||
|
||||
from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer
|
||||
|
||||
__all__ = [
|
||||
'ColoModule',
|
||||
'register_colo_module',
|
||||
'is_colo_module',
|
||||
'get_colo_module',
|
||||
'init_colo_module',
|
||||
'check_colo_module',
|
||||
'ColoLinear',
|
||||
'ColoEmbedding',
|
||||
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
|
||||
'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
|
||||
'LimitBuffIndexCopyer'
|
||||
]
|
||||
|
|
|
@ -5,9 +5,9 @@ from typing import List, Optional, Iterator, Tuple
|
|||
from .base_embedding import BaseEmbeddingBag
|
||||
from .cache_mgr import CachedParamMgr
|
||||
from torch.nn.parameter import Parameter
|
||||
from .._utils import dual_all_to_all
|
||||
from colossalai.nn._ops._utils import dual_all_to_all
|
||||
|
||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup, ColoTensorSpec
|
||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec
|
||||
|
||||
|
||||
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
|
@ -1,15 +1,17 @@
|
|||
import pytest
|
||||
from functools import partial
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
import colossalai
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec
|
||||
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
||||
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
||||
|
||||
NUM_EMBED, EMBED_DIM = 10, 8
|
||||
BATCH_SIZE = 8
|
Loading…
Reference in New Issue