mirror of https://github.com/hpcaitech/ColossalAI
[FAW] export FAW in _ops (#1438)
parent
9056677b13
commit
30b4dd17c0
|
@ -3,13 +3,10 @@ from .linear import ColoLinear
|
||||||
from .embedding import ColoEmbedding
|
from .embedding import ColoEmbedding
|
||||||
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
||||||
|
|
||||||
|
from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColoModule',
|
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
|
||||||
'register_colo_module',
|
'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
|
||||||
'is_colo_module',
|
'LimitBuffIndexCopyer'
|
||||||
'get_colo_module',
|
|
||||||
'init_colo_module',
|
|
||||||
'check_colo_module',
|
|
||||||
'ColoLinear',
|
|
||||||
'ColoEmbedding',
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,9 +5,9 @@ from typing import List, Optional, Iterator, Tuple
|
||||||
from .base_embedding import BaseEmbeddingBag
|
from .base_embedding import BaseEmbeddingBag
|
||||||
from .cache_mgr import CachedParamMgr
|
from .cache_mgr import CachedParamMgr
|
||||||
from torch.nn.parameter import Parameter
|
from torch.nn.parameter import Parameter
|
||||||
from .._utils import dual_all_to_all
|
from colossalai.nn._ops._utils import dual_all_to_all
|
||||||
|
|
||||||
from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup, ColoTensorSpec
|
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec
|
||||||
|
|
||||||
|
|
||||||
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:
|
|
@ -1,15 +1,17 @@
|
||||||
import pytest
|
import pytest
|
||||||
from functools import partial
|
from functools import partial
|
||||||
import torch
|
|
||||||
import torch.multiprocessing as mp
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec
|
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec
|
||||||
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
|
||||||
|
|
||||||
NUM_EMBED, EMBED_DIM = 10, 8
|
NUM_EMBED, EMBED_DIM = 10, 8
|
||||||
BATCH_SIZE = 8
|
BATCH_SIZE = 8
|
Loading…
Reference in New Issue