[examples] update autoparallel demo (#2061)

pull/2052/head
YuliangLiu0306 2022-12-01 18:50:58 +08:00 committed by GitHub
parent 1c1fe44305
commit edf4cd46c5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 14 deletions

View File

@ -15,7 +15,7 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
from colossalai.core import global_context as gpc
@ -26,8 +26,6 @@ from colossalai.nn.lr_scheduler import CosineAnnealingLR
from colossalai.utils import get_dataloader
DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()
BATCH_SIZE = 1024
NUM_EPOCHS = 10
def parse_args():
@ -37,14 +35,14 @@ def parse_args():
def synthesize_data():
img = torch.rand(BATCH_SIZE, 3, 32, 32)
label = torch.randint(low=0, high=10, size=(BATCH_SIZE,))
img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32)
label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,))
return img, label
def main():
args = parse_args()
colossalai.launch_from_torch(config={})
colossalai.launch_from_torch(config='./config.py')
logger = get_dist_logger()
@ -70,16 +68,16 @@ def main():
train_dataloader = get_dataloader(
dataset=train_dataset,
add_sampler=False,
add_sampler=True,
shuffle=True,
batch_size=BATCH_SIZE,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
test_dataloader = get_dataloader(
dataset=test_dataset,
add_sampler=False,
batch_size=BATCH_SIZE,
add_sampler=True,
batch_size=gpc.config.BATCH_SIZE,
pin_memory=True,
)
else:
@ -93,13 +91,13 @@ def main():
# trace the model with meta data
tracer = ColoTracer()
model = resnet50(num_classes=10).cuda()
input_sample = {'x': torch.rand([1024, 3, 32, 32]).to('meta')}
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
graph = tracer.trace(root=model, meta_args=input_sample)
gm = GraphModule(model, graph, model.__class__.__name__)
gm.recompile()
# prepare info for solver
solver_options = SolverOptions(fast=True)
solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED)
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
strategies_constructor.build_strategies_and_cost()
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
@ -126,9 +124,9 @@ def main():
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
# lr_scheduler
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=NUM_EPOCHS)
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
for epoch in range(NUM_EPOCHS):
for epoch in range(gpc.config.NUM_EPOCHS):
gm.train()
if args.synthetic:

View File

@ -0,0 +1,2 @@
BATCH_SIZE = 128
NUM_EPOCHS = 10