Browse Source

fix runtime prepare pass (#4502)

Co-authored-by: lufang.chen <lufang.chen@nio.com>
pull/4431/head^2
Lufang Chen 1 year ago committed by GitHub
parent
commit
12c95a9fed
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
  1. 2
      colossalai/auto_parallel/passes/runtime_preparation_pass.py

2
colossalai/auto_parallel/passes/runtime_preparation_pass.py

@ -144,7 +144,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
# DeviceMesh information instructs the scaling of the size value
device_mesh_info = {}
for dim, dim_size in enumerate(device_mesh.mesh_shape):
for dim, dim_size in enumerate(device_mesh.shape):
device_mesh_info[dim] = dim_size
def _extract_target_dim(node):

Loading…
Cancel
Save