[zero] new interface for ShardedOptimv2 (#406)

pull/413/head
Jiarui Fang 2022-03-14 20:48:41 +08:00 committed by GitHub
parent a9c27be42e
commit 370f567e7d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 51 additions and 35 deletions

View File

@ -1,5 +1,5 @@
from enum import Enum from enum import Enum
from typing import Dict, Optional from typing import Callable, Dict, Optional, Union
import torch import torch
import torch.distributed as dist import torch.distributed as dist
@ -15,7 +15,7 @@ from torch import Tensor
from torch.distributed import ProcessGroup from torch.distributed import ProcessGroup
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
from torch.optim import Optimizer from torch.optim import Optimizer
from typing import Type, Any
from ._utils import has_inf_or_nan from ._utils import has_inf_or_nan
@ -27,8 +27,8 @@ class OptimState(Enum):
class ShardedOptimizerV2(ColossalaiOptimizer): class ShardedOptimizerV2(ColossalaiOptimizer):
def __init__(self, def __init__(self,
optimizer: Optimizer,
sharded_model: ShardedModelV2, sharded_model: ShardedModelV2,
optimizer_class: Type[Optimizer],
shard_strategy: BaseShardStrategy, shard_strategy: BaseShardStrategy,
cpu_offload: bool = False, cpu_offload: bool = False,
initial_scale: float = 2**32, initial_scale: float = 2**32,
@ -39,9 +39,34 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
hysteresis: float = 2, hysteresis: float = 2,
max_scale: int = 2**32, max_scale: int = 2**32,
dp_process_group: Optional[ProcessGroup] = None, dp_process_group: Optional[ProcessGroup] = None,
mp_process_group: Optional[ProcessGroup] = None) -> None: mp_process_group: Optional[ProcessGroup] = None,
**defaults: Any) -> None:
"""
:param sharded_model: A sharded model initialized by class ShardedModelV2
:type sharded_model: sharded_model
:param optimizer_class: A type of Optimizer
:type optimizer_class: Type[Optimizer]
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:param cpu_offload: is offloading the optimizer states to CPU.
:type cpu_offload: bool
:param shard_strategy: The strategy to shard the sharded_model and optimizer model parameters.
:type shard_strategy: BaseShardStrategy
:**defaults: any trailing arguments, which are forwarded to the local optimizer.
:type defaults: dict()
"""
assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel' assert isinstance(sharded_model, ShardedModelV2), 'model must be wrapped with ShardedModel'
super().__init__(optimizer)
self._optim_defaults = defaults
# initialize the M, V as zeros tensors and initialize param fp32 from sharded_model.parameters()
self.optimizer = optimizer_class(sharded_model.parameters(), **self._optim_defaults)
super().__init__(self.optimizer)
self.shard_strategy = shard_strategy self.shard_strategy = shard_strategy
self.model: ShardedModelV2 = sharded_model self.model: ShardedModelV2 = sharded_model
if cpu_offload and not sharded_model.cpu_offload: if cpu_offload and not sharded_model.cpu_offload:
@ -65,7 +90,7 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
# Store fp32 param shards # Store fp32 param shards
self.master_params: Dict[Parameter, Tensor] = {} self.master_params: Dict[Parameter, Tensor] = {}
for group in optimizer.param_groups: for group in self.optimizer.param_groups:
for p in group['params']: for p in group['params']:
assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam' assert hasattr(p, 'col_attr'), 'The parameter must be wrapped with ShardedParam'
is_param_sharded = p.col_attr.data.is_sharded is_param_sharded = p.col_attr.data.is_sharded

View File

@ -74,8 +74,5 @@ def get_training_components():
sequence_length=sequence_length, sequence_length=sequence_length,
is_distrbuted=True) is_distrbuted=True)
def get_optim(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = None criterion = None
return bert_model_builder, trainloader, testloader, get_optim, criterion return bert_model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -49,8 +49,5 @@ def get_training_components():
trainloader = DummyDataLoader() trainloader = DummyDataLoader()
testloader = DummyDataLoader() testloader = DummyDataLoader()
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -43,8 +43,5 @@ def get_training_components():
trainloader = DummyDataLoader() trainloader = DummyDataLoader()
testloader = DummyDataLoader() testloader = DummyDataLoader()
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -29,8 +29,5 @@ def get_resnet_training_components():
trainloader = get_cifar10_dataloader(train=True) trainloader = get_cifar10_dataloader(train=True)
testloader = get_cifar10_dataloader(train=False) testloader = get_cifar10_dataloader(train=False)
def optim_builder(model):
return torch.optim.Adam(model.parameters(), lr=0.001)
criterion = torch.nn.CrossEntropyLoss() criterion = torch.nn.CrossEntropyLoss()
return model_builder, trainloader, testloader, optim_builder, criterion return model_builder, trainloader, testloader, torch.optim.Adam, criterion

View File

@ -19,11 +19,11 @@ def run_train():
# FIXME: test bert # FIXME: test bert
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model_builder, train_dataloader, _, optimizer_builder, criterion = get_components_func() model_builder, train_dataloader, _, optimizer_class, criterion = get_components_func()
model = model_builder(checkpoint=False) model = model_builder(checkpoint=False)
engine, train_dataloader, *args = colossalai.initialize(model=model, engine, train_dataloader, *args = colossalai.initialize(model=model,
optimizer=optimizer_builder(model), optimizer=optimizer_class(model.parameters(), lr=1e-3),
criterion=criterion, criterion=criterion,
train_dataloader=train_dataloader) train_dataloader=train_dataloader)
@ -84,7 +84,7 @@ def run_engine(rank, world_size, port):
@pytest.mark.dist @pytest.mark.dist
def test_engine(): def test_engine():
world_size = 4 world_size = 2
run_func = partial(run_engine, world_size=world_size, port=free_port()) run_func = partial(run_engine, world_size=world_size, port=free_port())
mp.spawn(run_func, nprocs=world_size) mp.spawn(run_func, nprocs=world_size)

View File

@ -25,9 +25,9 @@ def run_trainer_no_pipeline(rank, world_size, port):
test_models = ['repeated_computed_layers', 'resnet18', 'nested_model'] test_models = ['repeated_computed_layers', 'resnet18', 'nested_model']
for name in test_models: for name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(name) get_components_func = non_distributed_component_funcs.get_callable(name)
model_builder, train_dataloader, test_dataloader, optimizer_builder, criterion = get_components_func() model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model_builder() model = model_builder()
optimizer = optimizer_builder(model) optimizer = optimizer_class(model.parameters(), lr=1e-3)
engine, train_dataloader, *_ = colossalai.initialize(model=model, engine, train_dataloader, *_ = colossalai.initialize(model=model,
optimizer=optimizer, optimizer=optimizer,
criterion=criterion, criterion=criterion,

View File

@ -44,19 +44,21 @@ def run_dist(rank, world_size, port, cpu_offload, shard_strategy):
shard_strategy = shard_strategy() shard_strategy = shard_strategy()
for model_name in test_models: for model_name in test_models:
get_components_func = non_distributed_component_funcs.get_callable(model_name) get_components_func = non_distributed_component_funcs.get_callable(model_name)
model, train_dataloader, test_dataloader, optimizer, criterion = get_components_func() model, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
model = model(checkpoint=True).cuda() model = model(checkpoint=True).cuda()
zero_model = ShardedModelV2(copy.deepcopy(model), zero_model = ShardedModelV2(copy.deepcopy(model),
shard_strategy, shard_strategy,
offload_config=dict(device='cpu') if cpu_offload else None) offload_config=dict(device='cpu') if cpu_offload else None)
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
model = DDP(model) model = DDP(model)
optim = Adam(model.parameters(), lr=1e-3) lr = 1e-3
sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3), optim = optimizer_class(model.parameters(), lr=lr)
zero_model, sharded_optim = ShardedOptimizerV2(zero_model,
optimizer_class,
shard_strategy, shard_strategy,
cpu_offload=cpu_offload, cpu_offload=cpu_offload,
initial_scale=2**5) initial_scale=2**5,
lr=lr)
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break

View File

@ -59,11 +59,12 @@ def run_dist(rank, world_size, port, shard_strategy):
if dist.get_world_size() > 1: if dist.get_world_size() > 1:
model = DDP(model) model = DDP(model)
optim = Adam(model.parameters(), lr=1e-3) optim = Adam(model.parameters(), lr=1e-3)
sharded_optim = ShardedOptimizerV2(CPUAdam(zero_model.parameters(), lr=1e-3), sharded_optim = ShardedOptimizerV2(zero_model,
zero_model, CPUAdam,
shard_strategy, shard_strategy,
initial_scale=2**5, initial_scale=2**5,
cpu_offload=True) cpu_offload=True,
lr=1e-3)
for i, (data, label) in enumerate(train_dataloader): for i, (data, label) in enumerate(train_dataloader):
if i > 2: if i > 2:
break break