From 61da3fbc524c8c7939d194007d91488b89288dc5 Mon Sep 17 00:00:00 2001 From: Edenzzzz Date: Tue, 26 Mar 2024 17:22:27 +0800 Subject: [PATCH 1/2] fixed layout converter caching and updated tester --- colossalai/tensor/d_tensor/layout_converter.py | 5 ++++- .../test_tensor/test_dtensor/test_layout_converter.py | 11 +++++++++-- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/colossalai/tensor/d_tensor/layout_converter.py b/colossalai/tensor/d_tensor/layout_converter.py index abe4a86d8..667a7b78e 100644 --- a/colossalai/tensor/d_tensor/layout_converter.py +++ b/colossalai/tensor/d_tensor/layout_converter.py @@ -440,7 +440,10 @@ class LayoutConverter(metaclass=SingletonMeta): total_steps = 0 transform_path = [] comm_action_sequence: List[CommSpec] = [] - spec_pairs = (str(source_spec.sharding_sequence), str(target_spec.sharding_sequence)) + + src_shape = source_layout.get_sharded_shape_per_device() + dst_shape = target_layout.get_sharded_shape_per_device() + spec_pairs = ((str(source_spec.sharding_sequence), src_shape), (str(target_spec.sharding_sequence), dst_shape)) if spec_pairs in self.cached_solution: # Solution Cache hit diff --git a/tests/test_tensor/test_dtensor/test_layout_converter.py b/tests/test_tensor/test_dtensor/test_layout_converter.py index 4e65401bf..3bface1d2 100644 --- a/tests/test_tensor/test_dtensor/test_layout_converter.py +++ b/tests/test_tensor/test_dtensor/test_layout_converter.py @@ -123,8 +123,15 @@ def check_layout_converting(rank, world_size, port): assert comm_action_sequence[2].logical_process_axis == 1 # checkout chached_spec_pairs_transform_path - assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][0] == transform_path - assert layout_converter.cached_solution[("[R, S01, R]", "[S01, R, R]")][1] == comm_action_sequence + src_shape = source_layout.get_sharded_shape_per_device() + dst_shape = target_layout.get_sharded_shape_per_device() + assert ( + layout_converter.cached_solution[(("[R, S01, R]", src_shape), ("[S01, R, R]", dst_shape))][0] == transform_path + ) + assert ( + layout_converter.cached_solution[(("[R, S01, R]", src_shape), ("[S01, R, R]", dst_shape))][1] + == comm_action_sequence + ) comm_cost = layout_converter.get_total_comm_cost(source_layout, target_layout)