mirror of https://github.com/hpcaitech/ColossalAI
reorgnize colotensor directory (#1062)
* reorgnize colotensor directory * polish codepull/1066/head
parent
3d10be33bd
commit
a00644079e
|
@ -4,3 +4,7 @@ from .lr_scheduler import *
|
||||||
from .metric import *
|
from .metric import *
|
||||||
from .model import *
|
from .model import *
|
||||||
from .optimizer import *
|
from .optimizer import *
|
||||||
|
from ._ops import *
|
||||||
|
|
||||||
|
from .modules import ColoLinear, ColoEmbedding
|
||||||
|
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
||||||
|
|
|
@ -10,6 +10,7 @@ def register_colo_module(module_type: type, colo_module: ColoModule):
|
||||||
global _COLOSSAL_MODULES
|
global _COLOSSAL_MODULES
|
||||||
_COLOSSAL_MODULES[module_type] = colo_module
|
_COLOSSAL_MODULES[module_type] = colo_module
|
||||||
|
|
||||||
|
|
||||||
def is_colo_module(module: torch.nn.Module):
|
def is_colo_module(module: torch.nn.Module):
|
||||||
global _COLOSSAL_MODULES
|
global _COLOSSAL_MODULES
|
||||||
for module_type in _COLOSSAL_MODULES.keys():
|
for module_type in _COLOSSAL_MODULES.keys():
|
||||||
|
@ -17,6 +18,7 @@ def is_colo_module(module: torch.nn.Module):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_colo_module(module: torch.nn.Module):
|
def get_colo_module(module: torch.nn.Module):
|
||||||
global _COLOSSAL_MODULES
|
global _COLOSSAL_MODULES
|
||||||
if is_colo_module(module):
|
if is_colo_module(module):
|
||||||
|
@ -26,6 +28,7 @@ def get_colo_module(module: torch.nn.Module):
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def check_colo_module(module: torch.nn.Module, recursive=True):
|
def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||||
if is_colo_module(module):
|
if is_colo_module(module):
|
||||||
colo_module = get_colo_module(module)
|
colo_module = get_colo_module(module)
|
||||||
|
@ -41,14 +44,16 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||||
compute_pattern = cur_compute_pattern
|
compute_pattern = cur_compute_pattern
|
||||||
else:
|
else:
|
||||||
if cur_compute_pattern != compute_pattern:
|
if cur_compute_pattern != compute_pattern:
|
||||||
raise Exception(f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.')
|
raise Exception(
|
||||||
|
f'Invalid ColoParameter spec: Params in {module} have different compute_pattern.')
|
||||||
else:
|
else:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if compute_pattern is not None:
|
if compute_pattern is not None:
|
||||||
colo_module.register(compute_pattern)
|
colo_module.register(compute_pattern)
|
||||||
if not colo_module.has_compute_pattern(compute_pattern):
|
if not colo_module.has_compute_pattern(compute_pattern):
|
||||||
raise Exception(f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
|
raise Exception(
|
||||||
|
f'Invalid ColoParameter spec: ComputePattern {compute_pattern} in {module} is not allowed.')
|
||||||
|
|
||||||
match_specs = False
|
match_specs = False
|
||||||
allowed_specs = colo_module.get_dist_specs(compute_pattern)
|
allowed_specs = colo_module.get_dist_specs(compute_pattern)
|
||||||
|
@ -73,6 +78,7 @@ def check_colo_module(module: torch.nn.Module, recursive=True):
|
||||||
for submodule in module.children():
|
for submodule in module.children():
|
||||||
check_colo_module(submodule, recursive=True)
|
check_colo_module(submodule, recursive=True)
|
||||||
|
|
||||||
|
|
||||||
def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, mode='default'):
|
def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, recursive=True, mode='default'):
|
||||||
compute_pattern = parallel_action.compute_pattern
|
compute_pattern = parallel_action.compute_pattern
|
||||||
if is_colo_module(module):
|
if is_colo_module(module):
|
||||||
|
@ -99,4 +105,3 @@ def init_colo_module(module: torch.nn.Module, parallel_action: ParallelAction, r
|
||||||
if recursive == True:
|
if recursive == True:
|
||||||
for submodule in module.children():
|
for submodule in module.children():
|
||||||
init_colo_module(submodule, parallel_action, recursive=True, mode=mode)
|
init_colo_module(submodule, parallel_action, recursive=True, mode=mode)
|
||||||
|
|
|
@ -4,6 +4,7 @@ from typing import List, Dict
|
||||||
|
|
||||||
|
|
||||||
class ColoModule(object):
|
class ColoModule(object):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._shard_params: List[str] = []
|
self._shard_params: List[str] = []
|
||||||
# Example:
|
# Example:
|
||||||
|
@ -21,8 +22,12 @@ class ColoModule(object):
|
||||||
def _register_shard_params(self, params: List[str]):
|
def _register_shard_params(self, params: List[str]):
|
||||||
self._shard_params = params
|
self._shard_params = params
|
||||||
|
|
||||||
def _register_allowed_patterns(self, compute_pattern: ComputePattern, dist_specs: Dict[str, _DistSpec], mode='default'):
|
def _register_allowed_patterns(self,
|
||||||
assert list(dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
|
compute_pattern: ComputePattern,
|
||||||
|
dist_specs: Dict[str, _DistSpec],
|
||||||
|
mode='default'):
|
||||||
|
assert list(
|
||||||
|
dist_specs.keys()).sort() == self._shard_params.sort(), 'Every registered param should have dist_spec.'
|
||||||
if not compute_pattern in self._allowed_patterns:
|
if not compute_pattern in self._allowed_patterns:
|
||||||
self._allowed_patterns[compute_pattern] = {}
|
self._allowed_patterns[compute_pattern] = {}
|
||||||
self._allowed_patterns[compute_pattern][mode] = dist_specs
|
self._allowed_patterns[compute_pattern][mode] = dist_specs
|
|
@ -3,7 +3,9 @@ from colossalai.tensor import ComputePattern, distspec
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
class ColoEmbedding(ColoModule):
|
class ColoEmbedding(ColoModule):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(ColoEmbedding, self).__init__()
|
super(ColoEmbedding, self).__init__()
|
||||||
self._register_shard_params(['weight'])
|
self._register_shard_params(['weight'])
|
||||||
|
@ -19,7 +21,9 @@ class ColoEmbedding(ColoModule):
|
||||||
self._register_allowed_patterns(
|
self._register_allowed_patterns(
|
||||||
compute_pattern=_compute_pattern,
|
compute_pattern=_compute_pattern,
|
||||||
dist_specs={
|
dist_specs={
|
||||||
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
'weight':
|
||||||
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||||
|
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
},
|
},
|
||||||
mode='row',
|
mode='row',
|
||||||
)
|
)
|
||||||
|
@ -28,7 +32,9 @@ class ColoEmbedding(ColoModule):
|
||||||
self._register_allowed_patterns(
|
self._register_allowed_patterns(
|
||||||
compute_pattern=_compute_pattern,
|
compute_pattern=_compute_pattern,
|
||||||
dist_specs={
|
dist_specs={
|
||||||
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
'weight':
|
||||||
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
|
||||||
|
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
},
|
},
|
||||||
mode='col',
|
mode='col',
|
||||||
)
|
)
|
|
@ -3,7 +3,9 @@ from colossalai.tensor import ComputePattern, distspec
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
|
||||||
|
|
||||||
class ColoLinear(ColoModule):
|
class ColoLinear(ColoModule):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(ColoLinear, self).__init__()
|
super(ColoLinear, self).__init__()
|
||||||
self._register_shard_params(['weight', 'bias'])
|
self._register_shard_params(['weight', 'bias'])
|
||||||
|
@ -19,8 +21,11 @@ class ColoLinear(ColoModule):
|
||||||
self._register_allowed_patterns(
|
self._register_allowed_patterns(
|
||||||
compute_pattern=_compute_pattern,
|
compute_pattern=_compute_pattern,
|
||||||
dist_specs={
|
dist_specs={
|
||||||
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
'weight':
|
||||||
'bias': None
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1],
|
||||||
|
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
'bias':
|
||||||
|
None
|
||||||
},
|
},
|
||||||
mode='row',
|
mode='row',
|
||||||
)
|
)
|
||||||
|
@ -29,8 +34,12 @@ class ColoLinear(ColoModule):
|
||||||
self._register_allowed_patterns(
|
self._register_allowed_patterns(
|
||||||
compute_pattern=_compute_pattern,
|
compute_pattern=_compute_pattern,
|
||||||
dist_specs={
|
dist_specs={
|
||||||
'weight': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
'weight':
|
||||||
'bias': distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)])
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||||
|
[gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
'bias':
|
||||||
|
distspec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0],
|
||||||
|
[gpc.get_world_size(ParallelMode.PARALLEL_1D)])
|
||||||
},
|
},
|
||||||
mode='col',
|
mode='col',
|
||||||
)
|
)
|
|
@ -7,7 +7,9 @@ from .lamb import Lamb
|
||||||
from .lars import Lars
|
from .lars import Lars
|
||||||
from .cpu_adam import CPUAdam
|
from .cpu_adam import CPUAdam
|
||||||
from .hybrid_adam import HybridAdam
|
from .hybrid_adam import HybridAdam
|
||||||
|
from .colo_optimizer import ColoOptimizer
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam', 'CPU_ADAM_CNT'
|
'ColossalaiOptimizer', 'FusedLAMB', 'FusedAdam', 'FusedSGD', 'Lamb', 'Lars', 'CPUAdam', 'HybridAdam',
|
||||||
|
'CPU_ADAM_CNT', 'ColoOptimizer'
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,21 +1,14 @@
|
||||||
from .spec import ComputePattern, ParallelAction, TensorSpec
|
from .spec import ComputePattern, ParallelAction, TensorSpec
|
||||||
from .op_wrapper import (
|
|
||||||
colo_op_impl,)
|
|
||||||
from .colo_tensor import ColoTensor
|
from .colo_tensor import ColoTensor
|
||||||
from .colo_parameter import ColoParameter
|
from .colo_parameter import ColoParameter
|
||||||
from .utils import convert_parameter, named_params_with_colotensor
|
from .utils import convert_parameter, named_params_with_colotensor
|
||||||
from ._ops import *
|
|
||||||
from .optim.colo_optimizer import ColoOptimizer
|
|
||||||
from . import distspec
|
from . import distspec
|
||||||
from .dist_spec_mgr import DistSpecManager
|
from .dist_spec_mgr import DistSpecManager
|
||||||
from .param_op_hook import ParamOpHook, use_param_op_hooks
|
from .param_op_hook import ParamOpHook, use_param_op_hooks
|
||||||
from .chunk import ChunkManager, TensorState
|
from .chunk import ChunkManager, TensorState
|
||||||
from .module_utils import register_colo_module, is_colo_module, get_colo_module, init_colo_module, check_colo_module
|
|
||||||
from .modules import ColoLinear, ColoEmbedding
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
|
'ColoTensor', 'convert_parameter', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'named_params_with_colotensor',
|
||||||
'named_params_with_colotensor', 'ColoOptimizer', 'ColoParameter', 'distspec', 'DistSpecManager',
|
'ColoParameter', 'distspec', 'DistSpecManager', 'ParamOpHook', 'use_param_op_hooks', 'ChunkManager', 'TensorState'
|
||||||
'register_colo_module', 'is_colo_module', 'get_colo_module', 'init_colo_module', 'check_colo_module', 'ColoLinear',
|
|
||||||
'ColoEmbedding', 'ParamOpHook', 'use_param_op_hooks', 'ChunkManager', 'TensorState'
|
|
||||||
]
|
]
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
from .colo_tensor import ColoTensor
|
from colossalai.tensor.colo_tensor import ColoTensor
|
||||||
from .const import TensorType
|
from colossalai.tensor.const import TensorType
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import TensorSpec, distspec
|
from colossalai.tensor import TensorSpec, distspec
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from .param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd
|
from colossalai.tensor.param_op_hook import _ParamOpHookWrapper, PreFwdPostBwd, PostFwdPreBwd
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,29 @@
|
||||||
from colossalai.tensor.distspec import _DistSpec
|
from colossalai.tensor.distspec import _DistSpec
|
||||||
from colossalai.nn.layer.utils import divide
|
# from colossalai.nn.layer.utils import divide
|
||||||
from numpy import prod
|
from numpy import prod
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
|
|
||||||
|
|
||||||
|
# TODO(jiaruifang) circle import, move the divide to colossalai.commons.
|
||||||
|
# colossalai.tensor shall not import any submodule from colossal.nn
|
||||||
|
def divide(numerator, denominator):
|
||||||
|
"""Only allow exact division.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
numerator (int): Numerator of the division.
|
||||||
|
denominator (int): Denominator of the division.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: the result of exact division.
|
||||||
|
"""
|
||||||
|
assert denominator != 0, 'denominator can not be zero'
|
||||||
|
assert numerator % denominator == 0, \
|
||||||
|
'{} is not divisible by {}'.format(numerator, denominator)
|
||||||
|
return numerator // denominator
|
||||||
|
|
||||||
|
|
||||||
class TransformDistSpec(torch.autograd.Function):
|
class TransformDistSpec(torch.autograd.Function):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from colossalai.tensor.colo_tensor import ColoTensor
|
|
||||||
|
|
||||||
from typing import Iterator, Tuple, Union
|
from typing import Iterator, Tuple, Union
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor.colo_tensor import ColoTensor
|
||||||
|
|
||||||
|
|
||||||
# The function is credited to PyTorch Team
|
# The function is credited to PyTorch Team
|
||||||
|
|
|
@ -1,11 +1,12 @@
|
||||||
from .utils import InsertPostInitMethodToModuleSubClasses
|
from .utils import InsertPostInitMethodToModuleSubClasses
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor import ColoTensor, ColoParameter, register_colo_module, init_colo_module, \
|
from colossalai.tensor import ColoTensor, ColoParameter
|
||||||
|
|
||||||
|
from colossalai.nn import register_colo_module, init_colo_module, \
|
||||||
ColoLinear, ColoEmbedding
|
ColoLinear, ColoEmbedding
|
||||||
import types
|
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from typing import Iterator, Tuple, Union, Optional
|
from typing import Iterator, Tuple, Union
|
||||||
|
|
||||||
# find named_params includes replica
|
# find named_params includes replica
|
||||||
|
|
||||||
|
@ -24,6 +25,7 @@ def _named_params_with_replica(
|
||||||
name = mod_prefix + ('.' if mod_prefix else '') + name
|
name = mod_prefix + ('.' if mod_prefix else '') + name
|
||||||
yield name, val
|
yield name, val
|
||||||
|
|
||||||
|
|
||||||
def ColoModulize(module):
|
def ColoModulize(module):
|
||||||
"""
|
"""
|
||||||
Replacing the parameters() and named_parameters() with our customized ones
|
Replacing the parameters() and named_parameters() with our customized ones
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
from colossalai.utils import free_port, ColoInitContext, get_current_device
|
from colossalai.utils import free_port, ColoInitContext, get_current_device
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, init_colo_module
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
|
|
||||||
|
from colossalai.nn import init_colo_module
|
||||||
from colossalai.nn.parallel import ColoDDP
|
from colossalai.nn.parallel import ColoDDP
|
||||||
|
|
||||||
import colossalai
|
import colossalai
|
||||||
|
@ -11,7 +14,9 @@ import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
class Net(torch.nn.Module):
|
class Net(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super(Net, self).__init__()
|
super(Net, self).__init__()
|
||||||
self.embed = torch.nn.Embedding(20, 4)
|
self.embed = torch.nn.Embedding(20, 4)
|
||||||
|
@ -27,6 +32,7 @@ class Net(torch.nn.Module):
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
def run_hybrid_device(use_ddp):
|
def run_hybrid_device(use_ddp):
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
model = Net()
|
model = Net()
|
||||||
|
@ -36,7 +42,6 @@ def run_hybrid_device(use_ddp):
|
||||||
model = ColoDDP(model)
|
model = ColoDDP(model)
|
||||||
real_model = model.module
|
real_model = model.module
|
||||||
|
|
||||||
|
|
||||||
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
|
print(f'embedding weight size: {real_model.embed.weight.size()} | device: {real_model.embed.weight.device}')
|
||||||
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
|
#print(f'linear weight size: {real_model.proj.weight.size()} | device: {real_model.proj.weight.device}')
|
||||||
parallel_action = ParallelAction(ComputePattern.TP1D)
|
parallel_action = ParallelAction(ComputePattern.TP1D)
|
||||||
|
@ -54,6 +59,7 @@ def run_hybrid_device(use_ddp):
|
||||||
out = model(data)
|
out = model(data)
|
||||||
out.sum().backward()
|
out.sum().backward()
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port, use_ddp):
|
def run_dist(rank, world_size, port, use_ddp):
|
||||||
if use_ddp and world_size == 1:
|
if use_ddp and world_size == 1:
|
||||||
return
|
return
|
||||||
|
@ -62,6 +68,7 @@ def run_dist(rank, world_size, port, use_ddp):
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_hybrid_device(use_ddp)
|
run_hybrid_device(use_ddp)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
@pytest.mark.parametrize('use_ddp', [False, True])
|
@pytest.mark.parametrize('use_ddp', [False, True])
|
||||||
|
@ -71,5 +78,6 @@ def _test_hybrid_device(world_size, use_ddp):
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
|
run_func = partial(run_dist, world_size=world_size, port=free_port(), use_ddp=use_ddp)
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
_test_hybrid_device(1, False)
|
_test_hybrid_device(1, False)
|
|
@ -10,9 +10,10 @@ from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils import ColoInitContext
|
from colossalai.utils import ColoInitContext
|
||||||
from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \
|
from colossalai.tensor import distspec, named_params_with_colotensor, TensorSpec, ComputePattern, \
|
||||||
ParallelAction, ColoTensor, ColoOptimizer, DistSpecManager
|
ParallelAction, ColoTensor, DistSpecManager
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
from colossalai.nn.optimizer import ColoOptimizer
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from _utils import set_seed
|
from _utils import set_seed
|
||||||
|
|
||||||
|
|
|
@ -1,24 +1,28 @@
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from colossalai.utils.cuda import get_current_device
|
import pytest
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
|
||||||
import torch
|
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
|
||||||
from colossalai.tensor import ColoTensor, distspec
|
|
||||||
|
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
|
||||||
import colossalai
|
|
||||||
import pytest
|
|
||||||
import torch
|
import torch
|
||||||
import torch.multiprocessing as mp
|
import torch.multiprocessing as mp
|
||||||
import torch.nn.functional as F
|
|
||||||
|
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction
|
||||||
|
from colossalai.nn import init_colo_module, check_colo_module
|
||||||
|
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.utils.cuda import get_current_device
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
|
||||||
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
|
from colossalai.tensor import distspec
|
||||||
|
|
||||||
from colossalai.testing import rerun_if_address_is_in_use
|
from colossalai.testing import rerun_if_address_is_in_use
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
from colossalai.tensor import TensorSpec, ComputePattern, ParallelAction, DistSpecManager, register_colo_module, init_colo_module, check_colo_module
|
|
||||||
from _utils import tensor_equal, tensor_shard_equal, set_seed
|
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
|
||||||
|
|
||||||
def run_model_with_spec(mode, model_name):
|
def run_model_with_spec(mode, model_name):
|
||||||
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
get_components_func = non_distributed_component_funcs.get_callable(model_name)
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
@ -103,6 +107,7 @@ def run_model_with_spec(mode, model_name):
|
||||||
if i > 3:
|
if i > 3:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
def run_linear_with_spec(mode):
|
def run_linear_with_spec(mode):
|
||||||
with ColoInitContext(device=get_current_device()):
|
with ColoInitContext(device=get_current_device()):
|
||||||
model = torch.nn.Linear(4, 8)
|
model = torch.nn.Linear(4, 8)
|
||||||
|
@ -122,6 +127,7 @@ def run_linear_with_spec(mode):
|
||||||
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad)
|
assert tensor_shard_equal(model.weight.grad, model_handy.weight.grad)
|
||||||
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad)
|
assert tensor_shard_equal(model.bias.grad, model_handy.bias.grad)
|
||||||
|
|
||||||
|
|
||||||
def run_check_shared_param():
|
def run_check_shared_param():
|
||||||
from transformers import BertForMaskedLM, BertConfig
|
from transformers import BertForMaskedLM, BertConfig
|
||||||
hidden_dim = 8
|
hidden_dim = 8
|
||||||
|
@ -157,12 +163,14 @@ def run_check_shared_param():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert 'incorrectly sharded' in str(e)
|
assert 'incorrectly sharded' in str(e)
|
||||||
|
|
||||||
|
|
||||||
def run_dist(rank, world_size, port):
|
def run_dist(rank, world_size, port):
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_linear_with_spec('col')
|
run_linear_with_spec('col')
|
||||||
run_linear_with_spec('row')
|
run_linear_with_spec('row')
|
||||||
|
|
||||||
|
|
||||||
def run_dist_model(rank, world_size, port):
|
def run_dist_model(rank, world_size, port):
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
@ -170,11 +178,13 @@ def run_dist_model(rank, world_size, port):
|
||||||
run_model_with_spec('col', model_name)
|
run_model_with_spec('col', model_name)
|
||||||
run_model_with_spec('row', model_name)
|
run_model_with_spec('row', model_name)
|
||||||
|
|
||||||
|
|
||||||
def run_dist_check(rank, world_size, port):
|
def run_dist_check(rank, world_size, port):
|
||||||
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
config = dict(parallel=dict(tensor=dict(mode="1d", size=world_size),))
|
||||||
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
colossalai.launch(config=config, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
run_check_shared_param()
|
run_check_shared_param()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
@ -182,6 +192,7 @@ def test_module_linear_1d(world_size):
|
||||||
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
run_func = partial(run_dist, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
@ -189,6 +200,7 @@ def test_module_model(world_size):
|
||||||
run_func = partial(run_dist_model, world_size=world_size, port=free_port())
|
run_func = partial(run_dist_model, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 2])
|
@pytest.mark.parametrize('world_size', [1, 2])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
@ -196,5 +208,6 @@ def test_module_check(world_size):
|
||||||
run_func = partial(run_dist_check, world_size=world_size, port=free_port())
|
run_func = partial(run_dist_check, world_size=world_size, port=free_port())
|
||||||
mp.spawn(run_func, nprocs=world_size)
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test_module_check(2)
|
test_module_check(2)
|
Loading…
Reference in New Issue