mirror of https://github.com/hpcaitech/ColossalAI
parent
0b00def881
commit
12c95a9fed
|
@ -144,7 +144,7 @@ def size_value_converting_pass(gm: torch.fx.GraphModule, device_mesh: DeviceMesh
|
|||
|
||||
# DeviceMesh information instructs the scaling of the size value
|
||||
device_mesh_info = {}
|
||||
for dim, dim_size in enumerate(device_mesh.mesh_shape):
|
||||
for dim, dim_size in enumerate(device_mesh.shape):
|
||||
device_mesh_info[dim] = dim_size
|
||||
|
||||
def _extract_target_dim(node):
|
||||
|
|
Loading…
Reference in New Issue