mirror of https://github.com/hpcaitech/ColossalAI
[release] update version (#6195)
* [release] update version * fix test * fix testpull/6210/head v0.4.8
parent
24dee8f0b7
commit
9379cbd668
|
@ -5,7 +5,6 @@ from .bloom import *
|
|||
from .chatglm2 import *
|
||||
from .command import *
|
||||
from .deepseek import *
|
||||
from .deepseek_v3 import *
|
||||
from .falcon import *
|
||||
from .gpt import *
|
||||
from .gptj import *
|
||||
|
|
|
@ -5,8 +5,6 @@ import torch
|
|||
import transformers
|
||||
from transformers import AutoConfig
|
||||
|
||||
from ..registry import ModelAttribute, model_zoo
|
||||
|
||||
# ===============================
|
||||
# Register single-sentence Mixtral
|
||||
# ===============================
|
||||
|
@ -75,13 +73,3 @@ def init_deepseek():
|
|||
if m.__class__.__name__ == "DeepseekV3MoE":
|
||||
m.moe_infer = MethodType(m.moe_infer.__wrapped__, m)
|
||||
return model
|
||||
|
||||
|
||||
model_zoo.register(
|
||||
name="transformers_deepseek_v3",
|
||||
model_fn=init_deepseek,
|
||||
data_gen_fn=data_gen_for_lm,
|
||||
output_transform_fn=output_transform_fn,
|
||||
loss_fn=loss_fn_for_lm,
|
||||
model_attribute=ModelAttribute(has_control_flow=True),
|
||||
)
|
||||
|
|
|
@ -11,7 +11,12 @@ from colossalai.booster.plugin import MoeHybridParallelPlugin
|
|||
from colossalai.booster.plugin.moe_hybrid_parallel_plugin import MoeHybridParallelPlugin
|
||||
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
|
||||
from colossalai.testing.random import seed_all
|
||||
from tests.kit.model_zoo import model_zoo
|
||||
from tests.kit.model_zoo.transformers.deepseek_v3 import (
|
||||
data_gen_for_lm,
|
||||
init_deepseek,
|
||||
loss_fn_for_lm,
|
||||
output_transform_fn,
|
||||
)
|
||||
from tests.test_shardformer.test_model._utils import (
|
||||
build_model_from_hybrid_plugin,
|
||||
run_forward_backward_with_hybrid_plugin,
|
||||
|
@ -74,16 +79,13 @@ def run_deepseek_v3_test(config: Tuple[int, ...]):
|
|||
find_unused_parameters=True,
|
||||
)
|
||||
|
||||
sub_model_zoo = model_zoo.get_sub_registry("transformers_deepseek_v3")
|
||||
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
|
||||
|
||||
check_forward_backward(
|
||||
model_fn,
|
||||
data_gen_fn,
|
||||
output_transform_fn,
|
||||
loss_fn,
|
||||
plugin_config,
|
||||
)
|
||||
check_forward_backward(
|
||||
init_deepseek,
|
||||
data_gen_for_lm,
|
||||
output_transform_fn,
|
||||
loss_fn_for_lm,
|
||||
plugin_config,
|
||||
)
|
||||
|
||||
|
||||
def check_deepseek_v3(rank, world_size, port):
|
||||
|
|
|
@ -1 +1 @@
|
|||
0.4.7
|
||||
0.4.8
|
||||
|
|
Loading…
Reference in New Issue