|
|
|
@ -16,6 +16,7 @@ from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
|
|
|
|
|
from colossalai.nn.optimizer import ColoOptimizer |
|
|
|
|
|
|
|
|
|
from tests.components_to_test.registry import non_distributed_component_funcs |
|
|
|
|
from _utils import split_param_row_tp1d, split_param_col_tp1d |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup): |
|
|
|
@ -50,7 +51,9 @@ def run_1d_hybrid_tp(model_name):
|
|
|
|
|
# A simple net with two stacked nn.Linear |
|
|
|
|
get_components_func = non_distributed_component_funcs.get_callable(model_name) |
|
|
|
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func() |
|
|
|
|
|
|
|
|
|
rank = torch.distributed.get_rank() |
|
|
|
|
world_size = torch.distributed.get_world_size() |
|
|
|
|
|
|
|
|
|
set_seed(1) |
|
|
|
|
with ColoInitContext(device=get_current_device()): |
|
|
|
@ -59,14 +62,15 @@ def run_1d_hybrid_tp(model_name):
|
|
|
|
|
if rank == 0: |
|
|
|
|
model_torch = model_builder(checkpoint=True) |
|
|
|
|
model_torch = model_torch.cuda() |
|
|
|
|
colo_optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1) |
|
|
|
|
optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1) |
|
|
|
|
|
|
|
|
|
# Make two models have the same init params |
|
|
|
|
for p1, p2 in zip(model.parameters(), model_torch.parameters()): |
|
|
|
|
p2.data.copy_(p1.data) |
|
|
|
|
else: |
|
|
|
|
model_torch = None |
|
|
|
|
optimizer_torch = None |
|
|
|
|
|
|
|
|
|
rank = torch.distributed.get_rank() |
|
|
|
|
world_size = torch.distributed.get_world_size() |
|
|
|
|
pg = ProcessGroup(tp_degree=world_size) |
|
|
|
|
if 'bert' == model_name: |
|
|
|
|
for name, p in model.named_parameters(): |
|
|
|
@ -75,8 +79,8 @@ def run_1d_hybrid_tp(model_name):
|
|
|
|
|
# print(name) |
|
|
|
|
# num_class = type_vocab_size = 2 | (8, 2) |
|
|
|
|
# TODO(jiaruifang) has bug if open the following 2 comments |
|
|
|
|
# if 'classifier' in name and 'weight' in name: |
|
|
|
|
# init_1d_row_linear(p, pg) |
|
|
|
|
if 'classifier' in name and 'weight' in name: |
|
|
|
|
init_1d_row_linear(p, pg) |
|
|
|
|
# num_class = vocab_size = 30524 | (30524, 8) |
|
|
|
|
if 'word_embeddings' in name and 'weight' in name: |
|
|
|
|
init_1d_row_embedding(p, pg) |
|
|
|
@ -86,6 +90,8 @@ def run_1d_hybrid_tp(model_name):
|
|
|
|
|
# num_class = type_vocab_size = 2 | (2, 8) |
|
|
|
|
if 'token_type_embeddings' in name and 'weight' in name: |
|
|
|
|
init_1d_col_embedding(p, pg) |
|
|
|
|
if p.process_group.tp_world_size() == 1: |
|
|
|
|
p.set_process_group(pg) |
|
|
|
|
elif "simple_net" == model_name: |
|
|
|
|
# A naive way to set spec for all weights in Linear |
|
|
|
|
for name, p in model.named_parameters(): |
|
|
|
@ -101,13 +107,18 @@ def run_1d_hybrid_tp(model_name):
|
|
|
|
|
init_1d_col_linear(p, pg) |
|
|
|
|
|
|
|
|
|
model = model.cuda() |
|
|
|
|
model.train() |
|
|
|
|
if rank == 0: |
|
|
|
|
model_torch.train() |
|
|
|
|
|
|
|
|
|
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1) |
|
|
|
|
|
|
|
|
|
for i, (data, label) in enumerate(train_dataloader): |
|
|
|
|
model.eval() |
|
|
|
|
|
|
|
|
|
# Zero grad |
|
|
|
|
colo_optimizer.zero_grad() |
|
|
|
|
if rank == 0: |
|
|
|
|
model_torch.eval() |
|
|
|
|
colo_optimizer_torch.zero_grad() |
|
|
|
|
optimizer_torch.zero_grad() |
|
|
|
|
|
|
|
|
|
data = data.to(get_current_device()) |
|
|
|
|
label = label.to(get_current_device()) |
|
|
|
@ -123,7 +134,7 @@ def run_1d_hybrid_tp(model_name):
|
|
|
|
|
output = model(data, label) |
|
|
|
|
loss = output |
|
|
|
|
|
|
|
|
|
# For reference |
|
|
|
|
# Test output |
|
|
|
|
if rank == 0: |
|
|
|
|
if criterion: |
|
|
|
|
output_torch = model_torch(data) |
|
|
|
@ -131,17 +142,14 @@ def run_1d_hybrid_tp(model_name):
|
|
|
|
|
else: |
|
|
|
|
output_torch = model_torch(data, label) |
|
|
|
|
loss_torch = output_torch |
|
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
|
|
with torch.no_grad(): |
|
|
|
|
assert torch.allclose(loss, loss_torch, rtol=1e-2) |
|
|
|
|
assert torch.allclose(loss, loss_torch, rtol=1e-2) |
|
|
|
|
|
|
|
|
|
loss.backward() |
|
|
|
|
colo_optimizer.step() |
|
|
|
|
|
|
|
|
|
if rank == 0: |
|
|
|
|
loss_torch.backward() |
|
|
|
|
colo_optimizer_torch.step() |
|
|
|
|
optimizer_torch.step() |
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
# check param |
|
|
|
@ -231,14 +239,19 @@ def run_1d_row_tp(model_name: str):
|
|
|
|
|
if rank == 0: |
|
|
|
|
model_torch = model_builder(checkpoint=True) |
|
|
|
|
model_torch = model_torch.cuda() |
|
|
|
|
|
|
|
|
|
# A naive way to set spec for all weights in Linear |
|
|
|
|
for name, p in model.named_parameters(): |
|
|
|
|
if not isinstance(p, ColoTensor): |
|
|
|
|
continue |
|
|
|
|
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name: |
|
|
|
|
init_1d_row_linear(p, pg) |
|
|
|
|
if 'embed' in name and 'weight' in name: |
|
|
|
|
init_1d_row_embedding(p, pg) |
|
|
|
|
for mo_name, module in model.named_modules(): |
|
|
|
|
# print(mo_name) |
|
|
|
|
for pa_name, param in module.named_parameters(recurse=False): |
|
|
|
|
# print('\t', pa_name, param.shape) |
|
|
|
|
if not isinstance(param, ColoTensor): |
|
|
|
|
continue |
|
|
|
|
if 'weight' in pa_name: |
|
|
|
|
if 'embed' in mo_name and 'token' not in mo_name and 'LayerNorm' not in mo_name: |
|
|
|
|
split_param_row_tp1d(param, pg) |
|
|
|
|
elif 'LayerNorm' not in mo_name and 'ln' not in mo_name: |
|
|
|
|
split_param_col_tp1d(param, pg) |
|
|
|
|
|
|
|
|
|
model = model.cuda() |
|
|
|
|
|
|
|
|
@ -313,9 +326,9 @@ def _run_pretrain_load():
|
|
|
|
|
|
|
|
|
|
def run_model_dist(rank, world_size, port): |
|
|
|
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') |
|
|
|
|
for name in ['simple_net']: |
|
|
|
|
for name in ['bert']: |
|
|
|
|
run_1d_row_tp(name) |
|
|
|
|
for name in ['simple_net']: |
|
|
|
|
for name in ['bert']: |
|
|
|
|
run_1d_hybrid_tp(name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|