|
|
@ -10,7 +10,7 @@ from colossalai.auto_parallel.tensor_shard.sharding_strategy import MemoryCost, |
|
|
|
from colossalai.context.singleton_meta import SingletonMeta |
|
|
|
from colossalai.context.singleton_meta import SingletonMeta |
|
|
|
from colossalai.tensor.d_tensor.comm_spec import * |
|
|
|
from colossalai.tensor.d_tensor.comm_spec import * |
|
|
|
from colossalai.tensor.d_tensor.layout import Layout |
|
|
|
from colossalai.tensor.d_tensor.layout import Layout |
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpecException |
|
|
|
from colossalai.tensor.d_tensor.misc import LayoutException |
|
|
|
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator |
|
|
|
from colossalai.tensor.utils import all_gather_simulator, all_to_all_simulator, shard_simulator |
|
|
|
|
|
|
|
|
|
|
|
from .sharding_spec import ShardingSpec |
|
|
|
from .sharding_spec import ShardingSpec |
|
|
@ -145,7 +145,7 @@ class LayoutConverter(metaclass=SingletonMeta): |
|
|
|
entire_shape=source_layout.entire_shape) |
|
|
|
entire_shape=source_layout.entire_shape) |
|
|
|
|
|
|
|
|
|
|
|
valid_spec_dict[new_layout] = comm_spec |
|
|
|
valid_spec_dict[new_layout] = comm_spec |
|
|
|
except ShardingSpecException: |
|
|
|
except LayoutException: |
|
|
|
pass |
|
|
|
pass |
|
|
|
return valid_spec_dict |
|
|
|
return valid_spec_dict |
|
|
|
|
|
|
|
|
|
|
@ -255,7 +255,7 @@ class LayoutConverter(metaclass=SingletonMeta): |
|
|
|
device_type=source_layout.device_type, |
|
|
|
device_type=source_layout.device_type, |
|
|
|
entire_shape=source_layout.entire_shape) |
|
|
|
entire_shape=source_layout.entire_shape) |
|
|
|
valid_spec_dict[new_layout] = comm_spec |
|
|
|
valid_spec_dict[new_layout] = comm_spec |
|
|
|
except ShardingSpecException: |
|
|
|
except LayoutException: |
|
|
|
pass |
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
return valid_spec_dict |
|
|
|
return valid_spec_dict |
|
|
@ -343,7 +343,7 @@ class LayoutConverter(metaclass=SingletonMeta): |
|
|
|
device_type=source_layout.device_type, |
|
|
|
device_type=source_layout.device_type, |
|
|
|
entire_shape=source_layout.entire_shape) |
|
|
|
entire_shape=source_layout.entire_shape) |
|
|
|
valid_spec_dict[new_layout] = comm_spec |
|
|
|
valid_spec_dict[new_layout] = comm_spec |
|
|
|
except ShardingSpecException: |
|
|
|
except LayoutException: |
|
|
|
pass |
|
|
|
pass |
|
|
|
return valid_spec_dict |
|
|
|
return valid_spec_dict |
|
|
|
|
|
|
|
|
|
|
|