mirror of https://github.com/InternLM/InternLM
remove micro_bsz
parent
06e8301861
commit
fdbdfcff34
|
@ -156,10 +156,6 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
def config(self):
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def micro_bsz(self):
|
||||
return self._config.data.micro_bsz
|
||||
|
||||
@property
|
||||
def micro_num(self):
|
||||
return self._config.data.micro_num
|
||||
|
|
|
@ -30,22 +30,22 @@ def get_tensor_shape():
|
|||
if not gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
return None
|
||||
|
||||
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
||||
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
||||
if gpc.config.model.use_flash_attn:
|
||||
if gpc.config.parallel.sequence_parallel:
|
||||
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
|
||||
gpc.config.data["packed_length"] // sequence_world_size,
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
)
|
||||
else:
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
||||
gpc.config.data["packed_length"],
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
)
|
||||
else:
|
||||
tensor_shape = (
|
||||
gpc.config.data["micro_bsz"],
|
||||
gpc.config.data["packed_length"] // gpc.config.SEQ_LEN,
|
||||
gpc.config.SEQ_LEN,
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
)
|
||||
|
|
|
@ -187,7 +187,6 @@ class StaticBatchSampler:
|
|||
each increment. For example, "192 24 8" means that the batch size
|
||||
starts at 192 and increases by 24 every 8 steps. Defaults to
|
||||
"6 2 8", which corresponds to a batch size of 2 for the first 6 steps.
|
||||
micro_bsz (int): The micro-batch size. Defaults to 2.
|
||||
seed (int): The random seed for shuffling the indices. Defaults to 0.
|
||||
drop_last (bool): If True, drop the last incomplete batch. Currently only supports True. Defaults to True.
|
||||
data_rank (int): The rank of the current process in the data parallel group. Defaults to 0.
|
||||
|
@ -199,40 +198,32 @@ class StaticBatchSampler:
|
|||
datasets,
|
||||
batch_size=192,
|
||||
rampup_batch_size="6 2 8",
|
||||
micro_bsz=2,
|
||||
seed=0,
|
||||
drop_last=True,
|
||||
data_rank=0,
|
||||
data_world_size=1,
|
||||
):
|
||||
assert drop_last is True, "Currently only support drop last"
|
||||
self.raw_rampup_batch_size = rampup_batch_size
|
||||
if rampup_batch_size:
|
||||
# In the process increase to batch_size
|
||||
start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split())
|
||||
else:
|
||||
start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1
|
||||
self.raw_rampup_batch_size = rampup_batch_size
|
||||
self.start_bsz = start_bsz
|
||||
self.bsz_incre = bsz_incre
|
||||
self.incre_every = incre_every
|
||||
|
||||
if gpc.is_initialized(ParallelMode.PIPELINE):
|
||||
assert (
|
||||
batch_size - self.start_bsz
|
||||
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
|
||||
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
|
||||
assert (
|
||||
self.start_bsz % micro_bsz == 0
|
||||
), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})"
|
||||
assert (
|
||||
self.bsz_incre % micro_bsz == 0
|
||||
), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})"
|
||||
|
||||
self.batch_size = batch_size
|
||||
self.epoch = 0
|
||||
self.seed = seed
|
||||
self.rng = np.random.RandomState(seed)
|
||||
self.batch_count = 0
|
||||
self.micro_bsz = micro_bsz
|
||||
self.data_rank = data_rank
|
||||
self.data_world_size = data_world_size
|
||||
self.num_consumed_samples_in_epoch = 0
|
||||
|
@ -343,7 +334,6 @@ Vs. self.num_samples: {self.num_samples}"
|
|||
self.datasets,
|
||||
self.batch_size,
|
||||
self.raw_rampup_batch_size,
|
||||
self.micro_bsz,
|
||||
self.seed,
|
||||
drop_last=True,
|
||||
data_rank=self.data_rank,
|
||||
|
|
|
@ -149,9 +149,9 @@ class PackedDataset(torch.utils.data.Dataset):
|
|||
if index == 0:
|
||||
pre_pos = 0
|
||||
else:
|
||||
pre_pos = index * gpc.config.data["micro_bsz"]
|
||||
pre_pos = index * gpc.config.data["packed_length"] // gpc.config.SEQ_LEN
|
||||
|
||||
pos = (index + 1) * gpc.config.data["micro_bsz"]
|
||||
pos = (index + 1) * gpc.config.data["packed_length"] // gpc.config.SEQ_LEN
|
||||
return pre_pos, pos
|
||||
|
||||
def build_unpack(self, index):
|
||||
|
|
|
@ -33,7 +33,7 @@ def unpack_data(input_ids, cu_seqlens):
|
|||
|
||||
bsz = input_ids.shape[0]
|
||||
|
||||
num_sequence = gpc.config.data["micro_bsz"]
|
||||
num_sequence = gpc.config.data["packed_length"] // gpc.config.SEQ_LEN
|
||||
|
||||
outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
|
||||
|
||||
|
|
|
@ -4,6 +4,7 @@
|
|||
import argparse
|
||||
import gc
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Union
|
||||
|
||||
|
@ -100,12 +101,21 @@ def args_sanity_check():
|
|||
data = gpc.config.data
|
||||
|
||||
assert data.seq_len is not None, "'seq_len' must be given a value"
|
||||
assert data.micro_bsz is not None, "'micro_bsz' must be given a value"
|
||||
|
||||
if "packed_length" in data and gpc.is_rank_for_log():
|
||||
logger.warning("packed_length would be ignored and will be setted as seq_len * micro_bsz.")
|
||||
|
||||
if "micro_bsz" in data:
|
||||
warnings.warn(
|
||||
"The parameter `micro_bsz` will be deprecated in the future, please use `packed_length` instead, "
|
||||
"and set `packed_length` to `seq_len * micro_bsz`.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if data.get("packed_length", None) is None:
|
||||
data._add_item("packed_length", data.seq_len * data.micro_bsz)
|
||||
else:
|
||||
assert (
|
||||
data.packed_length == data.seq_len * data.micro_bsz
|
||||
), "'packed_length' must be equal to 'seq_len * micro_bsz'"
|
||||
else:
|
||||
assert data.packed_length is not None, "'packed_length' must be given a value"
|
||||
|
||||
if "micro_num" not in data:
|
||||
data._add_item("micro_num", 1)
|
||||
|
@ -154,7 +164,6 @@ def args_sanity_check():
|
|||
logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201
|
||||
logger.info(f"seq_len: {data.seq_len}")
|
||||
logger.info(f"micro_num: {data.micro_num}")
|
||||
logger.info(f"micro_bsz: {data.micro_bsz}")
|
||||
logger.info(f"packed_length: {data.packed_length}")
|
||||
logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}")
|
||||
logger.info(f"min_length: {data.min_length}")
|
||||
|
|
|
@ -246,7 +246,6 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C
|
|||
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
|
||||
batch_size=data_cfg.micro_num,
|
||||
rampup_batch_size=data_cfg.rampup_batch_size,
|
||||
micro_bsz=data_cfg.micro_bsz,
|
||||
seed=1024,
|
||||
drop_last=True,
|
||||
data_rank=gpc.get_local_rank(ParallelMode.DATA),
|
||||
|
@ -302,10 +301,9 @@ def get_validation_data_loader(
|
|||
else:
|
||||
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
|
||||
# otherwise too much data may be dropped
|
||||
batch_size = min(
|
||||
data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA)
|
||||
)
|
||||
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
|
||||
micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN
|
||||
batch_size = min(data_cfg.valid_micro_num * micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA))
|
||||
batch_size = batch_size // micro_bsz * micro_bsz
|
||||
|
||||
if batch_size == 0 and gpc.is_rank_for_log():
|
||||
logger.info(f"skip validate {val_name}.")
|
||||
|
|
|
@ -96,13 +96,12 @@ def evaluate_on_val_dls(
|
|||
):
|
||||
moe_loss = None
|
||||
with torch.inference_mode():
|
||||
micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN
|
||||
if gpc.is_using_pp():
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // data_cfg.micro_bsz
|
||||
tensor_shape = torch.Size(
|
||||
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
|
||||
)
|
||||
assert total_val_bsz % micro_bsz == 0
|
||||
num_microbatches = total_val_bsz // micro_bsz
|
||||
tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE])
|
||||
|
||||
with switch_evaluation_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
|
@ -121,8 +120,7 @@ def evaluate_on_val_dls(
|
|||
)
|
||||
else:
|
||||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||
grad_accum_size = total_val_bsz // micro_bsz
|
||||
with switch_evaluation_no_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
grad_accum_size=grad_accum_size,
|
||||
|
|
Loading…
Reference in New Issue