diff --git a/configs/7B_MoE4_sft.py b/configs/7B_MoE4_sft.py index 92a93d0..cc94cdc 100644 --- a/configs/7B_MoE4_sft.py +++ b/configs/7B_MoE4_sft.py @@ -145,18 +145,18 @@ model = dict( moe_use_residual=False, moe_gate_k=2, ) -""" -zero1 parallel: - 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. - 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. - For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -pipeline parallel (dict): - 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. -tensor parallel: tensor parallel size, usually the number of GPUs per node. -""" + +# zero1 parallel: +# 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, +# so parameters will be divided within the range of dp. +# 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. +# 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. +# For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +# pipeline parallel (dict): +# 1. size: int, the size of pipeline parallel. +# 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. +# tensor parallel: tensor parallel size, usually the number of GPUs per node. + parallel = dict( zero1=dict(size=-1, fsdp=False), tensor=1, @@ -176,4 +176,8 @@ monitor = dict( ), ) -model_type = "INTERNLM_MoE" \ No newline at end of file +model_type = "INTERNLM_MoE" + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 7d945b4..c0a9bc8 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -146,18 +146,18 @@ model = dict( use_flash_attn=True, num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used. ) -""" -zero1 parallel: - 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, - so parameters will be divided within the range of dp. - 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. - 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. - For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. -pipeline parallel (dict): - 1. size: int, the size of pipeline parallel. - 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. -tensor parallel: tensor parallel size, usually the number of GPUs per node. -""" + +# zero1 parallel: +# 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group, +# so parameters will be divided within the range of dp. +# 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters. +# 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size. +# For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8. +# pipeline parallel (dict): +# 1. size: int, the size of pipeline parallel. +# 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler. +# tensor parallel: tensor parallel size, usually the number of GPUs per node. + parallel = dict( zero1=dict(size=8, fsdp=False), tensor=1, @@ -177,3 +177,7 @@ monitor = dict( alert_file_path=f"llm_alter/{JOB_NAME}_alert.log", ), ) + +# metric_dtype can be "fp32" or other string +# only when set to "fp32" will use fp32 to calc in metrics +# metric_dtype = "fp32" diff --git a/internlm/core/engine.py b/internlm/core/engine.py index a372b9e..eb33e35 100644 --- a/internlm/core/engine.py +++ b/internlm/core/engine.py @@ -185,6 +185,11 @@ class Engine: if to_gpu: batch_data = move_to_device(batch_data) + + # For packed-dataset, batch_data is (micro_num, micro_bsz*seq_len), + # therefore 'batch_size' is equal to 'micro_num' + # For nopacked-dataset, batch_data is (micro_num*micro_bsz, seq_len), + # therefore 'batch_size' is equal to 'micro_num*micro_bsz' batch_size = get_batch_size(batch_data) return batch_data, batch_size diff --git a/internlm/core/scheduler/base_scheduler.py b/internlm/core/scheduler/base_scheduler.py index 14c3457..6e19425 100644 --- a/internlm/core/scheduler/base_scheduler.py +++ b/internlm/core/scheduler/base_scheduler.py @@ -4,7 +4,7 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Dict, Iterable, Optional import torch @@ -36,10 +36,18 @@ class BaseScheduler(ABC): """ pass - def _load_micro_batch(self, data, label, offset): + def _load_micro_batch(self, data: Dict, label: torch.Tensor, offset: int, bsz_stride: int): + """ + For pp, it will cut one fully batch into micro batch in pipeline concept. + For nopp, it will cut one fully batch into small batch based on gradient accumulate size. + + A special case is that pp uses a 'non-packed-dateset' (such as evaluation dataset), + so the data of batch is unpacked and 'bsz_stride' is equal to 'micro_bsz'. + In all other cases 'bsz_stride' should be equal to 1. + """ assert isinstance(data, dict) and isinstance(label, torch.Tensor) - micro_batch_data = {k: v[offset : offset + 1] for k, v in data.items()} - micro_batch_label = label[offset : offset + 1] + micro_batch_data = {k: v[offset : offset + bsz_stride] for k, v in data.items()} + micro_batch_label = label[offset : offset + bsz_stride] return micro_batch_data, micro_batch_label diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 24d94ef..79a6f62 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -72,7 +72,7 @@ class NonPipelineScheduler(BaseScheduler): label (Any): The label to be loaded. """ - _data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset) + _data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset, bsz_stride=1) self._grad_accum_offset += 1 if self.data_process_func: @@ -167,7 +167,7 @@ class NonPipelineScheduler(BaseScheduler): forward_only or return_loss ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." - batch_data, actual_batch_size = engine.load_batch(data_iter) + batch_data, actual_batch_size = engine.load_batch(data_iter) # actual_batch_size is micro_num self._grad_accum_size = actual_batch_size # Rampup or variable bsz size. diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 550584e..5b864ff 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -14,7 +14,11 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.core.naive_amp import NaiveAMPModel -from internlm.utils.common import get_current_device, move_to_device +from internlm.utils.common import ( + check_data_is_packed, + get_current_device, + move_to_device, +) from internlm.utils.logger import get_logger from internlm.utils.timeout import llm_timeout @@ -186,17 +190,28 @@ class PipelineScheduler(BaseScheduler): raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") def load_batch(self, engine, data_iter): - # Pipeline schedule just puts data in memory + # Pipeline schedule just puts data in memory, batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False) - self.num_microbatches = actual_batch_size # Rampup or variable bsz size. + # Even if 'use_flash_attn' is False, the data seen when the 'load_batch' is called is still packed, + # because internlm's current train dataset is packed, even using dummy data. + # The unpack operation is performed in load_micro_batch(). + if check_data_is_packed(batch_data): + micro_num = actual_batch_size + else: + micro_num = actual_batch_size // gpc.config.data["micro_bsz"] + self.microbatch_offset = 0 self.batch_size = actual_batch_size self.batch_data, self.batch_label = batch_data + self.bsz_stride = self.batch_size // micro_num + # 'num_microbatches' is no longer an initialization parameter, + # but is determined on the fly by the Scheduler. + self.num_microbatches = micro_num # Rampup or variable bsz size. def load_micro_batch(self): micro_batch_data, micro_batch_label = self._load_micro_batch( - data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset + data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, bsz_stride=self.bsz_stride ) if self.data_process_func: micro_batch_data["input_ids"] = self.data_process_func( @@ -208,7 +223,7 @@ class PipelineScheduler(BaseScheduler): micro_batch_data.pop("indexes") micro_batch_data["label"] = micro_batch_label - self.microbatch_offset += 1 + self.microbatch_offset += self.bsz_stride return move_to_device(micro_batch_data) @@ -787,9 +802,10 @@ class InterleavedPipelineScheduler(PipelineScheduler): data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset[model_chunk_id], + bsz_stride=self.bsz_stride, ) micro_batch_data["label"] = micro_batch_label - self.microbatch_offset[model_chunk_id] += 1 + self.microbatch_offset[model_chunk_id] += self.bsz_stride return move_to_device(micro_batch_data) def _forward_step(self, engine, chunk_id): diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index 1f54d06..704d2d6 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -26,7 +26,11 @@ class AccPerplex: self.device = device self.right = torch.Tensor([0]).to(device=device) self.total = torch.Tensor([0]).to(device=device) - self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float) + self.metric_dtype = torch.float if gpc.config.get("metric_dtype", None) == "fp32" else None + if self.metric_dtype is not None: + self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype) + else: + self.total_log_probs = torch.Tensor([0]).to(device=device) self.tp_pg = tp_pg self.dp_pg = dp_pg self.tp_local_rank = torch.distributed.get_rank(self.tp_pg) @@ -128,8 +132,9 @@ class AccPerplex: # All reduce is needed to get the chunks from other GPUs. torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg) - predicted_logits = predicted_logits.to(dtype=torch.float) - shift_logits = shift_logits.to(dtype=torch.float) + if self.metric_dtype is not None: + predicted_logits = predicted_logits.to(dtype=self.metric_dtype) + shift_logits = shift_logits.to(dtype=self.metric_dtype) pred_exp_logits = torch.exp(predicted_logits) # Sum of exponential of logits along vocab dimension across all GPUs. diff --git a/internlm/solver/optimizer/hybrid_zero_optim.py b/internlm/solver/optimizer/hybrid_zero_optim.py index 01b40ab..eb7aae3 100644 --- a/internlm/solver/optimizer/hybrid_zero_optim.py +++ b/internlm/solver/optimizer/hybrid_zero_optim.py @@ -219,10 +219,7 @@ class HybridZeroOptimizer(BaseOptimizer): # flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled. self.skip_grad_reduce = False - # reduction hook is only used if overlapping communication - # if it is stage 1 without overlapping, no hook will be attached - if self._overlap_sync_grad: - self._attach_reduction_hook() + self._attach_reduction_hook() @property def zero_local_rank(self): @@ -321,12 +318,15 @@ class HybridZeroOptimizer(BaseOptimizer): # if sequence_parallel is True, # the grad of norm should be all-reduce across the tp process group - if gpc.config.parallel.sequence_parallel is True: - if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True: - accum_grad_obj_sp = get_grad_accumulate_object(param) - accum_grad_obj_sp.register_hook(reduce_grad_hook_sp) + if ( + gpc.config.parallel.sequence_parallel is True + and hasattr(param, IS_SEQUENCE_PARALLEL) + and getattr(param, IS_SEQUENCE_PARALLEL) is True + ): + accum_grad_obj.register_hook(reduce_grad_hook_sp) - accum_grad_obj.register_hook(reduce_grad_hook) + if self._overlap_sync_grad: + accum_grad_obj.register_hook(reduce_grad_hook) _define_and_attach(param, reduce_rank) diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 6c9cc68..a20b61d 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -110,6 +110,17 @@ def get_batch_size(data): return data[list(data.keys())[0]].size(0) +def check_data_is_packed(data): + if isinstance(data, torch.Tensor): + return False + elif isinstance(data, (list, tuple)): + if isinstance(data[0], dict): + return "indexes" in data[0] + return False + elif isinstance(data, dict): + return "indexes" in data[0] + + def filter_kwargs(func, kwargs): sig = inspect.signature(func) return {k: v for k, v in kwargs.items() if k in sig.parameters} diff --git a/internlm/utils/model_checkpoint.py b/internlm/utils/model_checkpoint.py index c4b0c3c..bf0b9e9 100644 --- a/internlm/utils/model_checkpoint.py +++ b/internlm/utils/model_checkpoint.py @@ -1016,7 +1016,8 @@ now step_count is {train_state.step_count}", torch.distributed.barrier() def query_latest_snapshot_step_boto3(self): - """query_latest_snapshot_step_boto3 + """Query the latest snapshot step from the storage backend. + Currently, we only support the following storage backends: boto3, oss2 and volc. Returns: Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return. """ @@ -1074,6 +1075,7 @@ now step_count is {train_state.step_count}", return load_path, max(snap_step, max_normal_step) def query_latest_snapshot_step_local(self): + """Query the latest snapshot step from the local file system.""" max_step, max_step_path = 0, None save_ckpt_folder = self.save_ckpt_folder.split(":")[1] for root, _, files in os.walk(save_ckpt_folder, followlinks=True): @@ -1090,18 +1092,22 @@ now step_count is {train_state.step_count}", return max_step_path, max_step def query_lastest_ckpt(self): + """Query the latest ckpt via the storage backend.""" latest_ckpt, step = None, -1 # Training was automatically restarted by the process, forcing the latest snapshot to be read. if self.save_ckpt_folder: backend, _ = try_get_storage_backend(self.save_ckpt_folder) - if backend == "boto3": + if backend in ["boto3", "oss2", "volc"]: latest_ckpt, step = self.query_latest_snapshot_step_boto3() - if latest_ckpt and not latest_ckpt.startswith("boto3:"): - latest_ckpt = ":".join(["boto3", latest_ckpt]) elif backend == "local": latest_ckpt, step = self.query_latest_snapshot_step_local() - if latest_ckpt and not latest_ckpt.startswith("local:"): - latest_ckpt = ":".join(["local", latest_ckpt]) + else: + raise NotImplementedError( + f"Unsupported backend: {backend}, " "Currently only support `boto3`, `oss2`, `volc` and `local`" + ) + + if latest_ckpt and not latest_ckpt.startswith(backend + ":"): + latest_ckpt = ":".join([backend, latest_ckpt]) if gpc.is_rank_for_log(): logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...") diff --git a/internlm/utils/storage_manager.py b/internlm/utils/storage_manager.py index 151af04..53a4e37 100644 --- a/internlm/utils/storage_manager.py +++ b/internlm/utils/storage_manager.py @@ -739,10 +739,9 @@ class AliClient(StorageClient): if AliClient.is_fp_exists(handler, fp): folder_name_list = [] for obj in handler.handler.ObjectIteratorV2(handler.client, prefix=fp): - folder_name_list.append(obj.key.split("/")[-1]) + folder_name_list.append(obj.key.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0]) return list(set(folder_name_list)) - else: if is_rank_for_log(): logger.warning(f"'{fp}' not found!") diff --git a/tests/test_data/test_batch_sampler.py b/tests/test_data/test_batch_sampler.py index 2ad10c0..eb835b2 100644 --- a/tests/test_data/test_batch_sampler.py +++ b/tests/test_data/test_batch_sampler.py @@ -10,7 +10,15 @@ from internlm.core.context import global_context as gpc # from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import Config from internlm.core.trainer import TrainState -from internlm.train import get_train_data_loader, load_new_batch +from internlm.train import ( + get_train_data_loader, + get_validation_data_loader, + load_new_batch, +) +from internlm.utils.evaluation import ( + switch_evaluation_no_pipeline_scheduler, + switch_evaluation_pipeline_scheduler, +) # from internlm.core.context.parallel_context import global_context as gpc from tests.test_core.utils import build_environment, init_model_and_optim @@ -20,7 +28,7 @@ use_flash_attens = [True, False] answers = [[1] * 8, [1, 1, 1, 1, 2, 2, 2, 2], [4] * 8, [2, 2, 4, 4, 6, 6, 8, 8]] test_case_group = [ # format: micro_nums, rampup_batch_size, should sccuess, answer, pp size, sql len - # (1, "1 1 1", True, answers[0], 1, 8), + (1, "1 1 1", True, answers[0], 1, 8), (4, "1 1 4", True, answers[1], 1, 8), (4, None, True, answers[2], 1, 8), (8, "2 2 2", True, answers[3], 1, 8), @@ -28,6 +36,11 @@ test_case_group = [ ] +class DummyTrainer: + def __init__(self, scheduler) -> None: + self.schedule = scheduler + + def do_warmup(args): rank, worldsize, init_config, should_sccuess, answer = args build_environment(rank, worldsize, init_config) @@ -44,9 +57,11 @@ def do_warmup(args): ) scheduler.pre_processing(engine) engine.train() + trainer = DummyTrainer(scheduler) try: train_dl, _ = get_train_data_loader(num_worker=0) + val_dls = get_validation_data_loader(num_worker=0) except Exception as e: assert should_sccuess is False, f"{e}" else: @@ -105,6 +120,38 @@ def do_warmup(args): tokens_num == answer[i] * gpc.config.data.seq_len * micro_bsz ), f"{tokens_num} == {answer[i] * gpc.config.data.seq_len * micro_bsz}" + # test no-packed datasets. + for _, val_dl in val_dls.items(): + for _, batch in enumerate(val_dl): + if gpc.is_using_pp(): + total_val_bsz = len(batch[1]) + batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16) + assert total_val_bsz % micro_bsz == 0 + num_microbatches = total_val_bsz // micro_bsz + tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1]]) # toy model hidden size is 8. + with switch_evaluation_pipeline_scheduler( + trainer=trainer, + num_microbatches=num_microbatches, + tensor_shape=tensor_shape, + metric_hook_list=[], + ): + scheduler.forward_backward_step( + engine, batch, forward_only=True, return_loss=False, return_output_label=False + ) + else: + total_val_bsz = len(batch[1]) + batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16) + assert total_val_bsz % micro_bsz == 0 + grad_accum_size = total_val_bsz // micro_bsz + with switch_evaluation_no_pipeline_scheduler( + trainer=trainer, + grad_accum_size=grad_accum_size, + metric_hook_list=[], + ): + scheduler.forward_backward_step( + engine, batch, forward_only=True, return_loss=False, return_output_label=False + ) + @pytest.mark.parametrize("use_flash_atten_case", use_flash_attens) @pytest.mark.parametrize("group_case", test_case_group) @@ -121,7 +168,14 @@ def test_warmup(use_flash_atten_case, group_case, micro_bsz_case): sequence_parallel=False, tensor=1, ), - data=dict(train_folder=None, pack_sample_into_one=False, min_length=0, total_steps=8), + data=dict( + train_folder=None, + valid_folder=None, + valid_micro_num=4, + pack_sample_into_one=False, + min_length=0, + total_steps=8, + ), model=dict( dtype=torch.bfloat16, ),