mirror of https://github.com/hpcaitech/ColossalAI
[hotfix] fix autoparallel demo (#2533)
parent
63199c6687
commit
f477a14f4a
|
@ -3,8 +3,9 @@ from torchvision.models import resnet50
|
|||
from tqdm import tqdm
|
||||
|
||||
import colossalai
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import autoparallelize
|
||||
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.device.device_mesh import DeviceMesh
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.lr_scheduler import CosineAnnealingLR
|
||||
|
||||
|
@ -22,9 +23,14 @@ def main():
|
|||
|
||||
# trace the model with meta data
|
||||
model = resnet50(num_classes=10).cuda()
|
||||
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
|
||||
|
||||
model = autoparallelize(model, input_sample)
|
||||
input_sample = {'x': torch.rand([gpc.config.BATCH_SIZE * torch.distributed.get_world_size(), 3, 32, 32]).to('meta')}
|
||||
device_mesh = DeviceMesh(physical_mesh_id=torch.tensor([0, 1, 2, 3]), mesh_shape=[2, 2], init_process_group=True)
|
||||
model, solution = initialize_model(model, input_sample, device_mesh=device_mesh, return_solution=True)
|
||||
|
||||
if gpc.get_global_rank() == 0:
|
||||
for node_strategy in solution:
|
||||
print(node_strategy)
|
||||
# build criterion
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
|
||||
|
@ -52,6 +58,7 @@ def main():
|
|||
output = model(img)
|
||||
train_loss = criterion(output, label)
|
||||
train_loss.backward(train_loss)
|
||||
torch.cuda.synchronize()
|
||||
optimizer.step()
|
||||
lr_scheduler.step()
|
||||
|
||||
|
|
Loading…
Reference in New Issue