mirror of https://github.com/hpcaitech/ColossalAI
impl shard optim v2 and add unit test
parent
74f77e314b
commit
001ca624dd
|
@ -1,12 +1,16 @@
|
|||
import torch
|
||||
from . import BaseOpHook
|
||||
import torch.distributed as dist
|
||||
from colossalai.registry import OPHOOKS
|
||||
|
||||
from . import BaseOpHook
|
||||
|
||||
|
||||
@OPHOOKS.register_module
|
||||
class ShardParamHook(BaseOpHook):
|
||||
"""
|
||||
A hook to process sharded param before and afther FWD and BWD operator executing.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -17,25 +21,32 @@ class ShardParamHook(BaseOpHook):
|
|||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.gather()
|
||||
if dist.get_rank() == 0:
|
||||
print(f'{param._name} pre fwd shape {param.ca_attr.payload("cpu").shape}')
|
||||
|
||||
def post_fwd_exec(self, module: torch.nn.Module, *args):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.shard()
|
||||
if dist.get_rank() == 0:
|
||||
print(f'{param._name} post fwd shape {param.ca_attr.payload("cpu").shape}')
|
||||
|
||||
def pre_bwd_exec(self, module: torch.nn.Module, input, output):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.gather()
|
||||
if dist.get_rank() == 0:
|
||||
print(f'{param._name} pre bwd shape {param.ca_attr.payload("cpu").shape}')
|
||||
|
||||
def post_bwd_exec(self, module: torch.nn.Module, input):
|
||||
for param in module.parameters():
|
||||
assert hasattr(param, 'ca_attr')
|
||||
param.ca_attr.shard()
|
||||
if dist.get_rank() == 0:
|
||||
print(f'{param._name} post bwd shape {param.ca_attr.payload("cpu").shape}')
|
||||
|
||||
def pre_iter(self):
|
||||
pass
|
||||
|
||||
def post_iter(self):
|
||||
pass
|
||||
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
from .sharded_optim import ShardedOptimizer
|
||||
from .sharded_optim_v2 import ShardedOptimizerV2
|
||||
|
||||
__all__ = ['ShardedOptimizer']
|
||||
__all__ = ['ShardedOptimizer', 'ShardedOptimizerV2']
|
||||
|
|
|
@ -14,6 +14,7 @@ from torch.distributed import ProcessGroup
|
|||
from torch.nn.parameter import Parameter
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from ..sharded_model._zero3_utils import free_storage
|
||||
from ._utils import has_inf_or_nan
|
||||
|
||||
|
||||
|
@ -62,6 +63,8 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
if hasattr(p, 'ca_attr'):
|
||||
assert p.ca_attr.is_sharded, 'ShardedAdam can be only used with sharded model'
|
||||
self.master_params[p] = p.ca_attr.payload(self.device)
|
||||
if dist.get_rank() == 0:
|
||||
print(f'load payload {p._name} {self.master_params[p].shape}')
|
||||
else:
|
||||
self.master_params[p] = p.data.to(device=self.device)
|
||||
if torch.is_floating_point(self.master_params[p]) and self.master_params[p].dtype != torch.float:
|
||||
|
@ -84,19 +87,27 @@ class ShardedOptimizerV2(ColossalaiOptimizer):
|
|||
for p in group['params']:
|
||||
p.data = self.master_params[p]
|
||||
ret = self.optim.step(*args, **kwargs)
|
||||
# Write master param to payload and set p.data to None
|
||||
# Write master param to payload
|
||||
for group in self.optim.param_groups:
|
||||
for p in group['params']:
|
||||
if hasattr(p, 'ca_attr'):
|
||||
# TODO: update payload
|
||||
p.data = None
|
||||
if dist.get_rank() == 0:
|
||||
print(f'write {p._name} {p.shape} orig_shape {p.ca_attr._origin_shape} \
|
||||
payload shape {p.ca_attr._param_payload.shape} sharded {p.ca_attr.is_sharded}')
|
||||
p.ca_attr.set_payload(p.data)
|
||||
# We cannot set p.data to None directly, so we free storage
|
||||
free_storage(p.data)
|
||||
return ret
|
||||
|
||||
def backward(self, loss: Tensor) -> None:
|
||||
loss = self.loss_scale * loss
|
||||
self.optim_state = OptimState.SCALED
|
||||
if self.model_is_sharded:
|
||||
if dist.get_rank() == 0:
|
||||
print('sharded model backward')
|
||||
self.model.backward(loss)
|
||||
if dist.get_rank() == 0:
|
||||
print('sharded model backward done')
|
||||
else:
|
||||
super().backward(loss)
|
||||
|
||||
|
|
|
@ -0,0 +1,68 @@
|
|||
#!/usr/bin/env python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import copy
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.multiprocessing as mp
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.zero.sharded_model import ShardedModelV2
|
||||
from colossalai.zero.sharded_optim import ShardedOptimizerV2
|
||||
from torch.optim import Adam
|
||||
|
||||
from common import (CONFIG, Net, check_grads, check_grads_padding, check_params, check_params_padding)
|
||||
|
||||
|
||||
def run_step(model, optimizer, x, enable_autocast=False):
|
||||
model.train()
|
||||
with torch.cuda.amp.autocast(enabled=enable_autocast):
|
||||
y = model(x)
|
||||
loss = y.sum()
|
||||
loss = loss.float()
|
||||
if isinstance(model, ShardedModelV2):
|
||||
optimizer.backward(loss)
|
||||
for p in model.parameters():
|
||||
assert p.ca_attr.is_sharded
|
||||
else:
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port):
|
||||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
model = Net(checkpoint=True).cuda()
|
||||
zero_model = copy.deepcopy(model)
|
||||
zero_model = ShardedModelV2(zero_model, process_group=gpc.get_group(ParallelMode.DATA))
|
||||
for n, p in zero_model.named_parameters():
|
||||
p._name = n
|
||||
optim = Adam(model.parameters(), lr=1e-3)
|
||||
sharded_optim = ShardedOptimizerV2(Adam(zero_model.parameters(), lr=1e-3), zero_model)
|
||||
|
||||
for _ in range(2):
|
||||
x = torch.rand(2, 5).cuda()
|
||||
run_step(zero_model, sharded_optim, x, False)
|
||||
run_step(model, optim, x, False)
|
||||
if dist.get_world_size() > 1:
|
||||
check_grads_padding(model, zero_model)
|
||||
check_params_padding(model, zero_model)
|
||||
else:
|
||||
check_grads(model, zero_model)
|
||||
check_params(model, zero_model)
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
def test_sharded_optim_v2():
|
||||
world_size = 2
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_sharded_optim_v2()
|
Loading…
Reference in New Issue