mirror of https://github.com/InternLM/InternLM
bind seed
parent
2b984ffa58
commit
c53667d70c
|
@ -27,6 +27,7 @@ else:
|
||||||
get_numa = True
|
get_numa = True
|
||||||
|
|
||||||
logger = get_logger(__file__)
|
logger = get_logger(__file__)
|
||||||
|
GLOBAL_SEED = 1024
|
||||||
|
|
||||||
|
|
||||||
def get_default_parser():
|
def get_default_parser():
|
||||||
|
@ -531,6 +532,9 @@ def initialize_distributed_env(
|
||||||
else:
|
else:
|
||||||
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
assert launcher in ["slurm", "torch"], "launcher only support slurm or torch"
|
||||||
|
|
||||||
|
global GLOBAL_SEED
|
||||||
|
GLOBAL_SEED = seed
|
||||||
|
|
||||||
if args_check:
|
if args_check:
|
||||||
args_sanity_check()
|
args_sanity_check()
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,11 @@
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
import math
|
import math
|
||||||
|
import random
|
||||||
|
from functools import wraps
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
from flash_attn.modules.embedding import ParallelGPT2Embeddings
|
||||||
from flash_attn.modules.mlp import ParallelFusedMLP
|
from flash_attn.modules.mlp import ParallelFusedMLP
|
||||||
|
@ -12,6 +15,7 @@ from torch import nn
|
||||||
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
from internlm.core.context import IS_SEQUENCE_PARALLEL, IS_TENSOR_PARALLEL, ParallelMode
|
||||||
from internlm.core.context.parallel_context import global_context as gpc
|
from internlm.core.context.parallel_context import global_context as gpc
|
||||||
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
from internlm.initialize.initialize_tensor import normal_, scaled_init_method_normal
|
||||||
|
from internlm.initialize.launch import GLOBAL_SEED
|
||||||
from internlm.model.embedding import Embedding1D
|
from internlm.model.embedding import Embedding1D
|
||||||
from internlm.model.linear import (
|
from internlm.model.linear import (
|
||||||
FeedForward,
|
FeedForward,
|
||||||
|
@ -81,6 +85,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
self.use_flash_attn = use_flash_attn
|
self.use_flash_attn = use_flash_attn
|
||||||
|
|
||||||
head_dim = hidden_size // num_attention_heads
|
head_dim = hidden_size // num_attention_heads
|
||||||
|
|
||||||
self.mixer = MHA(
|
self.mixer = MHA(
|
||||||
embed_dim=hidden_size,
|
embed_dim=hidden_size,
|
||||||
num_heads=num_attention_heads,
|
num_heads=num_attention_heads,
|
||||||
|
@ -410,6 +415,18 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
||||||
|
def fix_seed(func):
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
seed = GLOBAL_SEED
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
func(*args, **kwargs)
|
||||||
|
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"), **kwargs):
|
||||||
"""
|
"""
|
||||||
build generic model 1d
|
build generic model 1d
|
||||||
|
@ -429,6 +446,7 @@ def _build_generic_model_1d(num_layers, num_chunks, device=torch.device("cuda"),
|
||||||
logger.info(f"The layer sharding is {all_parts}.")
|
logger.info(f"The layer sharding is {all_parts}.")
|
||||||
|
|
||||||
models = []
|
models = []
|
||||||
|
PackedFlashInternLm1D.__init__ = fix_seed(PackedFlashInternLm1D.__init__)
|
||||||
|
|
||||||
for start, end in parts:
|
for start, end in parts:
|
||||||
kwargs["num_layers"] = end - start
|
kwargs["num_layers"] = end - start
|
||||||
|
|
Loading…
Reference in New Issue