diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 2af18d677..c4cf3fb85 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -15,15 +15,7 @@ class DistCrossEntropy(Function): """ @staticmethod - def forward( - ctx, - vocab_logits: torch.Tensor, - target: torch.Tensor, - ignore_index: int, - process_group: ProcessGroup, - vocab_size: int, - dtype=torch.float32, - ): + def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: int, process_group: ProcessGroup): r""" Calculate the cross entropy loss before gather, the origin loss function is as follows: loss = -log(exp(x[class])/sum(exp(x[i])) @@ -35,7 +27,7 @@ class DistCrossEntropy(Function): Args: vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is [batch_size, seq_len, vocab_size] - target (:class:`torch.Tensor`): The labels of the vocabulary, shape is + labels (:class:`torch.Tensor`): The labels of the vocabulary, shape is [batch_size, seq_len] Returns: @@ -49,21 +41,15 @@ class DistCrossEntropy(Function): vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1) # mask the target in the local device + partition_vocab_size = vocab_logits.size()[-1] rank = dist.get_rank(group=process_group) world_size = dist.get_world_size(group=process_group) - if vocab_size == None: - partition_vocab_size = vocab_logits.size()[-1] - global_vocab_size = partition_vocab_size * world_size - else: - global_vocab_size = vocab_size - partition_vocab_size = global_vocab_size // world_size + global_vocab_size = partition_vocab_size * world_size # [down, up) => false, other device and -100 => true delta = (global_vocab_size + world_size - 1) // world_size down_threshold = rank * delta up_threshold = down_threshold + delta - if up_threshold > global_vocab_size: - up_threshold = global_vocab_size mask = (target < down_threshold) | (target >= up_threshold) masked_target = target.clone() - down_threshold masked_target[mask] = 0 @@ -71,8 +57,7 @@ class DistCrossEntropy(Function): # reshape the logits and target # reshape the vocab_logits to [bath_size * seq_len, vocab_size] # reshape the labels to [bath_size * seq_len] - self_vocab_size = vocab_logits.size()[-1] - logits_2d = vocab_logits.view(-1, self_vocab_size) + logits_2d = vocab_logits.view(-1, partition_vocab_size) masked_target_1d = masked_target.view(-1) # extract the x[class] and set the x[other device] to zero @@ -87,7 +72,7 @@ class DistCrossEntropy(Function): dist.all_reduce(pred_logits, op=dist.ReduceOp.SUM, group=process_group) exp_logits = vocab_logits torch.exp(vocab_logits, out=exp_logits) - sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) + sum_exp_logits = torch.sum(exp_logits, dim=-1) dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group) # calculate the loss @@ -98,10 +83,9 @@ class DistCrossEntropy(Function): loss = torch.sum(loss).div_(num_non_zero) # calculate the softmax - exp_logits = exp_logits.div(sum_exp_logits.unsqueeze(dim=-1)).to(dtype) + exp_logits.div_(sum_exp_logits.unsqueeze(dim=-1)) exp_logits[target == ignore_index] = 0.0 ctx.save_for_backward(exp_logits, mask, masked_target_1d) - ctx.dtype = dtype return loss @@ -116,19 +100,14 @@ class DistCrossEntropy(Function): partion_vocab_size = grad_logits.shape[-1] grad_logits_2d = grad_logits.view(-1, partion_vocab_size) - update = 1.0 - mask.view(-1).float().to(ctx.dtype) + update = 1.0 - mask.view(-1).float() grad_logits_2d[torch.arange(0, grad_logits_2d.shape[0]), masked_target_1d] -= update grad_logits.mul_(grad_output.unsqueeze(dim=-1)) - return grad_logits, None, None, None, None, None + return grad_logits, None, None, None def cross_entropy_1d( - vocab_logits: torch.Tensor, - labels: torch.Tensor, - ignore_index: int = -100, - process_group: ProcessGroup = None, - vocab_size: int = None, - dtype: torch.dtype = None, + vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None ) -> torch.Tensor: - return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype) \ No newline at end of file + return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 6502457f2..b9342df04 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -320,8 +320,6 @@ class LlamaPipelineForwards: shift_logits, shift_labels, process_group=shard_config.tensor_parallel_process_group, - vocab_size=self.lm_head.out_features, - dtype=self.model.dtype, ) else: shift_logits = shift_logits.view(-1, self.config.vocab_size)