mirror of https://github.com/InternLM/InternLM
Merge branch 'develop' of https://github.com/InternLM/InternLM into hf_llama
commit
d7555e8216
|
@ -145,18 +145,18 @@ model = dict(
|
||||||
moe_use_residual=False,
|
moe_use_residual=False,
|
||||||
moe_gate_k=2,
|
moe_gate_k=2,
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
zero1 parallel:
|
# zero1 parallel:
|
||||||
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
|
# 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
|
||||||
so parameters will be divided within the range of dp.
|
# so parameters will be divided within the range of dp.
|
||||||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
# 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
# 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
# For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||||
pipeline parallel (dict):
|
# pipeline parallel (dict):
|
||||||
1. size: int, the size of pipeline parallel.
|
# 1. size: int, the size of pipeline parallel.
|
||||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
# 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
||||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
# tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||||
"""
|
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=dict(size=-1, fsdp=False),
|
zero1=dict(size=-1, fsdp=False),
|
||||||
tensor=1,
|
tensor=1,
|
||||||
|
@ -177,3 +177,7 @@ monitor = dict(
|
||||||
)
|
)
|
||||||
|
|
||||||
model_type = "INTERNLM_MoE"
|
model_type = "INTERNLM_MoE"
|
||||||
|
|
||||||
|
# metric_dtype can be "fp32" or other string
|
||||||
|
# only when set to "fp32" will use fp32 to calc in metrics
|
||||||
|
# metric_dtype = "fp32"
|
||||||
|
|
|
@ -146,18 +146,18 @@ model = dict(
|
||||||
use_flash_attn=True,
|
use_flash_attn=True,
|
||||||
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
num_chunks=1, # if num_chunks > 1, interleaved pipeline scheduler is used.
|
||||||
)
|
)
|
||||||
"""
|
|
||||||
zero1 parallel:
|
# zero1 parallel:
|
||||||
1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
|
# 1. if zero1 <= 0, The size of the zero process group is equal to the size of the dp process group,
|
||||||
so parameters will be divided within the range of dp.
|
# so parameters will be divided within the range of dp.
|
||||||
2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
# 2. if zero1 == 1, zero is not used, and all dp groups retain the full amount of model parameters.
|
||||||
3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
# 3. zero1 > 1 and zero1 <= dp world size, the world size of zero is a subset of dp world size.
|
||||||
For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
# For smaller models, it is usually a better choice to split the parameters within nodes with a setting <= 8.
|
||||||
pipeline parallel (dict):
|
# pipeline parallel (dict):
|
||||||
1. size: int, the size of pipeline parallel.
|
# 1. size: int, the size of pipeline parallel.
|
||||||
2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
# 2. interleaved_overlap: bool, enable/disable communication overlap when using interleaved pipeline scheduler.
|
||||||
tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
# tensor parallel: tensor parallel size, usually the number of GPUs per node.
|
||||||
"""
|
|
||||||
parallel = dict(
|
parallel = dict(
|
||||||
zero1=dict(size=8, fsdp=False),
|
zero1=dict(size=8, fsdp=False),
|
||||||
tensor=1,
|
tensor=1,
|
||||||
|
@ -177,3 +177,7 @@ monitor = dict(
|
||||||
alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
|
alert_file_path=f"llm_alter/{JOB_NAME}_alert.log",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# metric_dtype can be "fp32" or other string
|
||||||
|
# only when set to "fp32" will use fp32 to calc in metrics
|
||||||
|
# metric_dtype = "fp32"
|
||||||
|
|
|
@ -185,6 +185,11 @@ class Engine:
|
||||||
|
|
||||||
if to_gpu:
|
if to_gpu:
|
||||||
batch_data = move_to_device(batch_data)
|
batch_data = move_to_device(batch_data)
|
||||||
|
|
||||||
|
# For packed-dataset, batch_data is (micro_num, micro_bsz*seq_len),
|
||||||
|
# therefore 'batch_size' is equal to 'micro_num'
|
||||||
|
# For nopacked-dataset, batch_data is (micro_num*micro_bsz, seq_len),
|
||||||
|
# therefore 'batch_size' is equal to 'micro_num*micro_bsz'
|
||||||
batch_size = get_batch_size(batch_data)
|
batch_size = get_batch_size(batch_data)
|
||||||
|
|
||||||
return batch_data, batch_size
|
return batch_data, batch_size
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Callable, Iterable, Optional
|
from typing import Any, Callable, Dict, Iterable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
@ -36,10 +36,18 @@ class BaseScheduler(ABC):
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _load_micro_batch(self, data, label, offset):
|
def _load_micro_batch(self, data: Dict, label: torch.Tensor, offset: int, bsz_stride: int):
|
||||||
|
"""
|
||||||
|
For pp, it will cut one fully batch into micro batch in pipeline concept.
|
||||||
|
For nopp, it will cut one fully batch into small batch based on gradient accumulate size.
|
||||||
|
|
||||||
|
A special case is that pp uses a 'non-packed-dateset' (such as evaluation dataset),
|
||||||
|
so the data of batch is unpacked and 'bsz_stride' is equal to 'micro_bsz'.
|
||||||
|
In all other cases 'bsz_stride' should be equal to 1.
|
||||||
|
"""
|
||||||
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
|
assert isinstance(data, dict) and isinstance(label, torch.Tensor)
|
||||||
micro_batch_data = {k: v[offset : offset + 1] for k, v in data.items()}
|
micro_batch_data = {k: v[offset : offset + bsz_stride] for k, v in data.items()}
|
||||||
micro_batch_label = label[offset : offset + 1]
|
micro_batch_label = label[offset : offset + bsz_stride]
|
||||||
|
|
||||||
return micro_batch_data, micro_batch_label
|
return micro_batch_data, micro_batch_label
|
||||||
|
|
||||||
|
|
|
@ -72,7 +72,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
label (Any): The label to be loaded.
|
label (Any): The label to be loaded.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
_data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset)
|
_data, _label = self._load_micro_batch(data=data, label=label, offset=self._grad_accum_offset, bsz_stride=1)
|
||||||
self._grad_accum_offset += 1
|
self._grad_accum_offset += 1
|
||||||
|
|
||||||
if self.data_process_func:
|
if self.data_process_func:
|
||||||
|
@ -167,7 +167,7 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
forward_only or return_loss
|
forward_only or return_loss
|
||||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||||
|
|
||||||
batch_data, actual_batch_size = engine.load_batch(data_iter)
|
batch_data, actual_batch_size = engine.load_batch(data_iter) # actual_batch_size is micro_num
|
||||||
|
|
||||||
self._grad_accum_size = actual_batch_size # Rampup or variable bsz size.
|
self._grad_accum_size = actual_batch_size # Rampup or variable bsz size.
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,11 @@ from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context import global_context as gpc
|
from internlm.core.context import global_context as gpc
|
||||||
from internlm.core.engine import Engine
|
from internlm.core.engine import Engine
|
||||||
from internlm.core.naive_amp import NaiveAMPModel
|
from internlm.core.naive_amp import NaiveAMPModel
|
||||||
from internlm.utils.common import get_current_device, move_to_device
|
from internlm.utils.common import (
|
||||||
|
check_data_is_packed,
|
||||||
|
get_current_device,
|
||||||
|
move_to_device,
|
||||||
|
)
|
||||||
from internlm.utils.logger import get_logger
|
from internlm.utils.logger import get_logger
|
||||||
from internlm.utils.timeout import llm_timeout
|
from internlm.utils.timeout import llm_timeout
|
||||||
|
|
||||||
|
@ -186,17 +190,28 @@ class PipelineScheduler(BaseScheduler):
|
||||||
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
raise TypeError(f"Expected data to be of type torch.Tensor, list, tuple, or dict, but got {type(data)}")
|
||||||
|
|
||||||
def load_batch(self, engine, data_iter):
|
def load_batch(self, engine, data_iter):
|
||||||
# Pipeline schedule just puts data in memory
|
# Pipeline schedule just puts data in memory,
|
||||||
batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False)
|
batch_data, actual_batch_size = engine.load_batch(data_iter, to_gpu=False)
|
||||||
|
|
||||||
self.num_microbatches = actual_batch_size # Rampup or variable bsz size.
|
# Even if 'use_flash_attn' is False, the data seen when the 'load_batch' is called is still packed,
|
||||||
|
# because internlm's current train dataset is packed, even using dummy data.
|
||||||
|
# The unpack operation is performed in load_micro_batch().
|
||||||
|
if check_data_is_packed(batch_data):
|
||||||
|
micro_num = actual_batch_size
|
||||||
|
else:
|
||||||
|
micro_num = actual_batch_size // gpc.config.data["micro_bsz"]
|
||||||
|
|
||||||
self.microbatch_offset = 0
|
self.microbatch_offset = 0
|
||||||
self.batch_size = actual_batch_size
|
self.batch_size = actual_batch_size
|
||||||
self.batch_data, self.batch_label = batch_data
|
self.batch_data, self.batch_label = batch_data
|
||||||
|
self.bsz_stride = self.batch_size // micro_num
|
||||||
|
# 'num_microbatches' is no longer an initialization parameter,
|
||||||
|
# but is determined on the fly by the Scheduler.
|
||||||
|
self.num_microbatches = micro_num # Rampup or variable bsz size.
|
||||||
|
|
||||||
def load_micro_batch(self):
|
def load_micro_batch(self):
|
||||||
micro_batch_data, micro_batch_label = self._load_micro_batch(
|
micro_batch_data, micro_batch_label = self._load_micro_batch(
|
||||||
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset
|
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, bsz_stride=self.bsz_stride
|
||||||
)
|
)
|
||||||
if self.data_process_func:
|
if self.data_process_func:
|
||||||
micro_batch_data["input_ids"] = self.data_process_func(
|
micro_batch_data["input_ids"] = self.data_process_func(
|
||||||
|
@ -208,7 +223,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
micro_batch_data.pop("indexes")
|
micro_batch_data.pop("indexes")
|
||||||
|
|
||||||
micro_batch_data["label"] = micro_batch_label
|
micro_batch_data["label"] = micro_batch_label
|
||||||
self.microbatch_offset += 1
|
self.microbatch_offset += self.bsz_stride
|
||||||
|
|
||||||
return move_to_device(micro_batch_data)
|
return move_to_device(micro_batch_data)
|
||||||
|
|
||||||
|
@ -787,9 +802,10 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
||||||
data=self.batch_data,
|
data=self.batch_data,
|
||||||
label=self.batch_label,
|
label=self.batch_label,
|
||||||
offset=self.microbatch_offset[model_chunk_id],
|
offset=self.microbatch_offset[model_chunk_id],
|
||||||
|
bsz_stride=self.bsz_stride,
|
||||||
)
|
)
|
||||||
micro_batch_data["label"] = micro_batch_label
|
micro_batch_data["label"] = micro_batch_label
|
||||||
self.microbatch_offset[model_chunk_id] += 1
|
self.microbatch_offset[model_chunk_id] += self.bsz_stride
|
||||||
return move_to_device(micro_batch_data)
|
return move_to_device(micro_batch_data)
|
||||||
|
|
||||||
def _forward_step(self, engine, chunk_id):
|
def _forward_step(self, engine, chunk_id):
|
||||||
|
|
|
@ -26,7 +26,11 @@ class AccPerplex:
|
||||||
self.device = device
|
self.device = device
|
||||||
self.right = torch.Tensor([0]).to(device=device)
|
self.right = torch.Tensor([0]).to(device=device)
|
||||||
self.total = torch.Tensor([0]).to(device=device)
|
self.total = torch.Tensor([0]).to(device=device)
|
||||||
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=torch.float)
|
self.metric_dtype = torch.float if gpc.config.get("metric_dtype", None) == "fp32" else None
|
||||||
|
if self.metric_dtype is not None:
|
||||||
|
self.total_log_probs = torch.Tensor([0]).to(device=device, dtype=self.metric_dtype)
|
||||||
|
else:
|
||||||
|
self.total_log_probs = torch.Tensor([0]).to(device=device)
|
||||||
self.tp_pg = tp_pg
|
self.tp_pg = tp_pg
|
||||||
self.dp_pg = dp_pg
|
self.dp_pg = dp_pg
|
||||||
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
|
self.tp_local_rank = torch.distributed.get_rank(self.tp_pg)
|
||||||
|
@ -128,8 +132,9 @@ class AccPerplex:
|
||||||
# All reduce is needed to get the chunks from other GPUs.
|
# All reduce is needed to get the chunks from other GPUs.
|
||||||
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
torch.distributed.all_reduce(predicted_logits, op=torch.distributed.ReduceOp.SUM, group=self.tp_pg)
|
||||||
|
|
||||||
predicted_logits = predicted_logits.to(dtype=torch.float)
|
if self.metric_dtype is not None:
|
||||||
shift_logits = shift_logits.to(dtype=torch.float)
|
predicted_logits = predicted_logits.to(dtype=self.metric_dtype)
|
||||||
|
shift_logits = shift_logits.to(dtype=self.metric_dtype)
|
||||||
|
|
||||||
pred_exp_logits = torch.exp(predicted_logits)
|
pred_exp_logits = torch.exp(predicted_logits)
|
||||||
# Sum of exponential of logits along vocab dimension across all GPUs.
|
# Sum of exponential of logits along vocab dimension across all GPUs.
|
||||||
|
|
|
@ -219,10 +219,7 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
|
# flag used to skip unnecessary gradient reduce operation when gradient accumulation is enabled.
|
||||||
self.skip_grad_reduce = False
|
self.skip_grad_reduce = False
|
||||||
|
|
||||||
# reduction hook is only used if overlapping communication
|
self._attach_reduction_hook()
|
||||||
# if it is stage 1 without overlapping, no hook will be attached
|
|
||||||
if self._overlap_sync_grad:
|
|
||||||
self._attach_reduction_hook()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def zero_local_rank(self):
|
def zero_local_rank(self):
|
||||||
|
@ -321,12 +318,15 @@ class HybridZeroOptimizer(BaseOptimizer):
|
||||||
|
|
||||||
# if sequence_parallel is True,
|
# if sequence_parallel is True,
|
||||||
# the grad of norm should be all-reduce across the tp process group
|
# the grad of norm should be all-reduce across the tp process group
|
||||||
if gpc.config.parallel.sequence_parallel is True:
|
if (
|
||||||
if hasattr(param, IS_SEQUENCE_PARALLEL) and getattr(param, IS_SEQUENCE_PARALLEL) is True:
|
gpc.config.parallel.sequence_parallel is True
|
||||||
accum_grad_obj_sp = get_grad_accumulate_object(param)
|
and hasattr(param, IS_SEQUENCE_PARALLEL)
|
||||||
accum_grad_obj_sp.register_hook(reduce_grad_hook_sp)
|
and getattr(param, IS_SEQUENCE_PARALLEL) is True
|
||||||
|
):
|
||||||
|
accum_grad_obj.register_hook(reduce_grad_hook_sp)
|
||||||
|
|
||||||
accum_grad_obj.register_hook(reduce_grad_hook)
|
if self._overlap_sync_grad:
|
||||||
|
accum_grad_obj.register_hook(reduce_grad_hook)
|
||||||
|
|
||||||
_define_and_attach(param, reduce_rank)
|
_define_and_attach(param, reduce_rank)
|
||||||
|
|
||||||
|
|
|
@ -110,6 +110,17 @@ def get_batch_size(data):
|
||||||
return data[list(data.keys())[0]].size(0)
|
return data[list(data.keys())[0]].size(0)
|
||||||
|
|
||||||
|
|
||||||
|
def check_data_is_packed(data):
|
||||||
|
if isinstance(data, torch.Tensor):
|
||||||
|
return False
|
||||||
|
elif isinstance(data, (list, tuple)):
|
||||||
|
if isinstance(data[0], dict):
|
||||||
|
return "indexes" in data[0]
|
||||||
|
return False
|
||||||
|
elif isinstance(data, dict):
|
||||||
|
return "indexes" in data[0]
|
||||||
|
|
||||||
|
|
||||||
def filter_kwargs(func, kwargs):
|
def filter_kwargs(func, kwargs):
|
||||||
sig = inspect.signature(func)
|
sig = inspect.signature(func)
|
||||||
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
return {k: v for k, v in kwargs.items() if k in sig.parameters}
|
||||||
|
|
|
@ -1016,7 +1016,8 @@ now step_count is {train_state.step_count}",
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
def query_latest_snapshot_step_boto3(self):
|
def query_latest_snapshot_step_boto3(self):
|
||||||
"""query_latest_snapshot_step_boto3
|
"""Query the latest snapshot step from the storage backend.
|
||||||
|
Currently, we only support the following storage backends: boto3, oss2 and volc.
|
||||||
Returns:
|
Returns:
|
||||||
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
|
Tuple(str, int): path of latest ckpt and ckpt step, if not found, None will return.
|
||||||
"""
|
"""
|
||||||
|
@ -1074,6 +1075,7 @@ now step_count is {train_state.step_count}",
|
||||||
return load_path, max(snap_step, max_normal_step)
|
return load_path, max(snap_step, max_normal_step)
|
||||||
|
|
||||||
def query_latest_snapshot_step_local(self):
|
def query_latest_snapshot_step_local(self):
|
||||||
|
"""Query the latest snapshot step from the local file system."""
|
||||||
max_step, max_step_path = 0, None
|
max_step, max_step_path = 0, None
|
||||||
save_ckpt_folder = self.save_ckpt_folder.split(":")[1]
|
save_ckpt_folder = self.save_ckpt_folder.split(":")[1]
|
||||||
for root, _, files in os.walk(save_ckpt_folder, followlinks=True):
|
for root, _, files in os.walk(save_ckpt_folder, followlinks=True):
|
||||||
|
@ -1090,18 +1092,22 @@ now step_count is {train_state.step_count}",
|
||||||
return max_step_path, max_step
|
return max_step_path, max_step
|
||||||
|
|
||||||
def query_lastest_ckpt(self):
|
def query_lastest_ckpt(self):
|
||||||
|
"""Query the latest ckpt via the storage backend."""
|
||||||
latest_ckpt, step = None, -1
|
latest_ckpt, step = None, -1
|
||||||
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
|
# Training was automatically restarted by the process, forcing the latest snapshot to be read.
|
||||||
if self.save_ckpt_folder:
|
if self.save_ckpt_folder:
|
||||||
backend, _ = try_get_storage_backend(self.save_ckpt_folder)
|
backend, _ = try_get_storage_backend(self.save_ckpt_folder)
|
||||||
if backend == "boto3":
|
if backend in ["boto3", "oss2", "volc"]:
|
||||||
latest_ckpt, step = self.query_latest_snapshot_step_boto3()
|
latest_ckpt, step = self.query_latest_snapshot_step_boto3()
|
||||||
if latest_ckpt and not latest_ckpt.startswith("boto3:"):
|
|
||||||
latest_ckpt = ":".join(["boto3", latest_ckpt])
|
|
||||||
elif backend == "local":
|
elif backend == "local":
|
||||||
latest_ckpt, step = self.query_latest_snapshot_step_local()
|
latest_ckpt, step = self.query_latest_snapshot_step_local()
|
||||||
if latest_ckpt and not latest_ckpt.startswith("local:"):
|
else:
|
||||||
latest_ckpt = ":".join(["local", latest_ckpt])
|
raise NotImplementedError(
|
||||||
|
f"Unsupported backend: {backend}, " "Currently only support `boto3`, `oss2`, `volc` and `local`"
|
||||||
|
)
|
||||||
|
|
||||||
|
if latest_ckpt and not latest_ckpt.startswith(backend + ":"):
|
||||||
|
latest_ckpt = ":".join([backend, latest_ckpt])
|
||||||
|
|
||||||
if gpc.is_rank_for_log():
|
if gpc.is_rank_for_log():
|
||||||
logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...")
|
logger.info(f"Found latest ckpt {latest_ckpt if latest_ckpt else 'None'}, step: {step}...")
|
||||||
|
|
|
@ -739,10 +739,9 @@ class AliClient(StorageClient):
|
||||||
if AliClient.is_fp_exists(handler, fp):
|
if AliClient.is_fp_exists(handler, fp):
|
||||||
folder_name_list = []
|
folder_name_list = []
|
||||||
for obj in handler.handler.ObjectIteratorV2(handler.client, prefix=fp):
|
for obj in handler.handler.ObjectIteratorV2(handler.client, prefix=fp):
|
||||||
folder_name_list.append(obj.key.split("/")[-1])
|
folder_name_list.append(obj.key.split(fp, maxsplit=1)[1].strip("/").split("/", maxsplit=1)[0])
|
||||||
|
|
||||||
return list(set(folder_name_list))
|
return list(set(folder_name_list))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if is_rank_for_log():
|
if is_rank_for_log():
|
||||||
logger.warning(f"'{fp}' not found!")
|
logger.warning(f"'{fp}' not found!")
|
||||||
|
|
|
@ -10,7 +10,15 @@ from internlm.core.context import global_context as gpc
|
||||||
# from internlm.core.context import ParallelMode
|
# from internlm.core.context import ParallelMode
|
||||||
from internlm.core.context.parallel_context import Config
|
from internlm.core.context.parallel_context import Config
|
||||||
from internlm.core.trainer import TrainState
|
from internlm.core.trainer import TrainState
|
||||||
from internlm.train import get_train_data_loader, load_new_batch
|
from internlm.train import (
|
||||||
|
get_train_data_loader,
|
||||||
|
get_validation_data_loader,
|
||||||
|
load_new_batch,
|
||||||
|
)
|
||||||
|
from internlm.utils.evaluation import (
|
||||||
|
switch_evaluation_no_pipeline_scheduler,
|
||||||
|
switch_evaluation_pipeline_scheduler,
|
||||||
|
)
|
||||||
|
|
||||||
# from internlm.core.context.parallel_context import global_context as gpc
|
# from internlm.core.context.parallel_context import global_context as gpc
|
||||||
from tests.test_core.utils import build_environment, init_model_and_optim
|
from tests.test_core.utils import build_environment, init_model_and_optim
|
||||||
|
@ -20,7 +28,7 @@ use_flash_attens = [True, False]
|
||||||
answers = [[1] * 8, [1, 1, 1, 1, 2, 2, 2, 2], [4] * 8, [2, 2, 4, 4, 6, 6, 8, 8]]
|
answers = [[1] * 8, [1, 1, 1, 1, 2, 2, 2, 2], [4] * 8, [2, 2, 4, 4, 6, 6, 8, 8]]
|
||||||
test_case_group = [
|
test_case_group = [
|
||||||
# format: micro_nums, rampup_batch_size, should sccuess, answer, pp size, sql len
|
# format: micro_nums, rampup_batch_size, should sccuess, answer, pp size, sql len
|
||||||
# (1, "1 1 1", True, answers[0], 1, 8),
|
(1, "1 1 1", True, answers[0], 1, 8),
|
||||||
(4, "1 1 4", True, answers[1], 1, 8),
|
(4, "1 1 4", True, answers[1], 1, 8),
|
||||||
(4, None, True, answers[2], 1, 8),
|
(4, None, True, answers[2], 1, 8),
|
||||||
(8, "2 2 2", True, answers[3], 1, 8),
|
(8, "2 2 2", True, answers[3], 1, 8),
|
||||||
|
@ -28,6 +36,11 @@ test_case_group = [
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class DummyTrainer:
|
||||||
|
def __init__(self, scheduler) -> None:
|
||||||
|
self.schedule = scheduler
|
||||||
|
|
||||||
|
|
||||||
def do_warmup(args):
|
def do_warmup(args):
|
||||||
rank, worldsize, init_config, should_sccuess, answer = args
|
rank, worldsize, init_config, should_sccuess, answer = args
|
||||||
build_environment(rank, worldsize, init_config)
|
build_environment(rank, worldsize, init_config)
|
||||||
|
@ -44,9 +57,11 @@ def do_warmup(args):
|
||||||
)
|
)
|
||||||
scheduler.pre_processing(engine)
|
scheduler.pre_processing(engine)
|
||||||
engine.train()
|
engine.train()
|
||||||
|
trainer = DummyTrainer(scheduler)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
train_dl, _ = get_train_data_loader(num_worker=0)
|
train_dl, _ = get_train_data_loader(num_worker=0)
|
||||||
|
val_dls = get_validation_data_loader(num_worker=0)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert should_sccuess is False, f"{e}"
|
assert should_sccuess is False, f"{e}"
|
||||||
else:
|
else:
|
||||||
|
@ -105,6 +120,38 @@ def do_warmup(args):
|
||||||
tokens_num == answer[i] * gpc.config.data.seq_len * micro_bsz
|
tokens_num == answer[i] * gpc.config.data.seq_len * micro_bsz
|
||||||
), f"{tokens_num} == {answer[i] * gpc.config.data.seq_len * micro_bsz}"
|
), f"{tokens_num} == {answer[i] * gpc.config.data.seq_len * micro_bsz}"
|
||||||
|
|
||||||
|
# test no-packed datasets.
|
||||||
|
for _, val_dl in val_dls.items():
|
||||||
|
for _, batch in enumerate(val_dl):
|
||||||
|
if gpc.is_using_pp():
|
||||||
|
total_val_bsz = len(batch[1])
|
||||||
|
batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16)
|
||||||
|
assert total_val_bsz % micro_bsz == 0
|
||||||
|
num_microbatches = total_val_bsz // micro_bsz
|
||||||
|
tensor_shape = torch.Size([micro_bsz, batch[0]["input_ids"].shape[1]]) # toy model hidden size is 8.
|
||||||
|
with switch_evaluation_pipeline_scheduler(
|
||||||
|
trainer=trainer,
|
||||||
|
num_microbatches=num_microbatches,
|
||||||
|
tensor_shape=tensor_shape,
|
||||||
|
metric_hook_list=[],
|
||||||
|
):
|
||||||
|
scheduler.forward_backward_step(
|
||||||
|
engine, batch, forward_only=True, return_loss=False, return_output_label=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
total_val_bsz = len(batch[1])
|
||||||
|
batch[0]["input_ids"] = batch[0]["input_ids"].to(torch.bfloat16)
|
||||||
|
assert total_val_bsz % micro_bsz == 0
|
||||||
|
grad_accum_size = total_val_bsz // micro_bsz
|
||||||
|
with switch_evaluation_no_pipeline_scheduler(
|
||||||
|
trainer=trainer,
|
||||||
|
grad_accum_size=grad_accum_size,
|
||||||
|
metric_hook_list=[],
|
||||||
|
):
|
||||||
|
scheduler.forward_backward_step(
|
||||||
|
engine, batch, forward_only=True, return_loss=False, return_output_label=False
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("use_flash_atten_case", use_flash_attens)
|
@pytest.mark.parametrize("use_flash_atten_case", use_flash_attens)
|
||||||
@pytest.mark.parametrize("group_case", test_case_group)
|
@pytest.mark.parametrize("group_case", test_case_group)
|
||||||
|
@ -121,7 +168,14 @@ def test_warmup(use_flash_atten_case, group_case, micro_bsz_case):
|
||||||
sequence_parallel=False,
|
sequence_parallel=False,
|
||||||
tensor=1,
|
tensor=1,
|
||||||
),
|
),
|
||||||
data=dict(train_folder=None, pack_sample_into_one=False, min_length=0, total_steps=8),
|
data=dict(
|
||||||
|
train_folder=None,
|
||||||
|
valid_folder=None,
|
||||||
|
valid_micro_num=4,
|
||||||
|
pack_sample_into_one=False,
|
||||||
|
min_length=0,
|
||||||
|
total_steps=8,
|
||||||
|
),
|
||||||
model=dict(
|
model=dict(
|
||||||
dtype=torch.bfloat16,
|
dtype=torch.bfloat16,
|
||||||
),
|
),
|
||||||
|
|
Loading…
Reference in New Issue