pull/407/head
yingtongxiong 2023-10-08 13:20:29 +08:00
parent e5a2909af0
commit bf475b6940
5 changed files with 21 additions and 8 deletions

View File

@ -5,7 +5,7 @@ SEQ_LEN = 2048
HIDDEN_SIZE = 4096 HIDDEN_SIZE = 4096
NUM_ATTENTION_HEAD = 32 NUM_ATTENTION_HEAD = 32
MLP_RATIO = 8 / 3 MLP_RATIO = 8 / 3
NUM_LAYER = 32 NUM_LAYER = 4
VOCAB_SIZE = 103168 VOCAB_SIZE = 103168
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx" MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
@ -55,7 +55,7 @@ data = dict(
# defaults to the value of micro_num # defaults to the value of micro_num
valid_micro_num=4, valid_micro_num=4,
# defaults to 0, means disable evaluate # defaults to 0, means disable evaluate
valid_every=50, valid_every=1000,
pack_sample_into_one=False, pack_sample_into_one=False,
total_steps=50000, total_steps=50000,
skip_batches="", skip_batches="",

View File

@ -202,10 +202,10 @@ class NonPipelineScheduler(BaseScheduler):
if return_output_label: if return_output_label:
outputs.append(_output) outputs.append(_output)
labels.append(_label) labels.append(_label)
if not return_output_label: if not return_output_label:
outputs, labels = None, None outputs, labels = None, None
# Compatible for non-moe # Compatible for non-moe
if hasattr(gpc.config.model, "num_experts"): if hasattr(gpc.config.model, "num_experts"):
return outputs, labels, loss, moe_loss return outputs, labels, loss, moe_loss

View File

@ -28,9 +28,20 @@ from torch.cuda.amp import custom_bwd, custom_fwd
import fused_dense_lib as fused_dense_cuda import fused_dense_lib as fused_dense_cuda
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw from flash_attn.utils.distributed import all_gather_raw, all_reduce_raw
# reduce_scatter_raw
from flash_attn.utils.distributed import reduce_scatter, all_reduce from flash_attn.utils.distributed import reduce_scatter, all_reduce
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, op=torch.distributed.ReduceOp.SUM):
world_size = torch.distributed.get_world_size(process_group)
assert input_.shape[0] % world_size == 0
output = torch.empty(
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
)
handle = torch.distributed.reduce_scatter_tensor(
output, input_.contiguous(), op=op, group=process_group, async_op=async_op
)
return output, handle
class ScaleColumnParallelLinear(nn.Linear): class ScaleColumnParallelLinear(nn.Linear):
""" """
@ -279,15 +290,15 @@ class FusedDenseFunc_fsdp(torch.autograd.Function):
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1]) grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
# do all-gather for weight before backward # do all-gather for weight before backward
weight, handle_weight = all_gather_raw(weight, process_group, async_op=True) total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
handle_weight.wait() handle_weight.wait()
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
if not ctx.return_residual: if not ctx.return_residual:
grad_input = F.linear(grad_output, weight.t()) grad_input = F.linear(grad_output, total_weight.t())
else: else:
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]), grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
grad_output, weight) grad_output, total_weight)
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1]) grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
# if process_group is not None: # if process_group is not None:
# import pdb; pdb.set_trace() # import pdb; pdb.set_trace()

View File

@ -372,6 +372,7 @@ class PackedFlashInternLm1D(nn.Module):
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None): def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
# attention_mask: compute attention on the places where the value is 1 # attention_mask: compute attention on the places where the value is 1
import pdb; pdb.set_trace()
if hasattr(self, "embedding"): if hasattr(self, "embedding"):
hidden_states = self.embedding(input_ids) hidden_states = self.embedding(input_ids)
if self.embed_grad_scale != 1: if self.embed_grad_scale != 1:

View File

@ -254,6 +254,7 @@ def main(args):
trainer=trainer, trainer=trainer,
start_time=start_time, start_time=start_time,
loss=loss, loss=loss,
moe_loss=None,
grad_norm=grad_norm_groups, grad_norm=grad_norm_groups,
metric=metric, metric=metric,
update_panel=uniscale_logger is not None, update_panel=uniscale_logger is not None,