Browse Source

[polish] polish singleton and global context (#500)

pull/504/head
Jiarui Fang 3 years ago committed by GitHub
parent
commit
a445e118cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      colossalai/context/__init__.py
  2. 17
      colossalai/context/moe_context.py
  3. 21
      colossalai/context/parallel_context.py
  4. 0
      colossalai/context/singleton_meta.py
  5. 5
      colossalai/core.py
  6. 3
      colossalai/engine/gradient_handler/_moe_gradient_handler.py
  7. 4
      colossalai/initialize.py
  8. 2
      colossalai/nn/layer/moe/experts.py
  9. 2
      colossalai/nn/layer/moe/layers.py
  10. 2
      colossalai/nn/layer/moe/utils.py
  11. 2
      colossalai/nn/loss/loss_moe.py
  12. 2
      colossalai/utils/memory_tracer/model_data_memtracer.py
  13. 3
      colossalai/utils/moe.py
  14. 2
      model_zoo/moe/models.py
  15. 13
      tests/test_amp/test_naive_fp16.py
  16. 2
      tests/test_moe/test_grad_handler.py
  17. 2
      tests/test_moe/test_kernel.py
  18. 2
      tests/test_moe/test_moe_group.py

2
colossalai/context/__init__.py

@ -1,6 +1,6 @@
from .config import Config, ConfigException
from .parallel_context import ParallelContext
from .moe_context import MoeContext
from .parallel_mode import ParallelMode
from .moe_context import MOE_CONTEXT
from .process_group_initializer import *
from .random import *

17
colossalai/context/moe_context.py

@ -1,6 +1,9 @@
import torch
import torch.distributed as dist
from .parallel_mode import ParallelMode
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from typing import Tuple
@ -56,17 +59,10 @@ class MoeParallelInfo:
self.dp_group = group
class MoeContext:
class MoeContext(metaclass=SingletonMeta):
"""MoE parallel context manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
__instance = None
@staticmethod
def get_instance():
if MoeContext.__instance is None:
MoeContext.__instance = MoeContext()
return MoeContext.__instance
def __init__(self):
self.world_size = 1
@ -160,3 +156,6 @@ class MoeContext:
def get_loss(self):
return self.aux_loss
MOE_CONTEXT = MoeContext()

21
colossalai/context/parallel_context.py

@ -15,30 +15,16 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
from colossalai.context.singleton_meta import SingletonMeta
class ParallelContext:
class ParallelContext(metaclass=SingletonMeta):
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
"""
__instance = None
@staticmethod
def get_instance():
if ParallelContext.__instance is None:
ParallelContext()
return ParallelContext.__instance
def __init__(self):
# create a singleton instance
if ParallelContext.__instance is not None:
raise Exception(
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
else:
ParallelContext.__instance = self
# distributed settings
self._global_ranks = dict()
self._local_ranks = dict()
@ -510,3 +496,6 @@ class ParallelContext:
def set_virtual_pipeline_parallel_rank(self, rank):
self.virtual_pipeline_parallel_rank = rank
global_context = ParallelContext()

0
colossalai/utils/commons/singleton_meta.py → colossalai/context/singleton_meta.py

5
colossalai/core.py

@ -1,7 +1,4 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from colossalai.context import ParallelContext, MoeContext
global_context = ParallelContext.get_instance()
MOE_CONTEXT = MoeContext.get_instance()
from colossalai.context.parallel_context import global_context

3
colossalai/engine/gradient_handler/_moe_gradient_handler.py

@ -1,9 +1,10 @@
from colossalai.core import global_context as gpc, MOE_CONTEXT
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from colossalai.utils.moe import get_moe_epsize_param_dict
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
from .utils import bucket_allreduce
from colossalai.context.moe_context import MOE_CONTEXT
@GRADIENT_HANDLER.register_module

4
colossalai/initialize.py

@ -19,7 +19,9 @@ from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc, MOE_CONTEXT
from colossalai.core import global_context as gpc
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger

2
colossalai/nn/layer/moe/experts.py

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from typing import Type

2
colossalai/nn/layer/moe/layers.py

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts, Experts

2
colossalai/nn/layer/moe/utils.py

@ -1,6 +1,6 @@
import torch
from colossalai.utils import get_current_device
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from .experts import FFNExperts, TPExperts

2
colossalai/nn/loss/loss_moe.py

@ -1,7 +1,7 @@
import torch.nn as nn
from colossalai.registry import LOSSES
from torch.nn.modules.loss import _Loss
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
@LOSSES.register_module

2
colossalai/utils/memory_tracer/model_data_memtracer.py

@ -1,4 +1,4 @@
from colossalai.utils.commons.singleton_meta import SingletonMeta
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils.memory_tracer.commons import col_tensor_mem_usage
import torch

3
colossalai/utils/moe.py

@ -1,6 +1,7 @@
import torch.nn as nn
import torch.distributed as dist
from colossalai.core import global_context as gpc, MOE_CONTEXT
from colossalai.core import global_context as gpc
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.context import ParallelMode
from .common import is_using_ddp
from typing import Dict, List

2
model_zoo/moe/models.py

@ -7,7 +7,7 @@ from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule
from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
from typing import List

13
tests/test_amp/test_naive_fp16.py

@ -1,12 +1,15 @@
import torch
import colossalai
import copy
import pytest
import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
from colossalai.testing import assert_close_loose
from colossalai.utils import free_port
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
import copy
import pytest
from functools import partial

2
tests/test_moe/test_grad_handler.py

@ -7,7 +7,7 @@ import torch.distributed as dist
import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.testing import assert_equal_in_group

2
tests/test_moe/test_kernel.py

@ -8,7 +8,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
BATCH_SIZE = 16
NUM_EXPERTS = 4

2
tests/test_moe/test_moe_group.py

@ -6,7 +6,7 @@ import torch.distributed as dist
import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Experts
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.testing import assert_equal_in_group

Loading…
Cancel
Save