From 10b3df65c8278bf403309d042711d20a7965a67d Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 10 Aug 2022 14:31:53 +0800 Subject: [PATCH] [FAW] move coloparam setting in test code. (#1429) --- .../cache_embedding/parallel_freq_aware_embedding.py | 3 --- tests/test_tensor/ops/test_cache_embedding.py | 9 ++++++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py b/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py index 4400d6fc2..083076532 100644 --- a/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py +++ b/colossalai/nn/_ops/cache_embedding/parallel_freq_aware_embedding.py @@ -67,9 +67,6 @@ class ParallelFreqAwareEmbeddingBag(BaseEmbeddingBag): self.init_parameters() else: assert isinstance(_weight, ColoParameter), "initialized weight must in type of ColoParameter" - _weight.process_group = ProcessGroup(tp_degree=self.world_size) - _weight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[self.world_size]), - ComputeSpec(ComputePattern.TP1D)) self._weight = _weight @property diff --git a/tests/test_tensor/ops/test_cache_embedding.py b/tests/test_tensor/ops/test_cache_embedding.py index ac5b3bc40..8471975df 100644 --- a/tests/test_tensor/ops/test_cache_embedding.py +++ b/tests/test_tensor/ops/test_cache_embedding.py @@ -8,11 +8,9 @@ import random import colossalai from colossalai.utils import free_port from colossalai.testing import rerun_if_address_is_in_use -from colossalai.tensor import ColoParameter +from colossalai.tensor import ColoParameter, ProcessGroup, ShardSpec, ComputePattern, ComputeSpec from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag, ParallelFreqAwareEmbeddingBag -from colossalai.nn._ops.cache_embedding import CachedParamMgr, FreqAwareEmbeddingBag - NUM_EMBED, EMBED_DIM = 10, 8 BATCH_SIZE = 8 @@ -161,6 +159,11 @@ def run_parallel_freq_aware_embed(rank, world_size): weight = torch.rand(num_embed, embed_dim) coloweight = ColoParameter(weight.clone().detach().cpu(), requires_grad=False) + # initialize the tensor spec for the embedding weight parameter, + # which is an ColoParameter. + coloweight.process_group = ProcessGroup(tp_degree=world_size) + coloweight.set_tensor_spec(ShardSpec(dims=[-1], num_partitions=[world_size]), ComputeSpec(ComputePattern.TP1D)) + model = ParallelFreqAwareEmbeddingBag.from_pretrained(coloweight, include_last_offset=True, freeze=False,