You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/colossalai/global_variables.py

37 lines
1.0 KiB

class MoeEnv:
"""Moe enviroment variables.
"""
def __init__(self):
self.data_parallel_size = None
self.model_parallel_size = None
self.aux_loss = None
def setup(self, moe_model_size):
from .core import global_context as gpc
if gpc.tensor_parallel_size > 1 or gpc.pipeline_parallel_size > 1:
raise NotImplementedError("Moe is not compatible with tensor or pipeline parallel")
assert gpc.data_parallel_size % moe_model_size == 0, \
"The size of data parallel needs to be divided by moe model parallel size"
self.data_parallel_size = gpc.data_parallel_size // moe_model_size
self.model_parallel_size = moe_model_size
def is_initialized(self):
return self.model_parallel_size is not None
def reset_loss(self):
self.aux_loss = 0
def add_loss(self, loss):
self.aux_loss += loss
def get_loss(self):
return self.aux_loss
moe_env = MoeEnv()