Browse Source

remove vocab_size args

pull/5844/head
wangbluo 5 months ago
parent
commit
dba59354d7
  1. 41
      colossalai/shardformer/layer/loss.py
  2. 2
      colossalai/shardformer/modeling/llama.py

41
colossalai/shardformer/layer/loss.py

@ -15,15 +15,7 @@ class DistCrossEntropy(Function):
"""
@staticmethod
def forward(
ctx,
vocab_logits: torch.Tensor,
target: torch.Tensor,
ignore_index: int,
process_group: ProcessGroup,
vocab_size: int,
dtype=torch.float32,
):
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup):
r"""
Calculate the cross entropy loss before gather, the origin loss function is as follows:
loss = -log(exp(x[class])/sum(exp(x[i]))
@ -35,7 +27,7 @@ class DistCrossEntropy(Function):
Args:
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
[batch_size, seq_len, vocab_size]
target (:class:`torch.Tensor`): The labels of the vocabulary, shape is
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
[batch_size, seq_len]
Returns:
@ -49,21 +41,15 @@ class DistCrossEntropy(Function):
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
# mask the target in the local device
partition_vocab_size = vocab_logits.size()[-1]
rank = dist.get_rank(group=process_group)
world_size = dist.get_world_size(group=process_group)
if vocab_size == None:
partition_vocab_size = vocab_logits.size()[-1]
global_vocab_size = partition_vocab_size * world_size
else:
global_vocab_size = vocab_size
partition_vocab_size = global_vocab_size // world_size
# [down, up) => false, other device and -100 => true
delta = (global_vocab_size + world_size - 1) // world_size
down_threshold = rank * delta
up_threshold = down_threshold + delta
if up_threshold > global_vocab_size:
up_threshold = global_vocab_size
mask = (target < down_threshold) | (target >= up_threshold)
masked_target = target.clone() - down_threshold
masked_target[mask] = 0
@ -71,8 +57,7 @@ class DistCrossEntropy(Function):
# reshape the logits and target
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
# reshape the labels to [bath_size * seq_len]
self_vocab_size = vocab_logits.size()[-1]
logits_2d = vocab_logits.view(-1, self_vocab_size)
logits_2d = vocab_logits.view(-1, partition_vocab_size)
masked_target_1d = masked_target.view(-1)
# extract the x[class] and set the x[other device] to zero
@ -87,7 +72,7 @@ class DistCrossEntropy(Function):
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
exp_logits = vocab_logits
torch.exp(vocab_logits, out=exp_logits)
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
sum_exp_logits = torch.sum(exp_logits, dim=-1)
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
# calculate the loss
@ -98,10 +83,9 @@ class DistCrossEntropy(Function):
loss = torch.sum(loss).div_(num_non_zero)
# calculate the softmax
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
exp_logits[target == ignore_index] = 0.0
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
ctx.dtype = dtype
return loss
@ -116,19 +100,14 @@ class DistCrossEntropy(Function):
partion_vocab_size = grad_logits.shape[-1]
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
update = 1.0 - mask.view(-1).float()
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
return grad_logits, None, None, None, None, None
return grad_logits, None, None, None
def cross_entropy_1d(
vocab_logits: torch.Tensor,
labels: torch.Tensor,
ignore_index: int = -100,
process_group: ProcessGroup = None,
vocab_size: int = None,
dtype: torch.dtype = None,
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)

2
colossalai/shardformer/modeling/llama.py

@ -320,8 +320,6 @@ class LlamaPipelineForwards:
shift_logits,
shift_labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=self.lm_head.out_features,
dtype=self.model.dtype,
)
else:
shift_logits = shift_logits.view(-1, self.config.vocab_size)

Loading…
Cancel
Save