@ -15,9 +15,9 @@ try:
from colossalai . utils . model . experimental import LazyInitContext , LazyTensor , _MyTensor
from colossalai . utils . model . experimental import LazyInitContext , LazyTensor , _MyTensor
except :
except :
pass
pass
from tests. kit . model_zoo import model_zoo
from utils import SUPPORT_LAZY , assert_dist_model_equal , set_seed
# from utils import assert_dist_model_equal, set_seed
from tests . kit . model_zoo import model_zoo
def find_shard_dim ( shape : torch . Size ) - > Optional [ int ] :
def find_shard_dim ( shape : torch . Size ) - > Optional [ int ] :
@ -70,9 +70,8 @@ def generate_layout_dict(model: nn.Module, device_mesh: DeviceMesh) -> dict:
def run_dist_lazy_init ( subset , seed : int = 42 ) :
def run_dist_lazy_init ( subset , seed : int = 42 ) :
sub_model_zoo = model_zoo . get_sub_registry ( subset )
sub_model_zoo = model_zoo . get_sub_registry ( subset )
device_mesh = DeviceMesh ( torch . Tensor ( [ 0 , 1 , 2 , 3 ] ) , ( 2 , 2 ) , init_process_group = True )
device_mesh = DeviceMesh ( torch . Tensor ( [ 0 , 1 , 2 , 3 ] ) , ( 2 , 2 ) , init_process_group = True )
# FIXME(ver217): uncomment this line
_MyTensor . _pre_op_fn = lambda * args : set_seed ( seed )
# _MyTensor._pre_op_fn = lambda *args: set_seed(seed)
LazyTensor . _pre_op_fn = lambda * args : set_seed ( seed )
# LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
for name , entry in sub_model_zoo . items ( ) :
for name , entry in sub_model_zoo . items ( ) :
# TODO(ver217): lazy init does not support weight norm, skip these models
# TODO(ver217): lazy init does not support weight norm, skip these models
@ -88,8 +87,7 @@ def run_dist_lazy_init(subset, seed: int = 42):
deferred_model = model_fn ( )
deferred_model = model_fn ( )
layout_dict = generate_layout_dict ( deferred_model , device_mesh )
layout_dict = generate_layout_dict ( deferred_model , device_mesh )
ctx . distribute ( deferred_model , layout_dict , verbose = True )
ctx . distribute ( deferred_model , layout_dict , verbose = True )
# FIXME(ver217): uncomment this line
assert_dist_model_equal ( model , deferred_model , layout_dict )
# assert_dist_model_equal(model, deferred_model, layout_dict)
def run_dist ( rank , world_size , port ) - > None :
def run_dist ( rank , world_size , port ) - > None :
@ -97,8 +95,7 @@ def run_dist(rank, world_size, port) -> None:
run_dist_lazy_init ( )
run_dist_lazy_init ( )
# FIXME(ver217): temporarily skip this test since torch 1.11 does not fully support meta tensor
@pytest.mark.skipif ( not SUPPORT_LAZY , reason = ' torch version should be >= 1.12.0 ' )
@pytest.mark.skip
@pytest.mark.dist
@pytest.mark.dist
@rerun_if_address_is_in_use ( )
@rerun_if_address_is_in_use ( )
def test_dist_lazy_init ( ) :
def test_dist_lazy_init ( ) :