From 0bfc86205e823c4482e1bd48cd25045e6e2b4809 Mon Sep 17 00:00:00 2001 From: Guoteng <32697156+SolenoidWGT@users.noreply.github.com> Date: Thu, 16 Nov 2023 19:51:01 +0800 Subject: [PATCH] feat(train): support_rampup_batch_size and fix bugs (#493) --- configs/7B_sft.py | 7 +- internlm/core/context/parallel_context.py | 12 + internlm/core/scheduler/base_scheduler.py | 6 +- .../core/scheduler/no_pipeline_scheduler.py | 27 ++- internlm/core/scheduler/pipeline_scheduler.py | 17 +- internlm/initialize/launch.py | 18 +- internlm/utils/evaluation.py | 7 +- tests/test_core/__init__.py | 0 tests/test_core/test_pipeline.py | 187 +--------------- tests/test_core/utils.py | 207 ++++++++++++++++++ tests/test_data/test_batch_sampler.py | 143 ++++++++++++ 11 files changed, 421 insertions(+), 210 deletions(-) create mode 100644 tests/test_core/__init__.py create mode 100644 tests/test_core/utils.py create mode 100644 tests/test_data/test_batch_sampler.py diff --git a/configs/7B_sft.py b/configs/7B_sft.py index ca2adc6..0218a0b 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -59,7 +59,12 @@ data = dict( pack_sample_into_one=False, total_steps=50000, skip_batches="", - rampup_batch_size="", + # rampup_batch_size (str): A string with three space-separated integers representing the + # starting batch size, the increment, and the number of steps between + # each increment. For example, "192 24 8" means that the batch size (micro_num) + # starts at 192 and increases by 24 every 8 steps. Defaults to None. + # (IMPORTANT): The interval step size is 'micro_bsz'. + rampup_batch_size=None, # Datasets with less than 50 rows will be discarded min_length=50, # train_folder=TRAIN_FOLDER, diff --git a/internlm/core/context/parallel_context.py b/internlm/core/context/parallel_context.py index 633dfe4..db356a1 100644 --- a/internlm/core/context/parallel_context.py +++ b/internlm/core/context/parallel_context.py @@ -156,6 +156,18 @@ class ParallelContext(metaclass=SingletonMeta): def config(self): return self._config + @property + def micro_bsz(self): + return self._config.data.micro_bsz + + @property + def micro_num(self): + return self._config.data.micro_num + + @property + def grad_accum_num(self): + return self._config.data.gradient_accumulation + @property def expert_parallel_group_names(self): return self._expert_parallel_group_names diff --git a/internlm/core/scheduler/base_scheduler.py b/internlm/core/scheduler/base_scheduler.py index 20b4460..14c3457 100644 --- a/internlm/core/scheduler/base_scheduler.py +++ b/internlm/core/scheduler/base_scheduler.py @@ -36,10 +36,10 @@ class BaseScheduler(ABC): """ pass - def _load_micro_batch(self, data, label, offset, micro_bsz): + def _load_micro_batch(self, data, label, offset): assert isinstance(data, dict) and isinstance(label, torch.Tensor) - micro_batch_data = {k: v[offset : offset + micro_bsz] for k, v in data.items()} - micro_batch_label = label[offset : offset + micro_bsz] + micro_batch_data = {k: v[offset : offset + 1] for k, v in data.items()} + micro_batch_label = label[offset : offset + 1] return micro_batch_data, micro_batch_label diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 56661d8..24d94ef 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -10,10 +10,13 @@ import torch from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.utils.common import conditional_context +from internlm.utils.logger import get_logger from internlm.utils.timeout import llm_timeout from .base_scheduler import BaseScheduler, SchedulerHook +logger = get_logger(__file__) + class NonPipelineScheduler(BaseScheduler): """A helper schedule class for no pipeline parallelism running environment. @@ -69,10 +72,8 @@ class NonPipelineScheduler(BaseScheduler): label (Any): The label to be loaded. """ - _data, _label = self._load_micro_batch( - data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size - ) - self._grad_accum_offset += self._grad_accum_batch_size + _data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset) + self._grad_accum_offset += 1 if self.data_process_func: _data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"]) @@ -135,7 +136,7 @@ class NonPipelineScheduler(BaseScheduler): self._call_hooks("after_backward", None) if not return_loss: - loss = None + loss, moe_loss = None, None return output, loss, moe_loss @@ -166,12 +167,9 @@ class NonPipelineScheduler(BaseScheduler): forward_only or return_loss ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." - batch_data, batch_size = engine.load_batch(data_iter) + batch_data, actual_batch_size = engine.load_batch(data_iter) - assert ( - batch_size % self._grad_accum_size == 0 - ), f"batch_size:{batch_size} must be an integer multiple of gradient accumulation steps:{self._grad_accum_size}" - self._grad_accum_batch_size = batch_size // self._grad_accum_size + self._grad_accum_size = actual_batch_size # Rampup or variable bsz size. data, label = batch_data @@ -184,10 +182,11 @@ class NonPipelineScheduler(BaseScheduler): self._grad_accum_offset = 0 for _current_accum_step in range(self._grad_accum_size): - if _current_accum_step == self._grad_accum_size - 1: - engine.optimizer.skip_grad_reduce = False - else: - engine.optimizer.skip_grad_reduce = True + if engine.optimizer is not None: + if _current_accum_step == self._grad_accum_size - 1: + engine.optimizer.skip_grad_reduce = False + else: + engine.optimizer.skip_grad_reduce = True _data, _label = self._load_accum_batch(data, label) diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index efc9187..c851789 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -15,10 +15,13 @@ from internlm.core.context import global_context as gpc from internlm.core.engine import Engine from internlm.core.naive_amp import NaiveAMPModel from internlm.utils.common import get_current_device, move_to_device +from internlm.utils.logger import get_logger from internlm.utils.timeout import llm_timeout from .base_scheduler import BaseScheduler, SchedulerHook +logger = get_logger(__file__) + def get_tensor_shape(): if hasattr(gpc.config, "TENSOR_SHAPE"): @@ -184,17 +187,16 @@ class PipelineScheduler(BaseScheduler): def load_batch(self, engine, data_iter): # Pipeline schedule just puts data in memory - batch_data, batch_size = engine.load_batch(data_iter, to_gpu=False) - assert batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches" + batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False) + self.num_microbatches = actual_batch_size # Rampup or variable bsz size. self.microbatch_offset = 0 - self.batch_size = batch_size + self.batch_size = actual_batch_size self.batch_data, self.batch_label = batch_data - self.microbatch_size = self.batch_size // self.num_microbatches def load_micro_batch(self): micro_batch_data, micro_batch_label = self._load_micro_batch( - data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size + data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset ) if self.data_process_func: micro_batch_data["input_ids"] = self.data_process_func( @@ -206,7 +208,7 @@ class PipelineScheduler(BaseScheduler): micro_batch_data.pop("indexes") micro_batch_data["label"] = micro_batch_label - self.microbatch_offset += self.microbatch_size + self.microbatch_offset += 1 return move_to_device(micro_batch_data) @@ -785,10 +787,9 @@ class InterleavedPipelineScheduler(PipelineScheduler): data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset[model_chunk_id], - micro_bsz=self.microbatch_size, ) micro_batch_data["label"] = micro_batch_label - self.microbatch_offset[model_chunk_id] += self.microbatch_size + self.microbatch_offset[model_chunk_id] += 1 return move_to_device(micro_batch_data) def _forward_step(self, engine, chunk_id): diff --git a/internlm/initialize/launch.py b/internlm/initialize/launch.py index 6614db0..06f225a 100644 --- a/internlm/initialize/launch.py +++ b/internlm/initialize/launch.py @@ -109,9 +109,15 @@ def args_sanity_check(): if "micro_num" not in data: data._add_item("micro_num", 1) - data._add_item("gradient_accumulation", data.micro_num) - if gpc.is_rank_for_log(): - logger.info(f"gradient_accumulation size will be setted to {data.micro_num}.") + if "gradient_accumulation" not in data: + data._add_item("gradient_accumulation", data.micro_num) + if gpc.is_rank_for_log(): + logger.info(f"gradient_accumulation size will be setted to {data.micro_num}.") + else: + if pp == 1: + assert ( + data.gradient_accumulation == data.micro_num + ), "for nopp 'gradient_accumulation' should equal with 'micro_num'" # batch_size should be equal with micro_num, should not use it directly data._add_item("batch_size", data.micro_num) @@ -136,6 +142,11 @@ def args_sanity_check(): if "diag_outlier_ratio" not in data: data._add_item("diag_outlier_ratio", 1.1) + + if "rampup_batch_size" not in data or not data.rampup_batch_size or len(data.rampup_batch_size) == 0: + bsz = data.micro_num + data._add_item("rampup_batch_size", f"{bsz} {bsz} 1") + data.diag_outlier_ratio = max(1, data.diag_outlier_ratio) if gpc.is_rank_for_log(): @@ -148,6 +159,7 @@ def args_sanity_check(): logger.info(f"min_length: {data.min_length}") logger.info(f"valid_micro_num: {data.valid_micro_num}") logger.info(f"valid_every: {data.valid_every}") + logger.info(f"rampup_batch_size: {data.rampup_batch_size}") # processing the checkpoint config ckpt = gpc.config.ckpt diff --git a/internlm/utils/evaluation.py b/internlm/utils/evaluation.py index 6a55fa5..22d998b 100644 --- a/internlm/utils/evaluation.py +++ b/internlm/utils/evaluation.py @@ -11,22 +11,19 @@ from internlm.model.metrics import AccPerplex @contextmanager -def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list): +def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, metric_hook_list): if not gpc.is_using_pp(): prev_data_process_func = trainer.schedule.data_process_func prev_grad_accum_size = trainer.schedule._grad_accum_size - prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size prev_metric_hooks = trainer.schedule._hooks try: trainer.schedule.data_process_func = None trainer.schedule._grad_accum_size = grad_accum_size - trainer.schedule._grad_accum_batch_size = grad_accum_batch_size trainer.schedule._hooks = metric_hook_list yield finally: trainer.schedule.data_process_func = prev_data_process_func trainer.schedule._grad_accum_size = prev_grad_accum_size - trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size trainer.schedule._hooks = prev_metric_hooks @@ -126,11 +123,9 @@ def evaluate_on_val_dls( total_val_bsz = len(batch[1]) assert total_val_bsz % data_cfg.micro_bsz == 0 grad_accum_size = total_val_bsz // data_cfg.micro_bsz - grad_accum_batch_size = data_cfg.micro_bsz with switch_evaluation_no_pipeline_scheduler( trainer=trainer, grad_accum_size=grad_accum_size, - grad_accum_batch_size=grad_accum_batch_size, metric_hook_list=[val_sche_metric_hook], ): if hasattr(gpc.config.model, "num_experts"): diff --git a/tests/test_core/__init__.py b/tests/test_core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_core/test_pipeline.py b/tests/test_core/test_pipeline.py index ce9dc98..4b37f61 100644 --- a/tests/test_core/test_pipeline.py +++ b/tests/test_core/test_pipeline.py @@ -1,59 +1,19 @@ import multiprocessing as mp -import random -import numpy as np import pytest import torch -from torch import nn -from torch.testing import assert_close -import internlm from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc from internlm.core.context.parallel_context import Config -from internlm.core.engine import Engine -from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler -from internlm.core.scheduler import ( - InterleavedPipelineScheduler, - PipelineScheduler, - SchedulerMetricHook, +from tests.test_core.utils import ( + MlpModel, + MyLoss, + build_environment, + init_model_and_optim, + loose_close, + seed_all, ) -from internlm.solver.pipeline_utils import partition_uniform -from internlm.train import initialize_optimizer - - -class MlpModel(nn.Module): - """ - Custom model - """ - - def __init__(self, start, end, model_type=None): - super().__init__() - self.part = [start, end] - self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end - start)]) - self.model_type = model_type - - def forward(self, hidden_states=None, input_ids=None): - if self.model_type != "torch" and self.part[0] != 0: - input_ids = hidden_states - - for i in range(self.part[1] - self.part[0]): - input_ids = self.blocks[i](input_ids) - return input_ids - - -class MyLoss(nn.Module): - """ - Custom loss - """ - - def __init__(self): - super().__init__() - - def forward(self, logits, labels): - loss = torch.nn.MSELoss(reduction="sum") - return loss(logits, labels) - config = Config( dict( @@ -116,71 +76,6 @@ config = Config( ) -def build_environment(rank, world_size): - import os - - os.environ["RANK"] = str(rank) - os.environ["LOCAL_RANK"] = str(rank) - os.environ["WORLD_SIZE"] = str(world_size) - os.environ["MASTER_ADDR"] = "127.0.0.1" - os.environ["MASTER_PORT"] = "33333" - torch.cuda.empty_cache() - # launcher="torch" - internlm.launch_from_torch(config=config, seed=1024) - - -def loose_close(a, b, dtype: torch.dtype = torch.float32): - - if dtype is torch.float32: - rtol = 1.3e-6 - atol = 1e-5 - elif dtype is torch.bfloat16: - rtol = 2e-2 - atol = 2e-2 - - if isinstance(a, torch.Tensor): - a = a.detach().to(dtype) - b = b.detach().to(dtype) - - assert_close(a, b, rtol=rtol, atol=atol) - - -def seed_all(seed, cuda_deterministic=False): - random.seed(seed) - np.random.seed(seed) - torch.manual_seed(seed) - if torch.cuda.is_available(): - torch.cuda.manual_seed(seed) - torch.cuda.manual_seed_all(seed) - if cuda_deterministic: # slower, more reproducible - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - else: - torch.backends.cudnn.deterministic = False - torch.backends.cudnn.benchmark = True - - -def _build_generic_model_1d(num_layers, num_chunks): - pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) - pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) - - all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) - parts = all_parts[pipeline_rank] - if gpc.is_rank_for_log(): - print(f"The layer sharding is {all_parts}.", flush=True) - - models = [] - for start, end in parts: - models.append(MlpModel(start, end).cuda()) - torch.distributed.barrier() - if len(models) == 1: - model = models[0] - else: - model = nn.ModuleList(models) - - return model - - def exam_pipeline_parallel(args): # init rank, world_size, micro_num, num_chunks, interleaved_overlap = args @@ -188,76 +83,18 @@ def exam_pipeline_parallel(args): config.model.num_chunks = num_chunks config.parallel.pipeline.interleaved_overlap = interleaved_overlap - build_environment(rank, world_size) + build_environment(rank, world_size, config) device = torch.device(f"cuda:{rank}") dtype = config.model["dtype"] + seq_len = gpc.config.data.seq_len # set seed seed_all(1024) - # pp model - pp_model = _build_generic_model_1d(num_layers=32, num_chunks=num_chunks) - pp_model = pp_model.to(dtype) - - # pp scheduler - scheduler_hooks = [ - SchedulerMetricHook(skip=True), - ] - - seq_len = gpc.config.data.seq_len - gpc.config.NUM_MICRO_BATCHES = micro_num - communication_overlap = interleaved_overlap - - if num_chunks == 1: - # noninterleaved pp - scheduler = PipelineScheduler( - data_process_func=None, - num_microbatches=micro_num, - dtype=dtype, - tensor_shape=[1, 8], - scatter_gather_tensors=False, - scheduler_hooks=scheduler_hooks, - ) - else: - # interleaved pp - if micro_num < gpc.get_world_size(ParallelMode.PIPELINE): - try: - scheduler = InterleavedPipelineScheduler( - num_microbatches=micro_num, - num_chunks=gpc.config.model.num_chunks, - dtype=dtype, - tensor_shape=[1, 8], - scatter_gather_tensors=False, - scheduler_hooks=scheduler_hooks, - communication_overlap=communication_overlap, - ) - except AssertionError: - return - else: - raise RuntimeError("Error: AssertionError should occur when micro_num < Pipeline parrallel world size") - else: - scheduler = InterleavedPipelineScheduler( - num_microbatches=micro_num, - num_chunks=gpc.config.model.num_chunks, - dtype=dtype, - tensor_shape=[1, 8], - scatter_gather_tensors=False, - scheduler_hooks=scheduler_hooks, - communication_overlap=communication_overlap, - ) - - # pp optimizer and engine - optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model) - engine = Engine( - model=pp_model, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - beta2_scheduler=beta2_scheduler, - criterion=MyLoss().to(dtype), - gradient_handlers=[PipelineSharedModuleGradientHandler(model=pp_model, optimizer=optimizer)], - clip_grad_norm=gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0), - ) + engine, scheduler = init_model_and_optim(32, num_chunks, dtype, micro_num, interleaved_overlap, tensor_shape=[1, 8]) + if scheduler is None: + return scheduler.pre_processing(engine) engine.train() diff --git a/tests/test_core/utils.py b/tests/test_core/utils.py new file mode 100644 index 0000000..6f66a15 --- /dev/null +++ b/tests/test_core/utils.py @@ -0,0 +1,207 @@ +import random + +import numpy as np +import torch +from torch import nn +from torch.testing import assert_close + +import internlm +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.core.engine import Engine +from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler +from internlm.core.scheduler import ( + InterleavedPipelineScheduler, + NonPipelineScheduler, + PipelineScheduler, + SchedulerMetricHook, +) +from internlm.solver.pipeline_utils import partition_uniform +from internlm.train import initialize_optimizer + + +class MlpModel(nn.Module): + """ + Custom model + """ + + def __init__(self, start, end, model_type=None, embedding=False): + super().__init__() + self.part = [start, end] + self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end - start)]) + self.model_type = model_type + self.embedding = embedding + + def forward( + self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None + ): # pylint: disable=W0613 + if self.model_type != "torch" and self.part[0] != 0: + input_ids = hidden_states + + # Simulate Embedding. + if self.embedding: + if len(input_ids.shape) == 2: + input_ids = input_ids.view(-1, 8) + elif len(input_ids.shape) == 3: + input_ids = input_ids.view(input_ids.shape(0), -1, 8) + + for i in range(self.part[1] - self.part[0]): + input_ids = self.blocks[i](input_ids) + + return input_ids + + +class MyLoss(nn.Module): + """ + Custom loss + """ + + def __init__(self): + super().__init__() + + def forward(self, logits, labels): + loss = torch.nn.MSELoss(reduction="sum") + return loss(logits, labels) + + +def init_model_and_optim( + num_layers, num_chunks, dtype, micro_num, interleaved_overlap, tensor_shape, init_optim=True, embedding=False +): + # pp model + pp_model = _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, embedding=embedding) + pp_model = pp_model.to(dtype) + + # pp scheduler + scheduler_hooks = [ + SchedulerMetricHook(skip=True), + ] + + if gpc.get_world_size(ParallelMode.PIPELINE) > 1: + if num_chunks == 1: + # noninterleaved pp + scheduler = PipelineScheduler( + data_process_func=None, + num_microbatches=micro_num, + dtype=dtype, + tensor_shape=tensor_shape, + scatter_gather_tensors=False, + scheduler_hooks=scheduler_hooks, + ) + else: + # interleaved pp + if micro_num < gpc.get_world_size(ParallelMode.PIPELINE): + try: + scheduler = InterleavedPipelineScheduler( + num_microbatches=micro_num, + num_chunks=gpc.config.model.num_chunks, + dtype=dtype, + tensor_shape=tensor_shape, + scatter_gather_tensors=False, + scheduler_hooks=scheduler_hooks, + communication_overlap=interleaved_overlap, + ) + except AssertionError as e: + print(f"AssertionError: {e}", flush=True) + return None, None + else: + raise RuntimeError( + "Error: AssertionError should occur when micro_num < Pipeline parrallel world size" + ) + else: + scheduler = InterleavedPipelineScheduler( + num_microbatches=micro_num, + num_chunks=gpc.config.model.num_chunks, + dtype=dtype, + tensor_shape=tensor_shape, + scatter_gather_tensors=False, + scheduler_hooks=scheduler_hooks, + communication_overlap=interleaved_overlap, + ) + else: + scheduler = NonPipelineScheduler( + data_process_func=None, + gradient_accumulation_size=gpc.config.data.gradient_accumulation, + scheduler_hooks=scheduler_hooks, + ) + + # pp optimizer and engine + if init_optim: + optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model) + else: + optimizer, beta2_scheduler, lr_scheduler = None, None, None + + engine = Engine( + model=pp_model, + optimizer=optimizer, + lr_scheduler=lr_scheduler, + beta2_scheduler=beta2_scheduler, + criterion=MyLoss().to(dtype), + gradient_handlers=[PipelineSharedModuleGradientHandler(model=pp_model, optimizer=optimizer)], + clip_grad_norm=0.0, + ) + return engine, scheduler + + +def build_environment(rank, world_size, config): + import os + + os.environ["RANK"] = str(rank) + os.environ["LOCAL_RANK"] = str(rank) + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["MASTER_ADDR"] = "127.0.0.1" + os.environ["MASTER_PORT"] = "33333" + torch.cuda.empty_cache() + # launcher="torch" + internlm.launch_from_torch(config=config, seed=1024) + + +def loose_close(a, b, dtype: torch.dtype = torch.float32): + + if dtype is torch.float32: + rtol = 1.3e-6 + atol = 1e-5 + elif dtype is torch.bfloat16: + rtol = 2e-2 + atol = 2e-2 + + if isinstance(a, torch.Tensor): + a = a.detach().to(dtype) + b = b.detach().to(dtype) + + assert_close(a, b, rtol=rtol, atol=atol) + + +def seed_all(seed, cuda_deterministic=False): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + if cuda_deterministic: # slower, more reproducible + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + else: + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + + +def _build_generic_model_1d(num_layers, num_chunks, embedding=False): + pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE) + pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE) + + all_parts = partition_uniform(num_layers, pipeline_size, num_chunks) + parts = all_parts[pipeline_rank] + if gpc.is_rank_for_log(): + print(f"The layer sharding is {all_parts}.", flush=True) + + models = [] + for start, end in parts: + models.append(MlpModel(start, end, embedding=embedding).cuda()) + torch.distributed.barrier() + if len(models) == 1: + model = models[0] + else: + model = nn.ModuleList(models) + + return model diff --git a/tests/test_data/test_batch_sampler.py b/tests/test_data/test_batch_sampler.py new file mode 100644 index 0000000..9110ca2 --- /dev/null +++ b/tests/test_data/test_batch_sampler.py @@ -0,0 +1,143 @@ +import multiprocessing as mp + +import numpy as np +import pytest +import torch + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc + +# from internlm.core.context import ParallelMode +from internlm.core.context.parallel_context import Config +from internlm.core.trainer import TrainState +from internlm.train import get_train_data_loader, load_new_batch + +# from internlm.core.context.parallel_context import global_context as gpc +from tests.test_core.utils import build_environment, init_model_and_optim + +micro_bszs = [1, 2] +use_flash_attens = [True, False] +answers = [[1] * 8, [1, 1, 1, 1, 2, 2, 2, 2], [4] * 8, [2, 2, 4, 4, 6, 6, 8, 8]] +test_case_group = [ + # format: micro_nums, rampup_batch_size, should sccuess, answer, pp size, sql len + # (1, "1 1 1", True, answers[0], 1, 8), + (4, "1 1 4", True, answers[1], 1, 8), + (4, None, True, answers[2], 1, 8), + (8, "2 2 2", True, answers[3], 1, 8), + (8, "2 2 2", True, answers[3], 2, 8), +] + + +def do_warmup(args): + rank, worldsize, init_config, should_sccuess, answer = args + build_environment(rank, worldsize, init_config) + gpc.config.model.num_chunks = 1 if gpc.get_world_size(ParallelMode.PIPELINE) == 1 else 2 + engine, scheduler = init_model_and_optim( + 8, + gpc.config.model.num_chunks, + torch.bfloat16, + init_config.data.micro_num, + True, + tensor_shape=[1, 8], # can't use get_tensor_shape becase we use toy model. + init_optim=False, + embedding=True, + ) + scheduler.pre_processing(engine) + engine.train() + + try: + train_dl, _ = get_train_data_loader(num_worker=0) + except Exception as e: + assert should_sccuess is False, f"{e}" + else: + assert should_sccuess is True + + # initialize and resume train state + train_state = TrainState(gpc.config, train_dl.batch_sampler) + # transfer the train data loader into train data iterator + train_iter = iter(train_dl) + + micro_bsz = gpc.config.data.micro_bsz + sql = gpc.config.data.seq_len + + consumed_token = 0 # Token consumed + packed_length = micro_bsz * sql + for i in range(init_config.data.total_steps): + batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state) + input_shape = batch[0]["type_ids"].shape + tokens_num = np.prod(input_shape) + + if not init_config.model.use_flash_attn: + if answer[i] > 1: + assert input_shape == torch.Size( + [answer[i], micro_bsz, sql] + ), f"iter:{i}, {input_shape} != {[answer[i], micro_bsz, sql]}" + else: + assert input_shape == torch.Size([micro_bsz, sql]), f"iter:{i}, {input_shape} != {[micro_bsz, sql]}" + else: + assert input_shape == torch.Size( + [answer[i], packed_length] + ), f"iter:{i}, {input_shape} != {torch.Size([answer[i], packed_length])}" + + if gpc.get_global_rank() == 0: + print( + f"iter:{i}", + f"pp size: {gpc.get_world_size(ParallelMode.PIPELINE)}", + f"use_flash_attn:{gpc.config.model.use_flash_attn}", + f"micro_bsz:{micro_bsz}", + f"input shape: {batch[0]['type_ids'].shape}", + f"rampup_batch_size: {gpc.config.data.rampup_batch_size}", + f"tokens_num: {tokens_num}", + flush=True, + ) + + consumed_token += tokens_num + batch[0].pop("type_ids", None) + batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16) + + scheduler.forward_backward_step(engine, batch, forward_only=True, return_loss=False, return_output_label=False) + assert ( + tokens_num == answer[i] * gpc.config.data.seq_len * micro_bsz + ), f"{tokens_num} == {answer[i] * gpc.config.data.seq_len * micro_bsz}" + + +@pytest.mark.parametrize("use_flash_atten_case", use_flash_attens) +@pytest.mark.parametrize("group_case", test_case_group) +@pytest.mark.parametrize("micro_bsz_case", micro_bszs) +def test_warmup(use_flash_atten_case, group_case, micro_bsz_case): + ctx = mp.get_context("spawn") + # print(pp_size_case, use_flash_atten_case, group_case, micro_bsz_case, flush=True) + + config = Config( + dict( + parallel=dict( + zero1=dict(size=1, fsdp=False), + pipeline=dict(size=1, interleaved_overlap=False), + sequence_parallel=False, + tensor=1, + ), + data=dict(train_folder=None, pack_sample_into_one=False, min_length=0, total_steps=8), + model=dict( + dtype=torch.bfloat16, + ), + adam=dict(lr=1e-4), + resume_tb_folder=None, + tensorboard_folder=None, + ) + ) + + config.data.seq_len = group_case[5] + config.parallel.pipeline.size = group_case[4] + config.model.use_flash_attn = use_flash_atten_case + config.data.micro_bsz = micro_bsz_case + config.data.micro_num = group_case[0] + config.data.gradient_accumulation = config.data.micro_num + config.data.rampup_batch_size = group_case[1] + config.data.packed_length = micro_bsz_case * config.data.seq_len + should_sccuess = group_case[2] + answer = group_case[3] + + with ctx.Pool(processes=8) as pool: + pool.map(do_warmup, [[rank, 8, config, should_sccuess, answer] for rank in range(8)]) + pool.close() + pool.join()