fix(core/context): use dummy mode to generate random numbers in model construction (#266)

* change mode to dummy in model construction and restore to data when done

* add comments

* move set_mode(.DATA) to initialize_model(.)
pull/286/head
Wenwen Qu 2023-09-06 14:34:11 +08:00 committed by GitHub
parent ff181bc5f8
commit 7f687bf4b3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 16 additions and 3 deletions

View File

@ -36,7 +36,7 @@ class Config(dict):
config (dict): The dict object to be wrapped.
"""
def __init__(self, config: dict = None):
def __init__(self, config: dict = None): # pylint: disable=W0231
if config is not None:
for k, v in config.items():
self._add_item(k, v)
@ -100,7 +100,7 @@ class Config(dict):
module_name = filepath.stem
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
module = source_file.load_module() # pylint: disable=W4902,E1120
module = source_file.load_module() # pylint: disable=W4902,E1120,W1505
# load into config
config = Config()
@ -526,6 +526,7 @@ class ParallelContext(metaclass=SingletonMeta):
if dpseed_with_tpoffset:
dp_seed = seed + pipeline_offset * 1024
add_seed(ParallelMode.DATA, dp_seed)
add_seed(ParallelMode.DUMMY, dp_seed)
# model parallel seeds are different across ranks
if self.is_initialized(ParallelMode.TENSOR):
@ -533,7 +534,11 @@ class ParallelContext(metaclass=SingletonMeta):
tp_seed = seed + tp_rank + pipeline_offset * 1024
add_seed(ParallelMode.TENSOR, tp_seed)
set_mode(ParallelMode.DATA)
# we do not set the random state mode to ParallelMode.DATA until model is built (instead, we use a dummy mode
# during model construction), this is because the random state will be different in different tensor parallel
# device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform
# additional random operations during the RowParallelLinear module building process.
set_mode(ParallelMode.DUMMY)
seeds = get_seeds()
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])

View File

@ -35,6 +35,9 @@ class ParallelMode(Enum):
# runntime network test
NETTEST = "nettest"
# dummy mode, only used during mode construction
DUMMY = "dummy"
class ProcessGroupInitializer(ABC):
"""An object, knowing the parallelism configuration, that initializes parallel groups.

View File

@ -12,6 +12,7 @@ from torch.utils.data import ConcatDataset, DataLoader
from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.context.random import set_mode
from internlm.core.naive_amp import NaiveAMPModel
from internlm.core.trainer import TrainState
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
@ -80,6 +81,10 @@ def initialize_model():
# the same across tensor parallelism.
sync_model_param_within_tp(model)
# Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random
# state in the same dp group are all the same.
set_mode(ParallelMode.DATA)
return model