mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix bert model test in unitests (#1272)
parent
01ea68b2e6
commit
abba4d84e1
|
@ -42,10 +42,10 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||
norm_type: float = 2.0,
|
||||
scale_grad_by_freq: bool = False,
|
||||
sparse: bool = False) -> ColoTensor:
|
||||
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
|
||||
# Find index in this shard and mask those not here
|
||||
# Reduce all
|
||||
pg = weight.get_process_group()
|
||||
# embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim]
|
||||
# get the index of current segment and mask other segments with 0
|
||||
|
||||
# get complete input tensor through all-gather
|
||||
input_tensor = input_tensor.redistribute(ReplicaSpec())
|
||||
|
||||
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
@ -54,12 +54,11 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
|
|||
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
|
||||
vocab_end_index = vocab_start_index + num_embeddings_per_partition
|
||||
|
||||
# Build the mask.
|
||||
input_mask = (input_tensor < vocab_start_index) | \
|
||||
(input_tensor >= vocab_end_index)
|
||||
# Mask the input.
|
||||
# build the mask.
|
||||
input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
|
||||
# mask the input.
|
||||
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
|
||||
masked_input = input_tensor.clone() - vocab_start_index
|
||||
masked_input = input_tensor - vocab_start_index
|
||||
masked_input[input_mask] = 0
|
||||
|
||||
partial_output = F.embedding(masked_input,
|
||||
|
|
|
@ -5,6 +5,7 @@ import torch
|
|||
import torch.distributed as dist
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context import ParallelMode
|
||||
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern
|
||||
|
||||
|
||||
def set_seed(seed):
|
||||
|
@ -57,3 +58,18 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si
|
|||
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def split_param_single_dim_tp1d(dim, param, pg):
|
||||
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
if param.process_group.tp_world_size() == 1:
|
||||
param.set_process_group(pg)
|
||||
param.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def split_param_row_tp1d(param, pg):
|
||||
split_param_single_dim_tp1d(0, param, pg)
|
||||
|
||||
|
||||
def split_param_col_tp1d(param, pg):
|
||||
split_param_single_dim_tp1d(-1, param, pg)
|
||||
|
|
|
@ -4,12 +4,11 @@ import pytest
|
|||
import torch.nn as nn
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.tensor import ColoTensor, ProcessGroup
|
||||
from colossalai.tensor import ShardSpec
|
||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager
|
||||
from colossalai.tensor import ColoTensorSpec
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from functools import partial
|
||||
from _utils import tensor_shard_equal, tensor_equal
|
||||
from _utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d
|
||||
|
||||
|
||||
class Conv1D(nn.Module):
|
||||
|
@ -36,20 +35,7 @@ class Conv1D(nn.Module):
|
|||
return x
|
||||
|
||||
|
||||
def init_1d_row(weight, bias, pg: ProcessGroup):
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
bias.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
def run_with_spec(spec_init_func, split_bias):
|
||||
model = Conv1D(4, 16).cuda()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
|
@ -57,7 +43,10 @@ def run_with_spec(spec_init_func):
|
|||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))
|
||||
|
||||
spec_init_func(weight, bias, pg)
|
||||
spec_init_func(weight, pg)
|
||||
if split_bias:
|
||||
spec_init_func(bias, pg)
|
||||
|
||||
x = torch.rand(2, 16).cuda()
|
||||
out = model(x)
|
||||
colo_out = torch.addmm(bias, x, weight)
|
||||
|
@ -72,8 +61,8 @@ def run_with_spec(spec_init_func):
|
|||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(init_1d_row)
|
||||
run_with_spec(init_1d_col)
|
||||
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=False)
|
||||
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import torch
|
||||
from colossalai.tensor import ShardSpec, ColoParameter
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
|
||||
|
@ -9,21 +7,17 @@ import torch
|
|||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_col(weight, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
|
||||
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
model = torch.nn.EmbeddingBag(10, 4).cuda()
|
||||
weight = ColoParameter(model.weight.clone(), True, ColoTensorSpec(pg))
|
||||
|
||||
spec_init_func(weight, pg)
|
||||
|
||||
inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda()
|
||||
offsets = torch.tensor([0, 4]).cuda()
|
||||
out = model(inputs, offsets=offsets)
|
||||
|
@ -38,7 +32,7 @@ def run_with_spec(spec_init_func):
|
|||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(init_1d_col)
|
||||
run_with_spec(split_param_col_tp1d)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
import torch
|
||||
from colossalai.tensor import ColoTensor, ShardSpec
|
||||
from torch.nn import functional as F
|
||||
from functools import partial
|
||||
|
||||
|
@ -9,26 +7,16 @@ import torch
|
|||
import torch.multiprocessing as mp
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
|
||||
|
||||
def init_1d_row(weight, pg: ProcessGroup):
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
|
||||
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func, pg: ProcessGroup):
|
||||
model = torch.nn.Embedding(12, 32).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
|
||||
|
||||
spec_init_func(weight, pg)
|
||||
|
||||
x = torch.tensor((0, 3, 6, 9)).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.embedding(x, weight)
|
||||
|
@ -44,8 +32,8 @@ def run_dist(rank, world_size, port):
|
|||
# config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
run_with_spec(init_1d_row, pg)
|
||||
run_with_spec(init_1d_col, pg)
|
||||
run_with_spec(split_param_row_tp1d, pg)
|
||||
run_with_spec(split_param_col_tp1d, pg)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -1,6 +1,3 @@
|
|||
import torch
|
||||
from colossalai.tensor import ColoTensor, ShardSpec
|
||||
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
|
@ -10,29 +7,20 @@ import torch.multiprocessing as mp
|
|||
import torch.nn.functional as F
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
|
||||
from _utils import tensor_equal, tensor_shard_equal
|
||||
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
|
||||
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
|
||||
|
||||
|
||||
def init_1d_row(weight, bias, pg: ProcessGroup):
|
||||
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def init_1d_col(weight, bias, pg: ProcessGroup):
|
||||
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
|
||||
with DistSpecManager.no_grad():
|
||||
weight.set_tensor_spec(*spec)
|
||||
bias.set_tensor_spec(*spec)
|
||||
|
||||
|
||||
def run_with_spec(spec_init_func):
|
||||
def run_with_spec(spec_init_func, split_bias):
|
||||
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
|
||||
model = torch.nn.Linear(4, 8).cuda()
|
||||
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
|
||||
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))
|
||||
spec_init_func(weight, bias, pg)
|
||||
|
||||
spec_init_func(weight, pg)
|
||||
if split_bias:
|
||||
spec_init_func(bias, pg)
|
||||
|
||||
x = torch.rand(2, 4).cuda()
|
||||
out = model(x)
|
||||
colo_out = F.linear(x, weight, bias)
|
||||
|
@ -48,8 +36,8 @@ def run_with_spec(spec_init_func):
|
|||
def run_dist(rank, world_size, port):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_spec(init_1d_row)
|
||||
run_with_spec(init_1d_col)
|
||||
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=False)
|
||||
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=True)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
|
|
|
@ -16,6 +16,7 @@ from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
|
|||
from colossalai.nn.optimizer import ColoOptimizer
|
||||
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
from _utils import split_param_row_tp1d, split_param_col_tp1d
|
||||
|
||||
|
||||
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
|
||||
|
@ -50,7 +51,9 @@ def run_1d_hybrid_tp(model_name):
|
|||
# A simple net with two stacked nn.Linear
|
||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
|
@ -59,14 +62,15 @@ def run_1d_hybrid_tp(model_name):
|
|||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
model_torch = model_torch.cuda()
|
||||
colo_optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
|
||||
# Make two models have the same init params
|
||||
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
else:
|
||||
model_torch = None
|
||||
optimizer_torch = None
|
||||
|
||||
rank = torch.distributed.get_rank()
|
||||
world_size = torch.distributed.get_world_size()
|
||||
pg = ProcessGroup(tp_degree=world_size)
|
||||
if 'bert' == model_name:
|
||||
for name, p in model.named_parameters():
|
||||
|
@ -75,8 +79,8 @@ def run_1d_hybrid_tp(model_name):
|
|||
# print(name)
|
||||
# num_class = type_vocab_size = 2 | (8, 2)
|
||||
# TODO(jiaruifang) has bug if open the following 2 comments
|
||||
# if 'classifier' in name and 'weight' in name:
|
||||
# init_1d_row_linear(p, pg)
|
||||
if 'classifier' in name and 'weight' in name:
|
||||
init_1d_row_linear(p, pg)
|
||||
# num_class = vocab_size = 30524 | (30524, 8)
|
||||
if 'word_embeddings' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p, pg)
|
||||
|
@ -86,6 +90,8 @@ def run_1d_hybrid_tp(model_name):
|
|||
# num_class = type_vocab_size = 2 | (2, 8)
|
||||
if 'token_type_embeddings' in name and 'weight' in name:
|
||||
init_1d_col_embedding(p, pg)
|
||||
if p.process_group.tp_world_size() == 1:
|
||||
p.set_process_group(pg)
|
||||
elif "simple_net" == model_name:
|
||||
# A naive way to set spec for all weights in Linear
|
||||
for name, p in model.named_parameters():
|
||||
|
@ -101,13 +107,18 @@ def run_1d_hybrid_tp(model_name):
|
|||
init_1d_col_linear(p, pg)
|
||||
|
||||
model = model.cuda()
|
||||
model.train()
|
||||
if rank == 0:
|
||||
model_torch.train()
|
||||
|
||||
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
model.eval()
|
||||
|
||||
# Zero grad
|
||||
colo_optimizer.zero_grad()
|
||||
if rank == 0:
|
||||
model_torch.eval()
|
||||
colo_optimizer_torch.zero_grad()
|
||||
optimizer_torch.zero_grad()
|
||||
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
@ -123,7 +134,7 @@ def run_1d_hybrid_tp(model_name):
|
|||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
# For reference
|
||||
# Test output
|
||||
if rank == 0:
|
||||
if criterion:
|
||||
output_torch = model_torch(data)
|
||||
|
@ -131,9 +142,6 @@ def run_1d_hybrid_tp(model_name):
|
|||
else:
|
||||
output_torch = model_torch(data, label)
|
||||
loss_torch = output_torch
|
||||
|
||||
if rank == 0:
|
||||
with torch.no_grad():
|
||||
assert torch.allclose(loss, loss_torch, rtol=1e-2)
|
||||
|
||||
loss.backward()
|
||||
|
@ -141,7 +149,7 @@ def run_1d_hybrid_tp(model_name):
|
|||
|
||||
if rank == 0:
|
||||
loss_torch.backward()
|
||||
colo_optimizer_torch.step()
|
||||
optimizer_torch.step()
|
||||
|
||||
with torch.no_grad():
|
||||
# check param
|
||||
|
@ -231,14 +239,19 @@ def run_1d_row_tp(model_name: str):
|
|||
if rank == 0:
|
||||
model_torch = model_builder(checkpoint=True)
|
||||
model_torch = model_torch.cuda()
|
||||
|
||||
# A naive way to set spec for all weights in Linear
|
||||
for name, p in model.named_parameters():
|
||||
if not isinstance(p, ColoTensor):
|
||||
for mo_name, module in model.named_modules():
|
||||
# print(mo_name)
|
||||
for pa_name, param in module.named_parameters(recurse=False):
|
||||
# print('\t', pa_name, param.shape)
|
||||
if not isinstance(param, ColoTensor):
|
||||
continue
|
||||
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
|
||||
init_1d_row_linear(p, pg)
|
||||
if 'embed' in name and 'weight' in name:
|
||||
init_1d_row_embedding(p, pg)
|
||||
if 'weight' in pa_name:
|
||||
if 'embed' in mo_name and 'token' not in mo_name and 'LayerNorm' not in mo_name:
|
||||
split_param_row_tp1d(param, pg)
|
||||
elif 'LayerNorm' not in mo_name and 'ln' not in mo_name:
|
||||
split_param_col_tp1d(param, pg)
|
||||
|
||||
model = model.cuda()
|
||||
|
||||
|
@ -313,9 +326,9 @@ def _run_pretrain_load():
|
|||
|
||||
def run_model_dist(rank, world_size, port):
|
||||
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
for name in ['simple_net']:
|
||||
for name in ['bert']:
|
||||
run_1d_row_tp(name)
|
||||
for name in ['simple_net']:
|
||||
for name in ['bert']:
|
||||
run_1d_hybrid_tp(name)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue