mirror of https://github.com/hpcaitech/ColossalAI
[NFC] fix typo colossalai/amp auto_parallel autochunk (#3756)
parent
21e29e2212
commit
32f81f14d4
|
@ -240,7 +240,7 @@ class GradScaler(object):
|
|||
for grads in per_dtype_grads.values():
|
||||
torch._amp_foreach_non_finite_check_and_unscale_(grads, per_device_found_inf.get(device),
|
||||
per_device_inv_scale.get(device))
|
||||
# For tensor parallel paramters it should be all-reduced over tensor parallel process group
|
||||
# For tensor parallel parameters it should be all-reduced over tensor parallel process group
|
||||
if gpc.is_initialized(ParallelMode.MODEL) and gpc.get_world_size(ParallelMode.MODEL) > 1:
|
||||
vals = [val for val in per_device_found_inf._per_device_tensors.values()]
|
||||
coalesced = _flatten_dense_tensors(vals)
|
||||
|
|
|
@ -325,7 +325,7 @@ def matmul_meta_info(*args, **kwargs) -> Tuple[TrainCycleItem, TrainCycleItem, L
|
|||
else:
|
||||
_is_batch_dims_same = False
|
||||
|
||||
# retireve dimensions
|
||||
# retrieve dimensions
|
||||
input_dim_00 = input_tensors[0].shape[-2]
|
||||
input_dim_01 = input_tensors[0].shape[-1]
|
||||
input_dim_10 = input_tensors[1].shape[-2]
|
||||
|
|
|
@ -219,7 +219,7 @@ def _comm_spec_apply(gm: torch.fx.GraphModule):
|
|||
return gm
|
||||
|
||||
|
||||
def _act_annotataion_pass(gm: torch.fx.GraphModule):
|
||||
def _act_annotation_pass(gm: torch.fx.GraphModule):
|
||||
"""
|
||||
This pass is used to add the act annotation to the new inserted nodes.
|
||||
"""
|
||||
|
|
|
@ -54,7 +54,7 @@ def size_processing(size: Union[int, torch.Size],
|
|||
return size
|
||||
|
||||
|
||||
def solution_annotatation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
def solution_annotation_pass(gm: torch.fx.GraphModule, solution: List[int],
|
||||
strategies_constructor: StrategiesConstructor):
|
||||
"""
|
||||
This method is used to stick the solution strategy to the nodes and add the information
|
||||
|
@ -496,7 +496,7 @@ def runtime_preparation_pass(gm: torch.fx.GraphModule,
|
|||
device_mesh: DeviceMesh,
|
||||
strategies_constructor: StrategiesConstructor,
|
||||
overlap=False):
|
||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotatation_pass(
|
||||
gm, sharding_spec_convert_dict, origin_node_sharding_spec_dict, comm_actions_dict = solution_annotation_pass(
|
||||
gm, solution, strategies_constructor)
|
||||
gm = size_value_converting_pass(gm, device_mesh)
|
||||
gm = node_args_converting_pass(gm, device_mesh)
|
||||
|
|
|
@ -64,7 +64,7 @@ class TraceFlow(object):
|
|||
return False
|
||||
return True
|
||||
|
||||
def _assgin_single_node_flow(
|
||||
def _assign_single_node_flow(
|
||||
self,
|
||||
arg_node: Node,
|
||||
start_idx: int,
|
||||
|
@ -177,7 +177,7 @@ class TraceFlow(object):
|
|||
if get_node_shape(arg) is None:
|
||||
continue
|
||||
arg_list.append(arg)
|
||||
flow_flag = self._assgin_single_node_flow(
|
||||
flow_flag = self._assign_single_node_flow(
|
||||
arg,
|
||||
start_idx,
|
||||
end_idx,
|
||||
|
@ -315,7 +315,7 @@ class TraceFlow(object):
|
|||
chunk_info["args"]["prepose_nodes"] = prepose_nodes
|
||||
|
||||
def _get_non_chunk_inputs(self, chunk_info, start_idx, end_idx):
|
||||
# we need to log input nodes to avoid deleteing them in the loop
|
||||
# we need to log input nodes to avoid deleting them in the loop
|
||||
chunk_node_list = self.node_mgr.get_node_slice_by_idx(start_idx, end_idx + 1)
|
||||
# also need to get some prepose node's arg out of non_chunk_inputs
|
||||
for n in chunk_info["args"]["prepose_nodes"]:
|
||||
|
|
|
@ -461,7 +461,7 @@ class TraceIndice(object):
|
|||
nodes_in.append(node_in)
|
||||
self._inherit_more_indice_from_node_with_exclude(node_in, node)
|
||||
|
||||
def _assgin_no_change_indice(self, node, idx):
|
||||
def _assign_no_change_indice(self, node, idx):
|
||||
self._assign_indice_as_input(node, idx)
|
||||
for node_in in node.args:
|
||||
if type(node_in) == type(node):
|
||||
|
@ -792,7 +792,7 @@ class TraceIndice(object):
|
|||
self._add_dim(node_idx, i)
|
||||
dim_from.reverse()
|
||||
|
||||
# inheirt indice from current node
|
||||
# inherit indice from current node
|
||||
if len(dim_from) != 0 and len(dim_to) != 0:
|
||||
if dim_diff == 1:
|
||||
if origin_shape[dim_from[0]] == 1:
|
||||
|
@ -852,7 +852,7 @@ class TraceIndice(object):
|
|||
elif "split" == node_name:
|
||||
self._assign_split_indice(node, idx)
|
||||
elif any(i == node_name for i in ["to", "contiguous", "clone", "type", "float"]):
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
self._assign_no_change_indice(node, idx)
|
||||
elif "new_ones" == node_name:
|
||||
self._assign_all_indice(node, idx)
|
||||
elif "flatten" == node_name:
|
||||
|
@ -914,7 +914,7 @@ class TraceIndice(object):
|
|||
elif "conv2d" == node_name:
|
||||
self._assign_conv2d_indice(node, idx)
|
||||
elif "identity" == node_name:
|
||||
self._assgin_no_change_indice(node, idx)
|
||||
self._assign_no_change_indice(node, idx)
|
||||
elif any(n == node_name for n in ["sigmoid", "dropout", "relu", "silu", "gelu"]):
|
||||
self._assign_elementwise_indice(node, idx)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue