[FAW] export FAW in _ops (#1438)

pull/1441/head
Jiarui Fang 2022-08-11 13:43:24 +08:00 committed by GitHub
parent 9056677b13
commit 30b4dd17c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 13 additions and 14 deletions

View File

@ -5,4 +5,4 @@ from .loss import colo_cross_entropy
from .embedding import colo_embedding
from .addmm import colo_addmm
from .embedding_bag import colo_embedding_bag
from .view import colo_view
from .view import colo_view

View File

@ -3,13 +3,10 @@ from .linear import ColoLinear
from .embedding import ColoEmbedding
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
from .cache_embedding import FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag, CachedParamMgr, LimitBuffIndexCopyer
__all__ = [
'ColoModule',
'register_colo_module',
'is_colo_module',
'get_colo_module',
'init_colo_module',
'check_colo_module',
'ColoLinear',
'ColoEmbedding',
'ColoModule', 'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
'ColoLinear', 'ColoEmbedding', 'FreqAwareEmbeddingBag', 'ParallelFreqAwareEmbeddingBag', 'CachedParamMgr',
'LimitBuffIndexCopyer'
]

View File

@ -5,9 +5,9 @@ from typing import List, Optional, Iterator, Tuple
from .base_embedding import BaseEmbeddingBag
from .cache_mgr import CachedParamMgr
from torch.nn.parameter import Parameter
from .._utils import dual_all_to_all
from colossalai.nn._ops._utils import dual_all_to_all
from colossalai.tensor import ColoParameter, ShardSpec, ComputeSpec, ComputePattern, ProcessGroup, ColoTensorSpec
from colossalai.tensor import ColoParameter, ShardSpec, ComputePattern, ProcessGroup, ColoTensorSpec
def get_partition(embedding_dim, rank, world_size) -> Tuple[int, int, bool]:

View File

@ -1,15 +1,17 @@
import pytest
from functools import partial
import torch
import torch.multiprocessing as mp
import numpy as np
import random
import torch
import torch.multiprocessing as mp
import colossalai
from colossalai.utils import free_port
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec
from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
from colossalai.nn.parallel.layers import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag
NUM_EMBED, EMBED_DIM = 10, 8
BATCH_SIZE = 8