mirror of https://github.com/hpcaitech/ColossalAI
Merge branch 'dev/zero-offload' into offload
commit
e893f88a4f
|
@ -110,4 +110,6 @@ class DistCrossEntropy(Function):
|
||||||
def cross_entropy_1d(
|
def cross_entropy_1d(
|
||||||
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
|
vocab_logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100, process_group: ProcessGroup = None
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
|
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group)
|
||||||
|
|
||||||
|
|
|
@ -24,10 +24,7 @@ from transformers.models.llama.modeling_llama import (
|
||||||
from transformers.utils import logging
|
from transformers.utils import logging
|
||||||
|
|
||||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||||
from colossalai.shardformer.layer._operation import (
|
from colossalai.shardformer.layer._operation import gather_forward_split_backward, split_forward_gather_backward
|
||||||
gather_forward_split_backward,
|
|
||||||
split_forward_gather_backward,
|
|
||||||
)
|
|
||||||
from colossalai.shardformer.shard import ShardConfig
|
from colossalai.shardformer.shard import ShardConfig
|
||||||
|
|
||||||
from ..layer import ColoAttention, cross_entropy_1d
|
from ..layer import ColoAttention, cross_entropy_1d
|
||||||
|
|
Loading…
Reference in New Issue