mirror of https://github.com/InternLM/InternLM
586 lines
24 KiB
Python
586 lines
24 KiB
Python
#!/usr/bin/env python
|
|
# -*- encoding: utf-8 -*-
|
|
|
|
import copy
|
|
import fcntl
|
|
import os
|
|
import socket
|
|
import time
|
|
from enum import Enum
|
|
from typing import Dict
|
|
|
|
import torch
|
|
|
|
from internlm.core.context import ParallelMode
|
|
from internlm.core.context import global_context as gpc
|
|
from internlm.core.trainer import TrainState
|
|
from internlm.monitor import send_alert_message
|
|
from internlm.solver.optimizer import HybridZeroOptimizer
|
|
from internlm.utils.common import get_current_device
|
|
from internlm.utils.logger import get_logger
|
|
from internlm.utils.megatron_timers import megatron_timer as timer
|
|
from internlm.utils.storage_manager import (
|
|
get_fns,
|
|
get_storage_manager,
|
|
llm_load,
|
|
llm_save,
|
|
)
|
|
|
|
logger = get_logger(__file__)
|
|
|
|
|
|
class CheckpointType(Enum):
|
|
NORMAL_CHECKPOINT = 1
|
|
SNAPSHOT_CHECKPOINT = 2
|
|
|
|
|
|
def get_model_topology(model):
|
|
"""
|
|
Returns:
|
|
{
|
|
'{name}': {'dim': int}
|
|
}
|
|
where name is the name of the module, and all parameters under this module are
|
|
concatenated along the dimension 'dim'.
|
|
"""
|
|
|
|
from flash_attn.modules.embedding import VocabParallelEmbedding
|
|
|
|
topos = {}
|
|
for name, module in model.named_modules():
|
|
# If it does not meet these conditions, it is shared between various tp/dp, and it is necessary to assert
|
|
# that they are consistent.
|
|
if isinstance(module, VocabParallelEmbedding):
|
|
topos[name] = {"dim": 0}
|
|
return topos
|
|
|
|
|
|
def save_model_checkpoint(folder, model):
|
|
"""
|
|
Save the model according to the relationship between tp and dp. The principle is that the data of each tp
|
|
will not be gathered and saved separately, which is equivalent to actual sharding. The saved weight is named
|
|
- folder
|
|
- model_tp{tp_rank}_pp{pp_rank}.pt
|
|
|
|
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
|
|
|
Args:
|
|
folder: The folder to save the model
|
|
model: The model to be saved
|
|
"""
|
|
|
|
states = model.state_dict()
|
|
topo = get_model_topology(model)
|
|
|
|
if folder is not None:
|
|
dp_size = gpc.get_world_size(ParallelMode.DATA)
|
|
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
|
dp_rank = gpc.get_local_rank(ParallelMode.DATA)
|
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
|
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
|
|
# TODO In theory, we should also consider pp level, but since pp is generally a state across machines,
|
|
# even if pp is not considered, it will definitely not be written on the same machine.
|
|
should_save_rank_pair = set() # (tp_rank, dp_rank)
|
|
for i in range(tp_size):
|
|
should_save_rank_pair.add((i, i % dp_size))
|
|
|
|
if (tp_rank, dp_rank) in should_save_rank_pair:
|
|
fn = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
|
fp = os.path.join(folder, fn)
|
|
llm_save(fp, saved_obj=states)
|
|
topo_fn = f"topo_tp{tp_rank}_pp{pp_rank}.json"
|
|
topo_fp = os.path.join(folder, topo_fn)
|
|
llm_save(topo_fp, saved_obj=topo)
|
|
|
|
torch.distributed.barrier()
|
|
|
|
|
|
def load_model_checkpoint(folder, model):
|
|
"""
|
|
There should be weights with names similar to the following under the folder.
|
|
- folder
|
|
- model_tp{tp_rank}_pp{pp_rank}.pt
|
|
|
|
If the tp is inconsistent with the saved one in the future use, the weight needs to be converted before loading.
|
|
"""
|
|
|
|
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
|
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
|
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
|
|
fns = get_fns(folder)
|
|
max_pp, max_tp = 0, 0
|
|
for fn in fns:
|
|
if fn.startswith("model_t") and not fn.endswith(".md5"):
|
|
segements = os.path.splitext(fn)[0].split("_")
|
|
max_pp = max(max_pp, int(segements[-1][2:]))
|
|
max_tp = max(max_tp, int(segements[-2][2:]))
|
|
|
|
assert (
|
|
pp_size == max_pp + 1
|
|
), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
|
|
assert (
|
|
tp_size == max_tp + 1
|
|
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
|
|
|
should_load_name = f"model_tp{tp_rank}_pp{pp_rank}.pt"
|
|
fp = os.path.join(folder, should_load_name)
|
|
states = llm_load(fp, map_location=get_current_device())
|
|
|
|
missing_k, unexpected_keys = model.load_state_dict(states, strict=False)
|
|
if len(missing_k) != 0:
|
|
logger.warning(f"Warning: missing keys {missing_k}")
|
|
if len(unexpected_keys) != 0:
|
|
logger.warning(f"Warning: unexpected keys {unexpected_keys}")
|
|
|
|
# avoid to cuda oom, Ref: https://discuss.pytorch.org/t/load-state-dict-causes-memory-leak/36189/11
|
|
del states
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def save_optimizer_checkpoint(optim, state_path):
|
|
"""Store the state of the optimizer to the local file system or remote OSS.
|
|
|
|
Args:
|
|
optim (Optimizer)
|
|
state_path (str): The state loading path of optimizer.
|
|
"""
|
|
|
|
# TODO sanity check for optimizer type
|
|
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
|
tp_rank = gpc.get_local_rank(ParallelMode.TENSOR)
|
|
pp_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
|
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
|
fp = f"optimizer_tp{tp_rank}_pp{pp_rank}_zo{zero_rank}.pt"
|
|
|
|
states = optim.state_dict()
|
|
if isinstance(optim, HybridZeroOptimizer):
|
|
if gpc.get_global_rank() < optim.zero_world_size * tp_size * pp_size:
|
|
llm_save(os.path.join(state_path, fp), states)
|
|
if "zero_devide_optim_plan" in states:
|
|
params_per_rank_id_dict = states.pop("zero_devide_optim_plan")
|
|
fp_meta = os.path.join(state_path, optim.rank_unique_id)
|
|
llm_save(fp_meta, params_per_rank_id_dict)
|
|
else:
|
|
llm_save(os.path.join(state_path, fp), states)
|
|
|
|
|
|
def load_optimizer_checkpoint(folder, optim):
|
|
"""Load the optimizer state from the local file system or remote
|
|
object storage Service (OSS).
|
|
|
|
Args:
|
|
optim (Optimizer): optimizer
|
|
folder (str): The FS/OSS path where the optimizer will be stored.
|
|
"""
|
|
|
|
fns = get_fns(folder)
|
|
max_tp, max_pp, max_zero = 0, 0, 0
|
|
for fn in fns:
|
|
if fn.startswith("optimizer_") and not fn.endswith(".md5"):
|
|
_, tp, pp, zero = os.path.splitext(fn)[0].split("_")
|
|
max_zero = max(max_zero, int(zero[2:]))
|
|
max_tp = max(max_tp, int(tp[2:]))
|
|
max_pp = max(max_pp, int(pp[2:]))
|
|
|
|
zero_size = gpc.get_world_size(ParallelMode.ZERO1)
|
|
zero_rank = gpc.get_local_rank(ParallelMode.ZERO1)
|
|
tp_size = gpc.get_world_size(ParallelMode.TENSOR)
|
|
pp_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
|
|
|
assert (
|
|
zero_size == max_zero + 1
|
|
), f"The weights are save for {max_zero+1} data parallel, while current has {zero_size} zero broadcast range."
|
|
assert (
|
|
pp_size == max_pp + 1
|
|
), f"The weights are save for {max_pp+1} pipelines, while current has {pp_size} pipelines"
|
|
assert (
|
|
tp_size == max_tp + 1
|
|
), f"The weights are save for {max_tp+1} parallelism, while current has {tp_size} tensor parallelism"
|
|
|
|
fp = f"optimizer_tp{gpc.get_local_rank(ParallelMode.TENSOR)}_"
|
|
fp += f"pp{gpc.get_local_rank(ParallelMode.PIPELINE)}_"
|
|
fp += f"zo{zero_rank}.pt"
|
|
states = llm_load(os.path.join(folder, fp), map_location=get_current_device())
|
|
|
|
if isinstance(optim, HybridZeroOptimizer):
|
|
fp_meta = os.path.join(folder, optim.rank_unique_id)
|
|
try:
|
|
zero_devide_optim_plan = llm_load(fp_meta)
|
|
states.update({"zero_devide_optim_plan": zero_devide_optim_plan})
|
|
except Exception as e:
|
|
logger.warning(
|
|
f"Read zero optimzer split file '{fp_meta}', for '{e}'"
|
|
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
|
|
)
|
|
|
|
optim.load_state_dict(states)
|
|
del states
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def load_sampler(ckpt_path: str, sampler):
|
|
sampler_states = llm_load(os.path.join(ckpt_path, "sampler.pt"))
|
|
sampler.load_state_dict(sampler_states)
|
|
if gpc.is_rank_for_log():
|
|
pstate = copy.deepcopy(sampler_states)
|
|
pstate.pop("indices")
|
|
pstate.pop("rng_state")
|
|
logger.info(f"reload sampler_states:{pstate}")
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def load_context(ckpt_path: str, train_dl, train_state: TrainState):
|
|
context_stuffs = llm_load(os.path.join(ckpt_path, "context.pt"))
|
|
train_state.load_state_dict(context_stuffs, train_dl)
|
|
if gpc.is_rank_for_log():
|
|
logger.info(f"reload train_state:{train_state}")
|
|
torch.cuda.empty_cache()
|
|
|
|
|
|
def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, learning_rate, train_state: TrainState):
|
|
scheduler_states = llm_load(os.path.join(ckpt_path, "schedulder.pt"))
|
|
if learning_rate != scheduler_states["base_lrs"][0] and gpc.is_rank_for_log():
|
|
logger.warning(
|
|
f"Using new learning rate {learning_rate} to replace old learn rate {scheduler_states['base_lrs'][0]}."
|
|
)
|
|
|
|
base_lrs = copy.deepcopy(scheduler_states["base_lrs"])
|
|
scheduler_states["base_lrs"] = [learning_rate] * len(scheduler_states["base_lrs"])
|
|
if "after_scheduler_dict" in scheduler_states:
|
|
scheduler_states["after_scheduler_dict"]["base_lrs"] = [learning_rate] * len(
|
|
scheduler_states["after_scheduler_dict"]["base_lrs"]
|
|
)
|
|
|
|
lr_scheduler.load_state_dict(scheduler_states)
|
|
lr_scheduler.last_epoch = train_state.step_count + 1
|
|
|
|
ratios = [learning_rate / lr for lr in base_lrs]
|
|
for idx, param_group in enumerate(optimizer.param_groups):
|
|
param_group["lr"] = param_group["lr"] * ratios[idx]
|
|
torch.cuda.empty_cache()
|
|
|
|
if gpc.is_rank_for_log():
|
|
logger.info(f"reload load_scheduler:{lr_scheduler}")
|
|
|
|
|
|
class CheckpointManager:
|
|
"""StorageManagerContext"""
|
|
|
|
def __init__(self, ckpt_config, model, model_config, feishu_address=None) -> None:
|
|
"""
|
|
CheckpointManager is used to decide when to store ckpt. If it is an asynchronous
|
|
upload mode, you must call wait_async_upload_finish at the end of the program to wait
|
|
for the asynchronous ckpt upload to complete.
|
|
|
|
Args:
|
|
ckpt_config (dict): model checkpoint config.
|
|
model (nn.module): model obj
|
|
optimizer (object): optimzier obj.
|
|
lr_scheduler (object): lr_scheduler obj.
|
|
model_config (dict): model config.
|
|
"""
|
|
self.enable_save_ckpt = ckpt_config.enable_save_ckpt
|
|
self.checkpoint_every = ckpt_config.checkpoint_every
|
|
self.save_ckpt_folder = ckpt_config.save_ckpt_folder
|
|
self.snapshot_ckpt_folder = ckpt_config.snapshot_ckpt_folder
|
|
self.oss_snapshot_freq: int = ckpt_config.oss_snapshot_freq
|
|
self.stop_file_path = ckpt_config.stop_file_path
|
|
self.load_model_only_folder = ckpt_config.load_model_only_folder
|
|
self.feishu_address = feishu_address
|
|
self.storage_manager = get_storage_manager()
|
|
self.snapshot_counter = 0
|
|
self.load_optimizer = gpc.config.ckpt.load_optimizer
|
|
|
|
self.model = model
|
|
self.model_config = model_config
|
|
|
|
if self.stop_file_path and gpc.get_global_rank() == 0:
|
|
dir_path = os.path.dirname(self.stop_file_path)
|
|
if dir_path != "" and not os.path.exists(dir_path):
|
|
os.makedirs(dir_path)
|
|
with open(self.stop_file_path, "w", encoding="utf-8") as f:
|
|
f.write("0")
|
|
|
|
if not ckpt_config.load_given_ckpt:
|
|
latest_ckpt_path = self.query_lastest_ckpt()
|
|
self.load_ckpt_folder = latest_ckpt_path if latest_ckpt_path is not None else ckpt_config.load_ckpt_folder
|
|
else:
|
|
self.load_ckpt_folder = ckpt_config.load_ckpt_folder
|
|
|
|
def quit_signal_handler(self, train_state) -> bool:
|
|
"""
|
|
Exit signal detection function, if we write the exit step in the 'QUIT_FILE_PATH' file,
|
|
all ranks will save ckpt and exit.
|
|
Negative integer step means save ckpt.
|
|
Positive integer step means save ckpt and quit.
|
|
|
|
Args:
|
|
train_state (TrainState):
|
|
Returns:
|
|
bool: whether to quit.
|
|
"""
|
|
if self.stop_file_path is None:
|
|
logger.warning("no set stop_file_path")
|
|
return
|
|
|
|
now_break, now_save_ckpt, save_type = False, False, CheckpointType.NORMAL_CHECKPOINT
|
|
with open(self.stop_file_path, "a+", encoding="utf-8") as f:
|
|
fcntl.flock(f, fcntl.LOCK_EX)
|
|
f.seek(0)
|
|
msg = f.read()
|
|
fcntl.flock(f, fcntl.LOCK_UN)
|
|
action_step = int(msg)
|
|
|
|
if action_step < 0 and abs(action_step) == train_state.step_count:
|
|
now_save_ckpt = True
|
|
|
|
if action_step > 0 and action_step == train_state.step_count:
|
|
now_break, now_save_ckpt = True, True
|
|
|
|
if action_step != 0 and gpc.is_rank_for_log():
|
|
msg = "Stop" if action_step > 0 else "Save"
|
|
action_step = abs(action_step)
|
|
if train_state.step_count <= action_step:
|
|
if self.feishu_address:
|
|
send_alert_message(
|
|
address=self.feishu_address,
|
|
message=f"training will {msg} at step_count {action_step}!\
|
|
now step_count is {train_state.step_count}",
|
|
)
|
|
|
|
return now_break, now_save_ckpt, save_type
|
|
|
|
def try_save_checkpoint(self, train_state):
|
|
if not self.enable_save_ckpt:
|
|
return False
|
|
|
|
save_ckpts, save_type = False, CheckpointType.NORMAL_CHECKPOINT
|
|
if self.oss_snapshot_freq > 1 and train_state.step_count % self.oss_snapshot_freq == 0:
|
|
save_ckpts, save_type = True, CheckpointType.SNAPSHOT_CHECKPOINT
|
|
if train_state.step_count % self.checkpoint_every == 0:
|
|
save_ckpts, save_type = True, CheckpointType.NORMAL_CHECKPOINT
|
|
now_break, singal_save_ckpts, singal_save_type = self.quit_signal_handler(train_state)
|
|
if save_ckpts is False:
|
|
save_ckpts = singal_save_ckpts
|
|
save_type = singal_save_type
|
|
|
|
if save_ckpts:
|
|
# Wait for the previous round of asynchronous upload storage to complete.
|
|
self.storage_manager.wait()
|
|
if save_type == CheckpointType.SNAPSHOT_CHECKPOINT:
|
|
# Snapshot number, with only two snapshots written alternately.
|
|
self.snapshot_counter = (self.snapshot_counter + 1) % 2
|
|
save_ckpt_folder = os.path.join(self.snapshot_ckpt_folder, f"{self.snapshot_counter}")
|
|
else:
|
|
save_ckpt_folder = os.path.join(self.save_ckpt_folder, str(train_state.step_count))
|
|
|
|
self.save_checkpoint(
|
|
folder=save_ckpt_folder,
|
|
model=self.model,
|
|
optimizer=self.optimizer,
|
|
scheduler=self.lr_scheduler,
|
|
train_state=train_state,
|
|
model_config=self.model_config,
|
|
)
|
|
|
|
return now_break
|
|
|
|
def wait_async_upload_finish(self):
|
|
"""wait for all checkpoint uploads to be completed"""
|
|
self.storage_manager.wait()
|
|
torch.distributed.barrier()
|
|
|
|
def query_latest_snapshot_step_boto3(self):
|
|
"""query_latest_snapshot_step_boto3
|
|
Returns:
|
|
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
|
|
"""
|
|
ckpt_list = self.storage_manager.get_fns(self.save_ckpt_folder)
|
|
if len(ckpt_list) == 0:
|
|
return None, None
|
|
|
|
max_normal_step = 0
|
|
ckpt_list = list(map(lambda a: int(a.strip("/")) if a.strip("/").isdigit() else 0, ckpt_list))
|
|
ckpt_list.sort(reverse=True)
|
|
for ckpt in ckpt_list:
|
|
fns_list = self.storage_manager.get_fns(os.path.join(self.save_ckpt_folder, str(ckpt)))
|
|
for fn in fns_list:
|
|
if fn.endswith(".step"):
|
|
max_normal_step = ckpt
|
|
break
|
|
if max_normal_step != 0:
|
|
break
|
|
|
|
max_normal_step = ckpt_list[0]
|
|
load_normal_ckpt_path = os.path.join(self.save_ckpt_folder, str(max_normal_step))
|
|
|
|
snapshot_path_0 = os.path.join(self.save_ckpt_folder, "snapshot", "0")
|
|
snapshot_path_1 = os.path.join(self.save_ckpt_folder, "snapshot", "1")
|
|
ckpt_list_1 = self.storage_manager.get_fns(snapshot_path_0)
|
|
ckpt_list_2 = self.storage_manager.get_fns(snapshot_path_1)
|
|
max_step_0, max_step_1 = 0, 0
|
|
for ckpt in ckpt_list_1:
|
|
ckpt = ckpt.strip("/")
|
|
if ckpt.endswith(".step"):
|
|
max_step_0 = max(max_step_0, int(ckpt.split(".")[0]))
|
|
for ckpt in ckpt_list_2:
|
|
ckpt = ckpt.strip("/")
|
|
if ckpt.endswith(".step"):
|
|
max_step_1 = max(max_step_1, int(ckpt.split(".")[0]))
|
|
|
|
snap_load_path = snapshot_path_0 if max_step_0 > max_step_1 else snapshot_path_1
|
|
snap_step = max(max_step_0, max_step_1)
|
|
load_path = snap_load_path if snap_step > max_normal_step else load_normal_ckpt_path
|
|
load_step = max(snap_step, max_normal_step)
|
|
return load_path, load_step
|
|
|
|
def query_latest_snapshot_step_local(self):
|
|
max_step, max_step_path = 0, None
|
|
for root, _, files in os.walk(self.save_ckpt_folder, followlinks=True):
|
|
for fn in files:
|
|
fn = fn.strip("/")
|
|
if fn.endswith(".step"):
|
|
# We assume that both normal ckpt and snapshot ckpt will store the '.step' file
|
|
# as an integrity flag.
|
|
step = int(fn.rsplit(".", maxsplit=1)[0])
|
|
if max_step < step:
|
|
max_step = step
|
|
max_step_path = root
|
|
|
|
return max_step_path, max_step
|
|
|
|
def query_lastest_ckpt(self):
|
|
latest_checkpoint = None
|
|
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
|
|
if self.save_ckpt_folder:
|
|
if self.save_ckpt_folder.startswith("boto3"):
|
|
latest_checkpoint, step = self.query_latest_snapshot_step_boto3()
|
|
elif self.save_ckpt_folder.startswith("local"):
|
|
latest_checkpoint, step = self.query_latest_snapshot_step_local()
|
|
else:
|
|
latest_checkpoint, step = None, 0
|
|
|
|
if latest_checkpoint is not None:
|
|
if gpc.is_rank_for_log():
|
|
logger.info(f"Found latest ckpt : {latest_checkpoint}, step: {step}")
|
|
send_alert_message(
|
|
address=self.feishu_address,
|
|
message=f"Auto restart resume from ckpt-path: '{latest_checkpoint}', step : {step}",
|
|
)
|
|
else:
|
|
if gpc.is_rank_for_log():
|
|
send_alert_message(
|
|
address=self.feishu_address,
|
|
message=f"Can't find snapshot checkpoint, use default load-ckpt path: {latest_checkpoint}",
|
|
)
|
|
|
|
return latest_checkpoint
|
|
|
|
def try_load_model(self, current_time=""):
|
|
model_load_path = None
|
|
|
|
if self.load_ckpt_folder and self.load_model_only_folder:
|
|
raise ValueError(
|
|
"Error, try to use both load_ckpt_folder and load_model_only_folder paths, \
|
|
if you only need to load model weights (for example starting an SFT task for the first time), \
|
|
set load_model_only_folder path, if you need to resume training from ckpt, \
|
|
set load_ckpt_folder or use default value \
|
|
(if is the default value, internlm will try to load the latest ckpt from save_ckpt_folder)"
|
|
)
|
|
|
|
if self.load_ckpt_folder:
|
|
logger.info(
|
|
f"===========Resume training from `{self.load_ckpt_folder}` {current_time} on host:"
|
|
f"{socket.gethostname()}==========="
|
|
)
|
|
model_load_path = self.load_ckpt_folder
|
|
elif self.load_model_only_folder:
|
|
logger.info(
|
|
f"===========SFT training from `{self.load_model_only_folder}` {current_time} on host:"
|
|
f"{socket.gethostname()}==========="
|
|
)
|
|
model_load_path = self.load_model_only_folder
|
|
else:
|
|
logger.info(
|
|
f"===========New Run {current_time} on host:{socket.gethostname()},rank={gpc.get_global_rank()},"
|
|
f"tp={gpc.get_local_rank(ParallelMode.TENSOR)},pp={gpc.get_local_rank(ParallelMode.PIPELINE)},"
|
|
f"dp={gpc.get_local_rank(ParallelMode.DATA)}==========="
|
|
)
|
|
|
|
# Loading model weights must be done before zero is initialized.
|
|
if model_load_path is not None:
|
|
load_model_checkpoint(folder=model_load_path, model=self.model)
|
|
|
|
def try_resume_traing(self, lr_scheduler, optimizer, lr, train_state, train_dl):
|
|
"""Attempt to restore the training state of the last ckpt.
|
|
|
|
Args:
|
|
lr_scheduler (_LRScheduler): lr_scheduler object.
|
|
optimizer (Optimizer): optimizer object.
|
|
lr (float): learning rate.
|
|
train_state (dict): traing states.
|
|
train_dl (DataLoader): traning dataloader object
|
|
"""
|
|
if self.load_ckpt_folder is not None:
|
|
# load lr scheduler states.
|
|
load_scheduler(self.load_ckpt_folder, lr_scheduler, optimizer, lr, train_state)
|
|
# load training states.
|
|
load_context(self.load_ckpt_folder, train_dl, train_state)
|
|
# load dataloader sampler states.
|
|
load_sampler(self.load_ckpt_folder, train_dl.batch_sampler)
|
|
# load optimzier states.
|
|
if self.load_optimizer:
|
|
load_optimizer_checkpoint(self.load_ckpt_folder, optimizer)
|
|
|
|
self.optimizer = optimizer
|
|
self.lr_scheduler = lr_scheduler
|
|
|
|
def save_checkpoint(self, folder, model, optimizer, scheduler, train_state: TrainState, model_config: Dict = None):
|
|
"""
|
|
Save checkpoint to the given folder path.
|
|
"""
|
|
|
|
start = time.time()
|
|
self.set_save_folder(folder, train_state.step_count)
|
|
torch.distributed.barrier()
|
|
if gpc.is_rank_for_log():
|
|
logger.info(f"Saving checkpoint to `{folder}` at batch count:{train_state.step_count}...")
|
|
|
|
timer("save-model").start()
|
|
save_model_checkpoint(folder=folder, model=model)
|
|
timer("save-model").stop()
|
|
|
|
timer("save-optimizer").start()
|
|
save_optimizer_checkpoint(optim=optimizer, state_path=folder)
|
|
timer("save-optimizer").stop()
|
|
|
|
if gpc.is_rank_for_log():
|
|
scheduler_states = scheduler.state_dict()
|
|
llm_save(os.path.join(folder, "schedulder.pt"), saved_obj=scheduler_states)
|
|
|
|
sampler_state = train_state.batch_sampler.state_dict()
|
|
llm_save(os.path.join(folder, "sampler.pt"), saved_obj=sampler_state)
|
|
llm_save(os.path.join(folder, "context.pt"), saved_obj=train_state.state_dict())
|
|
|
|
if model_config is not None:
|
|
llm_save(os.path.join(folder, "model_config.pt"), saved_obj=model_config)
|
|
|
|
torch.distributed.barrier()
|
|
|
|
if gpc.is_rank_for_log():
|
|
timer.log(["save-model", "save-optimizer"], logger=logger)
|
|
logger.info(f"Step: {train_state.step_count}, rank 0 save ckpt use {time.time() - start:.3f} s")
|
|
if self.storage_manager.async_mode is False:
|
|
llm_save(
|
|
os.path.join(folder, f"{train_state.step_count}.step"),
|
|
saved_obj=dict({"step": train_state.step_count}),
|
|
)
|
|
|
|
def set_save_folder(self, folder, step):
|
|
self.storage_manager.latest_save_folder = folder
|
|
self.storage_manager.latest_save_step = step
|