mirror of https://github.com/InternLM/InternLM
feat(train): support_rampup_batch_size and fix bugs (#493)
parent
4a6987d5e7
commit
0bfc86205e
|
@ -59,7 +59,12 @@ data = dict(
|
||||||
pack_sample_into_one=False,
|
pack_sample_into_one=False,
|
||||||
total_steps=50000,
|
total_steps=50000,
|
||||||
skip_batches="",
|
skip_batches="",
|
||||||
rampup_batch_size="",
|
# rampup_batch_size (str): A string with three space-separated integers representing the
|
||||||
|
# starting batch size, the increment, and the number of steps between
|
||||||
|
# each increment. For example, "192 24 8" means that the batch size (micro_num)
|
||||||
|
# starts at 192 and increases by 24 every 8 steps. Defaults to None.
|
||||||
|
# (IMPORTANT): The interval step size is 'micro_bsz'.
|
||||||
|
rampup_batch_size=None,
|
||||||
# Datasets with less than 50 rows will be discarded
|
# Datasets with less than 50 rows will be discarded
|
||||||
min_length=50,
|
min_length=50,
|
||||||
# train_folder=TRAIN_FOLDER,
|
# train_folder=TRAIN_FOLDER,
|
||||||
|
|
|
@ -156,6 +156,18 @@ class ParallelContext(metaclass=SingletonMeta):
|
||||||
def config(self):
|
def config(self):
|
||||||
return self._config
|
return self._config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def micro_bsz(self):
|
||||||
|
return self._config.data.micro_bsz
|
||||||
|
|
||||||
|
@property
|
||||||
|
def micro_num(self):
|
||||||
|
return self._config.data.micro_num
|
||||||
|
|
||||||
|
@property
|
||||||
|
def grad_accum_num(self):
|
||||||
|
return self._config.data.gradient_accumulation
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def expert_parallel_group_names(self):
|
def expert_parallel_group_names(self):
|
||||||
return self._expert_parallel_group_names
|
return self._expert_parallel_group_names
|
||||||
|
|
|
@ -36,10 +36,10 @@ class BaseScheduler(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _load_micro_batch(self, data, label, offset, micro_bsz):
|
def _load_micro_batch(self, data, label, offset):
|
||||||
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
|
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
|
||||||
micro_batch_data = {k: v[offset : offset + micro_bsz] for k, v in data.items()}
|
micro_batch_data = {k: v[offset : offset + 1] for k, v in data.items()}
|
||||||
micro_batch_label = label[offset : offset + micro_bsz]
|
micro_batch_label = label[offset : offset + 1]
|
||||||
|
|
||||||
return micro_batch_data, micro_batch_label
|
return micro_batch_data, micro_batch_label
|
||||||
|
|
||||||
|
|
|
@ -10,10 +10,13 @@ import torch
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.engine import Engine
|
from internlm.core.engine import Engine
|
||||||
from internlm.utils.common import conditional_context
|
from internlm.utils.common import conditional_context
|
||||||
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||||
|
|
||||||
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
class NonPipelineScheduler(BaseScheduler):
|
class NonPipelineScheduler(BaseScheduler):
|
||||||
"""A helper schedule class for no pipeline parallelism running environment.
|
"""A helper schedule class for no pipeline parallelism running environment.
|
||||||
|
@ -69,10 +72,8 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
label (Any): The label to be loaded.
|
label (Any): The label to be loaded.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_data, _label = self._load_micro_batch(
|
_data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset)
|
||||||
data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size
|
self._grad_accum_offset += 1
|
||||||
)
|
|
||||||
self._grad_accum_offset += self._grad_accum_batch_size
|
|
||||||
|
|
||||||
if self.data_process_func:
|
if self.data_process_func:
|
||||||
_data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"])
|
_data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"])
|
||||||
|
@ -135,7 +136,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
self._call_hooks("after_backward", None)
|
self._call_hooks("after_backward", None)
|
||||||
|
|
||||||
if not return_loss:
|
if not return_loss:
|
||||||
loss = None
|
loss, moe_loss = None, None
|
||||||
|
|
||||||
return output, loss, moe_loss
|
return output, loss, moe_loss
|
||||||
|
|
||||||
|
@ -166,12 +167,9 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
forward_only or return_loss
|
forward_only or return_loss
|
||||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||||
|
|
||||||
batch_data, batch_size = engine.load_batch(data_iter)
|
batch_data, actual_batch_size = engine.load_batch(data_iter)
|
||||||
|
|
||||||
assert (
|
self._grad_accum_size = actual_batch_size # Rampup or variable bsz size.
|
||||||
batch_size % self._grad_accum_size == 0
|
|
||||||
), f"batch_size:{batch_size} must be an integer multiple of gradient accumulation steps:{self._grad_accum_size}"
|
|
||||||
self._grad_accum_batch_size = batch_size // self._grad_accum_size
|
|
||||||
|
|
||||||
data, label = batch_data
|
data, label = batch_data
|
||||||
|
|
||||||
|
@ -184,6 +182,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
self._grad_accum_offset = 0
|
self._grad_accum_offset = 0
|
||||||
|
|
||||||
for _current_accum_step in range(self._grad_accum_size):
|
for _current_accum_step in range(self._grad_accum_size):
|
||||||
|
if engine.optimizer is not None:
|
||||||
if _current_accum_step == self._grad_accum_size - 1:
|
if _current_accum_step == self._grad_accum_size - 1:
|
||||||
engine.optimizer.skip_grad_reduce = False
|
engine.optimizer.skip_grad_reduce = False
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -15,10 +15,13 @@ from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.engine import Engine
|
from internlm.core.engine import Engine
|
||||||
from internlm.core.naive_amp import NaiveAMPModel
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
from internlm.utils.common import get_current_device, move_to_device
|
from internlm.utils.common import get_current_device, move_to_device
|
||||||
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
from .base_scheduler import BaseScheduler, SchedulerHook
|
from .base_scheduler import BaseScheduler, SchedulerHook
|
||||||
|
|
||||||
|
logger = get_logger(__file__)
|
||||||
|
|
||||||
|
|
||||||
def get_tensor_shape():
|
def get_tensor_shape():
|
||||||
if hasattr(gpc.config, "TENSOR_SHAPE"):
|
if hasattr(gpc.config, "TENSOR_SHAPE"):
|
||||||
|
@ -184,17 +187,16 @@ class PipelineScheduler(BaseScheduler):
|
||||||
|
|
||||||
def load_batch(self, engine, data_iter):
|
def load_batch(self, engine, data_iter):
|
||||||
# Pipeline schedule just puts data in memory
|
# Pipeline schedule just puts data in memory
|
||||||
batch_data, batch_size = engine.load_batch(data_iter, to_gpu=False)
|
batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False)
|
||||||
assert batch_size % self.num_microbatches == 0, "Batch size should divided by the number of microbatches"
|
|
||||||
|
|
||||||
|
self.num_microbatches = actual_batch_size # Rampup or variable bsz size.
|
||||||
self.microbatch_offset = 0
|
self.microbatch_offset = 0
|
||||||
self.batch_size = batch_size
|
self.batch_size = actual_batch_size
|
||||||
self.batch_data, self.batch_label = batch_data
|
self.batch_data, self.batch_label = batch_data
|
||||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
|
||||||
|
|
||||||
def load_micro_batch(self):
|
def load_micro_batch(self):
|
||||||
micro_batch_data, micro_batch_label = self._load_micro_batch(
|
micro_batch_data, micro_batch_label = self._load_micro_batch(
|
||||||
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
|
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset
|
||||||
)
|
)
|
||||||
if self.data_process_func:
|
if self.data_process_func:
|
||||||
micro_batch_data["input_ids"] = self.data_process_func(
|
micro_batch_data["input_ids"] = self.data_process_func(
|
||||||
|
@ -206,7 +208,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
micro_batch_data.pop("indexes")
|
micro_batch_data.pop("indexes")
|
||||||
|
|
||||||
micro_batch_data["label"] = micro_batch_label
|
micro_batch_data["label"] = micro_batch_label
|
||||||
self.microbatch_offset += self.microbatch_size
|
self.microbatch_offset += 1
|
||||||
|
|
||||||
return move_to_device(micro_batch_data)
|
return move_to_device(micro_batch_data)
|
||||||
|
|
||||||
|
@ -785,10 +787,9 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
data=self.batch_data,
|
data=self.batch_data,
|
||||||
label=self.batch_label,
|
label=self.batch_label,
|
||||||
offset=self.microbatch_offset[model_chunk_id],
|
offset=self.microbatch_offset[model_chunk_id],
|
||||||
micro_bsz=self.microbatch_size,
|
|
||||||
)
|
)
|
||||||
micro_batch_data["label"] = micro_batch_label
|
micro_batch_data["label"] = micro_batch_label
|
||||||
self.microbatch_offset[model_chunk_id] += self.microbatch_size
|
self.microbatch_offset[model_chunk_id] += 1
|
||||||
return move_to_device(micro_batch_data)
|
return move_to_device(micro_batch_data)
|
||||||
|
|
||||||
def _forward_step(self, engine, chunk_id):
|
def _forward_step(self, engine, chunk_id):
|
||||||
|
|
|
@ -109,9 +109,15 @@ def args_sanity_check():
|
||||||
if "micro_num" not in data:
|
if "micro_num" not in data:
|
||||||
data._add_item("micro_num", 1)
|
data._add_item("micro_num", 1)
|
||||||
|
|
||||||
|
if "gradient_accumulation" not in data:
|
||||||
data._add_item("gradient_accumulation", data.micro_num)
|
data._add_item("gradient_accumulation", data.micro_num)
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info(f"gradient_accumulation size will be setted to {data.micro_num}.")
|
logger.info(f"gradient_accumulation size will be setted to {data.micro_num}.")
|
||||||
|
else:
|
||||||
|
if pp == 1:
|
||||||
|
assert (
|
||||||
|
data.gradient_accumulation == data.micro_num
|
||||||
|
), "for nopp 'gradient_accumulation' should equal with 'micro_num'"
|
||||||
|
|
||||||
# batch_size should be equal with micro_num, should not use it directly
|
# batch_size should be equal with micro_num, should not use it directly
|
||||||
data._add_item("batch_size", data.micro_num)
|
data._add_item("batch_size", data.micro_num)
|
||||||
|
@ -136,6 +142,11 @@ def args_sanity_check():
|
||||||
|
|
||||||
if "diag_outlier_ratio" not in data:
|
if "diag_outlier_ratio" not in data:
|
||||||
data._add_item("diag_outlier_ratio", 1.1)
|
data._add_item("diag_outlier_ratio", 1.1)
|
||||||
|
|
||||||
|
if "rampup_batch_size" not in data or not data.rampup_batch_size or len(data.rampup_batch_size) == 0:
|
||||||
|
bsz = data.micro_num
|
||||||
|
data._add_item("rampup_batch_size", f"{bsz} {bsz} 1")
|
||||||
|
|
||||||
data.diag_outlier_ratio = max(1, data.diag_outlier_ratio)
|
data.diag_outlier_ratio = max(1, data.diag_outlier_ratio)
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
|
@ -148,6 +159,7 @@ def args_sanity_check():
|
||||||
logger.info(f"min_length: {data.min_length}")
|
logger.info(f"min_length: {data.min_length}")
|
||||||
logger.info(f"valid_micro_num: {data.valid_micro_num}")
|
logger.info(f"valid_micro_num: {data.valid_micro_num}")
|
||||||
logger.info(f"valid_every: {data.valid_every}")
|
logger.info(f"valid_every: {data.valid_every}")
|
||||||
|
logger.info(f"rampup_batch_size: {data.rampup_batch_size}")
|
||||||
|
|
||||||
# processing the checkpoint config
|
# processing the checkpoint config
|
||||||
ckpt = gpc.config.ckpt
|
ckpt = gpc.config.ckpt
|
||||||
|
|
|
@ -11,22 +11,19 @@ from internlm.model.metrics import AccPerplex
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, grad_accum_batch_size, metric_hook_list):
|
def switch_evaluation_no_pipeline_scheduler(trainer, grad_accum_size, metric_hook_list):
|
||||||
if not gpc.is_using_pp():
|
if not gpc.is_using_pp():
|
||||||
prev_data_process_func = trainer.schedule.data_process_func
|
prev_data_process_func = trainer.schedule.data_process_func
|
||||||
prev_grad_accum_size = trainer.schedule._grad_accum_size
|
prev_grad_accum_size = trainer.schedule._grad_accum_size
|
||||||
prev_grad_accum_batch_size = trainer.schedule._grad_accum_batch_size
|
|
||||||
prev_metric_hooks = trainer.schedule._hooks
|
prev_metric_hooks = trainer.schedule._hooks
|
||||||
try:
|
try:
|
||||||
trainer.schedule.data_process_func = None
|
trainer.schedule.data_process_func = None
|
||||||
trainer.schedule._grad_accum_size = grad_accum_size
|
trainer.schedule._grad_accum_size = grad_accum_size
|
||||||
trainer.schedule._grad_accum_batch_size = grad_accum_batch_size
|
|
||||||
trainer.schedule._hooks = metric_hook_list
|
trainer.schedule._hooks = metric_hook_list
|
||||||
yield
|
yield
|
||||||
finally:
|
finally:
|
||||||
trainer.schedule.data_process_func = prev_data_process_func
|
trainer.schedule.data_process_func = prev_data_process_func
|
||||||
trainer.schedule._grad_accum_size = prev_grad_accum_size
|
trainer.schedule._grad_accum_size = prev_grad_accum_size
|
||||||
trainer.schedule._grad_accum_batch_size = prev_grad_accum_batch_size
|
|
||||||
trainer.schedule._hooks = prev_metric_hooks
|
trainer.schedule._hooks = prev_metric_hooks
|
||||||
|
|
||||||
|
|
||||||
|
@ -126,11 +123,9 @@ def evaluate_on_val_dls(
|
||||||
total_val_bsz = len(batch[1])
|
total_val_bsz = len(batch[1])
|
||||||
assert total_val_bsz % data_cfg.micro_bsz == 0
|
assert total_val_bsz % data_cfg.micro_bsz == 0
|
||||||
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
grad_accum_size = total_val_bsz // data_cfg.micro_bsz
|
||||||
grad_accum_batch_size = data_cfg.micro_bsz
|
|
||||||
with switch_evaluation_no_pipeline_scheduler(
|
with switch_evaluation_no_pipeline_scheduler(
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
grad_accum_size=grad_accum_size,
|
grad_accum_size=grad_accum_size,
|
||||||
grad_accum_batch_size=grad_accum_batch_size,
|
|
||||||
metric_hook_list=[val_sche_metric_hook],
|
metric_hook_list=[val_sche_metric_hook],
|
||||||
):
|
):
|
||||||
if hasattr(gpc.config.model, "num_experts"):
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
|
|
|
@ -1,59 +1,19 @@
|
||||||
import multiprocessing as mp
|
import multiprocessing as mp
|
||||||
import random
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
|
||||||
from torch.testing import assert_close
|
|
||||||
|
|
||||||
import internlm
|
|
||||||
from internlm.core.context import ParallelMode
|
from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.context.parallel_context import Config
|
from internlm.core.context.parallel_context import Config
|
||||||
from internlm.core.engine import Engine
|
from tests.test_core.utils import (
|
||||||
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
MlpModel,
|
||||||
from internlm.core.scheduler import (
|
MyLoss,
|
||||||
InterleavedPipelineScheduler,
|
build_environment,
|
||||||
PipelineScheduler,
|
init_model_and_optim,
|
||||||
SchedulerMetricHook,
|
loose_close,
|
||||||
|
seed_all,
|
||||||
)
|
)
|
||||||
from internlm.solver.pipeline_utils import partition_uniform
|
|
||||||
from internlm.train import initialize_optimizer
|
|
||||||
|
|
||||||
|
|
||||||
class MlpModel(nn.Module):
|
|
||||||
"""
|
|
||||||
Custom model
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, start, end, model_type=None):
|
|
||||||
super().__init__()
|
|
||||||
self.part = [start, end]
|
|
||||||
self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end - start)])
|
|
||||||
self.model_type = model_type
|
|
||||||
|
|
||||||
def forward(self, hidden_states=None, input_ids=None):
|
|
||||||
if self.model_type != "torch" and self.part[0] != 0:
|
|
||||||
input_ids = hidden_states
|
|
||||||
|
|
||||||
for i in range(self.part[1] - self.part[0]):
|
|
||||||
input_ids = self.blocks[i](input_ids)
|
|
||||||
return input_ids
|
|
||||||
|
|
||||||
|
|
||||||
class MyLoss(nn.Module):
|
|
||||||
"""
|
|
||||||
Custom loss
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def forward(self, logits, labels):
|
|
||||||
loss = torch.nn.MSELoss(reduction="sum")
|
|
||||||
return loss(logits, labels)
|
|
||||||
|
|
||||||
|
|
||||||
config = Config(
|
config = Config(
|
||||||
dict(
|
dict(
|
||||||
|
@ -116,71 +76,6 @@ config = Config(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def build_environment(rank, world_size):
|
|
||||||
import os
|
|
||||||
|
|
||||||
os.environ["RANK"] = str(rank)
|
|
||||||
os.environ["LOCAL_RANK"] = str(rank)
|
|
||||||
os.environ["WORLD_SIZE"] = str(world_size)
|
|
||||||
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
|
||||||
os.environ["MASTER_PORT"] = "33333"
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
# launcher="torch"
|
|
||||||
internlm.launch_from_torch(config=config, seed=1024)
|
|
||||||
|
|
||||||
|
|
||||||
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
|
||||||
|
|
||||||
if dtype is torch.float32:
|
|
||||||
rtol = 1.3e-6
|
|
||||||
atol = 1e-5
|
|
||||||
elif dtype is torch.bfloat16:
|
|
||||||
rtol = 2e-2
|
|
||||||
atol = 2e-2
|
|
||||||
|
|
||||||
if isinstance(a, torch.Tensor):
|
|
||||||
a = a.detach().to(dtype)
|
|
||||||
b = b.detach().to(dtype)
|
|
||||||
|
|
||||||
assert_close(a, b, rtol=rtol, atol=atol)
|
|
||||||
|
|
||||||
|
|
||||||
def seed_all(seed, cuda_deterministic=False):
|
|
||||||
random.seed(seed)
|
|
||||||
np.random.seed(seed)
|
|
||||||
torch.manual_seed(seed)
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
torch.cuda.manual_seed(seed)
|
|
||||||
torch.cuda.manual_seed_all(seed)
|
|
||||||
if cuda_deterministic: # slower, more reproducible
|
|
||||||
torch.backends.cudnn.deterministic = True
|
|
||||||
torch.backends.cudnn.benchmark = False
|
|
||||||
else:
|
|
||||||
torch.backends.cudnn.deterministic = False
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
|
|
||||||
|
|
||||||
def _build_generic_model_1d(num_layers, num_chunks):
|
|
||||||
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
|
||||||
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
|
||||||
|
|
||||||
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
|
||||||
parts = all_parts[pipeline_rank]
|
|
||||||
if gpc.is_rank_for_log():
|
|
||||||
print(f"The layer sharding is {all_parts}.", flush=True)
|
|
||||||
|
|
||||||
models = []
|
|
||||||
for start, end in parts:
|
|
||||||
models.append(MlpModel(start, end).cuda())
|
|
||||||
torch.distributed.barrier()
|
|
||||||
if len(models) == 1:
|
|
||||||
model = models[0]
|
|
||||||
else:
|
|
||||||
model = nn.ModuleList(models)
|
|
||||||
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
def exam_pipeline_parallel(args):
|
def exam_pipeline_parallel(args):
|
||||||
# init
|
# init
|
||||||
rank, world_size, micro_num, num_chunks, interleaved_overlap = args
|
rank, world_size, micro_num, num_chunks, interleaved_overlap = args
|
||||||
|
@ -188,76 +83,18 @@ def exam_pipeline_parallel(args):
|
||||||
config.model.num_chunks = num_chunks
|
config.model.num_chunks = num_chunks
|
||||||
config.parallel.pipeline.interleaved_overlap = interleaved_overlap
|
config.parallel.pipeline.interleaved_overlap = interleaved_overlap
|
||||||
|
|
||||||
build_environment(rank, world_size)
|
build_environment(rank, world_size, config)
|
||||||
|
|
||||||
device = torch.device(f"cuda:{rank}")
|
device = torch.device(f"cuda:{rank}")
|
||||||
dtype = config.model["dtype"]
|
dtype = config.model["dtype"]
|
||||||
|
seq_len = gpc.config.data.seq_len
|
||||||
|
|
||||||
# set seed
|
# set seed
|
||||||
seed_all(1024)
|
seed_all(1024)
|
||||||
|
|
||||||
# pp model
|
engine, scheduler = init_model_and_optim(32, num_chunks, dtype, micro_num, interleaved_overlap, tensor_shape=[1, 8])
|
||||||
pp_model = _build_generic_model_1d(num_layers=32, num_chunks=num_chunks)
|
if scheduler is None:
|
||||||
pp_model = pp_model.to(dtype)
|
|
||||||
|
|
||||||
# pp scheduler
|
|
||||||
scheduler_hooks = [
|
|
||||||
SchedulerMetricHook(skip=True),
|
|
||||||
]
|
|
||||||
|
|
||||||
seq_len = gpc.config.data.seq_len
|
|
||||||
gpc.config.NUM_MICRO_BATCHES = micro_num
|
|
||||||
communication_overlap = interleaved_overlap
|
|
||||||
|
|
||||||
if num_chunks == 1:
|
|
||||||
# noninterleaved pp
|
|
||||||
scheduler = PipelineScheduler(
|
|
||||||
data_process_func=None,
|
|
||||||
num_microbatches=micro_num,
|
|
||||||
dtype=dtype,
|
|
||||||
tensor_shape=[1, 8],
|
|
||||||
scatter_gather_tensors=False,
|
|
||||||
scheduler_hooks=scheduler_hooks,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# interleaved pp
|
|
||||||
if micro_num < gpc.get_world_size(ParallelMode.PIPELINE):
|
|
||||||
try:
|
|
||||||
scheduler = InterleavedPipelineScheduler(
|
|
||||||
num_microbatches=micro_num,
|
|
||||||
num_chunks=gpc.config.model.num_chunks,
|
|
||||||
dtype=dtype,
|
|
||||||
tensor_shape=[1, 8],
|
|
||||||
scatter_gather_tensors=False,
|
|
||||||
scheduler_hooks=scheduler_hooks,
|
|
||||||
communication_overlap=communication_overlap,
|
|
||||||
)
|
|
||||||
except AssertionError:
|
|
||||||
return
|
return
|
||||||
else:
|
|
||||||
raise RuntimeError("Error: AssertionError should occur when micro_num < Pipeline parrallel world size")
|
|
||||||
else:
|
|
||||||
scheduler = InterleavedPipelineScheduler(
|
|
||||||
num_microbatches=micro_num,
|
|
||||||
num_chunks=gpc.config.model.num_chunks,
|
|
||||||
dtype=dtype,
|
|
||||||
tensor_shape=[1, 8],
|
|
||||||
scatter_gather_tensors=False,
|
|
||||||
scheduler_hooks=scheduler_hooks,
|
|
||||||
communication_overlap=communication_overlap,
|
|
||||||
)
|
|
||||||
|
|
||||||
# pp optimizer and engine
|
|
||||||
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
|
||||||
engine = Engine(
|
|
||||||
model=pp_model,
|
|
||||||
optimizer=optimizer,
|
|
||||||
lr_scheduler=lr_scheduler,
|
|
||||||
beta2_scheduler=beta2_scheduler,
|
|
||||||
criterion=MyLoss().to(dtype),
|
|
||||||
gradient_handlers=[PipelineSharedModuleGradientHandler(model=pp_model, optimizer=optimizer)],
|
|
||||||
clip_grad_norm=gpc.config.hybrid_zero_optimizer.get("clip_grad_norm", 0.0),
|
|
||||||
)
|
|
||||||
|
|
||||||
scheduler.pre_processing(engine)
|
scheduler.pre_processing(engine)
|
||||||
engine.train()
|
engine.train()
|
||||||
|
|
|
@ -0,0 +1,207 @@
|
||||||
|
import random
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from torch import nn
|
||||||
|
from torch.testing import assert_close
|
||||||
|
|
||||||
|
import internlm
|
||||||
|
from internlm.core.context import ParallelMode
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
|
from internlm.core.engine import Engine
|
||||||
|
from internlm.core.gradient_handler import PipelineSharedModuleGradientHandler
|
||||||
|
from internlm.core.scheduler import (
|
||||||
|
InterleavedPipelineScheduler,
|
||||||
|
NonPipelineScheduler,
|
||||||
|
PipelineScheduler,
|
||||||
|
SchedulerMetricHook,
|
||||||
|
)
|
||||||
|
from internlm.solver.pipeline_utils import partition_uniform
|
||||||
|
from internlm.train import initialize_optimizer
|
||||||
|
|
||||||
|
|
||||||
|
class MlpModel(nn.Module):
|
||||||
|
"""
|
||||||
|
Custom model
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, start, end, model_type=None, embedding=False):
|
||||||
|
super().__init__()
|
||||||
|
self.part = [start, end]
|
||||||
|
self.blocks = nn.ModuleList([nn.Linear(8, 8, bias=False) for lid in range(end - start)])
|
||||||
|
self.model_type = model_type
|
||||||
|
self.embedding = embedding
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None
|
||||||
|
): # pylint: disable=W0613
|
||||||
|
if self.model_type != "torch" and self.part[0] != 0:
|
||||||
|
input_ids = hidden_states
|
||||||
|
|
||||||
|
# Simulate Embedding.
|
||||||
|
if self.embedding:
|
||||||
|
if len(input_ids.shape) == 2:
|
||||||
|
input_ids = input_ids.view(-1, 8)
|
||||||
|
elif len(input_ids.shape) == 3:
|
||||||
|
input_ids = input_ids.view(input_ids.shape(0), -1, 8)
|
||||||
|
|
||||||
|
for i in range(self.part[1] - self.part[0]):
|
||||||
|
input_ids = self.blocks[i](input_ids)
|
||||||
|
|
||||||
|
return input_ids
|
||||||
|
|
||||||
|
|
||||||
|
class MyLoss(nn.Module):
|
||||||
|
"""
|
||||||
|
Custom loss
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
def forward(self, logits, labels):
|
||||||
|
loss = torch.nn.MSELoss(reduction="sum")
|
||||||
|
return loss(logits, labels)
|
||||||
|
|
||||||
|
|
||||||
|
def init_model_and_optim(
|
||||||
|
num_layers, num_chunks, dtype, micro_num, interleaved_overlap, tensor_shape, init_optim=True, embedding=False
|
||||||
|
):
|
||||||
|
# pp model
|
||||||
|
pp_model = _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, embedding=embedding)
|
||||||
|
pp_model = pp_model.to(dtype)
|
||||||
|
|
||||||
|
# pp scheduler
|
||||||
|
scheduler_hooks = [
|
||||||
|
SchedulerMetricHook(skip=True),
|
||||||
|
]
|
||||||
|
|
||||||
|
if gpc.get_world_size(ParallelMode.PIPELINE) > 1:
|
||||||
|
if num_chunks == 1:
|
||||||
|
# noninterleaved pp
|
||||||
|
scheduler = PipelineScheduler(
|
||||||
|
data_process_func=None,
|
||||||
|
num_microbatches=micro_num,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_shape=tensor_shape,
|
||||||
|
scatter_gather_tensors=False,
|
||||||
|
scheduler_hooks=scheduler_hooks,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# interleaved pp
|
||||||
|
if micro_num < gpc.get_world_size(ParallelMode.PIPELINE):
|
||||||
|
try:
|
||||||
|
scheduler = InterleavedPipelineScheduler(
|
||||||
|
num_microbatches=micro_num,
|
||||||
|
num_chunks=gpc.config.model.num_chunks,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_shape=tensor_shape,
|
||||||
|
scatter_gather_tensors=False,
|
||||||
|
scheduler_hooks=scheduler_hooks,
|
||||||
|
communication_overlap=interleaved_overlap,
|
||||||
|
)
|
||||||
|
except AssertionError as e:
|
||||||
|
print(f"AssertionError: {e}", flush=True)
|
||||||
|
return None, None
|
||||||
|
else:
|
||||||
|
raise RuntimeError(
|
||||||
|
"Error: AssertionError should occur when micro_num < Pipeline parrallel world size"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scheduler = InterleavedPipelineScheduler(
|
||||||
|
num_microbatches=micro_num,
|
||||||
|
num_chunks=gpc.config.model.num_chunks,
|
||||||
|
dtype=dtype,
|
||||||
|
tensor_shape=tensor_shape,
|
||||||
|
scatter_gather_tensors=False,
|
||||||
|
scheduler_hooks=scheduler_hooks,
|
||||||
|
communication_overlap=interleaved_overlap,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
scheduler = NonPipelineScheduler(
|
||||||
|
data_process_func=None,
|
||||||
|
gradient_accumulation_size=gpc.config.data.gradient_accumulation,
|
||||||
|
scheduler_hooks=scheduler_hooks,
|
||||||
|
)
|
||||||
|
|
||||||
|
# pp optimizer and engine
|
||||||
|
if init_optim:
|
||||||
|
optimizer, beta2_scheduler, lr_scheduler = initialize_optimizer(model=pp_model)
|
||||||
|
else:
|
||||||
|
optimizer, beta2_scheduler, lr_scheduler = None, None, None
|
||||||
|
|
||||||
|
engine = Engine(
|
||||||
|
model=pp_model,
|
||||||
|
optimizer=optimizer,
|
||||||
|
lr_scheduler=lr_scheduler,
|
||||||
|
beta2_scheduler=beta2_scheduler,
|
||||||
|
criterion=MyLoss().to(dtype),
|
||||||
|
gradient_handlers=[PipelineSharedModuleGradientHandler(model=pp_model, optimizer=optimizer)],
|
||||||
|
clip_grad_norm=0.0,
|
||||||
|
)
|
||||||
|
return engine, scheduler
|
||||||
|
|
||||||
|
|
||||||
|
def build_environment(rank, world_size, config):
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["RANK"] = str(rank)
|
||||||
|
os.environ["LOCAL_RANK"] = str(rank)
|
||||||
|
os.environ["WORLD_SIZE"] = str(world_size)
|
||||||
|
os.environ["MASTER_ADDR"] = "127.0.0.1"
|
||||||
|
os.environ["MASTER_PORT"] = "33333"
|
||||||
|
torch.cuda.empty_cache()
|
||||||
|
# launcher="torch"
|
||||||
|
internlm.launch_from_torch(config=config, seed=1024)
|
||||||
|
|
||||||
|
|
||||||
|
def loose_close(a, b, dtype: torch.dtype = torch.float32):
|
||||||
|
|
||||||
|
if dtype is torch.float32:
|
||||||
|
rtol = 1.3e-6
|
||||||
|
atol = 1e-5
|
||||||
|
elif dtype is torch.bfloat16:
|
||||||
|
rtol = 2e-2
|
||||||
|
atol = 2e-2
|
||||||
|
|
||||||
|
if isinstance(a, torch.Tensor):
|
||||||
|
a = a.detach().to(dtype)
|
||||||
|
b = b.detach().to(dtype)
|
||||||
|
|
||||||
|
assert_close(a, b, rtol=rtol, atol=atol)
|
||||||
|
|
||||||
|
|
||||||
|
def seed_all(seed, cuda_deterministic=False):
|
||||||
|
random.seed(seed)
|
||||||
|
np.random.seed(seed)
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
if torch.cuda.is_available():
|
||||||
|
torch.cuda.manual_seed(seed)
|
||||||
|
torch.cuda.manual_seed_all(seed)
|
||||||
|
if cuda_deterministic: # slower, more reproducible
|
||||||
|
torch.backends.cudnn.deterministic = True
|
||||||
|
torch.backends.cudnn.benchmark = False
|
||||||
|
else:
|
||||||
|
torch.backends.cudnn.deterministic = False
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
|
||||||
|
|
||||||
|
def _build_generic_model_1d(num_layers, num_chunks, embedding=False):
|
||||||
|
pipeline_size = gpc.get_world_size(ParallelMode.PIPELINE)
|
||||||
|
pipeline_rank = gpc.get_local_rank(ParallelMode.PIPELINE)
|
||||||
|
|
||||||
|
all_parts = partition_uniform(num_layers, pipeline_size, num_chunks)
|
||||||
|
parts = all_parts[pipeline_rank]
|
||||||
|
if gpc.is_rank_for_log():
|
||||||
|
print(f"The layer sharding is {all_parts}.", flush=True)
|
||||||
|
|
||||||
|
models = []
|
||||||
|
for start, end in parts:
|
||||||
|
models.append(MlpModel(start, end, embedding=embedding).cuda())
|
||||||
|
torch.distributed.barrier()
|
||||||
|
if len(models) == 1:
|
||||||
|
model = models[0]
|
||||||
|
else:
|
||||||
|
model = nn.ModuleList(models)
|
||||||
|
|
||||||
|
return model
|
|
@ -0,0 +1,143 @@
|
||||||
|
import multiprocessing as mp
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from internlm.core.context import ParallelMode
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
|
|
||||||
|
# from internlm.core.context import ParallelMode
|
||||||
|
from internlm.core.context.parallel_context import Config
|
||||||
|
from internlm.core.trainer import TrainState
|
||||||
|
from internlm.train import get_train_data_loader, load_new_batch
|
||||||
|
|
||||||
|
# from internlm.core.context.parallel_context import global_context as gpc
|
||||||
|
from tests.test_core.utils import build_environment, init_model_and_optim
|
||||||
|
|
||||||
|
micro_bszs = [1, 2]
|
||||||
|
use_flash_attens = [True, False]
|
||||||
|
answers = [[1] * 8, [1, 1, 1, 1, 2, 2, 2, 2], [4] * 8, [2, 2, 4, 4, 6, 6, 8, 8]]
|
||||||
|
test_case_group = [
|
||||||
|
# format: micro_nums, rampup_batch_size, should sccuess, answer, pp size, sql len
|
||||||
|
# (1, "1 1 1", True, answers[0], 1, 8),
|
||||||
|
(4, "1 1 4", True, answers[1], 1, 8),
|
||||||
|
(4, None, True, answers[2], 1, 8),
|
||||||
|
(8, "2 2 2", True, answers[3], 1, 8),
|
||||||
|
(8, "2 2 2", True, answers[3], 2, 8),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def do_warmup(args):
|
||||||
|
rank, worldsize, init_config, should_sccuess, answer = args
|
||||||
|
build_environment(rank, worldsize, init_config)
|
||||||
|
gpc.config.model.num_chunks = 1 if gpc.get_world_size(ParallelMode.PIPELINE) == 1 else 2
|
||||||
|
engine, scheduler = init_model_and_optim(
|
||||||
|
8,
|
||||||
|
gpc.config.model.num_chunks,
|
||||||
|
torch.bfloat16,
|
||||||
|
init_config.data.micro_num,
|
||||||
|
True,
|
||||||
|
tensor_shape=[1, 8], # can't use get_tensor_shape becase we use toy model.
|
||||||
|
init_optim=False,
|
||||||
|
embedding=True,
|
||||||
|
)
|
||||||
|
scheduler.pre_processing(engine)
|
||||||
|
engine.train()
|
||||||
|
|
||||||
|
try:
|
||||||
|
train_dl, _ = get_train_data_loader(num_worker=0)
|
||||||
|
except Exception as e:
|
||||||
|
assert should_sccuess is False, f"{e}"
|
||||||
|
else:
|
||||||
|
assert should_sccuess is True
|
||||||
|
|
||||||
|
# initialize and resume train state
|
||||||
|
train_state = TrainState(gpc.config, train_dl.batch_sampler)
|
||||||
|
# transfer the train data loader into train data iterator
|
||||||
|
train_iter = iter(train_dl)
|
||||||
|
|
||||||
|
micro_bsz = gpc.config.data.micro_bsz
|
||||||
|
sql = gpc.config.data.seq_len
|
||||||
|
|
||||||
|
consumed_token = 0 # Token consumed
|
||||||
|
packed_length = micro_bsz * sql
|
||||||
|
for i in range(init_config.data.total_steps):
|
||||||
|
batch, train_iter = load_new_batch(train_dl=train_dl, train_iter=train_iter, train_state=train_state)
|
||||||
|
input_shape = batch[0]["type_ids"].shape
|
||||||
|
tokens_num = np.prod(input_shape)
|
||||||
|
|
||||||
|
if not init_config.model.use_flash_attn:
|
||||||
|
if answer[i] > 1:
|
||||||
|
assert input_shape == torch.Size(
|
||||||
|
[answer[i], micro_bsz, sql]
|
||||||
|
), f"iter:{i}, {input_shape} != {[answer[i], micro_bsz, sql]}"
|
||||||
|
else:
|
||||||
|
assert input_shape == torch.Size([micro_bsz, sql]), f"iter:{i}, {input_shape} != {[micro_bsz, sql]}"
|
||||||
|
else:
|
||||||
|
assert input_shape == torch.Size(
|
||||||
|
[answer[i], packed_length]
|
||||||
|
), f"iter:{i}, {input_shape} != {torch.Size([answer[i], packed_length])}"
|
||||||
|
|
||||||
|
if gpc.get_global_rank() == 0:
|
||||||
|
print(
|
||||||
|
f"iter:{i}",
|
||||||
|
f"pp size: {gpc.get_world_size(ParallelMode.PIPELINE)}",
|
||||||
|
f"use_flash_attn:{gpc.config.model.use_flash_attn}",
|
||||||
|
f"micro_bsz:{micro_bsz}",
|
||||||
|
f"input shape: {batch[0]['type_ids'].shape}",
|
||||||
|
f"rampup_batch_size: {gpc.config.data.rampup_batch_size}",
|
||||||
|
f"tokens_num: {tokens_num}",
|
||||||
|
flush=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
consumed_token += tokens_num
|
||||||
|
batch[0].pop("type_ids", None)
|
||||||
|
batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16)
|
||||||
|
|
||||||
|
scheduler.forward_backward_step(engine, batch, forward_only=True, return_loss=False, return_output_label=False)
|
||||||
|
assert (
|
||||||
|
tokens_num == answer[i] * gpc.config.data.seq_len * micro_bsz
|
||||||
|
), f"{tokens_num} == {answer[i] * gpc.config.data.seq_len * micro_bsz}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("use_flash_atten_case", use_flash_attens)
|
||||||
|
@pytest.mark.parametrize("group_case", test_case_group)
|
||||||
|
@pytest.mark.parametrize("micro_bsz_case", micro_bszs)
|
||||||
|
def test_warmup(use_flash_atten_case, group_case, micro_bsz_case):
|
||||||
|
ctx = mp.get_context("spawn")
|
||||||
|
# print(pp_size_case, use_flash_atten_case, group_case, micro_bsz_case, flush=True)
|
||||||
|
|
||||||
|
config = Config(
|
||||||
|
dict(
|
||||||
|
parallel=dict(
|
||||||
|
zero1=dict(size=1, fsdp=False),
|
||||||
|
pipeline=dict(size=1, interleaved_overlap=False),
|
||||||
|
sequence_parallel=False,
|
||||||
|
tensor=1,
|
||||||
|
),
|
||||||
|
data=dict(train_folder=None, pack_sample_into_one=False, min_length=0, total_steps=8),
|
||||||
|
model=dict(
|
||||||
|
dtype=torch.bfloat16,
|
||||||
|
),
|
||||||
|
adam=dict(lr=1e-4),
|
||||||
|
resume_tb_folder=None,
|
||||||
|
tensorboard_folder=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
config.data.seq_len = group_case[5]
|
||||||
|
config.parallel.pipeline.size = group_case[4]
|
||||||
|
config.model.use_flash_attn = use_flash_atten_case
|
||||||
|
config.data.micro_bsz = micro_bsz_case
|
||||||
|
config.data.micro_num = group_case[0]
|
||||||
|
config.data.gradient_accumulation = config.data.micro_num
|
||||||
|
config.data.rampup_batch_size = group_case[1]
|
||||||
|
config.data.packed_length = micro_bsz_case * config.data.seq_len
|
||||||
|
should_sccuess = group_case[2]
|
||||||
|
answer = group_case[3]
|
||||||
|
|
||||||
|
with ctx.Pool(processes=8) as pool:
|
||||||
|
pool.map(do_warmup, [[rank, 8, config, should_sccuess, answer] for rank in range(8)])
|
||||||
|
pool.close()
|
||||||
|
pool.join()
|
Loading…
Reference in New Issue