diff --git a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py index 474c56a61..e4aff13e4 100644 --- a/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py +++ b/examples/tutorial/auto_parallel/auto_parallel_with_resnet.py @@ -15,7 +15,7 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser -from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions +from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions from colossalai.auto_parallel.tensor_shard.solver.solver import Solver from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor from colossalai.core import global_context as gpc @@ -26,8 +26,6 @@ from colossalai.nn.lr_scheduler import CosineAnnealingLR from colossalai.utils import get_dataloader DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute() -BATCH_SIZE = 1024 -NUM_EPOCHS = 10 def parse_args(): @@ -37,14 +35,14 @@ def parse_args(): def synthesize_data(): - img = torch.rand(BATCH_SIZE, 3, 32, 32) - label = torch.randint(low=0, high=10, size=(BATCH_SIZE,)) + img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32) + label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,)) return img, label def main(): args = parse_args() - colossalai.launch_from_torch(config={}) + colossalai.launch_from_torch(config='./config.py') logger = get_dist_logger() @@ -70,16 +68,16 @@ def main(): train_dataloader = get_dataloader( dataset=train_dataset, - add_sampler=False, + add_sampler=True, shuffle=True, - batch_size=BATCH_SIZE, + batch_size=gpc.config.BATCH_SIZE, pin_memory=True, ) test_dataloader = get_dataloader( dataset=test_dataset, - add_sampler=False, - batch_size=BATCH_SIZE, + add_sampler=True, + batch_size=gpc.config.BATCH_SIZE, pin_memory=True, ) else: @@ -93,13 +91,13 @@ def main(): # trace the model with meta data tracer = ColoTracer() model = resnet50(num_classes=10).cuda() - input_sample = {'x': torch.rand([1024, 3, 32, 32]).to('meta')} + input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} graph = tracer.trace(root=model, meta_args=input_sample) gm = GraphModule(model, graph, model.__class__.__name__) gm.recompile() # prepare info for solver - solver_options = SolverOptions(fast=True) + solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED) strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) strategies_constructor.build_strategies_and_cost() cost_graph = CostGraph(strategies_constructor.leaf_strategies) @@ -126,9 +124,9 @@ def main(): optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4) # lr_scheduler - lr_scheduler = CosineAnnealingLR(optimizer, total_steps=NUM_EPOCHS) + lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) - for epoch in range(NUM_EPOCHS): + for epoch in range(gpc.config.NUM_EPOCHS): gm.train() if args.synthetic: diff --git a/examples/tutorial/auto_parallel/config.py b/examples/tutorial/auto_parallel/config.py new file mode 100644 index 000000000..fa14eda74 --- /dev/null +++ b/examples/tutorial/auto_parallel/config.py @@ -0,0 +1,2 @@ +BATCH_SIZE = 128 +NUM_EPOCHS = 10