|
|
@ -3,8 +3,9 @@ from torchvision.models import resnet50 |
|
|
|
from tqdm import tqdm |
|
|
|
from tqdm import tqdm |
|
|
|
|
|
|
|
|
|
|
|
import colossalai |
|
|
|
import colossalai |
|
|
|
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize |
|
|
|
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model |
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
from colossalai.core import global_context as gpc |
|
|
|
|
|
|
|
from colossalai.device.device_mesh import DeviceMesh |
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
from colossalai.logging import get_dist_logger |
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingLR |
|
|
|
from colossalai.nn.lr_scheduler import CosineAnnealingLR |
|
|
|
|
|
|
|
|
|
|
@ -22,9 +23,14 @@ def main(): |
|
|
|
|
|
|
|
|
|
|
|
# trace the model with meta data |
|
|
|
# trace the model with meta data |
|
|
|
model = resnet50(num_classes=10).cuda() |
|
|
|
model = resnet50(num_classes=10).cuda() |
|
|
|
|
|
|
|
|
|
|
|
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} |
|
|
|
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')} |
|
|
|
|
|
|
|
device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True) |
|
|
|
|
|
|
|
model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True) |
|
|
|
|
|
|
|
|
|
|
|
model = autoparallelize(model, input_sample) |
|
|
|
if gpc.get_global_rank() == 0: |
|
|
|
|
|
|
|
for node_strategy in solution: |
|
|
|
|
|
|
|
print(node_strategy) |
|
|
|
# build criterion |
|
|
|
# build criterion |
|
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
|
|
|
|
|
|
|
|
|
@ -52,6 +58,7 @@ def main(): |
|
|
|
output = model(img) |
|
|
|
output = model(img) |
|
|
|
train_loss = criterion(output, label) |
|
|
|
train_loss = criterion(output, label) |
|
|
|
train_loss.backward(train_loss) |
|
|
|
train_loss.backward(train_loss) |
|
|
|
|
|
|
|
torch.cuda.synchronize() |
|
|
|
optimizer.step() |
|
|
|
optimizer.step() |
|
|
|
lr_scheduler.step() |
|
|
|
lr_scheduler.step() |
|
|
|
|
|
|
|
|
|
|
|