[polish] polish singleton and global context (#500)

pull/504/head
Jiarui Fang 3 years ago committed by GitHub
parent 9ec1ce6ab1
commit a445e118cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,6 @@
from .config import Config, ConfigException
from .parallel_context import ParallelContext
from .moe_context import MoeContext
from .parallel_mode import ParallelMode
from .moe_context import MOE_CONTEXT
from .process_group_initializer import *
from .random import *

@ -1,6 +1,9 @@
import torch
import torch.distributed as dist
from .parallel_mode import ParallelMode
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from typing import Tuple
@ -56,17 +59,10 @@ class MoeParallelInfo:
self.dp_group = group
class MoeContext:
class MoeContext(metaclass=SingletonMeta):
"""MoE parallel context manager. This class manages different
parallel groups in MoE context and MoE loss in training.
"""
__instance = None
@staticmethod
def get_instance():
if MoeContext.__instance is None:
MoeContext.__instance = MoeContext()
return MoeContext.__instance
def __init__(self):
self.world_size = 1
@ -160,3 +156,6 @@ class MoeContext:
def get_loss(self):
return self.aux_loss
MOE_CONTEXT = MoeContext()

@ -15,30 +15,16 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode
from colossalai.context.singleton_meta import SingletonMeta
class ParallelContext:
class ParallelContext(metaclass=SingletonMeta):
"""This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device.
"""
__instance = None
@staticmethod
def get_instance():
if ParallelContext.__instance is None:
ParallelContext()
return ParallelContext.__instance
def __init__(self):
# create a singleton instance
if ParallelContext.__instance is not None:
raise Exception(
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
else:
ParallelContext.__instance = self
# distributed settings
self._global_ranks = dict()
self._local_ranks = dict()
@ -510,3 +496,6 @@ class ParallelContext:
def set_virtual_pipeline_parallel_rank(self, rank):
self.virtual_pipeline_parallel_rank = rank
global_context = ParallelContext()

@ -1,7 +1,4 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
from colossalai.context import ParallelContext, MoeContext
global_context = ParallelContext.get_instance()
MOE_CONTEXT = MoeContext.get_instance()
from colossalai.context.parallel_context import global_context

@ -1,9 +1,10 @@
from colossalai.core import global_context as gpc, MOE_CONTEXT
from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER
from colossalai.utils.moe import get_moe_epsize_param_dict
from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode
from .utils import bucket_allreduce
from colossalai.context.moe_context import MOE_CONTEXT
@GRADIENT_HANDLER.register_module

@ -19,7 +19,9 @@ from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc, MOE_CONTEXT
from colossalai.core import global_context as gpc
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from typing import Type

@ -4,7 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts, Experts

@ -1,6 +1,6 @@
import torch
from colossalai.utils import get_current_device
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from .experts import FFNExperts, TPExperts

@ -1,7 +1,7 @@
import torch.nn as nn
from colossalai.registry import LOSSES
from torch.nn.modules.loss import _Loss
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
@LOSSES.register_module

@ -1,4 +1,4 @@
from colossalai.utils.commons.singleton_meta import SingletonMeta
from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils.memory_tracer.commons import col_tensor_mem_usage
import torch

@ -1,6 +1,7 @@
import torch.nn as nn
import torch.distributed as dist
from colossalai.core import global_context as gpc, MOE_CONTEXT
from colossalai.core import global_context as gpc
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.context import ParallelMode
from .common import is_using_ddp
from typing import Dict, List

@ -7,7 +7,7 @@ from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule
from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device
from typing import List

@ -1,12 +1,15 @@
import torch
import colossalai
import copy
import pytest
import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
import colossalai
from colossalai.testing import assert_close_loose
from colossalai.utils import free_port
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
import copy
import pytest
from functools import partial

@ -7,7 +7,7 @@ import torch.distributed as dist
import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.testing import assert_equal_in_group

@ -8,7 +8,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
BATCH_SIZE = 16
NUM_EXPERTS = 4

@ -6,7 +6,7 @@ import torch.distributed as dist
import colossalai
from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Experts
from colossalai.core import MOE_CONTEXT
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param
from colossalai.testing import assert_equal_in_group

Loading…
Cancel
Save