Merge branch 'develop' of https://github.com/InternLM/InternLM into hf_llama

pull/539/head
lijiaxing 2023-12-13 14:54:37 +08:00
commit d7555e8216
12 changed files with 172 additions and 60 deletions

View File

@ -145,18 +145,18 @@ model = dict(
moe_use_residual=False,
moe_gate_k=2,
)
"""
zero1 parallel:
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
so parameters will be divided within the range of dp.
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
# zero1 parallel:
# 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
# so parameters will be divided within the range of dp.
# 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
# 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
# For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
# pipeline parallel (dict):
# 1. size: int, the size of pipeline parallel.
# 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
# tensor parallel: tensor parallel size, usually the number of GPUs per node.
parallel = dict(
zero1=dict(size=-1, fsdp=False),
tensor=1,
@ -176,4 +176,8 @@ monitor = dict(
),
)
model_type = "INTERNLM_MoE"
model_type = "INTERNLM_MoE"
# metric_dtype can be "fp32" or other string
# only when set to "fp32" will use fp32 to calc in metrics
# metric_dtype = "fp32"

View File

@ -146,18 +146,18 @@ model = dict(
use_flash_attn=True,
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
)
"""
zero1 parallel:
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
so parameters will be divided within the range of dp.
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
pipeline parallel (dict):
1. size: int, the size of pipeline parallel.
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
tensor parallel: tensor parallel size, usually the number of GPUs per node.
"""
# zero1 parallel:
# 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
# so parameters will be divided within the range of dp.
# 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
# 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
# For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
# pipeline parallel (dict):
# 1. size: int, the size of pipeline parallel.
# 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
# tensor parallel: tensor parallel size, usually the number of GPUs per node.
parallel = dict(
zero1=dict(size=8, fsdp=False),
tensor=1,
@ -177,3 +177,7 @@ monitor = dict(
alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
),
)
# metric_dtype can be "fp32" or other string
# only when set to "fp32" will use fp32 to calc in metrics
# metric_dtype = "fp32"

View File

@ -185,6 +185,11 @@ class Engine:
if to_gpu:
batch_data = move_to_device(batch_data)
# For packed-dataset, batch_data is (micro_num, micro_bsz*seq_len),
# therefore 'batch_size' is equal to 'micro_num'
# For nopacked-dataset, batch_data is (micro_num*micro_bsz, seq_len),
# therefore 'batch_size' is equal to 'micro_num*micro_bsz'
batch_size = get_batch_size(batch_data)
return batch_data, batch_size

View File

@ -4,7 +4,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, Optional
from typing import Any, Callable, Dict, Iterable, Optional
import torch
@ -36,10 +36,18 @@ class BaseScheduler(ABC):
"""
pass
def _load_micro_batch(self, data, label, offset):
def _load_micro_batch(self, data: Dict, label: torch.Tensor, offset: int, bsz_stride: int):
"""
For pp, it will cut one fully batch into micro batch in pipeline concept.
For nopp, it will cut one fully batch into small batch based on gradient accumulate size.
A special case is that pp uses a 'non-packed-dateset' (such as evaluation dataset),
so the data of batch is unpacked and 'bsz_stride' is equal to 'micro_bsz'.
In all other cases 'bsz_stride' should be equal to 1.
"""
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
micro_batch_data = {k: v[offset : offset + 1] for k, v in data.items()}
micro_batch_label = label[offset : offset + 1]
micro_batch_data = {k: v[offset : offset + bsz_stride] for k, v in data.items()}
micro_batch_label = label[offset : offset + bsz_stride]
return micro_batch_data, micro_batch_label

View File

@ -72,7 +72,7 @@ class NonPipelineScheduler(BaseScheduler):
label (Any): The label to be loaded.
"""
_data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset)
_data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset, bsz_stride=1)
self._grad_accum_offset += 1
if self.data_process_func:
@ -167,7 +167,7 @@ class NonPipelineScheduler(BaseScheduler):
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data, actual_batch_size = engine.load_batch(data_iter)
batch_data, actual_batch_size = engine.load_batch(data_iter) # actual_batch_size is micro_num
self._grad_accum_size = actual_batch_size # Rampup or variable bsz size.

View File

@ -14,7 +14,11 @@ from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.common import get_current_device, move_to_device
from internlm.utils.common import (
check_data_is_packed,
get_current_device,
move_to_device,
)
from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout
@ -186,17 +190,28 @@ class PipelineScheduler(BaseScheduler):
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
def load_batch(self, engine, data_iter):
# Pipeline schedule just puts data in memory
# Pipeline schedule just puts data in memory,
batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False)
self.num_microbatches = actual_batch_size # Rampup or variable bsz size.
# Even if 'use_flash_attn' is False, the data seen when the 'load_batch' is called is still packed,
# because internlm's current train dataset is packed, even using dummy data.
# The unpack operation is performed in load_micro_batch().
if check_data_is_packed(batch_data):
micro_num = actual_batch_size
else:
micro_num = actual_batch_size // gpc.config.data["micro_bsz"]
self.microbatch_offset = 0
self.batch_size = actual_batch_size
self.batch_data, self.batch_label = batch_data
self.bsz_stride = self.batch_size // micro_num
# 'num_microbatches' is no longer an initialization parameter,
# but is determined on the fly by the Scheduler.
self.num_microbatches = micro_num # Rampup or variable bsz size.
def load_micro_batch(self):
micro_batch_data, micro_batch_label = self._load_micro_batch(
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, bsz_stride=self.bsz_stride
)
if self.data_process_func:
micro_batch_data["input_ids"] = self.data_process_func(
@ -208,7 +223,7 @@ class PipelineScheduler(BaseScheduler):
micro_batch_data.pop("indexes")
micro_batch_data["label"] = micro_batch_label
self.microbatch_offset += 1
self.microbatch_offset += self.bsz_stride
return move_to_device(micro_batch_data)
@ -787,9 +802,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
data=self.batch_data,
label=self.batch_label,
offset=self.microbatch_offset[model_chunk_id],
bsz_stride=self.bsz_stride,
)
micro_batch_data["label"] = micro_batch_label
self.microbatch_offset[model_chunk_id] += 1
self.microbatch_offset[model_chunk_id] += self.bsz_stride
return move_to_device(micro_batch_data)
def _forward_step(self, engine, chunk_id):

View File

@ -26,7 +26,11 @@ class AccPerplex:
self.device = device
self.right = torch.Tensor([0]).to(device=device)
self.total = torch.Tensor([0]).to(device=device)
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float)
self.metric_dtype = torch.float if gpc.config.get("metric_dtype", None) == "fp32" else None
if self.metric_dtype is not None:
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype)
else:
self.total_log_probs = torch.Tensor([0]).to(device=device)
self.tp_pg = tp_pg
self.dp_pg = dp_pg
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
@ -128,8 +132,9 @@ class AccPerplex:
# All reduce is needed to get the chunks from other GPUs.
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
predicted_logits = predicted_logits.to(dtype=torch.float)
shift_logits = shift_logits.to(dtype=torch.float)
if self.metric_dtype is not None:
predicted_logits = predicted_logits.to(dtype=self.metric_dtype)
shift_logits = shift_logits.to(dtype=self.metric_dtype)
pred_exp_logits = torch.exp(predicted_logits)
# Sum of exponential of logits along vocab dimension across all GPUs.

View File

@ -219,10 +219,7 @@ class HybridZeroOptimizer(BaseOptimizer):
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
self.skip_grad_reduce = False
# reduction hook is only used if overlapping communication
# if it is stage 1 without overlapping, no hook will be attached
if self._overlap_sync_grad:
self._attach_reduction_hook()
self._attach_reduction_hook()
@property
def zero_local_rank(self):
@ -321,12 +318,15 @@ class HybridZeroOptimizer(BaseOptimizer):
# if sequence_parallel is True,
# the grad of norm should be all-reduce across the tp process group
if gpc.config.parallel.sequence_parallel is True:
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
accum_grad_obj_sp = get_grad_accumulate_object(param)
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
if (
gpc.config.parallel.sequence_parallel is True
and hasattr(param, IS_SEQUENCE_PARALLEL)
and getattr(param, IS_SEQUENCE_PARALLEL) is True
):
accum_grad_obj.register_hook(reduce_grad_hook_sp)
accum_grad_obj.register_hook(reduce_grad_hook)
if self._overlap_sync_grad:
accum_grad_obj.register_hook(reduce_grad_hook)
_define_and_attach(param, reduce_rank)

View File

@ -110,6 +110,17 @@ def get_batch_size(data):
return data[list(data.keys())[0]].size(0)
def check_data_is_packed(data):
if isinstance(data, torch.Tensor):
return False
elif isinstance(data, (list, tuple)):
if isinstance(data[0], dict):
return "indexes" in data[0]
return False
elif isinstance(data, dict):
return "indexes" in data[0]
def filter_kwargs(func, kwargs):
sig = inspect.signature(func)
return {k: v for k, v in kwargs.items() if k in sig.parameters}

View File

@ -1016,7 +1016,8 @@ now step_count is {train_state.step_count}",
torch.distributed.barrier()
def query_latest_snapshot_step_boto3(self):
"""query_latest_snapshot_step_boto3
"""Query the latest snapshot step from the storage backend.
Currently, we only support the following storage backends: boto3, oss2 and volc.
Returns:
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
"""
@ -1074,6 +1075,7 @@ now step_count is {train_state.step_count}",
return load_path, max(snap_step, max_normal_step)
def query_latest_snapshot_step_local(self):
"""Query the latest snapshot step from the local file system."""
max_step, max_step_path = 0, None
save_ckpt_folder = self.save_ckpt_folder.split(":")[1]
for root, _, files in os.walk(save_ckpt_folder, followlinks=True):
@ -1090,18 +1092,22 @@ now step_count is {train_state.step_count}",
return max_step_path, max_step
def query_lastest_ckpt(self):
"""Query the latest ckpt via the storage backend."""
latest_ckpt, step = None, -1
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
if self.save_ckpt_folder:
backend, _ = try_get_storage_backend(self.save_ckpt_folder)
if backend == "boto3":
if backend in ["boto3", "oss2", "volc"]:
latest_ckpt, step = self.query_latest_snapshot_step_boto3()
if latest_ckpt and not latest_ckpt.startswith("boto3:"):
latest_ckpt = ":".join(["boto3", latest_ckpt])
elif backend == "local":
latest_ckpt, step = self.query_latest_snapshot_step_local()
if latest_ckpt and not latest_ckpt.startswith("local:"):
latest_ckpt = ":".join(["local", latest_ckpt])
else:
raise NotImplementedError(
f"Unsupported backend: {backend}, " "Currently only support `boto3`, `oss2`, `volc` and `local`"
)
if latest_ckpt and not latest_ckpt.startswith(backend + ":"):
latest_ckpt = ":".join([backend, latest_ckpt])
if gpc.is_rank_for_log():
logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...")

View File

@ -739,10 +739,9 @@ class AliClient(StorageClient):
if AliClient.is_fp_exists(handler, fp):
folder_name_list = []
for obj in handler.handler.ObjectIteratorV2(handler.client, prefix=fp):
folder_name_list.append(obj.key.split("/")[-1])
folder_name_list.append(obj.key.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
return list(set(folder_name_list))
else:
if is_rank_for_log():
logger.warning(f"'{fp}' not found!")

View File

@ -10,7 +10,15 @@ from internlm.core.context import global_context as gpc
# from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import Config
from internlm.core.trainer import TrainState
from internlm.train import get_train_data_loader, load_new_batch
from internlm.train import (
get_train_data_loader,
get_validation_data_loader,
load_new_batch,
)
from internlm.utils.evaluation import (
switch_evaluation_no_pipeline_scheduler,
switch_evaluation_pipeline_scheduler,
)
# from internlm.core.context.parallel_context import global_context as gpc
from tests.test_core.utils import build_environment, init_model_and_optim
@ -20,7 +28,7 @@ use_flash_attens = [True, False]
answers = [[1] * 8, [1, 1, 1, 1, 2, 2, 2, 2], [4] * 8, [2, 2, 4, 4, 6, 6, 8, 8]]
test_case_group = [
# format: micro_nums, rampup_batch_size, should sccuess, answer, pp size, sql len
# (1, "1 1 1", True, answers[0], 1, 8),
(1, "1 1 1", True, answers[0], 1, 8),
(4, "1 1 4", True, answers[1], 1, 8),
(4, None, True, answers[2], 1, 8),
(8, "2 2 2", True, answers[3], 1, 8),
@ -28,6 +36,11 @@ test_case_group = [
]
class DummyTrainer:
def __init__(self, scheduler) -> None:
self.schedule = scheduler
def do_warmup(args):
rank, worldsize, init_config, should_sccuess, answer = args
build_environment(rank, worldsize, init_config)
@ -44,9 +57,11 @@ def do_warmup(args):
)
scheduler.pre_processing(engine)
engine.train()
trainer = DummyTrainer(scheduler)
try:
train_dl, _ = get_train_data_loader(num_worker=0)
val_dls = get_validation_data_loader(num_worker=0)
except Exception as e:
assert should_sccuess is False, f"{e}"
else:
@ -105,6 +120,38 @@ def do_warmup(args):
tokens_num == answer[i] * gpc.config.data.seq_len * micro_bsz
), f"{tokens_num} == {answer[i] * gpc.config.data.seq_len * micro_bsz}"
# test no-packed datasets.
for _, val_dl in val_dls.items():
for _, batch in enumerate(val_dl):
if gpc.is_using_pp():
total_val_bsz = len(batch[1])
batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16)
assert total_val_bsz % micro_bsz == 0
num_microbatches = total_val_bsz // micro_bsz
tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1]]) # toy model hidden size is 8.
with switch_evaluation_pipeline_scheduler(
trainer=trainer,
num_microbatches=num_microbatches,
tensor_shape=tensor_shape,
metric_hook_list=[],
):
scheduler.forward_backward_step(
engine, batch, forward_only=True, return_loss=False, return_output_label=False
)
else:
total_val_bsz = len(batch[1])
batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16)
assert total_val_bsz % micro_bsz == 0
grad_accum_size = total_val_bsz // micro_bsz
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,
metric_hook_list=[],
):
scheduler.forward_backward_step(
engine, batch, forward_only=True, return_loss=False, return_output_label=False
)
@pytest.mark.parametrize("use_flash_atten_case", use_flash_attens)
@pytest.mark.parametrize("group_case", test_case_group)
@ -121,7 +168,14 @@ def test_warmup(use_flash_atten_case, group_case, micro_bsz_case):
sequence_parallel=False,
tensor=1,
),
data=dict(train_folder=None, pack_sample_into_one=False, min_length=0, total_steps=8),
data=dict(
train_folder=None,
valid_folder=None,
valid_micro_num=4,
pack_sample_into_one=False,
min_length=0,
total_steps=8,
),
model=dict(
dtype=torch.bfloat16,
),