mirror of https://github.com/InternLM/InternLM
feat(*): support not-flash-attn for pp and no-pp (#145)
* support not flash attention for no-pp * support pipeline * modify the config * refactor the code * refactor the code * remove some unnecessary codepull/158/head
parent
8b1717a05d
commit
5ee651c2f1
|
@ -110,6 +110,7 @@ model = dict(
|
|||
dtype="torch.bfloat16",
|
||||
norm_type="rmsnorm",
|
||||
layer_norm_epsilon=1e-5,
|
||||
use_flash_attn=True,
|
||||
)
|
||||
"""
|
||||
zero1 parallel:
|
||||
|
|
|
@ -3,7 +3,6 @@
|
|||
|
||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||
|
||||
import inspect
|
||||
from typing import Any, Callable, Iterable
|
||||
|
||||
import torch
|
||||
|
@ -36,15 +35,6 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
"""
|
||||
|
||||
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
|
||||
# check that non-pipeline schedule data process func only takes in one parameter
|
||||
# which is the batch data
|
||||
if data_process_func:
|
||||
sig = inspect.signature(data_process_func)
|
||||
assert len(sig.parameters) == 1, (
|
||||
"The data_process_func only takes in one parameter for NonPipelineSchedule, "
|
||||
"which is a tuple of tensors for the current batch, "
|
||||
"i.e. data_process_func(dataloader_output)."
|
||||
)
|
||||
|
||||
self._grad_accum_size = gradient_accumulation_size
|
||||
self._grad_accum_batch_size = 1 # static batch size for flash attetion.
|
||||
|
@ -72,6 +62,12 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size
|
||||
)
|
||||
self._grad_accum_offset += self._grad_accum_batch_size
|
||||
|
||||
if self.data_process_func:
|
||||
_data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"])
|
||||
_label = self.data_process_func(_label, _data["cu_seqlens"])
|
||||
_data.pop("cu_seqlens")
|
||||
_data.pop("indexes")
|
||||
|
||||
return _data, _label
|
||||
|
||||
|
@ -152,12 +148,7 @@ class NonPipelineScheduler(BaseScheduler):
|
|||
batch_size == self._grad_accum_size
|
||||
), f"batch_size:{batch_size} must be equal to gradient accumulation steps:{self._grad_accum_size}"
|
||||
|
||||
if self.data_process_func:
|
||||
data, label = self.data_process_func(batch_data)
|
||||
else:
|
||||
# if not batch data process func is given,
|
||||
# then we regard the batch data as a simple tuple of (data, label)
|
||||
data, label = batch_data
|
||||
data, label = batch_data
|
||||
|
||||
loss = 0 if return_loss else None
|
||||
outputs = []
|
||||
|
|
|
@ -30,10 +30,16 @@ def get_tensor_shape():
|
|||
return None
|
||||
|
||||
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
)
|
||||
if gpc.config.model.use_flash_attn:
|
||||
tensor_shape = (
|
||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
)
|
||||
else:
|
||||
tensor_shape = (
|
||||
gpc.config.data["micro_bsz"], gpc.config.SEQ_LEN,
|
||||
gpc.config.HIDDEN_SIZE,
|
||||
)
|
||||
return tensor_shape
|
||||
else:
|
||||
return None
|
||||
|
@ -122,14 +128,21 @@ class PipelineScheduler(BaseScheduler):
|
|||
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||
|
||||
def load_micro_batch(self):
|
||||
mciro_batch_data, micro_batch_label = self._load_micro_batch(
|
||||
micro_batch_data, micro_batch_label = self._load_micro_batch(
|
||||
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
|
||||
)
|
||||
self.microbatch_offset += self.microbatch_size
|
||||
|
||||
# unpack data process
|
||||
# TODO by xyt
|
||||
return move_to_device(mciro_batch_data), move_to_device(micro_batch_label)
|
||||
if self.data_process_func:
|
||||
micro_batch_data["input_ids"] = self.data_process_func(
|
||||
micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
|
||||
)
|
||||
micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"])
|
||||
|
||||
micro_batch_data.pop("cu_seqlens")
|
||||
micro_batch_data.pop("indexes")
|
||||
|
||||
return move_to_device(micro_batch_data), move_to_device(micro_batch_label)
|
||||
|
||||
def pre_processing(self, engine):
|
||||
model = engine.model
|
||||
|
@ -204,11 +217,12 @@ class PipelineScheduler(BaseScheduler):
|
|||
pipeline stage.
|
||||
"""
|
||||
micro_batch_data, micro_batch_label = self.load_micro_batch()
|
||||
|
||||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, micro_batch_label)
|
||||
|
||||
timer("fwd").start()
|
||||
output_obj = self._call_engine(engine.model, data)
|
||||
timer("fwd").stop()
|
||||
|
||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||
timer("post_fn").start()
|
||||
post_func = kwargs.get("post_fn")
|
||||
|
@ -295,6 +309,7 @@ class PipelineScheduler(BaseScheduler):
|
|||
assert (
|
||||
forward_only or return_loss
|
||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||
|
||||
self.load_batch(engine, data_iter)
|
||||
num_warmup_microbatches = (
|
||||
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
|
||||
|
@ -841,4 +856,4 @@ class InterleavedPipelineScheduler(PipelineScheduler):
|
|||
output, label = pack_return_tensors(return_tensors)
|
||||
return output, label, accum_loss
|
||||
else:
|
||||
return None, None, accum_loss
|
||||
return None, None, accum_loss
|
|
@ -144,6 +144,48 @@ class PackedDataset(torch.utils.data.Dataset):
|
|||
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
|
||||
return out
|
||||
|
||||
def cal_pos_unpack(self, index):
|
||||
if index == 0:
|
||||
pre_pos = 0
|
||||
else:
|
||||
pre_pos = index * gpc.config.data["micro_bsz"]
|
||||
|
||||
pos = (index + 1) * gpc.config.data["micro_bsz"]
|
||||
return pre_pos, pos
|
||||
|
||||
def build_unpack(self, index):
|
||||
|
||||
pre_pos, pos = self.cal_pos_unpack(index)
|
||||
|
||||
pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
|
||||
|
||||
while pre_pos < pos and pre_pos < len(self.dataset):
|
||||
sample_idx = self.sample_indices[pre_pos]
|
||||
sample = self.dataset[sample_idx]
|
||||
length = min(len(sample["tokens"]), self.max_length_per_sample)
|
||||
chunk = sample["tokens"][0:length]
|
||||
pack.extend(chunk)
|
||||
_labels = deepcopy(chunk)
|
||||
_labels = list(_labels[1:]) + [-100]
|
||||
assert len(_labels) == len(chunk), (_labels, chunk)
|
||||
labels.extend(_labels)
|
||||
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
|
||||
cu_seqlens.append(cu_seqlens[-1] + len(chunk))
|
||||
indexes.extend(list(range(length)))
|
||||
pre_pos = pre_pos + 1
|
||||
|
||||
if cu_seqlens[-1] != self.packed_length:
|
||||
pack = pack + [0] * (self.packed_length - cu_seqlens[-1])
|
||||
labels = labels + [0] * (self.packed_length - cu_seqlens[-1])
|
||||
type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1])
|
||||
indexes.extend(list(range(self.packed_length - cu_seqlens[-1])))
|
||||
cu_seqlens.append(self.packed_length)
|
||||
|
||||
assert len(pack) == self.packed_length
|
||||
|
||||
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
|
||||
return out
|
||||
|
||||
def __getitem__(self, item: int) -> Dict:
|
||||
"""Given the index, it returns a dict as
|
||||
{
|
||||
|
@ -154,8 +196,11 @@ class PackedDataset(torch.utils.data.Dataset):
|
|||
}
|
||||
"""
|
||||
|
||||
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
|
||||
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
|
||||
if gpc.config.model.use_flash_attn:
|
||||
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
|
||||
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
|
||||
|
||||
return self.build_unpack(item)
|
||||
|
||||
|
||||
class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset):
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
|
||||
from internlm.core.context import global_context as gpc
|
||||
|
||||
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2, "ja": 3, "ar": 4, "kaoshi": 5}
|
||||
|
||||
|
||||
|
@ -13,3 +17,29 @@ def get_dataset_type_id(path):
|
|||
match_idxes.append(idx)
|
||||
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
|
||||
return match_idxes[0]
|
||||
|
||||
def unpack_data(input_ids, cu_seqlens):
|
||||
"""
|
||||
input_ids: (n, packed_length)
|
||||
Return:
|
||||
output: (batch_size, max_length)
|
||||
"""
|
||||
|
||||
bsz = input_ids.shape[0]
|
||||
|
||||
num_sequence = gpc.config.data["micro_bsz"]
|
||||
|
||||
outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
|
||||
|
||||
for i in range(bsz):
|
||||
output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
|
||||
cu_seqlens_slice = cu_seqlens[i]
|
||||
for j in range(num_sequence):
|
||||
seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j]
|
||||
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
|
||||
outputs[i] = output
|
||||
|
||||
if bsz == 1:
|
||||
outputs = outputs.squeeze(0)
|
||||
|
||||
return outputs
|
|
@ -22,6 +22,7 @@ from internlm.core.scheduler.pipeline_scheduler import (
|
|||
get_tensor_shape,
|
||||
)
|
||||
from internlm.core.trainer import Trainer
|
||||
from internlm.data.utils import unpack_data
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
||||
from internlm.utils.common import get_current_device
|
||||
|
@ -77,9 +78,17 @@ def initialize_trainer(
|
|||
|
||||
# initialize scheduler for trainer
|
||||
scheduler = None
|
||||
if gpc.config.model.use_flash_attn:
|
||||
data_fn = None
|
||||
else:
|
||||
data_fn = unpack_data
|
||||
if gpc.is_using_pp():
|
||||
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
||||
tensor_shape = get_tensor_shape()
|
||||
# if gpc.config.model.use_flash_attn:
|
||||
# tensor_shape = get_tensor_shape()
|
||||
# else:
|
||||
# tensor_shape = None
|
||||
use_interleaved = (
|
||||
hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
|
||||
)
|
||||
|
@ -96,13 +105,14 @@ def initialize_trainer(
|
|||
)
|
||||
else:
|
||||
scheduler = PipelineScheduler(
|
||||
data_process_func=data_fn,
|
||||
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
||||
dtype=gpc.config.model["dtype"],
|
||||
tensor_shape=tensor_shape,
|
||||
scatter_gather_tensors=scatter_gather,
|
||||
)
|
||||
else:
|
||||
scheduler = NonPipelineScheduler(gradient_accumulation_size=gpc.config.data.gradient_accumulation)
|
||||
scheduler = NonPipelineScheduler(data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation)
|
||||
|
||||
# initialize engine for trainer
|
||||
engine = Engine(
|
||||
|
|
|
@ -51,7 +51,10 @@ class AccPerplex:
|
|||
return self.update(logits, labels, type_ids=self.type_ids)
|
||||
|
||||
def update(self, logits, labels, type_ids=None):
|
||||
micro_bsz = labels.size(0)
|
||||
if gpc.config.model.use_flash_attn:
|
||||
micro_bsz = labels.size(0)
|
||||
else:
|
||||
micro_bsz = 1
|
||||
if type_ids is not None:
|
||||
type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
|
||||
self.batch_shift += 1
|
||||
|
|
|
@ -49,6 +49,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether use flash-attn. True by default.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -68,12 +69,14 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
dropout_selective_checkpoint: bool = True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
self.checkpoint = checkpoint
|
||||
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
|
||||
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
|
||||
self.layer_idx = layer_idx
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
head_dim = hidden_size // num_attention_heads
|
||||
self.mixer = MHA(
|
||||
|
@ -86,7 +89,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
|||
layer_idx=layer_idx,
|
||||
rotary_emb_dim=head_dim,
|
||||
rotary_emb_scale_base=0,
|
||||
use_flash_attn=True,
|
||||
use_flash_attn=use_flash_attn,
|
||||
sequence_parallel=False,
|
||||
device=device,
|
||||
dtype=dtype,
|
||||
|
@ -244,6 +247,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
|
||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -273,9 +277,11 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
dropout_selective_checkpoint: bool = True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.use_flash_attn = use_flash_attn
|
||||
if checkpoint_fraction <= 0:
|
||||
checkpoint = False
|
||||
if not checkpoint:
|
||||
|
@ -322,6 +328,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
||||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
for lid in range(num_layers)
|
||||
]
|
||||
|
@ -358,7 +365,7 @@ class PackedFlashInternLm1D(nn.Module):
|
|||
if isinstance(cu_seqlens, list):
|
||||
assert len(cu_seqlens) == 1
|
||||
cu_seqlens = cu_seqlens[0].to(hidden_states.device)
|
||||
|
||||
|
||||
if cu_seqlens is not None:
|
||||
cu_seqlens = cu_seqlens.squeeze(0)
|
||||
hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state,
|
||||
|
@ -456,6 +463,7 @@ def build_model_with_cfg(
|
|||
dropout_selective_checkpoint=True,
|
||||
use_scaled_init: bool = True,
|
||||
use_swiglu: bool = True,
|
||||
use_flash_attn: bool = True,
|
||||
):
|
||||
"""
|
||||
Builde model with config
|
||||
|
@ -485,6 +493,7 @@ def build_model_with_cfg(
|
|||
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
|
||||
use_scaled_init (bool): Whether to use scaled init. True by default.
|
||||
use_swiglu (bool): Whether to use swiglu. True by default.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -507,6 +516,7 @@ def build_model_with_cfg(
|
|||
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
||||
use_scaled_init=use_scaled_init,
|
||||
use_swiglu=use_swiglu,
|
||||
use_flash_attn=use_flash_attn,
|
||||
)
|
||||
|
||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
||||
|
|
|
@ -43,6 +43,7 @@ class MHA(nn.Module):
|
|||
of x will be done before doing the matmul.
|
||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||
dtype (Optional[torch.dtype]): The type of data.
|
||||
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||
|
||||
"""
|
||||
|
||||
|
@ -57,7 +58,7 @@ class MHA(nn.Module):
|
|||
layer_idx: int = None,
|
||||
rotary_emb_dim: int = 0,
|
||||
rotary_emb_scale_base: int = 0,
|
||||
use_flash_attn: bool = False,
|
||||
use_flash_attn: bool = True,
|
||||
sequence_parallel: bool = True,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
|
|
6
train.py
6
train.py
|
@ -25,7 +25,7 @@ from internlm.data.packed_dataset import (
|
|||
PackedDatasetWithoutCuSeqlen,
|
||||
get_packed_dataset_without_short_length,
|
||||
)
|
||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP
|
||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||
from internlm.model.loss import FlashGPTLMLoss
|
||||
from internlm.model.metrics import AccPerplex
|
||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||
|
@ -471,6 +471,10 @@ def main(args):
|
|||
# zero the grads of parameters
|
||||
trainer.zero_grad()
|
||||
type_ids = batch[0].pop("type_ids", None)
|
||||
# process data
|
||||
# if use_flash_attn is False, we need to unpack type_ids
|
||||
if not gpc.config.model.use_flash_attn:
|
||||
type_ids = unpack_data(type_ids, batch[0]["cu_seqlens"])
|
||||
if type_ids is not None:
|
||||
metric.set_current_type_ids(type_ids=type_ids)
|
||||
|
||||
|
|
Loading…
Reference in New Issue