Merge branch 'dev/zero-offload' into offload

pull/5844/head
Wang Binluo 2024-06-20 17:07:24 +08:00 committed by GitHub
commit e893f88a4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 5 additions and 6 deletions

View File

@ -110,4 +110,6 @@ class DistCrossEntropy(Function):
def cross_entropy_1d( def cross_entropy_1d(
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
) -> torch.Tensor: ) -> torch.Tensor:
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group) return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)

View File

@ -24,10 +24,7 @@ from transformers.models.llama.modeling_llama import (
from transformers.utils import logging from transformers.utils import logging
from colossalai.pipeline.stage_manager import PipelineStageManager from colossalai.pipeline.stage_manager import PipelineStageManager
from colossalai.shardformer.layer._operation import ( from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
gather_forward_split_backward,
split_forward_gather_backward,
)
from colossalai.shardformer.shard import ShardConfig from colossalai.shardformer.shard import ShardConfig
from ..layer import ColoAttention, cross_entropy_1d from ..layer import ColoAttention, cross_entropy_1d