remove micro_bsz

pull/517/head
gaoyang07 2023-11-25 22:44:20 +08:00
parent 06e8301861
commit fdbdfcff34
8 changed files with 33 additions and 42 deletions

View File

@ -156,10 +156,6 @@ class ParallelContext(metaclass=SingletonMeta):
def config(self):
return self._config
@property
def micro_bsz(self):
return self._config.data.micro_bsz
@property
def micro_num(self):
return self._config.data.micro_num

View File

@ -30,22 +30,22 @@ def get_tensor_shape():
if not gpc.is_initialized(ParallelMode.PIPELINE):
return None
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config, "HIDDEN_SIZE"):
if gpc.config.model.use_flash_attn:
if gpc.config.parallel.sequence_parallel:
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size,
gpc.config.data["packed_length"] // sequence_world_size,
gpc.config.HIDDEN_SIZE,
)
else:
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
gpc.config.data["packed_length"],
gpc.config.HIDDEN_SIZE,
)
else:
tensor_shape = (
gpc.config.data["micro_bsz"],
gpc.config.data["packed_length"] // gpc.config.SEQ_LEN,
gpc.config.SEQ_LEN,
gpc.config.HIDDEN_SIZE,
)

View File

@ -187,7 +187,6 @@ class StaticBatchSampler:
each increment. For example, "192 24 8" means that the batch size
starts at 192 and increases by 24 every 8 steps. Defaults to
"6 2 8", which corresponds to a batch size of 2 for the first 6 steps.
micro_bsz (int): The micro-batch size. Defaults to 2.
seed (int): The random seed for shuffling the indices. Defaults to 0.
drop_last (bool): If True, drop the last incomplete batch. Currently only supports True. Defaults to True.
data_rank (int): The rank of the current process in the data parallel group. Defaults to 0.
@ -199,40 +198,32 @@ class StaticBatchSampler:
datasets,
batch_size=192,
rampup_batch_size="6 2 8",
micro_bsz=2,
seed=0,
drop_last=True,
data_rank=0,
data_world_size=1,
):
assert drop_last is True, "Currently only support drop last"
self.raw_rampup_batch_size = rampup_batch_size
if rampup_batch_size:
# In the process increase to batch_size
start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split())
else:
start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1
self.raw_rampup_batch_size = rampup_batch_size
self.start_bsz = start_bsz
self.bsz_incre = bsz_incre
self.incre_every = incre_every
if gpc.is_initialized(ParallelMode.PIPELINE):
assert (
batch_size - self.start_bsz
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
assert (
self.start_bsz % micro_bsz == 0
), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})"
assert (
self.bsz_incre % micro_bsz == 0
), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})"
self.batch_size = batch_size
self.epoch = 0
self.seed = seed
self.rng = np.random.RandomState(seed)
self.batch_count = 0
self.micro_bsz = micro_bsz
self.data_rank = data_rank
self.data_world_size = data_world_size
self.num_consumed_samples_in_epoch = 0
@ -343,7 +334,6 @@ Vs. self.num_samples: {self.num_samples}"
self.datasets,
self.batch_size,
self.raw_rampup_batch_size,
self.micro_bsz,
self.seed,
drop_last=True,
data_rank=self.data_rank,

View File

@ -149,9 +149,9 @@ class PackedDataset(torch.utils.data.Dataset):
if index == 0:
pre_pos = 0
else:
pre_pos = index * gpc.config.data["micro_bsz"]
pre_pos = index * gpc.config.data["packed_length"] // gpc.config.SEQ_LEN
pos = (index + 1) * gpc.config.data["micro_bsz"]
pos = (index + 1) * gpc.config.data["packed_length"] // gpc.config.SEQ_LEN
return pre_pos, pos
def build_unpack(self, index):

View File

@ -33,7 +33,7 @@ def unpack_data(input_ids, cu_seqlens):
bsz = input_ids.shape[0]
num_sequence = gpc.config.data["micro_bsz"]
num_sequence = gpc.config.data["packed_length"] // gpc.config.SEQ_LEN
outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)

View File

@ -4,6 +4,7 @@
import argparse
import gc
import os
import warnings
from pathlib import Path
from typing import Dict, Union
@ -100,12 +101,21 @@ def args_sanity_check():
data = gpc.config.data
assert data.seq_len is not None, "'seq_len' must be given a value"
assert data.micro_bsz is not None, "'micro_bsz' must be given a value"
if "packed_length" in data and gpc.is_rank_for_log():
logger.warning("packed_length would be ignored and will be setted as seq_len * micro_bsz.")
data._add_item("packed_length", data.seq_len * data.micro_bsz)
if "micro_bsz" in data:
warnings.warn(
"The parameter `micro_bsz` will be deprecated in the future, please use `packed_length` instead, "
"and set `packed_length` to `seq_len * micro_bsz`.",
DeprecationWarning,
stacklevel=2,
)
if data.get("packed_length", None) is None:
data._add_item("packed_length", data.seq_len * data.micro_bsz)
else:
assert (
data.packed_length == data.seq_len * data.micro_bsz
), "'packed_length' must be equal to 'seq_len * micro_bsz'"
else:
assert data.packed_length is not None, "'packed_length' must be given a value"
if "micro_num" not in data:
data._add_item("micro_num", 1)
@ -154,7 +164,6 @@ def args_sanity_check():
logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"seq_len: {data.seq_len}")
logger.info(f"micro_num: {data.micro_num}")
logger.info(f"micro_bsz: {data.micro_bsz}")
logger.info(f"packed_length: {data.packed_length}")
logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}")
logger.info(f"min_length: {data.min_length}")

View File

@ -246,7 +246,6 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
batch_size=data_cfg.micro_num,
rampup_batch_size=data_cfg.rampup_batch_size,
micro_bsz=data_cfg.micro_bsz,
seed=1024,
drop_last=True,
data_rank=gpc.get_local_rank(ParallelMode.DATA),
@ -302,10 +301,9 @@ def get_validation_data_loader(
else:
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
# otherwise too much data may be dropped
batch_size = min(
data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA)
)
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN
batch_size = min(data_cfg.valid_micro_num * micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA))
batch_size = batch_size // micro_bsz * micro_bsz
if batch_size == 0 and gpc.is_rank_for_log():
logger.info(f"skip validate {val_name}.")

View File

@ -96,13 +96,12 @@ def evaluate_on_val_dls(
):
moe_loss = None
with torch.inference_mode():
micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN
if gpc.is_using_pp():
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz
tensor_shape = torch.Size(
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
assert total_val_bsz % micro_bsz == 0
num_microbatches = total_val_bsz // micro_bsz
tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE])
with switch_evaluation_pipeline_scheduler(
trainer=trainer,
@ -121,8 +120,7 @@ def evaluate_on_val_dls(
)
else:
total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
grad_accum_size = total_val_bsz // micro_bsz
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,