|
|
|
@ -4,23 +4,14 @@ from pathlib import Path
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
from titans.utils import barrier_context |
|
|
|
|
from torch.fx import GraphModule |
|
|
|
|
from torchvision import transforms |
|
|
|
|
from torchvision.datasets import CIFAR10 |
|
|
|
|
from torchvision.models import resnet50 |
|
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
|
from colossalai.auto_parallel.passes.runtime_apply_pass import runtime_apply_pass |
|
|
|
|
from colossalai.auto_parallel.passes.runtime_preparation_pass import runtime_preparation_pass |
|
|
|
|
from colossalai.auto_parallel.tensor_shard.solver.cost_graph import CostGraph |
|
|
|
|
from colossalai.auto_parallel.tensor_shard.solver.graph_analysis import GraphAnalyser |
|
|
|
|
from colossalai.auto_parallel.tensor_shard.solver.options import DataloaderOption, SolverOptions |
|
|
|
|
from colossalai.auto_parallel.tensor_shard.solver.solver import Solver |
|
|
|
|
from colossalai.auto_parallel.tensor_shard.solver.strategies_constructor import StrategiesConstructor |
|
|
|
|
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize |
|
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
from colossalai.device.device_mesh import DeviceMesh |
|
|
|
|
from colossalai.fx.tracer.tracer import ColoTracer |
|
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingLR |
|
|
|
|
from colossalai.utils import get_dataloader |
|
|
|
@ -28,12 +19,6 @@ from colossalai.utils import get_dataloader
|
|
|
|
|
DATA_ROOT = Path(os.environ.get('DATA', '../data')).absolute() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
|
|
parser = argparse.ArgumentParser() |
|
|
|
|
parser.add_argument('-s', '--synthetic', action="store_true", help="use synthetic dataset instead of CIFAR10") |
|
|
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def synthesize_data(): |
|
|
|
|
img = torch.rand(gpc.config.BATCH_SIZE, 3, 32, 32) |
|
|
|
|
label = torch.randint(low=0, high=10, size=(gpc.config.BATCH_SIZE,)) |
|
|
|
@ -41,82 +26,15 @@ def synthesize_data():
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
args = parse_args() |
|
|
|
|
colossalai.launch_from_torch(config='./config.py') |
|
|
|
|
|
|
|
|
|
logger = get_dist_logger() |
|
|
|
|
|
|
|
|
|
if not args.synthetic: |
|
|
|
|
with barrier_context(): |
|
|
|
|
# build dataloaders |
|
|
|
|
train_dataset = CIFAR10(root=DATA_ROOT, |
|
|
|
|
download=True, |
|
|
|
|
transform=transforms.Compose([ |
|
|
|
|
transforms.RandomCrop(size=32, padding=4), |
|
|
|
|
transforms.RandomHorizontalFlip(), |
|
|
|
|
transforms.ToTensor(), |
|
|
|
|
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], |
|
|
|
|
std=[0.2023, 0.1994, 0.2010]), |
|
|
|
|
])) |
|
|
|
|
|
|
|
|
|
test_dataset = CIFAR10(root=DATA_ROOT, |
|
|
|
|
train=False, |
|
|
|
|
transform=transforms.Compose([ |
|
|
|
|
transforms.ToTensor(), |
|
|
|
|
transforms.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]), |
|
|
|
|
])) |
|
|
|
|
|
|
|
|
|
train_dataloader = get_dataloader( |
|
|
|
|
dataset=train_dataset, |
|
|
|
|
add_sampler=True, |
|
|
|
|
shuffle=True, |
|
|
|
|
batch_size=gpc.config.BATCH_SIZE, |
|
|
|
|
pin_memory=True, |
|
|
|
|
) |
|
|
|
|
|
|
|
|
|
test_dataloader = get_dataloader( |
|
|
|
|
dataset=test_dataset, |
|
|
|
|
add_sampler=True, |
|
|
|
|
batch_size=gpc.config.BATCH_SIZE, |
|
|
|
|
pin_memory=True, |
|
|
|
|
) |
|
|
|
|
else: |
|
|
|
|
train_dataloader, test_dataloader = None, None |
|
|
|
|
|
|
|
|
|
# initialize device mesh |
|
|
|
|
physical_mesh_id = torch.arange(0, 4) |
|
|
|
|
mesh_shape = (2, 2) |
|
|
|
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True) |
|
|
|
|
|
|
|
|
|
# trace the model with meta data |
|
|
|
|
tracer = ColoTracer() |
|
|
|
|
model = resnet50(num_classes=10).cuda() |
|
|
|
|
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} |
|
|
|
|
graph = tracer.trace(root=model, meta_args=input_sample) |
|
|
|
|
gm = GraphModule(model, graph, model.__class__.__name__) |
|
|
|
|
gm.recompile() |
|
|
|
|
|
|
|
|
|
# prepare info for solver |
|
|
|
|
solver_options = SolverOptions(dataloader_option=DataloaderOption.DISTRIBUTED) |
|
|
|
|
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options) |
|
|
|
|
strategies_constructor.build_strategies_and_cost() |
|
|
|
|
cost_graph = CostGraph(strategies_constructor.leaf_strategies) |
|
|
|
|
cost_graph.simplify_graph() |
|
|
|
|
graph_analyser = GraphAnalyser(gm) |
|
|
|
|
|
|
|
|
|
# solve the solution |
|
|
|
|
solver = Solver(gm.graph, strategies_constructor, cost_graph, graph_analyser) |
|
|
|
|
ret = solver.call_solver_serialized_args() |
|
|
|
|
solution = list(ret[0]) |
|
|
|
|
if gpc.get_global_rank() == 0: |
|
|
|
|
for index, node in enumerate(graph.nodes): |
|
|
|
|
print(node.name, node.strategies_vector[solution[index]].name) |
|
|
|
|
|
|
|
|
|
# process the graph for distributed training ability |
|
|
|
|
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(gm, solution, device_mesh) |
|
|
|
|
gm = runtime_apply_pass(gm) |
|
|
|
|
gm.recompile() |
|
|
|
|
|
|
|
|
|
model = autoparallelize(model, input_sample) |
|
|
|
|
# build criterion |
|
|
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
@ -127,65 +45,47 @@ def main():
|
|
|
|
|
lr_scheduler = CosineAnnealingLR(optimizer, total_steps=gpc.config.NUM_EPOCHS) |
|
|
|
|
|
|
|
|
|
for epoch in range(gpc.config.NUM_EPOCHS): |
|
|
|
|
gm.train() |
|
|
|
|
model.train() |
|
|
|
|
|
|
|
|
|
if args.synthetic: |
|
|
|
|
# if we use synthetic data |
|
|
|
|
# we assume it only has 30 steps per epoch |
|
|
|
|
num_steps = range(30) |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
# we use the actual number of steps for training |
|
|
|
|
num_steps = range(len(train_dataloader)) |
|
|
|
|
data_iter = iter(train_dataloader) |
|
|
|
|
progress = tqdm(num_steps) |
|
|
|
|
|
|
|
|
|
for _ in progress: |
|
|
|
|
if args.synthetic: |
|
|
|
|
# generate fake data |
|
|
|
|
img, label = synthesize_data() |
|
|
|
|
else: |
|
|
|
|
# get the real data |
|
|
|
|
img, label = next(data_iter) |
|
|
|
|
|
|
|
|
|
img = img.cuda() |
|
|
|
|
label = label.cuda() |
|
|
|
|
optimizer.zero_grad() |
|
|
|
|
output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) |
|
|
|
|
output = model(img) |
|
|
|
|
train_loss = criterion(output, label) |
|
|
|
|
train_loss.backward(train_loss) |
|
|
|
|
optimizer.step() |
|
|
|
|
lr_scheduler.step() |
|
|
|
|
|
|
|
|
|
# run evaluation |
|
|
|
|
gm.eval() |
|
|
|
|
model.eval() |
|
|
|
|
correct = 0 |
|
|
|
|
total = 0 |
|
|
|
|
|
|
|
|
|
if args.synthetic: |
|
|
|
|
# if we use synthetic data |
|
|
|
|
# we assume it only has 10 steps for evaluation |
|
|
|
|
num_steps = range(30) |
|
|
|
|
|
|
|
|
|
else: |
|
|
|
|
# we use the actual number of steps for training |
|
|
|
|
num_steps = range(len(test_dataloader)) |
|
|
|
|
data_iter = iter(test_dataloader) |
|
|
|
|
progress = tqdm(num_steps) |
|
|
|
|
|
|
|
|
|
for _ in progress: |
|
|
|
|
if args.synthetic: |
|
|
|
|
# generate fake data |
|
|
|
|
img, label = synthesize_data() |
|
|
|
|
else: |
|
|
|
|
# get the real data |
|
|
|
|
img, label = next(data_iter) |
|
|
|
|
|
|
|
|
|
img = img.cuda() |
|
|
|
|
label = label.cuda() |
|
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
|
|
output = gm(img, sharding_spec_dict, origin_spec_dict, comm_actions_dict) |
|
|
|
|
output = model(img) |
|
|
|
|
test_loss = criterion(output, label) |
|
|
|
|
pred = torch.argmax(output, dim=-1) |
|
|
|
|
correct += torch.sum(pred == label) |
|
|
|
|