[test] Fix/fix testcase (#5770)

* [fix] branch for fix testcase;

* [fix] fix test_analyzer & test_auto_parallel;

* [fix] remove local change about moe;

* [fix] rm local change moe;
pull/5774/head
duanjunwen 6 months ago committed by GitHub
parent 3f2be80530
commit 1b76564e16
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -469,4 +469,4 @@ class ActivationCheckpointCodeGen(CodeGen):
{wrap_stmts} {wrap_stmts}
{prologue} {prologue}
{code}""" {code}"""
return PythonCode(fn_code, globals_) return PythonCode(fn_code, globals_, {})

@ -859,7 +859,7 @@ if CODEGEN_AVAILABLE:
{wrap_stmts} {wrap_stmts}
{prologue} {prologue}
{code}""" {code}"""
return PythonCode(fn_code, globals_) return PythonCode(fn_code, globals_, {})
else: else:

@ -2,7 +2,7 @@ import torch
import torch.nn as nn import torch.nn as nn
from transformers import BertConfig, BertLMHeadModel, GPT2Config, GPT2LMHeadModel from transformers import BertConfig, BertLMHeadModel, GPT2Config, GPT2LMHeadModel
from tests.components_to_test.registry import non_distributed_component_funcs # from tests.components_to_test.registry import non_distributed_component_funcs
class GPTLMModel(nn.Module): class GPTLMModel(nn.Module):
@ -55,7 +55,7 @@ class BertLMModel(nn.Module):
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0] return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=True)[0]
@non_distributed_component_funcs.register(name="bert_") # @non_distributed_component_funcs.register(name="bert_")
def get_bert_components(): def get_bert_components():
vocab_size = 1024 vocab_size = 1024
seq_len = 64 seq_len = 64
@ -74,7 +74,7 @@ def get_bert_components():
return bert_model_builder, bert_data_gen return bert_model_builder, bert_data_gen
@non_distributed_component_funcs.register(name="gpt2_") # @non_distributed_component_funcs.register(name="gpt2_")
def get_gpt2_components(): def get_gpt2_components():
vocab_size = 1024 vocab_size = 1024
seq_len = 8 seq_len = 8

@ -10,11 +10,14 @@ from colossalai.auto_parallel.offload.amp_optimizer import AMPOptimizer
from colossalai.auto_parallel.offload.mem_optimize import memory_optimize from colossalai.auto_parallel.offload.mem_optimize import memory_optimize
from colossalai.auto_parallel.offload.solver import NOT_NVML from colossalai.auto_parallel.offload.solver import NOT_NVML
from colossalai.fx.profiler import parameter_size from colossalai.fx.profiler import parameter_size
from colossalai.legacy.zero.gemini.colo_init_context import ColoInitContext
from colossalai.nn.optimizer import HybridAdam from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.zero import ColoInitContext, zero_model_wrapper, zero_optim_wrapper from colossalai.utils import set_seed
from colossalai.zero import zero_model_wrapper, zero_optim_wrapper
from tests.test_auto_parallel.test_offload.model_utils import * from tests.test_auto_parallel.test_offload.model_utils import *
from tests.test_tensor.common_utils import set_seed
# from tests.test_tensor.common_utils import set_seed
@parameterize("model_name", ["gpt2_"]) @parameterize("model_name", ["gpt2_"])

Loading…
Cancel
Save