mirror of https://github.com/InternLM/InternLM
fix(fsdp): fix conflicts
commit
aedd88e5a7
|
|
@ -1,7 +1,8 @@
|
|||
JOB_NAME = "7b_train"
|
||||
DO_ALERT = False
|
||||
|
||||
SEQ_LEN = 2048
|
||||
HIDDEN_SIZE = 4096
|
||||
SEQ_LEN = 256
|
||||
HIDDEN_SIZE = 512
|
||||
NUM_ATTENTION_HEAD = 32
|
||||
MLP_RATIO = 8 / 3
|
||||
NUM_LAYER = 32
|
||||
|
|
@ -22,14 +23,22 @@ CHECKPOINT_EVERY = 20
|
|||
ckpt = dict(
|
||||
enable_save_ckpt=False, # enable ckpt save.
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save training ckpt.
|
||||
load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
|
||||
# load_ckpt_folder=LOAD_CKPT_FOLDER, # Ckpt path to resume training(load weights and scheduler/context states).
|
||||
load_given_ckpt = False,
|
||||
# load_model_only_folder=MODEL_ONLY_FOLDER, # Path to initialize with given model weights.
|
||||
load_optimizer=True, # Wheter to load optimizer states when continuing training.
|
||||
|
||||
# load_ckpt_folder= dict(path=MODEL_ONLY_FOLDER, content=["model"], ckpt_type="normal"),
|
||||
load_ckpt_folder="local:llm_ckpts/",
|
||||
# 'load_ckpt_info' setting guide:
|
||||
# 1. the 'path' indicate ckpt path,
|
||||
# 2. the 'content‘ means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
|
||||
# 3. the ’ckpt_type‘ means the type of checkpoint to be loaded, now only 'normal' type is supported.
|
||||
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
||||
|
||||
checkpoint_every=CHECKPOINT_EVERY,
|
||||
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
|
||||
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
|
||||
snapshot_ckpt_folder="/".join([SAVE_CKPT_FOLDER, "snapshot"]), # directory for snapshot ckpt storage path.
|
||||
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
|
||||
)
|
||||
|
||||
|
|
@ -46,7 +55,7 @@ data = dict(
|
|||
# defaults to 0, means disable evaluate
|
||||
valid_every=50,
|
||||
pack_sample_into_one=False,
|
||||
total_steps=50000,
|
||||
total_steps=30,
|
||||
skip_batches="",
|
||||
rampup_batch_size="",
|
||||
# Datasets with less than 50 rows will be discarded
|
||||
|
|
@ -145,8 +154,17 @@ parallel = dict(
|
|||
pipeline=dict(size=1, interleaved_overlap=True),
|
||||
tensor=1,
|
||||
sequence_parallel=False,
|
||||
use_fsdp = True,
|
||||
use_fsdp=True,
|
||||
)
|
||||
|
||||
cudnn_deterministic = False
|
||||
cudnn_benchmark = False
|
||||
|
||||
monitor = dict(
|
||||
# feishu alert configs
|
||||
alert=dict(
|
||||
enable_feishu_alert=DO_ALERT,
|
||||
feishu_alert_address=None, # feishu webhook to send alert message
|
||||
light_monitor_address=None, # light_monitor address to send heartbeat
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -112,19 +112,19 @@ If you want to load a model checkpoint when starting the training, you can confi
|
|||
|
||||
```python
|
||||
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
|
||||
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
|
||||
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
|
||||
ckpt = dict(
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # Path to save the model and optimizer checkpoints
|
||||
checkpoint_every=float("inf"), # Save a checkpoint every specified number of steps, default value is inf
|
||||
load_model_only_folder=MODEL_ONLY_FOLDER, # Path to load the initial model weights, only load model weights without loading optimizer weights, training will start from the first step
|
||||
load_ckpt_folder=LOAD_CKPT_FOLDER, # Path to load the weights of the model and optimizer for resuming training, training will resume from the specified step
|
||||
load_optimizer=True, # Whether to load optimizer weights when resuming training, default value is True
|
||||
# When resuming training from a breakpoint,:
|
||||
# (1) 'path' is the path of the loaded checkpoint.
|
||||
# (2) 'content' indicates which state will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
|
||||
# (3) 'ckpt_type' indicates which type ckpt will be loaded, currently supported: "internlm"
|
||||
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
||||
)
|
||||
```
|
||||
|
||||
Note:
|
||||
- `load_model_only_folder` and `load_ckpt_folder` cannot be set at the same time.
|
||||
- If the path starts with `local:`, it means the file is stored in the local file system. If it starts with `boto3:`, it means the file is stored in the remote OSS.
|
||||
|
||||
The configuration for the model is as follows:
|
||||
|
|
|
|||
|
|
@ -101,18 +101,17 @@ data = dict(
|
|||
如果在启动训练时要加载模型 `checkpoint`,可进行如下相关配置:
|
||||
```python
|
||||
SAVE_CKPT_FOLDER = "local:/path/to/save/ckpt"
|
||||
MODEL_ONLY_FOLDER = "local:/path/to/load/init/model/ckpt"
|
||||
LOAD_CKPT_FOLDER = "local:/path/to/load/resume/ckpt"
|
||||
ckpt = dict(
|
||||
save_ckpt_folder=SAVE_CKPT_FOLDER, # 存储模型和优化器 checkpoint 的路径
|
||||
checkpoint_every=float("inf"), # 每多少个 step 存储一次 checkpoint,默认值为 inf
|
||||
load_model_only_folder=MODEL_ONLY_FOLDER, # 加载模型初始权重的路径,只加载模型权重,不加载优化器权重,训练将从第一个 step 开始
|
||||
load_ckpt_folder=LOAD_CKPT_FOLDER, # 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
|
||||
load_optimizer=True, # 断点续训时,是否需要加载优化器权重,默认值为 True
|
||||
# 断点续训时,加载模型和优化器等权重的路径,将从指定的 step 恢复训练
|
||||
# content 表示哪些状态会被加载,支持: "model", "sampler", "optimizer", "scheduler", "all"
|
||||
# ckpt_type 表示加载的模型类型,目前支持: "internlm"
|
||||
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
|
||||
)
|
||||
```
|
||||
注意:
|
||||
- `load_model_only_folder`与`load_ckpt_folder`不能同时设置
|
||||
- 路径若以 `local:` 为前缀,则存储在本地文件系统;若以 `boto3:` 为前缀,则存储在远程 oss 上
|
||||
|
||||
模型相关关键参数配置如下所示:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ import torch.distributed as dist
|
|||
|
||||
from internlm.utils.common import SingletonMeta
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
|
||||
|
||||
from . import process_group_initializer as pgroup_initializer
|
||||
from .process_group_initializer import ParallelMode
|
||||
|
|
@ -36,7 +37,7 @@ class Config(dict):
|
|||
config (dict): The dict object to be wrapped.
|
||||
"""
|
||||
|
||||
def __init__(self, config: dict = None):
|
||||
def __init__(self, config: dict = None): # pylint: disable=W0231
|
||||
if config is not None:
|
||||
for k, v in config.items():
|
||||
self._add_item(k, v)
|
||||
|
|
@ -100,7 +101,7 @@ class Config(dict):
|
|||
|
||||
module_name = filepath.stem
|
||||
source_file = SourceFileLoader(fullname=str(module_name), path=str(filepath))
|
||||
module = source_file.load_module() # pylint: disable=W4902,E1120
|
||||
module = source_file.load_module() # pylint: disable=W4902,E1120,W1505
|
||||
|
||||
# load into config
|
||||
config = Config()
|
||||
|
|
@ -374,12 +375,22 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
"""
|
||||
# initialize the default process group
|
||||
init_method = f"tcp://[{host}]:{port}"
|
||||
dist.init_process_group(rank=rank, world_size=world_size, backend=backend, init_method=init_method)
|
||||
dist.init_process_group(
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
backend=backend,
|
||||
init_method=init_method,
|
||||
timeout=LLM_NCCL_TIMEOUT,
|
||||
)
|
||||
|
||||
# None will give the default global process group for pytorch dist operations
|
||||
ranks = list(range(world_size))
|
||||
if use_cpu:
|
||||
cpu_group = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else None
|
||||
cpu_group = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else None
|
||||
)
|
||||
else:
|
||||
cpu_group = None
|
||||
self._register_dist(rank, world_size, dist.GroupMember.WORLD, cpu_group, ranks, ParallelMode.GLOBAL)
|
||||
|
|
@ -528,6 +539,7 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
if dpseed_with_tpoffset:
|
||||
dp_seed = seed + pipeline_offset * 1024
|
||||
add_seed(ParallelMode.DATA, dp_seed)
|
||||
add_seed(ParallelMode.DUMMY, dp_seed)
|
||||
|
||||
# model parallel seeds are different across ranks
|
||||
if self.is_initialized(ParallelMode.TENSOR):
|
||||
|
|
@ -535,7 +547,11 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
tp_seed = seed + tp_rank + pipeline_offset * 1024
|
||||
add_seed(ParallelMode.TENSOR, tp_seed)
|
||||
|
||||
set_mode(ParallelMode.DATA)
|
||||
# we do not set the random state mode to ParallelMode.DATA until model is built (instead, we use a dummy mode
|
||||
# during model construction), this is because the random state will be different in different tensor parallel
|
||||
# device of the same data parallel group. The underlying reason is that the device of tp_rank = 0 will perform
|
||||
# additional random operations during the RowParallelLinear module building process.
|
||||
set_mode(ParallelMode.DUMMY)
|
||||
|
||||
seeds = get_seeds()
|
||||
seed_str = ", ".join([f"{k}: {v}" for k, v in seeds.items()])
|
||||
|
|
|
|||
|
|
@ -9,6 +9,8 @@ from enum import Enum
|
|||
|
||||
import torch.distributed as dist
|
||||
|
||||
from internlm.utils.timeout import LLM_NCCL_TIMEOUT
|
||||
|
||||
|
||||
# parallel modes
|
||||
class ParallelMode(Enum):
|
||||
|
|
@ -40,6 +42,9 @@ class ParallelMode(Enum):
|
|||
# then manual communication only happens between inter-fsdp-modules, while intra-modules reduction is done by fsdp
|
||||
ZERO3_DP = "zero3_dp"
|
||||
|
||||
# dummy mode, only used during mode construction
|
||||
DUMMY = "dummy"
|
||||
|
||||
|
||||
class ProcessGroupInitializer(ABC):
|
||||
"""An object, knowing the parallelism configuration, that initializes parallel groups.
|
||||
|
|
@ -111,9 +116,13 @@ class Initializer_Data(ProcessGroupInitializer):
|
|||
|
||||
for i in range(self.rank_num_per_dp_group):
|
||||
ranks = [i + j * self.rank_num_per_dp_group for j in range(self.data_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
|
@ -163,9 +172,13 @@ class Initializer_Model(ProcessGroupInitializer):
|
|||
|
||||
for i in range(self.num_group):
|
||||
ranks = [i * self.rank_num_per_group + j for j in range(self.rank_num_per_group)]
|
||||
group = dist.new_group(ranks)
|
||||
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
|
@ -223,9 +236,13 @@ class Initializer_Pipeline(ProcessGroupInitializer):
|
|||
)
|
||||
)
|
||||
pipe_group_size = len(ranks)
|
||||
pipe_group = dist.new_group(ranks)
|
||||
pipe_group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else pipe_group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else pipe_group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
|
@ -273,9 +290,13 @@ class Initializer_Tensor(ProcessGroupInitializer):
|
|||
|
||||
for i in range(self.num_tensor_parallel_group):
|
||||
ranks = [i * self.tensor_parallel_size + j for j in range(self.tensor_parallel_size)]
|
||||
group = dist.new_group(ranks)
|
||||
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
|
@ -329,9 +350,13 @@ class Initializer_Zero1(ProcessGroupInitializer):
|
|||
i + (j * self.zero1_parallel_size + k) * self.rank_num_per_dp_group
|
||||
for k in range(self.zero1_parallel_size)
|
||||
]
|
||||
group = dist.new_group(ranks)
|
||||
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
|
@ -378,9 +403,13 @@ class Initializer_Nettest(ProcessGroupInitializer):
|
|||
rank = i * self.nettest_parallel_size + j
|
||||
if rank < self.world_size:
|
||||
ranks.append(rank)
|
||||
group = dist.new_group(ranks)
|
||||
group = dist.new_group(ranks, timeout=LLM_NCCL_TIMEOUT)
|
||||
if use_cpu:
|
||||
group_cpu = dist.new_group(ranks, backend="gloo") if dist.get_backend() != "gloo" else group
|
||||
group_cpu = (
|
||||
dist.new_group(ranks, backend="gloo", timeout=LLM_NCCL_TIMEOUT)
|
||||
if dist.get_backend() != "gloo"
|
||||
else group
|
||||
)
|
||||
else:
|
||||
group_cpu = None
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import torch
|
|||
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.common import conditional_context
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||
|
||||
|
|
@ -126,6 +127,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
|
||||
return output, loss
|
||||
|
||||
@llm_timeout(func_name="nopp_forward_backward_step")
|
||||
def forward_backward_step(
|
||||
self,
|
||||
engine: Engine,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from internlm.core.engine import Engine
|
|||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.common import get_current_device, move_to_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||
|
||||
|
|
@ -592,6 +593,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
return output, label, accum_loss
|
||||
|
||||
@llm_timeout(func_name="nointerleaved_forward_backward_step")
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
"""Runs non-interleaved 1F1B schedule, with communication between pipeline stages.
|
||||
Returns a tuple with losses if the last stage, an empty tuple otherwise.
|
||||
|
|
@ -1247,6 +1249,7 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
# 3. Cooldown
|
||||
self._run_cooldown_loop(engine, num_microsteps, num_1f1b_micropairs=num_1f1b_micropairs)
|
||||
|
||||
@llm_timeout(func_name="interleaved_forward_backward_step")
|
||||
def forward_backward_step(self, engine, data_iter, forward_only=False, return_loss=True, return_output_label=True):
|
||||
"""Run interleaved 1F1B schedule (model split into model chunks), with
|
||||
communication between pipeline stages as needed.
|
||||
|
|
|
|||
|
|
@ -23,7 +23,15 @@ class TrainState:
|
|||
train_dl (DataLoader): The DataLoader object used for training.
|
||||
"""
|
||||
|
||||
def __init__(self, config) -> None:
|
||||
def __init__(self, config, batch_sampler) -> None:
|
||||
"""
|
||||
Args:
|
||||
config (Config): internlm config
|
||||
batch_sampler (torch.utils.data.Sampler): Because the dataloader loading is
|
||||
asynchronous and prefetched, the batch_sampler state maintained inside the
|
||||
dataloader are faster then the actual training progress, so we copy the
|
||||
batch_sampler as the anchor point of ckpt reload.
|
||||
"""
|
||||
# The number of batches produced by the data iterator
|
||||
self.batch_count: int = 0
|
||||
# Used to store the number of samples consumed in the current epoch
|
||||
|
|
@ -43,9 +51,20 @@ class TrainState:
|
|||
|
||||
self.tensorboard_folder = config.tensorboard_folder
|
||||
|
||||
def init_batch_sampler(self, train_dl):
|
||||
# Copy of the batch sampler from the DataLoader
|
||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
# learning rate
|
||||
self.lr = config.adam.lr
|
||||
|
||||
# smapler state
|
||||
if batch_sampler:
|
||||
self.init_batch_sampler(batch_sampler)
|
||||
|
||||
def init_batch_sampler(self, batch_sampler):
|
||||
"""
|
||||
Args:
|
||||
batch_sampler (torch.utils.data.Sampler): sampler.
|
||||
"""
|
||||
# make a copy of batch_sampler.
|
||||
self.batch_sampler = batch_sampler.copy()
|
||||
# Iterator for the batch sampler
|
||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||
|
||||
|
|
@ -61,26 +80,22 @@ class TrainState:
|
|||
|
||||
return json.dumps(info, indent=4, sort_keys=True)
|
||||
|
||||
def load_state_dict(self, other_stuffs, train_dl):
|
||||
def load_state_dict(self, other_stuffs):
|
||||
"""
|
||||
Resumes training from a checkpoint.
|
||||
|
||||
Args:
|
||||
other_stuffs (dict): Other information needed to resume training.
|
||||
train_dl (DataLoader): The DataLoader object used for training.
|
||||
"""
|
||||
|
||||
self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
|
||||
self.num_consumed_samples_in_epoch = other_stuffs["num_consumed_samples_in_epoch"]
|
||||
self.num_consumed_tokens = other_stuffs["num_consumed_tokens"]
|
||||
self.inf_nan_skip_batches = other_stuffs["inf_nan_skip_batches"]
|
||||
# compatible with previous checkpoints without this parameter
|
||||
self.step_count = other_stuffs.get("step_count", other_stuffs["batch_count"]) + 1
|
||||
|
||||
# track the actual updates of sampler when using weighted sampling
|
||||
if hasattr(self, "batch_sampler"):
|
||||
self.batch_sampler = train_dl.batch_sampler.copy()
|
||||
self.batch_sampler_iter = iter(self.batch_sampler)
|
||||
# Because the ckpt save occurs after updating 'step_count',
|
||||
# there is no need to increment 'step_count' here (Does our step count start from 0 ?),
|
||||
# However, 'batch_count' is updating before ckpt storage, so it need to inc 1 when resume.
|
||||
self.batch_count = other_stuffs["batch_count"] + 1 # here you need to shift a batch backward
|
||||
self.step_count = other_stuffs.get("step_count", self.batch_count)
|
||||
|
||||
# resume tensorboard from older tensorboard_folder
|
||||
self.resume_tb_folder = other_stuffs.get("tensorboard_folder", None)
|
||||
|
|
|
|||
|
|
@ -10,9 +10,10 @@ import torch
|
|||
|
||||
from internlm.core.context import Config
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.monitor import initialize_light_monitor
|
||||
from internlm.utils.common import get_master_node
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.storage_manager import init_storage_manager
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
|
@ -122,7 +123,7 @@ def args_sanity_check():
|
|||
# processing the checkpoint config
|
||||
ckpt = gpc.config.ckpt
|
||||
if "enable_save_ckpt" not in ckpt:
|
||||
ckpt._add_item("enable_save_ckpt", False)
|
||||
ckpt._add_item("enable_save_ckpt", True)
|
||||
|
||||
# Saving checkpoint args.
|
||||
if ckpt.enable_save_ckpt:
|
||||
|
|
@ -148,9 +149,6 @@ def args_sanity_check():
|
|||
if not ckpt.async_upload:
|
||||
ckpt._add_item("async_upload_tmp_folder", None)
|
||||
|
||||
if "snapshot_ckpt_folder" not in ckpt:
|
||||
ckpt._add_item("snapshot_ckpt_folder", os.path.join(ckpt.save_ckpt_folder, "snapshot"))
|
||||
|
||||
if "oss_snapshot_freq" not in ckpt:
|
||||
ckpt._add_item("oss_snapshot_freq", float("inf")) # if oss_snapshot_freq not given, we disable.
|
||||
else:
|
||||
|
|
@ -160,44 +158,23 @@ def args_sanity_check():
|
|||
ckpt._add_item("async_upload", False)
|
||||
ckpt._add_item("async_upload_tmp_folder", None)
|
||||
ckpt._add_item("snapshot_ckpt_folder", None)
|
||||
ckpt._add_item("snapshot_ckpt_folder", None)
|
||||
|
||||
# Loading checkpoint args.
|
||||
if "load_model_only_folder" not in ckpt:
|
||||
ckpt._add_item("load_model_only_folder", None)
|
||||
|
||||
if "load_ckpt_folder" not in ckpt:
|
||||
ckpt._add_item("load_ckpt_folder", None)
|
||||
|
||||
if "load_optimizer" not in ckpt:
|
||||
ckpt._add_item("load_optimizer", True)
|
||||
|
||||
if "stop_file_path" not in ckpt:
|
||||
ckpt._add_item("stop_file_path", None)
|
||||
|
||||
if "load_given_ckpt" not in ckpt:
|
||||
# If 'load_given_ckpt' is not given, we set it to False, so internlm can have opportunity
|
||||
if "auto_resume" not in ckpt:
|
||||
# If 'auto_resume' is not given, we set it to True, so internlm can have opportunity
|
||||
# to auto-load latest checkpoint.
|
||||
ckpt._add_item("load_given_ckpt", False)
|
||||
|
||||
if ckpt.load_given_ckpt:
|
||||
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
|
||||
if ckpt.load_ckpt_folder and ckpt.load_model_only_folder:
|
||||
logger.warning(
|
||||
"Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
|
||||
and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
||||
)
|
||||
ckpt.load_model_only_folder = None
|
||||
ckpt._add_item("auto_resume", True)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info("+" * 15 + " Ckpt Info " + "+" * 15) # pylint: disable=W1201
|
||||
logger.info(f"is enable save ckpt: {ckpt.enable_save_ckpt}")
|
||||
logger.info(f"save_ckpt_folder: {ckpt.save_ckpt_folder}")
|
||||
logger.info(f"checkpoint_every: {ckpt.checkpoint_every}")
|
||||
logger.info(f"load_given_ckpt: {ckpt.load_given_ckpt}")
|
||||
|
||||
# initialization storage manager
|
||||
init_storage_manager(ckpt)
|
||||
|
||||
# tensorboard writer config
|
||||
if "enable_tb" not in gpc.config:
|
||||
|
|
@ -288,9 +265,22 @@ and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
|||
gpc.config.parallel.sequence_parallel is True and gpc.config.model.use_flash_attn is False
|
||||
), "sequence parallel does not support use_flash_attn=False"
|
||||
|
||||
# feishu webhook address for alerting
|
||||
if "alert_address" not in gpc.config:
|
||||
gpc.config._add_item("alert_address", None)
|
||||
# monitoring default config
|
||||
monitor_default_config = {
|
||||
"alert_address": None, # compatible with old alert config
|
||||
"monitor": { # new monitoring config
|
||||
"alert": {"enable_feishu_alert": False, "feishu_alert_address": None, "light_monitor_address": None}
|
||||
},
|
||||
}
|
||||
|
||||
for key, value in monitor_default_config.items():
|
||||
if key not in gpc.config:
|
||||
gpc.config._add_item(key, value)
|
||||
|
||||
alert = gpc.config.monitor.alert
|
||||
|
||||
if alert.enable_feishu_alert and not alert.feishu_alert_address and gpc.is_rank_for_log():
|
||||
logger.warning("alert is enable but alert_address is not set")
|
||||
|
||||
optim_ckpt = gpc.config.hybrid_zero_optimizer
|
||||
if "zero_overlap_communication" in optim_ckpt:
|
||||
|
|
@ -437,6 +427,7 @@ def launch_from_torch(
|
|||
)
|
||||
|
||||
|
||||
@llm_timeout(func_name="initialize_distributed_env")
|
||||
def initialize_distributed_env(
|
||||
config: str,
|
||||
launcher: str = "slurm",
|
||||
|
|
@ -470,3 +461,20 @@ def initialize_distributed_env(
|
|||
|
||||
if args_check:
|
||||
args_sanity_check()
|
||||
|
||||
# init light monitor client
|
||||
alert_config = gpc.config.monitor.alert
|
||||
if alert_config.enable_feishu_alert and gpc.is_rank_for_log():
|
||||
light_monitor_address = alert_config.light_monitor_address
|
||||
if light_monitor_address:
|
||||
initialize_light_monitor(light_monitor_address)
|
||||
else:
|
||||
logger.warning("monitor address is none, monitor could not be used!")
|
||||
|
||||
|
||||
def get_config_value(config, key, defalut):
|
||||
try:
|
||||
value = config[key]
|
||||
except KeyError:
|
||||
value = defalut
|
||||
return value
|
||||
|
|
|
|||
|
|
@ -0,0 +1,40 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from internlm.initialize.launch import get_config_value
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def auto_resume_sanity_check(ckpt_config):
|
||||
load_given_ckpt = get_config_value(ckpt_config, "load_given_ckpt", None)
|
||||
if load_given_ckpt is None:
|
||||
return True # default value is True
|
||||
else:
|
||||
return not load_given_ckpt
|
||||
|
||||
|
||||
def ckpt_info_sanity_check(ckpt_config):
|
||||
load_ckpt_folder = get_config_value(ckpt_config, "load_ckpt_folder", None)
|
||||
|
||||
load_model_only_folder = get_config_value(ckpt_config, "load_model_only_folder", None)
|
||||
|
||||
if load_model_only_folder is not None:
|
||||
assert (
|
||||
load_ckpt_folder is None
|
||||
), "Detect 'load_ckpt_folder' and 'load_model_only_folder' set at the same time, \
|
||||
# and 'load_given_ckpt' is True, so internlm will load from 'load_ckpt_folder'"
|
||||
return dict(path=load_model_only_folder, content=("model",), ckpt_type="internlm")
|
||||
else:
|
||||
load_optimizer = get_config_value(ckpt_config, "load_optimizer", True)
|
||||
|
||||
if isinstance(load_ckpt_folder, str):
|
||||
if load_optimizer:
|
||||
return dict(path=load_ckpt_folder, content=("model", "sampler", "optimizer"), ckpt_type="internlm")
|
||||
else:
|
||||
return dict(path=load_ckpt_folder, content=("model", "sampler"), ckpt_type="internlm")
|
||||
elif load_ckpt_folder is None:
|
||||
return None
|
||||
else:
|
||||
assert f"Unsupport data type:'{type(load_ckpt_folder)}' for config.ckpt arg: 'load_ckpt_folder'"
|
||||
|
|
@ -9,7 +9,7 @@ from flash_attn.ops.fused_dense import ColumnParallelLinear, RowParallelLinear
|
|||
from flash_attn.utils.distributed import all_reduce, reduce_scatter
|
||||
from torch import nn
|
||||
|
||||
from internlm.core.context import IS_TENSOR_PARALLEL, ParallelMode
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.model.utils import fused_dense_func_torch
|
||||
|
||||
|
|
@ -195,12 +195,6 @@ class FeedForward(nn.Module):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
# need to assign tp attribute so that colossalai know it is tensor parallel module
|
||||
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
for name in ["w1", "w2", "w3"]:
|
||||
for param in getattr(self, name).parameters():
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.w3(F.silu(self.w1(x)) * self.w2(x))
|
||||
|
|
|
|||
|
|
@ -127,6 +127,9 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
device=device,
|
||||
dtype=dtype,
|
||||
)
|
||||
for _, param in self.mlp.named_parameters():
|
||||
if gpc.get_world_size(ParallelMode.TENSOR) > 1:
|
||||
setattr(param, IS_TENSOR_PARALLEL, True)
|
||||
self.dropout2 = nn.Dropout(drop_rate)
|
||||
self.use_swiglu = use_swiglu
|
||||
self.use_scaled_init = use_scaled_init
|
||||
|
|
|
|||
|
|
@ -1,4 +1,11 @@
|
|||
from .alert import initialize_light_monitor, send_heartbeat
|
||||
from .monitor import initialize_monitor_manager, send_alert_message
|
||||
from .utils import set_env_var
|
||||
|
||||
__all__ = ["send_alert_message", "initialize_monitor_manager", "set_env_var"]
|
||||
__all__ = [
|
||||
"send_alert_message",
|
||||
"initialize_monitor_manager",
|
||||
"set_env_var",
|
||||
"initialize_light_monitor",
|
||||
"send_heartbeat",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,59 @@
|
|||
import json
|
||||
import math
|
||||
import os
|
||||
import re
|
||||
import time
|
||||
from typing import Dict
|
||||
|
||||
import requests
|
||||
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def initialize_light_monitor(monitor_address: str = None):
|
||||
try:
|
||||
from uniscale_monitoring import init_monitor
|
||||
|
||||
init_monitor(monitor_address)
|
||||
except Exception as e:
|
||||
logger.warning(f"init monitor meet error: {e}")
|
||||
|
||||
|
||||
def send_heartbeat(msg_type: str, msg: Dict):
|
||||
def nan2none(v):
|
||||
if isinstance(v, float) and math.isnan(v):
|
||||
return None
|
||||
return v
|
||||
|
||||
try:
|
||||
from uniscale_monitoring import send_meta
|
||||
|
||||
data = {}
|
||||
for k, v in msg.items():
|
||||
if isinstance(v, Dict):
|
||||
for k1, v1 in v.items():
|
||||
new_k = f"{k}_{k1}".split(" ")[0]
|
||||
new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k)
|
||||
data[new_k] = nan2none(v1)
|
||||
else:
|
||||
new_k = k.split(" ")[0]
|
||||
new_k = re.sub(r"[^a-zA-Z0-9_]", "_", new_k)
|
||||
data[new_k] = nan2none(v)
|
||||
|
||||
if os.getenv("CLUSTER_NAME"):
|
||||
data.update({"cluster": os.getenv("CLUSTER_NAME")})
|
||||
if msg_type == "train_metrics":
|
||||
data.update({"msg_type": "train_metrics"})
|
||||
elif msg_type == "init_time":
|
||||
data.update({"msg_type": "init_time"})
|
||||
elif msg_type == "stage_time":
|
||||
data.update({"msg_type": "stage_time"})
|
||||
send_meta(data, timeout=0.1)
|
||||
except Exception as e:
|
||||
logger.warning(f"send heartbeat meet error: {e}")
|
||||
|
||||
|
||||
def send_feishu_msg_with_webhook(webhook: str, title: str, message: str):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -226,9 +226,7 @@ def initialize_monitor_manager(job_name: str = None, alert_address: str = None):
|
|||
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} is starting.")
|
||||
yield
|
||||
finally:
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address, message=f"Training in {socket.gethostname()} completed."
|
||||
)
|
||||
send_alert_message(address=alert_address, message=f"Training in {socket.gethostname()} completed.")
|
||||
monitor_manager.stop_monitor()
|
||||
else:
|
||||
yield
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer
|
||||
from .hybrid_zero_optim import HybridZeroOptimizer, FSDPadaptOptimizer, reload_zero_fp32_buff
|
||||
|
||||
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer"]
|
||||
__all__ = ["HybridZeroOptimizer", "FSDPadaptOptimizer", "reload_zero_fp32_buff"]
|
||||
|
|
@ -32,6 +32,7 @@ from internlm.solver.optimizer.utils import (
|
|||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
from .utils import compute_norm
|
||||
|
||||
|
|
@ -329,6 +330,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
self._param_store = ParameterStore(ParallelMode.ZERO1)
|
||||
self._grad_store = GradientStore(ParallelMode.DATA)
|
||||
self._bucket_store = BucketStore(ParallelMode.DATA)
|
||||
self._bucket_in_progress = []
|
||||
|
||||
# fp16 and fp32 params for mixed precision training
|
||||
self._fp16_param_groups = dict()
|
||||
|
|
@ -338,6 +340,8 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# self._overlap_communication = overlap_communication
|
||||
self._reduce_bucket_size = reduce_bucket_size
|
||||
|
||||
self._comm_bcast_stream = torch.cuda.Stream()
|
||||
|
||||
# gradient scaler
|
||||
self.grad_scaler = DynamicGradScaler(
|
||||
initial_scale=initial_scale,
|
||||
|
|
@ -436,13 +440,6 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
|
||||
self.skip_grad_reduce = False
|
||||
|
||||
# initialize communication stream for
|
||||
# communication-computation overlapping
|
||||
if self._overlap_sync_grad:
|
||||
self._comm_stream = torch.cuda.Stream()
|
||||
else:
|
||||
self._comm_stream = torch.cuda.current_stream()
|
||||
|
||||
# reduction hook is only used if overlapping communication
|
||||
# if it is stage 1 without overlapping, no hook will be attached
|
||||
if self._overlap_sync_grad:
|
||||
|
|
@ -588,34 +585,41 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
def _reduce_grads_by_rank(self, reduce_rank, grads, bucket_size):
|
||||
grad_buckets_by_dtype = split_half_float_double(grads)
|
||||
|
||||
next_bucket_list = []
|
||||
# add parameters into bucket for reduction
|
||||
for tensor_list in grad_buckets_by_dtype:
|
||||
param_bucket = TensorBucket(size=bucket_size)
|
||||
for tensor in tensor_list:
|
||||
param_bucket.add_to_bucket(tensor, allow_oversize=True)
|
||||
if param_bucket.is_full_or_oversized():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
param_bucket.empty()
|
||||
if not param_bucket.is_empty():
|
||||
self._reduce_and_copy(bucket=param_bucket, reduce_rank=reduce_rank)
|
||||
next_bucket_list.append(param_bucket)
|
||||
|
||||
# wait for the completion of previouce bucket list reduction, and do unflatten_and_copy()
|
||||
# here we can also overlap the communication with some memcpy operation caused by bucket.flatten()
|
||||
for bucket in self._bucket_in_progress:
|
||||
bucket.commu_handle.wait()
|
||||
bucket.unflatten_and_copy()
|
||||
bucket.empty()
|
||||
self._bucket_in_progress = []
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
# after the completion of bucket list reduction, add new buckets into _bucket_in_progress
|
||||
self._bucket_in_progress = next_bucket_list.copy()
|
||||
|
||||
def _reduce_and_copy(self, bucket: TensorBucket, reduce_rank):
|
||||
if self._overlap_sync_grad:
|
||||
self._comm_stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
# flatten the tensors and do allreduce
|
||||
bucket.flatten()
|
||||
bucket.commu_handle = reduce_tensor(
|
||||
tensor=bucket.get_flat_tensor(),
|
||||
dtype=None,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=ParallelMode.DATA,
|
||||
)
|
||||
|
||||
with torch.cuda.stream(self._comm_stream):
|
||||
flat = bucket.flatten()
|
||||
reduced_flat = reduce_tensor(
|
||||
tensor=flat,
|
||||
dtype=self.dtype,
|
||||
dst_rank=reduce_rank,
|
||||
parallel_mode=ParallelMode.DATA,
|
||||
)
|
||||
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._zero_local_rank:
|
||||
bucket.unflatten_and_copy(reduced_flat)
|
||||
# update the reduced tensor
|
||||
if reduce_rank is None or reduce_rank == self._zero_local_rank:
|
||||
bucket.set_unflatten_and_copy_flag(flag=True)
|
||||
|
||||
def _has_inf_or_nan(self, tensor):
|
||||
try:
|
||||
|
|
@ -711,6 +715,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
return norm
|
||||
|
||||
@llm_timeout(func_name="optim_step")
|
||||
def step(self, closure=None):
|
||||
"""Performs a single optimization step.
|
||||
|
||||
|
|
@ -739,10 +744,13 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
groups_norms.append(self._compute_norm_with_stage(group_id=group_id))
|
||||
|
||||
# clear reduced grads
|
||||
if self._overlap_sync_grad:
|
||||
# grads in the last bucket is reduced
|
||||
self._comm_stream.synchronize()
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
# grads in the last bucket is reduced
|
||||
for bucket in self._bucket_in_progress:
|
||||
bucket.commu_handle.wait()
|
||||
bucket.unflatten_and_copy()
|
||||
bucket.empty()
|
||||
self._bucket_in_progress = []
|
||||
self._param_store.clear_grads_of_previous_reduced_params()
|
||||
|
||||
# compute norm for gradients in the last bucket
|
||||
total_norms = {}
|
||||
|
|
@ -783,7 +791,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if gpc.is_rank_for_log():
|
||||
logger.warning("Overflow occurs, please check it.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address,
|
||||
address=gpc.config.monitor.alert.feishu_alert_address,
|
||||
message="Overflow occurs, please check it.",
|
||||
)
|
||||
self._grad_store._averaged_gradients = dict()
|
||||
|
|
@ -829,7 +837,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
if gpc.config.model.dtype is not torch.float32:
|
||||
if len(single_grad_partition_groups) != 0 and self._clip_grad_norm > 0:
|
||||
self._unscale_and_clip_grads(
|
||||
single_grad_partition_groups, list(global_norm_groups.values()), loss_scale
|
||||
single_grad_partition_groups,
|
||||
list(global_norm_groups.values()),
|
||||
loss_scale,
|
||||
)
|
||||
|
||||
# update the parameters
|
||||
|
|
@ -850,7 +860,9 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
fp32_param = self._fp32_flat_param_groups_of_current_rank[group_id]
|
||||
fp16_param.data.copy_(fp32_param)
|
||||
|
||||
self.broadcast_params()
|
||||
torch.cuda.synchronize()
|
||||
with torch.cuda.stream(self._comm_bcast_stream):
|
||||
self.broadcast_params()
|
||||
|
||||
timer("step").stop()
|
||||
|
||||
|
|
@ -976,3 +988,17 @@ class HybridZeroOptimizer(BaseOptimizer):
|
|||
|
||||
if "zero_devide_optim_plan" in states:
|
||||
self.params_per_rank_id_dict = states["zero_devide_optim_plan"]
|
||||
|
||||
|
||||
def reload_zero_fp32_buff(optimizer):
|
||||
# If we use AMP optimizer, we need to update its fp32 buffer as newly loaded weights value.
|
||||
# Or we must ensure that loading model weights must be done before zero is initialized.
|
||||
if isinstance(optimizer, HybridZeroOptimizer):
|
||||
for group_id, param_group in enumerate(optimizer.optim.param_groups):
|
||||
if optimizer.param_group_has_params[group_id]:
|
||||
# flatten fp16 params have already been updated by 'load_model_checkpoint'
|
||||
fp16_flat_current_rank = optimizer._param_store.get_flat_fp16_param_by_rank_group(
|
||||
optimizer._zero_local_rank, group_id
|
||||
)
|
||||
# param_group["params"] is fp32 flatten optimizer states of this zero rank.
|
||||
param_group["params"][0].data.copy_(fp16_flat_current_rank.float())
|
||||
|
|
|
|||
|
|
@ -249,11 +249,17 @@ class ParameterStore(BaseStore):
|
|||
if not last_bucket:
|
||||
if group_id not in self._former_bucket_reduced_param:
|
||||
return [], []
|
||||
return self._former_bucket_reduced_param[group_id], self._former_bucket_reduced_grad[group_id]
|
||||
return (
|
||||
self._former_bucket_reduced_param[group_id],
|
||||
self._former_bucket_reduced_grad[group_id],
|
||||
)
|
||||
else:
|
||||
if group_id not in self._last_bucket_reduced_param:
|
||||
return [], []
|
||||
return self._last_bucket_reduced_param[group_id], self._last_bucket_reduced_grad[group_id]
|
||||
return (
|
||||
self._last_bucket_reduced_param[group_id],
|
||||
self._last_bucket_reduced_grad[group_id],
|
||||
)
|
||||
|
||||
def reset_reduced_data_for_compute_norm(self):
|
||||
self._former_bucket_reduced_param = {}
|
||||
|
|
@ -277,6 +283,9 @@ class TensorBucket:
|
|||
self._max_size = size
|
||||
self._current_size = 0
|
||||
self._bucket = []
|
||||
self._flat_tensor = None
|
||||
self._unflatten_and_copy_flag = False
|
||||
self.commu_handle = None
|
||||
|
||||
@property
|
||||
def max_size(self):
|
||||
|
|
@ -292,6 +301,15 @@ class TensorBucket:
|
|||
def is_empty(self):
|
||||
return len(self._bucket) == 0
|
||||
|
||||
def set_unflatten_and_copy_flag(self, flag):
|
||||
self._unflatten_and_copy_flag = flag
|
||||
|
||||
def get_unflatten_and_copy_flag(self):
|
||||
return self._unflatten_and_copy_flag
|
||||
|
||||
def get_flat_tensor(self):
|
||||
return self._flat_tensor
|
||||
|
||||
def add_to_bucket(self, tensor, allow_oversize=False):
|
||||
tensor_size = tensor.numel()
|
||||
|
||||
|
|
@ -312,11 +330,14 @@ class TensorBucket:
|
|||
def empty(self):
|
||||
self._bucket = []
|
||||
self._size = 0
|
||||
self._flat_tensor = None
|
||||
self.commu_handle = None
|
||||
|
||||
def flatten(self):
|
||||
return _flatten_dense_tensors(self._bucket)
|
||||
self._flat_tensor = _flatten_dense_tensors(self._bucket)
|
||||
|
||||
def unflatten_and_copy(self, flat_tensor):
|
||||
unflattened_tensor_list = _unflatten_dense_tensors(flat_tensor, self._bucket)
|
||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
||||
def unflatten_and_copy(self):
|
||||
if self._unflatten_and_copy_flag:
|
||||
unflattened_tensor_list = _unflatten_dense_tensors(self._flat_tensor, self._bucket)
|
||||
for old, new in zip(self._bucket, unflattened_tensor_list):
|
||||
old.copy_(new)
|
||||
|
|
|
|||
|
|
@ -95,37 +95,34 @@ def reduce_tensor(tensor, dtype=None, dst_rank=None, parallel_mode=ParallelMode.
|
|||
:type parallel_mode: ParallelMode, optional
|
||||
"""
|
||||
# use the original dtype
|
||||
if dtype is None:
|
||||
dtype = tensor.dtype
|
||||
# if dtype is None:
|
||||
assert dtype is None
|
||||
dtype = tensor.dtype
|
||||
|
||||
# cast the data to specified dtype for reduce/all-reduce
|
||||
if tensor.dtype != dtype:
|
||||
tensor_to_reduce = tensor.to(dtype)
|
||||
else:
|
||||
tensor_to_reduce = tensor
|
||||
# if tensor.dtype != dtype:
|
||||
# tensor_to_reduce = tensor.to(dtype)
|
||||
# else:
|
||||
# tensor_to_reduce = tensor
|
||||
|
||||
world_size = gpc.get_world_size(parallel_mode)
|
||||
# world_size = gpc.get_world_size(parallel_mode)
|
||||
# tensor.div_(world_size)
|
||||
group = gpc.get_group(parallel_mode)
|
||||
tensor_to_reduce.div_(world_size)
|
||||
|
||||
# if rank is None, all reduce will be used
|
||||
# else, reduce is used
|
||||
use_all_reduce = dst_rank is None
|
||||
|
||||
if use_all_reduce:
|
||||
dist.all_reduce(tensor_to_reduce, group=group)
|
||||
handle = dist.all_reduce(tensor=tensor, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True)
|
||||
else:
|
||||
ranks_in_group = gpc.get_ranks_in_group(parallel_mode)
|
||||
global_rank = ranks_in_group[dst_rank]
|
||||
dist.reduce(tensor=tensor_to_reduce, dst=global_rank, group=group)
|
||||
handle = dist.reduce(
|
||||
tensor=tensor, dst=global_rank, group=group, op=torch.distributed.ReduceOp.AVG, async_op=True
|
||||
)
|
||||
|
||||
# recover the original dtype
|
||||
if tensor.dtype != dtype and tensor is not tensor_to_reduce:
|
||||
local_rank = gpc.get_local_rank(parallel_mode)
|
||||
if use_all_reduce or dst_rank == local_rank:
|
||||
tensor.copy_(tensor_to_reduce)
|
||||
|
||||
return tensor
|
||||
return handle
|
||||
|
||||
|
||||
def has_inf_or_nan(tensor):
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ from torch.utils.data import ConcatDataset, DataLoader
|
|||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context.random import set_mode
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.data.batch_sampler import StaticBatchSampler, get_dpsampler_dataloader
|
||||
|
|
@ -24,7 +25,7 @@ from internlm.data.packed_dataset import (
|
|||
get_packed_dataset_without_short_length,
|
||||
)
|
||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.monitor import set_env_var
|
||||
from internlm.monitor import send_heartbeat, set_env_var
|
||||
from internlm.monitor.monitor import monitor_manager as mm
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.lr_scheduler import FineTuneCosineAnnealingWarmupLR
|
||||
|
|
@ -39,6 +40,7 @@ from internlm.utils.parallel import (
|
|||
sync_model_param_within_tp,
|
||||
)
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
||||
from torch.distributed.fsdp.fully_sharded_data_parallel import (
|
||||
|
|
@ -54,6 +56,7 @@ from internlm.model.modeling_internlm import PackedFlashBaseLayer1D, PackedFlash
|
|||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
@llm_timeout(func_name="initialize_model")
|
||||
def initialize_model():
|
||||
"""
|
||||
Initialize model with Automatic Mixed Precision.
|
||||
|
|
@ -93,6 +96,10 @@ def initialize_model():
|
|||
# the same across tensor parallelism.
|
||||
sync_model_param_within_tp(model)
|
||||
|
||||
# Change random state mode to ParallelMode.DATA after model is built, guaranteeing the random
|
||||
# state in the same dp group are all the same.
|
||||
set_mode(ParallelMode.DATA)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
|
|
@ -114,6 +121,7 @@ def warp_FSDP_model(model: Union[nn.Module, nn.ModuleList]):
|
|||
return model
|
||||
|
||||
|
||||
@llm_timeout(func_name="initialize_optimizer")
|
||||
def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
||||
"""
|
||||
Initialize optimizer.
|
||||
|
|
@ -158,6 +166,7 @@ def initialize_optimizer(model: Union[nn.Module, nn.ModuleList]):
|
|||
return optimizer, beta2_scheduler, lr_scheduler
|
||||
|
||||
|
||||
@llm_timeout(func_name="get_train_data_loader")
|
||||
def get_train_data_loader(
|
||||
num_worker: int = 0, dataset_generate_func: Callable = None, train_sampler=None, train_collate_fn=None
|
||||
):
|
||||
|
|
@ -237,6 +246,7 @@ def get_train_data_loader(
|
|||
return train_dl, dataset_types
|
||||
|
||||
|
||||
@llm_timeout(func_name="get_validation_data_loader")
|
||||
def get_validation_data_loader(
|
||||
num_worker: int = 0, dataset_generate_func: Callable = None, val_collate_fn=None, dataloader_func=None
|
||||
):
|
||||
|
|
@ -298,6 +308,7 @@ def get_validation_data_loader(
|
|||
return val_dls
|
||||
|
||||
|
||||
@llm_timeout(func_name="load_new_batch")
|
||||
def load_new_batch(train_dl: DataLoader, train_iter: Iterable, train_state: TrainState):
|
||||
"""
|
||||
Load and return the new batch data based on training data loader.
|
||||
|
|
@ -355,6 +366,7 @@ def initialize_llm_profile(profiling: bool = False, start_time: str = None):
|
|||
)
|
||||
|
||||
|
||||
@llm_timeout(func_name="record_current_batch_training_metrics")
|
||||
def record_current_batch_training_metrics(
|
||||
get_tflops_func,
|
||||
logger,
|
||||
|
|
@ -440,6 +452,9 @@ def record_current_batch_training_metrics(
|
|||
else:
|
||||
writer.add_scalar(key=key, value=value, step=train_state.step_count)
|
||||
|
||||
if gpc.config.monitor.alert.get("light_monitor_address", None) and batch_count % 50 == 0:
|
||||
send_heartbeat("train_metrics", infos)
|
||||
|
||||
if update_panel:
|
||||
# metrics shown with dashboard panels
|
||||
panel_metrics = {
|
||||
|
|
@ -465,4 +480,8 @@ def record_current_batch_training_metrics(
|
|||
logger.info(line)
|
||||
|
||||
# if loss spike occurs, send alert info to feishu
|
||||
mm.monitor_loss_spike(alert_address=gpc.config.alert_address, step_count=batch_count, cur_step_loss=loss.item())
|
||||
mm.monitor_loss_spike(
|
||||
alert_address=gpc.config.monitor.alert.feishu_alert_address,
|
||||
step_count=batch_count,
|
||||
cur_step_loss=loss.item(),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ def initialize_uniscale_logger(
|
|||
job_name and launch_time and file_name
|
||||
), "If file_path is None, job_name, launch_time and file_name must be setted."
|
||||
log_file_name = file_name
|
||||
log_folder = os.path.join(job_name, launch_time, "logs")
|
||||
log_folder = os.path.join("RUN", job_name, launch_time, "logs")
|
||||
log_dir = os.path.join(log_folder, log_file_name)
|
||||
file_path = log_dir
|
||||
|
||||
|
|
|
|||
|
|
@ -3,37 +3,136 @@
|
|||
|
||||
import copy
|
||||
import fcntl
|
||||
import inspect
|
||||
import os
|
||||
import socket
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Dict
|
||||
from typing import Callable, Dict, Union
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.initialize.launch import get_config_value
|
||||
from internlm.initialize.legacy.launch import (
|
||||
auto_resume_sanity_check,
|
||||
ckpt_info_sanity_check,
|
||||
)
|
||||
from internlm.monitor import send_alert_message
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer
|
||||
from internlm.solver.optimizer import HybridZeroOptimizer, reload_zero_fp32_buff
|
||||
from internlm.utils.common import get_current_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.storage_manager import (
|
||||
get_fns,
|
||||
get_storage_manager,
|
||||
init_storage_manager,
|
||||
llm_load,
|
||||
llm_save,
|
||||
try_get_storage_backend,
|
||||
)
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class CheckpointType(Enum):
|
||||
class CheckpointSaveType(Enum):
|
||||
NORMAL_CHECKPOINT = 1
|
||||
SNAPSHOT_CHECKPOINT = 2
|
||||
|
||||
|
||||
class CheckpointLoadType(Enum):
|
||||
INTERNLM = "internlm"
|
||||
|
||||
|
||||
# The load method implemented by internlm by default does not use string representation types,
|
||||
# but uses enumeration types defined in advance.
|
||||
LOAD_TYPE_DICT = {
|
||||
"internlm": CheckpointLoadType.INTERNLM,
|
||||
}
|
||||
|
||||
|
||||
class CheckpointLoadContent:
|
||||
MODEL = "model"
|
||||
SAMPLER = "sampler"
|
||||
OPIMIZER = "optimizer"
|
||||
SCHEDULAER = "scheduler"
|
||||
|
||||
|
||||
class CheckpointLoadMethod:
|
||||
"""The registration class of the checkpoint loading method,
|
||||
users can define their own custom ckpt loading methods."""
|
||||
|
||||
LOAD_FUNC_SIG = None
|
||||
LOAD_TYPE_FUNC = {}
|
||||
|
||||
@staticmethod
|
||||
def convet_load_type(load_type: str) -> Union[CheckpointLoadType, str]:
|
||||
if load_type.lower() in LOAD_TYPE_DICT:
|
||||
# The ckpt load method implemented by internlm by default.
|
||||
return LOAD_TYPE_DICT[load_type.lower()]
|
||||
else:
|
||||
# If it is a user-defined field, we do not do any conversion and represent it as a string.
|
||||
return load_type
|
||||
|
||||
@staticmethod
|
||||
def register_ckpt_load_type(load_type: Union[str, CheckpointLoadType], load_func: Callable):
|
||||
if load_type in CheckpointLoadMethod.LOAD_TYPE_FUNC:
|
||||
logger.warning(f"{load_type} has aleady been registed!")
|
||||
return
|
||||
|
||||
CheckpointLoadMethod.LOAD_TYPE_FUNC.update({load_type: load_func})
|
||||
|
||||
if load_type == CheckpointLoadType.INTERNLM:
|
||||
CheckpointLoadMethod.LOAD_FUNC_SIG = inspect.signature(load_func)
|
||||
else:
|
||||
if inspect.signature(load_func) != CheckpointLoadMethod.LOAD_FUNC_SIG:
|
||||
logger.warning(
|
||||
f"registe load model ckpt signature is not same with: {CheckpointLoadMethod.LOAD_FUNC_SIG}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_ckpt_load_type_func(load_type: Union[str, CheckpointLoadType]):
|
||||
return CheckpointLoadMethod.LOAD_TYPE_FUNC[load_type]
|
||||
|
||||
|
||||
class CheckpointLoadMask:
|
||||
"""
|
||||
According to the content field in the incoming ckpt_info, decide which components to load.
|
||||
"""
|
||||
|
||||
LOAD_CONTENT_DICT = {
|
||||
"model": CheckpointLoadContent.MODEL,
|
||||
"sampler": CheckpointLoadContent.SAMPLER,
|
||||
"optimizer": CheckpointLoadContent.OPIMIZER,
|
||||
"scheduler": CheckpointLoadContent.SCHEDULAER,
|
||||
}
|
||||
|
||||
def __init__(self, content: tuple) -> None:
|
||||
self.load_set = set(map(lambda x: x.lower(), content))
|
||||
if "all" in self.load_set:
|
||||
self.load_set = set(CheckpointLoadMask.LOAD_CONTENT_DICT.values())
|
||||
else:
|
||||
self.load_set = set(map(lambda x: CheckpointLoadMask.LOAD_CONTENT_DICT[x.lower()], content))
|
||||
|
||||
def need_load(self, content: CheckpointLoadContent):
|
||||
return content in self.load_set
|
||||
|
||||
def not_only_load(self, content: CheckpointLoadContent):
|
||||
return content in self.load_set and len(self.load_set) > 1
|
||||
|
||||
def only_load(self, content: CheckpointLoadContent):
|
||||
return set((content,)) == self.load_set
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.load_set}."
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.load_set}."
|
||||
|
||||
|
||||
def get_model_topology(model):
|
||||
"""
|
||||
Returns:
|
||||
|
|
@ -75,6 +174,66 @@ def get_state_dict(model):
|
|||
return states
|
||||
|
||||
|
||||
def try_load_internlm_ckpt(ckpt_mm, load_info, train_state: TrainState):
|
||||
load_content_str = ""
|
||||
load_ckpt_folder = load_info["path"]
|
||||
load_content: CheckpointLoadMask = load_info["content"]
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Try load_ckpt_folder: {load_ckpt_folder}")
|
||||
|
||||
if load_content.need_load(CheckpointLoadContent.MODEL):
|
||||
load_model_checkpoint(folder=load_ckpt_folder, model=ckpt_mm.model)
|
||||
load_content_str += f"{CheckpointLoadContent.MODEL}, "
|
||||
|
||||
if load_content.not_only_load(CheckpointLoadContent.MODEL):
|
||||
# load training states.
|
||||
load_context(load_ckpt_folder, train_state)
|
||||
|
||||
# load optimzier states.
|
||||
if load_content.need_load(CheckpointLoadContent.OPIMIZER):
|
||||
load_optimizer_checkpoint(load_ckpt_folder, ckpt_mm.optimizer)
|
||||
load_content_str += f"{CheckpointLoadContent.OPIMIZER}, "
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("CheckpointManager has no 'optimizer', skip reload optim checkpoint!")
|
||||
|
||||
# load lr scheduler states.
|
||||
if load_content.need_load(CheckpointLoadContent.SCHEDULAER):
|
||||
if ckpt_mm.lr_scheduler:
|
||||
load_scheduler(load_ckpt_folder, ckpt_mm.lr_scheduler, ckpt_mm.optimizer, train_state)
|
||||
load_content_str += f"{CheckpointLoadContent.SCHEDULAER}, "
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("CheckpointManager has no 'lr_scheduler', skip reload lr_scheduler checkpoint!")
|
||||
|
||||
# load dataloader sampler states.
|
||||
if load_content.need_load(CheckpointLoadContent.SAMPLER):
|
||||
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||
):
|
||||
load_sampler(load_ckpt_folder, ckpt_mm.train_dl.batch_sampler)
|
||||
# track the actual updates of sampler when using weighted sampling
|
||||
train_state.init_batch_sampler(ckpt_mm.train_dl.batch_sampler)
|
||||
load_content_str += f"{CheckpointLoadContent.SAMPLER}, "
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning("CheckpointManager skip reload 'batch_sampler'")
|
||||
|
||||
# reload data state dict.
|
||||
if hasattr(train_state, "data_state_dict"):
|
||||
ckpt_mm.train_dl.dataset.load_state_dict(
|
||||
llm_load(os.path.join(load_ckpt_folder, "sampler_0.pt")), ckpt_path=load_ckpt_folder
|
||||
)
|
||||
load_content_str += f"{CheckpointLoadContent.SAMPLER}, "
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
"CheckpointManager has no 'data_state_dict', skip reload data_state_dict checkpoint!"
|
||||
)
|
||||
return load_content_str
|
||||
|
||||
|
||||
def save_model_checkpoint(folder, model):
|
||||
"""
|
||||
Save the model according to the relationship between tp and dp. The principle is that the data of each tp
|
||||
|
|
@ -257,15 +416,16 @@ def load_sampler(ckpt_path: str, sampler):
|
|||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_context(ckpt_path: str, train_dl, train_state: TrainState):
|
||||
def load_context(ckpt_path: str, train_state: TrainState):
|
||||
context_stuffs = llm_load(os.path.join(ckpt_path, "context.pt"))
|
||||
train_state.load_state_dict(context_stuffs, train_dl)
|
||||
train_state.load_state_dict(context_stuffs)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"reload train_state:{train_state}")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState):
|
||||
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainState):
|
||||
learning_rate = train_state.lr
|
||||
scheduler_states = llm_load(os.path.join(ckpt_path, "schedulder.pt"))
|
||||
if learning_rate != scheduler_states["base_lrs"][0] and gpc.is_rank_for_log():
|
||||
logger.warning(
|
||||
|
|
@ -294,7 +454,17 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train
|
|||
class CheckpointManager:
|
||||
"""StorageManagerContext"""
|
||||
|
||||
def __init__(self, ckpt_config, model, model_config=None, model_config_file=None, feishu_address=None) -> None:
|
||||
def __init__(
|
||||
self,
|
||||
ckpt_config,
|
||||
model,
|
||||
train_dl=None,
|
||||
optimizer=None,
|
||||
lr_scheduler=None,
|
||||
model_config=None,
|
||||
model_config_file=None,
|
||||
feishu_address=None,
|
||||
) -> None:
|
||||
"""
|
||||
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
|
||||
upload mode, you must call wait_async_upload_finish at the end of the program to wait
|
||||
|
|
@ -307,22 +477,44 @@ class CheckpointManager:
|
|||
lr_scheduler (object): lr_scheduler obj.
|
||||
model_config (dict): model config.
|
||||
"""
|
||||
self.enable_save_ckpt = ckpt_config.enable_save_ckpt
|
||||
self.checkpoint_every = ckpt_config.checkpoint_every
|
||||
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
|
||||
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
|
||||
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
|
||||
self.stop_file_path = ckpt_config.stop_file_path
|
||||
self.load_model_only_folder = ckpt_config.load_model_only_folder
|
||||
self.enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
|
||||
self.checkpoint_every = get_config_value(ckpt_config, "checkpoint_every", 100)
|
||||
self.save_ckpt_folder = get_config_value(ckpt_config, "save_ckpt_folder", None)
|
||||
self.oss_snapshot_freq: int = get_config_value(ckpt_config, "oss_snapshot_freq", 50)
|
||||
self.stop_file_path = get_config_value(ckpt_config, "stop_file_path", None)
|
||||
if self.save_ckpt_folder:
|
||||
self.snapshot_ckpt_folder = get_config_value(
|
||||
ckpt_config, "snapshot_ckpt_folder", os.path.join(self.save_ckpt_folder, "snapshot")
|
||||
)
|
||||
self.async_upload_tmp_folder = get_config_value(
|
||||
ckpt_config, "async_upload_tmp_folder", "/dev/shm/internlm_tmp_ckpt/"
|
||||
)
|
||||
else:
|
||||
self.snapshot_ckpt_folder = None
|
||||
self.async_upload_tmp_folder = None
|
||||
|
||||
self.async_upload = get_config_value(ckpt_config, "async_upload", False)
|
||||
|
||||
# initialization storage manager
|
||||
init_storage_manager(self.enable_save_ckpt, self.async_upload_tmp_folder, self.async_upload)
|
||||
|
||||
self.feishu_address = feishu_address
|
||||
self.storage_manager = get_storage_manager()
|
||||
self.snapshot_counter = 0
|
||||
self.load_optimizer = gpc.config.ckpt.load_optimizer
|
||||
|
||||
self.model = model
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
self.train_dl = train_dl
|
||||
self.model_config = model_config
|
||||
self.model_config_file = model_config_file
|
||||
|
||||
# Register defalut internlm ckpt load type.
|
||||
self.defalut_load_type_func = {CheckpointLoadType.INTERNLM: try_load_internlm_ckpt}
|
||||
for ckpt_load_type in CheckpointLoadType:
|
||||
CheckpointLoadMethod.register_ckpt_load_type(ckpt_load_type, self.defalut_load_type_func[ckpt_load_type])
|
||||
|
||||
# Init alter file.
|
||||
if self.stop_file_path and gpc.get_global_rank() == 0:
|
||||
dir_path = os.path.dirname(self.stop_file_path)
|
||||
if dir_path != "" and not os.path.exists(dir_path):
|
||||
|
|
@ -330,21 +522,35 @@ class CheckpointManager:
|
|||
with open(self.stop_file_path, "w", encoding="utf-8") as f:
|
||||
f.write("0")
|
||||
|
||||
if ckpt_config.load_given_ckpt is False:
|
||||
# Priority: load_given_ckpt(True) > latest_checkpoint > load_model_only_folder
|
||||
latest_ckpt_path = self.query_lastest_ckpt()
|
||||
if latest_ckpt_path:
|
||||
self.load_ckpt_folder = latest_ckpt_path
|
||||
else:
|
||||
# At this time, we have to load model init weights and train from step 0.
|
||||
self.load_ckpt_folder = self.load_model_only_folder
|
||||
else:
|
||||
self.load_ckpt_folder = ckpt_config.load_ckpt_folder
|
||||
self.load_ckpt_info = get_config_value(ckpt_config, "load_ckpt_info", None)
|
||||
if self.load_ckpt_info is None: # (legacy): Try Compatible with old interfaces
|
||||
self.load_ckpt_info = ckpt_info_sanity_check(ckpt_config)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"load_ckpt_folder will set to :'{self.load_ckpt_folder}'")
|
||||
if self.stop_file_path is None:
|
||||
logger.warning("no set stop_file_path, quit_signal_handler is disable")
|
||||
# Auto-reload latest checkpoint, it will overwrite the setting of 'load_ckpt_info'.
|
||||
self.auto_resume = get_config_value(ckpt_config, "auto_resume", None)
|
||||
if self.auto_resume is None: # (legacy): Try Compatible with old interfaces
|
||||
self.auto_resume = auto_resume_sanity_check(ckpt_config)
|
||||
if self.auto_resume:
|
||||
self.load_ckpt_info = self.query_lastest_ckpt()
|
||||
|
||||
if self.stop_file_path is None and gpc.is_rank_for_log():
|
||||
logger.warning("no set stop_file_path, quit_signal_handler is disable")
|
||||
|
||||
# convert to internal representation
|
||||
if self.load_ckpt_info:
|
||||
assert (
|
||||
"path" in self.load_ckpt_info
|
||||
and "content" in self.load_ckpt_info
|
||||
and "ckpt_type" in self.load_ckpt_info
|
||||
), "please set content in ckpt setting, eg: ckpt = dict(path='', content=['model'], ckpt_type='internlm')"
|
||||
|
||||
# replace load_ckpt
|
||||
self.load_ckpt_info["content"] = CheckpointLoadMask(self.load_ckpt_info["content"])
|
||||
self.load_ckpt_info["ckpt_type"] = CheckpointLoadMethod.convet_load_type(self.load_ckpt_info["ckpt_type"])
|
||||
|
||||
# test storage setting is ok.
|
||||
if self.enable_save_ckpt:
|
||||
self.try_ping_storage()
|
||||
|
||||
def quit_signal_handler(self, train_state) -> bool:
|
||||
"""
|
||||
|
|
@ -358,7 +564,7 @@ class CheckpointManager:
|
|||
Returns:
|
||||
bool: whether to quit.
|
||||
"""
|
||||
now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT
|
||||
now_break, now_save_ckpt, save_type = False, False, CheckpointSaveType.NORMAL_CHECKPOINT
|
||||
|
||||
if self.stop_file_path is None:
|
||||
return now_break, now_save_ckpt, save_type
|
||||
|
|
@ -389,24 +595,29 @@ now step_count is {train_state.step_count}",
|
|||
|
||||
return now_break, now_save_ckpt, save_type
|
||||
|
||||
def try_save_checkpoint(self, train_state):
|
||||
if not self.enable_save_ckpt:
|
||||
return False
|
||||
|
||||
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
|
||||
def is_now_to_save_ckpt(self, train_state) -> (bool, CheckpointSaveType, bool):
|
||||
save_ckpts, save_type, now_break = False, CheckpointSaveType.NORMAL_CHECKPOINT, False
|
||||
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
|
||||
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
|
||||
save_ckpts, save_type = True, CheckpointSaveType.SNAPSHOT_CHECKPOINT
|
||||
if train_state.step_count % self.checkpoint_every == 0:
|
||||
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
|
||||
save_ckpts, save_type = True, CheckpointSaveType.NORMAL_CHECKPOINT
|
||||
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
|
||||
if save_ckpts is False:
|
||||
save_ckpts = singal_save_ckpts
|
||||
save_type = singal_save_type
|
||||
|
||||
return save_ckpts, save_type, now_break
|
||||
|
||||
def try_save_checkpoint(self, train_state):
|
||||
if not self.enable_save_ckpt:
|
||||
return False
|
||||
|
||||
save_ckpts, save_type, now_break = self.is_now_to_save_ckpt(train_state)
|
||||
|
||||
if save_ckpts:
|
||||
# Wait for the previous round of asynchronous upload storage to complete.
|
||||
self.storage_manager.wait()
|
||||
if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
|
||||
if save_type == CheckpointSaveType.SNAPSHOT_CHECKPOINT:
|
||||
# Snapshot number, with only two snapshots written alternately.
|
||||
self.snapshot_counter = (self.snapshot_counter + 1) % 2
|
||||
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
|
||||
|
|
@ -436,7 +647,7 @@ now step_count is {train_state.step_count}",
|
|||
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
|
||||
"""
|
||||
ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder)
|
||||
if len(ckpt_list) == 0:
|
||||
if ckpt_list is None or len(ckpt_list) == 0:
|
||||
return None, None
|
||||
|
||||
max_normal_step = 0
|
||||
|
|
@ -459,14 +670,16 @@ now step_count is {train_state.step_count}",
|
|||
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
|
||||
ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
|
||||
max_step_0, max_step_1 = 0, 0
|
||||
for ckpt in ckpt_list_1:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
|
||||
for ckpt in ckpt_list_2:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
|
||||
if ckpt_list_1:
|
||||
for ckpt in ckpt_list_1:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
|
||||
if ckpt_list_2:
|
||||
for ckpt in ckpt_list_2:
|
||||
ckpt = ckpt.strip("/")
|
||||
if ckpt.endswith(".step"):
|
||||
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
|
||||
|
||||
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
|
||||
snap_step = max(max_step_0, max_step_1)
|
||||
|
|
@ -476,11 +689,12 @@ now step_count is {train_state.step_count}",
|
|||
|
||||
def query_latest_snapshot_step_local(self):
|
||||
max_step, max_step_path = 0, None
|
||||
for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True):
|
||||
save_ckpt_folder = self.save_ckpt_folder.split(":")[1]
|
||||
for root, _, files in os.walk(save_ckpt_folder, followlinks=True):
|
||||
for fn in files:
|
||||
fn = fn.strip("/")
|
||||
if fn.endswith(".step"):
|
||||
# We assume that both normal ckpt and snapshot ckpt will store the '.step' file
|
||||
# We assume that both internlm ckpt and snapshot ckpt will store the '.step' file
|
||||
# as an integrity flag.
|
||||
step = int(fn.rsplit(".", maxsplit=1)[0])
|
||||
if max_step < step:
|
||||
|
|
@ -490,100 +704,55 @@ now step_count is {train_state.step_count}",
|
|||
return max_step_path, max_step
|
||||
|
||||
def query_lastest_ckpt(self):
|
||||
latest_checkpoint = None
|
||||
latest_ckpt, step = None, -1
|
||||
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
|
||||
if self.save_ckpt_folder:
|
||||
if self.save_ckpt_folder.startswith("boto3"):
|
||||
latest_checkpoint, step = self.query_latest_snapshot_step_boto3()
|
||||
elif self.save_ckpt_folder.startswith("local"):
|
||||
latest_checkpoint, step = self.query_latest_snapshot_step_local()
|
||||
else:
|
||||
latest_checkpoint, step = None, 0
|
||||
backend, _ = try_get_storage_backend(self.save_ckpt_folder)
|
||||
if backend == "boto3":
|
||||
latest_ckpt, step = self.query_latest_snapshot_step_boto3()
|
||||
if latest_ckpt and not latest_ckpt.startswith("boto3:"):
|
||||
latest_ckpt = ":".join(["boto3", latest_ckpt])
|
||||
elif backend == "local":
|
||||
latest_ckpt, step = self.query_latest_snapshot_step_local()
|
||||
if latest_ckpt and not latest_ckpt.startswith("local:"):
|
||||
latest_ckpt = ":".join(["local", latest_ckpt])
|
||||
|
||||
if latest_checkpoint is not None:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}")
|
||||
send_alert_message(
|
||||
address=self.feishu_address,
|
||||
message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}",
|
||||
)
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
send_alert_message(
|
||||
address=self.feishu_address,
|
||||
message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}",
|
||||
)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...")
|
||||
|
||||
return latest_checkpoint
|
||||
return dict(path=latest_ckpt, content=("all",), ckpt_type="internlm")
|
||||
|
||||
def try_load_model(self, current_time=""):
|
||||
model_load_path = None
|
||||
def try_resume_training(self, train_state: TrainState, current_time=""):
|
||||
|
||||
if self.load_ckpt_folder and self.load_model_only_folder:
|
||||
raise ValueError(
|
||||
"Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
|
||||
if you only need to load model weights (for example starting an SFT task for the first time), \
|
||||
set load_model_only_folder path, if you need to resume training from ckpt, \
|
||||
set load_ckpt_folder or use default value \
|
||||
(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
|
||||
)
|
||||
|
||||
if self.load_ckpt_folder:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = self.load_ckpt_folder
|
||||
elif self.load_model_only_folder:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========Load Model from `{self.load_model_only_folder}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
model_load_path = self.load_model_only_folder
|
||||
else:
|
||||
if self.load_ckpt_info is None or self.load_ckpt_info["path"] is None:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(
|
||||
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
||||
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
||||
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
||||
)
|
||||
else:
|
||||
load_path = self.load_ckpt_info["path"]
|
||||
load_content = self.load_ckpt_info["content"]
|
||||
load_type = self.load_ckpt_info["ckpt_type"]
|
||||
|
||||
# Loading model weights must be done before zero is initialized.
|
||||
if model_load_path is not None:
|
||||
load_model_checkpoint(folder=model_load_path, model=self.model)
|
||||
load_func = CheckpointLoadMethod.get_ckpt_load_type_func(load_type)
|
||||
load_content_str = load_func(self, self.load_ckpt_info, train_state)
|
||||
|
||||
def try_resume_training(self, lr_scheduler, optimizer, lr, train_state, train_dl):
|
||||
"""Attempt to restore the training state of the last ckpt.
|
||||
# If we only load model weight, we need rewrite zero optim's fp32 buffer.
|
||||
if load_content.only_load(CheckpointLoadContent.MODEL) and isinstance(self.optimizer, HybridZeroOptimizer):
|
||||
reload_zero_fp32_buff(self.optimizer)
|
||||
|
||||
Args:
|
||||
lr_scheduler (_LRScheduler): lr_scheduler object.
|
||||
optimizer (Optimizer): optimizer object.
|
||||
lr (float): learning rate.
|
||||
train_state (dict): traing states.
|
||||
train_dl (DataLoader): traning dataloader object
|
||||
"""
|
||||
if self.load_ckpt_folder is not None:
|
||||
# load optimzier states.
|
||||
if self.load_optimizer:
|
||||
load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
|
||||
# load lr scheduler states.
|
||||
load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
|
||||
# load training states.
|
||||
load_context(self.load_ckpt_folder, train_dl, train_state)
|
||||
# load dataloader sampler states.
|
||||
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||
):
|
||||
load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
|
||||
if hasattr(train_state, "data_state_dict"):
|
||||
train_dl.dataset.load_state_dict(
|
||||
llm_load(os.path.join(self.load_ckpt_folder, "sampler_0.pt")), ckpt_path=self.load_ckpt_folder
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"load_ckpt_info : {self.load_ckpt_info}")
|
||||
logger.info(
|
||||
f"===========Resume training from `{load_path}` {current_time} on host:"
|
||||
f"{socket.gethostname()}==========="
|
||||
)
|
||||
self.optimizer = optimizer
|
||||
self.lr_scheduler = lr_scheduler
|
||||
if load_content_str:
|
||||
logger.info(f"===========Load contents are: {load_content_str}")
|
||||
|
||||
@llm_timeout(func_name="save_checkpoint")
|
||||
def save_checkpoint(
|
||||
self,
|
||||
folder,
|
||||
|
|
@ -624,8 +793,10 @@ set load_ckpt_folder or use default value \
|
|||
)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
scheduler_states = scheduler.state_dict()
|
||||
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
||||
if scheduler:
|
||||
scheduler_states = scheduler.state_dict()
|
||||
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
||||
|
||||
if hasattr(train_state, "batch_sampler") and not isinstance(
|
||||
train_state.batch_sampler, torch.utils.data.sampler.BatchSampler
|
||||
):
|
||||
|
|
@ -655,3 +826,12 @@ set load_ckpt_folder or use default value \
|
|||
def set_save_folder(self, folder, step):
|
||||
self.storage_manager.latest_save_folder = folder
|
||||
self.storage_manager.latest_save_step = step
|
||||
|
||||
def try_ping_storage(self):
|
||||
if gpc.get_global_rank() % 8 == 0:
|
||||
buff = torch.ones((1, 64, 64), dtype=torch.bfloat16)
|
||||
test_fn = os.path.join(self.save_ckpt_folder, f"pings/{socket.gethostname()}.ping")
|
||||
self.storage_manager.save(test_fn, buff)
|
||||
self.storage_manager.wait()
|
||||
self.storage_manager.load(test_fn)
|
||||
del buff
|
||||
|
|
|
|||
|
|
@ -46,12 +46,12 @@ def get_fns(fp: str):
|
|||
return storage_manager.get_fns(fp)
|
||||
|
||||
|
||||
def llm_load(fp: str, *args, **kwargs):
|
||||
return storage_manager.load(fp, *args, **kwargs)
|
||||
def llm_load(fp: str, **kwargs):
|
||||
return storage_manager.load(fp, **kwargs)
|
||||
|
||||
|
||||
def llm_save(save_path: str, saved_obj: Any, *args, **kwargs):
|
||||
storage_manager.save(save_path, *args, saved_obj=saved_obj, **kwargs)
|
||||
def llm_save(save_path: str, saved_obj: Any, **kwargs):
|
||||
storage_manager.save(save_path, to_save_obj=saved_obj, **kwargs)
|
||||
|
||||
|
||||
class StorageClient:
|
||||
|
|
@ -63,19 +63,23 @@ class StorageClient:
|
|||
self.handler = handler
|
||||
|
||||
@staticmethod
|
||||
def load(client, load_path: str, *args, **kwargs):
|
||||
def load(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def sync_upload_fileobj(*args, saved_obj=None, **kwargs):
|
||||
def sync_upload_fileobj(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def assert_fp_exists(client):
|
||||
def async_upload_fileobj(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_fns(client):
|
||||
def assert_fp_exists(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def get_fns(*args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
|
|
@ -92,40 +96,65 @@ class Boto3MetaInfo:
|
|||
async_upload_fn: callable,
|
||||
local_nvme_path=None,
|
||||
) -> None:
|
||||
self.is_async = is_async
|
||||
# all need info.
|
||||
self.client = handler
|
||||
self.bucket_name = bucket_name
|
||||
self.endpoint = endpoint
|
||||
self.file_path = file_path
|
||||
self.async_upload_fn = async_upload_fn
|
||||
# only save need info.
|
||||
self.local_nvme_path = local_nvme_path
|
||||
self.is_async = is_async
|
||||
self.endpoint = endpoint
|
||||
self.async_upload_fn = async_upload_fn
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
|
||||
local_nvme_path: {self.local_nvme_path}"
|
||||
|
||||
@staticmethod
|
||||
def unpack_boto3_save_meta(meta):
|
||||
if meta.is_async:
|
||||
return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path
|
||||
else:
|
||||
return meta.client, meta.bucket_name, meta.file_path
|
||||
|
||||
@staticmethod
|
||||
def unpack_boto3_nosave_meta(meta):
|
||||
return meta.client, meta.bucket_name, meta.file_path
|
||||
|
||||
|
||||
class LocalMetaInfo:
|
||||
"""Local meta info for save/load etc."""
|
||||
|
||||
def __init__(self, handler: StorageClient, dest_path: str) -> None:
|
||||
self.is_async = False
|
||||
self.client = handler
|
||||
self.dest_path = dest_path
|
||||
def __init__(self, file_path: str) -> None:
|
||||
self.file_path = file_path
|
||||
self.async_upload_fn = None
|
||||
self.is_async = False
|
||||
|
||||
@staticmethod
|
||||
def unpack_local_save_meta(meta):
|
||||
return (meta.file_path,)
|
||||
|
||||
@staticmethod
|
||||
def unpack_local_nosave_meta(meta):
|
||||
return (meta.file_path,)
|
||||
|
||||
|
||||
def unpack_meta(meta):
|
||||
args = []
|
||||
is_async = meta.is_async
|
||||
for k, v in meta.__dict__.items():
|
||||
if k in ("endpoint", "async_upload_fn", "is_async"):
|
||||
continue
|
||||
if not is_async and k in ("local_nvme_path",):
|
||||
continue
|
||||
args.append(v)
|
||||
def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
|
||||
if isinstance(meta, Boto3MetaInfo):
|
||||
return Boto3MetaInfo.unpack_boto3_save_meta(meta)
|
||||
elif isinstance(meta, LocalMetaInfo):
|
||||
return LocalMetaInfo.unpack_local_save_meta(meta)
|
||||
else:
|
||||
raise ValueError(f"unkonwn meta info: {type(meta)}")
|
||||
|
||||
return args
|
||||
|
||||
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
|
||||
if isinstance(meta, Boto3MetaInfo):
|
||||
return Boto3MetaInfo.unpack_boto3_nosave_meta(meta)
|
||||
elif isinstance(meta, LocalMetaInfo):
|
||||
return LocalMetaInfo.unpack_local_nosave_meta(meta)
|
||||
else:
|
||||
raise ValueError(f"unkonwn meta info: {type(meta)}")
|
||||
|
||||
|
||||
def compute_file_md5_by_chunk(file_name: str):
|
||||
|
|
@ -136,6 +165,22 @@ def compute_file_md5_by_chunk(file_name: str):
|
|||
return hash_md5.hexdigest()
|
||||
|
||||
|
||||
def try_get_storage_backend(path: str):
|
||||
sre = path.split(":", maxsplit=1)
|
||||
if len(sre) == 1:
|
||||
if path.startswith("s3:"):
|
||||
backend = "boto3"
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
|
||||
else:
|
||||
backend = "local"
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of local.")
|
||||
return backend, sre
|
||||
else:
|
||||
return sre[0], sre[1] # (backend_prefix, splited_path)
|
||||
|
||||
|
||||
class Boto3Client(StorageClient):
|
||||
"""
|
||||
Boto3Client
|
||||
|
|
@ -189,13 +234,11 @@ class Boto3Client(StorageClient):
|
|||
)
|
||||
|
||||
@staticmethod
|
||||
def sync_upload_fileobj(
|
||||
handler, bucket_name: str, fp: str, local_nvme_path: str, *args, saved_obj=None, **kwargs
|
||||
): # pylint: disable=W0613
|
||||
def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs):
|
||||
assert saved_obj is not None, "saved_obj is None!"
|
||||
try:
|
||||
with io.BytesIO() as f:
|
||||
torch.save(saved_obj, f, *args, **kwargs)
|
||||
torch.save(saved_obj, f, **kwargs)
|
||||
f.seek(0)
|
||||
handler.client.upload_fileobj(f, bucket_name, fp, Config=handler.config)
|
||||
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||
|
|
@ -204,14 +247,7 @@ class Boto3Client(StorageClient):
|
|||
) from exc
|
||||
|
||||
@staticmethod
|
||||
def load(
|
||||
handler,
|
||||
bucket_name: str,
|
||||
fp: str,
|
||||
local_nvme_path: str, # pylint: disable=W0613
|
||||
*args,
|
||||
**kwargs,
|
||||
) -> Dict:
|
||||
def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict:
|
||||
"""
|
||||
Args:
|
||||
fp (str): Path to save, eg. s3://opennlplab/model_weights/xxx/ddd.pt
|
||||
|
|
@ -220,7 +256,7 @@ class Boto3Client(StorageClient):
|
|||
with io.BytesIO() as f:
|
||||
handler.client.download_fileobj(bucket_name, fp, f, Config=handler.config)
|
||||
f.seek(0)
|
||||
states = torch.load(f, *args, **kwargs)
|
||||
states = torch.load(f, **kwargs)
|
||||
except handler.botocore.exceptions.EndpointConnectionError as exc:
|
||||
raise RuntimeError(
|
||||
f"Boto3 Network Error: Please Check your Internet Connection in {socket.gethostname()}"
|
||||
|
|
@ -228,24 +264,37 @@ class Boto3Client(StorageClient):
|
|||
return states
|
||||
|
||||
@staticmethod
|
||||
def assert_fp_exists(handler, bucket_name: str, fp: str, local_nvme_path: str): # pylint: disable=W0613
|
||||
def assert_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
|
||||
assert len(list(handler.client.list_objects(Bucket=bucket_name, Prefix=fp)["Contents"])) > 0, fp
|
||||
|
||||
@staticmethod
|
||||
def get_fns(handler, bucket_name: str, fp: str, local_nvme_path: str, *args, **kwargs): # pylint: disable=W0613
|
||||
def is_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
|
||||
re = handler.client.list_objects(Bucket=bucket_name, Prefix=fp)
|
||||
if "Contents" in re:
|
||||
return len(list(re["Contents"])) > 0
|
||||
else:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def get_fns(handler, bucket_name: str, fp: str):
|
||||
"""
|
||||
Ref: https://stackoverflow.com/questions/54314563/
|
||||
how-to-get-more-than-1000-objects-from-s3-by-using-list-objects-v2
|
||||
"""
|
||||
paginator = handler.client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
||||
folder_name_list = []
|
||||
for page in pages:
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
pth: str = obj["Key"]
|
||||
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||
return list(set(folder_name_list))
|
||||
if Boto3Client.is_fp_exists(handler, bucket_name, fp):
|
||||
paginator = handler.client.get_paginator("list_objects_v2")
|
||||
pages = paginator.paginate(Bucket=bucket_name, Prefix=fp)
|
||||
folder_name_list = []
|
||||
for page in pages:
|
||||
if "Contents" in page:
|
||||
for obj in page["Contents"]:
|
||||
pth: str = obj["Key"]
|
||||
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||
return list(set(folder_name_list))
|
||||
else:
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"'{fp}' not found!")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
|
||||
|
|
@ -273,37 +322,35 @@ class LocalClient(StorageClient):
|
|||
super().__init__(None)
|
||||
|
||||
@staticmethod
|
||||
def sync_upload_fileobj(handler, fp: str, *args, saved_obj=None, **kwargs):
|
||||
assert isinstance(handler, LocalClient)
|
||||
def sync_upload_fileobj(fp: str, saved_obj=None, **kwargs):
|
||||
assert saved_obj is not None
|
||||
fp_dirname = os.path.dirname(fp)
|
||||
if not os.path.exists(fp_dirname):
|
||||
os.makedirs(fp_dirname, exist_ok=True)
|
||||
torch.save(saved_obj, fp, *args, **kwargs)
|
||||
torch.save(saved_obj, fp, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def load(handler, fp: str, *args, **kwargs): # pylint: disable=W0613
|
||||
assert isinstance(handler, LocalClient)
|
||||
assert os.path.exists(fp), f"{fp} is not found!"
|
||||
with open(fp, "rb") as f:
|
||||
states = torch.load(f, *args, **kwargs)
|
||||
def load(load_path: str, **kwargs):
|
||||
assert os.path.exists(load_path), f"{load_path} is not found!"
|
||||
with open(load_path, "rb") as f:
|
||||
states = torch.load(f, **kwargs)
|
||||
return states
|
||||
|
||||
@staticmethod
|
||||
def assert_fp_exists(handler, folder):
|
||||
assert isinstance(handler, LocalClient)
|
||||
def assert_fp_exists(folder):
|
||||
assert os.path.exists(folder), folder
|
||||
|
||||
@staticmethod
|
||||
def get_fns(handler, folder):
|
||||
assert isinstance(handler, LocalClient)
|
||||
assert os.path.exists(folder), f"folder '{folder}' not exists!"
|
||||
fns = os.listdir(folder)
|
||||
return fns
|
||||
def get_fns(folder):
|
||||
if not os.path.exists(folder):
|
||||
if gpc.is_rank_for_log():
|
||||
logger.warning(f"'{folder}' not found!")
|
||||
return None
|
||||
else:
|
||||
return os.listdir(folder)
|
||||
|
||||
@staticmethod
|
||||
def delete_obj(handler, fp: str):
|
||||
assert isinstance(handler, LocalClient)
|
||||
def delete_obj(fp: str):
|
||||
if not os.path.isdir(fp):
|
||||
os.remove(fp)
|
||||
|
||||
|
|
@ -327,7 +374,10 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
|
|||
assert match is not None, f"url '{fp}' is not a valid boto3 url"
|
||||
bucket_name, endpoint = match.group(1), match.group(2)
|
||||
endpoint = "http://" + endpoint + ":80"
|
||||
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
|
||||
if is_async:
|
||||
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
|
||||
else:
|
||||
tmp_step_file = None
|
||||
return Boto3MetaInfo(
|
||||
is_async=is_async,
|
||||
handler=None,
|
||||
|
|
@ -341,7 +391,7 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
|
|||
|
||||
def get_local_meta(fp: str) -> LocalMetaInfo:
|
||||
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
|
||||
return LocalMetaInfo(None, fp)
|
||||
return LocalMetaInfo(fp)
|
||||
|
||||
|
||||
def get_mount_point_free_size(path: str):
|
||||
|
|
@ -427,7 +477,7 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
|
||||
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
|
||||
|
||||
def _get_client(self, path=str) -> Union[Boto3MetaInfo, LocalMetaInfo]:
|
||||
def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]:
|
||||
"""
|
||||
example:
|
||||
local:/path/to/checkpoint
|
||||
|
|
@ -436,17 +486,14 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
Args:
|
||||
path (str): _description_
|
||||
"""
|
||||
try:
|
||||
backend, path = path.split(":", maxsplit=1)
|
||||
except Exception as exc:
|
||||
raise AttributeError(f"Given path '{path}' is not startwith backend prefix:'local/boto3'") from exc
|
||||
backend, path = try_get_storage_backend(path)
|
||||
|
||||
init_args = (None,)
|
||||
if backend == "local":
|
||||
meta_info = get_local_meta(path)
|
||||
backend_key = backend
|
||||
elif backend == "boto3":
|
||||
meta_info = get_boto3_meta(path, self.tmp_local_folder, self.async_mode)
|
||||
meta_info = get_boto3_meta(path, self.tmp_local_folder, async_mode)
|
||||
backend_key = backend + ":" + meta_info.endpoint
|
||||
init_args = (meta_info.endpoint,)
|
||||
if (
|
||||
|
|
@ -474,17 +521,22 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
|
||||
def assert_fp_exists(self, folder) -> None:
|
||||
meta = self._get_client(path=folder)
|
||||
meta.client.assert_fp_exists(*unpack_meta(meta))
|
||||
meta.client.assert_fp_exists(*unpack_nosave_meta(meta))
|
||||
|
||||
def get_fns(self, folder) -> List[str]:
|
||||
meta = self._get_client(path=folder)
|
||||
return meta.client.get_fns(*unpack_meta(meta))
|
||||
return meta.client.get_fns(*unpack_nosave_meta(meta))
|
||||
|
||||
def save(self, save_path: str, saved_obj: Any, *args, async_upload=None, **kwargs):
|
||||
meta = self._get_client(path=save_path)
|
||||
def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs):
|
||||
|
||||
if async_upload is None:
|
||||
async_upload = self.async_mode
|
||||
|
||||
if not save_path.startswith("boto3:"):
|
||||
async_upload = False
|
||||
|
||||
meta = self._get_client(save_path, async_upload)
|
||||
|
||||
if async_upload:
|
||||
assert (
|
||||
self.tmp_local_folder
|
||||
|
|
@ -492,22 +544,22 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
tmp_step_file = meta.local_nvme_path
|
||||
self._to_be_del_files.append(tmp_step_file)
|
||||
with open(tmp_step_file, "wb") as f:
|
||||
torch.save(saved_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.async_executor(meta.async_upload_fn, *unpack_meta(meta))
|
||||
torch.save(to_save_obj, f, pickle_protocol=pickle.HIGHEST_PROTOCOL)
|
||||
self.async_executor(meta.async_upload_fn, *unpack_save_meta(meta))
|
||||
os.chmod(tmp_step_file, stat.S_IRWXU | stat.S_IRWXG | stat.S_IRWXO)
|
||||
self.async_task_peeding = True
|
||||
else:
|
||||
meta.client.sync_upload_fileobj(*unpack_meta(meta), *args, saved_obj=saved_obj, **kwargs)
|
||||
meta.client.sync_upload_fileobj(*unpack_save_meta(meta), saved_obj=to_save_obj, **kwargs)
|
||||
self.upload_count += 1
|
||||
|
||||
def load(self, load_path: str, *args, **kwargs) -> Any:
|
||||
def load(self, load_path: str, **kwargs) -> Any:
|
||||
self.wait()
|
||||
meta = self._get_client(path=load_path)
|
||||
return meta.client.load(*unpack_meta(meta), *args, **kwargs)
|
||||
return meta.client.load(*unpack_nosave_meta(meta), **kwargs)
|
||||
|
||||
def delete_obj(self, fp: str):
|
||||
meta = self._get_client(path=fp)
|
||||
meta.client.delete_obj(*unpack_meta(meta))
|
||||
meta.client.delete_obj(*unpack_nosave_meta(meta))
|
||||
|
||||
def _del_tmp_folder(self):
|
||||
for fp in self._to_be_del_files:
|
||||
|
|
@ -594,23 +646,24 @@ class StorageManager(metaclass=SingletonMeta):
|
|||
|
||||
if gpc.is_rank_for_log():
|
||||
self.upload_count += 1
|
||||
if self.async_mode:
|
||||
if self.async_mode and self.latest_save_folder:
|
||||
self.save(
|
||||
os.path.join(self.latest_save_folder, f"{self.latest_save_step}.step"),
|
||||
saved_obj=dict({"step": self.latest_save_step}),
|
||||
to_save_obj=dict({"step": self.latest_save_step}),
|
||||
async_upload=False,
|
||||
)
|
||||
self.latest_save_folder = None
|
||||
|
||||
|
||||
storage_manager: StorageManager = None
|
||||
|
||||
|
||||
def init_storage_manager(ckpt_config):
|
||||
def init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload):
|
||||
global storage_manager
|
||||
storage_manager = StorageManager(
|
||||
ckpt_config.enable_save_ckpt,
|
||||
tmp_local_folder=ckpt_config.async_upload_tmp_folder,
|
||||
async_mode=ckpt_config.async_upload,
|
||||
enable_save_ckpt,
|
||||
tmp_local_folder=async_upload_tmp_folder,
|
||||
async_mode=async_upload,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,4 +1,13 @@
|
|||
import datetime
|
||||
import os
|
||||
import signal
|
||||
import socket
|
||||
import traceback
|
||||
from functools import wraps
|
||||
|
||||
from internlm.utils.logger import get_logger
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class Timeout:
|
||||
|
|
@ -24,3 +33,81 @@ class Timeout:
|
|||
|
||||
def __exit__(self, error_type, value, traceback):
|
||||
signal.alarm(0)
|
||||
|
||||
|
||||
ENABLE_TIMEOUT = os.getenv("INTERNLM_ENABLE_TIMEOUT", None)
|
||||
|
||||
|
||||
timeout_threshold_dict = {
|
||||
"initialize_distributed_env": 120,
|
||||
"nopp_forward_backward_step": 360,
|
||||
"initialize_model": 10,
|
||||
"initialize_optimizer": 20,
|
||||
"optim_step": 30,
|
||||
"get_train_data_loader": 600,
|
||||
"get_validation_data_loader": 60,
|
||||
"load_new_batch": 10,
|
||||
"record_current_batch_training_metrics": 10,
|
||||
"save_checkpoint": 1200,
|
||||
"interleaved_forward_backward_step": 600,
|
||||
"nointerleaved_forward_backward_step": 600,
|
||||
}
|
||||
|
||||
if ENABLE_TIMEOUT is not None:
|
||||
os.environ["NCCL_ASYNC_ERROR_HANDLING"] = "1"
|
||||
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=int(os.getenv("NCCL_TIMEOUT", str(60))))
|
||||
else:
|
||||
timeout_threshold_dict = dict.fromkeys(timeout_threshold_dict.keys(), 0)
|
||||
LLM_NCCL_TIMEOUT = datetime.timedelta(seconds=1800)
|
||||
|
||||
|
||||
def try_get_gpc_rank():
|
||||
try:
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
rank = gpc.get_global_rank()
|
||||
except: # noqa # pylint: disable=bare-except
|
||||
rank = "unknown"
|
||||
|
||||
return f"host-{socket.gethostname()}-rank-{rank}"
|
||||
|
||||
|
||||
def llm_timeout(seconds=0, func_name=None):
|
||||
"""timeout decorator, Note that this decorator cannot be reentrant,
|
||||
otherwise the signal will be reset.
|
||||
|
||||
Args:
|
||||
seconds (int, optional): timeout threshold. Defaults to 300.
|
||||
func_name (str, optional): the func who is been waited to timeout.
|
||||
"""
|
||||
|
||||
def decorator(func):
|
||||
nonlocal func_name
|
||||
if func_name is None:
|
||||
func_name = func.__name__
|
||||
|
||||
@wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
def _handle_timeout(signum, frame):
|
||||
raise TimeoutError
|
||||
|
||||
nonlocal seconds
|
||||
seconds = timeout_threshold_dict.get(func_name, seconds)
|
||||
|
||||
if seconds > 0:
|
||||
signal.signal(signal.SIGALRM, _handle_timeout)
|
||||
signal.alarm(seconds)
|
||||
|
||||
try:
|
||||
result = func(*args, **kwargs)
|
||||
except TimeoutError as e:
|
||||
logger.error(f"TimeoutError at {try_get_gpc_rank()}: {func_name}\\n {traceback.format_exc()}")
|
||||
raise e
|
||||
finally:
|
||||
signal.alarm(0)
|
||||
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
return decorator
|
||||
|
|
|
|||
|
|
@ -0,0 +1,181 @@
|
|||
import os
|
||||
import shutil
|
||||
from subprocess import PIPE, STDOUT, Popen
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||
from internlm.utils.common import SingletonMeta
|
||||
|
||||
OSS_NAME = os.environ["OSS_BUCKET_NAME"]
|
||||
OSS_IP = os.environ["OSS_IP"]
|
||||
USER = os.environ["USER"]
|
||||
JOB_NAME = "CI_TEST"
|
||||
LOCAL_SAVE_PATH = "local:local_ckpt"
|
||||
|
||||
BOTO_SAVE_PATH = f"boto3:s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}"
|
||||
BOTO_SAVE_PATH_NO_PRFIX = f"s3://{OSS_NAME}.{OSS_IP}/{USER}/{JOB_NAME}/"
|
||||
|
||||
ASYNC_TMP_FOLDER = "./async_tmp_folder"
|
||||
|
||||
|
||||
# 1B
|
||||
init_config = Config(
|
||||
dict(
|
||||
parallel=dict(zero1=1, pipeline=dict(size=1, interleaved_overlap=False), sequence_parallel=False, tensor=1),
|
||||
model_type="INTERNLM",
|
||||
adam=dict(
|
||||
lr=1e-4,
|
||||
),
|
||||
data=dict(seq_len=2048, micro_num=1, micro_bsz=1, pack_sample_into_one=False, min_length=0, total_steps=9999),
|
||||
model=dict(
|
||||
checkpoint=False,
|
||||
num_attention_heads=2,
|
||||
embed_split_hidden=True,
|
||||
vocab_size=103168,
|
||||
embed_grad_scale=1,
|
||||
parallel_output=True,
|
||||
hidden_size=1024,
|
||||
num_layers=2,
|
||||
mlp_ratio=1,
|
||||
apply_post_layer_norm=False,
|
||||
dtype=torch.bfloat16,
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
num_chunks=1,
|
||||
),
|
||||
resume_tb_folder="",
|
||||
tensorboard_folder="",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def init_naive_model():
|
||||
# let MODEL_INITIALIZER to work
|
||||
import internlm.model.modeling_internlm # noqa # pylint: disable=unused-import
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.registry import MODEL_INITIALIZER
|
||||
|
||||
model = MODEL_INITIALIZER.get_module(module_name=gpc.config.model_type)(**(init_config.model))
|
||||
model = NaiveAMPModel(
|
||||
model=model,
|
||||
output_to_fp32=False,
|
||||
dtype=torch.bfloat16,
|
||||
sync_buffer=False,
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
def init_naive_optim(model):
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": 0.01}],
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.95),
|
||||
eps=1e-8,
|
||||
)
|
||||
return naive_optimizer
|
||||
|
||||
|
||||
def init_hybrid_optim(model):
|
||||
naive_optimizer = torch.optim.AdamW(
|
||||
params=[{"params": model.parameters(), "weight_decay": 0.01}],
|
||||
lr=1e-4,
|
||||
betas=(0.9, 0.95),
|
||||
eps=1e-8,
|
||||
)
|
||||
optimizer = HybridZeroOptimizer(
|
||||
naive_optimizer,
|
||||
grad_scal_cfg=Config(
|
||||
dict(
|
||||
fp16=dict(
|
||||
initial_scale=2**16,
|
||||
min_scale=1,
|
||||
growth_interval=1000,
|
||||
),
|
||||
growth_factor=2,
|
||||
backoff_factor=0.5,
|
||||
max_scale=2**24,
|
||||
hysteresis=2,
|
||||
)
|
||||
),
|
||||
zero_cfg=Config(
|
||||
dict(
|
||||
overlap_sync_grad=False,
|
||||
overlap_sync_param=False,
|
||||
reduce_bucket_size=512 * 1024 * 1024,
|
||||
clip_grad_norm=1.0,
|
||||
)
|
||||
),
|
||||
param_bcast_sync_handler=None,
|
||||
)
|
||||
return optimizer
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True, scope="function")
|
||||
def reset_singletons():
|
||||
SingletonMeta._instances = {}
|
||||
|
||||
|
||||
def reset_seed():
|
||||
from internlm.core.context.random import _SEED_MANAGER
|
||||
|
||||
_SEED_MANAGER.reset()
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def init_dist_and_model(rank=0, world_size=1):
|
||||
from internlm.initialize import initialize_distributed_env
|
||||
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "12377"
|
||||
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
|
||||
|
||||
# setup
|
||||
print("set up", flush=True)
|
||||
model = init_naive_model()
|
||||
# opim = init_naive_optim(model)
|
||||
opim = init_hybrid_optim(model)
|
||||
|
||||
yield model, opim
|
||||
|
||||
# teardown
|
||||
del model, opim
|
||||
print("teardown", flush=True)
|
||||
gpc.destroy()
|
||||
reset_seed()
|
||||
|
||||
|
||||
def enter_flag(text):
|
||||
print(f"{text} begin!", flush=True)
|
||||
yield
|
||||
print(f"{text} end!", flush=True)
|
||||
|
||||
|
||||
def del_tmp_file():
|
||||
try:
|
||||
shutil.rmtree(ASYNC_TMP_FOLDER, ignore_errors=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
shutil.rmtree(LOCAL_SAVE_PATH.split(":")[1], ignore_errors=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
try:
|
||||
cmd = r"/mnt/petrelfs/share/sensesync --dryrun --deleteSrc cp " + BOTO_SAVE_PATH_NO_PRFIX + " / "
|
||||
with Popen(cmd, stdout=PIPE, stderr=STDOUT, shell=True) as output:
|
||||
results, presults = "", ""
|
||||
for line in iter(output.stdout.readline, b""):
|
||||
results += str(line.rstrip())
|
||||
presults += line.rstrip().decode() + "\n"
|
||||
print(presults, flush=True)
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
|
@ -0,0 +1,247 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import HybridZeroOptimizer
|
||||
from internlm.utils.common import SingletonMeta
|
||||
from internlm.utils.model_checkpoint import CheckpointManager
|
||||
from internlm.utils.storage_manager import wait_async_upload_finish
|
||||
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
||||
ASYNC_TMP_FOLDER,
|
||||
BOTO_SAVE_PATH,
|
||||
LOCAL_SAVE_PATH,
|
||||
del_tmp_file,
|
||||
init_dist_and_model,
|
||||
reset_singletons,
|
||||
)
|
||||
|
||||
TOTAL_STEP = 6
|
||||
|
||||
CKPT_EVERY = 4
|
||||
SNPASHOT_EVERY = 2
|
||||
|
||||
|
||||
ckpt_config_list = [
|
||||
# Old interface format
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
save_ckpt_folder=BOTO_SAVE_PATH,
|
||||
load_optimizer=True,
|
||||
checkpoint_every=CKPT_EVERY,
|
||||
async_upload=True,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
snapshot_ckpt_folder="/".join([BOTO_SAVE_PATH, "snapshot"]),
|
||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
||||
stop_file_path=None,
|
||||
load_model_only_folder=None,
|
||||
load_given_ckpt=False,
|
||||
load_ckpt_folder=None,
|
||||
is_old_api=True,
|
||||
),
|
||||
# Old interface format
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
save_ckpt_folder=LOCAL_SAVE_PATH,
|
||||
load_optimizer=True,
|
||||
checkpoint_every=CKPT_EVERY,
|
||||
async_upload=False,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
snapshot_ckpt_folder="/".join([LOCAL_SAVE_PATH, "snapshot"]),
|
||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
||||
stop_file_path=None,
|
||||
load_model_only_folder=None,
|
||||
load_given_ckpt=False,
|
||||
load_ckpt_folder=None,
|
||||
is_old_api=True,
|
||||
),
|
||||
# New interface format
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
save_ckpt_folder=BOTO_SAVE_PATH,
|
||||
checkpoint_every=CKPT_EVERY,
|
||||
async_upload=True,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
||||
stop_file_path=None,
|
||||
is_old_api=False,
|
||||
auto_resume=True,
|
||||
),
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
save_ckpt_folder=LOCAL_SAVE_PATH,
|
||||
checkpoint_every=CKPT_EVERY,
|
||||
async_upload=False,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
oss_snapshot_freq=SNPASHOT_EVERY,
|
||||
stop_file_path=None,
|
||||
load_ckpt_folder=None,
|
||||
is_old_api=False,
|
||||
auto_resume=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def overwrite_optim_state(optim, set_value):
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
|
||||
p.data.fill_(set_value)
|
||||
for group_id in range(len(optim._fp16_param_groups)):
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
|
||||
rank=optim._zero_local_rank, group_id=group_id
|
||||
)
|
||||
fp16_p.fill_(set_value)
|
||||
else:
|
||||
for group in optim.param_groups:
|
||||
for p in group["params"]:
|
||||
# p.copy_(torch.full_like(p, set_value, dtype=p.dtype))
|
||||
p.data.fill_(set_value)
|
||||
|
||||
|
||||
def compare_optim_state(optim1, optim2):
|
||||
re = True
|
||||
if isinstance(optim1, HybridZeroOptimizer):
|
||||
fp32_buff1 = optim1._fp32_flat_param_groups_of_current_rank
|
||||
fp32_buff2 = optim2._fp32_flat_param_groups_of_current_rank
|
||||
for group_id_1, group_id_2 in zip(fp32_buff1, fp32_buff2):
|
||||
re &= group_id_1 == group_id_2
|
||||
if optim1.zero_local_rank not in optim1.param_group_no_params_ranks[group_id_1]:
|
||||
re &= torch.equal(fp32_buff1[group_id_1], fp32_buff1[group_id_2])
|
||||
else:
|
||||
for group1, group2 in zip(optim1.param_groups, optim2.param_groups):
|
||||
for p1, p2 in zip(group1["params"], group2["params"]):
|
||||
re &= torch.equal(p1, p2)
|
||||
return re
|
||||
|
||||
|
||||
def compare_optim_value(optim, value):
|
||||
re = True
|
||||
if isinstance(optim, HybridZeroOptimizer):
|
||||
for group_id, p in optim._fp32_flat_param_groups_of_current_rank.items():
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
|
||||
for group_id in range(len(optim._fp16_param_groups)):
|
||||
if optim._zero_local_rank not in optim.param_group_no_params_ranks[group_id]:
|
||||
fp16_p = optim._param_store.get_flat_fp16_param_by_rank_group(
|
||||
rank=optim._zero_local_rank, group_id=group_id
|
||||
)
|
||||
re &= torch.equal(fp16_p, torch.full_like(fp16_p, value, dtype=fp16_p.dtype))
|
||||
else:
|
||||
for group in optim.param_groups:
|
||||
for p in group["params"]:
|
||||
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
|
||||
return re
|
||||
|
||||
|
||||
def overwrite_model_value(model, value):
|
||||
for p in model.parameters():
|
||||
# p.copy_(torch.full_like(p, value, dtype=p.dtype))
|
||||
p.data.fill_(value)
|
||||
|
||||
|
||||
def compare_model_value(model, value):
|
||||
re = True
|
||||
for p in model.parameters():
|
||||
re &= torch.equal(p, torch.full_like(p, value, dtype=p.dtype))
|
||||
return re
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def del_tmp():
|
||||
del_tmp_file()
|
||||
yield
|
||||
del_tmp_file()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("del_tmp")
|
||||
@pytest.mark.usefixtures("reset_singletons")
|
||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
||||
def test_ckpt_mm(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
|
||||
from internlm.utils.model_checkpoint import CheckpointLoadMask, CheckpointLoadType
|
||||
|
||||
ckpt_config = Config(ckpt_config)
|
||||
assert ckpt_config.checkpoint_every < TOTAL_STEP
|
||||
assert ckpt_config.oss_snapshot_freq < TOTAL_STEP
|
||||
|
||||
model, opim = init_dist_and_model
|
||||
train_state = TrainState(gpc.config, None)
|
||||
if isinstance(opim, HybridZeroOptimizer):
|
||||
print("Is HybridZeroOptimizer!", flush=True)
|
||||
else:
|
||||
print("Is naive Adam!", flush=True)
|
||||
|
||||
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
|
||||
latest_ckpt_step = None
|
||||
for i in range(TOTAL_STEP + 1):
|
||||
overwrite_model_value(model, i)
|
||||
overwrite_optim_state(opim, i)
|
||||
|
||||
train_state.batch_count = i
|
||||
train_state.step_count += 1
|
||||
|
||||
save_ckpts, _, _ = ckpt_mm.is_now_to_save_ckpt(train_state)
|
||||
if save_ckpts:
|
||||
latest_ckpt_step = i
|
||||
|
||||
ckpt_mm.try_save_checkpoint(train_state)
|
||||
|
||||
wait_async_upload_finish()
|
||||
latest_ckpt_info = ckpt_mm.query_lastest_ckpt()
|
||||
assert latest_ckpt_info is not None
|
||||
latest_ckpt = latest_ckpt_info["path"]
|
||||
if ckpt_mm.save_ckpt_folder.startswith("local"):
|
||||
assert latest_ckpt == "local:local_ckpt/snapshot/0", latest_ckpt
|
||||
else:
|
||||
assert latest_ckpt == f"{BOTO_SAVE_PATH}/snapshot/0", latest_ckpt
|
||||
|
||||
del ckpt_mm
|
||||
SingletonMeta._instances = {}
|
||||
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
|
||||
ckpt_mm.try_resume_training(train_state)
|
||||
assert latest_ckpt_step == 5
|
||||
assert train_state.step_count == 6
|
||||
assert train_state.batch_count == 6
|
||||
assert compare_optim_value(ckpt_mm.optimizer, latest_ckpt_step), ckpt_mm.optimizer.param_groups[0]["params"][0]
|
||||
assert compare_model_value(ckpt_mm.model, latest_ckpt_step), list(ckpt_mm.model.parameters())[0][0]
|
||||
|
||||
if ckpt_mm.save_ckpt_folder.startswith("local:"):
|
||||
ckpt_mm.load_ckpt_info = dict(
|
||||
path=os.path.join(LOCAL_SAVE_PATH, "4"),
|
||||
content=CheckpointLoadMask(("all",)),
|
||||
ckpt_type=CheckpointLoadType.INTERNLM,
|
||||
)
|
||||
else:
|
||||
ckpt_mm.load_ckpt_info = dict(
|
||||
path=os.path.join(BOTO_SAVE_PATH, "4"),
|
||||
content=CheckpointLoadMask(("all",)),
|
||||
ckpt_type=CheckpointLoadType.INTERNLM,
|
||||
)
|
||||
|
||||
ckpt_mm.try_resume_training(train_state)
|
||||
|
||||
assert train_state.step_count == 4
|
||||
assert train_state.batch_count == 4
|
||||
assert compare_optim_value(ckpt_mm.optimizer, 3), ckpt_mm.optimizer.param_groups[0]["params"][0]
|
||||
assert compare_model_value(ckpt_mm.model, 3), list(ckpt_mm.model.parameters())[0][0]
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("del_tmp")
|
||||
@pytest.mark.usefixtures("reset_singletons")
|
||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
||||
def test_ckpt_mm_ping(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-import
|
||||
ckpt_config = Config(ckpt_config)
|
||||
|
||||
model, opim = init_dist_and_model
|
||||
SingletonMeta._instances = {}
|
||||
ckpt_mm = CheckpointManager(ckpt_config, model=model, optimizer=opim)
|
||||
ckpt_mm.try_ping_storage()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main()
|
||||
|
|
@ -0,0 +1,89 @@
|
|||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.initialize.launch import get_config_value
|
||||
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
||||
ASYNC_TMP_FOLDER,
|
||||
BOTO_SAVE_PATH,
|
||||
LOCAL_SAVE_PATH,
|
||||
del_tmp_file,
|
||||
init_dist_and_model,
|
||||
reset_singletons,
|
||||
)
|
||||
|
||||
ASYNC_TMP_FOLDER = "./async_tmp_folder"
|
||||
ckpt_config_list = [
|
||||
# async boto
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
async_upload=True,
|
||||
save_folder=BOTO_SAVE_PATH,
|
||||
test_id=0,
|
||||
),
|
||||
# sync local
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
async_upload_tmp_folder=None,
|
||||
async_upload=False,
|
||||
save_folder=LOCAL_SAVE_PATH,
|
||||
test_id=1,
|
||||
),
|
||||
# sync boto
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
async_upload_tmp_folder=None,
|
||||
async_upload=False,
|
||||
save_folder=BOTO_SAVE_PATH,
|
||||
test_id=2,
|
||||
),
|
||||
# async local
|
||||
dict(
|
||||
enable_save_ckpt=True,
|
||||
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
|
||||
async_upload=True,
|
||||
save_folder=LOCAL_SAVE_PATH,
|
||||
test_id=3,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def del_tmp():
|
||||
del_tmp_file()
|
||||
yield
|
||||
del_tmp_file()
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("del_tmp")
|
||||
@pytest.mark.usefixtures("reset_singletons")
|
||||
@pytest.mark.parametrize("ckpt_config", ckpt_config_list)
|
||||
def test_storage_mm_save_load(ckpt_config, init_dist_and_model): # noqa # pylint: disable=unused-argument
|
||||
from internlm.utils.storage_manager import (
|
||||
check_folder,
|
||||
get_fns,
|
||||
init_storage_manager,
|
||||
llm_load,
|
||||
llm_save,
|
||||
wait_async_upload_finish,
|
||||
)
|
||||
|
||||
ckpt_config = Config(ckpt_config)
|
||||
enable_save_ckpt = get_config_value(ckpt_config, "enable_save_ckpt", False)
|
||||
async_upload_tmp_folder = get_config_value(ckpt_config, "async_upload_tmp_folder", False)
|
||||
async_upload = get_config_value(ckpt_config, "async_upload", False)
|
||||
|
||||
init_storage_manager(enable_save_ckpt, async_upload_tmp_folder, async_upload)
|
||||
|
||||
tobj = torch.rand(64, 64)
|
||||
save_fn = os.path.join(ckpt_config.save_folder, "test.pt")
|
||||
llm_save(save_fn, tobj)
|
||||
if ckpt_config.test_id == 0:
|
||||
wait_async_upload_finish()
|
||||
check_folder(save_fn)
|
||||
assert get_fns(ckpt_config.save_folder)[0] == "test.pt"
|
||||
load_obj = llm_load(save_fn, map_location="cpu")
|
||||
assert 0 == ((load_obj != tobj).sum())
|
||||
|
|
@ -0,0 +1,119 @@
|
|||
import fcntl
|
||||
import os
|
||||
import time
|
||||
from multiprocessing import Process
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
os.environ["INTERNLM_ENABLE_TIMEOUT"] = "1" # noqa # pylint: disable=wrong-import-position
|
||||
os.environ["NCCL_TIMEOUT"] = "5"
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
|
||||
init_config,
|
||||
)
|
||||
|
||||
WORLD_SIZE = 2
|
||||
|
||||
|
||||
@llm_timeout(2, "fake_timeout_func")
|
||||
def fake_timeout_func():
|
||||
time.sleep(10)
|
||||
|
||||
|
||||
@llm_timeout(10, "nccl_timeout_func")
|
||||
def nccl_timeout_func(rank):
|
||||
# see: https://github.com/pytorch/pytorch/issues/104506#issuecomment-1679762880
|
||||
# 'NCCL_ASYNC_ERROR_HANDLING' cannot take effect on the first collective communication.
|
||||
buff = torch.ones([64, 64]).cuda(rank)
|
||||
dist.all_reduce(buff) # lazy communicator init
|
||||
torch.cuda.synchronize()
|
||||
if rank == 0:
|
||||
dist.all_reduce(buff)
|
||||
torch.cuda.synchronize() # main thread will hang at here.
|
||||
else:
|
||||
time.sleep(9999)
|
||||
|
||||
|
||||
@llm_timeout(10, "try_file_lock")
|
||||
def try_file_lock(rank, stop_file_path):
|
||||
if rank == 1:
|
||||
time.sleep(5)
|
||||
|
||||
with open(stop_file_path, "r", encoding="utf-8") as f:
|
||||
fcntl.flock(f, fcntl.LOCK_EX) # rank 1 hang.
|
||||
if rank == 0:
|
||||
time.sleep(99999) # rank 0 hang.
|
||||
f.seek(0)
|
||||
f.read()
|
||||
fcntl.flock(f, fcntl.LOCK_UN)
|
||||
|
||||
|
||||
def local_timeout(rank, _):
|
||||
|
||||
try:
|
||||
fake_timeout_func()
|
||||
except TimeoutError as e:
|
||||
print(f"local_timeout, rank:{rank}, e:{e}", flush=True)
|
||||
else:
|
||||
assert False, "It should timeout!"
|
||||
|
||||
|
||||
def gpc_timeout(rank, world_size):
|
||||
|
||||
from internlm.initialize import initialize_distributed_env
|
||||
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "12377"
|
||||
initialize_distributed_env(config=init_config, launcher="torch", master_port=12377, args_check=False)
|
||||
|
||||
try:
|
||||
nccl_timeout_func(rank)
|
||||
except TimeoutError as e:
|
||||
print(f"gpc_timeout, rank:{rank}, e:{e}", flush=True)
|
||||
time.sleep(5) # wait rank 0 to be killed
|
||||
else:
|
||||
time.sleep(5) # give some time to let Watchdog kill rank 0.
|
||||
assert False, "It should timeout!"
|
||||
|
||||
|
||||
def file_lock_timeout(rank, _, stop_file_path):
|
||||
if rank == 0:
|
||||
with open(stop_file_path, "w"):
|
||||
pass
|
||||
try:
|
||||
try_file_lock(rank, stop_file_path)
|
||||
except TimeoutError as e:
|
||||
print(e, flush=True)
|
||||
else:
|
||||
assert False, "It should timeout!"
|
||||
finally:
|
||||
if rank == 0:
|
||||
os.remove(stop_file_path)
|
||||
|
||||
|
||||
timeout_func_list = [(gpc_timeout, 2, None), (local_timeout, 1, None), (file_lock_timeout, 2, "test_lock.log")]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("timeout_func_and_args", timeout_func_list)
|
||||
def test_timeout(timeout_func_and_args):
|
||||
timeout_func, world_size, other_args = timeout_func_and_args
|
||||
procs = []
|
||||
for i in range(world_size):
|
||||
if other_args is None:
|
||||
args = (i, world_size)
|
||||
else:
|
||||
args = (i, world_size, other_args)
|
||||
proc = Process(target=timeout_func, args=args)
|
||||
proc.start()
|
||||
procs.append(proc)
|
||||
|
||||
for proc in procs:
|
||||
proc.join(15)
|
||||
if proc.is_alive():
|
||||
proc.terminate()
|
||||
proc.join()
|
||||
42
train.py
42
train.py
|
|
@ -36,7 +36,6 @@ from internlm.utils.common import (
|
|||
parse_args,
|
||||
)
|
||||
from internlm.utils.evaluation import evaluate_on_val_dls
|
||||
from internlm.utils.gputest import bench_gpu, bench_net
|
||||
from internlm.utils.logger import get_logger, initialize_uniscale_logger
|
||||
from internlm.utils.megatron_timers import megatron_timer as timer
|
||||
from internlm.utils.model_checkpoint import CheckpointManager
|
||||
|
|
@ -74,7 +73,6 @@ def main(args):
|
|||
total_steps = gpc.config.data.total_steps
|
||||
valid_every = gpc.config.data.valid_every
|
||||
label_smoothing = gpc.config.loss.label_smoothing
|
||||
lr = gpc.config.adam.lr
|
||||
|
||||
get_tflops_func = partial(
|
||||
get_megatron_flops,
|
||||
|
|
@ -97,21 +95,11 @@ def main(args):
|
|||
# initialize customed llm logger
|
||||
uniscale_logger = initialize_llm_logger(start_time=current_time)
|
||||
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config)
|
||||
|
||||
# initialize model
|
||||
model = initialize_model()
|
||||
|
||||
with open(args.config, "r") as f:
|
||||
config_lines = f.readlines()
|
||||
ckpt_manager = CheckpointManager(
|
||||
ckpt_config=gpc.config.ckpt,
|
||||
model=model,
|
||||
model_config=gpc.config.model,
|
||||
model_config_file="".join(config_lines),
|
||||
feishu_address=gpc.config.alert_address,
|
||||
)
|
||||
|
||||
# initialize loss function
|
||||
criterion = FlashGPTLMLoss(parallel_output=True, label_smoothing=label_smoothing)
|
||||
|
|
@ -119,18 +107,28 @@ def main(args):
|
|||
# initialize the train and validation data loader
|
||||
train_dl, dataset_types = get_train_data_loader(num_worker=4)
|
||||
val_dls = get_validation_data_loader()
|
||||
train_state.init_batch_sampler(train_dl)
|
||||
|
||||
# Loading model weights must be done before zero is initialized.
|
||||
ckpt_manager.try_load_model(current_time)
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||
|
||||
# if fsdp enabled, warp the model
|
||||
model = warp_FSDP_model(model)
|
||||
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=model)
|
||||
|
||||
ckpt_manager = CheckpointManager(
|
||||
ckpt_config=gpc.config.ckpt,
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
train_dl=train_dl,
|
||||
model_config=gpc.config.model,
|
||||
model_config_file="".join(config_lines),
|
||||
feishu_address=gpc.config.monitor.alert.feishu_alert_address,
|
||||
)
|
||||
|
||||
# Loading other persistent training states.
|
||||
ckpt_manager.try_resume_training(lr_scheduler, optimizer, lr, train_state, train_dl)
|
||||
ckpt_manager.try_resume_training(train_state, current_time)
|
||||
|
||||
# initialize customed llm writer
|
||||
writer = Writer(
|
||||
|
|
@ -201,8 +199,6 @@ def main(args):
|
|||
for batch_count in range(train_state.batch_count, total_steps):
|
||||
if batch_count % 50 == 0:
|
||||
torch.cuda.empty_cache()
|
||||
bench_gpu()
|
||||
bench_net()
|
||||
|
||||
start_time = time.time()
|
||||
timer("one-batch").start()
|
||||
|
|
@ -245,7 +241,7 @@ def main(args):
|
|||
if -1 in grad_norm_groups.values() and gpc.is_rank_for_log(): # -1 encodes a specific failure case
|
||||
logger.warning(f"Warning: skip parameter update at step {batch_count}.")
|
||||
send_alert_message(
|
||||
address=gpc.config.alert_address,
|
||||
address=gpc.config.monitor.alert.feishu_alert_address,
|
||||
message=f"Warning: skip parameter update at step {batch_count}.",
|
||||
)
|
||||
|
||||
|
|
@ -305,11 +301,15 @@ if __name__ == "__main__":
|
|||
assert hasattr(gpc, "config") and gpc.config is not None
|
||||
|
||||
# initialize monitor manager context
|
||||
with initialize_monitor_manager(job_name=gpc.config.JOB_NAME, alert_address=gpc.config.alert_address):
|
||||
with initialize_monitor_manager(
|
||||
job_name=gpc.config.JOB_NAME, alert_address=gpc.config.monitor.alert.feishu_alert_address
|
||||
):
|
||||
try:
|
||||
main(args)
|
||||
except Exception:
|
||||
logger.error(
|
||||
f"Raise exception from {hostname} with rank id: {gpc.get_global_rank()}\n{traceback.format_exc()}",
|
||||
)
|
||||
mm.monitor_exception(alert_address=gpc.config.alert_address, excp_info=traceback.format_exc())
|
||||
mm.monitor_exception(
|
||||
alert_address=gpc.config.monitor.alert.feishu_alert_address, excp_info=traceback.format_exc()
|
||||
)
|
||||
|
|
|
|||
Loading…
Reference in New Issue