[FAW] export FAW in _ops (#1438)

pull/1441/head
Jiarui Fang 2022-08-11 13:43:24 +08:00 committed by GitHub
parent 9056677b13
commit 30b4dd17c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 13 additions and 14 deletions

View File

@ -5,4 +5,4 @@ from .loss import colo_cross_entropy
from .embedding import colo_embedding from .embedding import colo_embedding
from .addmm import colo_addmm from .addmm import colo_addmm
from .embedding_bag import colo_embedding_bag from .embedding_bag import colo_embedding_bag
from .view import colo_view from .view import colo_view

View File

@ -3,13 +3,10 @@ from .linear import ColoLinear
from .embedding import ColoEmbedding from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer
__all__ = [ __all__ = [
'ColoModule', 'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
'register_colo_module', 'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
'is_colo_module', 'LimitBuffIndexCopyer'
'get_colo_module',
'init_colo_module',
'check_colo_module',
'ColoLinear',
'ColoEmbedding',
] ]

View File

@ -5,9 +5,9 @@ from typing import List, Optional, Iterator, Tuple
from .base_embedding import BaseEmbeddingBag from .base_embedding import BaseEmbeddingBag
from .cache_mgr import CachedParamMgr from .cache_mgr import CachedParamMgr
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from .._utils import dual_all_to_all from colossalai.nn._ops._utils import dual_all_to_all
from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup, ColoTensorSpec from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]: def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:

View File

@ -1,15 +1,17 @@
import pytest import pytest
from functools import partial from functools import partial
import torch
import torch.multiprocessing as mp
import numpy as np import numpy as np
import random import random
import torch
import torch.multiprocessing as mp
import colossalai import colossalai
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
NUM_EMBED, EMBED_DIM = 10, 8 NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8 BATCH_SIZE = 8