[NFC] fix typo colossalai/amp auto_parallel autochunk (#3756)

pull/3780/head^2
digger yu 2 years ago committed by GitHub
parent 21e29e2212
commit 32f81f14d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -240,7 +240,7 @@ class GradScaler(object):
for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
per_device_inv_scale.get(device))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
# For tensor parallel parameters it should be all-reduced over tensor parallel process group
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
coalesced = _flatten_dense_tensors(vals)

@ -325,7 +325,7 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else:
_is_batch_dims_same = False
# retireve dimensions
# retrieve dimensions
input_dim_00 = input_tensors[0].shape[-2]
input_dim_01 = input_tensors[0].shape[-1]
input_dim_10 = input_tensors[1].shape[-2]

@ -219,7 +219,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
return gm
def _act_annotataion_pass(gm: torch.fx.GraphModule):
def _act_annotation_pass(gm: torch.fx.GraphModule):
"""
This pass is used to add the act annotation to the new inserted nodes.
"""

@ -54,7 +54,7 @@ def size_processing(size: Union[int, torch.Size],
return size
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
strategies_constructor: StrategiesConstructor):
"""
This method is used to stick the solution strategy to the nodes and add the information
@ -496,7 +496,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor,
overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
gm, solution, strategies_constructor)
gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh)

@ -64,7 +64,7 @@ class TraceFlow(object):
return False
return True
def _assgin_single_node_flow(
def _assign_single_node_flow(
self,
arg_node: Node,
start_idx: int,
@ -177,7 +177,7 @@ class TraceFlow(object):
if get_node_shape(arg) is None:
continue
arg_list.append(arg)
flow_flag = self._assgin_single_node_flow(
flow_flag = self._assign_single_node_flow(
arg,
start_idx,
end_idx,
@ -315,7 +315,7 @@ class TraceFlow(object):
chunk_info["args"]["prepose_nodes"] = prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop
# we need to log input nodes to avoid deleting them in the loop
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
# also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]:

@ -461,7 +461,7 @@ class TraceIndice(object):
nodes_in.append(node_in)
self._inherit_more_indice_from_node_with_exclude(node_in, node)
def _assgin_no_change_indice(self, node, idx):
def _assign_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx)
for node_in in node.args:
if type(node_in) == type(node):
@ -792,7 +792,7 @@ class TraceIndice(object):
self._add_dim(node_idx, i)
dim_from.reverse()
# inheirt indice from current node
# inherit indice from current node
if len(dim_from) != 0 and len(dim_to) != 0:
if dim_diff == 1:
if origin_shape[dim_from[0]] == 1:
@ -852,7 +852,7 @@ class TraceIndice(object):
elif "split" == node_name:
self._assign_split_indice(node, idx)
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
self._assgin_no_change_indice(node, idx)
self._assign_no_change_indice(node, idx)
elif "new_ones" == node_name:
self._assign_all_indice(node, idx)
elif "flatten" == node_name:
@ -914,7 +914,7 @@ class TraceIndice(object):
elif "conv2d" == node_name:
self._assign_conv2d_indice(node, idx)
elif "identity" == node_name:
self._assgin_no_change_indice(node, idx)
self._assign_no_change_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
self._assign_elementwise_indice(node, idx)
else:

Loading…
Cancel
Save