mirror of https://github.com/hpcaitech/ColossalAI
[plugin]fix 3d checkpoint load when booster boost without optimizer. (#5135)
* fix 3d checkpoint load when booster boost without optimizer fix 3d checkpoint load when booster boost without optimizer * test ci * revert ci * fix fixpull/5141/head^2
parent
f6731db67c
commit
2a2ec49aa7
|
@ -21,7 +21,7 @@ from torch.utils.data.distributed import DistributedSampler
|
|||
from colossalai.amp.naive_amp.mixed_precision_optimizer import MixedPrecisionOptimizer
|
||||
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
|
||||
from colossalai.cluster import ProcessGroupMesh
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper
|
||||
from colossalai.interface import ModelWrapper, OptimizerWrapper, AMPModelMixin
|
||||
from colossalai.pipeline.schedule import InterleavedSchedule, OneForwardOneBackwardSchedule
|
||||
from colossalai.pipeline.stage_manager import PipelineStageManager
|
||||
from colossalai.shardformer import ShardConfig, ShardFormer
|
||||
|
@ -42,7 +42,7 @@ def _convert_floating_point(x, dtype: torch.dtype = torch.float16):
|
|||
return x
|
||||
|
||||
|
||||
class HybridParallelModule(ModelWrapper):
|
||||
class HybridParallelModule(ModelWrapper, AMPModelMixin):
|
||||
def __init__(
|
||||
self,
|
||||
module: Module,
|
||||
|
|
|
@ -116,6 +116,9 @@ def check_gemini_plugin(
|
|||
"transformers_falcon_for_sequence_classification",
|
||||
"transformers_falcon_for_token_classification",
|
||||
"transformers_falcon_for_question_answering",
|
||||
"transformers_gptj_lm", # lead to OOM when running in ci
|
||||
"transformers_gptj_for_question_answering",
|
||||
"transformers_gptj_for_sequence_classification",
|
||||
]:
|
||||
continue
|
||||
|
||||
|
|
Loading…
Reference in New Issue