mirror of https://github.com/hpcaitech/ColossalAI
[examples] update autoparallel demo (#2061)
parent
1c1fe44305
commit
edf4cd46c5
|
@ -15,7 +15,7 @@ from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pas
|
||||||
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.options import SolverOptions
|
from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver
|
||||||
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
@ -26,8 +26,6 @@ from colossalai.nn.lr_scheduler import CosineAnnealingLR
|
||||||
from colossalai.utils import get_dataloader
|
from colossalai.utils import get_dataloader
|
||||||
|
|
||||||
DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()
|
DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute()
|
||||||
BATCH_SIZE = 1024
|
|
||||||
NUM_EPOCHS = 10
|
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
|
@ -37,14 +35,14 @@ def parse_args():
|
||||||
|
|
||||||
|
|
||||||
def synthesize_data():
|
def synthesize_data():
|
||||||
img = torch.rand(BATCH_SIZE, 3, 32, 32)
|
img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32)
|
||||||
label = torch.randint(low=0, high=10, size=(BATCH_SIZE,))
|
label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,))
|
||||||
return img, label
|
return img, label
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
colossalai.launch_from_torch(config={})
|
colossalai.launch_from_torch(config='./config.py')
|
||||||
|
|
||||||
logger = get_dist_logger()
|
logger = get_dist_logger()
|
||||||
|
|
||||||
|
@ -70,16 +68,16 @@ def main():
|
||||||
|
|
||||||
train_dataloader = get_dataloader(
|
train_dataloader = get_dataloader(
|
||||||
dataset=train_dataset,
|
dataset=train_dataset,
|
||||||
add_sampler=False,
|
add_sampler=True,
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=gpc.config.BATCH_SIZE,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
test_dataloader = get_dataloader(
|
test_dataloader = get_dataloader(
|
||||||
dataset=test_dataset,
|
dataset=test_dataset,
|
||||||
add_sampler=False,
|
add_sampler=True,
|
||||||
batch_size=BATCH_SIZE,
|
batch_size=gpc.config.BATCH_SIZE,
|
||||||
pin_memory=True,
|
pin_memory=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -93,13 +91,13 @@ def main():
|
||||||
# trace the model with meta data
|
# trace the model with meta data
|
||||||
tracer = ColoTracer()
|
tracer = ColoTracer()
|
||||||
model = resnet50(num_classes=10).cuda()
|
model = resnet50(num_classes=10).cuda()
|
||||||
input_sample = {'x': torch.rand([1024, 3, 32, 32]).to('meta')}
|
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
|
||||||
graph = tracer.trace(root=model, meta_args=input_sample)
|
graph = tracer.trace(root=model, meta_args=input_sample)
|
||||||
gm = GraphModule(model, graph, model.__class__.__name__)
|
gm = GraphModule(model, graph, model.__class__.__name__)
|
||||||
gm.recompile()
|
gm.recompile()
|
||||||
|
|
||||||
# prepare info for solver
|
# prepare info for solver
|
||||||
solver_options = SolverOptions(fast=True)
|
solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED)
|
||||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||||
strategies_constructor.build_strategies_and_cost()
|
strategies_constructor.build_strategies_and_cost()
|
||||||
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
cost_graph = CostGraph(strategies_constructor.leaf_strategies)
|
||||||
|
@ -126,9 +124,9 @@ def main():
|
||||||
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
|
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
|
||||||
|
|
||||||
# lr_scheduler
|
# lr_scheduler
|
||||||
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=NUM_EPOCHS)
|
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS)
|
||||||
|
|
||||||
for epoch in range(NUM_EPOCHS):
|
for epoch in range(gpc.config.NUM_EPOCHS):
|
||||||
gm.train()
|
gm.train()
|
||||||
|
|
||||||
if args.synthetic:
|
if args.synthetic:
|
||||||
|
|
|
@ -0,0 +1,2 @@
|
||||||
|
BATCH_SIZE = 128
|
||||||
|
NUM_EPOCHS = 10
|
Loading…
Reference in New Issue