mirror of https://github.com/hpcaitech/ColossalAI
remove vocab_size args
parent
b12e9a3275
commit
dba59354d7
|
@ -15,15 +15,7 @@ class DistCrossEntropy(Function):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(
|
def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup):
|
||||||
ctx,
|
|
||||||
vocab_logits: torch.Tensor,
|
|
||||||
target: torch.Tensor,
|
|
||||||
ignore_index: int,
|
|
||||||
process_group: ProcessGroup,
|
|
||||||
vocab_size: int,
|
|
||||||
dtype=torch.float32,
|
|
||||||
):
|
|
||||||
r"""
|
r"""
|
||||||
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
Calculate the cross entropy loss before gather, the origin loss function is as follows:
|
||||||
loss = -log(exp(x[class])/sum(exp(x[i]))
|
loss = -log(exp(x[class])/sum(exp(x[i]))
|
||||||
|
@ -35,7 +27,7 @@ class DistCrossEntropy(Function):
|
||||||
Args:
|
Args:
|
||||||
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
|
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
|
||||||
[batch_size, seq_len, vocab_size]
|
[batch_size, seq_len, vocab_size]
|
||||||
target (:class:`torch.Tensor`): The labels of the vocabulary, shape is
|
labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is
|
||||||
[batch_size, seq_len]
|
[batch_size, seq_len]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
|
@ -49,21 +41,15 @@ class DistCrossEntropy(Function):
|
||||||
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
|
||||||
|
|
||||||
# mask the target in the local device
|
# mask the target in the local device
|
||||||
|
partition_vocab_size = vocab_logits.size()[-1]
|
||||||
rank = dist.get_rank(group=process_group)
|
rank = dist.get_rank(group=process_group)
|
||||||
world_size = dist.get_world_size(group=process_group)
|
world_size = dist.get_world_size(group=process_group)
|
||||||
if vocab_size == None:
|
global_vocab_size = partition_vocab_size * world_size
|
||||||
partition_vocab_size = vocab_logits.size()[-1]
|
|
||||||
global_vocab_size = partition_vocab_size * world_size
|
|
||||||
else:
|
|
||||||
global_vocab_size = vocab_size
|
|
||||||
partition_vocab_size = global_vocab_size // world_size
|
|
||||||
|
|
||||||
# [down, up) => false, other device and -100 => true
|
# [down, up) => false, other device and -100 => true
|
||||||
delta = (global_vocab_size + world_size - 1) // world_size
|
delta = (global_vocab_size + world_size - 1) // world_size
|
||||||
down_threshold = rank * delta
|
down_threshold = rank * delta
|
||||||
up_threshold = down_threshold + delta
|
up_threshold = down_threshold + delta
|
||||||
if up_threshold > global_vocab_size:
|
|
||||||
up_threshold = global_vocab_size
|
|
||||||
mask = (target < down_threshold) | (target >= up_threshold)
|
mask = (target < down_threshold) | (target >= up_threshold)
|
||||||
masked_target = target.clone() - down_threshold
|
masked_target = target.clone() - down_threshold
|
||||||
masked_target[mask] = 0
|
masked_target[mask] = 0
|
||||||
|
@ -71,8 +57,7 @@ class DistCrossEntropy(Function):
|
||||||
# reshape the logits and target
|
# reshape the logits and target
|
||||||
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
# reshape the vocab_logits to [bath_size * seq_len, vocab_size]
|
||||||
# reshape the labels to [bath_size * seq_len]
|
# reshape the labels to [bath_size * seq_len]
|
||||||
self_vocab_size = vocab_logits.size()[-1]
|
logits_2d = vocab_logits.view(-1, partition_vocab_size)
|
||||||
logits_2d = vocab_logits.view(-1, self_vocab_size)
|
|
||||||
masked_target_1d = masked_target.view(-1)
|
masked_target_1d = masked_target.view(-1)
|
||||||
|
|
||||||
# extract the x[class] and set the x[other device] to zero
|
# extract the x[class] and set the x[other device] to zero
|
||||||
|
@ -87,7 +72,7 @@ class DistCrossEntropy(Function):
|
||||||
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
|
dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group)
|
||||||
exp_logits = vocab_logits
|
exp_logits = vocab_logits
|
||||||
torch.exp(vocab_logits, out=exp_logits)
|
torch.exp(vocab_logits, out=exp_logits)
|
||||||
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32)
|
sum_exp_logits = torch.sum(exp_logits, dim=-1)
|
||||||
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
|
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
|
||||||
|
|
||||||
# calculate the loss
|
# calculate the loss
|
||||||
|
@ -98,10 +83,9 @@ class DistCrossEntropy(Function):
|
||||||
loss = torch.sum(loss).div_(num_non_zero)
|
loss = torch.sum(loss).div_(num_non_zero)
|
||||||
|
|
||||||
# calculate the softmax
|
# calculate the softmax
|
||||||
exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype)
|
exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1))
|
||||||
exp_logits[target == ignore_index] = 0.0
|
exp_logits[target == ignore_index] = 0.0
|
||||||
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
ctx.save_for_backward(exp_logits, mask, masked_target_1d)
|
||||||
ctx.dtype = dtype
|
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
|
@ -116,19 +100,14 @@ class DistCrossEntropy(Function):
|
||||||
partion_vocab_size = grad_logits.shape[-1]
|
partion_vocab_size = grad_logits.shape[-1]
|
||||||
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
|
grad_logits_2d = grad_logits.view(-1, partion_vocab_size)
|
||||||
|
|
||||||
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
|
update = 1.0 - mask.view(-1).float()
|
||||||
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update
|
||||||
|
|
||||||
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
grad_logits.mul_(grad_output.unsqueeze(dim=-1))
|
||||||
return grad_logits, None, None, None, None, None
|
return grad_logits, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def cross_entropy_1d(
|
def cross_entropy_1d(
|
||||||
vocab_logits: torch.Tensor,
|
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
|
||||||
labels: torch.Tensor,
|
|
||||||
ignore_index: int = -100,
|
|
||||||
process_group: ProcessGroup = None,
|
|
||||||
vocab_size: int = None,
|
|
||||||
dtype: torch.dtype = None,
|
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype)
|
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
|
||||||
|
|
|
@ -320,8 +320,6 @@ class LlamaPipelineForwards:
|
||||||
shift_logits,
|
shift_logits,
|
||||||
shift_labels,
|
shift_labels,
|
||||||
process_group=shard_config.tensor_parallel_process_group,
|
process_group=shard_config.tensor_parallel_process_group,
|
||||||
vocab_size=self.lm_head.out_features,
|
|
||||||
dtype=self.model.dtype,
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
shift_logits = shift_logits.view(-1, self.config.vocab_size)
|
||||||
|
|
Loading…
Reference in New Issue