mirror of https://github.com/hpcaitech/ColossalAI
[nfc] fix typo colossalai/cli fx kernel (#3847)
* fix typo colossalai/autochunk auto_parallel amp * fix typo colossalai/auto_parallel nn utils etc. * fix typo colossalai/auto_parallel autochunk fx/passes etc. * fix typo docs/ * change placememt_policy to placement_policy in docs/ and examples/ * fix typo colossalai/ applications/ * fix typo colossalai/cli fx kernelpull/3880/head
parent
281b33f362
commit
70c8cdecf4
|
@ -28,7 +28,7 @@ from .run import launch_multi_processes
|
||||||
type=str,
|
type=str,
|
||||||
default=None,
|
default=None,
|
||||||
help=
|
help=
|
||||||
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --includ,"
|
"Specify computing devices to NOT use during execution. Mutually exclusive with --include. Formatting is the same as --include,"
|
||||||
" only effective when used with --hostfile.")
|
" only effective when used with --hostfile.")
|
||||||
@click.option("--num_nodes",
|
@click.option("--num_nodes",
|
||||||
type=int,
|
type=int,
|
||||||
|
|
|
@ -38,7 +38,7 @@ class HostInfo:
|
||||||
|
|
||||||
# socket.getfqdn("127.0.0.1") does not return localhost
|
# socket.getfqdn("127.0.0.1") does not return localhost
|
||||||
# on some users' machines
|
# on some users' machines
|
||||||
# thus, we directly return True if hostname is locahost, 127.0.0.1 or 0.0.0.0
|
# thus, we directly return True if hostname is localhost, 127.0.0.1 or 0.0.0.0
|
||||||
if hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
|
if hostname in ("localhost", "127.0.0.1", "0.0.0.0"):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
|
@ -114,7 +114,7 @@ class MultiNodeRunner:
|
||||||
Receive messages from all hosts
|
Receive messages from all hosts
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
msg_from_node (dict): a dictionry which contains messages from each node
|
msg_from_node (dict): a dictionary which contains messages from each node
|
||||||
"""
|
"""
|
||||||
|
|
||||||
msg_from_node = dict()
|
msg_from_node = dict()
|
||||||
|
|
|
@ -298,7 +298,7 @@ def launch_multi_processes(args: Config) -> None:
|
||||||
# receive the stop status
|
# receive the stop status
|
||||||
msg_from_node = runner.recv_from_all()
|
msg_from_node = runner.recv_from_all()
|
||||||
|
|
||||||
# printe node status
|
# print node status
|
||||||
click.echo("\n====== Stopping All Nodes =====")
|
click.echo("\n====== Stopping All Nodes =====")
|
||||||
for hostname, msg in msg_from_node.items():
|
for hostname, msg in msg_from_node.items():
|
||||||
click.echo(f"{hostname}: {msg}")
|
click.echo(f"{hostname}: {msg}")
|
||||||
|
|
|
@ -197,7 +197,7 @@ class AlphaBetaProfiler:
|
||||||
dist.broadcast_object_list(broadcast_list, src=process_group[0])
|
dist.broadcast_object_list(broadcast_list, src=process_group[0])
|
||||||
alpha_beta_dict[process_group] = tuple(broadcast_list)
|
alpha_beta_dict[process_group] = tuple(broadcast_list)
|
||||||
|
|
||||||
# add symmetry pair to the apha_beta_dict
|
# add symmetry pair to the alpha_beta_dict
|
||||||
symmetry_ab_dict = {}
|
symmetry_ab_dict = {}
|
||||||
for process_group, alpha_beta_pair in alpha_beta_dict.items():
|
for process_group, alpha_beta_pair in alpha_beta_dict.items():
|
||||||
symmetry_process_group = (process_group[1], process_group[0])
|
symmetry_process_group = (process_group[1], process_group[0])
|
||||||
|
|
|
@ -51,7 +51,7 @@ class BiasAdditionModule(ABC):
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are
|
The kwargs for conv2d module is {} because the attributes like 'padding' or 'groups' are
|
||||||
considered during module initilizing. However, we need to consider those attributes as kwargs
|
considered during module initializing. However, we need to consider those attributes as kwargs
|
||||||
in F.conv2d.
|
in F.conv2d.
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -295,7 +295,7 @@ class ColoTracer(Tracer):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||||
# signal that the current tracing occurs within activaton checkpoint part
|
# signal that the current tracing occurs within activation checkpoint part
|
||||||
self.inside_torch_checkpoint_func = True
|
self.inside_torch_checkpoint_func = True
|
||||||
out = run_function(*args)
|
out = run_function(*args)
|
||||||
self.inside_torch_checkpoint_func = False
|
self.inside_torch_checkpoint_func = False
|
||||||
|
|
|
@ -92,7 +92,7 @@ class ColoTracer(Tracer):
|
||||||
return proxy
|
return proxy
|
||||||
|
|
||||||
# if graph is traced for auto parallelism module, some extra node will be added during
|
# if graph is traced for auto parallelism module, some extra node will be added during
|
||||||
# graph construction to deal with the compatability between bias addition and all reduce.
|
# graph construction to deal with the compatibility between bias addition and all reduce.
|
||||||
|
|
||||||
# if no extra manipulation is applied, we just pass the origin arguments to create_proxy function
|
# if no extra manipulation is applied, we just pass the origin arguments to create_proxy function
|
||||||
# to create node on computation graph
|
# to create node on computation graph
|
||||||
|
@ -208,7 +208,7 @@ class ColoTracer(Tracer):
|
||||||
self.proxy_cls = ColoProxy
|
self.proxy_cls = ColoProxy
|
||||||
self.tracer_type = TracerType.META
|
self.tracer_type = TracerType.META
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognised tracer type {tracer_type}")
|
raise ValueError(f"Unrecognized tracer type {tracer_type}")
|
||||||
|
|
||||||
def _meta_data_computing(self, kind, target, args, kwargs):
|
def _meta_data_computing(self, kind, target, args, kwargs):
|
||||||
|
|
||||||
|
@ -445,7 +445,7 @@ class ColoTracer(Tracer):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def forward(ctx, run_function, preserve_rng_state, *args):
|
def forward(ctx, run_function, preserve_rng_state, *args):
|
||||||
# signal that the current tracing occurs within activaton checkpoint part
|
# signal that the current tracing occurs within activation checkpoint part
|
||||||
self.inside_torch_checkpoint_func = True
|
self.inside_torch_checkpoint_func = True
|
||||||
out = run_function(*args)
|
out = run_function(*args)
|
||||||
self.inside_torch_checkpoint_func = False
|
self.inside_torch_checkpoint_func = False
|
||||||
|
|
|
@ -138,7 +138,7 @@ if HAS_MEM_EFF_ATTN:
|
||||||
elif attn_mask_type == AttnMaskType.causal: # gpt style
|
elif attn_mask_type == AttnMaskType.causal: # gpt style
|
||||||
attn_bias = LowerTriangularMask()
|
attn_bias = LowerTriangularMask()
|
||||||
|
|
||||||
if bias is not None: # alibi / relative position emebedding
|
if bias is not None: # alibi / relative position embedding
|
||||||
assert allow_alibi, "flash attention with bias is not supported in this system."
|
assert allow_alibi, "flash attention with bias is not supported in this system."
|
||||||
assert attn_mask_type == AttnMaskType.causal, \
|
assert attn_mask_type == AttnMaskType.causal, \
|
||||||
"attention with bias is only supported for causal attention so far."
|
"attention with bias is only supported for causal attention so far."
|
||||||
|
|
|
@ -43,7 +43,7 @@ class Config:
|
||||||
attn_prob_dropout_ratio: float # attention score dropout ratio
|
attn_prob_dropout_ratio: float # attention score dropout ratio
|
||||||
hidden_dropout_ratio: float # dropout ration before residual
|
hidden_dropout_ratio: float # dropout ration before residual
|
||||||
norm_first: bool # norm_first
|
norm_first: bool # norm_first
|
||||||
fp16: bool # fp16 presion
|
fp16: bool # fp16 precision
|
||||||
|
|
||||||
|
|
||||||
class MultiHeadAttention1DFunc(Function):
|
class MultiHeadAttention1DFunc(Function):
|
||||||
|
|
|
@ -43,7 +43,7 @@ def warmup_jit_fusion(batch_size: int,
|
||||||
seq_length: int = 512,
|
seq_length: int = 512,
|
||||||
vocab_size: int = 32768,
|
vocab_size: int = 32768,
|
||||||
dtype: torch.dtype = torch.float32):
|
dtype: torch.dtype = torch.float32):
|
||||||
""" Compilie JIT functions before the main training steps """
|
""" Compile JIT functions before the main training steps """
|
||||||
|
|
||||||
embed = Embedding(vocab_size, hidden_size).to(get_current_device())
|
embed = Embedding(vocab_size, hidden_size).to(get_current_device())
|
||||||
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
|
linear_1 = Linear(hidden_size, hidden_size * 4, skip_bias_add=True).to(get_current_device())
|
||||||
|
|
Loading…
Reference in New Issue