mirror of https://github.com/hpcaitech/ColossalAI
[ColoTensor] ColoInitContext initialize parameters in shard mode. (#1937)
parent
b42b672842
commit
9f4fb3f28a
|
@ -1,4 +1,4 @@
|
||||||
from typing import Iterator, Tuple, Union
|
from typing import Dict, Iterator, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
@ -36,7 +36,10 @@ def ColoModulize(module):
|
||||||
|
|
||||||
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
|
|
||||||
def __init__(self, device: torch.device = torch.device('cpu'), dtype: torch.dtype = torch.float):
|
def __init__(self,
|
||||||
|
device: torch.device = torch.device('cpu'),
|
||||||
|
dtype: torch.dtype = torch.float,
|
||||||
|
default_shard_plan: Optional[Dict] = None):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
|
device (torch.device): the device where parameters initialized are resident. Defaults to torch.device('cpu').
|
||||||
|
@ -47,6 +50,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
|
|
||||||
self._register_colo_modules()
|
self._register_colo_modules()
|
||||||
|
self._default_shard_plan = default_shard_plan
|
||||||
|
|
||||||
def _register_colo_modules(self):
|
def _register_colo_modules(self):
|
||||||
register_colo_module(torch.nn.Linear, ColoLinear())
|
register_colo_module(torch.nn.Linear, ColoLinear())
|
||||||
|
@ -64,6 +68,10 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
if hasattr(module, '_colo_visited'):
|
if hasattr(module, '_colo_visited'):
|
||||||
return
|
return
|
||||||
|
|
||||||
|
if self._default_shard_plan is not None:
|
||||||
|
default_pg = self._default_shard_plan.get('pg', None)
|
||||||
|
default_shard_spec = self._default_shard_plan.get('shard_spec', None)
|
||||||
|
|
||||||
name_list = []
|
name_list = []
|
||||||
for name, param in _named_params_with_replica(module):
|
for name, param in _named_params_with_replica(module):
|
||||||
if isinstance(param, ColoTensor):
|
if isinstance(param, ColoTensor):
|
||||||
|
@ -91,7 +99,18 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
|
||||||
# TODO(jiaruifang) we initialize a Default PG memory
|
# TODO(jiaruifang) we initialize a Default PG memory
|
||||||
colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype),
|
colo_param = ColoParameter(param.to(device=self._device, dtype=self._dtype),
|
||||||
requires_grad=requires_grad)
|
requires_grad=requires_grad)
|
||||||
# add mapping record
|
|
||||||
|
# if default_shard_plan exists, shard the param during initialization.
|
||||||
|
# This can reduce the model size after initialization.
|
||||||
|
# NOTE() embedding usually can not be correctly sharded. So I use except to handle
|
||||||
|
# the param that can not be sharded by the default plan
|
||||||
|
if self._default_shard_plan is not None:
|
||||||
|
colo_param.set_process_group(default_pg)
|
||||||
|
try:
|
||||||
|
colo_param.set_dist_spec(default_shard_spec)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
replaced_tensors[param] = colo_param
|
replaced_tensors[param] = colo_param
|
||||||
delattr(submodule, param_name)
|
delattr(submodule, param_name)
|
||||||
setattr(submodule, param_name, colo_param)
|
setattr(submodule, param_name, colo_param)
|
||||||
|
|
|
@ -1,5 +1,66 @@
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
import colossalai
|
||||||
|
from colossalai.tensor import (
|
||||||
|
ColoParameter,
|
||||||
|
ColoTensorSpec,
|
||||||
|
ComputePattern,
|
||||||
|
ComputeSpec,
|
||||||
|
ProcessGroup,
|
||||||
|
ReplicaSpec,
|
||||||
|
ShardSpec,
|
||||||
|
)
|
||||||
|
from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
|
from tests.test_tensor.common_utils import set_seed
|
||||||
|
|
||||||
|
|
||||||
|
def run_colo_init_context(rank: int, world_size: int, port: int):
|
||||||
|
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
|
||||||
|
# make sure seed of each process is the same, so the params are consistent among processes and the params are exactly replicated.
|
||||||
|
set_seed(42)
|
||||||
|
get_components_func = non_distributed_component_funcs.get_callable('gpt2')
|
||||||
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
|
||||||
|
# keep parameters replicated during init
|
||||||
|
with ColoInitContext(device=get_current_device()):
|
||||||
|
model1 = model_builder()
|
||||||
|
|
||||||
|
# shard the parameters during init
|
||||||
|
set_seed(42)
|
||||||
|
shard_spec = ReplicaSpec()
|
||||||
|
# ShardSpec(dims=[0], num_partitions=[world_size])
|
||||||
|
default_shard_plan = {'pg': ProcessGroup(tp_degree=world_size), 'shard_spec': shard_spec}
|
||||||
|
with ColoInitContext(device=get_current_device(), default_shard_plan=default_shard_plan):
|
||||||
|
model2 = model_builder()
|
||||||
|
|
||||||
|
# reshard both models
|
||||||
|
new_shard = ShardSpec(dims=[-1], num_partitions=[world_size])
|
||||||
|
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||||
|
p1: ColoParameter = p1
|
||||||
|
p1.set_process_group(ProcessGroup(tp_degree=world_size))
|
||||||
|
p1.set_dist_spec(new_shard)
|
||||||
|
p2.set_dist_spec(new_shard)
|
||||||
|
|
||||||
|
for p1, p2 in zip(model1.parameters(), model2.parameters()):
|
||||||
|
assert (torch.allclose(p1, p2))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.dist
|
||||||
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_colo_init_context(world_size):
|
||||||
|
run_func = partial(run_colo_init_context, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_colo_init_context(2)
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from lib2to3 import pgen2
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
|
|
@ -18,7 +18,7 @@ from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||||
from colossalai.zero import ZeroOptimizer
|
from colossalai.zero import ZeroOptimizer
|
||||||
from tests.components_to_test.registry import non_distributed_component_funcs
|
from tests.components_to_test.registry import non_distributed_component_funcs
|
||||||
from tests.test_tensor.common_utils import set_seed, tensor_equal, tensor_shard_equal
|
from tests.test_tensor.common_utils import set_seed, tensor_shard_equal
|
||||||
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
|
from tests.test_tensor.model.test_gpt2 import init_megatron_spec
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue