mirror of https://github.com/hpcaitech/ColossalAI
[shardformer] Unit test (#3928)
* fix bug in slicer, add slicer unit test * add dropout test * use pid as dropout seed * updata dropout test with local pattern * ad todopull/4157/head
parent
f1cb5ac6bf
commit
a73130482d
|
@ -1,5 +1,4 @@
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -14,7 +13,8 @@ class SeedManager:
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
original_state = torch.cuda.get_rng_state()
|
original_state = torch.cuda.get_rng_state()
|
||||||
seed = int(f"{int(time.time())}{os.environ['RANK']}")
|
# TODO: unify this seed manager with the colossalai.context.random
|
||||||
|
seed = os.getpid()
|
||||||
torch.cuda.manual_seed(int(seed))
|
torch.cuda.manual_seed(int(seed))
|
||||||
self.dropout_state = torch.cuda.get_rng_state()
|
self.dropout_state = torch.cuda.get_rng_state()
|
||||||
torch.cuda.set_rng_state(original_state)
|
torch.cuda.set_rng_state(original_state)
|
||||||
|
|
|
@ -3,7 +3,7 @@ import torch
|
||||||
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
|
from ..policies.basepolicy import Col_Layer, Layer, Row_Layer
|
||||||
from .shard_config import ShardConfig
|
from .shard_config import ShardConfig
|
||||||
|
|
||||||
dim_mapping = {Col_Layer: 1, Row_Layer: 0}
|
dim_mapping = {Col_Layer: 0, Row_Layer: 1}
|
||||||
|
|
||||||
|
|
||||||
class Slicer():
|
class Slicer():
|
||||||
|
@ -40,7 +40,7 @@ class Slicer():
|
||||||
# print(weight.shape, dim)
|
# print(weight.shape, dim)
|
||||||
if policy_layer_cls == Col_Layer:
|
if policy_layer_cls == Col_Layer:
|
||||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||||
bias = self.slice_tensor(bias, 0, True)
|
bias = self.slice_tensor(bias, 0, True, n_cast)
|
||||||
elif policy_layer_cls == Row_Layer:
|
elif policy_layer_cls == Row_Layer:
|
||||||
weight = self.slice_tensor(weight, dim, False, n_cast)
|
weight = self.slice_tensor(weight, dim, False, n_cast)
|
||||||
else:
|
else:
|
||||||
|
@ -129,13 +129,13 @@ class Slicer():
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if n_cast is None:
|
if n_cast is None:
|
||||||
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
||||||
else:
|
else:
|
||||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
|
||||||
chunk_list = [
|
chunk_list = [
|
||||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||||
]
|
]
|
||||||
return torch.cat(chunk_list, dim=0).contiguous()
|
return torch.cat(chunk_list, dim=1).contiguous()
|
||||||
|
|
||||||
def slice_row(
|
def slice_row(
|
||||||
self,
|
self,
|
||||||
|
@ -152,10 +152,10 @@ class Slicer():
|
||||||
:class:`torch.Tensor`: The sliced tensor
|
:class:`torch.Tensor`: The sliced tensor
|
||||||
"""
|
"""
|
||||||
if n_cast is None:
|
if n_cast is None:
|
||||||
return tensor.chunk(self.shardconfig.world_size, dim=1)[self.shardconfig.rank].contiguous()
|
return tensor.chunk(self.shardconfig.world_size, dim=0)[self.shardconfig.rank].contiguous()
|
||||||
else:
|
else:
|
||||||
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=1)
|
tensor_chunks = tensor.chunk(self.shardconfig.world_size * n_cast, dim=0)
|
||||||
chunk_list = [
|
chunk_list = [
|
||||||
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
tensor_chunks[i] for i in range(self.shardconfig.rank, len(tensor_chunks), self.shardconfig.world_size)
|
||||||
]
|
]
|
||||||
return torch.cat(chunk_list, dim=1).contiguous()
|
return torch.cat(chunk_list, dim=0).contiguous()
|
||||||
|
|
|
@ -0,0 +1,51 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.layer.dropout import Dropout1D
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||||
|
|
||||||
|
|
||||||
|
def check_dropout(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
|
||||||
|
|
||||||
|
# prepare data
|
||||||
|
input = torch.randn(5, 4).to('cuda')
|
||||||
|
dropout = Dropout1D(p=0.4).to('cuda')
|
||||||
|
output_list = []
|
||||||
|
# compare the dropout pattern in each device
|
||||||
|
for i in range(2):
|
||||||
|
output = dropout(input)
|
||||||
|
output_list.append(output)
|
||||||
|
dist_output_list = [torch.zeros(*output.shape).to('cuda') for _ in range(world_size)]
|
||||||
|
torch.distributed.all_gather(dist_output_list, output)
|
||||||
|
for j in range(world_size):
|
||||||
|
for k in range(world_size):
|
||||||
|
if j != k:
|
||||||
|
mask = torch.eq(dist_output_list[j], 0.0) == torch.eq(dist_output_list[k], 0.0)
|
||||||
|
assert torch.all(
|
||||||
|
mask
|
||||||
|
) == False, f"The dropout pattern in each device is not unique\n{dist_output_list[j]}\n{dist_output_list[k]}"
|
||||||
|
# compare the dropout pattern in loacl device
|
||||||
|
for i in range(len(output_list)):
|
||||||
|
for j in range(len(output_list)):
|
||||||
|
if i != j:
|
||||||
|
mask = torch.eq(output_list[i], 0.0) == torch.eq(output_list[j], 0.0)
|
||||||
|
assert torch.all(
|
||||||
|
mask
|
||||||
|
) == False, f"The dropout pattern in one device is not unique\n{output_list[i]}\n{output_list[j]}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_dropout():
|
||||||
|
spawn(check_dropout, 2)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_dropout()
|
|
@ -0,0 +1,78 @@
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.shardformer.policies.basepolicy import Col_Layer, Layer, Row_Layer
|
||||||
|
from colossalai.shardformer.shard.shard_config import ShardConfig
|
||||||
|
from colossalai.shardformer.shard.slicer import Slicer
|
||||||
|
from colossalai.testing import rerun_if_address_is_in_use, spawn
|
||||||
|
|
||||||
|
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),)
|
||||||
|
|
||||||
|
|
||||||
|
def check_slicer(rank, world_size, port, in_feature, out_feature):
|
||||||
|
disable_existing_loggers()
|
||||||
|
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, port=port, host='localhost', backend='nccl')
|
||||||
|
# initialize slicer
|
||||||
|
shardconfig = ShardConfig(rank=rank, world_size=world_size)
|
||||||
|
slicer = Slicer(shardconfig)
|
||||||
|
# initialize test data
|
||||||
|
weight = torch.randn(in_feature, out_feature)
|
||||||
|
bias = torch.randn(out_feature)
|
||||||
|
policy_layer_cls_list = [Layer, Col_Layer, Row_Layer]
|
||||||
|
n_cast_list = [None, 2, 3, 4]
|
||||||
|
# weight and bias
|
||||||
|
for n_cast in n_cast_list:
|
||||||
|
sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Layer, n_cast=n_cast)
|
||||||
|
expected_sliced_weight = weight
|
||||||
|
expected_sliced_bias = bias
|
||||||
|
assert torch.equal(
|
||||||
|
sliced_weight, expected_sliced_weight
|
||||||
|
), f"In Layer case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||||
|
assert torch.equal(
|
||||||
|
sliced_bias, expected_sliced_bias
|
||||||
|
), f"In Layer case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||||
|
|
||||||
|
sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Col_Layer, n_cast=n_cast)
|
||||||
|
if (n_cast is None):
|
||||||
|
expected_sliced_weight = weight.chunk(world_size, dim=0)[rank]
|
||||||
|
expected_sliced_bias = bias.chunk(world_size)[rank]
|
||||||
|
else:
|
||||||
|
chunks = weight.chunk(world_size * n_cast, dim=0)
|
||||||
|
expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=0)
|
||||||
|
chunks = bias.chunk(world_size * n_cast, dim=0)
|
||||||
|
expected_sliced_bias = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)])
|
||||||
|
assert torch.equal(
|
||||||
|
sliced_weight, expected_sliced_weight
|
||||||
|
), f"In Col_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||||
|
assert torch.equal(
|
||||||
|
sliced_bias, expected_sliced_bias
|
||||||
|
), f"In Col_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_bias}\nexpected:{expected_sliced_bias}"
|
||||||
|
|
||||||
|
sliced_weight, sliced_bias = slicer.slice_weight_bias(weight, bias, policy_layer_cls=Row_Layer, n_cast=n_cast)
|
||||||
|
if (n_cast is None):
|
||||||
|
expected_sliced_weight = weight.chunk(world_size, dim=1)[rank]
|
||||||
|
expected_sliced_bias = bias
|
||||||
|
else:
|
||||||
|
chunks = weight.chunk(world_size * n_cast, dim=1)
|
||||||
|
expected_sliced_weight = torch.cat([chunks[i] for i in range(rank, n_cast * world_size, world_size)], dim=1)
|
||||||
|
expected_sliced_bias = bias
|
||||||
|
assert torch.equal(
|
||||||
|
sliced_weight, expected_sliced_weight
|
||||||
|
), f"In Row_Layer {n_cast} cast case, weight: sliced_weight is not equal to expected_sliced_weight\norg:{weight}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||||
|
assert torch.equal(
|
||||||
|
sliced_bias, expected_sliced_bias
|
||||||
|
), f"In Row_Layer {n_cast} cast case, bias: sliced_bias is not equal to expected_sliced_bias\norg:{bias}\nsliced:{sliced_weight}\nexpected:{expected_sliced_weight}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_slicer():
|
||||||
|
args = dict(in_feature=24, out_feature=48)
|
||||||
|
spawn(check_slicer, nprocs=2, in_feature=args['in_feature'], out_feature=args['out_feature'])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_slicer()
|
Loading…
Reference in New Issue