[cli] support run as module option (#6135)

pull/6104/merge
Hongxin Liu 2 weeks ago committed by GitHub
parent cc40fe0e6f
commit 5a03d2696d
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -64,7 +64,8 @@ from .run import launch_multi_processes
"This will be converted to --arg1=1 --arg2=2 during execution", "This will be converted to --arg1=1 --arg2=2 during execution",
) )
@click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection") @click.option("--ssh-port", type=int, default=None, help="(optional) the port used for ssh connection")
@click.argument("user_script", type=str) @click.option("-m", type=str, default=None, help="run library module as a script (terminates option list)")
@click.argument("user_script", type=str, required=False, default=None)
@click.argument("user_args", nargs=-1) @click.argument("user_args", nargs=-1)
def run( def run(
host: str, host: str,
@ -77,8 +78,9 @@ def run(
master_port: int, master_port: int,
extra_launch_args: str, extra_launch_args: str,
ssh_port: int, ssh_port: int,
m: str,
user_script: str, user_script: str,
user_args: str, user_args: tuple,
) -> None: ) -> None:
""" """
To launch multiple processes on a single node or multiple nodes via command line. To launch multiple processes on a single node or multiple nodes via command line.
@ -102,9 +104,24 @@ def run(
# run with hostfile excluding the hosts selected # run with hostfile excluding the hosts selected
colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py colossalai run --hostfile <file_path> --master_addr host1 --exclude host2 --nprocs_per_node 4 train.py
""" """
if m is not None:
if m.endswith(".py"):
click.echo(f"Error: invalid Python module {m}. Did you use a wrong option? Try colossalai run --help")
exit()
if user_script is not None:
user_args = (user_script,) + user_args
user_script = m
m = True
else:
if user_script is None:
click.echo("Error: missing script argument. Did you use a wrong option? Try colossalai run --help")
exit()
if not user_script.endswith(".py"): if not user_script.endswith(".py"):
click.echo(f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help") click.echo(
f"Error: invalid Python file {user_script}. Did you use a wrong option? Try colossalai run --help"
)
exit() exit()
m = False
args_dict = locals() args_dict = locals()
args = Config(args_dict) args = Config(args_dict)

@ -113,6 +113,7 @@ def get_launch_command(
user_args: List[str], user_args: List[str],
node_rank: int, node_rank: int,
num_nodes: int, num_nodes: int,
run_as_module: bool,
extra_launch_args: str = None, extra_launch_args: str = None,
) -> str: ) -> str:
""" """
@ -155,6 +156,8 @@ def get_launch_command(
torch_version = version.parse(torch.__version__) torch_version = version.parse(torch.__version__)
assert torch_version.major >= 1 assert torch_version.major >= 1
if torch_version.major < 2 and run_as_module:
raise ValueError("Torch version < 2.0 does not support running as module")
if torch_version.major == 1 and torch_version.minor < 9: if torch_version.major == 1 and torch_version.minor < 9:
# torch distributed launch cmd with torch < 1.9 # torch distributed launch cmd with torch < 1.9
@ -198,7 +201,10 @@ def get_launch_command(
] ]
cmd += _arg_dict_to_list(default_torchrun_rdzv_args) cmd += _arg_dict_to_list(default_torchrun_rdzv_args)
cmd += _arg_dict_to_list(extra_launch_args) + [user_script] + user_args cmd += _arg_dict_to_list(extra_launch_args)
if run_as_module:
cmd.append("-m")
cmd += [user_script] + user_args
cmd = " ".join(cmd) cmd = " ".join(cmd)
return cmd return cmd
@ -294,6 +300,7 @@ def launch_multi_processes(args: Config) -> None:
user_args=args.user_args, user_args=args.user_args,
node_rank=node_id, node_rank=node_id,
num_nodes=len(active_device_pool), num_nodes=len(active_device_pool),
run_as_module=args.m,
extra_launch_args=args.extra_launch_args, extra_launch_args=args.extra_launch_args,
) )
runner.send(hostinfo=hostinfo, cmd=cmd) runner.send(hostinfo=hostinfo, cmd=cmd)

Loading…
Cancel
Save