mirror of https://github.com/hpcaitech/ColossalAI
[polish] polish singleton and global context (#500)
parent
9ec1ce6ab1
commit
a445e118cf
|
@ -1,6 +1,6 @@
|
|||
from .config import Config, ConfigException
|
||||
from .parallel_context import ParallelContext
|
||||
from .moe_context import MoeContext
|
||||
from .parallel_mode import ParallelMode
|
||||
from .moe_context import MOE_CONTEXT
|
||||
from .process_group_initializer import *
|
||||
from .random import *
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
import torch
|
||||
import torch.distributed as dist
|
||||
from .parallel_mode import ParallelMode
|
||||
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
|
@ -56,17 +59,10 @@ class MoeParallelInfo:
|
|||
self.dp_group = group
|
||||
|
||||
|
||||
class MoeContext:
|
||||
class MoeContext(metaclass=SingletonMeta):
|
||||
"""MoE parallel context manager. This class manages different
|
||||
parallel groups in MoE context and MoE loss in training.
|
||||
"""
|
||||
__instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance():
|
||||
if MoeContext.__instance is None:
|
||||
MoeContext.__instance = MoeContext()
|
||||
return MoeContext.__instance
|
||||
|
||||
def __init__(self):
|
||||
self.world_size = 1
|
||||
|
@ -160,3 +156,6 @@ class MoeContext:
|
|||
|
||||
def get_loss(self):
|
||||
return self.aux_loss
|
||||
|
||||
|
||||
MOE_CONTEXT = MoeContext()
|
||||
|
|
|
@ -15,30 +15,16 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
|
|||
|
||||
from .parallel_mode import ParallelMode
|
||||
from .random import add_seed, get_seeds, set_mode
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
|
||||
|
||||
class ParallelContext:
|
||||
class ParallelContext(metaclass=SingletonMeta):
|
||||
"""This class provides interface functions for users to get the parallel context,
|
||||
such as the global rank, the local rank, the world size, etc. of each device.
|
||||
|
||||
"""
|
||||
|
||||
__instance = None
|
||||
|
||||
@staticmethod
|
||||
def get_instance():
|
||||
if ParallelContext.__instance is None:
|
||||
ParallelContext()
|
||||
return ParallelContext.__instance
|
||||
|
||||
def __init__(self):
|
||||
# create a singleton instance
|
||||
if ParallelContext.__instance is not None:
|
||||
raise Exception(
|
||||
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
|
||||
else:
|
||||
ParallelContext.__instance = self
|
||||
|
||||
# distributed settings
|
||||
self._global_ranks = dict()
|
||||
self._local_ranks = dict()
|
||||
|
@ -510,3 +496,6 @@ class ParallelContext:
|
|||
|
||||
def set_virtual_pipeline_parallel_rank(self, rank):
|
||||
self.virtual_pipeline_parallel_rank = rank
|
||||
|
||||
|
||||
global_context = ParallelContext()
|
||||
|
|
|
@ -1,7 +1,4 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from colossalai.context import ParallelContext, MoeContext
|
||||
|
||||
global_context = ParallelContext.get_instance()
|
||||
MOE_CONTEXT = MoeContext.get_instance()
|
||||
from colossalai.context.parallel_context import global_context
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
from colossalai.core import global_context as gpc, MOE_CONTEXT
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.registry import GRADIENT_HANDLER
|
||||
from colossalai.utils.moe import get_moe_epsize_param_dict
|
||||
from ._base_gradient_handler import BaseGradientHandler
|
||||
from ...context.parallel_mode import ParallelMode
|
||||
from .utils import bucket_allreduce
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
|
||||
|
||||
@GRADIENT_HANDLER.register_module
|
||||
|
|
|
@ -19,7 +19,9 @@ from colossalai.amp import AMP_TYPE, convert_to_amp
|
|||
from colossalai.amp.naive_amp import NaiveAMPModel
|
||||
from colossalai.builder.builder import build_gradient_handler
|
||||
from colossalai.context import Config, ConfigException, ParallelMode
|
||||
from colossalai.core import global_context as gpc, MOE_CONTEXT
|
||||
from colossalai.core import global_context as gpc
|
||||
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.engine import Engine
|
||||
from colossalai.engine.ophooks import BaseOpHook
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
from colossalai.context import ParallelMode, seed
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from typing import Type
|
||||
|
||||
|
||||
|
|
|
@ -4,7 +4,7 @@ import torch
|
|||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
|
||||
from .experts import MoeExperts, Experts
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
import torch
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from .experts import FFNExperts, TPExperts
|
||||
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import torch.nn as nn
|
||||
from colossalai.registry import LOSSES
|
||||
from torch.nn.modules.loss import _Loss
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
|
||||
|
||||
@LOSSES.register_module
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from colossalai.utils.commons.singleton_meta import SingletonMeta
|
||||
from colossalai.context.singleton_meta import SingletonMeta
|
||||
from colossalai.utils.memory_tracer.commons import col_tensor_mem_usage
|
||||
import torch
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import torch.nn as nn
|
||||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc, MOE_CONTEXT
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.context import ParallelMode
|
||||
from .common import is_using_ddp
|
||||
from typing import Dict, List
|
||||
|
|
|
@ -7,7 +7,7 @@ from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
|
|||
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule
|
||||
from .util import moe_sa_args, moe_mlp_args
|
||||
from ..helper import TransformerLayer
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.utils import get_current_device
|
||||
from typing import List
|
||||
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
import torch
|
||||
import colossalai
|
||||
import copy
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
import colossalai
|
||||
from colossalai.testing import assert_close_loose
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
import copy
|
||||
import pytest
|
||||
from functools import partial
|
||||
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ import torch.distributed as dist
|
|||
import colossalai
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
from colossalai.engine.gradient_handler import MoeGradientHandler
|
||||
from colossalai.testing import assert_equal_in_group
|
||||
|
|
|
@ -8,7 +8,7 @@ from colossalai.context import ParallelMode
|
|||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
|
||||
BATCH_SIZE = 16
|
||||
NUM_EXPERTS = 4
|
||||
|
|
|
@ -6,7 +6,7 @@ import torch.distributed as dist
|
|||
import colossalai
|
||||
from colossalai.utils import free_port, get_current_device
|
||||
from colossalai.nn.layer.moe import Experts
|
||||
from colossalai.core import MOE_CONTEXT
|
||||
from colossalai.context.moe_context import MOE_CONTEXT
|
||||
from colossalai.utils.moe import sync_moe_model_param
|
||||
from colossalai.testing import assert_equal_in_group
|
||||
|
||||
|
|
Loading…
Reference in New Issue