mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
174 lines
7.3 KiB
174 lines
7.3 KiB
import os |
|
|
|
import torch |
|
import torch.distributed as dist |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from timm.models.vision_transformer import _create_vision_transformer |
|
from titans.dataloader.imagenet import build_dali_imagenet |
|
from tqdm import tqdm |
|
from vit import DummyDataLoader |
|
|
|
import colossalai |
|
from colossalai.core import global_context as gpc |
|
from colossalai.logging import disable_existing_loggers, get_dist_logger |
|
from colossalai.nn import CrossEntropyLoss |
|
from colossalai.nn._ops import * |
|
from colossalai.nn.lr_scheduler import CosineAnnealingWarmupLR |
|
from colossalai.nn.optimizer import HybridAdam |
|
from colossalai.nn.parallel.data_parallel import ColoDDP |
|
from colossalai.tensor import ComputePattern, ComputeSpec, DistSpecManager, ProcessGroup, ShardSpec |
|
from colossalai.utils import get_current_device |
|
from colossalai.utils.model.colo_init_context import ColoInitContext |
|
|
|
|
|
def init_1d_row_for_linear_weight_spec(model, world_size: int): |
|
pg = ProcessGroup(tp_degree=world_size) |
|
spec = (ShardSpec([-1], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
|
with DistSpecManager.no_grad(): |
|
for n, p in model.named_parameters(): |
|
if 'weight' in n and 'norm' not in n and 'patch_embed.proj.weight' not in n: |
|
p.set_process_group(pg) |
|
p.set_tensor_spec(*spec) |
|
|
|
|
|
# Similarly, it's col split for Linear but row split for others. |
|
def init_1d_col_for_linear_weight_bias_spec(model, world_size: int): |
|
pg = ProcessGroup(tp_degree=world_size) |
|
spec = (ShardSpec([0], [pg.tp_world_size()]), ComputeSpec(ComputePattern.TP1D)) |
|
with DistSpecManager.no_grad(): |
|
for n, p in model.named_parameters(): |
|
if ('weight' in n or 'bias' in n) and 'norm' not in n and ('patch_embed.proj.weight' not in n |
|
and 'patch_embed.proj.bias' not in n): |
|
p.set_process_group(pg) |
|
p.set_tensor_spec(*spec) |
|
|
|
|
|
def init_spec_func(model, tp_type): |
|
world_size = torch.distributed.get_world_size() |
|
if tp_type == 'row': |
|
init_1d_row_for_linear_weight_spec(model, world_size) |
|
elif tp_type == 'col': |
|
init_1d_col_for_linear_weight_bias_spec(model, world_size) |
|
else: |
|
raise NotImplemented |
|
|
|
|
|
def train_imagenet(): |
|
|
|
parser = colossalai.get_default_parser() |
|
parser.add_argument('--resume_from', default=False, action='store_true') |
|
parser.add_argument('--dummy_data', default=False, action='store_true') |
|
|
|
args = parser.parse_args() |
|
colossalai.launch_from_torch(config=args.config) |
|
use_ddp = gpc.config.USE_DDP |
|
|
|
disable_existing_loggers() |
|
|
|
logger = get_dist_logger() |
|
if hasattr(gpc.config, 'LOG_PATH'): |
|
if gpc.get_global_rank() == 0: |
|
log_path = gpc.config.LOG_PATH |
|
if not os.path.exists(log_path): |
|
os.mkdir(log_path) |
|
logger.log_to_file(log_path) |
|
|
|
logger.info('Build data loader', ranks=[0]) |
|
if not args.dummy_data: |
|
root = os.environ['DATA'] |
|
train_dataloader, test_dataloader = build_dali_imagenet(root, |
|
train_batch_size=gpc.config.BATCH_SIZE, |
|
test_batch_size=gpc.config.BATCH_SIZE) |
|
else: |
|
train_dataloader = DummyDataLoader(length=10, |
|
batch_size=gpc.config.BATCH_SIZE, |
|
category=gpc.config.NUM_CLASSES, |
|
image_size=gpc.config.IMG_SIZE, |
|
return_dict=False) |
|
test_dataloader = DummyDataLoader(length=5, |
|
batch_size=gpc.config.BATCH_SIZE, |
|
category=gpc.config.NUM_CLASSES, |
|
image_size=gpc.config.IMG_SIZE, |
|
return_dict=False) |
|
|
|
logger.info('Build model', ranks=[0]) |
|
|
|
model_kwargs = dict(img_size=gpc.config.IMG_SIZE, |
|
patch_size=gpc.config.PATCH_SIZE, |
|
embed_dim=gpc.config.HIDDEN_SIZE, |
|
depth=gpc.config.DEPTH, |
|
num_heads=gpc.config.NUM_HEADS, |
|
mlp_ratio=gpc.config.MLP_RATIO, |
|
num_classes=gpc.config.NUM_CLASSES, |
|
drop_rate=0.1, |
|
attn_drop_rate=0.1, |
|
weight_init='jax') |
|
|
|
with ColoInitContext(device=get_current_device()): |
|
model = _create_vision_transformer('vit_small_patch16_224', pretrained=False, **model_kwargs) |
|
init_spec_func(model, gpc.config.TP_TYPE) |
|
|
|
world_size = torch.distributed.get_world_size() |
|
model = ColoDDP(module=model, process_group=ProcessGroup(tp_degree=world_size)) |
|
logger.info('Build criterion, optimizer, lr_scheduler', ranks=[0]) |
|
optimizer = HybridAdam(model.parameters(), lr=gpc.config.LEARNING_RATE, weight_decay=gpc.config.WEIGHT_DECAY) |
|
|
|
criterion = CrossEntropyLoss() |
|
lr_scheduler = CosineAnnealingWarmupLR(optimizer=optimizer, |
|
total_steps=gpc.config.NUM_EPOCHS, |
|
warmup_steps=gpc.config.WARMUP_EPOCHS) |
|
|
|
start_epoch = 0 |
|
if args.resume_from: |
|
load_model = torch.load(args.resume_from + '_model.pth') |
|
start_epoch = load_model['epoch'] |
|
model.load_state_dict(load_model['model']) |
|
load_optim = torch.load(args.resume_from + '_optim_rank_{}.pth'.format(dist.get_rank())) |
|
optimizer.load_state_dict(load_optim['optim']) |
|
|
|
for epoch in range(start_epoch, gpc.config.NUM_EPOCHS): |
|
model.train() |
|
for index, (x, y) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=False): |
|
x, y = x.cuda(), y.cuda() |
|
output = model(x) |
|
loss = criterion(output, y) |
|
loss = loss / gpc.config.gradient_accumulation |
|
if use_ddp: |
|
model.backward(loss) |
|
else: |
|
loss.backward() |
|
if (index + 1) % gpc.config.gradient_accumulation == 0: |
|
optimizer.step() |
|
if use_ddp: |
|
model.zero_grad() |
|
else: |
|
optimizer.zero_grad() |
|
|
|
logger.info( |
|
f"Finish Train Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {loss.item():.3f} lr: {optimizer.state_dict()['param_groups'][0]['lr']}", |
|
ranks=[0]) |
|
|
|
model.eval() |
|
test_loss = 0 |
|
correct = 0 |
|
test_sum = 0 |
|
with torch.no_grad(): |
|
for index, (x, y) in tqdm(enumerate(test_dataloader), total=len(test_dataloader), leave=False): |
|
x, y = x.cuda(), y.cuda() |
|
output = model(x) |
|
test_loss += F.cross_entropy(output, y, reduction='sum').item() |
|
pred = output.argmax(dim=1, keepdim=True) |
|
correct += pred.eq(y.view_as(pred)).sum().item() |
|
test_sum += y.size(0) |
|
|
|
test_loss /= test_sum |
|
logger.info( |
|
f"Finish Test Epoch [{epoch+1}/{gpc.config.NUM_EPOCHS}] loss: {test_loss:.3f} Accuracy: [{correct}/{test_sum}]({correct/test_sum:.3f})", |
|
ranks=[0]) |
|
|
|
lr_scheduler.step() |
|
|
|
|
|
if __name__ == '__main__': |
|
train_imagenet()
|
|
|