mirror of https://github.com/hpcaitech/ColossalAI
Merge pull request #5515 from Edenzzzz/fix_layout_convert
Fix layout convertor cachingpull/5517/head
commit
9a3321e9f4
|
@ -440,7 +440,10 @@ class LayoutConverter(metaclass=SingletonMeta):
|
||||||
total_steps = 0
|
total_steps = 0
|
||||||
transform_path = []
|
transform_path = []
|
||||||
comm_action_sequence: List[CommSpec] = []
|
comm_action_sequence: List[CommSpec] = []
|
||||||
spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence))
|
|
||||||
|
src_shape = source_layout.get_sharded_shape_per_device()
|
||||||
|
dst_shape = target_layout.get_sharded_shape_per_device()
|
||||||
|
spec_pairs = ((str(source_spec.sharding_sequence), src_shape), (str(target_spec.sharding_sequence), dst_shape))
|
||||||
|
|
||||||
if spec_pairs in self.cached_solution:
|
if spec_pairs in self.cached_solution:
|
||||||
# Solution Cache hit
|
# Solution Cache hit
|
||||||
|
|
|
@ -123,8 +123,15 @@ def check_layout_converting(rank, world_size, port):
|
||||||
assert comm_action_sequence[2].logical_process_axis == 1
|
assert comm_action_sequence[2].logical_process_axis == 1
|
||||||
|
|
||||||
# checkout chached_spec_pairs_transform_path
|
# checkout chached_spec_pairs_transform_path
|
||||||
assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][0] == transform_path
|
src_shape = source_layout.get_sharded_shape_per_device()
|
||||||
assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][1] == comm_action_sequence
|
dst_shape = target_layout.get_sharded_shape_per_device()
|
||||||
|
assert (
|
||||||
|
layout_converter.cached_solution[(("[R, S01, R]", src_shape), ("[S01, R, R]", dst_shape))][0] == transform_path
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
layout_converter.cached_solution[(("[R, S01, R]", src_shape), ("[S01, R, R]", dst_shape))][1]
|
||||||
|
== comm_action_sequence
|
||||||
|
)
|
||||||
|
|
||||||
comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout)
|
comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue