mirror of https://github.com/hpcaitech/ColossalAI
[examples] update autoparallel demo (#2061)
parent
1c1fe44305
commit
edf4cd46c5
|
@ -15,7 +15,7 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas
|
|||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions
|
||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||
from colossalai.core import global_context as gpc
|
||||
|
@ -26,8 +26,6 @@ from colossalai.nn.lr_scheduler import CosineAnnealingLR
|
|||
from colossalai.utils import get_dataloader
|
||||
|
||||
DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()
|
||||
BATCH_SIZE = 1024
|
||||
NUM_EPOCHS = 10
|
||||
|
||||
|
||||
def parse_args():
|
||||
|
@ -37,14 +35,14 @@ def parse_args():
|
|||
|
||||
|
||||
def synthesize_data():
|
||||
img = torch.rand(BATCH_SIZE, 3, 32, 32)
|
||||
label = torch.randint(low=0, high=10, size=(BATCH_SIZE,))
|
||||
img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32)
|
||||
label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,))
|
||||
return img, label
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
colossalai.launch_from_torch(config={})
|
||||
colossalai.launch_from_torch(config='./config.py')
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
@ -70,16 +68,16 @@ def main():
|
|||
|
||||
train_dataloader = get_dataloader(
|
||||
dataset=train_dataset,
|
||||
add_sampler=False,
|
||||
add_sampler=True,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
)
|
||||
|
||||
test_dataloader = get_dataloader(
|
||||
dataset=test_dataset,
|
||||
add_sampler=False,
|
||||
batch_size=BATCH_SIZE,
|
||||
add_sampler=True,
|
||||
batch_size=gpc.config.BATCH_SIZE,
|
||||
pin_memory=True,
|
||||
)
|
||||
else:
|
||||
|
@ -93,13 +91,13 @@ def main():
|
|||
# trace the model with meta data
|
||||
tracer = ColoTracer()
|
||||
model = resnet50(num_classes=10).cuda()
|
||||
input_sample = {'x': torch.rand([1024, 3, 32, 32]).to('meta')}
|
||||
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
|
||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||
gm.recompile()
|
||||
|
||||
# prepare info for solver
|
||||
solver_options = SolverOptions(fast=True)
|
||||
solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED)
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||
|
@ -126,9 +124,9 @@ def main():
|
|||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
|
||||
|
||||
# lr_scheduler
|
||||
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=NUM_EPOCHS)
|
||||
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
|
||||
|
||||
for epoch in range(NUM_EPOCHS):
|
||||
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||
gm.train()
|
||||
|
||||
if args.synthetic:
|
||||
|
|
|
@ -0,0 +1,2 @@
|
|||
BATCH_SIZE = 128
|
||||
NUM_EPOCHS = 10
|
Loading…
Reference in New Issue