[NFC] fix typo colossalai/amp auto_parallel autochunk (#3756)

pull/3780/head^2
digger yu 2023-05-19 13:50:00 +08:00 committed by GitHub
parent 21e29e2212
commit 32f81f14d4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 12 additions and 12 deletions

View File

@ -240,7 +240,7 @@ class GradScaler(object):
for grads in per_dtype_grads.values(): for grads in per_dtype_grads.values():
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device), torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
per_device_inv_scale.get(device)) per_device_inv_scale.get(device))
# For tensor parallel paramters it should be all-reduced over tensor parallel process group # For tensor parallel parameters it should be all-reduced over tensor parallel process group
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1: if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
vals = [val for val in per_device_found_inf._per_device_tensors.values()] vals = [val for val in per_device_found_inf._per_device_tensors.values()]
coalesced = _flatten_dense_tensors(vals) coalesced = _flatten_dense_tensors(vals)

View File

@ -325,7 +325,7 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
else: else:
_is_batch_dims_same = False _is_batch_dims_same = False
# retireve dimensions # retrieve dimensions
input_dim_00 = input_tensors[0].shape[-2] input_dim_00 = input_tensors[0].shape[-2]
input_dim_01 = input_tensors[0].shape[-1] input_dim_01 = input_tensors[0].shape[-1]
input_dim_10 = input_tensors[1].shape[-2] input_dim_10 = input_tensors[1].shape[-2]

View File

@ -219,7 +219,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
return gm return gm
def _act_annotataion_pass(gm: torch.fx.GraphModule): def _act_annotation_pass(gm: torch.fx.GraphModule):
""" """
This pass is used to add the act annotation to the new inserted nodes. This pass is used to add the act annotation to the new inserted nodes.
""" """

View File

@ -54,7 +54,7 @@ def size_processing(size: Union[int, torch.Size],
return size return size
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int], def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
strategies_constructor: StrategiesConstructor): strategies_constructor: StrategiesConstructor):
""" """
This method is used to stick the solution strategy to the nodes and add the information This method is used to stick the solution strategy to the nodes and add the information
@ -496,7 +496,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
device_mesh: DeviceMesh, device_mesh: DeviceMesh,
strategies_constructor: StrategiesConstructor, strategies_constructor: StrategiesConstructor,
overlap=False): overlap=False):
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass( gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
gm, solution, strategies_constructor) gm, solution, strategies_constructor)
gm = size_value_converting_pass(gm, device_mesh) gm = size_value_converting_pass(gm, device_mesh)
gm = node_args_converting_pass(gm, device_mesh) gm = node_args_converting_pass(gm, device_mesh)

View File

@ -64,7 +64,7 @@ class TraceFlow(object):
return False return False
return True return True
def _assgin_single_node_flow( def _assign_single_node_flow(
self, self,
arg_node: Node, arg_node: Node,
start_idx: int, start_idx: int,
@ -177,7 +177,7 @@ class TraceFlow(object):
if get_node_shape(arg) is None: if get_node_shape(arg) is None:
continue continue
arg_list.append(arg) arg_list.append(arg)
flow_flag = self._assgin_single_node_flow( flow_flag = self._assign_single_node_flow(
arg, arg,
start_idx, start_idx,
end_idx, end_idx,
@ -315,7 +315,7 @@ class TraceFlow(object):
chunk_info["args"]["prepose_nodes"] = prepose_nodes chunk_info["args"]["prepose_nodes"] = prepose_nodes
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx): def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
# we need to log input nodes to avoid deleteing them in the loop # we need to log input nodes to avoid deleting them in the loop
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1) chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
# also need to get some prepose node's arg out of non_chunk_inputs # also need to get some prepose node's arg out of non_chunk_inputs
for n in chunk_info["args"]["prepose_nodes"]: for n in chunk_info["args"]["prepose_nodes"]:

View File

@ -461,7 +461,7 @@ class TraceIndice(object):
nodes_in.append(node_in) nodes_in.append(node_in)
self._inherit_more_indice_from_node_with_exclude(node_in, node) self._inherit_more_indice_from_node_with_exclude(node_in, node)
def _assgin_no_change_indice(self, node, idx): def _assign_no_change_indice(self, node, idx):
self._assign_indice_as_input(node, idx) self._assign_indice_as_input(node, idx)
for node_in in node.args: for node_in in node.args:
if type(node_in) == type(node): if type(node_in) == type(node):
@ -792,7 +792,7 @@ class TraceIndice(object):
self._add_dim(node_idx, i) self._add_dim(node_idx, i)
dim_from.reverse() dim_from.reverse()
# inheirt indice from current node # inherit indice from current node
if len(dim_from) != 0 and len(dim_to) != 0: if len(dim_from) != 0 and len(dim_to) != 0:
if dim_diff == 1: if dim_diff == 1:
if origin_shape[dim_from[0]] == 1: if origin_shape[dim_from[0]] == 1:
@ -852,7 +852,7 @@ class TraceIndice(object):
elif "split" == node_name: elif "split" == node_name:
self._assign_split_indice(node, idx) self._assign_split_indice(node, idx)
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]): elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
self._assgin_no_change_indice(node, idx) self._assign_no_change_indice(node, idx)
elif "new_ones" == node_name: elif "new_ones" == node_name:
self._assign_all_indice(node, idx) self._assign_all_indice(node, idx)
elif "flatten" == node_name: elif "flatten" == node_name:
@ -914,7 +914,7 @@ class TraceIndice(object):
elif "conv2d" == node_name: elif "conv2d" == node_name:
self._assign_conv2d_indice(node, idx) self._assign_conv2d_indice(node, idx)
elif "identity" == node_name: elif "identity" == node_name:
self._assgin_no_change_indice(node, idx) self._assign_no_change_indice(node, idx)
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]): elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
self._assign_elementwise_indice(node, idx) self._assign_elementwise_indice(node, idx)
else: else: