|
|
|
@ -11,42 +11,13 @@ from colossalai.testing import rerun_if_address_is_in_use
|
|
|
|
|
from colossalai.utils.cuda import get_current_device |
|
|
|
|
from colossalai.utils import free_port |
|
|
|
|
from colossalai.utils.model.colo_init_context import ColoInitContext |
|
|
|
|
from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \ |
|
|
|
|
ComputeSpec, ColoTensor, DistSpecManager, ProcessGroup, ReplicaSpec |
|
|
|
|
from colossalai.tensor import ColoTensor, ProcessGroup |
|
|
|
|
from colossalai.nn.optimizer import ColoOptimizer |
|
|
|
|
|
|
|
|
|
from tests.components_to_test.registry import non_distributed_component_funcs |
|
|
|
|
from _utils import split_param_row_tp1d, split_param_col_tp1d |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup): |
|
|
|
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
|
|
|
|
with DistSpecManager.no_grad(): |
|
|
|
|
weight.set_process_group(pg) |
|
|
|
|
weight.set_tensor_spec(*spec) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_col_linear(weight, pg): |
|
|
|
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
|
|
|
|
with DistSpecManager.no_grad(): |
|
|
|
|
weight.set_process_group(pg) |
|
|
|
|
weight.set_tensor_spec(*spec) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_row_embedding(weight, pg): |
|
|
|
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
|
|
|
|
with DistSpecManager.no_grad(): |
|
|
|
|
weight.set_process_group(pg) |
|
|
|
|
weight.set_tensor_spec(*spec) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_col_embedding(weight, pg): |
|
|
|
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
|
|
|
|
with DistSpecManager.no_grad(): |
|
|
|
|
weight.set_process_group(pg) |
|
|
|
|
weight.set_tensor_spec(*spec) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_1d_hybrid_tp(model_name): |
|
|
|
|
# A simple net with two stacked nn.Linear |
|
|
|
|
get_components_func = non_distributed_component_funcs.get_callable(model_name) |
|
|
|
@ -79,19 +50,16 @@ def run_1d_hybrid_tp(model_name):
|
|
|
|
|
|
|
|
|
|
# num_class = type_vocab_size = 2 | (8, 2) |
|
|
|
|
if 'classifier' in name and 'weight' in name: |
|
|
|
|
init_1d_row_linear(p, pg) |
|
|
|
|
split_param_col_tp1d(p, pg) |
|
|
|
|
# num_class = vocab_size = 30524 | (30524, 8) |
|
|
|
|
elif 'word_embeddings' in name and 'weight' in name: |
|
|
|
|
init_1d_row_embedding(p, pg) |
|
|
|
|
split_param_row_tp1d(p, pg) |
|
|
|
|
# num_class = seq_len = 512 | (512, 8) |
|
|
|
|
elif 'position_embeddings' in name and 'weight' in name: |
|
|
|
|
init_1d_row_embedding(p, pg) |
|
|
|
|
split_param_row_tp1d(p, pg) |
|
|
|
|
# num_class = type_vocab_size = 2 | (2, 8) |
|
|
|
|
elif 'token_type_embeddings' in name and 'weight' in name: |
|
|
|
|
init_1d_col_embedding(p, pg) |
|
|
|
|
elif p.process_group.tp_world_size() == 1: |
|
|
|
|
with DistSpecManager.no_grad(): |
|
|
|
|
p.redistribute(ReplicaSpec(), pg) |
|
|
|
|
split_param_col_tp1d(p, pg) |
|
|
|
|
|
|
|
|
|
elif "simple_net" == model_name: |
|
|
|
|
# A naive way to set spec for all weights in Linear |
|
|
|
@ -99,13 +67,13 @@ def run_1d_hybrid_tp(model_name):
|
|
|
|
|
if not isinstance(p, ColoTensor): |
|
|
|
|
continue |
|
|
|
|
if 'embed' in name and 'weight' in name: |
|
|
|
|
init_1d_col_embedding(p, pg) |
|
|
|
|
split_param_col_tp1d(p, pg) |
|
|
|
|
if 'proj1' in name and ('weight' in name or 'bias' in name): |
|
|
|
|
init_1d_col_linear(p, pg) |
|
|
|
|
split_param_row_tp1d(p, pg) |
|
|
|
|
if 'proj2' in name and 'weight' in name: |
|
|
|
|
init_1d_row_linear(p, pg) |
|
|
|
|
split_param_col_tp1d(p, pg) |
|
|
|
|
if 'classifier' in name and ('weight' in name or 'bias' in name): |
|
|
|
|
init_1d_col_linear(p, pg) |
|
|
|
|
split_param_row_tp1d(p, pg) |
|
|
|
|
|
|
|
|
|
model = model.cuda() |
|
|
|
|
model.train() |
|
|
|
@ -327,9 +295,9 @@ def _run_pretrain_load():
|
|
|
|
|
|
|
|
|
|
def run_model_dist(rank, world_size, port): |
|
|
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
|
|
|
|
for name in ['bert']: |
|
|
|
|
for name in ['bert', 'simple_net']: |
|
|
|
|
run_1d_row_tp(name) |
|
|
|
|
for name in ['bert']: |
|
|
|
|
for name in ['bert', 'simple_net']: |
|
|
|
|
run_1d_hybrid_tp(name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|