mirror of https://github.com/hpcaitech/ColossalAI
[NFC] fix typo colossalai/amp auto_parallel autochunk (#3756)
parent
21e29e2212
commit
32f81f14d4
|
@ -240,7 +240,7 @@ class GradScaler(object):
|
||||||
for grads in per_dtype_grads.values():
|
for grads in per_dtype_grads.values():
|
||||||
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
|
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
|
||||||
per_device_inv_scale.get(device))
|
per_device_inv_scale.get(device))
|
||||||
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
|
# For tensor parallel parameters it should be all-reduced over tensor parallel process group
|
||||||
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
|
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||||
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
|
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
|
||||||
coalesced = _flatten_dense_tensors(vals)
|
coalesced = _flatten_dense_tensors(vals)
|
||||||
|
|
|
@ -325,7 +325,7 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
||||||
else:
|
else:
|
||||||
_is_batch_dims_same = False
|
_is_batch_dims_same = False
|
||||||
|
|
||||||
# retireve dimensions
|
# retrieve dimensions
|
||||||
input_dim_00 = input_tensors[0].shape[-2]
|
input_dim_00 = input_tensors[0].shape[-2]
|
||||||
input_dim_01 = input_tensors[0].shape[-1]
|
input_dim_01 = input_tensors[0].shape[-1]
|
||||||
input_dim_10 = input_tensors[1].shape[-2]
|
input_dim_10 = input_tensors[1].shape[-2]
|
||||||
|
|
|
@ -219,7 +219,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
||||||
return gm
|
return gm
|
||||||
|
|
||||||
|
|
||||||
def _act_annotataion_pass(gm: torch.fx.GraphModule):
|
def _act_annotation_pass(gm: torch.fx.GraphModule):
|
||||||
"""
|
"""
|
||||||
This pass is used to add the act annotation to the new inserted nodes.
|
This pass is used to add the act annotation to the new inserted nodes.
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -54,7 +54,7 @@ def size_processing(size: Union[int, torch.Size],
|
||||||
return size
|
return size
|
||||||
|
|
||||||
|
|
||||||
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||||
strategies_constructor: StrategiesConstructor):
|
strategies_constructor: StrategiesConstructor):
|
||||||
"""
|
"""
|
||||||
This method is used to stick the solution strategy to the nodes and add the information
|
This method is used to stick the solution strategy to the nodes and add the information
|
||||||
|
@ -496,7 +496,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
||||||
device_mesh: DeviceMesh,
|
device_mesh: DeviceMesh,
|
||||||
strategies_constructor: StrategiesConstructor,
|
strategies_constructor: StrategiesConstructor,
|
||||||
overlap=False):
|
overlap=False):
|
||||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
|
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
|
||||||
gm, solution, strategies_constructor)
|
gm, solution, strategies_constructor)
|
||||||
gm = size_value_converting_pass(gm, device_mesh)
|
gm = size_value_converting_pass(gm, device_mesh)
|
||||||
gm = node_args_converting_pass(gm, device_mesh)
|
gm = node_args_converting_pass(gm, device_mesh)
|
||||||
|
|
|
@ -64,7 +64,7 @@ class TraceFlow(object):
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def _assgin_single_node_flow(
|
def _assign_single_node_flow(
|
||||||
self,
|
self,
|
||||||
arg_node: Node,
|
arg_node: Node,
|
||||||
start_idx: int,
|
start_idx: int,
|
||||||
|
@ -177,7 +177,7 @@ class TraceFlow(object):
|
||||||
if get_node_shape(arg) is None:
|
if get_node_shape(arg) is None:
|
||||||
continue
|
continue
|
||||||
arg_list.append(arg)
|
arg_list.append(arg)
|
||||||
flow_flag = self._assgin_single_node_flow(
|
flow_flag = self._assign_single_node_flow(
|
||||||
arg,
|
arg,
|
||||||
start_idx,
|
start_idx,
|
||||||
end_idx,
|
end_idx,
|
||||||
|
@ -315,7 +315,7 @@ class TraceFlow(object):
|
||||||
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
||||||
|
|
||||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||||
# we need to log input nodes to avoid deleteing them in the loop
|
# we need to log input nodes to avoid deleting them in the loop
|
||||||
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
|
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
|
||||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||||
for n in chunk_info["args"]["prepose_nodes"]:
|
for n in chunk_info["args"]["prepose_nodes"]:
|
||||||
|
|
|
@ -461,7 +461,7 @@ class TraceIndice(object):
|
||||||
nodes_in.append(node_in)
|
nodes_in.append(node_in)
|
||||||
self._inherit_more_indice_from_node_with_exclude(node_in, node)
|
self._inherit_more_indice_from_node_with_exclude(node_in, node)
|
||||||
|
|
||||||
def _assgin_no_change_indice(self, node, idx):
|
def _assign_no_change_indice(self, node, idx):
|
||||||
self._assign_indice_as_input(node, idx)
|
self._assign_indice_as_input(node, idx)
|
||||||
for node_in in node.args:
|
for node_in in node.args:
|
||||||
if type(node_in) == type(node):
|
if type(node_in) == type(node):
|
||||||
|
@ -792,7 +792,7 @@ class TraceIndice(object):
|
||||||
self._add_dim(node_idx, i)
|
self._add_dim(node_idx, i)
|
||||||
dim_from.reverse()
|
dim_from.reverse()
|
||||||
|
|
||||||
# inheirt indice from current node
|
# inherit indice from current node
|
||||||
if len(dim_from) != 0 and len(dim_to) != 0:
|
if len(dim_from) != 0 and len(dim_to) != 0:
|
||||||
if dim_diff == 1:
|
if dim_diff == 1:
|
||||||
if origin_shape[dim_from[0]] == 1:
|
if origin_shape[dim_from[0]] == 1:
|
||||||
|
@ -852,7 +852,7 @@ class TraceIndice(object):
|
||||||
elif "split" == node_name:
|
elif "split" == node_name:
|
||||||
self._assign_split_indice(node, idx)
|
self._assign_split_indice(node, idx)
|
||||||
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
|
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
|
||||||
self._assgin_no_change_indice(node, idx)
|
self._assign_no_change_indice(node, idx)
|
||||||
elif "new_ones" == node_name:
|
elif "new_ones" == node_name:
|
||||||
self._assign_all_indice(node, idx)
|
self._assign_all_indice(node, idx)
|
||||||
elif "flatten" == node_name:
|
elif "flatten" == node_name:
|
||||||
|
@ -914,7 +914,7 @@ class TraceIndice(object):
|
||||||
elif "conv2d" == node_name:
|
elif "conv2d" == node_name:
|
||||||
self._assign_conv2d_indice(node, idx)
|
self._assign_conv2d_indice(node, idx)
|
||||||
elif "identity" == node_name:
|
elif "identity" == node_name:
|
||||||
self._assgin_no_change_indice(node, idx)
|
self._assign_no_change_indice(node, idx)
|
||||||
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
|
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
|
||||||
self._assign_elementwise_indice(node, idx)
|
self._assign_elementwise_indice(node, idx)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue