|
|
@ -123,8 +123,15 @@ def check_layout_converting(rank, world_size, port): |
|
|
|
assert comm_action_sequence[2].logical_process_axis == 1 |
|
|
|
assert comm_action_sequence[2].logical_process_axis == 1 |
|
|
|
|
|
|
|
|
|
|
|
# checkout chached_spec_pairs_transform_path |
|
|
|
# checkout chached_spec_pairs_transform_path |
|
|
|
assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][0] == transform_path |
|
|
|
src_shape = source_layout.get_sharded_shape_per_device() |
|
|
|
assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][1] == comm_action_sequence |
|
|
|
dst_shape = target_layout.get_sharded_shape_per_device() |
|
|
|
|
|
|
|
assert ( |
|
|
|
|
|
|
|
layout_converter.cached_solution[(("[R, S01, R]", src_shape), ("[S01, R, R]", dst_shape))][0] == transform_path |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
assert ( |
|
|
|
|
|
|
|
layout_converter.cached_solution[(("[R, S01, R]", src_shape), ("[S01, R, R]", dst_shape))][1] |
|
|
|
|
|
|
|
== comm_action_sequence |
|
|
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout) |
|
|
|
comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout) |
|
|
|
|
|
|
|
|
|
|
|