Browse Source

[hotfix] fix bert model test in unitests (#1272)

pull/1279/head^2
HELSON 2 years ago committed by GitHub
parent
commit
abba4d84e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 17
      colossalai/nn/_ops/embedding.py
  2. 16
      tests/test_tensor/_utils/_util.py
  3. 29
      tests/test_tensor/test_addmm_tp.py
  4. 16
      tests/test_tensor/test_embedding_bag_tp.py
  5. 24
      tests/test_tensor/test_embedding_tp.py
  6. 32
      tests/test_tensor/test_linear_tp.py
  7. 59
      tests/test_tensor/test_model.py

17
colossalai/nn/_ops/embedding.py

@ -42,10 +42,10 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
norm_type: float = 2.0,
scale_grad_by_freq: bool = False,
sparse: bool = False) -> ColoTensor:
# embedding_1Drow split the weight(lookup table) to (num_embeddings/P, embedding_dim)
# Find index in this shard and mask those not here
# Reduce all
pg = weight.get_process_group()
# embedding_1Drow splits the weight(lookup table) to the shape, [num_embeddings/P, embedding_dim]
# get the index of current segment and mask other segments with 0
# get complete input tensor through all-gather
input_tensor = input_tensor.redistribute(ReplicaSpec())
# tensor_parallel_rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
@ -54,12 +54,11 @@ def colo_embedding_1Drow(input_tensor: ColoTensor,
vocab_start_index = tensor_parallel_rank * num_embeddings_per_partition
vocab_end_index = vocab_start_index + num_embeddings_per_partition
# Build the mask.
input_mask = (input_tensor < vocab_start_index) | \
(input_tensor >= vocab_end_index)
# Mask the input.
# build the mask.
input_mask = (input_tensor < vocab_start_index) | (input_tensor >= vocab_end_index)
# mask the input.
# TODO(jzy) masked_input may be an activation managed by ColoTensor.
masked_input = input_tensor.clone() - vocab_start_index
masked_input = input_tensor - vocab_start_index
masked_input[input_mask] = 0
partial_output = F.embedding(masked_input,

16
tests/test_tensor/_utils/_util.py

@ -5,6 +5,7 @@ import torch
import torch.distributed as dist
from colossalai.core import global_context as gpc
from colossalai.context import ParallelMode
from colossalai.tensor import ShardSpec, ComputeSpec, ComputePattern
def set_seed(seed):
@ -57,3 +58,18 @@ def tensor_shard_equal(tensor: torch.Tensor, shard: torch.Tensor, rank, world_si
return tensor_equal(tensor.chunk(world_size, dim)[rank], shard)
else:
raise NotImplementedError
def split_param_single_dim_tp1d(dim, param, pg):
spec = (ShardSpec([dim], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
if param.process_group.tp_world_size() == 1:
param.set_process_group(pg)
param.set_tensor_spec(*spec)
def split_param_row_tp1d(param, pg):
split_param_single_dim_tp1d(0, param, pg)
def split_param_col_tp1d(param, pg):
split_param_single_dim_tp1d(-1, param, pg)

29
tests/test_tensor/test_addmm_tp.py

@ -4,12 +4,11 @@ import pytest
import torch.nn as nn
import torch.multiprocessing as mp
from colossalai.tensor import ColoTensor, ProcessGroup
from colossalai.tensor import ShardSpec
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager
from colossalai.tensor import ColoTensorSpec
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from functools import partial
from _utils import tensor_shard_equal, tensor_equal
from _utils import tensor_shard_equal, tensor_equal, split_param_row_tp1d, split_param_col_tp1d
class Conv1D(nn.Module):
@ -36,20 +35,7 @@ class Conv1D(nn.Module):
return x
def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec)
def run_with_spec(spec_init_func):
def run_with_spec(spec_init_func, split_bias):
model = Conv1D(4, 16).cuda()
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
@ -57,7 +43,10 @@ def run_with_spec(spec_init_func):
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))
spec_init_func(weight, bias, pg)
spec_init_func(weight, pg)
if split_bias:
spec_init_func(bias, pg)
x = torch.rand(2, 16).cuda()
out = model(x)
colo_out = torch.addmm(bias, x, weight)
@ -72,8 +61,8 @@ def run_with_spec(spec_init_func):
def run_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row)
run_with_spec(init_1d_col)
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=False)
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=True)
@pytest.mark.dist

16
tests/test_tensor/test_embedding_bag_tp.py

@ -1,5 +1,3 @@
import torch
from colossalai.tensor import ShardSpec, ColoParameter
from torch.nn import functional as F
from functools import partial
@ -9,21 +7,17 @@ import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal
def init_1d_col(weight, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
from colossalai.tensor import ColoParameter, ColoTensorSpec, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d
def run_with_spec(spec_init_func):
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
model = torch.nn.EmbeddingBag(10, 4).cuda()
weight = ColoParameter(model.weight.clone(), True, ColoTensorSpec(pg))
spec_init_func(weight, pg)
inputs = torch.tensor([1, 2, 4, 5, 4, 3, 2, 9]).cuda()
offsets = torch.tensor([0, 4]).cuda()
out = model(inputs, offsets=offsets)
@ -38,7 +32,7 @@ def run_with_spec(spec_init_func):
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_col)
run_with_spec(split_param_col_tp1d)
@pytest.mark.dist

24
tests/test_tensor/test_embedding_tp.py

@ -1,5 +1,3 @@
import torch
from colossalai.tensor import ColoTensor, ShardSpec
from torch.nn import functional as F
from functools import partial
@ -9,26 +7,16 @@ import torch
import torch.multiprocessing as mp
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, pg: ProcessGroup):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
def init_1d_col(weight, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
def run_with_spec(spec_init_func, pg: ProcessGroup):
model = torch.nn.Embedding(12, 32).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
spec_init_func(weight, pg)
x = torch.tensor((0, 3, 6, 9)).cuda()
out = model(x)
colo_out = F.embedding(x, weight)
@ -44,8 +32,8 @@ def run_dist(rank, world_size, port):
# config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
pg = ProcessGroup(tp_degree=world_size)
run_with_spec(init_1d_row, pg)
run_with_spec(init_1d_col, pg)
run_with_spec(split_param_row_tp1d, pg)
run_with_spec(split_param_col_tp1d, pg)
@pytest.mark.dist

32
tests/test_tensor/test_linear_tp.py

@ -1,6 +1,3 @@
import torch
from colossalai.tensor import ColoTensor, ShardSpec
from functools import partial
import colossalai
@ -10,29 +7,20 @@ import torch.multiprocessing as mp
import torch.nn.functional as F
from colossalai.testing import rerun_if_address_is_in_use
from colossalai.utils import free_port
from colossalai.tensor import ColoTensorSpec, ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup
from _utils import tensor_equal, tensor_shard_equal
def init_1d_row(weight, bias, pg: ProcessGroup):
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
from colossalai.tensor import ColoTensorSpec, ProcessGroup, ColoTensor
from _utils import tensor_equal, tensor_shard_equal, split_param_col_tp1d, split_param_row_tp1d
def init_1d_col(weight, bias, pg: ProcessGroup):
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D))
with DistSpecManager.no_grad():
weight.set_tensor_spec(*spec)
bias.set_tensor_spec(*spec)
def run_with_spec(spec_init_func):
def run_with_spec(spec_init_func, split_bias):
pg = ProcessGroup(tp_degree=torch.distributed.get_world_size())
model = torch.nn.Linear(4, 8).cuda()
weight = ColoTensor(torch.nn.Parameter(model.weight.detach()), ColoTensorSpec(pg))
bias = ColoTensor(torch.nn.Parameter(model.bias.detach()), ColoTensorSpec(pg))
spec_init_func(weight, bias, pg)
spec_init_func(weight, pg)
if split_bias:
spec_init_func(bias, pg)
x = torch.rand(2, 4).cuda()
out = model(x)
colo_out = F.linear(x, weight, bias)
@ -48,8 +36,8 @@ def run_with_spec(spec_init_func):
def run_dist(rank, world_size, port):
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
run_with_spec(init_1d_row)
run_with_spec(init_1d_col)
run_with_spec(spec_init_func=split_param_col_tp1d, split_bias=False)
run_with_spec(spec_init_func=split_param_row_tp1d, split_bias=True)
@pytest.mark.dist

59
tests/test_tensor/test_model.py

@ -16,6 +16,7 @@ from colossalai.tensor import ShardSpec, ColoTensorSpec, ComputePattern, \
from colossalai.nn.optimizer import ColoOptimizer
from tests.components_to_test.registry import non_distributed_component_funcs
from _utils import split_param_row_tp1d, split_param_col_tp1d
def init_1d_row_linear(weight: ColoTensor, pg: ProcessGroup):
@ -50,7 +51,9 @@ def run_1d_hybrid_tp(model_name):
# A simple net with two stacked nn.Linear
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
set_seed(1)
with ColoInitContext(device=get_current_device()):
@ -59,14 +62,15 @@ def run_1d_hybrid_tp(model_name):
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
colo_optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1)
optimizer_torch = ColoOptimizer(dict(model_torch.named_parameters()), torch.optim.SGD, lr=0.1)
# Make two models have the same init params
for p1, p2 in zip(model.parameters(), model_torch.parameters()):
p2.data.copy_(p1.data)
else:
model_torch = None
optimizer_torch = None
rank = torch.distributed.get_rank()
world_size = torch.distributed.get_world_size()
pg = ProcessGroup(tp_degree=world_size)
if 'bert' == model_name:
for name, p in model.named_parameters():
@ -75,8 +79,8 @@ def run_1d_hybrid_tp(model_name):
# print(name)
# num_class = type_vocab_size = 2 | (8, 2)
# TODO(jiaruifang) has bug if open the following 2 comments
# if 'classifier' in name and 'weight' in name:
# init_1d_row_linear(p, pg)
if 'classifier' in name and 'weight' in name:
init_1d_row_linear(p, pg)
# num_class = vocab_size = 30524 | (30524, 8)
if 'word_embeddings' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
@ -86,6 +90,8 @@ def run_1d_hybrid_tp(model_name):
# num_class = type_vocab_size = 2 | (2, 8)
if 'token_type_embeddings' in name and 'weight' in name:
init_1d_col_embedding(p, pg)
if p.process_group.tp_world_size() == 1:
p.set_process_group(pg)
elif "simple_net" == model_name:
# A naive way to set spec for all weights in Linear
for name, p in model.named_parameters():
@ -101,13 +107,18 @@ def run_1d_hybrid_tp(model_name):
init_1d_col_linear(p, pg)
model = model.cuda()
model.train()
if rank == 0:
model_torch.train()
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
for i, (data, label) in enumerate(train_dataloader):
model.eval()
# Zero grad
colo_optimizer.zero_grad()
if rank == 0:
model_torch.eval()
colo_optimizer_torch.zero_grad()
optimizer_torch.zero_grad()
data = data.to(get_current_device())
label = label.to(get_current_device())
@ -123,7 +134,7 @@ def run_1d_hybrid_tp(model_name):
output = model(data, label)
loss = output
# For reference
# Test output
if rank == 0:
if criterion:
output_torch = model_torch(data)
@ -131,17 +142,14 @@ def run_1d_hybrid_tp(model_name):
else:
output_torch = model_torch(data, label)
loss_torch = output_torch
if rank == 0:
with torch.no_grad():
assert torch.allclose(loss, loss_torch, rtol=1e-2)
assert torch.allclose(loss, loss_torch, rtol=1e-2)
loss.backward()
colo_optimizer.step()
if rank == 0:
loss_torch.backward()
colo_optimizer_torch.step()
optimizer_torch.step()
with torch.no_grad():
# check param
@ -231,14 +239,19 @@ def run_1d_row_tp(model_name: str):
if rank == 0:
model_torch = model_builder(checkpoint=True)
model_torch = model_torch.cuda()
# A naive way to set spec for all weights in Linear
for name, p in model.named_parameters():
if not isinstance(p, ColoTensor):
continue
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
init_1d_row_linear(p, pg)
if 'embed' in name and 'weight' in name:
init_1d_row_embedding(p, pg)
for mo_name, module in model.named_modules():
# print(mo_name)
for pa_name, param in module.named_parameters(recurse=False):
# print('\t', pa_name, param.shape)
if not isinstance(param, ColoTensor):
continue
if 'weight' in pa_name:
if 'embed' in mo_name and 'token' not in mo_name and 'LayerNorm' not in mo_name:
split_param_row_tp1d(param, pg)
elif 'LayerNorm' not in mo_name and 'ln' not in mo_name:
split_param_col_tp1d(param, pg)
model = model.cuda()
@ -313,9 +326,9 @@ def _run_pretrain_load():
def run_model_dist(rank, world_size, port):
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
for name in ['simple_net']:
for name in ['bert']:
run_1d_row_tp(name)
for name in ['simple_net']:
for name in ['bert']:
run_1d_hybrid_tp(name)

Loading…
Cancel
Save