mirror of https://github.com/InternLM/InternLM
203 lines
7.7 KiB
Python
203 lines
7.7 KiB
Python
import multiprocessing as mp
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from internlm.core.context import ParallelMode
|
|
from internlm.core.context import global_context as gpc
|
|
|
|
# from internlm.core.context import ParallelMode
|
|
from internlm.core.context.parallel_context import Config
|
|
from internlm.core.trainer import TrainState
|
|
from internlm.train import (
|
|
get_train_data_loader,
|
|
get_validation_data_loader,
|
|
load_new_batch,
|
|
)
|
|
from internlm.utils.evaluation import (
|
|
switch_evaluation_no_pipeline_scheduler,
|
|
switch_evaluation_pipeline_scheduler,
|
|
)
|
|
|
|
# from internlm.core.context.parallel_context import global_context as gpc
|
|
from tests.test_core.utils import build_environment, init_model_and_optim
|
|
|
|
micro_bszs = [1, 2]
|
|
use_flash_attens = [True, False]
|
|
answers = [[1] * 8, [1, 1, 1, 1, 2, 2, 2, 2], [4] * 8, [2, 2, 4, 4, 6, 6, 8, 8]]
|
|
test_case_group = [
|
|
# format: micro_nums, rampup_batch_size, should sccuess, answer, pp size, sql len
|
|
(1, "1 1 1", True, answers[0], 1, 8),
|
|
(4, "1 1 4", True, answers[1], 1, 8),
|
|
(4, None, True, answers[2], 1, 8),
|
|
(8, "2 2 2", True, answers[3], 1, 8),
|
|
(8, "2 2 2", True, answers[3], 2, 8),
|
|
]
|
|
|
|
|
|
class DummyTrainer:
|
|
def __init__(self, scheduler) -> None:
|
|
self.schedule = scheduler
|
|
|
|
|
|
def do_warmup(args):
|
|
rank, worldsize, init_config, should_sccuess, answer = args
|
|
build_environment(rank, worldsize, init_config)
|
|
gpc.config.model.num_chunks = 1 if gpc.get_world_size(ParallelMode.PIPELINE) == 1 else 2
|
|
engine, scheduler = init_model_and_optim(
|
|
8,
|
|
gpc.config.model.num_chunks,
|
|
torch.bfloat16,
|
|
init_config.data.micro_num,
|
|
True,
|
|
tensor_shape=[1, 8], # can't use get_tensor_shape becase we use toy model.
|
|
init_optim=False,
|
|
embedding=True,
|
|
)
|
|
scheduler.pre_processing(engine)
|
|
engine.train()
|
|
trainer = DummyTrainer(scheduler)
|
|
|
|
try:
|
|
train_dl, _ = get_train_data_loader(num_worker=0)
|
|
val_dls = get_validation_data_loader(num_worker=0)
|
|
except Exception as e:
|
|
assert should_sccuess is False, f"{e}"
|
|
else:
|
|
assert should_sccuess is True
|
|
|
|
# initialize and resume train state
|
|
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
|
# transfer the train data loader into train data iterator
|
|
train_iter = iter(train_dl)
|
|
|
|
micro_bsz = gpc.config.data.micro_bsz
|
|
sql = gpc.config.data.seq_len
|
|
|
|
consumed_token = 0 # Token consumed
|
|
packed_length = micro_bsz * sql
|
|
for i in range(init_config.data.total_steps):
|
|
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
|
type_ids_shape = batch[0]["type_ids"].shape
|
|
input_ids_shape = batch[0]["input_ids"].shape
|
|
|
|
tokens_num = np.prod(input_ids_shape)
|
|
|
|
# If not use fa, 'type_ids' is unpcaked when load_new_batch is calling.
|
|
# However, 'input_ids' is unpcaked in pp/nopp engine.
|
|
if not init_config.model.use_flash_attn:
|
|
assert type_ids_shape == torch.Size(
|
|
[answer[i], micro_bsz, sql]
|
|
), f"iter:{i}, type_ids_shape: {type_ids_shape} != {torch.Size([answer[i], micro_bsz, sql])}"
|
|
else:
|
|
assert type_ids_shape == torch.Size(
|
|
[answer[i], packed_length]
|
|
), f"iter:{i}, type_ids_shape: {type_ids_shape} != {torch.Size([answer[i], packed_length])}"
|
|
|
|
assert input_ids_shape == torch.Size(
|
|
[answer[i], packed_length]
|
|
), f"iter:{i}, input_ids_shape: {input_ids_shape} != {torch.Size([answer[i], packed_length])}"
|
|
|
|
if gpc.get_global_rank() == 0:
|
|
print(
|
|
f"iter:{i}",
|
|
f"pp size: {gpc.get_world_size(ParallelMode.PIPELINE)}",
|
|
f"use_flash_attn:{gpc.config.model.use_flash_attn}",
|
|
f"micro_bsz:{micro_bsz}",
|
|
f"input shape: {batch[0]['type_ids'].shape}",
|
|
f"rampup_batch_size: {gpc.config.data.rampup_batch_size}",
|
|
f"tokens_num: {tokens_num}",
|
|
flush=True,
|
|
)
|
|
|
|
consumed_token += tokens_num
|
|
batch[0].pop("type_ids", None)
|
|
batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16)
|
|
|
|
scheduler.forward_backward_step(engine, batch, forward_only=True, return_loss=False, return_output_label=False)
|
|
assert (
|
|
tokens_num == answer[i] * gpc.config.data.seq_len * micro_bsz
|
|
), f"{tokens_num} == {answer[i] * gpc.config.data.seq_len * micro_bsz}"
|
|
|
|
# test no-packed datasets.
|
|
for _, val_dl in val_dls.items():
|
|
for _, batch in enumerate(val_dl):
|
|
if gpc.is_using_pp():
|
|
total_val_bsz = len(batch[1])
|
|
batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16)
|
|
assert total_val_bsz % micro_bsz == 0
|
|
num_microbatches = total_val_bsz // micro_bsz
|
|
tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1]]) # toy model hidden size is 8.
|
|
with switch_evaluation_pipeline_scheduler(
|
|
trainer=trainer,
|
|
num_microbatches=num_microbatches,
|
|
tensor_shape=tensor_shape,
|
|
metric_hook_list=[],
|
|
):
|
|
scheduler.forward_backward_step(
|
|
engine, batch, forward_only=True, return_loss=False, return_output_label=False
|
|
)
|
|
else:
|
|
total_val_bsz = len(batch[1])
|
|
batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16)
|
|
assert total_val_bsz % micro_bsz == 0
|
|
grad_accum_size = total_val_bsz // micro_bsz
|
|
with switch_evaluation_no_pipeline_scheduler(
|
|
trainer=trainer,
|
|
grad_accum_size=grad_accum_size,
|
|
metric_hook_list=[],
|
|
):
|
|
scheduler.forward_backward_step(
|
|
engine, batch, forward_only=True, return_loss=False, return_output_label=False
|
|
)
|
|
|
|
|
|
@pytest.mark.parametrize("use_flash_atten_case", use_flash_attens)
|
|
@pytest.mark.parametrize("group_case", test_case_group)
|
|
@pytest.mark.parametrize("micro_bsz_case", micro_bszs)
|
|
def test_warmup(use_flash_atten_case, group_case, micro_bsz_case):
|
|
ctx = mp.get_context("spawn")
|
|
# print(pp_size_case, use_flash_atten_case, group_case, micro_bsz_case, flush=True)
|
|
|
|
config = Config(
|
|
dict(
|
|
parallel=dict(
|
|
zero1=dict(size=1, fsdp=False),
|
|
pipeline=dict(size=1, interleaved_overlap=False),
|
|
sequence_parallel=False,
|
|
tensor=1,
|
|
),
|
|
data=dict(
|
|
train_folder=None,
|
|
valid_folder=None,
|
|
valid_micro_num=4,
|
|
pack_sample_into_one=False,
|
|
min_length=0,
|
|
total_steps=8,
|
|
),
|
|
model=dict(
|
|
dtype=torch.bfloat16,
|
|
),
|
|
adam=dict(lr=1e-4),
|
|
resume_tb_folder=None,
|
|
tensorboard_folder=None,
|
|
)
|
|
)
|
|
|
|
config.data.seq_len = group_case[5]
|
|
config.parallel.pipeline.size = group_case[4]
|
|
config.model.use_flash_attn = use_flash_atten_case
|
|
config.data.micro_bsz = micro_bsz_case
|
|
config.data.micro_num = group_case[0]
|
|
config.data.gradient_accumulation = config.data.micro_num
|
|
config.data.rampup_batch_size = group_case[1]
|
|
config.data.packed_length = micro_bsz_case * config.data.seq_len
|
|
should_sccuess = group_case[2]
|
|
answer = group_case[3]
|
|
|
|
with ctx.Pool(processes=8) as pool:
|
|
pool.map(do_warmup, [[rank, 8, config, should_sccuess, answer] for rank in range(8)])
|
|
pool.close()
|
|
pool.join()
|