fix(pp): fix no-packed dataset load micro batch error (#538)

* fix(pp): fix no-packed dataset load micro batch error

* fix based on comment
pull/539/head
Guoteng 2023-12-13 14:48:32 +08:00 committed by GitHub
parent 432bd5ee9f
commit 5ecb6aa712
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 109 additions and 15 deletions

View File

@ -185,6 +185,11 @@ class Engine:
if to_gpu:
batch_data = move_to_device(batch_data)
# For packed-dataset, batch_data is (micro_num, micro_bsz*seq_len),
# therefore 'batch_size' is equal to 'micro_num'
# For nopacked-dataset, batch_data is (micro_num*micro_bsz, seq_len),
# therefore 'batch_size' is equal to 'micro_num*micro_bsz'
batch_size = get_batch_size(batch_data)
return batch_data, batch_size

View File

@ -4,7 +4,7 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
from abc import ABC, abstractmethod
from typing import Any, Callable, Iterable, Optional
from typing import Any, Callable, Dict, Iterable, Optional
import torch
@ -36,10 +36,18 @@ class BaseScheduler(ABC):
"""
pass
def _load_micro_batch(self, data, label, offset):
def _load_micro_batch(self, data: Dict, label: torch.Tensor, offset: int, bsz_stride: int):
"""
For pp, it will cut one fully batch into micro batch in pipeline concept.
For nopp, it will cut one fully batch into small batch based on gradient accumulate size.
A special case is that pp uses a 'non-packed-dateset' (such as evaluation dataset),
so the data of batch is unpacked and 'bsz_stride' is equal to 'micro_bsz'.
In all other cases 'bsz_stride' should be equal to 1.
"""
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
micro_batch_data = {k: v[offset : offset + 1] for k, v in data.items()}
micro_batch_label = label[offset : offset + 1]
micro_batch_data = {k: v[offset : offset + bsz_stride] for k, v in data.items()}
micro_batch_label = label[offset : offset + bsz_stride]
return micro_batch_data, micro_batch_label

View File

@ -72,7 +72,7 @@ class NonPipelineScheduler(BaseScheduler):
label (Any): The label to be loaded.
"""
_data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset)
_data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset, bsz_stride=1)
self._grad_accum_offset += 1
if self.data_process_func:
@ -167,7 +167,7 @@ class NonPipelineScheduler(BaseScheduler):
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
batch_data, actual_batch_size = engine.load_batch(data_iter)
batch_data, actual_batch_size = engine.load_batch(data_iter) # actual_batch_size is micro_num
self._grad_accum_size = actual_batch_size # Rampup or variable bsz size.

View File

@ -14,7 +14,11 @@ from internlm.core.context import ParallelMode
from internlm.core.context import global_context as gpc
from internlm.core.engine import Engine
from internlm.core.naive_amp import NaiveAMPModel
from internlm.utils.common import get_current_device, move_to_device
from internlm.utils.common import (
check_data_is_packed,
get_current_device,
move_to_device,
)
from internlm.utils.logger import get_logger
from internlm.utils.timeout import llm_timeout
@ -186,17 +190,28 @@ class PipelineScheduler(BaseScheduler):
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
def load_batch(self, engine, data_iter):
# Pipeline schedule just puts data in memory
# Pipeline schedule just puts data in memory,
batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False)
self.num_microbatches = actual_batch_size # Rampup or variable bsz size.
# Even if 'use_flash_attn' is False, the data seen when the 'load_batch' is called is still packed,
# because internlm's current train dataset is packed, even using dummy data.
# The unpack operation is performed in load_micro_batch().
if check_data_is_packed(batch_data):
micro_num = actual_batch_size
else:
micro_num = actual_batch_size // gpc.config.data["micro_bsz"]
self.microbatch_offset = 0
self.batch_size = actual_batch_size
self.batch_data, self.batch_label = batch_data
self.bsz_stride = self.batch_size // micro_num
# 'num_microbatches' is no longer an initialization parameter,
# but is determined on the fly by the Scheduler.
self.num_microbatches = micro_num # Rampup or variable bsz size.
def load_micro_batch(self):
micro_batch_data, micro_batch_label = self._load_micro_batch(
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, bsz_stride=self.bsz_stride
)
if self.data_process_func:
micro_batch_data["input_ids"] = self.data_process_func(
@ -208,7 +223,7 @@ class PipelineScheduler(BaseScheduler):
micro_batch_data.pop("indexes")
micro_batch_data["label"] = micro_batch_label
self.microbatch_offset += 1
self.microbatch_offset += self.bsz_stride
return move_to_device(micro_batch_data)
@ -787,9 +802,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
data=self.batch_data,
label=self.batch_label,
offset=self.microbatch_offset[model_chunk_id],
bsz_stride=self.bsz_stride,
)
micro_batch_data["label"] = micro_batch_label
self.microbatch_offset[model_chunk_id] += 1
self.microbatch_offset[model_chunk_id] += self.bsz_stride
return move_to_device(micro_batch_data)
def _forward_step(self, engine, chunk_id):

View File

@ -110,6 +110,17 @@ def get_batch_size(data):
return data[list(data.keys())[0]].size(0)
def check_data_is_packed(data):
if isinstance(data, torch.Tensor):
return False
elif isinstance(data, (list, tuple)):
if isinstance(data[0], dict):
return "indexes" in data[0]
return False
elif isinstance(data, dict):
return "indexes" in data[0]
def filter_kwargs(func, kwargs):
sig = inspect.signature(func)
return {k: v for k, v in kwargs.items() if k in sig.parameters}

View File

@ -10,7 +10,15 @@ from internlm.core.context import global_context as gpc
# from internlm.core.context import ParallelMode
from internlm.core.context.parallel_context import Config
from internlm.core.trainer import TrainState
from internlm.train import get_train_data_loader, load_new_batch
from internlm.train import (
get_train_data_loader,
get_validation_data_loader,
load_new_batch,
)
from internlm.utils.evaluation import (
switch_evaluation_no_pipeline_scheduler,
switch_evaluation_pipeline_scheduler,
)
# from internlm.core.context.parallel_context import global_context as gpc
from tests.test_core.utils import build_environment, init_model_and_optim
@ -20,7 +28,7 @@ use_flash_attens = [True, False]
answers = [[1] * 8, [1, 1, 1, 1, 2, 2, 2, 2], [4] * 8, [2, 2, 4, 4, 6, 6, 8, 8]]
test_case_group = [
# format: micro_nums, rampup_batch_size, should sccuess, answer, pp size, sql len
# (1, "1 1 1", True, answers[0], 1, 8),
(1, "1 1 1", True, answers[0], 1, 8),
(4, "1 1 4", True, answers[1], 1, 8),
(4, None, True, answers[2], 1, 8),
(8, "2 2 2", True, answers[3], 1, 8),
@ -28,6 +36,11 @@ test_case_group = [
]
class DummyTrainer:
def __init__(self, scheduler) -> None:
self.schedule = scheduler
def do_warmup(args):
rank, worldsize, init_config, should_sccuess, answer = args
build_environment(rank, worldsize, init_config)
@ -44,9 +57,11 @@ def do_warmup(args):
)
scheduler.pre_processing(engine)
engine.train()
trainer = DummyTrainer(scheduler)
try:
train_dl, _ = get_train_data_loader(num_worker=0)
val_dls = get_validation_data_loader(num_worker=0)
except Exception as e:
assert should_sccuess is False, f"{e}"
else:
@ -105,6 +120,38 @@ def do_warmup(args):
tokens_num == answer[i] * gpc.config.data.seq_len * micro_bsz
), f"{tokens_num} == {answer[i] * gpc.config.data.seq_len * micro_bsz}"
# test no-packed datasets.
for _, val_dl in val_dls.items():
for _, batch in enumerate(val_dl):
if gpc.is_using_pp():
total_val_bsz = len(batch[1])
batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16)
assert total_val_bsz % micro_bsz == 0
num_microbatches = total_val_bsz // micro_bsz
tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1]]) # toy model hidden size is 8.
with switch_evaluation_pipeline_scheduler(
trainer=trainer,
num_microbatches=num_microbatches,
tensor_shape=tensor_shape,
metric_hook_list=[],
):
scheduler.forward_backward_step(
engine, batch, forward_only=True, return_loss=False, return_output_label=False
)
else:
total_val_bsz = len(batch[1])
batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16)
assert total_val_bsz % micro_bsz == 0
grad_accum_size = total_val_bsz // micro_bsz
with switch_evaluation_no_pipeline_scheduler(
trainer=trainer,
grad_accum_size=grad_accum_size,
metric_hook_list=[],
):
scheduler.forward_backward_step(
engine, batch, forward_only=True, return_loss=False, return_output_label=False
)
@pytest.mark.parametrize("use_flash_atten_case", use_flash_attens)
@pytest.mark.parametrize("group_case", test_case_group)
@ -121,7 +168,14 @@ def test_warmup(use_flash_atten_case, group_case, micro_bsz_case):
sequence_parallel=False,
tensor=1,
),
data=dict(train_folder=None, pack_sample_into_one=False, min_length=0, total_steps=8),
data=dict(
train_folder=None,
valid_folder=None,
valid_micro_num=4,
pack_sample_into_one=False,
min_length=0,
total_steps=8,
),
model=dict(
dtype=torch.bfloat16,
),