mirror of https://github.com/hpcaitech/ColossalAI
[Tensor] add module handler for linear (#1021)
* add module spec for linear * polish * polish * polishpull/1030/head
parent
ee50497db2
commit
32291dd73f
|
@ -8,8 +8,12 @@ from ._ops import *
|
|||
from .optim.colo_optimizer import ColoOptimizer
|
||||
from . import distspec
|
||||
from .dist_spec_mgr import DistSpecManager
|
||||
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
||||
from .modules import ColoLinear
|
||||
|
||||
__all__ = [
|
||||
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
|
||||
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager'
|
||||
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager',
|
||||
'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module',
|
||||
'ColoLinear'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,92 @@
|
|||
from typing import Dict
|
||||
from colossalai.tensor import ColoParameter, ParallelAction, TensorSpec
|
||||
from .modules import ColoModule
|
||||
import torch
|
||||
|
||||
_COLOSSAL_MODULES: Dict[type, ColoModule] = {}
|
||||
|
||||
|
||||
def register_colo_module(module_type: type, colo_module: ColoModule):
|
||||
global _COLOSSAL_MODULES
|
||||
_COLOSSAL_MODULES[module_type] = colo_module
|
||||
|
||||
def is_colo_module(module: torch.nn.Module):
|
||||
global _COLOSSAL_MODULES
|
||||
return type(module) in _COLOSSAL_MODULES
|
||||
|
||||
def get_colo_module(module: torch.nn.Module):
|
||||
global _COLOSSAL_MODULES
|
||||
if is_colo_module(module):
|
||||
colo_module = _COLOSSAL_MODULES[type(module)]
|
||||
colo_module.register()
|
||||
return colo_module
|
||||
else:
|
||||
return None
|
||||
|
||||
def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||
if is_colo_module(module):
|
||||
colo_module = get_colo_module(module)
|
||||
param_names = colo_module.get_param_names()
|
||||
compute_pattern = None
|
||||
for param_name in param_names:
|
||||
param = module.get_parameter(param_name)
|
||||
if not isinstance(param, ColoParameter):
|
||||
raise Exception(f'Invalid ColoParameter spec: {param} in {module} is not a ColoParameter.')
|
||||
if param.has_spec():
|
||||
cur_compute_pattern = param.spec.parallel_action.compute_pattern
|
||||
if compute_pattern is None:
|
||||
compute_pattern = cur_compute_pattern
|
||||
else:
|
||||
if cur_compute_pattern != compute_pattern:
|
||||
raise Exception(f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.')
|
||||
else:
|
||||
continue
|
||||
|
||||
if compute_pattern is not None:
|
||||
if not colo_module.has_compute_pattern(compute_pattern):
|
||||
raise Exception(f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
|
||||
|
||||
match_specs = False
|
||||
allowed_specs = colo_module.get_dist_specs(compute_pattern)
|
||||
for _, param_specs in allowed_specs.items():
|
||||
cur_match = True
|
||||
for param_name, dist_spec in param_specs.items():
|
||||
param = module.get_parameter(param_name)
|
||||
if param.has_spec():
|
||||
if dist_spec != param.spec.dist_spec:
|
||||
cur_match = False
|
||||
break
|
||||
else:
|
||||
if dist_spec is not None:
|
||||
cur_match = False
|
||||
break
|
||||
if cur_match == True:
|
||||
match_specs = True
|
||||
break
|
||||
if match_specs == False:
|
||||
raise Exception(f'Invalid ColoParameter spec: Params in {module} are incorrectly sharded.')
|
||||
|
||||
if recursive == True:
|
||||
for submodule in module.children():
|
||||
check_colo_module(submodule, recursive=True)
|
||||
|
||||
def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, label='default'):
|
||||
compute_pattern = parallel_action.compute_pattern
|
||||
if is_colo_module(module):
|
||||
# for each param
|
||||
# set DistSpec and ParallelAction
|
||||
colo_module = get_colo_module(module)
|
||||
if not colo_module.has_compute_pattern_with_label(compute_pattern, label=label):
|
||||
raise NotImplementedError
|
||||
for param_name, dist_spec in colo_module.get_dist_specs_with_label(compute_pattern, label=label).items():
|
||||
if dist_spec is None:
|
||||
continue
|
||||
param = module.get_parameter(param_name)
|
||||
if isinstance(param, ColoParameter):
|
||||
spec = TensorSpec(dist_spec, parallel_action)
|
||||
param.set_spec(spec)
|
||||
check_colo_module(module, recursive=False)
|
||||
if recursive == True:
|
||||
for submodule in module.children():
|
||||
init_colo_module(submodule, parallel_action, recursive=True, label=label)
|
||||
|
|
@ -0,0 +1,2 @@
|
|||
from .colo_module import ColoModule
|
||||
from .linear import ColoLinear
|
|
@ -0,0 +1,51 @@
|
|||
from colossalai.tensor.distspec import _DistSpec
|
||||
from colossalai.tensor import ComputePattern
|
||||
from typing import List, Dict
|
||||
|
||||
|
||||
class ColoModule(object):
|
||||
def __init__(self):
|
||||
self._shard_params: List[str] = []
|
||||
# Example:
|
||||
# {ComputePattern.TP1D:
|
||||
# 'default':
|
||||
# 'weight':
|
||||
# distspec.shard(xxxxx)
|
||||
# 'bias':
|
||||
# distspec.shard(xxxxx)
|
||||
# 'row': ...
|
||||
# 'col': ...
|
||||
# }
|
||||
self._allowed_patterns: Dict[ComputePattern, Dict[str, Dict[str, _DistSpec]]] = {}
|
||||
|
||||
def _register_shard_params(self, params: List[str]):
|
||||
self._shard_params = params
|
||||
|
||||
def _register_allowed_patterns(self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], label='default'):
|
||||
assert list(dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
|
||||
if not compute_pattern in self._allowed_patterns:
|
||||
self._allowed_patterns[compute_pattern] = {}
|
||||
self._allowed_patterns[compute_pattern][label] = dist_specs
|
||||
|
||||
def _set_default(self, compute_pattern: ComputePattern, target_label):
|
||||
self._allowed_patterns[compute_pattern]['default'] = self._allowed_patterns[compute_pattern][target_label]
|
||||
|
||||
def has_compute_pattern(self, compute_pattern: ComputePattern):
|
||||
return compute_pattern in self._allowed_patterns
|
||||
|
||||
def get_dist_specs(self, compute_pattern: ComputePattern):
|
||||
assert self.has_compute_pattern(compute_pattern)
|
||||
return self._allowed_patterns[compute_pattern]
|
||||
|
||||
def has_compute_pattern_with_label(self, compute_pattern: ComputePattern, label='default'):
|
||||
return compute_pattern in self._allowed_patterns and label in self._allowed_patterns[compute_pattern]
|
||||
|
||||
def get_dist_specs_with_label(self, compute_pattern: ComputePattern, label='default'):
|
||||
assert self.has_compute_pattern_with_label(compute_pattern, label)
|
||||
return self._allowed_patterns[compute_pattern][label]
|
||||
|
||||
def get_param_names(self):
|
||||
return self._shard_params
|
||||
|
||||
def register(self):
|
||||
raise NotImplementedError
|
|
@ -0,0 +1,39 @@
|
|||
from .colo_module import ColoModule
|
||||
from colossalai.tensor import ComputePattern, distspec
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
|
||||
class ColoLinear(ColoModule):
|
||||
def __init__(self):
|
||||
super(ColoLinear, self).__init__()
|
||||
self._register_shard_params(['weight', 'bias'])
|
||||
self._register = False
|
||||
|
||||
def register(self):
|
||||
if self._register == False:
|
||||
self._set_TP1D()
|
||||
self._register = True
|
||||
|
||||
def _set_TP1D(self):
|
||||
# TP1D Row Linear
|
||||
_compute_pattern = ComputePattern.TP1D
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'bias': None
|
||||
},
|
||||
label='row',
|
||||
)
|
||||
|
||||
# TP1D Col Linear
|
||||
self._register_allowed_patterns(
|
||||
compute_pattern=_compute_pattern,
|
||||
dist_specs={
|
||||
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||
'bias': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)])
|
||||
},
|
||||
label='col',
|
||||
)
|
||||
|
||||
self._set_default(compute_pattern=_compute_pattern, target_label='row')
|
|
@ -1,6 +1,7 @@
|
|||
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||
import torch
|
||||
from colossalai.tensor import ColoTensor, ColoParameter
|
||||
from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \
|
||||
ColoLinear
|
||||
import types
|
||||
|
||||
from torch import nn
|
||||
|
@ -101,6 +102,17 @@ def _setattr_with_colotensor(self, name: str, value: Union[torch.Tensor, torch.n
|
|||
else:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
def _get_parameter_with_colotensor(self, target: str) -> Union[torch.nn.Parameter, ColoTensor]:
|
||||
module_path, _, param_name = target.rpartition(".")
|
||||
|
||||
mod: torch.nn.Module = self.get_submodule(module_path)
|
||||
|
||||
if not hasattr(mod, param_name):
|
||||
raise AttributeError(mod._get_name() + " has no attribute `"
|
||||
+ param_name + "`")
|
||||
|
||||
param = getattr(mod, param_name)
|
||||
return param
|
||||
|
||||
def ColoModulize(module):
|
||||
"""
|
||||
|
@ -124,6 +136,8 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
|||
|
||||
torch.nn.Module.__setattr__ = _setattr_with_colotensor
|
||||
torch.nn.Module.register_parameter = _register_parameter_with_colotensor
|
||||
torch.nn.Module.get_parameter = _get_parameter_with_colotensor
|
||||
register_colo_module(torch.nn.Linear, ColoLinear())
|
||||
|
||||
def _post_init_method(self, module: torch.nn.Module, *args, **kwargs):
|
||||
"""
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
from copy import copy
|
||||
from colossalai.utils.cuda import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
import torch
|
||||
from colossalai.context.parallel_mode import ParallelMode
|
||||
from colossalai.tensor import ColoTensor, distspec
|
||||
|
||||
from functools import partial
|
||||
|
||||
import colossalai
|
||||
import pytest
|
||||
import torch
|
||||
import torch.multiprocessing as mp
|
||||
import torch.nn.functional as F
|
||||
from colossalai.testing import rerun_if_address_is_in_use
|
||||
from colossalai.utils import free_port
|
||||
from colossalai.core import global_context as gpc
|
||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, register_colo_module, init_colo_module, ColoLinear
|
||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||
|
||||
def run_simplenet_with_spec(label):
|
||||
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||
rank = gpc.get_local_rank(ParallelMode.PARALLEL_1D)
|
||||
|
||||
set_seed(1)
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = model_builder(checkpoint=True)
|
||||
|
||||
if rank == 0:
|
||||
model_seq = model_builder(checkpoint=True)
|
||||
model_seq = model_seq.cuda()
|
||||
|
||||
# Make two models have the same init params
|
||||
for p1, p2 in zip(model.parameters(), model_seq.parameters()):
|
||||
p2.data.copy_(p1.data)
|
||||
|
||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||
init_colo_module(model, parallel_action, recursive=True, label=label)
|
||||
|
||||
model = model.cuda()
|
||||
for i, (data, label) in enumerate(train_dataloader):
|
||||
data = data.to(get_current_device())
|
||||
label = label.to(get_current_device())
|
||||
|
||||
torch.distributed.broadcast(data, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
torch.distributed.broadcast(label, 0, group=gpc.get_group(ParallelMode.PARALLEL_1D))
|
||||
|
||||
if criterion:
|
||||
output = model(data)
|
||||
loss = criterion(output, label)
|
||||
else:
|
||||
output = model(data, label)
|
||||
loss = output
|
||||
|
||||
# For reference
|
||||
if rank == 0:
|
||||
if criterion:
|
||||
output_seq = model_seq(data)
|
||||
loss_seq = criterion(output_seq, label)
|
||||
else:
|
||||
output_seq = model_seq(data, label)
|
||||
loss_seq = output_seq
|
||||
|
||||
if rank == 0:
|
||||
with torch.no_grad():
|
||||
assert torch.allclose(loss, loss_seq, rtol=1e-2)
|
||||
|
||||
loss.backward()
|
||||
|
||||
if rank == 0:
|
||||
loss_seq.backward()
|
||||
|
||||
with torch.no_grad():
|
||||
# check param
|
||||
for p1, p2 in zip(model.parameters(), model_seq.parameters()):
|
||||
if p1.size() == p2.size():
|
||||
assert torch.allclose(p1, p2)
|
||||
else:
|
||||
if p1.size(-1) < p2.size(-1): # col
|
||||
world_size = p2.size(-1) // p1.size(-1)
|
||||
split_p2 = torch.chunk(p2, world_size, dim=-1)[0]
|
||||
|
||||
elif p1.size(0) < p2.size(0): # row
|
||||
world_size = p2.size(0) // p1.size(0)
|
||||
split_p2 = torch.chunk(p2, world_size, dim=0)[0]
|
||||
|
||||
assert torch.allclose(p1, split_p2)
|
||||
|
||||
if i > 3:
|
||||
break
|
||||
|
||||
def run_linear_with_spec(label):
|
||||
with ColoInitContext(device=get_current_device()):
|
||||
model = torch.nn.Linear(4, 8)
|
||||
|
||||
model_handy = copy(model)
|
||||
|
||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||
init_colo_module(model, parallel_action, recursive=True, label=label)
|
||||
|
||||
x = torch.rand(2, 4).cuda()
|
||||
out = model(x)
|
||||
colo_out = model_handy(x)
|
||||
assert tensor_equal(out, colo_out)
|
||||
grad = torch.rand_like(out)
|
||||
out.backward(grad)
|
||||
colo_out.backward(grad)
|
||||
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad)
|
||||
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad)
|
||||
|
||||
|
||||
def run_dist(rank, world_size, port, func):
|
||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
func('col')
|
||||
func('row')
|
||||
func('default')
|
||||
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_linear_1d(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), func=run_linear_with_spec)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
@pytest.mark.dist
|
||||
@pytest.mark.parametrize('world_size', [1, 4])
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_module_simplenet(world_size):
|
||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), func=run_simplenet_with_spec)
|
||||
mp.spawn(run_func, nprocs=world_size)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_module_simplenet(4)
|
Loading…
Reference in New Issue