use dummy mode to generate random numbers in model construction

pull/182/head
Wenwen Qu 2023-09-08 17:56:42 +08:00
parent 1baa7b41f0
commit cd6b28b073
3 changed files with 17 additions and 4 deletions

View File

@ -36,7 +36,7 @@ class Config(dict):
config (dict): The dict object to be wrapped.
"""
def __init__(self, config: dict = None):
def __init__(self, config: dict = None): # pylint: disable=W0231
if config is not None:
for k, v in config.items():
self._add_item(k, v)
@ -100,7 +100,7 @@ class Config(dict):
module_name = filepath.stem
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
module = source_file.load_module() # pylint: disable=W4902,E1120
module = source_file.load_module() # pylint: disable=W4902,E1120, W1505:
# load into config
config = Config()
@ -527,6 +527,7 @@ class ParallelContext(metaclass=SingletonMeta):
if dpseed_with_tpoffset:
dp_seed = seed + pipeline_offset * 1024
add_seed(ParallelMode.DATA, dp_seed)
add_seed(ParallelMode.DUMMY, dp_seed)
# model parallel seeds are different across ranks
if self.is_initialized(ParallelMode.TENSOR):
@ -534,7 +535,11 @@ class ParallelContext(metaclass=SingletonMeta):
tp_seed = seed + tp_rank + pipeline_offset * 1024
add_seed(ParallelMode.TENSOR, tp_seed)
set_mode(ParallelMode.DATA)
# we do not set the random state mode to ParallelMode.DATA until model is built (instead, we use a dummy mode
# during model construction), this is because the random state will be different in different tensor parallel
# device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform
# additional random operations during the RowParallelLinear module building process.
set_mode(ParallelMode.DUMMY)
seeds = get_seeds()
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])

View File

@ -37,6 +37,9 @@ class ParallelMode(Enum):
# expert data parallel
EXPERT_DATA = "expert_data"
# dummy mode, only used during mode construction
DUMMY = "dummy"
class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups.

View File

@ -12,6 +12,7 @@ from torch.utils.data import ConcatDataset, DataLoader
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.random import set_mode
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.trainer import TrainState
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
@ -81,6 +82,10 @@ def initialize_model():
# the same across tensor parallelism.
sync_model_param_within_tp(model)
# Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random
# state in the same dp group are all the same.
set_mode(ParallelMode.DATA)
return model
@ -414,7 +419,7 @@ def record_current_batch_training_metrics(
"step": batch_count,
"lr": lr,
"num_consumed_tokens": train_state.num_consumed_tokens,
"loss": loss.item(),
"loss": loss.item() - moe_loss.item(),
"flops": tflops,
"tgs": tk_per_gpu,
"acc": acc_perplex["acc"],