2022-12-30 17:02:14 +00:00
|
|
|
from typing import Dict, List, Tuple
|
|
|
|
|
|
|
|
import torch
|
|
|
|
import torch.distributed as dist
|
|
|
|
import torch.nn as nn
|
|
|
|
from torch.fx.graph import Graph
|
|
|
|
|
2023-04-04 09:40:45 +00:00
|
|
|
from colossalai._analyzer.fx.codegen import ActivationCheckpointCodeGen
|
|
|
|
from colossalai._analyzer.fx.graph_module import ColoGraphModule
|
|
|
|
from colossalai._analyzer.fx.passes import shape_prop_pass
|
|
|
|
from colossalai._analyzer.fx.tracer.tracer import ColoTracer
|
2022-12-30 17:02:14 +00:00
|
|
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass
|
|
|
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
2023-02-15 05:48:28 +00:00
|
|
|
from colossalai.auto_parallel.tensor_shard.options import DataloaderOption, ShardOption, SolverOptions, SolverPerference
|
2022-12-30 17:02:14 +00:00
|
|
|
from colossalai.auto_parallel.tensor_shard.sharding_strategy import CommAction
|
2023-09-19 06:20:26 +00:00
|
|
|
from colossalai.auto_parallel.tensor_shard.solver import CostGraph, Solver, StrategiesConstructor
|
2023-01-05 08:39:55 +00:00
|
|
|
from colossalai.device.alpha_beta_profiler import AlphaBetaProfiler
|
2022-12-30 17:02:14 +00:00
|
|
|
from colossalai.device.device_mesh import DeviceMesh
|
|
|
|
from colossalai.tensor.sharding_spec import ShardingSpec
|
|
|
|
|
|
|
|
|
|
|
|
class ModuleWrapper(nn.Module):
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
This class is used to wrap the original module, and add the sharding_spec_dict, origin_spec_dict, comm_actions_dict
|
|
|
|
into the forward function.
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
module: ColoGraphModule,
|
|
|
|
sharding_spec_dict: Dict[int, List[ShardingSpec]],
|
|
|
|
origin_spec_dict: Dict[int, ShardingSpec],
|
|
|
|
comm_actions_dict: Dict[int, Dict[str, CommAction]],
|
|
|
|
):
|
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
Args:
|
|
|
|
module: the original module
|
|
|
|
sharding_spec_dict: The sharding_spec_dict is used to record the target sharding specs of each tensor required in user node.
|
|
|
|
origin_spec_dict: The origin_spec_dict is used to record the original sharding spec of each tensor.
|
|
|
|
comm_actions_dict: The comm_actions_dict is used to record the communication actions of each tensor.
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
super(ModuleWrapper, self).__init__()
|
|
|
|
self.module = module
|
|
|
|
self.sharding_spec_dict = sharding_spec_dict
|
|
|
|
self.origin_spec_dict = origin_spec_dict
|
|
|
|
self.comm_actions_dict = comm_actions_dict
|
|
|
|
|
|
|
|
def forward(self, *args, **kwargs):
|
2023-09-19 06:20:26 +00:00
|
|
|
return self.module(
|
|
|
|
*args,
|
|
|
|
sharding_spec_convert_dict=self.sharding_spec_dict,
|
|
|
|
origin_node_sharding_spec_dict=self.origin_spec_dict,
|
|
|
|
comm_actions_dict=self.comm_actions_dict,
|
|
|
|
**kwargs,
|
|
|
|
)
|
2022-12-30 17:02:14 +00:00
|
|
|
|
|
|
|
|
|
|
|
def extract_meta_args_from_dataloader(data_loader: torch.utils.data.DataLoader, data_process_func: callable):
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
This method is used to extract the meta_args from the dataloader under the instruction of the data_process_func.
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
# TODO: implement this function
|
|
|
|
|
|
|
|
|
|
|
|
def extract_alpha_beta_for_device_mesh(alpha_beta_dict: Dict[Tuple[int], Tuple[float]], logical_mesh_shape: Tuple[int]):
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
This method is used to extract the mesh_alpha and mesh_beta for the given logical_mesh_shape
|
|
|
|
from the alpha_beta_dict. These two values will be used to estimate the communication cost.
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
# TODO: implement this function
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def build_strategy_constructor(
|
|
|
|
graph: Graph, device_mesh: DeviceMesh, solver_preference: str, dataloader_option: str, shard_option: str
|
|
|
|
):
|
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
This method is used to build the strategy_constructor for the given graph.
|
|
|
|
After this method, each node in the graph will have a strategies_vector which
|
|
|
|
is constructed by the related node handler.
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
|
|
|
if solver_preference == "standard":
|
2023-02-15 05:48:28 +00:00
|
|
|
solver_preference = SolverPerference.STANDARD
|
2023-09-19 06:20:26 +00:00
|
|
|
elif solver_preference == "tp":
|
2023-02-15 05:48:28 +00:00
|
|
|
solver_preference = SolverPerference.TP
|
2023-09-19 06:20:26 +00:00
|
|
|
elif solver_preference == "dp":
|
2023-02-15 05:48:28 +00:00
|
|
|
solver_preference = SolverPerference.DP
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
raise ValueError(f"Invalid solver_preference: {solver_preference}")
|
2023-02-15 05:48:28 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if dataloader_option == "replicated":
|
2023-02-15 05:48:28 +00:00
|
|
|
dataloader_option = DataloaderOption.REPLICATED
|
2023-09-19 06:20:26 +00:00
|
|
|
elif dataloader_option == "distributed":
|
2023-02-15 05:48:28 +00:00
|
|
|
dataloader_option = DataloaderOption.DISTRIBUTED
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
raise ValueError(f"Invalid dataloader_option: {dataloader_option}")
|
2023-02-15 05:48:28 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
if shard_option == "standard":
|
2023-02-15 05:48:28 +00:00
|
|
|
shard_option = ShardOption.STANDARD
|
2023-09-19 06:20:26 +00:00
|
|
|
elif shard_option == "shard":
|
2023-02-15 05:48:28 +00:00
|
|
|
shard_option = ShardOption.SHARD
|
2023-09-19 06:20:26 +00:00
|
|
|
elif shard_option == "shard_last_axis":
|
2023-02-15 05:48:28 +00:00
|
|
|
shard_option = ShardOption.SHARD_LAST_AXIS
|
2023-09-19 06:20:26 +00:00
|
|
|
elif shard_option == "full_shard":
|
2023-02-15 05:48:28 +00:00
|
|
|
shard_option = ShardOption.FULL_SHARD
|
|
|
|
else:
|
2023-09-19 06:20:26 +00:00
|
|
|
raise ValueError(f"Invalid shard_option: {shard_option}")
|
2023-02-15 05:48:28 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
solver_options = SolverOptions(
|
|
|
|
solver_perference=solver_preference, dataloader_option=dataloader_option, shard_option=shard_option
|
|
|
|
)
|
2022-12-30 17:02:14 +00:00
|
|
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
|
|
|
strategies_constructor.build_strategies_and_cost()
|
|
|
|
|
|
|
|
return strategies_constructor
|
|
|
|
|
|
|
|
|
2023-01-16 08:25:13 +00:00
|
|
|
def solve_solution(gm: ColoGraphModule, strategy_constructor: StrategiesConstructor, memory_budget: float = -1.0):
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
This method is used to solve the best solution for the given graph.
|
|
|
|
The solution is a list of integers, each integer represents the best strategy index of the corresponding node.
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2023-02-28 03:03:30 +00:00
|
|
|
# temporarily we use all nodes as liveness list, we count the backward memory cost together with
|
|
|
|
# forward memory cost into the node memory cost, and no activation checkpoint is used in this phase.
|
|
|
|
# graph_analyser = GraphAnalyser(gm)
|
|
|
|
# liveness_list = graph_analyser.liveness_analysis()
|
2022-12-30 17:02:14 +00:00
|
|
|
cost_graph = CostGraph(strategy_constructor.leaf_strategies)
|
|
|
|
cost_graph.simplify_graph()
|
2023-02-28 03:03:30 +00:00
|
|
|
solver = Solver(gm.graph, strategy_constructor, cost_graph, memory_budget=memory_budget)
|
2022-12-30 17:02:14 +00:00
|
|
|
ret = solver.call_solver_serialized_args()
|
|
|
|
solution = list(ret[0])
|
|
|
|
|
|
|
|
return solution
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def transform_to_sharded_model(
|
|
|
|
gm: ColoGraphModule,
|
|
|
|
meta_args: Dict,
|
|
|
|
solution: List[int],
|
|
|
|
device_mesh: DeviceMesh,
|
|
|
|
strategies_constructor: StrategiesConstructor,
|
|
|
|
overlap: bool = False,
|
|
|
|
):
|
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
This method is used to transform the original graph to the sharded graph.
|
|
|
|
The model parameters will be sharded according to the solution and the grad hooks
|
|
|
|
will be added to the sharded graph using the runtime_preparation_pass.
|
|
|
|
The communication node will be added into the graph using the runtime_apply_pass.
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
|
|
|
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
|
|
|
gm, solution, device_mesh, strategies_constructor, overlap=overlap
|
|
|
|
)
|
2022-12-30 17:02:14 +00:00
|
|
|
gm = runtime_apply_pass(gm)
|
2023-04-04 09:40:45 +00:00
|
|
|
shape_prop_pass(gm, *meta_args.values(), sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
2022-12-30 17:02:14 +00:00
|
|
|
gm.recompile()
|
|
|
|
sharding_spec_dicts = (sharding_spec_dict, origin_spec_dict, comm_actions_dict)
|
|
|
|
|
|
|
|
return gm, sharding_spec_dicts
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def initialize_device_mesh(
|
|
|
|
world_size: int = -1,
|
|
|
|
physical_devices: List[int] = None,
|
|
|
|
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
|
|
|
logical_mesh_shape: Tuple[int] = None,
|
|
|
|
logical_mesh_id: torch.Tensor = None,
|
|
|
|
):
|
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
This method is used to initialize the device mesh.
|
|
|
|
|
|
|
|
Args:
|
2023-01-11 06:03:49 +00:00
|
|
|
world_size: the size of device mesh. If the world_size is -1,
|
2022-12-30 17:02:14 +00:00
|
|
|
the world size will be set to the number of GPUs in the current machine.
|
2023-01-11 06:03:49 +00:00
|
|
|
physical_devices: the physical devices used to initialize the device mesh.
|
2022-12-30 17:02:14 +00:00
|
|
|
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
|
|
|
|
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
|
|
|
|
generated by profile_alpha_beta function.
|
|
|
|
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
2023-01-11 06:03:49 +00:00
|
|
|
mesh shape.
|
|
|
|
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
# if world_size is not set, use the world size from torch.distributed
|
|
|
|
if world_size == -1:
|
|
|
|
world_size = dist.get_world_size()
|
2023-01-11 06:03:49 +00:00
|
|
|
|
|
|
|
if physical_devices is None:
|
|
|
|
physical_devices = [i for i in range(world_size)]
|
|
|
|
physical_mesh = torch.tensor(physical_devices)
|
2022-12-30 17:02:14 +00:00
|
|
|
|
|
|
|
if alpha_beta_dict is None:
|
|
|
|
# if alpha_beta_dict is not given, use a series of executions to profile alpha and beta values for each device
|
2023-01-11 06:03:49 +00:00
|
|
|
ab_profiler = AlphaBetaProfiler(physical_devices)
|
|
|
|
alpha_beta_dict = ab_profiler.alpha_beta_dict
|
|
|
|
else:
|
|
|
|
ab_profiler = AlphaBetaProfiler(physical_devices, alpha_beta_dict=alpha_beta_dict)
|
2022-12-30 17:02:14 +00:00
|
|
|
|
2023-01-11 06:03:49 +00:00
|
|
|
if logical_mesh_shape is None and logical_mesh_id is None:
|
2022-12-30 17:02:14 +00:00
|
|
|
# search for the best logical mesh shape
|
2023-01-11 06:03:49 +00:00
|
|
|
logical_mesh_id = ab_profiler.search_best_logical_mesh()
|
|
|
|
logical_mesh_id = torch.Tensor(logical_mesh_id).to(torch.int)
|
|
|
|
logical_mesh_shape = logical_mesh_id.shape
|
|
|
|
|
|
|
|
# extract alpha and beta values for the chosen logical mesh shape
|
|
|
|
mesh_alpha, mesh_beta = ab_profiler.extract_alpha_beta_for_device_mesh()
|
|
|
|
|
|
|
|
elif logical_mesh_shape is not None and logical_mesh_id is None:
|
|
|
|
logical_mesh_id = physical_mesh.reshape(logical_mesh_shape)
|
|
|
|
|
|
|
|
# extract alpha and beta values for the chosen logical mesh shape
|
|
|
|
mesh_alpha, mesh_beta = extract_alpha_beta_for_device_mesh(alpha_beta_dict, logical_mesh_id)
|
2022-12-30 17:02:14 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
device_mesh = DeviceMesh(
|
|
|
|
physical_mesh_id=physical_mesh,
|
|
|
|
logical_mesh_id=logical_mesh_id,
|
|
|
|
mesh_alpha=mesh_alpha,
|
|
|
|
mesh_beta=mesh_beta,
|
|
|
|
init_process_group=True,
|
|
|
|
)
|
2022-12-30 17:02:14 +00:00
|
|
|
return device_mesh
|
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def initialize_model(
|
|
|
|
model: nn.Module,
|
|
|
|
meta_args: Dict[str, torch.Tensor],
|
|
|
|
device_mesh: DeviceMesh,
|
|
|
|
memory_budget: float = -1.0,
|
|
|
|
overlap: bool = False,
|
|
|
|
solver_preference: str = "standard",
|
|
|
|
dataloader_option: str = "replicated",
|
|
|
|
shard_option: str = "standard",
|
|
|
|
save_solver_solution: bool = False,
|
|
|
|
load_solver_solution: bool = False,
|
|
|
|
solution_path: str = None,
|
|
|
|
return_solution: bool = False,
|
|
|
|
):
|
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
This method is used to initialize the sharded model which could be used as normal pytorch model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: the model to be sharded.
|
|
|
|
meta_args: the meta_args is used to specify the input shapes of the model.
|
|
|
|
device_mesh: the device mesh to execute the model.
|
|
|
|
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
|
|
|
|
the memory budget will be infinity.
|
2023-02-08 07:02:31 +00:00
|
|
|
overlap(optional): the overlap is used to specify whether to overlap gradient communication and
|
|
|
|
backward computing.
|
2023-02-15 05:48:28 +00:00
|
|
|
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
|
|
|
|
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
|
|
|
|
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
|
|
|
|
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
|
|
|
|
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
|
|
|
|
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
|
2022-12-30 17:02:14 +00:00
|
|
|
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
|
|
|
to the solution_path.
|
|
|
|
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
|
|
|
from the solution_path.
|
|
|
|
solution_path(optional): the path to save or load the solution.
|
2023-01-03 06:23:33 +00:00
|
|
|
return_solution(optional): if the return_solution is True, the solution will be returned. The returned
|
|
|
|
solution will be used to debug or help to analyze the sharding result. Therefore, we will not just
|
|
|
|
return a series of integers, but return the best strategies.
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
2023-04-04 09:40:45 +00:00
|
|
|
tracer = ColoTracer(trace_act_ckpt=True, bias_addition_split=True)
|
2022-12-30 17:02:14 +00:00
|
|
|
|
|
|
|
graph = tracer.trace(root=model, meta_args=meta_args)
|
2023-04-04 09:40:45 +00:00
|
|
|
graph.set_codegen(ActivationCheckpointCodeGen())
|
2023-01-16 08:25:13 +00:00
|
|
|
gm = ColoGraphModule(model, graph, model.__class__.__name__)
|
2023-04-04 09:40:45 +00:00
|
|
|
|
|
|
|
shape_prop_pass(gm, *meta_args.values())
|
2022-12-30 17:02:14 +00:00
|
|
|
gm.recompile()
|
2023-02-15 05:48:28 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
strategies_constructor = build_strategy_constructor(
|
|
|
|
graph,
|
|
|
|
device_mesh,
|
|
|
|
solver_preference=solver_preference,
|
|
|
|
dataloader_option=dataloader_option,
|
|
|
|
shard_option=shard_option,
|
|
|
|
)
|
2022-12-30 17:02:14 +00:00
|
|
|
if load_solver_solution:
|
|
|
|
solution = torch.load(solution_path)
|
|
|
|
else:
|
|
|
|
solution = solve_solution(gm, strategies_constructor, memory_budget)
|
|
|
|
if save_solver_solution:
|
|
|
|
torch.save(solution, solution_path)
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
gm, sharding_spec_dicts = transform_to_sharded_model(
|
|
|
|
gm, meta_args, solution, device_mesh, strategies_constructor, overlap
|
|
|
|
)
|
2023-04-04 09:40:45 +00:00
|
|
|
|
2022-12-30 17:02:14 +00:00
|
|
|
model_to_return = ModuleWrapper(gm, *sharding_spec_dicts)
|
|
|
|
|
2023-01-03 06:23:33 +00:00
|
|
|
if return_solution:
|
|
|
|
solution_to_return = []
|
|
|
|
nodes = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies]
|
|
|
|
for index, node in enumerate(nodes):
|
2023-09-19 06:20:26 +00:00
|
|
|
solution_to_return.append(f"{node.name} {node.strategies_vector[solution[index]].name}")
|
2023-01-03 06:23:33 +00:00
|
|
|
return model_to_return, solution_to_return
|
|
|
|
else:
|
|
|
|
return model_to_return
|
2022-12-30 17:02:14 +00:00
|
|
|
|
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
def autoparallelize(
|
|
|
|
model: nn.Module,
|
|
|
|
meta_args: Dict[str, torch.Tensor] = None,
|
|
|
|
data_loader: torch.utils.data.DataLoader = None,
|
|
|
|
data_process_func: callable = None,
|
|
|
|
alpha_beta_dict: Dict[Tuple[int], Tuple[float]] = None,
|
|
|
|
logical_mesh_shape: Tuple[int] = None,
|
|
|
|
logical_mesh_id: torch.Tensor = None,
|
|
|
|
solver_preference: str = "standard",
|
|
|
|
dataloader_option: str = "replicated",
|
|
|
|
shard_option: str = "standard",
|
|
|
|
save_solver_solution: bool = False,
|
|
|
|
load_solver_solution: bool = False,
|
|
|
|
solver_solution_path: str = None,
|
|
|
|
return_solution: bool = False,
|
|
|
|
memory_budget: float = -1.0,
|
|
|
|
):
|
|
|
|
"""
|
2022-12-30 17:02:14 +00:00
|
|
|
This method is used to initialize the device mesh, extract the meta_args, and
|
|
|
|
use them to create a sharded model.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model: the model to be sharded.
|
|
|
|
meta_args(optional): the meta_args is used to specify the input shapes of the model.
|
|
|
|
If the meta_args is None, the meta_args will be extracted from the data_loader.
|
|
|
|
data_loader(optional): the data_loader to be used in normal training loop.
|
|
|
|
data_process_func(optional): the data_process_func is used to process the data from the data_loader.
|
|
|
|
alpha_beta_dict(optional): the alpha_beta_dict contains the alpha and beta values
|
|
|
|
for each devices. if the alpha_beta_dict is None, the alpha_beta_dict will be
|
|
|
|
generated by profile_alpha_beta function.
|
|
|
|
logical_mesh_shape(optional): the logical_mesh_shape is used to specify the logical
|
|
|
|
mesh shape. If the logical_mesh_shape is None, the logical_mesh_shape will be
|
|
|
|
generated by search_best_logical_mesh_shape function.
|
2023-01-11 06:03:49 +00:00
|
|
|
logical_mesh_id(optional): the logical_mesh_id is used to specify the logical mesh id.
|
2023-02-15 05:48:28 +00:00
|
|
|
solver_preference(optional): the solver_preference is used to specify which parallelism algorithm
|
|
|
|
has higher priority. The valid solver_preference could be 'standard', 'tp', or 'dp'.
|
|
|
|
dataloader_option(optional): the dataloader_option is used to specify which kind of data_loader will
|
|
|
|
be used. The valid dataloader_option could be 'replicated' or 'distributed'.
|
|
|
|
shard_option(optional): the shard_option is used to specify how many axes will be used to shard the
|
|
|
|
model. The valid shard_option could be 'standard', 'shard', 'shard_last_axis', or 'full_shard'.
|
2022-12-30 17:02:14 +00:00
|
|
|
save_solver_solution(optional): if the save_solver_solution is True, the solution will be saved
|
|
|
|
to the solution_path.
|
|
|
|
load_solver_solution(optional): if the load_solver_solution is True, the solution will be loaded
|
|
|
|
from the solution_path.
|
|
|
|
solver_solution_path(optional): the path to save or load the solution.
|
2023-01-03 06:23:33 +00:00
|
|
|
return_solution(optional): if the return_solution is True, the solution will be returned.
|
2022-12-30 17:02:14 +00:00
|
|
|
memory_budget(optional): the max cuda memory could be used. If the memory budget is -1.0,
|
|
|
|
the memory budget will be infinity.
|
2023-09-19 06:20:26 +00:00
|
|
|
"""
|
|
|
|
device_mesh = initialize_device_mesh(
|
|
|
|
alpha_beta_dict=alpha_beta_dict, logical_mesh_shape=logical_mesh_shape, logical_mesh_id=logical_mesh_id
|
|
|
|
)
|
2022-12-30 17:02:14 +00:00
|
|
|
if meta_args is None:
|
|
|
|
meta_args = extract_meta_args_from_dataloader(data_loader, data_process_func)
|
2023-01-03 06:23:33 +00:00
|
|
|
|
2023-09-19 06:20:26 +00:00
|
|
|
rst_to_unpack = initialize_model(
|
|
|
|
model,
|
|
|
|
meta_args,
|
|
|
|
device_mesh,
|
|
|
|
solver_preference=solver_preference,
|
|
|
|
dataloader_option=dataloader_option,
|
|
|
|
shard_option=shard_option,
|
|
|
|
save_solver_solution=save_solver_solution,
|
|
|
|
load_solver_solution=load_solver_solution,
|
|
|
|
solution_path=solver_solution_path,
|
|
|
|
return_solution=return_solution,
|
|
|
|
memory_budget=memory_budget,
|
|
|
|
)
|
2023-01-03 06:23:33 +00:00
|
|
|
|
|
|
|
if return_solution:
|
|
|
|
model, solution = rst_to_unpack
|
|
|
|
return model, solution
|
|
|
|
else:
|
|
|
|
model = rst_to_unpack
|
|
|
|
return model
|