From a445e118cfdf47a8fd1c915b7eee6faa5b883f3d Mon Sep 17 00:00:00 2001 From: Jiarui Fang Date: Wed, 23 Mar 2022 18:03:39 +0800 Subject: [PATCH] [polish] polish singleton and global context (#500) --- colossalai/context/__init__.py | 2 +- colossalai/context/moe_context.py | 17 +++++++-------- colossalai/context/parallel_context.py | 21 +++++-------------- .../commons => context}/singleton_meta.py | 0 colossalai/core.py | 5 +---- .../gradient_handler/_moe_gradient_handler.py | 3 ++- colossalai/initialize.py | 4 +++- colossalai/nn/layer/moe/experts.py | 2 +- colossalai/nn/layer/moe/layers.py | 2 +- colossalai/nn/layer/moe/utils.py | 2 +- colossalai/nn/loss/loss_moe.py | 2 +- .../memory_tracer/model_data_memtracer.py | 2 +- colossalai/utils/moe.py | 3 ++- model_zoo/moe/models.py | 2 +- tests/test_amp/test_naive_fp16.py | 13 +++++++----- tests/test_moe/test_grad_handler.py | 2 +- tests/test_moe/test_kernel.py | 2 +- tests/test_moe/test_moe_group.py | 2 +- 18 files changed, 39 insertions(+), 47 deletions(-) rename colossalai/{utils/commons => context}/singleton_meta.py (100%) diff --git a/colossalai/context/__init__.py b/colossalai/context/__init__.py index e5b600b68..50178b5fa 100644 --- a/colossalai/context/__init__.py +++ b/colossalai/context/__init__.py @@ -1,6 +1,6 @@ from .config import Config, ConfigException from .parallel_context import ParallelContext -from .moe_context import MoeContext from .parallel_mode import ParallelMode +from .moe_context import MOE_CONTEXT from .process_group_initializer import * from .random import * diff --git a/colossalai/context/moe_context.py b/colossalai/context/moe_context.py index f71236bb0..23eec6186 100644 --- a/colossalai/context/moe_context.py +++ b/colossalai/context/moe_context.py @@ -1,6 +1,9 @@ import torch import torch.distributed as dist -from .parallel_mode import ParallelMode + +from colossalai.context.parallel_mode import ParallelMode +from colossalai.context.singleton_meta import SingletonMeta + from typing import Tuple @@ -56,17 +59,10 @@ class MoeParallelInfo: self.dp_group = group -class MoeContext: +class MoeContext(metaclass=SingletonMeta): """MoE parallel context manager. This class manages different parallel groups in MoE context and MoE loss in training. """ - __instance = None - - @staticmethod - def get_instance(): - if MoeContext.__instance is None: - MoeContext.__instance = MoeContext() - return MoeContext.__instance def __init__(self): self.world_size = 1 @@ -160,3 +156,6 @@ class MoeContext: def get_loss(self): return self.aux_loss + + +MOE_CONTEXT = MoeContext() diff --git a/colossalai/context/parallel_context.py b/colossalai/context/parallel_context.py index e461b88e9..f69ec66b6 100644 --- a/colossalai/context/parallel_context.py +++ b/colossalai/context/parallel_context.py @@ -15,30 +15,16 @@ from colossalai.registry import DIST_GROUP_INITIALIZER from .parallel_mode import ParallelMode from .random import add_seed, get_seeds, set_mode +from colossalai.context.singleton_meta import SingletonMeta -class ParallelContext: +class ParallelContext(metaclass=SingletonMeta): """This class provides interface functions for users to get the parallel context, such as the global rank, the local rank, the world size, etc. of each device. """ - __instance = None - - @staticmethod - def get_instance(): - if ParallelContext.__instance is None: - ParallelContext() - return ParallelContext.__instance - def __init__(self): - # create a singleton instance - if ParallelContext.__instance is not None: - raise Exception( - 'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context') - else: - ParallelContext.__instance = self - # distributed settings self._global_ranks = dict() self._local_ranks = dict() @@ -510,3 +496,6 @@ class ParallelContext: def set_virtual_pipeline_parallel_rank(self, rank): self.virtual_pipeline_parallel_rank = rank + + +global_context = ParallelContext() diff --git a/colossalai/utils/commons/singleton_meta.py b/colossalai/context/singleton_meta.py similarity index 100% rename from colossalai/utils/commons/singleton_meta.py rename to colossalai/context/singleton_meta.py diff --git a/colossalai/core.py b/colossalai/core.py index a2d3f57a7..4ae054d46 100644 --- a/colossalai/core.py +++ b/colossalai/core.py @@ -1,7 +1,4 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- -from colossalai.context import ParallelContext, MoeContext - -global_context = ParallelContext.get_instance() -MOE_CONTEXT = MoeContext.get_instance() +from colossalai.context.parallel_context import global_context diff --git a/colossalai/engine/gradient_handler/_moe_gradient_handler.py b/colossalai/engine/gradient_handler/_moe_gradient_handler.py index 8b2a58842..f65be3869 100644 --- a/colossalai/engine/gradient_handler/_moe_gradient_handler.py +++ b/colossalai/engine/gradient_handler/_moe_gradient_handler.py @@ -1,9 +1,10 @@ -from colossalai.core import global_context as gpc, MOE_CONTEXT +from colossalai.core import global_context as gpc from colossalai.registry import GRADIENT_HANDLER from colossalai.utils.moe import get_moe_epsize_param_dict from ._base_gradient_handler import BaseGradientHandler from ...context.parallel_mode import ParallelMode from .utils import bucket_allreduce +from colossalai.context.moe_context import MOE_CONTEXT @GRADIENT_HANDLER.register_module diff --git a/colossalai/initialize.py b/colossalai/initialize.py index 6211e694d..b9b01d5d0 100644 --- a/colossalai/initialize.py +++ b/colossalai/initialize.py @@ -19,7 +19,9 @@ from colossalai.amp import AMP_TYPE, convert_to_amp from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.builder.builder import build_gradient_handler from colossalai.context import Config, ConfigException, ParallelMode -from colossalai.core import global_context as gpc, MOE_CONTEXT +from colossalai.core import global_context as gpc + +from colossalai.context.moe_context import MOE_CONTEXT from colossalai.engine import Engine from colossalai.engine.ophooks import BaseOpHook from colossalai.logging import get_dist_logger diff --git a/colossalai/nn/layer/moe/experts.py b/colossalai/nn/layer/moe/experts.py index cfecbe4ba..c589595c1 100644 --- a/colossalai/nn/layer/moe/experts.py +++ b/colossalai/nn/layer/moe/experts.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn from colossalai.context import ParallelMode, seed from colossalai.utils import get_current_device -from colossalai.core import MOE_CONTEXT +from colossalai.context.moe_context import MOE_CONTEXT from typing import Type diff --git a/colossalai/nn/layer/moe/layers.py b/colossalai/nn/layer/moe/layers.py index ebd8b4f79..93b2c0e67 100644 --- a/colossalai/nn/layer/moe/layers.py +++ b/colossalai/nn/layer/moe/layers.py @@ -4,7 +4,7 @@ import torch import torch.nn as nn import torch.nn.functional as F import torch.distributed as dist -from colossalai.core import MOE_CONTEXT +from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils import get_current_device from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from .experts import MoeExperts, Experts diff --git a/colossalai/nn/layer/moe/utils.py b/colossalai/nn/layer/moe/utils.py index ad9c99621..2a938a414 100644 --- a/colossalai/nn/layer/moe/utils.py +++ b/colossalai/nn/layer/moe/utils.py @@ -1,6 +1,6 @@ import torch from colossalai.utils import get_current_device -from colossalai.core import MOE_CONTEXT +from colossalai.context.moe_context import MOE_CONTEXT from .experts import FFNExperts, TPExperts diff --git a/colossalai/nn/loss/loss_moe.py b/colossalai/nn/loss/loss_moe.py index 4c9c0fac8..b995647f3 100644 --- a/colossalai/nn/loss/loss_moe.py +++ b/colossalai/nn/loss/loss_moe.py @@ -1,7 +1,7 @@ import torch.nn as nn from colossalai.registry import LOSSES from torch.nn.modules.loss import _Loss -from colossalai.core import MOE_CONTEXT +from colossalai.context.moe_context import MOE_CONTEXT @LOSSES.register_module diff --git a/colossalai/utils/memory_tracer/model_data_memtracer.py b/colossalai/utils/memory_tracer/model_data_memtracer.py index 20bda8dbb..e4a70abee 100644 --- a/colossalai/utils/memory_tracer/model_data_memtracer.py +++ b/colossalai/utils/memory_tracer/model_data_memtracer.py @@ -1,4 +1,4 @@ -from colossalai.utils.commons.singleton_meta import SingletonMeta +from colossalai.context.singleton_meta import SingletonMeta from colossalai.utils.memory_tracer.commons import col_tensor_mem_usage import torch diff --git a/colossalai/utils/moe.py b/colossalai/utils/moe.py index 9b1f5f976..3618e28b2 100644 --- a/colossalai/utils/moe.py +++ b/colossalai/utils/moe.py @@ -1,6 +1,7 @@ import torch.nn as nn import torch.distributed as dist -from colossalai.core import global_context as gpc, MOE_CONTEXT +from colossalai.core import global_context as gpc +from colossalai.context.moe_context import MOE_CONTEXT from colossalai.context import ParallelMode from .common import is_using_ddp from typing import Dict, List diff --git a/model_zoo/moe/models.py b/model_zoo/moe/models.py index e9659a347..9dc273d08 100644 --- a/model_zoo/moe/models.py +++ b/model_zoo/moe/models.py @@ -7,7 +7,7 @@ from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \ from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule from .util import moe_sa_args, moe_mlp_args from ..helper import TransformerLayer -from colossalai.core import MOE_CONTEXT +from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils import get_current_device from typing import List diff --git a/tests/test_amp/test_naive_fp16.py b/tests/test_amp/test_naive_fp16.py index c3554f8ca..8704da098 100644 --- a/tests/test_amp/test_naive_fp16.py +++ b/tests/test_amp/test_naive_fp16.py @@ -1,12 +1,15 @@ import torch -import colossalai -import copy -import pytest import torch.multiprocessing as mp -from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp -from tests.components_to_test.registry import non_distributed_component_funcs + +import colossalai from colossalai.testing import assert_close_loose from colossalai.utils import free_port +from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp + +from tests.components_to_test.registry import non_distributed_component_funcs + +import copy +import pytest from functools import partial diff --git a/tests/test_moe/test_grad_handler.py b/tests/test_moe/test_grad_handler.py index 08122a55f..cef782fea 100644 --- a/tests/test_moe/test_grad_handler.py +++ b/tests/test_moe/test_grad_handler.py @@ -7,7 +7,7 @@ import torch.distributed as dist import colossalai from colossalai.utils import free_port, get_current_device from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts -from colossalai.core import MOE_CONTEXT +from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils.moe import sync_moe_model_param from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.testing import assert_equal_in_group diff --git a/tests/test_moe/test_kernel.py b/tests/test_moe/test_kernel.py index 12a362bbb..5fba61945 100644 --- a/tests/test_moe/test_kernel.py +++ b/tests/test_moe/test_kernel.py @@ -8,7 +8,7 @@ from colossalai.context import ParallelMode from colossalai.core import global_context as gpc from colossalai.utils import free_port, get_current_device from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts -from colossalai.core import MOE_CONTEXT +from colossalai.context.moe_context import MOE_CONTEXT BATCH_SIZE = 16 NUM_EXPERTS = 4 diff --git a/tests/test_moe/test_moe_group.py b/tests/test_moe/test_moe_group.py index 613d421f5..fb8274520 100644 --- a/tests/test_moe/test_moe_group.py +++ b/tests/test_moe/test_moe_group.py @@ -6,7 +6,7 @@ import torch.distributed as dist import colossalai from colossalai.utils import free_port, get_current_device from colossalai.nn.layer.moe import Experts -from colossalai.core import MOE_CONTEXT +from colossalai.context.moe_context import MOE_CONTEXT from colossalai.utils.moe import sync_moe_model_param from colossalai.testing import assert_equal_in_group