From 5ecb6aa7124a2e06ce718a8831ceeb3e62f071c9 Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Wed, 13 Dec 2023 14:48:32 +0800 Subject: [PATCH] fix(pp): fix no-packed dataset load micro batch error (#538) * fix(pp): fix no-packed dataset load micro batch error * fix based on comment --- internlm/core/engine.py | 5 ++ internlm/core/scheduler/base_scheduler.py | 16 +++-- .../core/scheduler/no_pipeline_scheduler.py | 4 +- internlm/core/scheduler/pipeline_scheduler.py | 28 +++++++-- internlm/utils/common.py | 11 ++++ tests/test_data/test_batch_sampler.py | 60 ++++++++++++++++++- 6 files changed, 109 insertions(+), 15 deletions(-) diff --git a/internlm/core/engine.py b/internlm/core/engine.py index a372b9e..eb33e35 100644 --- a/internlm/core/engine.py +++ b/internlm/core/engine.py @@ -185,6 +185,11 @@ class Engine: if to_gpu: batch_data = move_to_device(batch_data) + + # For packed-dataset, batch_data is (micro_num, micro_bsz*seq_len), + # therefore 'batch_size' is equal to 'micro_num' + # For nopacked-dataset, batch_data is (micro_num*micro_bsz, seq_len), + # therefore 'batch_size' is equal to 'micro_num*micro_bsz' batch_size = get_batch_size(batch_data) return batch_data, batch_size diff --git a/internlm/core/scheduler/base_scheduler.py b/internlm/core/scheduler/base_scheduler.py index 14c3457..6e19425 100644 --- a/internlm/core/scheduler/base_scheduler.py +++ b/internlm/core/scheduler/base_scheduler.py @@ -4,7 +4,7 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine from abc import ABC, abstractmethod -from typing import Any, Callable, Iterable, Optional +from typing import Any, Callable, Dict, Iterable, Optional import torch @@ -36,10 +36,18 @@ class BaseScheduler(ABC): """ pass - def _load_micro_batch(self, data, label, offset): + def _load_micro_batch(self, data: Dict, label: torch.Tensor, offset: int, bsz_stride: int): + """ + For pp, it will cut one fully batch into micro batch in pipeline concept. + For nopp, it will cut one fully batch into small batch based on gradient accumulate size. + + A special case is that pp uses a 'non-packed-dateset' (such as evaluation dataset), + so the data of batch is unpacked and 'bsz_stride' is equal to 'micro_bsz'. + In all other cases 'bsz_stride' should be equal to 1. + """ assert isinstance(data, dict) and isinstance(label, torch.Tensor) - micro_batch_data = {k: v[offset : offset + 1] for k, v in data.items()} - micro_batch_label = label[offset : offset + 1] + micro_batch_data = {k: v[offset : offset + bsz_stride] for k, v in data.items()} + micro_batch_label = label[offset : offset + bsz_stride] return micro_batch_data, micro_batch_label diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 24d94ef..79a6f62 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -72,7 +72,7 @@ class NonPipelineScheduler(BaseScheduler): label (Any): The label to be loaded. """ - _data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset) + _data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset, bsz_stride=1) self._grad_accum_offset += 1 if self.data_process_func: @@ -167,7 +167,7 @@ class NonPipelineScheduler(BaseScheduler): forward_only or return_loss ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." - batch_data, actual_batch_size = engine.load_batch(data_iter) + batch_data, actual_batch_size = engine.load_batch(data_iter) # actual_batch_size is micro_num self._grad_accum_size = actual_batch_size # Rampup or variable bsz size. diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index 550584e..5b864ff 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -14,7 +14,11 @@ from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.core.naive_amp import NaiveAMPModel -from internlm.utils.common import get_current_device, move_to_device +from internlm.utils.common import ( + check_data_is_packed, + get_current_device, + move_to_device, +) from internlm.utils.logger import get_logger from internlm.utils.timeout import llm_timeout @@ -186,17 +190,28 @@ class PipelineScheduler(BaseScheduler): raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}") def load_batch(self, engine, data_iter): - # Pipeline schedule just puts data in memory + # Pipeline schedule just puts data in memory, batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False) - self.num_microbatches = actual_batch_size # Rampup or variable bsz size. + # Even if 'use_flash_attn' is False, the data seen when the 'load_batch' is called is still packed, + # because internlm's current train dataset is packed, even using dummy data. + # The unpack operation is performed in load_micro_batch(). + if check_data_is_packed(batch_data): + micro_num = actual_batch_size + else: + micro_num = actual_batch_size // gpc.config.data["micro_bsz"] + self.microbatch_offset = 0 self.batch_size = actual_batch_size self.batch_data, self.batch_label = batch_data + self.bsz_stride = self.batch_size // micro_num + # 'num_microbatches' is no longer an initialization parameter, + # but is determined on the fly by the Scheduler. + self.num_microbatches = micro_num # Rampup or variable bsz size. def load_micro_batch(self): micro_batch_data, micro_batch_label = self._load_micro_batch( - data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset + data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, bsz_stride=self.bsz_stride ) if self.data_process_func: micro_batch_data["input_ids"] = self.data_process_func( @@ -208,7 +223,7 @@ class PipelineScheduler(BaseScheduler): micro_batch_data.pop("indexes") micro_batch_data["label"] = micro_batch_label - self.microbatch_offset += 1 + self.microbatch_offset += self.bsz_stride return move_to_device(micro_batch_data) @@ -787,9 +802,10 @@ class InterleavedPipelineScheduler(PipelineScheduler): data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset[model_chunk_id], + bsz_stride=self.bsz_stride, ) micro_batch_data["label"] = micro_batch_label - self.microbatch_offset[model_chunk_id] += 1 + self.microbatch_offset[model_chunk_id] += self.bsz_stride return move_to_device(micro_batch_data) def _forward_step(self, engine, chunk_id): diff --git a/internlm/utils/common.py b/internlm/utils/common.py index 6c9cc68..a20b61d 100644 --- a/internlm/utils/common.py +++ b/internlm/utils/common.py @@ -110,6 +110,17 @@ def get_batch_size(data): return data[list(data.keys())[0]].size(0) +def check_data_is_packed(data): + if isinstance(data, torch.Tensor): + return False + elif isinstance(data, (list, tuple)): + if isinstance(data[0], dict): + return "indexes" in data[0] + return False + elif isinstance(data, dict): + return "indexes" in data[0] + + def filter_kwargs(func, kwargs): sig = inspect.signature(func) return {k: v for k, v in kwargs.items() if k in sig.parameters} diff --git a/tests/test_data/test_batch_sampler.py b/tests/test_data/test_batch_sampler.py index 2ad10c0..eb835b2 100644 --- a/tests/test_data/test_batch_sampler.py +++ b/tests/test_data/test_batch_sampler.py @@ -10,7 +10,15 @@ from internlm.core.context import global_context as gpc # from internlm.core.context import ParallelMode from internlm.core.context.parallel_context import Config from internlm.core.trainer import TrainState -from internlm.train import get_train_data_loader, load_new_batch +from internlm.train import ( + get_train_data_loader, + get_validation_data_loader, + load_new_batch, +) +from internlm.utils.evaluation import ( + switch_evaluation_no_pipeline_scheduler, + switch_evaluation_pipeline_scheduler, +) # from internlm.core.context.parallel_context import global_context as gpc from tests.test_core.utils import build_environment, init_model_and_optim @@ -20,7 +28,7 @@ use_flash_attens = [True, False] answers = [[1] * 8, [1, 1, 1, 1, 2, 2, 2, 2], [4] * 8, [2, 2, 4, 4, 6, 6, 8, 8]] test_case_group = [ # format: micro_nums, rampup_batch_size, should sccuess, answer, pp size, sql len - # (1, "1 1 1", True, answers[0], 1, 8), + (1, "1 1 1", True, answers[0], 1, 8), (4, "1 1 4", True, answers[1], 1, 8), (4, None, True, answers[2], 1, 8), (8, "2 2 2", True, answers[3], 1, 8), @@ -28,6 +36,11 @@ test_case_group = [ ] +class DummyTrainer: + def __init__(self, scheduler) -> None: + self.schedule = scheduler + + def do_warmup(args): rank, worldsize, init_config, should_sccuess, answer = args build_environment(rank, worldsize, init_config) @@ -44,9 +57,11 @@ def do_warmup(args): ) scheduler.pre_processing(engine) engine.train() + trainer = DummyTrainer(scheduler) try: train_dl, _ = get_train_data_loader(num_worker=0) + val_dls = get_validation_data_loader(num_worker=0) except Exception as e: assert should_sccuess is False, f"{e}" else: @@ -105,6 +120,38 @@ def do_warmup(args): tokens_num == answer[i] * gpc.config.data.seq_len * micro_bsz ), f"{tokens_num} == {answer[i] * gpc.config.data.seq_len * micro_bsz}" + # test no-packed datasets. + for _, val_dl in val_dls.items(): + for _, batch in enumerate(val_dl): + if gpc.is_using_pp(): + total_val_bsz = len(batch[1]) + batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16) + assert total_val_bsz % micro_bsz == 0 + num_microbatches = total_val_bsz // micro_bsz + tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1]]) # toy model hidden size is 8. + with switch_evaluation_pipeline_scheduler( + trainer=trainer, + num_microbatches=num_microbatches, + tensor_shape=tensor_shape, + metric_hook_list=[], + ): + scheduler.forward_backward_step( + engine, batch, forward_only=True, return_loss=False, return_output_label=False + ) + else: + total_val_bsz = len(batch[1]) + batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16) + assert total_val_bsz % micro_bsz == 0 + grad_accum_size = total_val_bsz // micro_bsz + with switch_evaluation_no_pipeline_scheduler( + trainer=trainer, + grad_accum_size=grad_accum_size, + metric_hook_list=[], + ): + scheduler.forward_backward_step( + engine, batch, forward_only=True, return_loss=False, return_output_label=False + ) + @pytest.mark.parametrize("use_flash_atten_case", use_flash_attens) @pytest.mark.parametrize("group_case", test_case_group) @@ -121,7 +168,14 @@ def test_warmup(use_flash_atten_case, group_case, micro_bsz_case): sequence_parallel=False, tensor=1, ), - data=dict(train_folder=None, pack_sample_into_one=False, min_length=0, total_steps=8), + data=dict( + train_folder=None, + valid_folder=None, + valid_micro_num=4, + pack_sample_into_one=False, + min_length=0, + total_steps=8, + ), model=dict( dtype=torch.bfloat16, ),