feat(*): support not-flash-attn for pp and no-pp (#145)

* support not flash attention for no-pp

* support pipeline

* modify the config

* refactor the code

* refactor the code

* remove some unnecessary code
pull/158/head
ytxiong 2023-07-28 16:13:04 +08:00 committed by GitHub
parent 8b1717a05d
commit 5ee651c2f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 144 additions and 34 deletions

View File

@ -110,6 +110,7 @@ model = dict(
dtype="torch.bfloat16",
norm_type="rmsnorm",
layer_norm_epsilon=1e-5,
use_flash_attn=True,
)
"""
zero1 parallel:

View File

@ -3,7 +3,6 @@
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
import inspect
from typing import Any, Callable, Iterable
import torch
@ -36,15 +35,6 @@ class NonPipelineScheduler(BaseScheduler):
"""
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
# check that non-pipeline schedule data process func only takes in one parameter
# which is the batch data
if data_process_func:
sig = inspect.signature(data_process_func)
assert len(sig.parameters) == 1, (
"The data_process_func only takes in one parameter for NonPipelineSchedule, "
"which is a tuple of tensors for the current batch, "
"i.e. data_process_func(dataloader_output)."
)
self._grad_accum_size = gradient_accumulation_size
self._grad_accum_batch_size = 1 # static batch size for flash attetion.
@ -72,6 +62,12 @@ class NonPipelineScheduler(BaseScheduler):
data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size
)
self._grad_accum_offset += self._grad_accum_batch_size
if self.data_process_func:
_data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"])
_label = self.data_process_func(_label, _data["cu_seqlens"])
_data.pop("cu_seqlens")
_data.pop("indexes")
return _data, _label
@ -152,12 +148,7 @@ class NonPipelineScheduler(BaseScheduler):
batch_size == self._grad_accum_size
), f"batch_size:{batch_size} must be equal to gradient accumulation steps:{self._grad_accum_size}"
if self.data_process_func:
data, label = self.data_process_func(batch_data)
else:
# if not batch data process func is given,
# then we regard the batch data as a simple tuple of (data, label)
data, label = batch_data
data, label = batch_data
loss = 0 if return_loss else None
outputs = []

View File

@ -30,10 +30,16 @@ def get_tensor_shape():
return None
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
gpc.config.HIDDEN_SIZE,
)
if gpc.config.model.use_flash_attn:
tensor_shape = (
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
gpc.config.HIDDEN_SIZE,
)
else:
tensor_shape = (
gpc.config.data["micro_bsz"], gpc.config.SEQ_LEN,
gpc.config.HIDDEN_SIZE,
)
return tensor_shape
else:
return None
@ -122,14 +128,21 @@ class PipelineScheduler(BaseScheduler):
self.microbatch_size = self.batch_size // self.num_microbatches
def load_micro_batch(self):
mciro_batch_data, micro_batch_label = self._load_micro_batch(
micro_batch_data, micro_batch_label = self._load_micro_batch(
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
)
self.microbatch_offset += self.microbatch_size
# unpack data process
# TODO by xyt
return move_to_device(mciro_batch_data), move_to_device(micro_batch_label)
if self.data_process_func:
micro_batch_data["input_ids"] = self.data_process_func(
micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
)
micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"])
micro_batch_data.pop("cu_seqlens")
micro_batch_data.pop("indexes")
return move_to_device(micro_batch_data), move_to_device(micro_batch_label)
def pre_processing(self, engine):
model = engine.model
@ -204,11 +217,12 @@ class PipelineScheduler(BaseScheduler):
pipeline stage.
"""
micro_batch_data, micro_batch_label = self.load_micro_batch()
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, micro_batch_label)
timer("fwd").start()
output_obj = self._call_engine(engine.model, data)
timer("fwd").stop()
if gpc.is_last_rank(ParallelMode.PIPELINE):
timer("post_fn").start()
post_func = kwargs.get("post_fn")
@ -295,6 +309,7 @@ class PipelineScheduler(BaseScheduler):
assert (
forward_only or return_loss
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
self.load_batch(engine, data_iter)
num_warmup_microbatches = (
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
@ -841,4 +856,4 @@ class InterleavedPipelineScheduler(PipelineScheduler):
output, label = pack_return_tensors(return_tensors)
return output, label, accum_loss
else:
return None, None, accum_loss
return None, None, accum_loss

View File

@ -144,6 +144,48 @@ class PackedDataset(torch.utils.data.Dataset):
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
return out
def cal_pos_unpack(self, index):
if index == 0:
pre_pos = 0
else:
pre_pos = index * gpc.config.data["micro_bsz"]
pos = (index + 1) * gpc.config.data["micro_bsz"]
return pre_pos, pos
def build_unpack(self, index):
pre_pos, pos = self.cal_pos_unpack(index)
pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
while pre_pos < pos and pre_pos < len(self.dataset):
sample_idx = self.sample_indices[pre_pos]
sample = self.dataset[sample_idx]
length = min(len(sample["tokens"]), self.max_length_per_sample)
chunk = sample["tokens"][0:length]
pack.extend(chunk)
_labels = deepcopy(chunk)
_labels = list(_labels[1:]) + [-100]
assert len(_labels) == len(chunk), (_labels, chunk)
labels.extend(_labels)
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
cu_seqlens.append(cu_seqlens[-1] + len(chunk))
indexes.extend(list(range(length)))
pre_pos = pre_pos + 1
if cu_seqlens[-1] != self.packed_length:
pack = pack + [0] * (self.packed_length - cu_seqlens[-1])
labels = labels + [0] * (self.packed_length - cu_seqlens[-1])
type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1])
indexes.extend(list(range(self.packed_length - cu_seqlens[-1])))
cu_seqlens.append(self.packed_length)
assert len(pack) == self.packed_length
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
return out
def __getitem__(self, item: int) -> Dict:
"""Given the index, it returns a dict as
{
@ -154,8 +196,11 @@ class PackedDataset(torch.utils.data.Dataset):
}
"""
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
if gpc.config.model.use_flash_attn:
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
return self.build_unpack(item)
class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset):

View File

@ -1,6 +1,10 @@
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
import torch
from internlm.core.context import global_context as gpc
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2, "ja": 3, "ar": 4, "kaoshi": 5}
@ -13,3 +17,29 @@ def get_dataset_type_id(path):
match_idxes.append(idx)
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
return match_idxes[0]
def unpack_data(input_ids, cu_seqlens):
"""
input_ids: (n, packed_length)
Return:
output: (batch_size, max_length)
"""
bsz = input_ids.shape[0]
num_sequence = gpc.config.data["micro_bsz"]
outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
for i in range(bsz):
output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
cu_seqlens_slice = cu_seqlens[i]
for j in range(num_sequence):
seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j]
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
outputs[i] = output
if bsz == 1:
outputs = outputs.squeeze(0)
return outputs

View File

@ -22,6 +22,7 @@ from internlm.core.scheduler.pipeline_scheduler import (
get_tensor_shape,
)
from internlm.core.trainer import Trainer
from internlm.data.utils import unpack_data
from internlm.solver.beta2_scheduler import Beta2Scheduler
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
from internlm.utils.common import get_current_device
@ -77,9 +78,17 @@ def initialize_trainer(
# initialize scheduler for trainer
scheduler = None
if gpc.config.model.use_flash_attn:
data_fn = None
else:
data_fn = unpack_data
if gpc.is_using_pp():
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
tensor_shape = get_tensor_shape()
# if gpc.config.model.use_flash_attn:
# tensor_shape = get_tensor_shape()
# else:
# tensor_shape = None
use_interleaved = (
hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
)
@ -96,13 +105,14 @@ def initialize_trainer(
)
else:
scheduler = PipelineScheduler(
data_process_func=data_fn,
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
dtype=gpc.config.model["dtype"],
tensor_shape=tensor_shape,
scatter_gather_tensors=scatter_gather,
)
else:
scheduler = NonPipelineScheduler(gradient_accumulation_size=gpc.config.data.gradient_accumulation)
scheduler = NonPipelineScheduler(data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation)
# initialize engine for trainer
engine = Engine(

View File

@ -51,7 +51,10 @@ class AccPerplex:
return self.update(logits, labels, type_ids=self.type_ids)
def update(self, logits, labels, type_ids=None):
micro_bsz = labels.size(0)
if gpc.config.model.use_flash_attn:
micro_bsz = labels.size(0)
else:
micro_bsz = 1
if type_ids is not None:
type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
self.batch_shift += 1

View File

@ -49,6 +49,7 @@ class PackedFlashBaseLayer1D(nn.Module):
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
device (Optional[Union[str, torch.device]]): The device will be used.
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
use_flash_attn (bool): Whether use flash-attn. True by default.
"""
def __init__(
@ -68,12 +69,14 @@ class PackedFlashBaseLayer1D(nn.Module):
dropout_selective_checkpoint: bool = True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
):
super().__init__()
self.checkpoint = checkpoint
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
self.layer_idx = layer_idx
self.use_flash_attn = use_flash_attn
head_dim = hidden_size // num_attention_heads
self.mixer = MHA(
@ -86,7 +89,7 @@ class PackedFlashBaseLayer1D(nn.Module):
layer_idx=layer_idx,
rotary_emb_dim=head_dim,
rotary_emb_scale_base=0,
use_flash_attn=True,
use_flash_attn=use_flash_attn,
sequence_parallel=False,
device=device,
dtype=dtype,
@ -244,6 +247,7 @@ class PackedFlashInternLm1D(nn.Module):
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
"""
@ -273,9 +277,11 @@ class PackedFlashInternLm1D(nn.Module):
dropout_selective_checkpoint: bool = True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
):
super().__init__()
self.use_flash_attn = use_flash_attn
if checkpoint_fraction <= 0:
checkpoint = False
if not checkpoint:
@ -322,6 +328,7 @@ class PackedFlashInternLm1D(nn.Module):
dropout_selective_checkpoint=dropout_selective_checkpoint,
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
)
for lid in range(num_layers)
]
@ -358,7 +365,7 @@ class PackedFlashInternLm1D(nn.Module):
if isinstance(cu_seqlens, list):
assert len(cu_seqlens) == 1
cu_seqlens = cu_seqlens[0].to(hidden_states.device)
if cu_seqlens is not None:
cu_seqlens = cu_seqlens.squeeze(0)
hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed init indicated a packed state
@ -456,6 +463,7 @@ def build_model_with_cfg(
dropout_selective_checkpoint=True,
use_scaled_init: bool = True,
use_swiglu: bool = True,
use_flash_attn: bool = True,
):
"""
Builde model with config
@ -485,6 +493,7 @@ def build_model_with_cfg(
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
use_scaled_init (bool): Whether to use scaled init. True by default.
use_swiglu (bool): Whether to use swiglu. True by default.
use_flash_attn (bool): Whether to use flash-attn. True by default.
"""
@ -507,6 +516,7 @@ def build_model_with_cfg(
dropout_selective_checkpoint=dropout_selective_checkpoint,
use_scaled_init=use_scaled_init,
use_swiglu=use_swiglu,
use_flash_attn=use_flash_attn,
)
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)

View File

@ -43,6 +43,7 @@ class MHA(nn.Module):
of x will be done before doing the matmul.
device (Optional[Union[str, torch.device]]): The device will be used.
dtype (Optional[torch.dtype]): The type of data.
use_flash_attn (bool): Whether to use flash-attn. True by default.
"""
@ -57,7 +58,7 @@ class MHA(nn.Module):
layer_idx: int = None,
rotary_emb_dim: int = 0,
rotary_emb_scale_base: int = 0,
use_flash_attn: bool = False,
use_flash_attn: bool = True,
sequence_parallel: bool = True,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,

View File

@ -25,7 +25,7 @@ from internlm.data.packed_dataset import (
PackedDatasetWithoutCuSeqlen,
get_packed_dataset_without_short_length,
)
from internlm.data.utils import DATASET_TYPE_IDS_MAP
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
from internlm.model.loss import FlashGPTLMLoss
from internlm.model.metrics import AccPerplex
from internlm.solver.beta2_scheduler import Beta2Scheduler
@ -471,6 +471,10 @@ def main(args):
# zero the grads of parameters
trainer.zero_grad()
type_ids = batch[0].pop("type_ids", None)
# process data
# if use_flash_attn is False, we need to unpack type_ids
if not gpc.config.model.use_flash_attn:
type_ids = unpack_data(type_ids, batch[0]["cu_seqlens"])
if type_ids is not None:
metric.set_current_type_ids(type_ids=type_ids)