mirror of https://github.com/hpcaitech/ColossalAI
[test] refactored testing components (#324)
parent
4f26fabe4f
commit
6268446b81
|
@ -0,0 +1 @@
|
|||
from . import repeated_computed_layer, resnet, nested_model
|
|
@ -0,0 +1,49 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .utils import DummyDataGenerator
|
||||
from .registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
class SubNet(nn.Module):
|
||||
|
||||
def __init__(self, out_features) -> None:
|
||||
super().__init__()
|
||||
self.bias = nn.Parameter(torch.zeros(out_features))
|
||||
|
||||
def forward(self, x, weight):
|
||||
return F.linear(x, weight, self.bias)
|
||||
|
||||
|
||||
class NestedNet(nn.Module):
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.sub_fc = SubNet(5)
|
||||
self.fc2 = nn.Linear(5, 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.sub_fc(x, self.fc1.weight)
|
||||
x = self.fc1(x)
|
||||
x = self.fc2(x)
|
||||
return x
|
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
|
||||
def generate(self):
|
||||
data = torch.rand(16, 5)
|
||||
label = torch.randint(low=0, high=2, size=(16,))
|
||||
return data, label
|
||||
|
||||
|
||||
@non_distributed_component_funcs.register(name='nested_model')
|
||||
def get_training_components():
|
||||
model = NestedNet()
|
||||
trainloader = DummyDataLoader()
|
||||
testloader = DummyDataLoader()
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model, trainloader, testloader, optim, criterion
|
|
@ -0,0 +1,39 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
|
||||
class Registry:
|
||||
|
||||
def __init__(self):
|
||||
self._registry = dict()
|
||||
|
||||
def register(self, name):
|
||||
assert name not in self._registry
|
||||
|
||||
def _regsiter(callable_):
|
||||
self._registry[name] = callable_
|
||||
|
||||
return _regsiter
|
||||
|
||||
def get_callable(self, name: str):
|
||||
return self._registry[name]
|
||||
|
||||
def __iter__(self):
|
||||
self._idx = 0
|
||||
self._len = len(self._registry)
|
||||
self._names = list(self._registry.keys())
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
if self._idx < self._len:
|
||||
key = self._names[self._idx]
|
||||
callable_ = self._registry[key]
|
||||
self._idx += 1
|
||||
return callable_
|
||||
else:
|
||||
raise StopIteration
|
||||
|
||||
|
||||
non_distributed_component_funcs = Registry()
|
||||
model_paralle_component_funcs = Registry()
|
||||
|
||||
__all__ = ['non_distributed_component_funcs', 'model_paralle_component_funcs']
|
|
@ -0,0 +1,44 @@
|
|||
#!/usr/bin/env python
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from colossalai.nn import CheckpointModule
|
||||
from .utils.dummy_data_generator import DummyDataGenerator
|
||||
from .registry import non_distributed_component_funcs
|
||||
|
||||
|
||||
class NetWithRepeatedlyComputedLayers(CheckpointModule):
|
||||
"""
|
||||
This model is to test with layers which go through forward pass multiple times.
|
||||
In this model, the fc1 and fc2 call forward twice
|
||||
"""
|
||||
|
||||
def __init__(self, checkpoint=False) -> None:
|
||||
super().__init__(checkpoint=checkpoint)
|
||||
self.fc1 = nn.Linear(5, 5)
|
||||
self.fc2 = nn.Linear(5, 5)
|
||||
self.fc3 = nn.Linear(5, 2)
|
||||
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
|
||||
|
||||
def forward(self, x):
|
||||
for layer in self.layers:
|
||||
x = layer(x)
|
||||
return x
|
||||
|
||||
|
||||
class DummyDataLoader(DummyDataGenerator):
|
||||
|
||||
def generate(self):
|
||||
data = torch.rand(16, 5)
|
||||
label = torch.randint(low=0, high=2, size=(16,))
|
||||
return data, label
|
||||
|
||||
|
||||
@non_distributed_component_funcs.register(name='repeated_computed_layers')
|
||||
def get_training_components():
|
||||
model = NetWithRepeatedlyComputedLayers(checkpoint=True)
|
||||
trainloader = DummyDataLoader()
|
||||
testloader = DummyDataLoader()
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model, trainloader, testloader, optim, criterion
|
|
@ -0,0 +1,30 @@
|
|||
from torchvision.models import resnet18
|
||||
from .registry import non_distributed_component_funcs
|
||||
from pathlib import Path
|
||||
import os
|
||||
import torch
|
||||
from torchvision.transforms import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from colossalai.utils import get_dataloader
|
||||
|
||||
|
||||
def get_cifar10_dataloader(train):
|
||||
# build dataloaders
|
||||
dataset = CIFAR10(root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
train=train,
|
||||
transform=transforms.Compose(
|
||||
[transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]))
|
||||
dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True)
|
||||
return dataloader
|
||||
|
||||
|
||||
@non_distributed_component_funcs.register(name='resnet18')
|
||||
def get_resnet_training_components():
|
||||
model = resnet18(num_classes=10)
|
||||
trainloader = get_cifar10_dataloader(train=True)
|
||||
testloader = get_cifar10_dataloader(train=False)
|
||||
optim = torch.optim.Adam(model.parameters(), lr=0.001)
|
||||
criterion = torch.nn.CrossEntropyLoss()
|
||||
return model, trainloader, testloader, optim, criterion
|
|
@ -0,0 +1 @@
|
|||
from .dummy_data_generator import DummyDataGenerator
|
|
@ -0,0 +1,14 @@
|
|||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class DummyDataGenerator(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def generate(self):
|
||||
pass
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
return self.generate()
|
|
@ -0,0 +1,86 @@
|
|||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.context import Config
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
CONFIG = dict(parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)),
|
||||
fp16=dict(mode=None),
|
||||
clip_grad_norm=1.0)
|
||||
|
||||
|
||||
def run_train():
|
||||
for get_components_func in non_distributed_component_funcs:
|
||||
model, train_dataloader, _, optimizer, criterion = get_components_func()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader)
|
||||
|
||||
try:
|
||||
engine.train()
|
||||
for img, label in train_dataloader:
|
||||
engine.zero_grad()
|
||||
img = img.cuda()
|
||||
label = label.cuda()
|
||||
output = engine(img)
|
||||
loss = engine.criterion(output, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
break
|
||||
except IndexError:
|
||||
# if using apex amp, NetWithRepeatedlyComputedLayers will raise an index out of range issue
|
||||
# the following check fails in apex
|
||||
# if cached_x.grad_fn.next_functions[1][0].variable is not x:
|
||||
continue
|
||||
|
||||
|
||||
def run_with_no_amp():
|
||||
run_train()
|
||||
|
||||
|
||||
def run_with_torch_amp():
|
||||
# hack config
|
||||
CONFIG['fp16']['mode'] = AMP_TYPE.TORCH
|
||||
gpc._config = Config(CONFIG)
|
||||
run_train()
|
||||
|
||||
|
||||
def run_with_apex_amp():
|
||||
# hack config
|
||||
CONFIG['fp16']['mode'] = AMP_TYPE.APEX
|
||||
gpc._config = Config(CONFIG)
|
||||
run_train()
|
||||
|
||||
|
||||
def run_with_naive_amp():
|
||||
# hack config
|
||||
CONFIG['fp16']['mode'] = AMP_TYPE.NAIVE
|
||||
gpc._config = Config(CONFIG)
|
||||
run_train()
|
||||
|
||||
|
||||
def run_engine(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(config=dict(), rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
run_with_no_amp()
|
||||
run_with_torch_amp()
|
||||
run_with_apex_amp()
|
||||
run_with_naive_amp()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_engine():
|
||||
world_size = 4
|
||||
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_engine()
|
|
@ -1,110 +0,0 @@
|
|||
# !/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
|
||||
# Config
|
||||
BATCH_SIZE = 128
|
||||
IMG_SIZE = 224
|
||||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
),
|
||||
fp16=dict(mode=AMP_TYPE.APEX),
|
||||
clip_grad_norm=1.0
|
||||
)
|
||||
|
||||
|
||||
def run_engine(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(
|
||||
config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]
|
||||
)
|
||||
)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
drop_last=True)
|
||||
|
||||
# build optimizer
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader
|
||||
)
|
||||
logger = get_dist_logger()
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
engine.train()
|
||||
for img, label in train_dataloader:
|
||||
engine.zero_grad()
|
||||
img = img.cuda()
|
||||
label = label.cuda()
|
||||
output = engine(img)
|
||||
loss = engine.criterion(output, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
break
|
||||
|
||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||
|
||||
gpc.destroy()
|
||||
logger.info('Test engine finished')
|
||||
report_memory_usage("After testing")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_engine():
|
||||
world_size = 4
|
||||
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_engine()
|
|
@ -1,109 +0,0 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
|
||||
# Config
|
||||
BATCH_SIZE = 128
|
||||
IMG_SIZE = 224
|
||||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
),
|
||||
fp16=dict(
|
||||
mode=AMP_TYPE.NAIVE,
|
||||
clip_grad=1.0
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def run_engine(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(
|
||||
config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]
|
||||
)
|
||||
)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
drop_last=True)
|
||||
|
||||
# build optimizer
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader
|
||||
)
|
||||
logger = get_dist_logger()
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
engine.train()
|
||||
for img, label in train_dataloader:
|
||||
engine.zero_grad()
|
||||
img = img.cuda()
|
||||
label = label.cuda()
|
||||
output = engine(img)
|
||||
loss = engine.criterion(output, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
break
|
||||
|
||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||
|
||||
gpc.destroy()
|
||||
logger.info('Test engine finished')
|
||||
report_memory_usage("After testing")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_engine():
|
||||
world_size = 4
|
||||
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_engine()
|
|
@ -1,105 +0,0 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
|
||||
# Config
|
||||
BATCH_SIZE = 128
|
||||
IMG_SIZE = 224
|
||||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
),
|
||||
clip_grad_norm=1.0
|
||||
)
|
||||
|
||||
|
||||
def run_engine(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(
|
||||
config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]
|
||||
)
|
||||
)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
drop_last=True)
|
||||
|
||||
# build optimizer
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader
|
||||
)
|
||||
logger = get_dist_logger()
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
engine.train()
|
||||
for img, label in train_dataloader:
|
||||
engine.zero_grad()
|
||||
img = img.cuda()
|
||||
label = label.cuda()
|
||||
output = engine(img)
|
||||
loss = engine.criterion(output, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
break
|
||||
|
||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||
|
||||
gpc.destroy()
|
||||
logger.info('Test engine finished')
|
||||
report_memory_usage("After testing")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_engine():
|
||||
world_size = 4
|
||||
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_engine()
|
|
@ -1,107 +0,0 @@
|
|||
import os
|
||||
from functools import partial
|
||||
from pathlib import Path
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn as nn
|
||||
from colossalai.amp import AMP_TYPE
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.utils import free_port, get_dataloader, report_memory_usage
|
||||
from torch.optim import Adam
|
||||
from torchvision import transforms
|
||||
from torchvision.datasets import CIFAR10
|
||||
from torchvision.models import resnet18
|
||||
|
||||
# Config
|
||||
BATCH_SIZE = 128
|
||||
IMG_SIZE = 224
|
||||
DIM = 768
|
||||
NUM_CLASSES = 10
|
||||
NUM_ATTN_HEADS = 12
|
||||
|
||||
CONFIG = dict(
|
||||
parallel=dict(
|
||||
pipeline=dict(size=1),
|
||||
tensor=dict(size=1, mode=None)
|
||||
),
|
||||
fp16=dict(mode=AMP_TYPE.TORCH),
|
||||
clip_grad_norm=1.0
|
||||
)
|
||||
|
||||
|
||||
def run_engine(rank, world_size, port):
|
||||
# init dist env
|
||||
colossalai.launch(
|
||||
config=CONFIG,
|
||||
rank=rank,
|
||||
world_size=world_size,
|
||||
host='localhost',
|
||||
port=port,
|
||||
backend='nccl'
|
||||
)
|
||||
|
||||
# build model
|
||||
model = resnet18(num_classes=10)
|
||||
|
||||
# build dataloaders
|
||||
train_dataset = CIFAR10(
|
||||
root=Path(os.environ['DATA']),
|
||||
download=True,
|
||||
transform=transforms.Compose(
|
||||
[
|
||||
transforms.Resize(size=(IMG_SIZE, IMG_SIZE)),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
|
||||
]
|
||||
)
|
||||
)
|
||||
train_dataloader = get_dataloader(dataset=train_dataset,
|
||||
shuffle=True,
|
||||
batch_size=BATCH_SIZE,
|
||||
drop_last=True)
|
||||
|
||||
# build optimizer
|
||||
optimizer = Adam(model.parameters(), lr=0.001)
|
||||
criterion = nn.CrossEntropyLoss()
|
||||
|
||||
engine, train_dataloader, *args = colossalai.initialize(
|
||||
model=model,
|
||||
optimizer=optimizer,
|
||||
criterion=criterion,
|
||||
train_dataloader=train_dataloader
|
||||
)
|
||||
logger = get_dist_logger()
|
||||
rank = torch.distributed.get_rank()
|
||||
|
||||
engine.train()
|
||||
for img, label in train_dataloader:
|
||||
engine.zero_grad()
|
||||
img = img.cuda()
|
||||
label = label.cuda()
|
||||
output = engine(img)
|
||||
loss = engine.criterion(output, label)
|
||||
engine.backward(loss)
|
||||
engine.step()
|
||||
break
|
||||
|
||||
logger.info('Rank {} returns: {}'.format(rank, loss.item()))
|
||||
|
||||
gpc.destroy()
|
||||
logger.info('Test engine finished')
|
||||
report_memory_usage("After testing")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_engine():
|
||||
world_size = 4
|
||||
run_func = partial(run_engine, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_engine()
|
Loading…
Reference in New Issue