mirror of https://github.com/InternLM/InternLM
feat(*): support not-flash-attn for pp and no-pp (#145)
* support not flash attention for no-pp * support pipeline * modify the config * refactor the code * refactor the code * remove some unnecessary codepull/158/head
parent
8b1717a05d
commit
5ee651c2f1
|
@ -110,6 +110,7 @@ model = dict(
|
||||||
dtype="torch.bfloat16",
|
dtype="torch.bfloat16",
|
||||||
norm_type="rmsnorm",
|
norm_type="rmsnorm",
|
||||||
layer_norm_epsilon=1e-5,
|
layer_norm_epsilon=1e-5,
|
||||||
|
use_flash_attn=True,
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
zero1 parallel:
|
zero1 parallel:
|
||||||
|
|
|
@ -3,7 +3,6 @@
|
||||||
|
|
||||||
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
# adopted from https://github.com/hpcaitech/ColossalAI/blob/main/colossalai/engine
|
||||||
|
|
||||||
import inspect
|
|
||||||
from typing import Any, Callable, Iterable
|
from typing import Any, Callable, Iterable
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -36,15 +35,6 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
|
def __init__(self, data_process_func: Callable = None, gradient_accumulation_size: int = 1):
|
||||||
# check that non-pipeline schedule data process func only takes in one parameter
|
|
||||||
# which is the batch data
|
|
||||||
if data_process_func:
|
|
||||||
sig = inspect.signature(data_process_func)
|
|
||||||
assert len(sig.parameters) == 1, (
|
|
||||||
"The data_process_func only takes in one parameter for NonPipelineSchedule, "
|
|
||||||
"which is a tuple of tensors for the current batch, "
|
|
||||||
"i.e. data_process_func(dataloader_output)."
|
|
||||||
)
|
|
||||||
|
|
||||||
self._grad_accum_size = gradient_accumulation_size
|
self._grad_accum_size = gradient_accumulation_size
|
||||||
self._grad_accum_batch_size = 1 # static batch size for flash attetion.
|
self._grad_accum_batch_size = 1 # static batch size for flash attetion.
|
||||||
|
@ -73,6 +63,12 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
)
|
)
|
||||||
self._grad_accum_offset += self._grad_accum_batch_size
|
self._grad_accum_offset += self._grad_accum_batch_size
|
||||||
|
|
||||||
|
if self.data_process_func:
|
||||||
|
_data["input_ids"] = self.data_process_func(_data["input_ids"], _data["cu_seqlens"])
|
||||||
|
_label = self.data_process_func(_label, _data["cu_seqlens"])
|
||||||
|
_data.pop("cu_seqlens")
|
||||||
|
_data.pop("indexes")
|
||||||
|
|
||||||
return _data, _label
|
return _data, _label
|
||||||
|
|
||||||
def _train_one_batch(
|
def _train_one_batch(
|
||||||
|
@ -152,11 +148,6 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
batch_size == self._grad_accum_size
|
batch_size == self._grad_accum_size
|
||||||
), f"batch_size:{batch_size} must be equal to gradient accumulation steps:{self._grad_accum_size}"
|
), f"batch_size:{batch_size} must be equal to gradient accumulation steps:{self._grad_accum_size}"
|
||||||
|
|
||||||
if self.data_process_func:
|
|
||||||
data, label = self.data_process_func(batch_data)
|
|
||||||
else:
|
|
||||||
# if not batch data process func is given,
|
|
||||||
# then we regard the batch data as a simple tuple of (data, label)
|
|
||||||
data, label = batch_data
|
data, label = batch_data
|
||||||
|
|
||||||
loss = 0 if return_loss else None
|
loss = 0 if return_loss else None
|
||||||
|
|
|
@ -30,10 +30,16 @@ def get_tensor_shape():
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
if hasattr(gpc.config, "SEQ_LEN") and hasattr(gpc.config.data, "micro_bsz") and hasattr(gpc.config, "HIDDEN_SIZE"):
|
||||||
|
if gpc.config.model.use_flash_attn:
|
||||||
tensor_shape = (
|
tensor_shape = (
|
||||||
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
gpc.config.SEQ_LEN * gpc.config.data["micro_bsz"],
|
||||||
gpc.config.HIDDEN_SIZE,
|
gpc.config.HIDDEN_SIZE,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
tensor_shape = (
|
||||||
|
gpc.config.data["micro_bsz"], gpc.config.SEQ_LEN,
|
||||||
|
gpc.config.HIDDEN_SIZE,
|
||||||
|
)
|
||||||
return tensor_shape
|
return tensor_shape
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
@ -122,14 +128,21 @@ class PipelineScheduler(BaseScheduler):
|
||||||
self.microbatch_size = self.batch_size // self.num_microbatches
|
self.microbatch_size = self.batch_size // self.num_microbatches
|
||||||
|
|
||||||
def load_micro_batch(self):
|
def load_micro_batch(self):
|
||||||
mciro_batch_data, micro_batch_label = self._load_micro_batch(
|
micro_batch_data, micro_batch_label = self._load_micro_batch(
|
||||||
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
|
data=self.batch_data, label=self.batch_label, offset=self.microbatch_offset, micro_bsz=self.microbatch_size
|
||||||
)
|
)
|
||||||
self.microbatch_offset += self.microbatch_size
|
self.microbatch_offset += self.microbatch_size
|
||||||
|
|
||||||
# unpack data process
|
if self.data_process_func:
|
||||||
# TODO by xyt
|
micro_batch_data["input_ids"] = self.data_process_func(
|
||||||
return move_to_device(mciro_batch_data), move_to_device(micro_batch_label)
|
micro_batch_data["input_ids"], micro_batch_data["cu_seqlens"]
|
||||||
|
)
|
||||||
|
micro_batch_label = self.data_process_func(micro_batch_label, micro_batch_data["cu_seqlens"])
|
||||||
|
|
||||||
|
micro_batch_data.pop("cu_seqlens")
|
||||||
|
micro_batch_data.pop("indexes")
|
||||||
|
|
||||||
|
return move_to_device(micro_batch_data), move_to_device(micro_batch_label)
|
||||||
|
|
||||||
def pre_processing(self, engine):
|
def pre_processing(self, engine):
|
||||||
model = engine.model
|
model = engine.model
|
||||||
|
@ -204,11 +217,12 @@ class PipelineScheduler(BaseScheduler):
|
||||||
pipeline stage.
|
pipeline stage.
|
||||||
"""
|
"""
|
||||||
micro_batch_data, micro_batch_label = self.load_micro_batch()
|
micro_batch_data, micro_batch_label = self.load_micro_batch()
|
||||||
|
|
||||||
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, micro_batch_label)
|
data, label = self._get_data_label_for_current_step(input_obj, micro_batch_data, micro_batch_label)
|
||||||
|
|
||||||
timer("fwd").start()
|
timer("fwd").start()
|
||||||
output_obj = self._call_engine(engine.model, data)
|
output_obj = self._call_engine(engine.model, data)
|
||||||
timer("fwd").stop()
|
timer("fwd").stop()
|
||||||
|
|
||||||
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
if gpc.is_last_rank(ParallelMode.PIPELINE):
|
||||||
timer("post_fn").start()
|
timer("post_fn").start()
|
||||||
post_func = kwargs.get("post_fn")
|
post_func = kwargs.get("post_fn")
|
||||||
|
@ -295,6 +309,7 @@ class PipelineScheduler(BaseScheduler):
|
||||||
assert (
|
assert (
|
||||||
forward_only or return_loss
|
forward_only or return_loss
|
||||||
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
), "The argument 'return_loss' has to be True when 'forward_only' is False, but got False."
|
||||||
|
|
||||||
self.load_batch(engine, data_iter)
|
self.load_batch(engine, data_iter)
|
||||||
num_warmup_microbatches = (
|
num_warmup_microbatches = (
|
||||||
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
|
gpc.get_world_size(ParallelMode.PIPELINE) - gpc.get_local_rank(ParallelMode.PIPELINE) - 1
|
||||||
|
|
|
@ -144,6 +144,48 @@ class PackedDataset(torch.utils.data.Dataset):
|
||||||
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
|
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def cal_pos_unpack(self, index):
|
||||||
|
if index == 0:
|
||||||
|
pre_pos = 0
|
||||||
|
else:
|
||||||
|
pre_pos = index * gpc.config.data["micro_bsz"]
|
||||||
|
|
||||||
|
pos = (index + 1) * gpc.config.data["micro_bsz"]
|
||||||
|
return pre_pos, pos
|
||||||
|
|
||||||
|
def build_unpack(self, index):
|
||||||
|
|
||||||
|
pre_pos, pos = self.cal_pos_unpack(index)
|
||||||
|
|
||||||
|
pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], []
|
||||||
|
|
||||||
|
while pre_pos < pos and pre_pos < len(self.dataset):
|
||||||
|
sample_idx = self.sample_indices[pre_pos]
|
||||||
|
sample = self.dataset[sample_idx]
|
||||||
|
length = min(len(sample["tokens"]), self.max_length_per_sample)
|
||||||
|
chunk = sample["tokens"][0:length]
|
||||||
|
pack.extend(chunk)
|
||||||
|
_labels = deepcopy(chunk)
|
||||||
|
_labels = list(_labels[1:]) + [-100]
|
||||||
|
assert len(_labels) == len(chunk), (_labels, chunk)
|
||||||
|
labels.extend(_labels)
|
||||||
|
type_ids.extend([sample.get("type_id", 0)] * len(chunk))
|
||||||
|
cu_seqlens.append(cu_seqlens[-1] + len(chunk))
|
||||||
|
indexes.extend(list(range(length)))
|
||||||
|
pre_pos = pre_pos + 1
|
||||||
|
|
||||||
|
if cu_seqlens[-1] != self.packed_length:
|
||||||
|
pack = pack + [0] * (self.packed_length - cu_seqlens[-1])
|
||||||
|
labels = labels + [0] * (self.packed_length - cu_seqlens[-1])
|
||||||
|
type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1])
|
||||||
|
indexes.extend(list(range(self.packed_length - cu_seqlens[-1])))
|
||||||
|
cu_seqlens.append(self.packed_length)
|
||||||
|
|
||||||
|
assert len(pack) == self.packed_length
|
||||||
|
|
||||||
|
out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids}
|
||||||
|
return out
|
||||||
|
|
||||||
def __getitem__(self, item: int) -> Dict:
|
def __getitem__(self, item: int) -> Dict:
|
||||||
"""Given the index, it returns a dict as
|
"""Given the index, it returns a dict as
|
||||||
{
|
{
|
||||||
|
@ -154,9 +196,12 @@ class PackedDataset(torch.utils.data.Dataset):
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
if gpc.config.model.use_flash_attn:
|
||||||
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
|
pos_before, token_id_before, pos_after, token_id_after = self.mapping(item)
|
||||||
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
|
return self.build_pack(pos_before, token_id_before, pos_after, token_id_after)
|
||||||
|
|
||||||
|
return self.build_unpack(item)
|
||||||
|
|
||||||
|
|
||||||
class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset):
|
class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
# -*- encoding: utf-8 -*-
|
# -*- encoding: utf-8 -*-
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from internlm.core.context import global_context as gpc
|
||||||
|
|
||||||
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2, "ja": 3, "ar": 4, "kaoshi": 5}
|
DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2, "ja": 3, "ar": 4, "kaoshi": 5}
|
||||||
|
|
||||||
|
|
||||||
|
@ -13,3 +17,29 @@ def get_dataset_type_id(path):
|
||||||
match_idxes.append(idx)
|
match_idxes.append(idx)
|
||||||
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
|
assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}"
|
||||||
return match_idxes[0]
|
return match_idxes[0]
|
||||||
|
|
||||||
|
def unpack_data(input_ids, cu_seqlens):
|
||||||
|
"""
|
||||||
|
input_ids: (n, packed_length)
|
||||||
|
Return:
|
||||||
|
output: (batch_size, max_length)
|
||||||
|
"""
|
||||||
|
|
||||||
|
bsz = input_ids.shape[0]
|
||||||
|
|
||||||
|
num_sequence = gpc.config.data["micro_bsz"]
|
||||||
|
|
||||||
|
outputs = torch.zeros(bsz, num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
|
||||||
|
|
||||||
|
for i in range(bsz):
|
||||||
|
output = torch.zeros(num_sequence, gpc.config.data.seq_len, device=input_ids.device, dtype=input_ids.dtype)
|
||||||
|
cu_seqlens_slice = cu_seqlens[i]
|
||||||
|
for j in range(num_sequence):
|
||||||
|
seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j]
|
||||||
|
output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]]
|
||||||
|
outputs[i] = output
|
||||||
|
|
||||||
|
if bsz == 1:
|
||||||
|
outputs = outputs.squeeze(0)
|
||||||
|
|
||||||
|
return outputs
|
|
@ -22,6 +22,7 @@ from internlm.core.scheduler.pipeline_scheduler import (
|
||||||
get_tensor_shape,
|
get_tensor_shape,
|
||||||
)
|
)
|
||||||
from internlm.core.trainer import Trainer
|
from internlm.core.trainer import Trainer
|
||||||
|
from internlm.data.utils import unpack_data
|
||||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||||
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
from internlm.solver.optimizer.hybrid_zero_optim import BaseOptimizer
|
||||||
from internlm.utils.common import get_current_device
|
from internlm.utils.common import get_current_device
|
||||||
|
@ -77,9 +78,17 @@ def initialize_trainer(
|
||||||
|
|
||||||
# initialize scheduler for trainer
|
# initialize scheduler for trainer
|
||||||
scheduler = None
|
scheduler = None
|
||||||
|
if gpc.config.model.use_flash_attn:
|
||||||
|
data_fn = None
|
||||||
|
else:
|
||||||
|
data_fn = unpack_data
|
||||||
if gpc.is_using_pp():
|
if gpc.is_using_pp():
|
||||||
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
gpc.config.NUM_MICRO_BATCHES = gpc.config.data.micro_num
|
||||||
tensor_shape = get_tensor_shape()
|
tensor_shape = get_tensor_shape()
|
||||||
|
# if gpc.config.model.use_flash_attn:
|
||||||
|
# tensor_shape = get_tensor_shape()
|
||||||
|
# else:
|
||||||
|
# tensor_shape = None
|
||||||
use_interleaved = (
|
use_interleaved = (
|
||||||
hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
|
hasattr(gpc.config, "model") and hasattr(gpc.config.model, "num_chunks") and gpc.config.model.num_chunks > 1
|
||||||
)
|
)
|
||||||
|
@ -96,13 +105,14 @@ def initialize_trainer(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
scheduler = PipelineScheduler(
|
scheduler = PipelineScheduler(
|
||||||
|
data_process_func=data_fn,
|
||||||
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
num_microbatches=gpc.config.NUM_MICRO_BATCHES,
|
||||||
dtype=gpc.config.model["dtype"],
|
dtype=gpc.config.model["dtype"],
|
||||||
tensor_shape=tensor_shape,
|
tensor_shape=tensor_shape,
|
||||||
scatter_gather_tensors=scatter_gather,
|
scatter_gather_tensors=scatter_gather,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
scheduler = NonPipelineScheduler(gradient_accumulation_size=gpc.config.data.gradient_accumulation)
|
scheduler = NonPipelineScheduler(data_process_func=data_fn, gradient_accumulation_size=gpc.config.data.gradient_accumulation)
|
||||||
|
|
||||||
# initialize engine for trainer
|
# initialize engine for trainer
|
||||||
engine = Engine(
|
engine = Engine(
|
||||||
|
|
|
@ -51,7 +51,10 @@ class AccPerplex:
|
||||||
return self.update(logits, labels, type_ids=self.type_ids)
|
return self.update(logits, labels, type_ids=self.type_ids)
|
||||||
|
|
||||||
def update(self, logits, labels, type_ids=None):
|
def update(self, logits, labels, type_ids=None):
|
||||||
|
if gpc.config.model.use_flash_attn:
|
||||||
micro_bsz = labels.size(0)
|
micro_bsz = labels.size(0)
|
||||||
|
else:
|
||||||
|
micro_bsz = 1
|
||||||
if type_ids is not None:
|
if type_ids is not None:
|
||||||
type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
|
type_ids = type_ids[self.batch_shift * micro_bsz : (self.batch_shift + 1) * micro_bsz].view(-1)
|
||||||
self.batch_shift += 1
|
self.batch_shift += 1
|
||||||
|
|
|
@ -49,6 +49,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||||
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
norm_type (str): Use RMS norm or layernorm."rmsnorm" by default.
|
||||||
|
use_flash_attn (bool): Whether use flash-attn. True by default.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -68,12 +69,14 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
dropout_selective_checkpoint: bool = True,
|
dropout_selective_checkpoint: bool = True,
|
||||||
use_scaled_init: bool = True,
|
use_scaled_init: bool = True,
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
|
use_flash_attn: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.checkpoint = checkpoint
|
self.checkpoint = checkpoint
|
||||||
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
|
# dropout selective checkpoint can only be enabled when checkpoint is disabled.
|
||||||
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
|
self.dropout_selective_checkpoint = dropout_selective_checkpoint is True and checkpoint is False
|
||||||
self.layer_idx = layer_idx
|
self.layer_idx = layer_idx
|
||||||
|
self.use_flash_attn = use_flash_attn
|
||||||
|
|
||||||
head_dim = hidden_size // num_attention_heads
|
head_dim = hidden_size // num_attention_heads
|
||||||
self.mixer = MHA(
|
self.mixer = MHA(
|
||||||
|
@ -86,7 +89,7 @@ class PackedFlashBaseLayer1D(nn.Module):
|
||||||
layer_idx=layer_idx,
|
layer_idx=layer_idx,
|
||||||
rotary_emb_dim=head_dim,
|
rotary_emb_dim=head_dim,
|
||||||
rotary_emb_scale_base=0,
|
rotary_emb_scale_base=0,
|
||||||
use_flash_attn=True,
|
use_flash_attn=use_flash_attn,
|
||||||
sequence_parallel=False,
|
sequence_parallel=False,
|
||||||
device=device,
|
device=device,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
|
@ -244,6 +247,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
|
device (Optional[Union[str, torch.device]]): The device will be used. None by default.
|
||||||
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
residual_in_fp32 (bool): Whether to use residual in fp32. False by default.
|
||||||
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
norm_type (str): Normalization type. Use RMSNorm or LayerNorm. "rmsnorm" by default.
|
||||||
|
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -273,9 +277,11 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
dropout_selective_checkpoint: bool = True,
|
dropout_selective_checkpoint: bool = True,
|
||||||
use_scaled_init: bool = True,
|
use_scaled_init: bool = True,
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
|
use_flash_attn: bool = True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
self.use_flash_attn = use_flash_attn
|
||||||
if checkpoint_fraction <= 0:
|
if checkpoint_fraction <= 0:
|
||||||
checkpoint = False
|
checkpoint = False
|
||||||
if not checkpoint:
|
if not checkpoint:
|
||||||
|
@ -322,6 +328,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
||||||
use_scaled_init=use_scaled_init,
|
use_scaled_init=use_scaled_init,
|
||||||
use_swiglu=use_swiglu,
|
use_swiglu=use_swiglu,
|
||||||
|
use_flash_attn=use_flash_attn,
|
||||||
)
|
)
|
||||||
for lid in range(num_layers)
|
for lid in range(num_layers)
|
||||||
]
|
]
|
||||||
|
@ -456,6 +463,7 @@ def build_model_with_cfg(
|
||||||
dropout_selective_checkpoint=True,
|
dropout_selective_checkpoint=True,
|
||||||
use_scaled_init: bool = True,
|
use_scaled_init: bool = True,
|
||||||
use_swiglu: bool = True,
|
use_swiglu: bool = True,
|
||||||
|
use_flash_attn: bool = True,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Builde model with config
|
Builde model with config
|
||||||
|
@ -485,6 +493,7 @@ def build_model_with_cfg(
|
||||||
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
|
dropout_selective_checkpoint (bool): It can only be enabled when checkpoint is disabled. True by default.
|
||||||
use_scaled_init (bool): Whether to use scaled init. True by default.
|
use_scaled_init (bool): Whether to use scaled init. True by default.
|
||||||
use_swiglu (bool): Whether to use swiglu. True by default.
|
use_swiglu (bool): Whether to use swiglu. True by default.
|
||||||
|
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -507,6 +516,7 @@ def build_model_with_cfg(
|
||||||
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
dropout_selective_checkpoint=dropout_selective_checkpoint,
|
||||||
use_scaled_init=use_scaled_init,
|
use_scaled_init=use_scaled_init,
|
||||||
use_swiglu=use_swiglu,
|
use_swiglu=use_swiglu,
|
||||||
|
use_flash_attn=use_flash_attn,
|
||||||
)
|
)
|
||||||
|
|
||||||
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
return _build_generic_model_1d(num_layers=num_layers, num_chunks=num_chunks, **cfg)
|
||||||
|
|
|
@ -43,6 +43,7 @@ class MHA(nn.Module):
|
||||||
of x will be done before doing the matmul.
|
of x will be done before doing the matmul.
|
||||||
device (Optional[Union[str, torch.device]]): The device will be used.
|
device (Optional[Union[str, torch.device]]): The device will be used.
|
||||||
dtype (Optional[torch.dtype]): The type of data.
|
dtype (Optional[torch.dtype]): The type of data.
|
||||||
|
use_flash_attn (bool): Whether to use flash-attn. True by default.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -57,7 +58,7 @@ class MHA(nn.Module):
|
||||||
layer_idx: int = None,
|
layer_idx: int = None,
|
||||||
rotary_emb_dim: int = 0,
|
rotary_emb_dim: int = 0,
|
||||||
rotary_emb_scale_base: int = 0,
|
rotary_emb_scale_base: int = 0,
|
||||||
use_flash_attn: bool = False,
|
use_flash_attn: bool = True,
|
||||||
sequence_parallel: bool = True,
|
sequence_parallel: bool = True,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
|
6
train.py
6
train.py
|
@ -25,7 +25,7 @@ from internlm.data.packed_dataset import (
|
||||||
PackedDatasetWithoutCuSeqlen,
|
PackedDatasetWithoutCuSeqlen,
|
||||||
get_packed_dataset_without_short_length,
|
get_packed_dataset_without_short_length,
|
||||||
)
|
)
|
||||||
from internlm.data.utils import DATASET_TYPE_IDS_MAP
|
from internlm.data.utils import DATASET_TYPE_IDS_MAP, unpack_data
|
||||||
from internlm.model.loss import FlashGPTLMLoss
|
from internlm.model.loss import FlashGPTLMLoss
|
||||||
from internlm.model.metrics import AccPerplex
|
from internlm.model.metrics import AccPerplex
|
||||||
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
from internlm.solver.beta2_scheduler import Beta2Scheduler
|
||||||
|
@ -471,6 +471,10 @@ def main(args):
|
||||||
# zero the grads of parameters
|
# zero the grads of parameters
|
||||||
trainer.zero_grad()
|
trainer.zero_grad()
|
||||||
type_ids = batch[0].pop("type_ids", None)
|
type_ids = batch[0].pop("type_ids", None)
|
||||||
|
# process data
|
||||||
|
# if use_flash_attn is False, we need to unpack type_ids
|
||||||
|
if not gpc.config.model.use_flash_attn:
|
||||||
|
type_ids = unpack_data(type_ids, batch[0]["cu_seqlens"])
|
||||||
if type_ids is not None:
|
if type_ids is not None:
|
||||||
metric.set_current_type_ids(type_ids=type_ids)
|
metric.set_current_type_ids(type_ids=type_ids)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue