mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] test model unittest hotfix (#1281)
parent
e56731e916
commit
79fe7b027a
|
@ -12,7 +12,7 @@ from colossalai.utils.cuda import get_current_device
|
|||
from colossalai.utils import free_port
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
|
||||
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup
|
||||
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup, ReplicaSpec
|
||||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
@ -76,22 +76,23 @@ def run_1d_hybrid_tp(model_name):
|
|||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
continue
|
||||
# print(name)
|
||||
|
||||
# num_class = type_vocab_size = 2 | (8, 2)
|
||||
# TODO(jiaruifang) has bug if open the following 2 comments
|
||||
if 'classifier' in name and 'weight' in name:
|
||||
init_1d_row_linear(p, pg)
|
||||
# num_class = vocab_size = 30524 | (30524, 8)
|
||||
if 'word_embeddings' in name and 'weight' in name:
|
||||
elif 'word_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = seq_len = 512 | (512, 8)
|
||||
if 'position_embeddings' in name and 'weight' in name:
|
||||
elif 'position_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p, pg)
|
||||
# num_class = type_vocab_size = 2 | (2, 8)
|
||||
if 'token_type_embeddings' in name and 'weight' in name:
|
||||
elif 'token_type_embeddings' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p, pg)
|
||||
if p.process_group.tp_world_size() == 1:
|
||||
p.set_process_group(pg)
|
||||
elif p.process_group.tp_world_size() == 1:
|
||||
with DistSpecManager.no_grad():
|
||||
p.redistribute(ReplicaSpec(), pg)
|
||||
|
||||
elif "simple_net" == model_name:
|
||||
# A naive way to set spec for all weights in Linear
|
||||
for name, p in model.named_parameters():
|
||||
|
|
Loading…
Reference in New Issue