diff --git a/configs/7B_sft.py b/configs/7B_sft.py index 060d3cd..70e692e 100644 --- a/configs/7B_sft.py +++ b/configs/7B_sft.py @@ -110,6 +110,7 @@ model = dict( dtype="torch.bfloat16", norm_type="rmsnorm", layer_norm_epsilon=1e-5, + use_flash_attn=True, ) """ zero1 parallel: diff --git a/internlm/core/scheduler/no_pipeline_scheduler.py b/internlm/core/scheduler/no_pipeline_scheduler.py index 6133e2f..cdf3edc 100644 --- a/internlm/core/scheduler/no_pipeline_scheduler.py +++ b/internlm/core/scheduler/no_pipeline_scheduler.py @@ -3,7 +3,6 @@ # adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine -import inspect from typing import Any, Callable, Iterable import torch @@ -36,15 +35,6 @@ class NonPipelineScheduler(BaseScheduler): """ def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1): - # check that non-pipeline schedule data process func only takes in one parameter - # which is the batch data - if data_process_func: - sig = inspect.signature(data_process_func) - assert len(sig.parameters) == 1, ( - "The data_process_func only takes in one parameter for NonPipelineSchedule, " - "which is a tuple of tensors for the current batch, " - "i.e. data_process_func(dataloader_output)." - ) self._grad_accum_size = gradient_accumulation_size self._grad_accum_batch_size = 1 # static batch size for flash attetion. @@ -72,6 +62,12 @@ class NonPipelineScheduler(BaseScheduler): data=data, label=label, offset=self._grad_accum_offset, micro_bsz=self._grad_accum_batch_size ) self._grad_accum_offset += self._grad_accum_batch_size + + if self.data_process_func: + _data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"]) + _label = self.data_process_func(_label, _data["cu_seqlens"]) + _data.pop("cu_seqlens") + _data.pop("indexes") return _data, _label @@ -152,12 +148,7 @@ class NonPipelineScheduler(BaseScheduler): batch_size == self._grad_accum_size ), f"batch_size:{batch_size} must be equal to gradient accumulation steps:{self._grad_accum_size}" - if self.data_process_func: - data, label = self.data_process_func(batch_data) - else: - # if not batch data process func is given, - # then we regard the batch data as a simple tuple of (data, label) - data, label = batch_data + data, label = batch_data loss = 0 if return_loss else None outputs = [] diff --git a/internlm/core/scheduler/pipeline_scheduler.py b/internlm/core/scheduler/pipeline_scheduler.py index c0eb04b..1ce35ee 100644 --- a/internlm/core/scheduler/pipeline_scheduler.py +++ b/internlm/core/scheduler/pipeline_scheduler.py @@ -30,10 +30,16 @@ def get_tensor_shape(): return None if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"): - tensor_shape = ( - gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"], - gpc.config.HIDDEN_SIZE, - ) + if gpc.config.model.use_flash_attn: + tensor_shape = ( + gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"], + gpc.config.HIDDEN_SIZE, + ) + else: + tensor_shape = ( + gpc.config.data["micro_bsz"], gpc.config.SEQ_LEN, + gpc.config.HIDDEN_SIZE, + ) return tensor_shape else: return None @@ -122,14 +128,21 @@ class PipelineScheduler(BaseScheduler): self.microbatch_size = self.batch_size // self.num_microbatches def load_micro_batch(self): - mciro_batch_data, micro_batch_label = self._load_micro_batch( + micro_batch_data, micro_batch_label = self._load_micro_batch( data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size ) self.microbatch_offset += self.microbatch_size - # unpack data process - # TODO by xyt - return move_to_device(mciro_batch_data), move_to_device(micro_batch_label) + if self.data_process_func: + micro_batch_data["input_ids"] = self.data_process_func( + micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"] + ) + micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"]) + + micro_batch_data.pop("cu_seqlens") + micro_batch_data.pop("indexes") + + return move_to_device(micro_batch_data), move_to_device(micro_batch_label) def pre_processing(self, engine): model = engine.model @@ -204,11 +217,12 @@ class PipelineScheduler(BaseScheduler): pipeline stage. """ micro_batch_data, micro_batch_label = self.load_micro_batch() - data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, micro_batch_label) + timer("fwd").start() output_obj = self._call_engine(engine.model, data) timer("fwd").stop() + if gpc.is_last_rank(ParallelMode.PIPELINE): timer("post_fn").start() post_func = kwargs.get("post_fn") @@ -295,6 +309,7 @@ class PipelineScheduler(BaseScheduler): assert ( forward_only or return_loss ), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False." + self.load_batch(engine, data_iter) num_warmup_microbatches = ( gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1 @@ -841,4 +856,4 @@ class InterleavedPipelineScheduler(PipelineScheduler): output, label = pack_return_tensors(return_tensors) return output, label, accum_loss else: - return None, None, accum_loss + return None, None, accum_loss \ No newline at end of file diff --git a/internlm/data/packed_dataset.py b/internlm/data/packed_dataset.py index e9151bf..25a41f5 100644 --- a/internlm/data/packed_dataset.py +++ b/internlm/data/packed_dataset.py @@ -144,6 +144,48 @@ class PackedDataset(torch.utils.data.Dataset): out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids} return out + def cal_pos_unpack(self, index): + if index == 0: + pre_pos = 0 + else: + pre_pos = index * gpc.config.data["micro_bsz"] + + pos = (index + 1) * gpc.config.data["micro_bsz"] + return pre_pos, pos + + def build_unpack(self, index): + + pre_pos, pos = self.cal_pos_unpack(index) + + pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], [] + + while pre_pos < pos and pre_pos < len(self.dataset): + sample_idx = self.sample_indices[pre_pos] + sample = self.dataset[sample_idx] + length = min(len(sample["tokens"]), self.max_length_per_sample) + chunk = sample["tokens"][0:length] + pack.extend(chunk) + _labels = deepcopy(chunk) + _labels = list(_labels[1:]) + [-100] + assert len(_labels) == len(chunk), (_labels, chunk) + labels.extend(_labels) + type_ids.extend([sample.get("type_id", 0)] * len(chunk)) + cu_seqlens.append(cu_seqlens[-1] + len(chunk)) + indexes.extend(list(range(length))) + pre_pos = pre_pos + 1 + + if cu_seqlens[-1] != self.packed_length: + pack = pack + [0] * (self.packed_length - cu_seqlens[-1]) + labels = labels + [0] * (self.packed_length - cu_seqlens[-1]) + type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1]) + indexes.extend(list(range(self.packed_length - cu_seqlens[-1]))) + cu_seqlens.append(self.packed_length) + + assert len(pack) == self.packed_length + + out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids} + return out + def __getitem__(self, item: int) -> Dict: """Given the index, it returns a dict as { @@ -154,8 +196,11 @@ class PackedDataset(torch.utils.data.Dataset): } """ - pos_before, token_id_before, pos_after, token_id_after = self.mapping(item) - return self.build_pack(pos_before, token_id_before, pos_after, token_id_after) + if gpc.config.model.use_flash_attn: + pos_before, token_id_before, pos_after, token_id_after = self.mapping(item) + return self.build_pack(pos_before, token_id_before, pos_after, token_id_after) + + return self.build_unpack(item) class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset): diff --git a/internlm/data/utils.py b/internlm/data/utils.py index b003469..4d9c775 100644 --- a/internlm/data/utils.py +++ b/internlm/data/utils.py @@ -1,6 +1,10 @@ #!/usr/bin/env python # -*- encoding: utf-8 -*- +import torch + +from internlm.core.context import global_context as gpc + DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2, "ja": 3, "ar": 4, "kaoshi": 5} @@ -13,3 +17,29 @@ def get_dataset_type_id(path): match_idxes.append(idx) assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}" return match_idxes[0] + +def unpack_data(input_ids, cu_seqlens): + """ + input_ids: (n, packed_length) + Return: + output: (batch_size, max_length) + """ + + bsz = input_ids.shape[0] + + num_sequence = gpc.config.data["micro_bsz"] + + outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype) + + for i in range(bsz): + output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype) + cu_seqlens_slice = cu_seqlens[i] + for j in range(num_sequence): + seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j] + output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]] + outputs[i] = output + + if bsz == 1: + outputs = outputs.squeeze(0) + + return outputs \ No newline at end of file diff --git a/internlm/initialize/initialize_trainer.py b/internlm/initialize/initialize_trainer.py index 7a5bbc6..801012a 100644 --- a/internlm/initialize/initialize_trainer.py +++ b/internlm/initialize/initialize_trainer.py @@ -22,6 +22,7 @@ from internlm.core.scheduler.pipeline_scheduler import ( get_tensor_shape, ) from internlm.core.trainer import Trainer +from internlm.data.utils import unpack_data from internlm.solver.beta2_scheduler import Beta2Scheduler from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer from internlm.utils.common import get_current_device @@ -77,9 +78,17 @@ def initialize_trainer( # initialize scheduler for trainer scheduler = None + if gpc.config.model.use_flash_attn: + data_fn = None + else: + data_fn = unpack_data if gpc.is_using_pp(): gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num tensor_shape = get_tensor_shape() + # if gpc.config.model.use_flash_attn: + # tensor_shape = get_tensor_shape() + # else: + # tensor_shape = None use_interleaved = ( hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1 ) @@ -96,13 +105,14 @@ def initialize_trainer( ) else: scheduler = PipelineScheduler( + data_process_func=data_fn, num_microbatches=gpc.config.NUM_MICRO_BATCHES, dtype=gpc.config.model["dtype"], tensor_shape=tensor_shape, scatter_gather_tensors=scatter_gather, ) else: - scheduler = NonPipelineScheduler(gradient_accumulation_size=gpc.config.data.gradient_accumulation) + scheduler = NonPipelineScheduler(data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation) # initialize engine for trainer engine = Engine( diff --git a/internlm/model/metrics.py b/internlm/model/metrics.py index b6a2d45..1749aa2 100644 --- a/internlm/model/metrics.py +++ b/internlm/model/metrics.py @@ -51,7 +51,10 @@ class AccPerplex: return self.update(logits, labels, type_ids=self.type_ids) def update(self, logits, labels, type_ids=None): - micro_bsz = labels.size(0) + if gpc.config.model.use_flash_attn: + micro_bsz = labels.size(0) + else: + micro_bsz = 1 if type_ids is not None: type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1) self.batch_shift += 1 diff --git a/internlm/model/modeling_internlm.py b/internlm/model/modeling_internlm.py index 2e8181c..11340a4 100644 --- a/internlm/model/modeling_internlm.py +++ b/internlm/model/modeling_internlm.py @@ -49,6 +49,7 @@ class PackedFlashBaseLayer1D(nn.Module): residual_in_fp32 (bool): Whether to use residual in fp32. False by default. device (Optional[Union[str, torch.device]]): The device will be used. norm_type (str): Use RMS norm or layernorm."rmsnorm" by default. + use_flash_attn (bool): Whether use flash-attn. True by default. """ def __init__( @@ -68,12 +69,14 @@ class PackedFlashBaseLayer1D(nn.Module): dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, + use_flash_attn: bool = True, ): super().__init__() self.checkpoint = checkpoint # dropout selective checkpoint can only be enabled when checkpoint is disabled. self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False self.layer_idx = layer_idx + self.use_flash_attn = use_flash_attn head_dim = hidden_size // num_attention_heads self.mixer = MHA( @@ -86,7 +89,7 @@ class PackedFlashBaseLayer1D(nn.Module): layer_idx=layer_idx, rotary_emb_dim=head_dim, rotary_emb_scale_base=0, - use_flash_attn=True, + use_flash_attn=use_flash_attn, sequence_parallel=False, device=device, dtype=dtype, @@ -244,6 +247,7 @@ class PackedFlashInternLm1D(nn.Module): device (Optional[Union[str, torch.device]]): The device will be used. None by default. residual_in_fp32 (bool): Whether to use residual in fp32. False by default. norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default. + use_flash_attn (bool): Whether to use flash-attn. True by default. """ @@ -273,9 +277,11 @@ class PackedFlashInternLm1D(nn.Module): dropout_selective_checkpoint: bool = True, use_scaled_init: bool = True, use_swiglu: bool = True, + use_flash_attn: bool = True, ): super().__init__() + self.use_flash_attn = use_flash_attn if checkpoint_fraction <= 0: checkpoint = False if not checkpoint: @@ -322,6 +328,7 @@ class PackedFlashInternLm1D(nn.Module): dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, + use_flash_attn=use_flash_attn, ) for lid in range(num_layers) ] @@ -358,7 +365,7 @@ class PackedFlashInternLm1D(nn.Module): if isinstance(cu_seqlens, list): assert len(cu_seqlens) == 1 cu_seqlens = cu_seqlens[0].to(hidden_states.device) - + if cu_seqlens is not None: cu_seqlens = cu_seqlens.squeeze(0) hidden_states = hidden_states.squeeze(0) # If cu_seqlens is passed in,it indicated a packed state, @@ -456,6 +463,7 @@ def build_model_with_cfg( dropout_selective_checkpoint=True, use_scaled_init: bool = True, use_swiglu: bool = True, + use_flash_attn: bool = True, ): """ Builde model with config @@ -485,6 +493,7 @@ def build_model_with_cfg( dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default. use_scaled_init (bool): Whether to use scaled init. True by default. use_swiglu (bool): Whether to use swiglu. True by default. + use_flash_attn (bool): Whether to use flash-attn. True by default. """ @@ -507,6 +516,7 @@ def build_model_with_cfg( dropout_selective_checkpoint=dropout_selective_checkpoint, use_scaled_init=use_scaled_init, use_swiglu=use_swiglu, + use_flash_attn=use_flash_attn, ) return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg) diff --git a/internlm/model/multi_head_attention.py b/internlm/model/multi_head_attention.py index 88ce759..7513563 100644 --- a/internlm/model/multi_head_attention.py +++ b/internlm/model/multi_head_attention.py @@ -43,6 +43,7 @@ class MHA(nn.Module): of x will be done before doing the matmul. device (Optional[Union[str, torch.device]]): The device will be used. dtype (Optional[torch.dtype]): The type of data. + use_flash_attn (bool): Whether to use flash-attn. True by default. """ @@ -57,7 +58,7 @@ class MHA(nn.Module): layer_idx: int = None, rotary_emb_dim: int = 0, rotary_emb_scale_base: int = 0, - use_flash_attn: bool = False, + use_flash_attn: bool = True, sequence_parallel: bool = True, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, diff --git a/train.py b/train.py index 17db3f3..bca8b54 100644 --- a/train.py +++ b/train.py @@ -25,7 +25,7 @@ from internlm.data.packed_dataset import ( PackedDatasetWithoutCuSeqlen, get_packed_dataset_without_short_length, ) -from internlm.data.utils import DATASET_TYPE_IDS_MAP +from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data from internlm.model.loss import FlashGPTLMLoss from internlm.model.metrics import AccPerplex from internlm.solver.beta2_scheduler import Beta2Scheduler @@ -471,6 +471,10 @@ def main(args): # zero the grads of parameters trainer.zero_grad() type_ids = batch[0].pop("type_ids", None) + # process data + # if use_flash_attn is False, we need to unpack type_ids + if not gpc.config.model.use_flash_attn: + type_ids = unpack_data(type_ids, batch[0]["cu_seqlens"]) if type_ids is not None: metric.set_current_type_ids(type_ids=type_ids)