[test] merge old components to test to model zoo (#4945)

* [test] add custom models in model zoo

* [test] update legacy test

* [test] update model zoo

* [test] update gemini test

* [test] remove components to test
pull/4951/head
Hongxin Liu 2023-10-20 10:35:08 +08:00 committed by GitHub
parent 3a41e8304e
commit b8e770c832
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 461 additions and 914 deletions

View File

@ -9,6 +9,7 @@ from .comparison import (
)
from .pytest_wrapper import run_on_environment_flag
from .utils import (
DummyDataloader,
clear_cache_before_run,
free_port,
parameterize,
@ -34,4 +35,5 @@ __all__ = [
"run_on_environment_flag",
"check_state_dict_equal",
"assert_hf_output_close",
"DummyDataloader",
]

View File

@ -273,3 +273,24 @@ def clear_cache_before_run():
return _clear_cache
return _wrap_func
class DummyDataloader:
def __init__(self, data_gen_fn: Callable, length: int = 10):
self.data_gen_fn = data_gen_fn
self.length = length
self.step = 0
def __iter__(self):
self.step = 0
return self
def __next__(self):
if self.step < self.length:
self.step += 1
return self.data_gen_fn()
else:
raise StopIteration
def __len__(self):
return self.length

View File

@ -1,29 +0,0 @@
from . import (
beit,
bert,
gpt2,
hanging_param_model,
inline_op_model,
nested_model,
repeated_computed_layers,
resnet,
simple_net,
)
from .utils import run_fwd, run_fwd_bwd
from . import albert # isort:skip
__all__ = [
"bert",
"gpt2",
"hanging_param_model",
"inline_op_model",
"nested_model",
"repeated_computed_layers",
"resnet",
"simple_net",
"run_fwd_bwd",
"albert",
"beit",
"run_fwd",
]

View File

@ -1,62 +0,0 @@
import torch
from transformers import AlbertConfig, AlbertForSequenceClassification
from .bert import get_bert_data_loader
from .registry import non_distributed_component_funcs
@non_distributed_component_funcs.register(name="albert")
def get_training_components():
hidden_dim = 8
num_head = 4
sequence_length = 12
num_layer = 2
vocab_size = 32
def bert_model_builder(checkpoint: bool = False):
config = AlbertConfig(
vocab_size=vocab_size,
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,
max_position_embeddings=sequence_length,
num_hidden_layers=num_layer,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
)
print("building AlbertForSequenceClassification model")
# adapting huggingface BertForSequenceClassification for single unittest calling interface
class ModelAdaptor(AlbertForSequenceClassification):
def forward(self, input_ids, labels):
"""
inputs: data, label
outputs: loss
"""
return super().forward(input_ids=input_ids, labels=labels)[0]
model = ModelAdaptor(config)
# if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
# model.gradient_checkpointing_enable()
return model
is_distributed = torch.distributed.is_initialized()
trainloader = get_bert_data_loader(
n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distributed=is_distributed,
)
testloader = get_bert_data_loader(
n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distributed=is_distributed,
)
criterion = None
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,44 +0,0 @@
import torch
from timm.models.beit import Beit
from colossalai.utils.cuda import get_current_device
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class DummyDataLoader(DummyDataGenerator):
img_size = 64
num_channel = 3
num_class = 10
batch_size = 4
def generate(self):
data = torch.randn(
(
DummyDataLoader.batch_size,
DummyDataLoader.num_channel,
DummyDataLoader.img_size,
DummyDataLoader.img_size,
),
device=get_current_device(),
)
label = torch.randint(
low=0, high=DummyDataLoader.num_class, size=(DummyDataLoader.batch_size,), device=get_current_device()
)
return data, label
@non_distributed_component_funcs.register(name="beit")
def get_training_components():
def model_builder(checkpoint=False):
model = Beit(
img_size=DummyDataLoader.img_size, num_classes=DummyDataLoader.num_class, embed_dim=32, depth=2, num_heads=4
)
return model
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,88 +0,0 @@
import torch
import transformers
from packaging import version
from torch.utils.data import SequentialSampler
from transformers import BertConfig, BertForSequenceClassification
from .registry import non_distributed_component_funcs
def get_bert_data_loader(
n_class,
batch_size,
total_samples,
sequence_length,
device=torch.device("cpu:0"),
is_distributed=False,
):
train_data = torch.randint(
low=0,
high=n_class,
size=(total_samples, sequence_length),
device=device,
dtype=torch.long,
)
train_label = torch.randint(low=0, high=2, size=(total_samples,), device=device, dtype=torch.long)
train_dataset = torch.utils.data.TensorDataset(train_data, train_label)
if is_distributed:
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
else:
sampler = SequentialSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
return train_loader
@non_distributed_component_funcs.register(name="bert")
def get_training_components():
hidden_dim = 8
num_head = 4
sequence_length = 12
num_layer = 2
vocab_size = 32
def bert_model_builder(checkpoint: bool = False):
config = BertConfig(
vocab_size=vocab_size,
gradient_checkpointing=checkpoint,
hidden_size=hidden_dim,
intermediate_size=hidden_dim * 4,
num_attention_heads=num_head,
max_position_embeddings=sequence_length,
num_hidden_layers=num_layer,
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
)
# adapting huggingface BertForSequenceClassification for single unittest calling interface
class ModelAdaptor(BertForSequenceClassification):
def forward(self, input_ids, labels):
"""
inputs: data, label
outputs: loss
"""
return super().forward(input_ids=input_ids, labels=labels)[0]
model = ModelAdaptor(config)
if checkpoint and version.parse(transformers.__version__) >= version.parse("4.11.0"):
model.gradient_checkpointing_enable()
return model
is_distributed = torch.distributed.is_initialized()
trainloader = get_bert_data_loader(
n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distributed=is_distributed,
)
testloader = get_bert_data_loader(
n_class=vocab_size,
batch_size=2,
total_samples=10000,
sequence_length=sequence_length,
is_distributed=is_distributed,
)
criterion = None
return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,92 +0,0 @@
import torch
import torch.nn as nn
from transformers import GPT2Config, GPT2LMHeadModel
from colossalai.utils.cuda import get_current_device
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class DummyDataLoader(DummyDataGenerator):
vocab_size = 128
batch_size = 4
seq_len = 64
def generate(self):
input_ids = torch.randint(
0,
DummyDataLoader.vocab_size,
(DummyDataLoader.batch_size, DummyDataLoader.seq_len),
device=get_current_device(),
)
return input_ids, input_ids
class GPTLMModel(nn.Module):
def __init__(
self,
hidden_size=768,
num_layers=12,
num_attention_heads=12,
max_seq_len=1024,
vocab_size=50304,
checkpoint=False,
):
super().__init__()
self.checkpoint = checkpoint
self.model = GPT2LMHeadModel(
GPT2Config(
n_embd=hidden_size,
n_layer=num_layers,
n_head=num_attention_heads,
n_positions=max_seq_len,
n_ctx=max_seq_len,
vocab_size=vocab_size,
resid_pdrop=0.0,
embd_pdrop=0.0,
attn_pdrop=0.0,
)
)
if checkpoint:
self.model.gradient_checkpointing_enable()
def forward(self, input_ids):
# Only return lm_logits
attention_mask = torch.ones_like(input_ids)
return self.model(input_ids=input_ids, attention_mask=attention_mask, use_cache=not self.checkpoint)[0]
def gpt2_micro(checkpoint=True):
return GPTLMModel(
checkpoint=checkpoint, hidden_size=32, num_layers=2, num_attention_heads=4, max_seq_len=64, vocab_size=128
)
def gpt2_s(checkpoint=True):
return GPTLMModel(checkpoint=checkpoint)
def gpt2_m(checkpoint=True):
return GPTLMModel(hidden_size=1024, num_layers=24, num_attention_heads=16, checkpoint=checkpoint)
class GPTLMLoss(nn.Module):
def __init__(self):
super().__init__()
self.loss_fn = nn.CrossEntropyLoss()
def forward(self, logits, labels):
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
return self.loss_fn(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
@non_distributed_component_funcs.register(name="gpt2")
def get_training_components():
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = GPTLMLoss()
return gpt2_micro, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,48 +0,0 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.legacy.nn import CheckpointModule
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class HangingParamModule(CheckpointModule):
"""
Hanging Parameter: a parameter dose not belong to a leaf Module.
It has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.proj1 = nn.Linear(4, 8)
self.weight = nn.Parameter(torch.randn(8, 8))
self.proj2 = nn.Linear(8, 4)
def forward(self, x):
x = self.proj1(x)
x = F.linear(x, self.weight)
x = self.proj2(x)
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.rand(16, 4)
label = torch.randint(low=0, high=2, size=(16,))
return data, label
@non_distributed_component_funcs.register(name="hanging_param_model")
def get_training_components():
def model_builder(checkpoint=False):
return HangingParamModule(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
from colossalai.nn.optimizer import HybridAdam
return model_builder, trainloader, testloader, HybridAdam, criterion

View File

@ -1,49 +0,0 @@
import torch
import torch.nn as nn
from colossalai.legacy.nn import CheckpointModule
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class InlineOpModule(CheckpointModule):
"""
a module with inline Ops
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.proj1 = nn.Linear(4, 8)
self.proj2 = nn.Linear(8, 8)
def forward(self, x):
x = self.proj1(x)
# inline add_
x.add_(10)
x = self.proj2(x)
# inline relu_
x = torch.relu_(x)
x = self.proj2(x)
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.rand(16, 4)
label = torch.randint(low=0, high=2, size=(16,))
return data, label
@non_distributed_component_funcs.register(name="inline_op_model")
def get_training_components():
def model_builder(checkpoint=False):
return InlineOpModule(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
from colossalai.nn.optimizer import HybridAdam
return model_builder, trainloader, testloader, HybridAdam, criterion

View File

@ -1,38 +0,0 @@
#!/usr/bin/env python
class Registry:
def __init__(self):
self._registry = dict()
def register(self, name):
assert name not in self._registry
def _register(callable_):
self._registry[name] = callable_
return _register
def get_callable(self, name: str):
return self._registry[name]
def __iter__(self):
self._idx = 0
self._len = len(self._registry)
self._names = list(self._registry.keys())
return self
def __next__(self):
if self._idx < self._len:
key = self._names[self._idx]
callable_ = self._registry[key]
self._idx += 1
return callable_
else:
raise StopIteration
non_distributed_component_funcs = Registry()
model_parallel_component_funcs = Registry()
__all__ = ["non_distributed_component_funcs", "model_parallel_component_funcs"]

View File

@ -1,47 +0,0 @@
#!/usr/bin/env python
import torch
import torch.nn as nn
from colossalai.legacy.nn import CheckpointModule
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class NetWithRepeatedlyComputedLayers(CheckpointModule):
"""
This model is to test with layers which go through forward pass multiple times.
In this model, the fc1 and fc2 call forward twice
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 2)
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.rand(16, 5)
label = torch.randint(low=0, high=2, size=(16,))
return data, label
@non_distributed_component_funcs.register(name="repeated_computed_layers")
def get_training_components():
def model_builder(checkpoint=False):
return NetWithRepeatedlyComputedLayers(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,37 +0,0 @@
import os
from pathlib import Path
import torch
from torchvision.datasets import CIFAR10
from torchvision.models import resnet18
from torchvision.transforms import transforms
from colossalai.legacy.utils import get_dataloader
from .registry import non_distributed_component_funcs
def get_cifar10_dataloader(train):
# build dataloaders
dataset = CIFAR10(
root=Path(os.environ["DATA"]),
download=True,
train=train,
transform=transforms.Compose(
[transforms.ToTensor(), transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))]
),
)
dataloader = get_dataloader(dataset=dataset, shuffle=True, batch_size=16, drop_last=True)
return dataloader
@non_distributed_component_funcs.register(name="resnet18")
def get_resnet_training_components():
def model_builder(checkpoint=False):
return resnet18(num_classes=10)
trainloader = get_cifar10_dataloader(train=True)
testloader = get_cifar10_dataloader(train=False)
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -1,53 +0,0 @@
import torch
import torch.nn as nn
from colossalai.legacy.nn import CheckpointModule
from colossalai.utils.cuda import get_current_device
from .registry import non_distributed_component_funcs
from .utils.dummy_data_generator import DummyDataGenerator
class SimpleNet(CheckpointModule):
"""
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.embed = nn.Embedding(20, 4)
self.proj1 = nn.Linear(4, 8)
self.ln1 = nn.LayerNorm(8)
self.proj2 = nn.Linear(8, 4)
self.ln2 = nn.LayerNorm(4)
self.classifier = nn.Linear(4, 4)
def forward(self, x):
x = self.embed(x)
x = self.proj1(x)
x = self.ln1(x)
x = self.proj2(x)
x = self.ln2(x)
x = self.classifier(x)
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
label = torch.randint(low=0, high=2, size=(16,), device=get_current_device())
return data, label
@non_distributed_component_funcs.register(name="simple_net")
def get_training_components():
def model_builder(checkpoint=False):
return SimpleNet(checkpoint)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
from colossalai.nn.optimizer import HybridAdam
return model_builder, trainloader, testloader, HybridAdam, criterion

View File

@ -1,2 +0,0 @@
from .dummy_data_generator import DummyDataGenerator
from .executor import run_fwd, run_fwd_bwd

View File

@ -1,24 +0,0 @@
from abc import ABC, abstractmethod
class DummyDataGenerator(ABC):
def __init__(self, length=10):
self.length = length
@abstractmethod
def generate(self):
pass
def __iter__(self):
self.step = 0
return self
def __next__(self):
if self.step < self.length:
self.step += 1
return self.generate()
else:
raise StopIteration
def __len__(self):
return self.length

View File

@ -1,4 +1,5 @@
from . import diffusers, timm, torchaudio, torchrec, torchvision, transformers
from . import custom, diffusers, timm, torchaudio, torchrec, torchvision, transformers
from .executor import run_fwd, run_fwd_bwd
from .registry import model_zoo
__all__ = ["model_zoo"]
__all__ = ["model_zoo", "run_fwd", "run_fwd_bwd"]

View File

@ -0,0 +1,4 @@
from .hanging_param_model import *
from .nested_model import *
from .repeated_computed_layers import *
from .simple_net import *

View File

@ -0,0 +1,26 @@
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
class CheckpointModule(nn.Module):
def __init__(self, checkpoint: bool = False):
super().__init__()
self.checkpoint = checkpoint
self._use_checkpoint = checkpoint
def _forward(self, *args, **kwargs):
raise NotImplementedError("CheckpointModule should implement _forward method instead of origin forward")
def forward(self, *args, **kwargs):
if self._use_checkpoint:
return checkpoint(self._forward, *args, **kwargs)
else:
return self._forward(*args, **kwargs)
def train(self, mode: bool = True):
self._use_checkpoint = self.checkpoint
return super().train(mode=mode)
def eval(self):
self._use_checkpoint = False
return super().eval()

View File

@ -0,0 +1,48 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import model_zoo
from .base import CheckpointModule
class HangingParamModule(CheckpointModule):
"""
Hanging Parameter: a parameter dose not belong to a leaf Module.
It has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.proj1 = nn.Linear(4, 8)
self.weight = nn.Parameter(torch.randn(8, 8))
self.proj2 = nn.Linear(8, 4)
def forward(self, x):
x = self.proj1(x)
x = F.linear(x, self.weight)
x = self.proj2(x)
return x
def data_gen():
return dict(x=torch.rand(16, 4))
def loss_fn(x):
outputs = x["x"]
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return F.cross_entropy(x["x"], label)
def output_transform(x: torch.Tensor):
return dict(x=x)
model_zoo.register(
name="custom_hanging_param_model",
model_fn=HangingParamModule,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@ -2,10 +2,8 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from colossalai.legacy.nn import CheckpointModule
from .registry import non_distributed_component_funcs
from .utils import DummyDataGenerator
from ..registry import model_zoo
from .base import CheckpointModule
class SubNet(nn.Module):
@ -32,20 +30,24 @@ class NestedNet(CheckpointModule):
return x
class DummyDataLoader(DummyDataGenerator):
def generate(self):
data = torch.rand(16, 5)
label = torch.randint(low=0, high=2, size=(16,))
return data, label
def data_gen():
return dict(x=torch.rand(16, 5))
@non_distributed_component_funcs.register(name="nested_model")
def get_training_components():
def model_builder(checkpoint=False):
return NestedNet(checkpoint)
def loss_fn(x):
outputs = x["x"]
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return F.cross_entropy(x["x"], label)
trainloader = DummyDataLoader()
testloader = DummyDataLoader()
criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, torch.optim.Adam, criterion
def output_transform(x: torch.Tensor):
return dict(x=x)
model_zoo.register(
name="custom_nested_model",
model_fn=NestedNet,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@ -0,0 +1,48 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import model_zoo
from .base import CheckpointModule
class NetWithRepeatedlyComputedLayers(CheckpointModule):
"""
This model is to test with layers which go through forward pass multiple times.
In this model, the fc1 and fc2 call forward twice
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.fc1 = nn.Linear(5, 5)
self.fc2 = nn.Linear(5, 5)
self.fc3 = nn.Linear(5, 2)
self.layers = [self.fc1, self.fc2, self.fc1, self.fc2, self.fc3]
def forward(self, x):
for layer in self.layers:
x = layer(x)
return x
def data_gen():
return dict(x=torch.rand(16, 5))
def loss_fn(x):
outputs = x["x"]
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return F.cross_entropy(x["x"], label)
def output_transform(x: torch.Tensor):
return dict(x=x)
model_zoo.register(
name="custom_repeated_computed_layers",
model_fn=NetWithRepeatedlyComputedLayers,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@ -0,0 +1,53 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..registry import model_zoo
from .base import CheckpointModule
class SimpleNet(CheckpointModule):
"""
In this no-leaf module, it has subordinate nn.modules and a nn.Parameter.
"""
def __init__(self, checkpoint=False) -> None:
super().__init__(checkpoint=checkpoint)
self.embed = nn.Embedding(20, 4)
self.proj1 = nn.Linear(4, 8)
self.ln1 = nn.LayerNorm(8)
self.proj2 = nn.Linear(8, 4)
self.ln2 = nn.LayerNorm(4)
self.classifier = nn.Linear(4, 4)
def forward(self, x):
x = self.embed(x)
x = self.proj1(x)
x = self.ln1(x)
x = self.proj2(x)
x = self.ln2(x)
x = self.classifier(x)
return x
def data_gen():
return dict(x=torch.randint(low=0, high=20, size=(16,)))
def loss_fn(x):
outputs = x["x"]
label = torch.randint(low=0, high=2, size=(16,), device=outputs.device)
return F.cross_entropy(x["x"], label)
def output_transform(x: torch.Tensor):
return dict(x=x)
model_zoo.register(
name="custom_simple_net",
model_fn=SimpleNet,
data_gen_fn=data_gen,
output_transform_fn=output_transform,
loss_fn=loss_fn,
)

View File

@ -1,7 +1,15 @@
from typing import Callable, Dict, Optional, Union
import torch
from torch.nn import Module
from torch.optim import Optimizer
from colossalai.interface import OptimizerWrapper
def run_fwd(model, data, label, criterion) -> torch.Tensor:
def run_fwd(
model: Module, data: Dict, output_transform_fn: Callable, criterion: Optional[Callable] = None
) -> torch.Tensor:
"""run_fwd
run fwd for the model
@ -14,18 +22,22 @@ def run_fwd(model, data, label, criterion) -> torch.Tensor:
Returns:
torch.Tensor: loss of fwd
"""
outputs = model(**data)
outputs = output_transform_fn(outputs)
if criterion:
y = model(data)
y = y.float()
loss = criterion(y, label)
loss = criterion(outputs)
else:
loss = model(data, label)
loss = loss.float()
loss = next(iter(outputs.values())).sum()
return loss
def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
def run_fwd_bwd(
model: Module,
data: Dict,
output_transform_fn: Callable,
criterion: Optional[Callable] = None,
optimizer: Optional[Union[Optimizer, OptimizerWrapper]] = None,
) -> torch.Tensor:
"""run_fwd_bwd
run fwd and bwd for the model
@ -38,7 +50,7 @@ def run_fwd_bwd(model, data, label, criterion, optimizer=None) -> torch.Tensor:
Returns:
torch.Tensor: loss of fwd
"""
loss = run_fwd(model, data, label, criterion)
loss = run_fwd(model, data, output_transform_fn, criterion)
if optimizer:
optimizer.backward(loss)
else:

View File

@ -359,9 +359,9 @@ output_transform_fn = lambda x: x
# define loss funciton
loss_fn_for_bert_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn = lambda x: x.loss
loss_fn = lambda x: x["loss"]
config = transformers.BertConfig(
hidden_size=128,

View File

@ -35,7 +35,7 @@ def data_gen():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn_blip2_model = lambda x: x.loss
loss_fn_blip2_model = lambda x: x["loss"]
config = transformers.Blip2Config()
config.vision_config.patch_size = 14

View File

@ -69,11 +69,11 @@ output_transform_fn = lambda x: x
# define loss function
loss_fn_for_bloom_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn_for_causal_lm = lambda x: x.loss
loss_fn_for_classification = lambda x: x.loss
loss_fn_for_question_answering = lambda x: x.loss
loss_fn_for_causal_lm = lambda x: x["loss"]
loss_fn_for_classification = lambda x: x["loss"]
loss_fn_for_question_answering = lambda x: x["loss"]
config = transformers.BloomConfig(
n_layer=2, n_head=4, vocab_size=250880, hidden_dropout=0, attention_dropout=0, hidden_size=64, pad_token_id=50256

View File

@ -30,9 +30,9 @@ output_transform_fn = lambda x: x
# define loss function
loss_fn_for_chatglm_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn = lambda x: x.loss
loss_fn = lambda x: x["loss"]
config = ChatGLMConfig(
num_layers=2,

View File

@ -87,13 +87,14 @@ output_transform_fn = lambda x: x
# define loss function
loss_fn_for_gpt2_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn = lambda x: x.loss
loss_fn = lambda x: x["loss"]
config = transformers.GPT2Config(
n_layer=2,
n_head=4,
n_embd=128,
vocab_size=50258,
attn_pdrop=0,
embd_pdrop=0,

View File

@ -42,9 +42,9 @@ if HAS_LLAMA:
output_transform_fn = lambda x: x
# function to get the loss
loss_fn = lambda output: output.last_hidden_state.mean()
loss_fn_for_casual_lm = lambda output: output.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()
loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()
config = LlamaConfig(
num_hidden_layers=4,

View File

@ -45,9 +45,9 @@ def data_gen_for_question_answering():
output_transform_fn = lambda x: x
loss_fn_for_opt_model = lambda x: torch.nn.functional.mse_loss(
x.last_hidden_state, torch.ones_like(x.last_hidden_state)
x["last_hidden_state"], torch.ones_like(x["last_hidden_state"])
)
loss_fn_for_lm = lambda x: x.loss
loss_fn_for_lm = lambda x: x["loss"]
config = transformers.OPTConfig(
hidden_size=128,
num_hidden_layers=2,

View File

@ -40,7 +40,7 @@ def data_gen():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn = lambda x: x.iou_scores.mean()
loss_fn = lambda x: x["iou_scores"].mean()
config = transformers.SamConfig()
config.vision_config.num_hidden_layers = 2

View File

@ -44,9 +44,9 @@ def data_gen_for_t5_model():
output_transform_fn = lambda x: x
# define loss function
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
loss_fn_for_conditional_generation = lambda x: x.loss
loss_fn_for_t5_model = lambda x: x["last_hidden_state"].mean()
loss_fn_for_encoder_only = lambda x: x["last_hidden_state"].mean()
loss_fn_for_conditional_generation = lambda x: x["loss"]
# define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)

View File

@ -34,9 +34,9 @@ def data_gen_for_masked_image_modeling():
output_transform_fn = lambda x: x
# function to get the loss
loss_fn_for_vit_model = lambda x: x.pooler_output.mean()
loss_fn_for_image_classification = lambda x: x.logits.mean()
loss_fn_for_masked_image_modeling = lambda x: x.loss
loss_fn_for_vit_model = lambda x: x["pooler_output"].mean()
loss_fn_for_image_classification = lambda x: x["logits"].mean()
loss_fn_for_masked_image_modeling = lambda x: x["loss"]
# register the following models
# transformers.ViTModel,

View File

@ -53,8 +53,8 @@ def data_gen_for_audio_classification():
output_transform_fn = lambda x: x
# define loss funciton
loss_fn = lambda x: torch.nn.functional.mse_loss(x.last_hidden_state, torch.ones_like(x.last_hidden_state))
loss_fn_attr = lambda x: x.loss
loss_fn = lambda x: torch.nn.functional.mse_loss(x["last_hidden_state"], torch.ones_like(x["last_hidden_state"]))
loss_fn_attr = lambda x: x["loss"]
config = transformers.WhisperConfig(
classifier_proj_size=256,

View File

@ -6,7 +6,7 @@ import torch
import colossalai
from colossalai.legacy.amp import convert_to_apex_amp, convert_to_naive_amp
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo
def check_equal(a, b):
@ -25,13 +25,12 @@ def run_naive_amp():
torch.backends.cudnn.deterministic = True
# create layer
test_models = ["repeated_computed_layers", "nested_model", "resnet18"]
test_models = ["custom_repeated_computed_layers", "custom_nested_model", "torchvision_resnet18"]
for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name)
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values()))
# create model
naive_amp_model = model_builder(checkpoint=True).cuda()
naive_amp_model = model_builder().cuda()
apex_amp_model = copy.deepcopy(naive_amp_model)
# create optimizer
@ -48,13 +47,12 @@ def run_naive_amp():
apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config)
# create data
data_iter = iter(train_dataloader)
data, label = next(data_iter)
data = data.cuda()
data = data_gen_fn()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
# forward pass
naive_amp_output = naive_amp_model(data)
apex_amp_output = apex_amp_model(data)
naive_amp_output = naive_amp_model(**data)
apex_amp_output = apex_amp_model(**data)
assert_close_loose(naive_amp_output, apex_amp_output)
# backward

View File

@ -6,7 +6,7 @@ import torch
import colossalai
from colossalai.legacy.amp import convert_to_apex_amp, convert_to_torch_amp
from colossalai.testing import assert_close_loose, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo
def run_torch_amp():
@ -18,13 +18,12 @@ def run_torch_amp():
torch.backends.cudnn.deterministic = True
# create layer
test_models = ["resnet18", "simple_net"]
test_models = ["torchvision_resnet18", "custom_simple_net"]
for test_name in test_models:
get_component_func = non_distributed_component_funcs.get_callable(test_name)
model_builder, train_dataloader, _, optim_class, _ = get_component_func()
model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(test_name).values()))
# create model
torch_amp_model = model_builder(checkpoint=True).cuda()
torch_amp_model = model_builder().cuda()
apex_amp_model = copy.deepcopy(torch_amp_model)
# create optimizer
@ -41,13 +40,12 @@ def run_torch_amp():
apex_amp_model, apex_amp_optimizer = convert_to_apex_amp(apex_amp_model, apex_amp_optimizer, apex_amp_config)
# create data
data_iter = iter(train_dataloader)
data, label = next(data_iter)
data = data.cuda()
data = data_gen_fn()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
# forward pass
torch_amp_output = torch_amp_model(data)
apex_amp_output = apex_amp_model(data)
torch_amp_output = torch_amp_model(**data)
apex_amp_output = apex_amp_model(**data)
assert_close_loose(torch_amp_output, apex_amp_output)
for torch_amp_param, apex_amp_param in zip(torch_amp_model.parameters(), apex_amp_model.parameters()):

View File

@ -1,10 +1,11 @@
import pytest
import torch
import colossalai
from colossalai.legacy.amp import AMP_TYPE
from colossalai.legacy.core import global_context as gpc
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from tests.components_to_test.registry import non_distributed_component_funcs
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
CONFIG = dict(
parallel=dict(pipeline=dict(size=1), tensor=dict(size=1, mode=None)), fp16=dict(mode=None), clip_grad_norm=1.0
@ -15,29 +16,29 @@ CONFIG = dict(
@parameterize("amp_mode", [AMP_TYPE.APEX, AMP_TYPE.TORCH, AMP_TYPE.NAIVE, None])
def run_train(model_name, amp_mode):
# FIXME: test bert
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
train_dataloader = DummyDataloader(data_gen_fn)
criterion = lambda x: x.sum()
gpc.config.fp16["mode"] = amp_mode
model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model_builder(checkpoint=False)
model = model_builder()
engine, train_dataloader, *args = colossalai.legacy.initialize(
model=model,
optimizer=optimizer_class(model.parameters(), lr=1e-3),
optimizer=torch.optim.Adam(model.parameters(), lr=1e-3),
criterion=criterion,
train_dataloader=train_dataloader,
)
try:
engine.train()
for data, label in train_dataloader:
for data in train_dataloader:
engine.zero_grad()
data = data.cuda()
label = label.cuda()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
if criterion:
output = engine(data)
loss = engine.criterion(output, label)
output = engine(**data)
loss = engine.criterion(output)
else:
loss = engine(data, label)
loss = engine(**data)
engine.backward(loss)
engine.step()
break

View File

@ -5,9 +5,9 @@ import colossalai
from colossalai.legacy.amp.amp_type import AMP_TYPE
from colossalai.legacy.trainer import Trainer
from colossalai.logging import get_dist_logger
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import MultiTimer
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo
BATCH_SIZE = 4
IMG_SIZE = 32
@ -16,12 +16,14 @@ NUM_EPOCHS = 200
CONFIG = dict(fp16=dict(mode=AMP_TYPE.TORCH))
@parameterize("model_name", ["repeated_computed_layers", "resnet18", "nested_model"])
@parameterize("model_name", ["custom_repeated_computed_layers", "torchvision_resnet18", "custom_nested_model"])
def run_trainer(model_name):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model = model_builder()
optimizer = optimizer_class(model.parameters(), lr=1e-3)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
train_dataloader = DummyDataloader(data_gen_fn)
test_dataloader = DummyDataloader(data_gen_fn)
criterion = lambda x: x.sum()
engine, train_dataloader, *_ = colossalai.legacy.initialize(
model=model, optimizer=optimizer, criterion=criterion, train_dataloader=train_dataloader
)

View File

@ -2,7 +2,7 @@ import torch
from colossalai.nn.optimizer import CPUAdam, HybridAdam
from colossalai.testing import clear_cache_before_run, parameterize
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo
def move_some_params_to_cuda(model, torch_model):
@ -22,8 +22,7 @@ def check_params_equal(model, torch_model):
@parameterize("nvme_offload_dir", ["./offload", None])
@parameterize("adam_cls", [CPUAdam, HybridAdam])
def test_nvme_adam(nvme_offload_fraction, nvme_offload_dir, adam_cls):
get_components_func = non_distributed_component_funcs.get_callable("simple_net")
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, *_ = next(iter(model_zoo.get_sub_registry("custom_simple_net").values()))
model = model_builder()
torch_model = model_builder()
move_some_params_to_cuda(model, torch_model)

View File

@ -12,8 +12,7 @@ from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gather", [False, True])
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("master_weights", [False, True])
def exam_gpt_fwd_bwd(
@ -49,17 +48,22 @@ def exam_gpt_fwd_bwd(
master_weights: bool = True,
):
init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
set_seed(42)
model = model_builder(use_grad_checkpoint)
model = model_builder()
set_seed(42)
torch_model = model_builder(use_grad_checkpoint).cuda()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
if use_grad_checkpoint:
model.gradient_checkpointing_enable()
torch_model.gradient_checkpointing_enable()
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
@ -77,25 +81,22 @@ def exam_gpt_fwd_bwd(
torch_model = DDP(torch_model, device_ids=[rank])
set_seed(rank)
for i, (input_ids, label) in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
if i > 0:
break
input_ids, label = input_ids.cuda(), label.cuda()
torch_optim.zero_grad()
zero_optim.zero_grad()
data = data_gen_fn()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
# set random seed is same as torch_model.eval()
set_seed(42)
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
set_seed(42)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
torch_optim.zero_grad()
zero_optim.zero_grad()
assert torch.equal(torch_loss, loss)
# set random seed is same as torch_model.eval()
set_seed(42)
torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
set_seed(42)
loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
check_grad(model, torch_model)
assert_close(torch_loss.float(), loss.float())
check_grad(model, torch_model)
def run_dist(rank, world_size, port):

View File

@ -3,38 +3,34 @@ import torch
import torch.distributed as dist
import colossalai
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import search_chunk_configuration
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
# run gemini use the runtime memory tracer
@parameterize("placement_policy", ["auto"])
@parameterize("keep_gather", [False])
@parameterize("model_name", ["repeated_computed_layers", "bert", "albert", "gpt2"])
@parameterize("model_name", ["transformers_bert_for_sequence_classification"])
@parameterize("use_grad_checkpoint", [False, True])
def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_checkpoint: bool = False):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model = model_builder(use_grad_checkpoint).cuda()
model = model_builder().cuda()
if use_grad_checkpoint:
model.gradient_checkpointing_enable()
print(f"model_name {model_name}")
runtime_mem_tracer = RuntimeMemTracer(model)
for i, (input_ids, label) in enumerate(train_dataloader):
if i > 0:
break
input_ids, label = input_ids.cuda(), label.cuda()
# mem tracing
if i == 0:
run_fwd_bwd(runtime_mem_tracer, input_ids, label, criterion, runtime_mem_tracer)
runtime_mem_tracer = RuntimeMemTracer(model)
data = data_gen_fn()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer)
memstats = runtime_mem_tracer.memstats()
runtime_tracer_non_model_data = runtime_mem_tracer._memstats._non_model_data_cuda_list
print("runtime tracer non model data points: ", len(runtime_tracer_non_model_data))
@ -62,16 +58,17 @@ def run_gemini_use_rmt(placement_policy, keep_gather, model_name: str, use_grad_
)
set_seed(dist.get_rank())
for i, (input_ids, label) in enumerate(train_dataloader):
train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
# you can only test a single fwd + bwd.
# after bwd param is grad for Gemini, due to the chunk reuse optimization.
# print(f'iteration {i}')
if i > 4:
break
input_ids, label = input_ids.cuda(), label.cuda()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
set_seed(42)
run_fwd_bwd(model, input_ids, label, criterion, model)
run_fwd_bwd(model, data, output_transform_fn, optimizer=model)
gemini_non_model_data = model.gemini_manager._mem_stats_collector._memstats.non_model_data_list("cuda")

View File

@ -7,13 +7,12 @@ from torch.testing import assert_close
import colossalai
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo, run_fwd
PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
@ -38,7 +37,7 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
# Compare gradients.
for p0, p1 in zip(model.parameters(), torch_model.parameters()):
assert_close(p0, p1.grad, rtol=1e-3, atol=5e-5)
assert_close(p0, p1.grad, rtol=2e-3, atol=2e-2)
# Release gradient chunks and move them to gradient device.
for grad_chunk, device in zip(grad_chunk_list, device_list):
@ -48,21 +47,19 @@ def check_grad(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [False, True])
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("use_grad_checkpoint", [False, True])
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [False, True])
def exam_gemini_grad_acc(
placement_config, keep_gathered: bool, model_name: str, use_grad_checkpoint: bool, master_weights: bool
):
def exam_gemini_grad_acc(placement_config, keep_gathered: bool, model_name: str, master_weights: bool):
init_device = get_current_device()
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
set_seed(42)
gemini_model = model_builder(use_grad_checkpoint)
gemini_model = model_builder()
set_seed(42)
torch_model = model_builder(use_grad_checkpoint).cuda()
torch_model = model_builder().cuda()
for torch_p, p in zip(torch_model.parameters(), gemini_model.parameters()):
torch_p.data.copy_(p.data)
@ -94,22 +91,23 @@ def exam_gemini_grad_acc(
set_seed(rank)
accum_iter = 4
for i, (input_ids, label) in enumerate(train_dataloader):
train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
delay_unscale = False if (i + 1) % accum_iter == 0 else True
input_ids, label = input_ids.cuda(), label.cuda()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
set_seed(42 + rank)
torch_loss = run_fwd(torch_model, input_ids, label, criterion)
torch_loss = run_fwd(torch_model, data, output_transform_fn, loss_fn)
torch_loss = torch_loss / accum_iter
with amp.scale_loss(torch_loss, torch_optim, delay_unscale=delay_unscale) as scaled_loss:
scaled_loss.backward()
set_seed(42 + rank)
gemini_loss = run_fwd(gemini_model, input_ids, label, criterion)
gemini_loss = run_fwd(gemini_model, data, output_transform_fn, loss_fn)
gemini_loss = gemini_loss / accum_iter
gemini_optim.backward(gemini_loss)
assert torch.allclose(torch_loss, gemini_loss, rtol=1e-3, atol=1e-5)
assert torch.allclose(torch_loss.float(), gemini_loss.float(), rtol=1e-3, atol=1e-5)
check_grad(gemini_model, torch_model)

View File

@ -7,12 +7,11 @@ from torch.testing import assert_close
import colossalai
from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
PLACEMENT_CONFIGS = [
{
@ -51,12 +50,13 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"])
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("master_weights", [True, False])
def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
set_seed(1912)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
torch_model = model_builder().cuda()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=32)
@ -94,21 +94,17 @@ def exam_grad_clipping(placement_config, model_name: str, master_weights: bool):
torch_model.train()
set_seed(dist.get_rank() * 3 + 128)
for i, (data, label) in enumerate(train_dataloader):
train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
if i > 2:
break
data = data.cuda()
label = label.cuda()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
zero_optim.zero_grad()
torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, data, label, criterion, torch_optim)
loss = run_fwd_bwd(model, data, label, criterion, zero_optim)
# as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss)
run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
import apex.amp as apex_amp

View File

@ -9,13 +9,12 @@ from torch.testing import assert_close
import colossalai
from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo, run_fwd, run_fwd_bwd
PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
@ -53,12 +52,11 @@ def single_chunk_init(model: torch.nn.Module, placement_config: dict):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2"])
@parameterize("model_name", ["transformers_gpt_lm"])
@parameterize("model_init_func", [single_chunk_init, multi_chunk_init])
def exam_inference(placement_config: dict, model_name: str, model_init_func: Callable):
set_seed(19360226)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
torch_model = model_builder().cuda()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=128)
@ -79,29 +77,27 @@ def exam_inference(placement_config: dict, model_name: str, model_init_func: Cal
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
train_dataloader = iter(train_dataloader)
train_dataloader = iter(DummyDataloader(data_gen_fn))
def train_iter():
input_ids, label = next(train_dataloader)
input_ids, label = input_ids.cuda(), label.cuda()
data = next(train_dataloader)
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
zero_optim.zero_grad()
torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss, rtol=1e-5, atol=1e-5)
torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, optimizer=torch_optim)
loss = run_fwd_bwd(model, data, output_transform_fn, optimizer=zero_optim)
assert_close(torch_loss.float(), loss.float(), rtol=1e-5, atol=1e-5)
zero_optim.step()
torch_optim.step()
check_param(model, torch_model)
def inference_iter():
input_ids, label = next(train_dataloader)
input_ids, label = input_ids.cuda(), label.cuda()
data = next(train_dataloader)
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
with torch.no_grad():
torch_output = torch_model(input_ids)
torch_loss = criterion(torch_output.float(), label)
zero_output = model(input_ids)
zero_loss = criterion(zero_output.float(), label)
assert_close(torch_loss, zero_loss)
torch_loss = run_fwd(torch_model, data, output_transform_fn)
zero_loss = run_fwd(model, data, output_transform_fn)
assert_close(torch_loss.float(), zero_loss.float(), rtol=1e-5, atol=1e-5)
train_iter()
inference_iter()

View File

@ -1,20 +1,18 @@
import pytest
import torch
import torch.distributed as dist
from packaging.version import Version
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.testing import assert_close
import colossalai
from colossalai.legacy.amp import convert_to_apex_amp
from colossalai.nn.optimizer import HybridAdam
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.testing import DummyDataloader, parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.utils.cuda import get_current_device
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2
@ -32,14 +30,17 @@ PLACEMENT_CONFIGS = [
]
# this model is large enough to slice to chunks
TEST_MODELS = ["gpt2"]
TEST_MODELS = ["transformers_gpt_lm"]
# these models are too small, all parameters in these models are compacted into one chunk
EXAMPLE_MODELS = ["albert", "beit", "bert", "hanging_param_model", "nested_model", "repeated_computed_layers"]
EXAMPLE_MODELS = [
"transformers_bert_for_sequence_classification",
"custom_hanging_param_model",
"custom_nested_model",
"custom_repeated_computed_layers",
]
# bfloat16 cannot represent them exactly
BF16_IGNORED_KEYS = [
"albert.embeddings.word_embeddings.weight",
"albert.embeddings.position_embeddings.weight",
"masked_bias",
]
@ -55,7 +56,7 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
temp_zero_value = zero_dict[key].to(device=value.device)
if dtype is torch.bfloat16 and any(k in key for k in BF16_IGNORED_KEYS):
continue
rtol, atol = 1e-3, 4e-3
rtol, atol = 2e-3, 6e-3
if dtype is torch.bfloat16:
rtol, atol = 4e-3, 8e-3
# debug_print([0], "max range: ", key, torch.max(torch.abs(value - temp_zero_value)))
@ -74,8 +75,9 @@ def check_param(model: GeminiDDP, torch_model: torch.nn.Module, dtype: torch.dty
@parameterize("master_weights", [True, False])
def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dtype, master_weights: bool):
set_seed(42)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
torch_model = model_builder().cuda()
# apex no master weights leads to nan, so we don't use it
@ -104,19 +106,20 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
rtol, atol = 1e-4, 1e-5
for i, (input_ids, label) in enumerate(train_dataloader):
rtol, atol = 4e-2, 4e-2
train_dataloader = iter(DummyDataloader(data_gen_fn))
for i, data in enumerate(train_dataloader):
if i > 2:
break
input_ids, label = input_ids.cuda(), label.cuda()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
zero_optim.zero_grad()
torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
torch_loss = run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
loss = run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
# as no master weights leads to error accumulation, we don't check the loss
if master_weights:
assert_close(torch_loss, loss, rtol=rtol, atol=atol)
assert_close(torch_loss.float(), loss.float(), rtol=rtol, atol=atol)
zero_optim.step()
torch_optim.step()
@ -125,13 +128,14 @@ def exam_model_step(placement_config, model_name: str, mixed_precision: torch.dt
check_param(model, torch_model, mixed_precision)
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("placement_config", [PLACEMENT_CONFIGS[3]])
@parameterize("model_name", EXAMPLE_MODELS)
@parameterize("mixed_precision", [torch.half, torch.bfloat16])
@parameterize("mixed_precision", [torch.half])
def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.dtype):
set_seed(2008)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, loss_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
torch_model = model_builder().cuda()
amp_config = dict(opt_level="O2", keep_batchnorm_fp32=False, loss_scale=2)
@ -159,26 +163,19 @@ def exam_tiny_example(placement_config, model_name: str, mixed_precision: torch.
torch_model.eval()
set_seed(dist.get_rank() * 3 + 128)
rtol, atol = 1.5e-6, 2e-5
if mixed_precision is torch.bfloat16:
rtol, atol = 2e-3, 2e-3
elif Version(torch.__version__) >= Version("2.0.0"):
rtol, atol = 4e-5, 3e-5
for i, (input_ids, label) in enumerate(train_dataloader):
train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
if i > 2:
break
input_ids = input_ids.cuda()
label = label.cuda()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
zero_optim.zero_grad()
torch_optim.zero_grad()
torch_loss = run_fwd_bwd(torch_model, input_ids, label, criterion, torch_optim)
loss = run_fwd_bwd(model, input_ids, label, criterion, zero_optim)
assert_close(torch_loss, loss, rtol=rtol, atol=atol) # atol should be 2e-5 for torch lower than 1.12
run_fwd_bwd(torch_model, data, output_transform_fn, loss_fn, optimizer=torch_optim)
run_fwd_bwd(model, data, output_transform_fn, loss_fn, optimizer=zero_optim)
zero_optim.step()
torch_optim.step()

View File

@ -4,10 +4,9 @@ import numpy as np
import pytest
import torch
from colossalai.testing import clear_cache_before_run
from colossalai.testing import DummyDataloader, clear_cache_before_run
from colossalai.zero.gemini.memory_tracer.runtime_mem_tracer import RuntimeMemTracer
from tests.components_to_test import run_fwd_bwd
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo, run_fwd_bwd
@pytest.mark.skip("this is not used")
@ -16,21 +15,22 @@ def test_runtime_mem_tracer():
test_models = ["gpt2", "bert", "simple_net", "repeated_computed_layers", "nested_model", "albert"]
for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, _, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, *_ = next(
iter(model_zoo.get_sub_registry(model_name).values())
)
model = model_builder(checkpoint=False).cuda()
model = model_builder().cuda()
model_bk = deepcopy(model)
runtime_mem_tracer = RuntimeMemTracer(model)
for i, (data, label) in enumerate(train_dataloader):
train_dataloader = DummyDataloader(data_gen_fn)
for i, data in enumerate(train_dataloader):
if i > 1:
break
data = data.cuda()
label = label.cuda()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
run_fwd_bwd(runtime_mem_tracer, data, label, criterion, optimizer=runtime_mem_tracer)
run_fwd_bwd(runtime_mem_tracer, data, output_transform_fn, optimizer=runtime_mem_tracer)
for p1, p2 in zip(model_bk.parameters(), model.parameters()):
torch.allclose(p1.to(torch.half), p2)

View File

@ -5,40 +5,37 @@ import colossalai
from colossalai.testing import rerun_if_address_is_in_use, spawn
from colossalai.utils import get_current_device
from colossalai.zero.gemini.chunk import init_chunk_manager, search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo
def exam_search_chunk_size():
world_size = torch.distributed.get_world_size()
get_components_func = non_distributed_component_funcs.get_callable("gpt2")
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, *_ = next(
iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
)
# make sure torch_model and model has the same parameter values
model = model_builder()
config_dict, *_ = search_chunk_configuration(
model, search_range_m=1, search_interval=16, min_chunk_size_m=0, filter_exlarge_params=True
model, search_range_m=1, search_interval=128, min_chunk_size_m=0, filter_exlarge_params=True
)
for key in config_dict:
chunk_size = config_dict[key]["chunk_size"]
if world_size == 1 or True:
assert chunk_size == 31616
else:
assert chunk_size == 1024
assert chunk_size == 527872
def exam_chunk_manager():
world_size = torch.distributed.get_world_size()
get_components_func = non_distributed_component_funcs.get_callable("gpt2")
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, *_ = next(
iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
)
sharded_ddp_model = model_builder()
chunk_manager = init_chunk_manager(
sharded_ddp_model,
get_current_device(),
hidden_dim=16,
hidden_dim=128,
search_range_m=1,
min_chunk_size_m=0,
filter_exlarge_params=True,
@ -46,7 +43,7 @@ def exam_chunk_manager():
)
config_dict = chunk_manager.dp_degree_chunk_size_dict
assert len(config_dict) == 1
assert config_dict[world_size] == 31616
assert config_dict[world_size] == 527872
def run_dist(rank, world_size, port):

View File

@ -7,7 +7,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo
PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0}, # zero2
@ -26,15 +26,16 @@ def ignore_the_first_parameter(model: torch.nn.Module):
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [True, False])
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("model_name", ["transformers_gpt_lm", "transformers_bert_for_sequence_classification"])
@parameterize("master_weights", [False, True])
def exam_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, *_ = next(iter(model_zoo.get_sub_registry(model_name).values()))
model = model_builder()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
torch_model = model_builder()
for torch_p, p in zip(torch_model.parameters(), model.parameters()):
torch_p.data.copy_(p.data)
@ -54,29 +55,7 @@ def exam_state_dict(placement_config, keep_gathered, model_name: str, master_wei
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("keep_gathered", [True, False])
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("master_weights", [False, True])
def exam_load_state_dict(placement_config, keep_gathered, model_name: str, master_weights: bool):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder()
set_seed(451)
torch_model = model_builder() # get a different model
world_size = torch.distributed.get_world_size()
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
config_dict[world_size]["chunk_size"] = 5000
config_dict[world_size]["keep_gathered"] = keep_gathered
model = GeminiDDP(model, config_dict, **placement_config, pin_memory=True, master_weights=master_weights)
torch_dict = torch_model.state_dict()
# check load state dict
model.load_state_dict(torch_dict, strict=False)
zero_dict = model.state_dict(only_rank_0=False)
@ -85,23 +64,7 @@ def exam_load_state_dict(placement_config, keep_gathered, model_name: str, maste
temp_zero_value = zero_dict[key].to(device=value.device, dtype=value.dtype)
assert_close(value, temp_zero_value, rtol=1e-3, atol=1e-5)
@parameterize("placement_config", PLACEMENT_CONFIGS)
@parameterize("model_name", ["gpt2", "bert"])
@parameterize("master_weights", [False, True])
def exam_state_dict_shard(placement_config, model_name: str, master_weights: bool):
get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder()
model_size = sum(p.numel() * p.element_size() for p in model.parameters()) / 1024**2
config_dict, *_ = search_chunk_configuration(model, search_range_m=1, search_interval=100)
model = GeminiDDP(model, config_dict, **placement_config, master_weights=master_weights)
model.train()
zero_dict = model.state_dict(only_rank_0=False)
# check state dict shard
accumulated_keys = set()
# ensure number of shards > 1
for shard, _ in model.state_dict_shard(max_shard_size=(model_size / 3), only_rank_0=False):
@ -116,8 +79,6 @@ def run_dist(rank, world_size, port):
config = {}
colossalai.launch(config=config, rank=rank, world_size=world_size, host="localhost", port=port, backend="nccl")
exam_state_dict()
exam_load_state_dict()
exam_state_dict_shard()
@pytest.mark.dist

View File

@ -8,7 +8,7 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
from colossalai.utils import set_seed
from colossalai.zero import GeminiDDP, GeminiOptimizer
from colossalai.zero.gemini.chunk import search_chunk_configuration
from tests.components_to_test.registry import non_distributed_component_funcs
from tests.kit.model_zoo import model_zoo
PLACEMENT_CONFIGS = [
{"placement_policy": "static", "shard_param_frac": 0.0, "offload_optim_frac": 0.0}, # zero2
@ -22,8 +22,9 @@ PLACEMENT_CONFIGS = [
@parameterize("keep_gathered", [True, False])
def exam_zero_optim_state_dict(placement_config, keep_gathered):
set_seed(431)
get_components_func = non_distributed_component_funcs.get_callable("gpt2")
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model_builder, data_gen_fn, output_transform_fn, *_ = next(
iter(model_zoo.get_sub_registry("transformers_gpt_lm").values())
)
model = model_builder()
@ -41,15 +42,15 @@ def exam_zero_optim_state_dict(placement_config, keep_gathered):
set_seed(dist.get_rank() * 3 + 128)
model.train()
for i, (input_ids, label) in enumerate(train_dataloader):
if i > 0:
break
optim.zero_grad()
logits = model(input_ids)
logits = logits.float()
loss = criterion(logits, input_ids)
optim.backward(loss)
optim.step()
data = data_gen_fn()
data = {k: v.cuda() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
optim.zero_grad()
outputs = model(**data)
outputs = output_transform_fn(outputs)
loss = next(iter(outputs.values())).sum()
optim.backward(loss)
optim.step()
optim_state_dict = optim.state_dict()
optim.load_state_dict(optim_state_dict)