remove micro_bsz

pull/517/head
gaoyang07 2023-11-25 22:44:20 +08:00
parent 06e8301861
commit fdbdfcff34
8 changed files with 33 additions and 42 deletions

View File

@ -156,10 +156,6 @@ class ParallelContext(metaclass=SingletonMeta):
def config(self): def config(self):
return self._config return self._config
@property
def micro_bsz(self):
return self._config.data.micro_bsz
@property @property
def micro_num(self): def micro_num(self):
return self._config.data.micro_num return self._config.data.micro_num

View File

@ -30,22 +30,22 @@ def get_tensor_shape():
if not gpc.is_initialized(ParallelMode.PIPELINE): if not gpc.is_initialized(ParallelMode.PIPELINE):
return None return None
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"): if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config, "HIDDEN_SIZE"):
if gpc.config.model.use_flash_attn: if gpc.config.model.use_flash_attn:
if gpc.config.parallel.sequence_parallel: if gpc.config.parallel.sequence_parallel:
sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR) sequence_world_size = gpc.get_world_size(ParallelMode.TENSOR)
tensor_shape = ( tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"] // sequence_world_size, gpc.config.data["packed_length"] // sequence_world_size,
gpc.config.HIDDEN_SIZE, gpc.config.HIDDEN_SIZE,
) )
else: else:
tensor_shape = ( tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"], gpc.config.data["packed_length"],
gpc.config.HIDDEN_SIZE, gpc.config.HIDDEN_SIZE,
) )
else: else:
tensor_shape = ( tensor_shape = (
gpc.config.data["micro_bsz"], gpc.config.data["packed_length"] // gpc.config.SEQ_LEN,
gpc.config.SEQ_LEN, gpc.config.SEQ_LEN,
gpc.config.HIDDEN_SIZE, gpc.config.HIDDEN_SIZE,
) )

View File

@ -187,7 +187,6 @@ class StaticBatchSampler:
each increment. For example, "192 24 8" means that the batch size each increment. For example, "192 24 8" means that the batch size
starts at 192 and increases by 24 every 8 steps. Defaults to starts at 192 and increases by 24 every 8 steps. Defaults to
"6 2 8", which corresponds to a batch size of 2 for the first 6 steps. "6 2 8", which corresponds to a batch size of 2 for the first 6 steps.
micro_bsz (int): The micro-batch size. Defaults to 2.
seed (int): The random seed for shuffling the indices. Defaults to 0. seed (int): The random seed for shuffling the indices. Defaults to 0.
drop_last (bool): If True, drop the last incomplete batch. Currently only supports True. Defaults to True. drop_last (bool): If True, drop the last incomplete batch. Currently only supports True. Defaults to True.
data_rank (int): The rank of the current process in the data parallel group. Defaults to 0. data_rank (int): The rank of the current process in the data parallel group. Defaults to 0.
@ -199,40 +198,32 @@ class StaticBatchSampler:
datasets, datasets,
batch_size=192, batch_size=192,
rampup_batch_size="6 2 8", rampup_batch_size="6 2 8",
micro_bsz=2,
seed=0, seed=0,
drop_last=True, drop_last=True,
data_rank=0, data_rank=0,
data_world_size=1, data_world_size=1,
): ):
assert drop_last is True, "Currently only support drop last" assert drop_last is True, "Currently only support drop last"
self.raw_rampup_batch_size = rampup_batch_size
if rampup_batch_size: if rampup_batch_size:
# In the process increase to batch_size # In the process increase to batch_size
start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split()) start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split())
else: else:
start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1 start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1
self.raw_rampup_batch_size = rampup_batch_size
self.start_bsz = start_bsz self.start_bsz = start_bsz
self.bsz_incre = bsz_incre self.bsz_incre = bsz_incre
self.incre_every = incre_every self.incre_every = incre_every
if gpc.is_initialized(ParallelMode.PIPELINE): if gpc.is_initialized(ParallelMode.PIPELINE):
assert ( assert (
batch_size - self.start_bsz batch_size - self.start_bsz
) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}" ) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}"
assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})"
assert (
self.start_bsz % micro_bsz == 0
), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})"
assert (
self.bsz_incre % micro_bsz == 0
), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})"
self.batch_size = batch_size self.batch_size = batch_size
self.epoch = 0 self.epoch = 0
self.seed = seed self.seed = seed
self.rng = np.random.RandomState(seed) self.rng = np.random.RandomState(seed)
self.batch_count = 0 self.batch_count = 0
self.micro_bsz = micro_bsz
self.data_rank = data_rank self.data_rank = data_rank
self.data_world_size = data_world_size self.data_world_size = data_world_size
self.num_consumed_samples_in_epoch = 0 self.num_consumed_samples_in_epoch = 0
@ -343,7 +334,6 @@ Vs. self.num_samples: {self.num_samples}"
self.datasets, self.datasets,
self.batch_size, self.batch_size,
self.raw_rampup_batch_size, self.raw_rampup_batch_size,
self.micro_bsz,
self.seed, self.seed,
drop_last=True, drop_last=True,
data_rank=self.data_rank, data_rank=self.data_rank,

View File

@ -149,9 +149,9 @@ class PackedDataset(torch.utils.data.Dataset):
if index == 0: if index == 0:
pre_pos = 0 pre_pos = 0
else: else:
pre_pos = index * gpc.config.data["micro_bsz"] pre_pos = index * gpc.config.data["packed_length"] // gpc.config.SEQ_LEN
pos = (index + 1) * gpc.config.data["micro_bsz"] pos = (index + 1) * gpc.config.data["packed_length"] // gpc.config.SEQ_LEN
return pre_pos, pos return pre_pos, pos
def build_unpack(self, index): def build_unpack(self, index):

View File

@ -33,7 +33,7 @@ def unpack_data(input_ids, cu_seqlens):
bsz = input_ids.shape[0] bsz = input_ids.shape[0]
num_sequence = gpc.config.data["micro_bsz"] num_sequence = gpc.config.data["packed_length"] // gpc.config.SEQ_LEN
outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype) outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)

View File

@ -4,6 +4,7 @@
import argparse import argparse
import gc import gc
import os import os
import warnings
from pathlib import Path from pathlib import Path
from typing import Dict, Union from typing import Dict, Union
@ -100,12 +101,21 @@ def args_sanity_check():
data = gpc.config.data data = gpc.config.data
assert data.seq_len is not None, "'seq_len' must be given a value" assert data.seq_len is not None, "'seq_len' must be given a value"
assert data.micro_bsz is not None, "'micro_bsz' must be given a value" if "micro_bsz" in data:
warnings.warn(
if "packed_length" in data and gpc.is_rank_for_log(): "The parameter `micro_bsz` will be deprecated in the future, please use `packed_length` instead, "
logger.warning("packed_length would be ignored and will be setted as seq_len * micro_bsz.") "and set `packed_length` to `seq_len * micro_bsz`.",
DeprecationWarning,
data._add_item("packed_length", data.seq_len * data.micro_bsz) stacklevel=2,
)
if data.get("packed_length", None) is None:
data._add_item("packed_length", data.seq_len * data.micro_bsz)
else:
assert (
data.packed_length == data.seq_len * data.micro_bsz
), "'packed_length' must be equal to 'seq_len * micro_bsz'"
else:
assert data.packed_length is not None, "'packed_length' must be given a value"
if "micro_num" not in data: if "micro_num" not in data:
data._add_item("micro_num", 1) data._add_item("micro_num", 1)
@ -154,7 +164,6 @@ def args_sanity_check():
logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201 logger.info("+" * 15 + " Data Info " + "+" * 15) # pylint: disable=W1201
logger.info(f"seq_len: {data.seq_len}") logger.info(f"seq_len: {data.seq_len}")
logger.info(f"micro_num: {data.micro_num}") logger.info(f"micro_num: {data.micro_num}")
logger.info(f"micro_bsz: {data.micro_bsz}")
logger.info(f"packed_length: {data.packed_length}") logger.info(f"packed_length: {data.packed_length}")
logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}") logger.info(f"pack_sample_into_one: {data.pack_sample_into_one}")
logger.info(f"min_length: {data.min_length}") logger.info(f"min_length: {data.min_length}")

View File

@ -246,7 +246,6 @@ def get_train_data_loader(num_worker: int = 0, dataset_generate_func: Optional[C
train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds], train_ds.datasets if isinstance(train_ds, ConcatDataset) else [train_ds],
batch_size=data_cfg.micro_num, batch_size=data_cfg.micro_num,
rampup_batch_size=data_cfg.rampup_batch_size, rampup_batch_size=data_cfg.rampup_batch_size,
micro_bsz=data_cfg.micro_bsz,
seed=1024, seed=1024,
drop_last=True, drop_last=True,
data_rank=gpc.get_local_rank(ParallelMode.DATA), data_rank=gpc.get_local_rank(ParallelMode.DATA),
@ -302,10 +301,9 @@ def get_validation_data_loader(
else: else:
# making the batch_size of validate larger can speed up the evaluation, but it should not be too large, # making the batch_size of validate larger can speed up the evaluation, but it should not be too large,
# otherwise too much data may be dropped # otherwise too much data may be dropped
batch_size = min( micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN
data_cfg.valid_micro_num * data_cfg.micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA) batch_size = min(data_cfg.valid_micro_num * micro_bsz, len(ds) // gpc.get_world_size(ParallelMode.DATA))
) batch_size = batch_size // micro_bsz * micro_bsz
batch_size = batch_size // data_cfg.micro_bsz * data_cfg.micro_bsz
if batch_size == 0 and gpc.is_rank_for_log(): if batch_size == 0 and gpc.is_rank_for_log():
logger.info(f"skip validate {val_name}.") logger.info(f"skip validate {val_name}.")

View File

@ -96,13 +96,12 @@ def evaluate_on_val_dls(
): ):
moe_loss = None moe_loss = None
with torch.inference_mode(): with torch.inference_mode():
micro_bsz = data_cfg.packed_length // gpc.config.SEQ_LEN
if gpc.is_using_pp(): if gpc.is_using_pp():
total_val_bsz = len(batch[1]) total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0 assert total_val_bsz % micro_bsz == 0
num_microbatches = total_val_bsz // data_cfg.micro_bsz num_microbatches = total_val_bsz // micro_bsz
tensor_shape = torch.Size( tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE])
[data_cfg.micro_bsz, batch[0]["input_ids"].shape[1], gpc.config.HIDDEN_SIZE]
)
with switch_evaluation_pipeline_scheduler( with switch_evaluation_pipeline_scheduler(
trainer=trainer, trainer=trainer,
@ -121,8 +120,7 @@ def evaluate_on_val_dls(
) )
else: else:
total_val_bsz = len(batch[1]) total_val_bsz = len(batch[1])
assert total_val_bsz % data_cfg.micro_bsz == 0 grad_accum_size = total_val_bsz // micro_bsz
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
with switch_evaluation_no_pipeline_scheduler( with switch_evaluation_no_pipeline_scheduler(
trainer=trainer, trainer=trainer,
grad_accum_size=grad_accum_size, grad_accum_size=grad_accum_size,