mirror of https://github.com/InternLM/InternLM
feat(train): support_rampup_batch_size and fix bugs (#493)
parent
4a6987d5e7
commit
0bfc86205e
|
@ -59,7 +59,12 @@ data = dict(
|
|||
pack_sample_into_one=False,
|
||||
total_steps=50000,
|
||||
skip_batches="",
|
||||
rampup_batch_size="",
|
||||
# rampup_batch_size (str): A string with three space-separated integers representing the
|
||||
# starting batch size, the increment, and the number of steps between
|
||||
# each increment. For example, "192 24 8" means that the batch size (micro_num)
|
||||
# starts at 192 and increases by 24 every 8 steps. Defaults to None.
|
||||
# (IMPORTANT): The interval step size is 'micro_bsz'.
|
||||
rampup_batch_size=None,
|
||||
# Datasets with less than 50 rows will be discarded
|
||||
min_length=50,
|
||||
# train_folder=TRAIN_FOLDER,
|
||||
|
|
|
@ -156,6 +156,18 @@ class ParallelContext(metaclass=SingletonMeta):
|
|||
def config(self):
|
||||
return self._config
|
||||
|
||||
@property
|
||||
def micro_bsz(self):
|
||||
return self._config.data.micro_bsz
|
||||
|
||||
@property
|
||||
def micro_num(self):
|
||||
return self._config.data.micro_num
|
||||
|
||||
@property
|
||||
def grad_accum_num(self):
|
||||
return self._config.data.gradient_accumulation
|
||||
|
||||
@property
|
||||
def expert_parallel_group_names(self):
|
||||
return self._expert_parallel_group_names
|
||||
|
|
|
@ -36,10 +36,10 @@ class BaseScheduler(ABC):
|
|||
"""
|
||||
pass
|
||||
|
||||
def _load_micro_batch(self, data, label, offset, micro_bsz):
|
||||
def _load_micro_batch(self, data, label, offset):
|
||||
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
|
||||
micro_batch_data = {k: v[offset : offset + micro_bsz] for k, v in data.items()}
|
||||
micro_batch_label = label[offset : offset + micro_bsz]
|
||||
micro_batch_data = {k: v[offset : offset + 1] for k, v in data.items()}
|
||||
micro_batch_label = label[offset : offset + 1]
|
||||
|
||||
return micro_batch_data, micro_batch_label
|
||||
|
||||
|
|
|
@ -10,10 +10,13 @@ import torch
|
|||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.utils.common import conditional_context
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
class NonPipelineScheduler(BaseScheduler):
|
||||
"""A helper schedule class for no pipeline parallelism running environment.
|
||||
|
@ -69,10 +72,8 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
label (Any): The label to be loaded.
|
||||
"""
|
||||
|
||||
_data, _label = self._load_micro_batch(
|
||||
data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size
|
||||
)
|
||||
self._grad_accum_offset += self._grad_accum_batch_size
|
||||
_data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset)
|
||||
self._grad_accum_offset += 1
|
||||
|
||||
if self.data_process_func:
|
||||
_data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"])
|
||||
|
@ -135,7 +136,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
self._call_hooks("after_backward", None)
|
||||
|
||||
if not return_loss:
|
||||
loss = None
|
||||
loss, moe_loss = None, None
|
||||
|
||||
return output, loss, moe_loss
|
||||
|
||||
|
@ -166,12 +167,9 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
|
||||
batch_data, batch_size = engine.load_batch(data_iter)
|
||||
batch_data, actual_batch_size = engine.load_batch(data_iter)
|
||||
|
||||
assert (
|
||||
batch_size % self._grad_accum_size == 0
|
||||
), f"batch_size:{batch_size} must be an integer multiple of gradient accumulation steps:{self._grad_accum_size}"
|
||||
self._grad_accum_batch_size = batch_size // self._grad_accum_size
|
||||
self._grad_accum_size = actual_batch_size # Rampup or variable bsz size.
|
||||
|
||||
data, label = batch_data
|
||||
|
||||
|
@ -184,10 +182,11 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
self._grad_accum_offset = 0
|
||||
|
||||
for _current_accum_step in range(self._grad_accum_size):
|
||||
if _current_accum_step == self._grad_accum_size - 1:
|
||||
engine.optimizer.skip_grad_reduce = False
|
||||
else:
|
||||
engine.optimizer.skip_grad_reduce = True
|
||||
if engine.optimizer is not None:
|
||||
if _current_accum_step == self._grad_accum_size - 1:
|
||||
engine.optimizer.skip_grad_reduce = False
|
||||
else:
|
||||
engine.optimizer.skip_grad_reduce = True
|
||||
|
||||
_data, _label = self._load_accum_batch(data, label)
|
||||
|
||||
|
|
|
@ -15,10 +15,13 @@ from internlm.core.context import global_context as gpc
|
|||
from internlm.core.engine import Engine
|
||||
from internlm.core.naive_amp import NaiveAMPModel
|
||||
from internlm.utils.common import get_current_device, move_to_device
|
||||
from internlm.utils.logger import get_logger
|
||||
from internlm.utils.timeout import llm_timeout
|
||||
|
||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||
|
||||
logger = get_logger(__file__)
|
||||
|
||||
|
||||
def get_tensor_shape():
|
||||
if hasattr(gpc.config, "TENSOR_SHAPE"):
|
||||
|
@ -184,17 +187,16 @@ class PipelineScheduler(BaseScheduler):
|
|||
|
||||
def load_batch(self, engine, data_iter):
|
||||
# Pipeline schedule just puts data in memory
|
||||
batch_data, batch_size = engine.load_batch(data_iter, to_gpu=False)
|
||||
assert batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
|
||||
batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False)
|
||||
|
||||
self.num_microbatches = actual_batch_size # Rampup or variable bsz size.
|
||||
self.microbatch_offset = 0
|
||||
self.batch_size = batch_size
|
||||
self.batch_size = actual_batch_size
|
||||
self.batch_data, self.batch_label = batch_data
|
||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
|
||||
def load_micro_batch(self):
|
||||
micro_batch_data, micro_batch_label = self._load_micro_batch(
|
||||
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
|
||||
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset
|
||||
)
|
||||
if self.data_process_func:
|
||||
micro_batch_data["input_ids"] = self.data_process_func(
|
||||
|
@ -206,7 +208,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
micro_batch_data.pop("indexes")
|
||||
|
||||
micro_batch_data["label"] = micro_batch_label
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
self.microbatch_offset += 1
|
||||
|
||||
return move_to_device(micro_batch_data)
|
||||
|
||||
|
@ -785,10 +787,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
data=self.batch_data,
|
||||
label=self.batch_label,
|
||||
offset=self.microbatch_offset[model_chunk_id],
|
||||
micro_bsz=self.microbatch_size,
|
||||
)
|
||||
micro_batch_data["label"] = micro_batch_label
|
||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
||||
self.microbatch_offset[model_chunk_id] += 1
|
||||
return move_to_device(micro_batch_data)
|
||||
|
||||
def _forward_step(self, engine, chunk_id):
|
||||
|
|
|
@ -109,9 +109,15 @@ def args_sanity_check():
|
|||
if "micro_num" not in data:
|
||||
data._add_item("micro_num", 1)
|
||||
|
||||
data._add_item("gradient_accumulation", data.micro_num)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"gradient_accumulation size will be setted to {data.micro_num}.")
|
||||
if "gradient_accumulation" not in data:
|
||||
data._add_item("gradient_accumulation", data.micro_num)
|
||||
if gpc.is_rank_for_log():
|
||||
logger.info(f"gradient_accumulation size will be setted to {data.micro_num}.")
|
||||
else:
|
||||
if pp == 1:
|
||||
assert (
|
||||
data.gradient_accumulation == data.micro_num
|
||||
), "for nopp 'gradient_accumulation' should equal with 'micro_num'"
|
||||
|
||||
# batch_size should be equal with micro_num, should not use it directly
|
||||
data._add_item("batch_size", data.micro_num)
|
||||
|
@ -136,6 +142,11 @@ def args_sanity_check():
|
|||
|
||||
if "diag_outlier_ratio" not in data:
|
||||
data._add_item("diag_outlier_ratio", 1.1)
|
||||
|
||||
if "rampup_batch_size" not in data or not data.rampup_batch_size or len(data.rampup_batch_size) == 0:
|
||||
bsz = data.micro_num
|
||||
data._add_item("rampup_batch_size", f"{bsz} {bsz} 1")
|
||||
|
||||
data.diag_outlier_ratio = max(1, data.diag_outlier_ratio)
|
||||
|
||||
if gpc.is_rank_for_log():
|
||||
|
@ -148,6 +159,7 @@ def args_sanity_check():
|
|||
logger.info(f"min_length: {data.min_length}")
|
||||
logger.info(f"valid_micro_num: {data.valid_micro_num}")
|
||||
logger.info(f"valid_every: {data.valid_every}")
|
||||
logger.info(f"rampup_batch_size: {data.rampup_batch_size}")
|
||||
|
||||
# processing the checkpoint config
|
||||
ckpt = gpc.config.ckpt
|
||||
|
|
|
@ -11,22 +11,19 @@ from internlm.model.metrics import AccPerplex
|
|||
|
||||
|
||||
@contextmanager
|
||||
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list):
|
||||
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, metric_hook_list):
|
||||
if not gpc.is_using_pp():
|
||||
prev_data_process_func = trainer.schedule.data_process_func
|
||||
prev_grad_accum_size = trainer.schedule._grad_accum_size
|
||||
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
|
||||
prev_metric_hooks = trainer.schedule._hooks
|
||||
try:
|
||||
trainer.schedule.data_process_func = None
|
||||
trainer.schedule._grad_accum_size = grad_accum_size
|
||||
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
|
||||
trainer.schedule._hooks = metric_hook_list
|
||||
yield
|
||||
finally:
|
||||
trainer.schedule.data_process_func = prev_data_process_func
|
||||
trainer.schedule._grad_accum_size = prev_grad_accum_size
|
||||
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
|
||||
trainer.schedule._hooks = prev_metric_hooks
|
||||
|
||||
|
||||
|
@ -126,11 +123,9 @@ def evaluate_on_val_dls(
|
|||
total_val_bsz = len(batch[1])
|
||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||
grad_accum_batch_size = data_cfg.micro_bsz
|
||||
with switch_evaluation_no_pipeline_scheduler(
|
||||
trainer=trainer,
|
||||
grad_accum_size=grad_accum_size,
|
||||
grad_accum_batch_size=grad_accum_batch_size,
|
||||
metric_hook_list=[val_sche_metric_hook],
|
||||
):
|
||||
if hasattr(gpc.config.model, "num_experts"):
|
||||
|
|
|
@ -1,59 +1,19 @@
|
|||
import multiprocessing as mp
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import internlm
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from internlm.core.scheduler import (
|
||||
InterleavedPipelineScheduler,
|
||||
PipelineScheduler,
|
||||
SchedulerMetricHook,
|
||||
from tests.test_core.utils import (
|
||||
MlpModel,
|
||||
MyLoss,
|
||||
build_environment,
|
||||
init_model_and_optim,
|
||||
loose_close,
|
||||
seed_all,
|
||||
)
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
from internlm.train import initialize_optimizer
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
"""
|
||||
Custom model
|
||||
"""
|
||||
|
||||
def __init__(self, start, end, model_type=None):
|
||||
super().__init__()
|
||||
self.part = [start, end]
|
||||
self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end - start)])
|
||||
self.model_type = model_type
|
||||
|
||||
def forward(self, hidden_states=None, input_ids=None):
|
||||
if self.model_type != "torch" and self.part[0] != 0:
|
||||
input_ids = hidden_states
|
||||
|
||||
for i in range(self.part[1] - self.part[0]):
|
||||
input_ids = self.blocks[i](input_ids)
|
||||
return input_ids
|
||||
|
||||
|
||||
class MyLoss(nn.Module):
|
||||
"""
|
||||
Custom loss
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
loss = torch.nn.MSELoss(reduction="sum")
|
||||
return loss(logits, labels)
|
||||
|
||||
|
||||
config = Config(
|
||||
dict(
|
||||
|
@ -116,71 +76,6 @@ config = Config(
|
|||
)
|
||||
|
||||
|
||||
def build_environment(rank, world_size):
|
||||
import os
|
||||
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "33333"
|
||||
torch.cuda.empty_cache()
|
||||
# launcher="torch"
|
||||
internlm.launch_from_torch(config=config, seed=1024)
|
||||
|
||||
|
||||
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
|
||||
if dtype is torch.float32:
|
||||
rtol = 1.3e-6
|
||||
atol = 1e-5
|
||||
elif dtype is torch.bfloat16:
|
||||
rtol = 2e-2
|
||||
atol = 2e-2
|
||||
|
||||
if isinstance(a, torch.Tensor):
|
||||
a = a.detach().to(dtype)
|
||||
b = b.detach().to(dtype)
|
||||
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def seed_all(seed, cuda_deterministic=False):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if cuda_deterministic: # slower, more reproducible
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
else:
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def _build_generic_model_1d(num_layers, num_chunks):
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
if gpc.is_rank_for_log():
|
||||
print(f"The layer sharding is {all_parts}.", flush=True)
|
||||
|
||||
models = []
|
||||
for start, end in parts:
|
||||
models.append(MlpModel(start, end).cuda())
|
||||
torch.distributed.barrier()
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def exam_pipeline_parallel(args):
|
||||
# init
|
||||
rank, world_size, micro_num, num_chunks, interleaved_overlap = args
|
||||
|
@ -188,76 +83,18 @@ def exam_pipeline_parallel(args):
|
|||
config.model.num_chunks = num_chunks
|
||||
config.parallel.pipeline.interleaved_overlap = interleaved_overlap
|
||||
|
||||
build_environment(rank, world_size)
|
||||
build_environment(rank, world_size, config)
|
||||
|
||||
device = torch.device(f"cuda:{rank}")
|
||||
dtype = config.model["dtype"]
|
||||
seq_len = gpc.config.data.seq_len
|
||||
|
||||
# set seed
|
||||
seed_all(1024)
|
||||
|
||||
# pp model
|
||||
pp_model = _build_generic_model_1d(num_layers=32, num_chunks=num_chunks)
|
||||
pp_model = pp_model.to(dtype)
|
||||
|
||||
# pp scheduler
|
||||
scheduler_hooks = [
|
||||
SchedulerMetricHook(skip=True),
|
||||
]
|
||||
|
||||
seq_len = gpc.config.data.seq_len
|
||||
gpc.config.NUM_MICRO_BATCHES = micro_num
|
||||
communication_overlap = interleaved_overlap
|
||||
|
||||
if num_chunks == 1:
|
||||
# noninterleaved pp
|
||||
scheduler = PipelineScheduler(
|
||||
data_process_func=None,
|
||||
num_microbatches=micro_num,
|
||||
dtype=dtype,
|
||||
tensor_shape=[1, 8],
|
||||
scatter_gather_tensors=False,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
else:
|
||||
# interleaved pp
|
||||
if micro_num < gpc.get_world_size(ParallelMode.PIPELINE):
|
||||
try:
|
||||
scheduler = InterleavedPipelineScheduler(
|
||||
num_microbatches=micro_num,
|
||||
num_chunks=gpc.config.model.num_chunks,
|
||||
dtype=dtype,
|
||||
tensor_shape=[1, 8],
|
||||
scatter_gather_tensors=False,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
communication_overlap=communication_overlap,
|
||||
)
|
||||
except AssertionError:
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("Error: AssertionError should occur when micro_num < Pipeline parrallel world size")
|
||||
else:
|
||||
scheduler = InterleavedPipelineScheduler(
|
||||
num_microbatches=micro_num,
|
||||
num_chunks=gpc.config.model.num_chunks,
|
||||
dtype=dtype,
|
||||
tensor_shape=[1, 8],
|
||||
scatter_gather_tensors=False,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
communication_overlap=communication_overlap,
|
||||
)
|
||||
|
||||
# pp optimizer and engine
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
||||
engine = Engine(
|
||||
model=pp_model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
criterion=MyLoss().to(dtype),
|
||||
gradient_handlers=[PipelineSharedModuleGradientHandler(model=pp_model, optimizer=optimizer)],
|
||||
clip_grad_norm=gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0),
|
||||
)
|
||||
engine, scheduler = init_model_and_optim(32, num_chunks, dtype, micro_num, interleaved_overlap, tensor_shape=[1, 8])
|
||||
if scheduler is None:
|
||||
return
|
||||
|
||||
scheduler.pre_processing(engine)
|
||||
engine.train()
|
||||
|
|
|
@ -0,0 +1,207 @@
|
|||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.testing import assert_close
|
||||
|
||||
import internlm
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
from internlm.core.engine import Engine
|
||||
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
||||
from internlm.core.scheduler import (
|
||||
InterleavedPipelineScheduler,
|
||||
NonPipelineScheduler,
|
||||
PipelineScheduler,
|
||||
SchedulerMetricHook,
|
||||
)
|
||||
from internlm.solver.pipeline_utils import partition_uniform
|
||||
from internlm.train import initialize_optimizer
|
||||
|
||||
|
||||
class MlpModel(nn.Module):
|
||||
"""
|
||||
Custom model
|
||||
"""
|
||||
|
||||
def __init__(self, start, end, model_type=None, embedding=False):
|
||||
super().__init__()
|
||||
self.part = [start, end]
|
||||
self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end - start)])
|
||||
self.model_type = model_type
|
||||
self.embedding = embedding
|
||||
|
||||
def forward(
|
||||
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None
|
||||
): # pylint: disable=W0613
|
||||
if self.model_type != "torch" and self.part[0] != 0:
|
||||
input_ids = hidden_states
|
||||
|
||||
# Simulate Embedding.
|
||||
if self.embedding:
|
||||
if len(input_ids.shape) == 2:
|
||||
input_ids = input_ids.view(-1, 8)
|
||||
elif len(input_ids.shape) == 3:
|
||||
input_ids = input_ids.view(input_ids.shape(0), -1, 8)
|
||||
|
||||
for i in range(self.part[1] - self.part[0]):
|
||||
input_ids = self.blocks[i](input_ids)
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
class MyLoss(nn.Module):
|
||||
"""
|
||||
Custom loss
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, logits, labels):
|
||||
loss = torch.nn.MSELoss(reduction="sum")
|
||||
return loss(logits, labels)
|
||||
|
||||
|
||||
def init_model_and_optim(
|
||||
num_layers, num_chunks, dtype, micro_num, interleaved_overlap, tensor_shape, init_optim=True, embedding=False
|
||||
):
|
||||
# pp model
|
||||
pp_model = _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, embedding=embedding)
|
||||
pp_model = pp_model.to(dtype)
|
||||
|
||||
# pp scheduler
|
||||
scheduler_hooks = [
|
||||
SchedulerMetricHook(skip=True),
|
||||
]
|
||||
|
||||
if gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||
if num_chunks == 1:
|
||||
# noninterleaved pp
|
||||
scheduler = PipelineScheduler(
|
||||
data_process_func=None,
|
||||
num_microbatches=micro_num,
|
||||
dtype=dtype,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=False,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
else:
|
||||
# interleaved pp
|
||||
if micro_num < gpc.get_world_size(ParallelMode.PIPELINE):
|
||||
try:
|
||||
scheduler = InterleavedPipelineScheduler(
|
||||
num_microbatches=micro_num,
|
||||
num_chunks=gpc.config.model.num_chunks,
|
||||
dtype=dtype,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=False,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
communication_overlap=interleaved_overlap,
|
||||
)
|
||||
except AssertionError as e:
|
||||
print(f"AssertionError: {e}", flush=True)
|
||||
return None, None
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Error: AssertionError should occur when micro_num < Pipeline parrallel world size"
|
||||
)
|
||||
else:
|
||||
scheduler = InterleavedPipelineScheduler(
|
||||
num_microbatches=micro_num,
|
||||
num_chunks=gpc.config.model.num_chunks,
|
||||
dtype=dtype,
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=False,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
communication_overlap=interleaved_overlap,
|
||||
)
|
||||
else:
|
||||
scheduler = NonPipelineScheduler(
|
||||
data_process_func=None,
|
||||
gradient_accumulation_size=gpc.config.data.gradient_accumulation,
|
||||
scheduler_hooks=scheduler_hooks,
|
||||
)
|
||||
|
||||
# pp optimizer and engine
|
||||
if init_optim:
|
||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
||||
else:
|
||||
optimizer, beta2_scheduler, lr_scheduler = None, None, None
|
||||
|
||||
engine = Engine(
|
||||
model=pp_model,
|
||||
optimizer=optimizer,
|
||||
lr_scheduler=lr_scheduler,
|
||||
beta2_scheduler=beta2_scheduler,
|
||||
criterion=MyLoss().to(dtype),
|
||||
gradient_handlers=[PipelineSharedModuleGradientHandler(model=pp_model, optimizer=optimizer)],
|
||||
clip_grad_norm=0.0,
|
||||
)
|
||||
return engine, scheduler
|
||||
|
||||
|
||||
def build_environment(rank, world_size, config):
|
||||
import os
|
||||
|
||||
os.environ["RANK"] = str(rank)
|
||||
os.environ["LOCAL_RANK"] = str(rank)
|
||||
os.environ["WORLD_SIZE"] = str(world_size)
|
||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||
os.environ["MASTER_PORT"] = "33333"
|
||||
torch.cuda.empty_cache()
|
||||
# launcher="torch"
|
||||
internlm.launch_from_torch(config=config, seed=1024)
|
||||
|
||||
|
||||
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||
|
||||
if dtype is torch.float32:
|
||||
rtol = 1.3e-6
|
||||
atol = 1e-5
|
||||
elif dtype is torch.bfloat16:
|
||||
rtol = 2e-2
|
||||
atol = 2e-2
|
||||
|
||||
if isinstance(a, torch.Tensor):
|
||||
a = a.detach().to(dtype)
|
||||
b = b.detach().to(dtype)
|
||||
|
||||
assert_close(a, b, rtol=rtol, atol=atol)
|
||||
|
||||
|
||||
def seed_all(seed, cuda_deterministic=False):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if cuda_deterministic: # slower, more reproducible
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
else:
|
||||
torch.backends.cudnn.deterministic = False
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
|
||||
def _build_generic_model_1d(num_layers, num_chunks, embedding=False):
|
||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||
|
||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||
parts = all_parts[pipeline_rank]
|
||||
if gpc.is_rank_for_log():
|
||||
print(f"The layer sharding is {all_parts}.", flush=True)
|
||||
|
||||
models = []
|
||||
for start, end in parts:
|
||||
models.append(MlpModel(start, end, embedding=embedding).cuda())
|
||||
torch.distributed.barrier()
|
||||
if len(models) == 1:
|
||||
model = models[0]
|
||||
else:
|
||||
model = nn.ModuleList(models)
|
||||
|
||||
return model
|
|
@ -0,0 +1,143 @@
|
|||
import multiprocessing as mp
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from internlm.core.context import ParallelMode
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
# from internlm.core.context import ParallelMode
|
||||
from internlm.core.context.parallel_context import Config
|
||||
from internlm.core.trainer import TrainState
|
||||
from internlm.train import get_train_data_loader, load_new_batch
|
||||
|
||||
# from internlm.core.context.parallel_context import global_context as gpc
|
||||
from tests.test_core.utils import build_environment, init_model_and_optim
|
||||
|
||||
micro_bszs = [1, 2]
|
||||
use_flash_attens = [True, False]
|
||||
answers = [[1] * 8, [1, 1, 1, 1, 2, 2, 2, 2], [4] * 8, [2, 2, 4, 4, 6, 6, 8, 8]]
|
||||
test_case_group = [
|
||||
# format: micro_nums, rampup_batch_size, should sccuess, answer, pp size, sql len
|
||||
# (1, "1 1 1", True, answers[0], 1, 8),
|
||||
(4, "1 1 4", True, answers[1], 1, 8),
|
||||
(4, None, True, answers[2], 1, 8),
|
||||
(8, "2 2 2", True, answers[3], 1, 8),
|
||||
(8, "2 2 2", True, answers[3], 2, 8),
|
||||
]
|
||||
|
||||
|
||||
def do_warmup(args):
|
||||
rank, worldsize, init_config, should_sccuess, answer = args
|
||||
build_environment(rank, worldsize, init_config)
|
||||
gpc.config.model.num_chunks = 1 if gpc.get_world_size(ParallelMode.PIPELINE) == 1 else 2
|
||||
engine, scheduler = init_model_and_optim(
|
||||
8,
|
||||
gpc.config.model.num_chunks,
|
||||
torch.bfloat16,
|
||||
init_config.data.micro_num,
|
||||
True,
|
||||
tensor_shape=[1, 8], # can't use get_tensor_shape becase we use toy model.
|
||||
init_optim=False,
|
||||
embedding=True,
|
||||
)
|
||||
scheduler.pre_processing(engine)
|
||||
engine.train()
|
||||
|
||||
try:
|
||||
train_dl, _ = get_train_data_loader(num_worker=0)
|
||||
except Exception as e:
|
||||
assert should_sccuess is False, f"{e}"
|
||||
else:
|
||||
assert should_sccuess is True
|
||||
|
||||
# initialize and resume train state
|
||||
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||
# transfer the train data loader into train data iterator
|
||||
train_iter = iter(train_dl)
|
||||
|
||||
micro_bsz = gpc.config.data.micro_bsz
|
||||
sql = gpc.config.data.seq_len
|
||||
|
||||
consumed_token = 0 # Token consumed
|
||||
packed_length = micro_bsz * sql
|
||||
for i in range(init_config.data.total_steps):
|
||||
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||
input_shape = batch[0]["type_ids"].shape
|
||||
tokens_num = np.prod(input_shape)
|
||||
|
||||
if not init_config.model.use_flash_attn:
|
||||
if answer[i] > 1:
|
||||
assert input_shape == torch.Size(
|
||||
[answer[i], micro_bsz, sql]
|
||||
), f"iter:{i}, {input_shape} != {[answer[i], micro_bsz, sql]}"
|
||||
else:
|
||||
assert input_shape == torch.Size([micro_bsz, sql]), f"iter:{i}, {input_shape} != {[micro_bsz, sql]}"
|
||||
else:
|
||||
assert input_shape == torch.Size(
|
||||
[answer[i], packed_length]
|
||||
), f"iter:{i}, {input_shape} != {torch.Size([answer[i], packed_length])}"
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
print(
|
||||
f"iter:{i}",
|
||||
f"pp size: {gpc.get_world_size(ParallelMode.PIPELINE)}",
|
||||
f"use_flash_attn:{gpc.config.model.use_flash_attn}",
|
||||
f"micro_bsz:{micro_bsz}",
|
||||
f"input shape: {batch[0]['type_ids'].shape}",
|
||||
f"rampup_batch_size: {gpc.config.data.rampup_batch_size}",
|
||||
f"tokens_num: {tokens_num}",
|
||||
flush=True,
|
||||
)
|
||||
|
||||
consumed_token += tokens_num
|
||||
batch[0].pop("type_ids", None)
|
||||
batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16)
|
||||
|
||||
scheduler.forward_backward_step(engine, batch, forward_only=True, return_loss=False, return_output_label=False)
|
||||
assert (
|
||||
tokens_num == answer[i] * gpc.config.data.seq_len * micro_bsz
|
||||
), f"{tokens_num} == {answer[i] * gpc.config.data.seq_len * micro_bsz}"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("use_flash_atten_case", use_flash_attens)
|
||||
@pytest.mark.parametrize("group_case", test_case_group)
|
||||
@pytest.mark.parametrize("micro_bsz_case", micro_bszs)
|
||||
def test_warmup(use_flash_atten_case, group_case, micro_bsz_case):
|
||||
ctx = mp.get_context("spawn")
|
||||
# print(pp_size_case, use_flash_atten_case, group_case, micro_bsz_case, flush=True)
|
||||
|
||||
config = Config(
|
||||
dict(
|
||||
parallel=dict(
|
||||
zero1=dict(size=1, fsdp=False),
|
||||
pipeline=dict(size=1, interleaved_overlap=False),
|
||||
sequence_parallel=False,
|
||||
tensor=1,
|
||||
),
|
||||
data=dict(train_folder=None, pack_sample_into_one=False, min_length=0, total_steps=8),
|
||||
model=dict(
|
||||
dtype=torch.bfloat16,
|
||||
),
|
||||
adam=dict(lr=1e-4),
|
||||
resume_tb_folder=None,
|
||||
tensorboard_folder=None,
|
||||
)
|
||||
)
|
||||
|
||||
config.data.seq_len = group_case[5]
|
||||
config.parallel.pipeline.size = group_case[4]
|
||||
config.model.use_flash_attn = use_flash_atten_case
|
||||
config.data.micro_bsz = micro_bsz_case
|
||||
config.data.micro_num = group_case[0]
|
||||
config.data.gradient_accumulation = config.data.micro_num
|
||||
config.data.rampup_batch_size = group_case[1]
|
||||
config.data.packed_length = micro_bsz_case * config.data.seq_len
|
||||
should_sccuess = group_case[2]
|
||||
answer = group_case[3]
|
||||
|
||||
with ctx.Pool(processes=8) as pool:
|
||||
pool.map(do_warmup, [[rank, 8, config, should_sccuess, answer] for rank in range(8)])
|
||||
pool.close()
|
||||
pool.join()
|
Loading…
Reference in New Issue