Making large AI models cheaper, faster and more accessible
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 

288 lines
10 KiB

import os
from contextlib import nullcontext
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.testing import assert_close
import colossalai
from colossalai.lazy import LazyInitContext
from colossalai.pipeline.weight_grad_store import WeightGradStore
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, LinearWithGradAccum
from colossalai.tensor.d_tensor import is_distributed_tensor
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1"
def check_linear_1d_col(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_col = Linear1D_Col.from_native_module(
linear_copy, process_group=None, gather_output=True, seq_parallel_mode=seq_parallel_mode, overlap=overlap
)
# ensure that the parameters are distributed
assert is_distributed_tensor(linear_col.weight)
assert is_distributed_tensor(linear_col.bias)
assert linear_copy.weight is linear_col.weight
assert linear_copy.bias is linear_col.bias
# ensure the shape is correct
assert linear_col.weight.shape == torch.Size([64, 32])
assert linear_col.bias.shape == torch.Size([64])
# ensure state dict is reversibly loadable
linear.load_state_dict(linear_col.state_dict())
linear_col.load_state_dict(linear.state_dict())
# check computation correctness
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = (
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
)
x_for_shard.requires_grad_(True)
out = linear(x_for_unshard)
gather_out = linear_col(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=0)[rank]
assert_close(target_grad, linear_col.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
target_unshard_gard = (
x_for_unshard.grad
if seq_parallel_mode is None
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
)
assert_close(target_unshard_gard, x_for_shard.grad)
def check_linear_1d_row(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_row = Linear1D_Row.from_native_module(
linear_copy, process_group=None, parallel_input=False, seq_parallel_mode=seq_parallel_mode
)
assert linear_row.weight.shape == torch.Size([128, 16])
assert linear_row.bias.shape == torch.Size([128])
assert linear_copy.weight is linear_row.weight
assert linear_copy.bias is linear_row.bias
linear.load_state_dict(linear_row.state_dict())
linear_row.load_state_dict(linear.state_dict())
# check computation correctness
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
out = linear(x_for_unshard)
gather_out = linear_row(x_for_shard)
target_out = out if seq_parallel_mode is None else torch.chunk(out.clone(), 2, dim=1)[dist.get_rank()]
assert_close(target_out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
rank = dist.get_rank()
target_grad = torch.chunk(linear.weight.grad, 2, dim=1)[rank]
assert_close(target_grad, linear_row.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_without_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=False
)
assert linear_base.weight.shape == torch.Size([128, 32])
assert linear_base.bias.shape == torch.Size([128])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
linear.load_state_dict(linear_base.state_dict())
linear_base.load_state_dict(linear.state_dict())
# check computation correctness
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
out = linear(x_for_unshard)
gather_out = linear_base(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
assert_close(linear.weight.grad, linear_base.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_with_weight_grad_store(lazy_init: bool, seq_parallel_mode: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear = nn.Linear(32, 128).cuda()
with ctx:
linear_copy = nn.Linear(32, 128).cuda()
linear_base = LinearWithGradAccum.from_native_module(
linear_copy, parallel_input=False, seq_parallel_mode=seq_parallel_mode, use_zbv=True
)
assert linear_base.weight.shape == torch.Size([128, 32])
assert linear_base.bias.shape == torch.Size([128])
assert linear_copy.weight is linear_base.weight
assert linear_copy.bias is linear_base.bias
linear.load_state_dict(linear_base.state_dict())
linear_base.load_state_dict(linear.state_dict())
# check computation correctness
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = x.expand_as(x.clone())
x_for_shard.requires_grad_(True)
# run forward
out = linear(x_for_unshard)
gather_out = linear_base(x_for_shard)
assert_close(out, gather_out)
# check backward correctness
out.sum().backward()
gather_out.sum().backward()
# Weight grad is None before we do WeightGradStore pop
assert linear_base.weight.grad is None
# after WeightGradStore pop (dw computation complete), we assert weight grad
WeightGradStore.flush(chunk=0) # flush buffer to chunk 0 Queue
WeightGradStore.pop(chunk=0)
assert_close(linear.weight.grad, linear_base.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
assert_close(x_for_unshard.grad, x_for_shard.grad)
def check_linear_col_plus_row(lazy_init: bool, seq_parallel_mode: bool, overlap: bool):
ctx = LazyInitContext() if lazy_init else nullcontext()
linear_1 = nn.Linear(32, 128).cuda()
linear_2 = nn.Linear(128, 32).cuda()
with ctx:
linear_1_copy = nn.Linear(32, 128).cuda()
linear_2_copy = nn.Linear(128, 32).cuda()
linear_col = Linear1D_Col.from_native_module(
linear_1_copy, process_group=None, gather_output=False, seq_parallel_mode=seq_parallel_mode, overlap=overlap
)
linear_row = Linear1D_Row.from_native_module(
linear_2_copy, process_group=None, parallel_input=True, seq_parallel_mode=seq_parallel_mode
)
linear_1.load_state_dict(linear_col.state_dict())
linear_col.load_state_dict(linear_1.state_dict())
linear_2.load_state_dict(linear_row.state_dict())
linear_row.load_state_dict(linear_2.state_dict())
# check computation correctness
# [batch_size, seq_len, hidden_size]
x = torch.rand(2, 4, 32).cuda()
x_for_unshard = x.expand_as(x.clone())
x_for_unshard.requires_grad_(True)
x_for_shard = (
x.expand_as(x.clone()) if seq_parallel_mode is None else torch.chunk(x.clone(), 2, dim=1)[dist.get_rank()]
)
x_for_shard.requires_grad_(True)
# run forward
unshard_out = linear_2(linear_1(x_for_unshard))
shard_out = linear_row(linear_col(x_for_shard))
target_out = (
unshard_out if seq_parallel_mode is None else torch.chunk(unshard_out.clone(), 2, dim=1)[dist.get_rank()]
)
assert_close(target_out, shard_out)
# check backward correctness
unshard_out.sum().backward()
shard_out.sum().backward()
rank = dist.get_rank()
target_1_grad = torch.chunk(linear_1.weight.grad, 2, dim=0)[rank]
assert_close(target_1_grad, linear_col.weight.grad)
# check the input gradients
assert x_for_shard.grad is not None
assert x_for_unshard.grad is not None
target_unshard_gard = (
x_for_unshard.grad
if seq_parallel_mode is None
else torch.chunk(x_for_unshard.grad.clone(), 2, dim=1)[dist.get_rank()]
)
assert_close(target_unshard_gard, x_for_shard.grad)
@parameterize("lazy_init", [False, True])
@parameterize("seq_parallel_mode", [None, "split_gather"])
@parameterize("overlap", [True])
def run_dist_linear_test(lazy_init, seq_parallel_mode, overlap):
check_linear_1d_col(lazy_init, seq_parallel_mode, overlap)
check_linear_1d_row(lazy_init, seq_parallel_mode)
check_linear_col_plus_row(lazy_init, seq_parallel_mode, overlap)
check_linear_without_weight_grad_store(lazy_init, seq_parallel_mode)
check_linear_with_weight_grad_store(lazy_init, seq_parallel_mode)
def check_dist_linear(rank, world_size, port):
colossalai.launch(rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
run_dist_linear_test()
@rerun_if_address_is_in_use()
def test_linear():
spawn(check_dist_linear, nprocs=2)
if __name__ == "__main__":
test_linear()