From 12c95a9fedf1dfd4d455fe614c0e5869e7e0d4d1 Mon Sep 17 00:00:00 2001 From: Lufang Chen <64068400+vincentccc@users.noreply.github.com> Date: Wed, 30 Aug 2023 17:29:38 +0800 Subject: [PATCH] fix runtime prepare pass (#4502) Co-authored-by: lufang.chen --- colossalai/auto_parallel/passes/runtime_preparation_pass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/colossalai/auto_parallel/passes/runtime_preparation_pass.py b/colossalai/auto_parallel/passes/runtime_preparation_pass.py index 1a6dc7815..0ed0742ee 100644 --- a/colossalai/auto_parallel/passes/runtime_preparation_pass.py +++ b/colossalai/auto_parallel/passes/runtime_preparation_pass.py @@ -144,7 +144,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh # DeviceMesh information instructs the scaling of the size value device_mesh_info = {} - for dim, dim_size in enumerate(device_mesh.mesh_shape): + for dim, dim_size in enumerate(device_mesh.shape): device_mesh_info[dim] = dim_size def _extract_target_dim(node):