[polish] polish singleton and global context (#500)

pull/504/head
Jiarui Fang 2022-03-23 18:03:39 +08:00 committed by GitHub
parent 9ec1ce6ab1
commit a445e118cf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 39 additions and 47 deletions

View File

@ -1,6 +1,6 @@
from .config import Config, ConfigException from .config import Config, ConfigException
from .parallel_context import ParallelContext from .parallel_context import ParallelContext
from .moe_context import MoeContext
from .parallel_mode import ParallelMode from .parallel_mode import ParallelMode
from .moe_context import MOE_CONTEXT
from .process_group_initializer import * from .process_group_initializer import *
from .random import * from .random import *

View File

@ -1,6 +1,9 @@
import torch import torch
import torch.distributed as dist import torch.distributed as dist
from .parallel_mode import ParallelMode
from colossalai.context.parallel_mode import ParallelMode
from colossalai.context.singleton_meta import SingletonMeta
from typing import Tuple from typing import Tuple
@ -56,17 +59,10 @@ class MoeParallelInfo:
self.dp_group = group self.dp_group = group
class MoeContext: class MoeContext(metaclass=SingletonMeta):
"""MoE parallel context manager. This class manages different """MoE parallel context manager. This class manages different
parallel groups in MoE context and MoE loss in training. parallel groups in MoE context and MoE loss in training.
""" """
__instance = None
@staticmethod
def get_instance():
if MoeContext.__instance is None:
MoeContext.__instance = MoeContext()
return MoeContext.__instance
def __init__(self): def __init__(self):
self.world_size = 1 self.world_size = 1
@ -160,3 +156,6 @@ class MoeContext:
def get_loss(self): def get_loss(self):
return self.aux_loss return self.aux_loss
MOE_CONTEXT = MoeContext()

View File

@ -15,30 +15,16 @@ from colossalai.registry import DIST_GROUP_INITIALIZER
from .parallel_mode import ParallelMode from .parallel_mode import ParallelMode
from .random import add_seed, get_seeds, set_mode from .random import add_seed, get_seeds, set_mode
from colossalai.context.singleton_meta import SingletonMeta
class ParallelContext: class ParallelContext(metaclass=SingletonMeta):
"""This class provides interface functions for users to get the parallel context, """This class provides interface functions for users to get the parallel context,
such as the global rank, the local rank, the world size, etc. of each device. such as the global rank, the local rank, the world size, etc. of each device.
""" """
__instance = None
@staticmethod
def get_instance():
if ParallelContext.__instance is None:
ParallelContext()
return ParallelContext.__instance
def __init__(self): def __init__(self):
# create a singleton instance
if ParallelContext.__instance is not None:
raise Exception(
'ParallelContext is a singleton class, you should get the instance by colossalai.core.global_context')
else:
ParallelContext.__instance = self
# distributed settings # distributed settings
self._global_ranks = dict() self._global_ranks = dict()
self._local_ranks = dict() self._local_ranks = dict()
@ -510,3 +496,6 @@ class ParallelContext:
def set_virtual_pipeline_parallel_rank(self, rank): def set_virtual_pipeline_parallel_rank(self, rank):
self.virtual_pipeline_parallel_rank = rank self.virtual_pipeline_parallel_rank = rank
global_context = ParallelContext()

View File

@ -1,7 +1,4 @@
#!/usr/bin/env python #!/usr/bin/env python
# -*- encoding: utf-8 -*- # -*- encoding: utf-8 -*-
from colossalai.context import ParallelContext, MoeContext from colossalai.context.parallel_context import global_context
global_context = ParallelContext.get_instance()
MOE_CONTEXT = MoeContext.get_instance()

View File

@ -1,9 +1,10 @@
from colossalai.core import global_context as gpc, MOE_CONTEXT from colossalai.core import global_context as gpc
from colossalai.registry import GRADIENT_HANDLER from colossalai.registry import GRADIENT_HANDLER
from colossalai.utils.moe import get_moe_epsize_param_dict from colossalai.utils.moe import get_moe_epsize_param_dict
from ._base_gradient_handler import BaseGradientHandler from ._base_gradient_handler import BaseGradientHandler
from ...context.parallel_mode import ParallelMode from ...context.parallel_mode import ParallelMode
from .utils import bucket_allreduce from .utils import bucket_allreduce
from colossalai.context.moe_context import MOE_CONTEXT
@GRADIENT_HANDLER.register_module @GRADIENT_HANDLER.register_module

View File

@ -19,7 +19,9 @@ from colossalai.amp import AMP_TYPE, convert_to_amp
from colossalai.amp.naive_amp import NaiveAMPModel from colossalai.amp.naive_amp import NaiveAMPModel
from colossalai.builder.builder import build_gradient_handler from colossalai.builder.builder import build_gradient_handler
from colossalai.context import Config, ConfigException, ParallelMode from colossalai.context import Config, ConfigException, ParallelMode
from colossalai.core import global_context as gpc, MOE_CONTEXT from colossalai.core import global_context as gpc
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.engine import Engine from colossalai.engine import Engine
from colossalai.engine.ophooks import BaseOpHook from colossalai.engine.ophooks import BaseOpHook
from colossalai.logging import get_dist_logger from colossalai.logging import get_dist_logger

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from colossalai.context import ParallelMode, seed from colossalai.context import ParallelMode, seed
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.core import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from typing import Type from typing import Type

View File

@ -4,7 +4,7 @@ import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist import torch.distributed as dist
from colossalai.core import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum from ._operation import COL_MOE_KERNEL_FLAG, AllToAll, AllGather, ReduceScatter, MoeDispatch, MoeCombine, moe_cumsum
from .experts import MoeExperts, Experts from .experts import MoeExperts, Experts

View File

@ -1,6 +1,6 @@
import torch import torch
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from colossalai.core import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from .experts import FFNExperts, TPExperts from .experts import FFNExperts, TPExperts

View File

@ -1,7 +1,7 @@
import torch.nn as nn import torch.nn as nn
from colossalai.registry import LOSSES from colossalai.registry import LOSSES
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
from colossalai.core import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
@LOSSES.register_module @LOSSES.register_module

View File

@ -1,4 +1,4 @@
from colossalai.utils.commons.singleton_meta import SingletonMeta from colossalai.context.singleton_meta import SingletonMeta
from colossalai.utils.memory_tracer.commons import col_tensor_mem_usage from colossalai.utils.memory_tracer.commons import col_tensor_mem_usage
import torch import torch

View File

@ -1,6 +1,7 @@
import torch.nn as nn import torch.nn as nn
import torch.distributed as dist import torch.distributed as dist
from colossalai.core import global_context as gpc, MOE_CONTEXT from colossalai.core import global_context as gpc
from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.context import ParallelMode from colossalai.context import ParallelMode
from .common import is_using_ddp from .common import is_using_ddp
from typing import Dict, List from typing import Dict, List

View File

@ -7,7 +7,7 @@ from colossalai.nn.layer import VanillaPatchEmbedding, VanillaClassifier, \
from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule from colossalai.nn.layer.moe import build_ffn_experts, MoeLayer, Top2Router, NormalNoiseGenerator, MoeModule
from .util import moe_sa_args, moe_mlp_args from .util import moe_sa_args, moe_mlp_args
from ..helper import TransformerLayer from ..helper import TransformerLayer
from colossalai.core import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils import get_current_device from colossalai.utils import get_current_device
from typing import List from typing import List

View File

@ -1,12 +1,15 @@
import torch import torch
import colossalai
import copy
import pytest
import torch.multiprocessing as mp import torch.multiprocessing as mp
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs import colossalai
from colossalai.testing import assert_close_loose from colossalai.testing import assert_close_loose
from colossalai.utils import free_port from colossalai.utils import free_port
from colossalai.amp import convert_to_naive_amp, convert_to_apex_amp
from tests.components_to_test.registry import non_distributed_component_funcs
import copy
import pytest
from functools import partial from functools import partial

View File

@ -7,7 +7,7 @@ import torch.distributed as dist
import colossalai import colossalai
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts from colossalai.nn.layer.moe import Top1Router, UniformNoiseGenerator, MoeLayer, Experts
from colossalai.core import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param from colossalai.utils.moe import sync_moe_model_param
from colossalai.engine.gradient_handler import MoeGradientHandler from colossalai.engine.gradient_handler import MoeGradientHandler
from colossalai.testing import assert_equal_in_group from colossalai.testing import assert_equal_in_group

View File

@ -8,7 +8,7 @@ from colossalai.context import ParallelMode
from colossalai.core import global_context as gpc from colossalai.core import global_context as gpc
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts from colossalai.nn.layer.moe import Top1Router, Top2Router, MoeLayer, Experts
from colossalai.core import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
BATCH_SIZE = 16 BATCH_SIZE = 16
NUM_EXPERTS = 4 NUM_EXPERTS = 4

View File

@ -6,7 +6,7 @@ import torch.distributed as dist
import colossalai import colossalai
from colossalai.utils import free_port, get_current_device from colossalai.utils import free_port, get_current_device
from colossalai.nn.layer.moe import Experts from colossalai.nn.layer.moe import Experts
from colossalai.core import MOE_CONTEXT from colossalai.context.moe_context import MOE_CONTEXT
from colossalai.utils.moe import sync_moe_model_param from colossalai.utils.moe import sync_moe_model_param
from colossalai.testing import assert_equal_in_group from colossalai.testing import assert_equal_in_group