mirror of https://github.com/InternLM/InternLM
debug
parent
e5a2909af0
commit
bf475b6940
|
@ -5,7 +5,7 @@ SEQ_LEN = 2048
|
||||||
HIDDEN_SIZE = 4096
|
HIDDEN_SIZE = 4096
|
||||||
NUM_ATTENTION_HEAD = 32
|
NUM_ATTENTION_HEAD = 32
|
||||||
MLP_RATIO = 8 / 3
|
MLP_RATIO = 8 / 3
|
||||||
NUM_LAYER = 32
|
NUM_LAYER = 4
|
||||||
VOCAB_SIZE = 103168
|
VOCAB_SIZE = 103168
|
||||||
|
|
||||||
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
MODEL_ONLY_FOLDER = "local:llm_ckpts/xxxx"
|
||||||
|
@ -55,7 +55,7 @@ data = dict(
|
||||||
# defaults to the value of micro_num
|
# defaults to the value of micro_num
|
||||||
valid_micro_num=4,
|
valid_micro_num=4,
|
||||||
# defaults to 0, means disable evaluate
|
# defaults to 0, means disable evaluate
|
||||||
valid_every=50,
|
valid_every=1000,
|
||||||
pack_sample_into_one=False,
|
pack_sample_into_one=False,
|
||||||
total_steps=50000,
|
total_steps=50000,
|
||||||
skip_batches="",
|
skip_batches="",
|
||||||
|
|
|
@ -202,10 +202,10 @@ class NonPipelineScheduler(BaseScheduler):
|
||||||
if return_output_label:
|
if return_output_label:
|
||||||
outputs.append(_output)
|
outputs.append(_output)
|
||||||
labels.append(_label)
|
labels.append(_label)
|
||||||
|
|
||||||
if not return_output_label:
|
if not return_output_label:
|
||||||
outputs, labels = None, None
|
outputs, labels = None, None
|
||||||
|
|
||||||
# Compatible for non-moe
|
# Compatible for non-moe
|
||||||
if hasattr(gpc.config.model, "num_experts"):
|
if hasattr(gpc.config.model, "num_experts"):
|
||||||
return outputs, labels, loss, moe_loss
|
return outputs, labels, loss, moe_loss
|
||||||
|
|
|
@ -28,9 +28,20 @@ from torch.cuda.amp import custom_bwd, custom_fwd
|
||||||
import fused_dense_lib as fused_dense_cuda
|
import fused_dense_lib as fused_dense_cuda
|
||||||
|
|
||||||
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd
|
from flash_attn.ops.activations import gelu_bwd, relu_bwd, sqrelu_fwd, sqrelu_bwd
|
||||||
from flash_attn.utils.distributed import all_gather_raw, reduce_scatter_raw, all_reduce_raw
|
from flash_attn.utils.distributed import all_gather_raw, all_reduce_raw
|
||||||
|
# reduce_scatter_raw
|
||||||
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
from flash_attn.utils.distributed import reduce_scatter, all_reduce
|
||||||
|
|
||||||
|
def reduce_scatter_raw(input_: Tensor, process_group: ProcessGroup, async_op: bool = False, op=torch.distributed.ReduceOp.SUM):
|
||||||
|
world_size = torch.distributed.get_world_size(process_group)
|
||||||
|
assert input_.shape[0] % world_size == 0
|
||||||
|
output = torch.empty(
|
||||||
|
input_.shape[0] // world_size, *input_.shape[1:], dtype=input_.dtype, device=input_.device
|
||||||
|
)
|
||||||
|
handle = torch.distributed.reduce_scatter_tensor(
|
||||||
|
output, input_.contiguous(), op=op, group=process_group, async_op=async_op
|
||||||
|
)
|
||||||
|
return output, handle
|
||||||
|
|
||||||
class ScaleColumnParallelLinear(nn.Linear):
|
class ScaleColumnParallelLinear(nn.Linear):
|
||||||
"""
|
"""
|
||||||
|
@ -279,15 +290,15 @@ class FusedDenseFunc_fsdp(torch.autograd.Function):
|
||||||
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
grad_output = grad_output.reshape(batch_dim, grad_output.shape[-1])
|
||||||
|
|
||||||
# do all-gather for weight before backward
|
# do all-gather for weight before backward
|
||||||
weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
total_weight, handle_weight = all_gather_raw(weight, process_group, async_op=True)
|
||||||
handle_weight.wait()
|
handle_weight.wait()
|
||||||
|
|
||||||
if ctx.needs_input_grad[0]:
|
if ctx.needs_input_grad[0]:
|
||||||
if not ctx.return_residual:
|
if not ctx.return_residual:
|
||||||
grad_input = F.linear(grad_output, weight.t())
|
grad_input = F.linear(grad_output, total_weight.t())
|
||||||
else:
|
else:
|
||||||
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
|
grad_input = torch.addmm(grad_input.reshape(batch_dim, grad_input.shape[-1]),
|
||||||
grad_output, weight)
|
grad_output, total_weight)
|
||||||
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
grad_input = grad_input.reshape(*batch_shape, grad_input.shape[-1])
|
||||||
# if process_group is not None:
|
# if process_group is not None:
|
||||||
# import pdb; pdb.set_trace()
|
# import pdb; pdb.set_trace()
|
||||||
|
|
|
@ -372,6 +372,7 @@ class PackedFlashInternLm1D(nn.Module):
|
||||||
|
|
||||||
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
def forward(self, hidden_states=None, cu_seqlens=None, input_ids=None, indexes=None, inference_params=None):
|
||||||
# attention_mask: compute attention on the places where the value is 1
|
# attention_mask: compute attention on the places where the value is 1
|
||||||
|
import pdb; pdb.set_trace()
|
||||||
if hasattr(self, "embedding"):
|
if hasattr(self, "embedding"):
|
||||||
hidden_states = self.embedding(input_ids)
|
hidden_states = self.embedding(input_ids)
|
||||||
if self.embed_grad_scale != 1:
|
if self.embed_grad_scale != 1:
|
||||||
|
|
1
train.py
1
train.py
|
@ -254,6 +254,7 @@ def main(args):
|
||||||
trainer=trainer,
|
trainer=trainer,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
loss=loss,
|
loss=loss,
|
||||||
|
moe_loss=None,
|
||||||
grad_norm=grad_norm_groups,
|
grad_norm=grad_norm_groups,
|
||||||
metric=metric,
|
metric=metric,
|
||||||
update_panel=uniscale_logger is not None,
|
update_panel=uniscale_logger is not None,
|
||||||
|
|
Loading…
Reference in New Issue