Merge 'upstream/develop' into doc/add_moe__doc

pull/411/head
Qu Wenwen 2023-10-18 14:06:49 +08:00
commit 3421d1197a
9 changed files with 310 additions and 22 deletions

View File

@ -4,7 +4,7 @@ DO_ALERT = False
SEQ_LEN = 2048
HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3
MLP_RATIO = 4 / 3
NUM_LAYER = 32
VOCAB_SIZE = 103168
@ -30,6 +30,14 @@ ckpt = dict(
# 2. the 'content means what states will be loaded, support: "model", "sampler", "optimizer", "scheduler", "all"
# 3. the ckpt_type means the type of checkpoint to be loaded, now only 'normal' type is supported.
load_ckpt_info=dict(path=MODEL_ONLY_FOLDER, content=("model",), ckpt_type="internlm"),
# 'auto_resume' is designed to automatically load the latest checkpoint from 'save_ckpt_folder' when encountering
# training interruptions/hangs caused by hardware failures, using a scheduling system (such as k8s/slurm)
# with an automatic restart mechanism upon training reboot.
# Please be aware that if `auto_resume` is not set (its default value is True), it will not load the checkpoint
# path specified in `load_ckpt_info` by default.
# If you want to initialize your model weights from another model, you must set `auto_resume` to False.
# If you want to train from scratch, please set `auto_resume` to False and 'load_ckpt_info' to None.
auto_resume=True,
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
@ -43,7 +51,7 @@ data = dict(
# micro_num means the number of micro_batch contained in one gradient update
micro_num=4,
# packed_length = micro_bsz * SEQ_LEN
micro_bsz=1,
micro_bsz=2,
# defaults to the value of micro_num
valid_micro_num=4,
# defaults to 0, means disable evaluate
@ -81,8 +89,8 @@ grad_scaler = dict(
hybrid_zero_optimizer = dict(
# Enable low_level_optimzer overlap_communication
overlap_sync_grad=True,
overlap_sync_param=True,
overlap_sync_grad=False,
overlap_sync_param=False,
# bucket size for nccl communication params
reduce_bucket_size=512 * 1024 * 1024,
# grad clipping
@ -133,7 +141,7 @@ model = dict(
layer_norm_epsilon=1e-5,
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
num_experts=4,
num_experts=8,
moe_use_residual=False,
moe_gate_k=2,
)
@ -151,7 +159,7 @@ tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=2,
tensor=1,
pipeline=dict(size=1, interleaved_overlap=True),
sequence_parallel=False,
)

View File

@ -39,7 +39,7 @@ CheckpointManager
load_ckpt_folder=dict(path="local:/mnt/mfs/ckpt", content=["all",], ckpt_type="internlm"),
auto_resume=False, # disable auto-resume, internlm will load model checkpoint from the path of 'load_ckpt_folder'.
checkpoint_every=CHECKPOINT_EVERY,
async_upload=True, # async ckpt upload. (only work for boto3 ckpt)
async_upload=True, # async ckpt upload. (only work for boto3 and volc ckpt)
async_upload_tmp_folder="/dev/shm/internlm_tmp_ckpt/", # path for temporarily files during asynchronous upload.
oss_snapshot_freq=int(CHECKPOINT_EVERY / 2), # snapshot ckpt save frequency.
)
@ -67,7 +67,9 @@ InternLM对config中出现的所有存储路径都遵循以下的路径格式约
1. 如果需要使用boto3的路径需要在运行前提前导入 ``S3_ACCESS_KEY_ID````S3_SECRET_ACCESS_KEY_ID`` 这两个环境变量。
2. bucket的endpoint一般分为Inside IP和Outside IP如果可以尽量使用inside IP会获得更佳的存储速度。
2. 如果需要使用volc的路径需要在运行前提前导入 ``VOLC_ACCESS_KEY_ID````VOLC_SECRET_ACCESS_KEY_ID`` 这两个环境变量。
3. bucket的endpoint一般分为Inside IP和Outside IP如果可以尽量使用inside IP会获得更佳的存储速度。
@ -114,7 +116,7 @@ config.ckpt 中相关的参数:
- ``async_upload_tmp_folder``: 异步上传临时存储路径。参数类型 ``str/None``, 默认值为 ``/dev/shm/{JOB_NAME}_tmp_ckpt/``
需要注意的是异步上传功能仅在backend为boto3时才会有效果bcakend为local时只支持同步存储。
需要注意的是异步上传功能仅在backend为boto3或volc时才会有效果bcakend为local时只支持同步存储。
``async_upload_tmp_folder`` 设置的的原则为尽量设置为计算节点的local目录这样才可以获得最佳的异步上传速度一般来说建议为 ``/dev/shm````/nvme`` 下的路径,如果使用同步上传,则该路径可不给。

Binary file not shown.

Before

Width:  |  Height:  |  Size: 153 KiB

After

Width:  |  Height:  |  Size: 212 KiB

View File

@ -79,6 +79,10 @@ class FSDPadaptOptimizer(BaseOptimizer):
def _compute_norm_with_fsdp_flatten(self, group_id):
params = [p for p in self._fp16_param_groups[group_id] if p.untyped_storage().size() != 0]
gradients = [p.grad for p in params if p.untyped_storage().size() != 0]
norm_group = 0
if len(params) <= 0 or len(gradients) <= 0:
return norm_group
norm_group = compute_norm(gradients=gradients, parameters=params, last_stage=True)
return norm_group
@ -126,6 +130,8 @@ class FSDPadaptOptimizer(BaseOptimizer):
# create gradient for fp32 params
for group_idx in range(len(self.param_groups)):
if len(self._fp32_param_tensor_groups[group_idx]) <= 0:
continue
dtype = self._fp32_param_tensor_groups[group_idx][0].dtype
fp16_params = [p for p in self._fp16_param_groups[group_idx] if p.untyped_storage().size() != 0]
grad_fp32 = [p.grad.to(dtype) for p in fp16_params]

View File

@ -556,6 +556,18 @@ def load_optimizer_checkpoint(folder, optim):
f"Please check whether loading ckpts are saved with the HybridZeroOptimizer."
)
# compatible with old code that only have one param group, need to align with both parameter groups
if len(states["base_optim_states"]["param_groups"]) == 1:
for group in optim.param_groups:
# for new added empty group, since it has no params, just create it fakely
if len(group["params"]) == 0:
states["base_optim_states"]["param_groups"].append(group)
# for origin group, create new added attributes in recent updates
else:
saved_group = states["base_optim_states"]["param_groups"][0]
saved_group["dp_mode"] = group["dp_mode"]
saved_group["dtype"] = group["dtype"]
optim.load_state_dict(states)
del states
torch.cuda.empty_cache()
@ -598,6 +610,10 @@ def load_scheduler(ckpt_path: str, lr_scheduler, optimizer, train_state: TrainSt
lr_scheduler.load_state_dict(scheduler_states)
lr_scheduler.last_epoch = train_state.step_count + 1
# compatible with old code that only have one param group
if len(base_lrs) == 1:
base_lrs = base_lrs * len(optimizer.param_groups)
ratios = [learning_rate / lr for lr in base_lrs]
for idx, param_group in enumerate(optimizer.param_groups):
param_group["lr"] = param_group["lr"] * ratios[idx]

View File

@ -424,7 +424,9 @@ class SimpleMemoryProfiler:
layer_name, output.element_size() * output.nelement(), flush=False
)
def _activation_trace_hook_forward(self, chunk_id: int, model: Any, inputs: Any, output: torch.Tensor) -> None:
def _activation_trace_hook_forward(
self, chunk_id: int, model: Any, inputs: Any, output: Any # pylint: disable=W0613
) -> None:
"""
Hook function to trace the activation memory usage for a forward pass.
@ -437,7 +439,6 @@ class SimpleMemoryProfiler:
None
"""
del model, inputs
assert isinstance(output, torch.Tensor), f"invalid output type: {type(output)}"
if self._stoped:
return

View File

@ -25,6 +25,7 @@ from internlm.utils.logger import get_logger
try:
import boto3
import botocore
import tos
except ImportError:
pass
@ -32,6 +33,7 @@ except ImportError:
logger = get_logger(__file__)
boto3_url_re = re.compile(r"([^\.]+)\.([\d\.]+)")
volc_url_re = re.compile(r"^(.*?)\.(.*)$")
MB = 1024**2
@ -122,6 +124,47 @@ local_nvme_path: {self.local_nvme_path}"
return meta.client, meta.bucket_name, meta.file_path
class VolcMetaInfo:
"""Volc meta info for save/load etc."""
def __init__(
self,
is_async,
handler: StorageClient,
bucket_name: str,
endpoint: str,
region: str,
file_path: str,
async_upload_fn: callable,
local_nvme_path=None,
) -> None:
# all need info.
self.client = handler
self.bucket_name = bucket_name
self.file_path = file_path
# only save need info.
self.local_nvme_path = local_nvme_path
self.is_async = is_async
self.endpoint = endpoint
self.region = region
self.async_upload_fn = async_upload_fn
def __str__(self) -> str:
return f"is_async: {self.is_async}, bucket_name:{self.bucket_name}, endpoint:{self.endpoint}, \
region:{self.region}, local_nvme_path: {self.local_nvme_path}"
@staticmethod
def unpack_volc_save_meta(meta):
if meta.is_async:
return meta.client, meta.bucket_name, meta.file_path, meta.local_nvme_path
else:
return meta.client, meta.bucket_name, meta.file_path
@staticmethod
def unpack_volc_nosave_meta(meta):
return meta.client, meta.bucket_name, meta.file_path
class LocalMetaInfo:
"""Local meta info for save/load etc."""
@ -139,18 +182,22 @@ class LocalMetaInfo:
return (meta.file_path,)
def unpack_save_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
def unpack_save_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]):
if isinstance(meta, Boto3MetaInfo):
return Boto3MetaInfo.unpack_boto3_save_meta(meta)
elif isinstance(meta, VolcMetaInfo):
return VolcMetaInfo.unpack_volc_save_meta(meta)
elif isinstance(meta, LocalMetaInfo):
return LocalMetaInfo.unpack_local_save_meta(meta)
else:
raise ValueError(f"unkonwn meta info: {type(meta)}")
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, LocalMetaInfo]):
def unpack_nosave_meta(meta: Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]):
if isinstance(meta, Boto3MetaInfo):
return Boto3MetaInfo.unpack_boto3_nosave_meta(meta)
elif isinstance(meta, VolcMetaInfo):
return VolcMetaInfo.unpack_volc_nosave_meta(meta)
elif isinstance(meta, LocalMetaInfo):
return LocalMetaInfo.unpack_local_nosave_meta(meta)
else:
@ -170,6 +217,10 @@ def try_get_storage_backend(path: str):
if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of boto3.")
return "boto3", path
elif path.startswith("vc:"):
if gpc.is_rank_for_log():
logger.warning(f"path: '{path}' not start with backend prefix, guess it is the backend of volc.")
return "volc", path
else:
sre = path.split(":", maxsplit=1)
if len(sre) == 1:
@ -312,6 +363,143 @@ class Boto3Client(StorageClient):
raise NotImplementedError("boto3 not support delete_obj")
class VolcClient(StorageClient):
"""
VolcClient
"""
def __init__(
self,
endpoint: str,
region: str,
) -> None:
"""Volc object/file storage management class
Args:
access_key (str): Volc access key ID.
secret_key (str): Volc secret access key.
endpoint (str): Volc tos endpoint.
region (str): Volc tos region.
"""
super().__init__(tos)
try:
access_key = os.environ["VOLC_ACCESS_KEY_ID"]
secret_key = os.environ["VOLC_SECRET_ACCESS_KEY_ID"]
except KeyError as exc:
raise RuntimeError(
"Please set 'VOLC_ACCESS_KEY_ID' and 'VOLC_SECRET_ACCESS_KEY_ID'",
"using environment variable!",
) from exc
self.client = self.handler.TosClientV2(access_key, secret_key, endpoint, region)
@staticmethod
def sync_upload_fileobj(handler, bucket_name: str, fp: str, saved_obj=None, **kwargs):
assert saved_obj is not None, "saved_obj is None!"
try:
with io.BytesIO() as f:
torch.save(saved_obj, f, **kwargs)
f.seek(0)
handler.client.put_object(bucket_name, fp, content=f)
except handler.handler.exceptions.TosClientError as exc:
raise RuntimeError(
f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}"
) from exc
except handler.handler.exceptions.TosServerError as exc:
raise RuntimeError(
f"Volc Network Error: fail with server error, code: {exec.code}",
f"error with request id: {exec.request_id}",
f"error with message: {exec.message}",
f"error with http code: {exec.status_code}",
) from exc
@staticmethod
def load(handler, bucket_name: str, fp: str, **kwargs) -> Dict:
"""
Args:
fp (str): Path to save, eg. vc://opennlplab/model_weights/xxx/ddd.pt
"""
try:
object_stream = handler.client.get_object(bucket_name, fp)
buffer = io.BytesIO(object_stream.read())
states = torch.load(buffer, **kwargs)
except handler.handler.exceptions.TosClientError as exc:
raise RuntimeError(
f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}"
) from exc
except handler.handler.exceptions.TosServerError as exc:
raise RuntimeError(
f"Volc Network Error: fail with server error, code: {exec.code}",
f"error with request id: {exec.request_id}",
f"error with message: {exec.message}",
f"error with http code: {exec.status_code}",
) from exc
return states
@staticmethod
def assert_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
assert len(list(handler.client.list_objects_type2(bucket_name, prefix=fp).contents)) > 0, fp
@staticmethod
def is_fp_exists(handler, bucket_name: str, fp: str): # pylint: disable=W0613
re = handler.client.list_objects_type2(bucket_name, prefix=fp)
if hasattr(re, "contents"):
return len(list(re.contents)) > 0
else:
return False
@staticmethod
def get_fns(handler, bucket_name: str, fp: str):
if VolcClient.is_fp_exists(handler, bucket_name, fp):
folder_name_list = []
result = handler.client.list_objects_type2(bucket_name, prefix=fp)
if hasattr(result, "contents"):
for iterm in result.contents:
pth = iterm.key
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
while result.is_truncated:
result = handler.client.list_objects_type2(
bucket_name, prefix=fp, continuation_token=result.next_continuation_token
)
if hasattr(result, "contents"):
for iterm in result.contents:
pth = iterm.key
folder_name_list.append(pth.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
return list(set(folder_name_list))
else:
if gpc.is_rank_for_log():
logger.warning(f"'{fp}' not found!")
return None
@staticmethod
def async_upload_fileobj(handler, bucket_name: str, fp: str, local_nvme_path: str):
try:
handler.client.put_object_from_file(bucket_name, fp, local_nvme_path)
except handler.handler.exceptions.TosClientError as exc:
raise RuntimeError(
f"Volc Network Error: fail with client error, message:{exc.message}, cause: {exc.cause}"
) from exc
except handler.handler.exceptions.TosServerError as exc:
raise RuntimeError(
f"Volc Network Error: fail with server error, code: {exec.code}",
f"error with request id: {exec.request_id}",
f"error with message: {exec.message}",
f"error with http code: {exec.status_code}",
) from exc
except Exception as e:
raise e
@staticmethod
def delete_obj(handler, fp: str):
raise NotImplementedError("volc not support delete_obj")
class LocalClient(StorageClient):
"""
Storage Client for local NFS.
@ -388,8 +576,35 @@ def get_boto3_meta(fp: str, tmp_local_folder: str, is_async: bool) -> Boto3MetaI
)
def get_volc_meta(fp: str, tmp_local_folder: str, is_async: bool) -> VolcMetaInfo:
assert fp.startswith("vc://"), f"Path '{fp}' is not a volc url"
parts = fp.lstrip("vc://").split(os.path.sep)
match = volc_url_re.match(parts[0])
assert match is not None, f"url '{fp}' is not a valid volc url"
bucket_name, endpoint = match.group(1), match.group(2)
temp_part = endpoint.split(".")
endpoint = ".".join(temp_part[1:])
region = temp_part[1].split("-")
region = "-".join(region[1:])
if is_async:
tmp_step_file = get_tmp_file_name(tmp_local_folder, fp)
else:
tmp_step_file = None
return VolcMetaInfo(
is_async=is_async,
handler=None,
bucket_name=bucket_name,
endpoint=endpoint,
region=region,
file_path=os.path.sep.join(parts[1:]),
async_upload_fn=VolcClient.async_upload_fileobj,
local_nvme_path=tmp_step_file,
)
def get_local_meta(fp: str) -> LocalMetaInfo:
assert not fp.startswith("s3://"), f"Path '{fp}' is not a local path"
assert not fp.startswith("s3://") and not fp.startswith("vc://"), f"Path '{fp}' is not a local path"
return LocalMetaInfo(fp)
@ -430,10 +645,11 @@ class StorageManager(metaclass=SingletonMeta):
TODO: add a thread to poll the asynchronous storage state.
"""
BACKEND_TYPE = {"boto3", "local"}
BACKEND_TYPE = {"boto3", "local", "volc"}
BACKEND_INIT_METHOD = {
"boto3": Boto3Client,
"local": LocalClient,
"volc": VolcClient,
}
CLI_DICT = {}
@ -476,11 +692,12 @@ class StorageManager(metaclass=SingletonMeta):
logger.error(f'tmp_local_folder only have "{free_size}" GB free space, less then 100 GB!')
raise RuntimeError(f"Insufficient temporary storage space on {socket.gethostname()}")
def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, LocalMetaInfo]:
def _get_client(self, path: str, async_mode: bool = False) -> Union[Boto3MetaInfo, VolcMetaInfo, LocalMetaInfo]:
"""
example:
local:/path/to/checkpoint
boto3:s3://model_weights/0331/120bi
volc:vc://model_weights/0331/120bi
Args:
path (str): _description_
@ -507,10 +724,29 @@ class StorageManager(metaclass=SingletonMeta):
the proxy may make boto3 unavailable or affect performance."
)
self.has_warning = True
elif backend == "volc":
meta_info = get_volc_meta(path, self.tmp_local_folder, async_mode)
backend_key = backend + ":" + meta_info.endpoint
init_args = (
meta_info.endpoint,
meta_info.region,
)
if (
"http_proxy" in os.environ
or "https_proxy" in os.environ
or "HTTP_PROXY" in os.environ
or "HTTPS_PROXY" in os.environ
):
if not self.has_warning and gpc.is_rank_for_log():
logger.warning(
"HTTP/HTTPS proxy is detected when using volc, incorrectly setting \
the proxy may make volc unavailable or affect performance."
)
self.has_warning = True
assert backend in StorageManager.BACKEND_TYPE, f"Unkown backend: {backend}"
# boto3 backend need special treatment.
# boto3 and volc backend need special treatment.
if backend_key not in StorageManager.CLI_DICT:
StorageManager.CLI_DICT.update({backend_key: StorageManager.BACKEND_INIT_METHOD[backend](*init_args)})
@ -527,11 +763,10 @@ class StorageManager(metaclass=SingletonMeta):
return meta.client.get_fns(*unpack_nosave_meta(meta))
def save(self, save_path: str, to_save_obj: Any, async_upload=None, **kwargs):
if async_upload is None:
async_upload = self.async_mode
if not save_path.startswith("boto3:"):
if not save_path.startswith("boto3:") and not save_path.startswith("volc:"):
async_upload = False
meta = self._get_client(save_path, async_upload)
@ -554,6 +789,7 @@ class StorageManager(metaclass=SingletonMeta):
def load(self, load_path: str, **kwargs) -> Any:
self.wait()
meta = self._get_client(path=load_path)
return meta.client.load(*unpack_nosave_meta(meta), **kwargs)
def delete_obj(self, fp: str):

View File

@ -42,8 +42,8 @@ def init_tb_writer(
# dir of the last task by 'make_launch_script.sh'.
# If we load ckpt, 'resume_tb_folder' will be overwritten as the
# reloaded 'train_state.resume_tb_folder'.s
if resume_tb_folder is not None:
assert len(resume_tb_folder) > 0 and resume_tb_folder != "/"
if resume_tb_folder is not None and len(resume_tb_folder) > 0:
assert resume_tb_folder != "/"
if not os.path.exists(resume_tb_folder):
logger.error(
f"Can't found resume_tb_folder{resume_tb_folder}, \

View File

@ -6,9 +6,9 @@ import torch
from internlm.core.context.parallel_context import Config
from internlm.initialize.launch import get_config_value
from tests.test_utils.common_fixture import ( # noqa # pylint: disable=unused-import
ASYNC_TMP_FOLDER,
BOTO_SAVE_PATH,
LOCAL_SAVE_PATH,
VOLC_SAVE_PATH,
del_tmp_file,
init_dist_and_model,
reset_singletons,
@ -48,6 +48,22 @@ ckpt_config_list = [
save_folder=LOCAL_SAVE_PATH,
test_id=3,
),
# async volc
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=ASYNC_TMP_FOLDER,
async_upload=True,
save_folder=VOLC_SAVE_PATH,
test_id=4,
),
# sync volc
dict(
enable_save_ckpt=True,
async_upload_tmp_folder=None,
async_upload=False,
save_folder=VOLC_SAVE_PATH,
test_id=5,
),
]
@ -97,6 +113,9 @@ internlm_ckpt_path = [
("/mnt/ckpt/", "local", "/mnt/ckpt/"),
("./ckpt/", "local", "./ckpt/"),
("s3://oss_bucket/", "boto3", "s3://oss_bucket/"),
("volc:vc://oss_bucket/", "volc", "vc://oss_bucket/"),
("volc:oss_bucket/", "volc", "oss_bucket/"),
("vc://oss_bucket/", "volc", "vc://oss_bucket/"),
]