From fdbdfcff34de5b420cfddd3f7f86d856e3fea4f1 Mon Sep 17 00:00:00 2001 From: gaoyang07 Date: Sat, 25 Nov 2023 22:44:20 +0800 Subject: [PATCH] remove micro_bsz --- internlm/core/context/parallel_context.py | 4 ---- internlm/core/scheduler/pipeline_scheduler.py | 8 +++---- internlm/data/batch_sampler.py | 14 ++--------- internlm/data/packed_dataset.py | 4 ++-- internlm/data/utils.py | 2 +- internlm/initialize/launch.py | 23 +++++++++++++------ internlm/train/training_internlm.py | 8 +++---- internlm/utils/evaluation.py | 12 ++++------ 8 files changed, 33 insertions(+), 42 deletions(-) diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index db356a1..0e1ee45 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -156,10 +156,6 @@ class ParallelContext(metaclass=SingletonMeta): def config(self): return self._config - @property - def micro_bsz(self): - return self._config.data.micro_bsz - @property def micro_num(self): return self._config.data.micro_num diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index c851789..a85ca1c 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -30,22 +30,22 @@ def get_tensor_shape(): if not gpc.is_initialized(ParallelMode.PIPELINE): return None - if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"): + if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config, "HIDDEN_SIZE"): if gpc.config.model.use_flash_attn: if gpc.config.parallel.sequence_parallel: sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) tensor_shape = ( - gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size, + gpc.config.data["packed_length"] // sequence_world_size, gpc.config.HIDDEN_SIZE, ) else: tensor_shape = ( - gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"], + gpc.config.data["packed_length"], gpc.config.HIDDEN_SIZE, ) else: tensor_shape = ( - gpc.config.data["micro_bsz"], + gpc.config.data["packed_length"] // gpc.config.SEQ_LEN, gpc.config.SEQ_LEN, gpc.config.HIDDEN_SIZE, ) diff --git a/internlm/data/batch_sampler.py b/internlm/data/batch_sampler.py index 16fd6fc..eb3791c 100644 --- a/internlm/data/batch_sampler.py +++ b/internlm/data/batch_sampler.py @@ -187,7 +187,6 @@ class StaticBatchSampler: each increment. For example, "192 24 8" means that the batch size starts at 192 and increases by 24 every 8 steps. Defaults to "6 2 8", which corresponds to a batch size of 2 for the first 6 steps. - micro_bsz (int): The micro-batch size. Defaults to 2. seed (int): The random seed for shuffling the indices. Defaults to 0. drop_last (bool): If True, drop the last incomplete batch. Currently only supports True. Defaults to True. data_rank (int): The rank of the current process in the data parallel group. Defaults to 0. @@ -199,40 +198,32 @@ class StaticBatchSampler: datasets, batch_size=192, rampup_batch_size="6 2 8", - micro_bsz=2, seed=0, drop_last=True, data_rank=0, data_world_size=1, ): assert drop_last is True, "Currently only support drop last" + self.raw_rampup_batch_size = rampup_batch_size if rampup_batch_size: # In the process increase to batch_size start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split()) else: start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1 - self.raw_rampup_batch_size = rampup_batch_size self.start_bsz = start_bsz self.bsz_incre = bsz_incre self.incre_every = incre_every + if gpc.is_initialized(ParallelMode.PIPELINE): assert ( batch_size - self.start_bsz ) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}" - assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})" - assert ( - self.start_bsz % micro_bsz == 0 - ), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})" - assert ( - self.bsz_incre % micro_bsz == 0 - ), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})" self.batch_size = batch_size self.epoch = 0 self.seed = seed self.rng = np.random.RandomState(seed) self.batch_count = 0 - self.micro_bsz = micro_bsz self.data_rank = data_rank self.data_world_size = data_world_size self.num_consumed_samples_in_epoch = 0 @@ -343,7 +334,6 @@ Vs. self.num_samples: {self.num_samples}" self.datasets, self.batch_size, self.raw_rampup_batch_size, - self.micro_bsz, self.seed, drop_last=True, data_rank=self.data_rank, diff --git a/internlm/data/packed_dataset.py b/internlm/data/packed_dataset.py index af4c34a..04e8727 100644 --- a/internlm/data/packed_dataset.py +++ b/internlm/data/packed_dataset.py @@ -149,9 +149,9 @@ class PackedDataset(torch.utils.data.Dataset): if index == 0: pre_pos = 0 else: - pre_pos = index * gpc.config.data["micro_bsz"] + pre_pos = index * gpc.config.data["packed_length"] // gpc.config.SEQ_LEN - pos = (index + 1) * gpc.config.data["micro_bsz"] + pos = (index + 1) * gpc.config.data["packed_length"] // gpc.config.SEQ_LEN return pre_pos, pos def build_unpack(self, index): diff --git a/internlm/data/utils.py b/internlm/data/utils.py index fbcb6f7..444d9a3 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -33,7 +33,7 @@ def unpack_data(input_ids, cu_seqlens): bsz = input_ids.shape[0] - num_sequence = gpc.config.data["micro_bsz"] + num_sequence = gpc.config.data["packed_length"] // gpc.config.SEQ_LEN outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype) diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index e96d2d9..fa3ae79 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -4,6 +4,7 @@ import argparse import gc import os +import warnings from pathlib import Path from typing import Dict, Union @@ -100,12 +101,21 @@ def args_sanity_check(): data = gpc.config.data assert data.seq_len is not None, "'seq_len' must be given a value" - assert data.micro_bsz is not None, "'micro_bsz' must be given a value" - - if "packed_length" in data and gpc.is_rank_for_log(): - logger.warning("packed_length would be ignored and will be setted as seq_len * micro_bsz.") - - data._add_item("packed_length", data.seq_len * data.micro_bsz) + if "micro_bsz" in data: + warnings.warn( + "The parameter `micro_bsz` will be deprecated in the future, please use `packed_length` instead, " + "and set `packed_length` to `seq_len * micro_bsz`.", + DeprecationWarning, + stacklevel=2, + ) + if data.get("packed_length", None) is None: + data._add_item("packed_length", data.seq_len * data.micro_bsz) + else: + assert ( + data.packed_length == data.seq_len * data.micro_bsz + ), "'packed_length' must be equal to 'seq_len * micro_bsz'" + else: + assert data.packed_length is not None, "'packed_length' must be given a value" if "micro_num" not in data: data._add_item("micro_num", 1) @@ -154,7 +164,6 @@ def args_sanity_check(): logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201 logger.info(f"seq_len: {data.seq_len}") logger.info(f"micro_num: {data.micro_num}") - logger.info(f"micro_bsz: {data.micro_bsz}") logger.info(f"packed_length: {data.packed_length}") logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}") logger.info(f"min_length: {data.min_length}") diff --git a/internlm/train/training_internlm.py b/internlm/train/training_internlm.py index 1e36a21..067ce75 100644 --- a/internlm/train/training_internlm.py +++ b/internlm/train/training_internlm.py @@ -246,7 +246,6 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds], batch_size=data_cfg.micro_num, rampup_batch_size=data_cfg.rampup_batch_size, - micro_bsz=data_cfg.micro_bsz, seed=1024, drop_last=True, data_rank=gpc.get_local_rank(ParallelMode.DATA), @@ -302,10 +301,9 @@ def get_validation_data_loader( else: # making the batch_size of validate larger can speed up the evaluation, but it should not be too large, # otherwise too much data may be dropped - batch_size = min( - data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA) - ) - batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz + micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN + batch_size = min(data_cfg.valid_micro_num * micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA)) + batch_size = batch_size // micro_bsz * micro_bsz if batch_size == 0 and gpc.is_rank_for_log(): logger.info(f"skip validate {val_name}.") diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 22d998b..3e1a311 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -96,13 +96,12 @@ def evaluate_on_val_dls( ): moe_loss = None with torch.inference_mode(): + micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN if gpc.is_using_pp(): total_val_bsz = len(batch[1]) - assert total_val_bsz % data_cfg.micro_bsz == 0 - num_microbatches = total_val_bsz // data_cfg.micro_bsz - tensor_shape = torch.Size( - [data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE] - ) + assert total_val_bsz % micro_bsz == 0 + num_microbatches = total_val_bsz // micro_bsz + tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]) with switch_evaluation_pipeline_scheduler( trainer=trainer, @@ -121,8 +120,7 @@ def evaluate_on_val_dls( ) else: total_val_bsz = len(batch[1]) - assert total_val_bsz % data_cfg.micro_bsz == 0 - grad_accum_size = total_val_bsz // data_cfg.micro_bsz + grad_accum_size = total_val_bsz // micro_bsz with switch_evaluation_no_pipeline_scheduler( trainer=trainer, grad_accum_size=grad_accum_size,