You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ColossalAI/tests/test_lazy/test_distribute.py

103 lines
3.7 KiB

from typing import Optional
import pytest
import torch
import torch.nn as nn
import colossalai
from colossalai.device.device_mesh import DeviceMesh
from colossalai.tensor.d_tensor.layout import Layout
from colossalai.tensor.d_tensor.sharding_spec import ShardingSpec
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils.common import print_rank_0
try:
from colossalai.lazy.lazy_init import LazyInitContext, LazyTensor, _MyTensor
except:
pass
from lazy_init_utils import SUPPORT_LAZY, assert_dist_model_equal, set_seed
from tests.kit.model_zoo import model_zoo
def find_shard_dim(shape: torch.Size) -> Optional[int]:
for dim, size in enumerate(shape):
if size % 2 == 0:
return dim
def make_sharding_spec(original_tensor: torch.Tensor) -> Layout:
shard_dim = find_shard_dim(original_tensor.shape)
dim_partition_dict = {shard_dim: [0]} if shard_dim is not None else {}
target_sharding_spec = ShardingSpec(dim_size=original_tensor.dim(), dim_partition_dict=dim_partition_dict)
return target_sharding_spec
def _get_current_name(prefix: str, name: str) -> str:
return f'{prefix}.{name}'.lstrip('.')
def generate_sharding_spec_dict(model: nn.Module) -> dict:
sharding_spec_dict = {}
@torch.no_grad()
def generate_recursively(module: nn.Module, prefix: str = ''):
# recursively initialize the module
for name, mod in module.named_children():
generate_recursively(mod, prefix=_get_current_name(prefix, name))
# initialize tensors directly attached to the current module
for name, param in module.named_parameters(recurse=False):
if isinstance(param, LazyTensor):
sharding_spec = make_sharding_spec(param)
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
for name, buf in module.named_buffers(recurse=False):
if isinstance(buf, LazyTensor):
sharding_spec = make_sharding_spec(buf)
sharding_spec_dict[_get_current_name(prefix, name)] = sharding_spec
generate_recursively(model)
return sharding_spec_dict
@parameterize('subset', ['torchvision', 'diffusers', 'timm', 'transformers', 'torchaudio', 'deepfm', 'dlrm'])
def run_dist_lazy_init(subset, seed: int = 42):
sub_model_zoo = model_zoo.get_sub_registry(subset)
device_mesh = DeviceMesh(torch.Tensor([0, 1, 2, 3]), (2, 2), init_process_group=True)
_MyTensor._pre_op_fn = lambda *args: set_seed(seed)
LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
for name, entry in sub_model_zoo.items():
# TODO(ver217): lazy init does not support weight norm, skip these models
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
continue
print_rank_0(name)
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
ctx = LazyInitContext(tensor_cls=_MyTensor)
with ctx:
model = model_fn()
ctx = LazyInitContext()
with ctx:
deferred_model = model_fn()
sharding_spec_dict = generate_sharding_spec_dict(deferred_model)
ctx.distribute(deferred_model, device_mesh, sharding_spec_dict, verbose=True)
assert_dist_model_equal(model, deferred_model, device_mesh, sharding_spec_dict)
def run_dist(rank, world_size, port) -> None:
colossalai.launch({}, rank=rank, world_size=world_size, host='localhost', port=port)
run_dist_lazy_init()
@pytest.mark.skipif(not SUPPORT_LAZY, reason='torch version should be >= 1.12.0')
@pytest.mark.dist
@rerun_if_address_is_in_use()
def test_dist_lazy_init():
spawn(run_dist, 4)
if __name__ == '__main__':
test_dist_lazy_init()