[Tensor] init ColoParameter (#914)

pull/911/head
Jiarui Fang 2022-05-06 12:57:14 +08:00 committed by GitHub
parent 193d629311
commit ab95ec9aea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 77 additions and 44 deletions

View File

@ -2,11 +2,12 @@ from .spec import ComputePattern, ParallelAction, TensorSpec, ShardPattern
from .op_wrapper import ( from .op_wrapper import (
colo_op_impl,) colo_op_impl,)
from .colo_tensor import ColoTensor from .colo_tensor import ColoTensor
from .colo_parameter import ColoParameter
from .utils import convert_parameter, named_params_with_colotensor from .utils import convert_parameter, named_params_with_colotensor
from ._ops import * from ._ops import *
from .optim.colo_optimizer import ColoOptimizer from .optim.colo_optimizer import ColoOptimizer
__all__ = [ __all__ = [
'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction', 'ColoTensor', 'convert_parameter', 'colo_op_impl', 'ComputePattern', 'TensorSpec', 'ParallelAction',
'named_params_with_colotensor', 'ShardPattern', 'ColoOptimizer' 'named_params_with_colotensor', 'ShardPattern', 'ColoOptimizer', 'ColoParameter'
] ]

View File

@ -0,0 +1,28 @@
from .colo_tensor import ColoTensor
from .const import TensorType
import torch
class ColoParameter(ColoTensor):
r"""A kind of ColoTensor to be considered as a module parameter.
"""
def __init__(self, *args, **kargs):
super().__init__(*args, **kargs)
self._type = TensorType.MODEL
def __new__(cls, *args, **kwargs):
t = super(ColoParameter, cls).__new__(cls)
t._type = TensorType.MODEL
return t
@staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoParameter':
colo_p = ColoParameter(*tensor.size(),
dtype=tensor.dtype,
requires_grad=tensor.requires_grad,
pin_memory=tensor.is_pinned(),
device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0))
return colo_p

View File

@ -7,12 +7,7 @@ from colossalai.core import global_context as gpc
from colossalai.nn.layer.utils import divide from colossalai.nn.layer.utils import divide
from colossalai.tensor import TensorSpec, ComputePattern, ShardPattern from colossalai.tensor import TensorSpec, ComputePattern, ShardPattern
from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, gather_forward_split_backward from colossalai.nn.layer.parallel_1d._utils import split_forward_gather_backward, gather_forward_split_backward
from enum import Enum from .const import TensorType
class TensorType(Enum):
MODEL = 0
NONMODEL = 1 # mainly activations
class ColoTensor(object): class ColoTensor(object):
@ -26,17 +21,14 @@ class ColoTensor(object):
def __new__(cls, *args, **kwargs): def __new__(cls, *args, **kwargs):
return super(ColoTensor, cls).__new__(cls) return super(ColoTensor, cls).__new__(cls)
def __init__( def __init__(self,
self,
*size: Tuple[int], *size: Tuple[int],
dtype=None, dtype=None,
requires_grad=False, requires_grad=False,
pin_memory=False, pin_memory=False,
device=None, device=None,
torch_tensor=torch.empty(0), torch_tensor=torch.empty(0),
shard_spec: TensorSpec = TensorSpec(), shard_spec: TensorSpec = TensorSpec()):
is_model_data: bool = False,
):
self._size = size self._size = size
self._dtype = dtype self._dtype = dtype
self._requires_grad = requires_grad self._requires_grad = requires_grad
@ -45,9 +37,6 @@ class ColoTensor(object):
self._torch_tensor = torch_tensor self._torch_tensor = torch_tensor
self._shard_spec = shard_spec self._shard_spec = shard_spec
self._shard_pattern = ShardPattern.NA self._shard_pattern = ShardPattern.NA
if is_model_data:
self._type = TensorType.MODEL
else:
self._type = TensorType.NONMODEL self._type = TensorType.NONMODEL
def __getitem__(self, key): def __getitem__(self, key):
@ -97,14 +86,13 @@ class ColoTensor(object):
return product(self._size) return product(self._size)
@staticmethod @staticmethod
def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True, is_model_data=False) -> 'ColoTensor': def init_from_torch_tensor(tensor: torch.Tensor, save_payload=True) -> 'ColoTensor':
colo_t = ColoTensor(*tensor.size(), colo_t = ColoTensor(*tensor.size(),
dtype=tensor.dtype, dtype=tensor.dtype,
requires_grad=tensor.requires_grad, requires_grad=tensor.requires_grad,
pin_memory=tensor.is_pinned(), pin_memory=tensor.is_pinned(),
device=tensor.device, device=tensor.device,
torch_tensor=tensor if save_payload else torch.empty(0), torch_tensor=tensor if save_payload else torch.empty(0))
is_model_data=is_model_data)
return colo_t return colo_t
def del_torch_tensor(self, save_shape=False) -> None: def del_torch_tensor(self, save_shape=False) -> None:
@ -143,8 +131,7 @@ class ColoTensor(object):
self.gather() self.gather()
# Model Parameters # Model Parameters
if self._shard_spec.num_action == 1: if self._shard_spec.num_action == 1:
parallel_action = self._shard_spec.get_action_by_compute_pattern( parallel_action = self._shard_spec.get_action_by_compute_pattern(self._shard_spec.compute_patterns[0])
self._shard_spec.compute_patterns[0])
if parallel_action.compute_pattern in [ComputePattern.TP1DRow_Linear, \ if parallel_action.compute_pattern in [ComputePattern.TP1DRow_Linear, \
ComputePattern.TP1DCol_Embedding]: ComputePattern.TP1DCol_Embedding]:
self._shard_1d(parallel_action=parallel_action, dim=-1) self._shard_1d(parallel_action=parallel_action, dim=-1)
@ -157,7 +144,7 @@ class ColoTensor(object):
raise NotImplementedError raise NotImplementedError
def gather(self): def gather(self):
assert self.is_activation(), 'Currently we only support gather Activation ColoTensor.' assert not self.is_model_data(), 'Currently we only support gather Activation ColoTensor.'
assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.' assert not self.is_gathered(), 'Only sharded ColoTensor can be gathered.'
parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP) parallel_action = self._shard_spec.get_action_by_compute_pattern(ComputePattern.DP)
if self._shard_pattern == ShardPattern.Row: if self._shard_pattern == ShardPattern.Row:
@ -174,8 +161,8 @@ class ColoTensor(object):
def has_spec(self) -> bool: def has_spec(self) -> bool:
return self._shard_spec is not None and self._shard_spec.num_action > 0 return self._shard_spec is not None and self._shard_spec.num_action > 0
def is_activation(self) -> bool: def is_model_data(self) -> bool:
return self._type == TensorType.NONMODEL return self._type == TensorType.MODEL
def _shard_1d(self, parallel_action, dim=-1): def _shard_1d(self, parallel_action, dim=-1):
num_partition = gpc.get_world_size(parallel_action.parallel_mode) num_partition = gpc.get_world_size(parallel_action.parallel_mode)

View File

@ -0,0 +1,6 @@
from enum import Enum
class TensorType(Enum):
MODEL = 0
NONMODEL = 1 # mainly activations

View File

@ -1,6 +1,6 @@
from .utils import InsertPostInitMethodToModuleSubClasses from .utils import InsertPostInitMethodToModuleSubClasses
import torch import torch
from colossalai.tensor import ColoTensor from colossalai.tensor import ColoTensor, ColoParameter
import types import types
from torch import nn from torch import nn
@ -100,10 +100,7 @@ class ColoInitContext(InsertPostInitMethodToModuleSubClasses):
tensor_detached = param.to(self._device).detach() tensor_detached = param.to(self._device).detach()
tensor_detached.requires_grad = requires_grad tensor_detached.requires_grad = requires_grad
setattr( setattr(module, name,
module, name, ColoParameter.init_from_torch_tensor(tensor=tensor_detached, save_payload=save_torch_payload))
ColoTensor.init_from_torch_tensor(tensor=tensor_detached,
save_payload=save_torch_payload,
is_model_data=True))
ColoModulize(module) ColoModulize(module)

View File

@ -38,17 +38,23 @@ def run_1d_col_tp():
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
parallel_action_list_row = [ parallel_action_list_row = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D) ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DRow_Linear,
parallel_mode=ParallelMode.PARALLEL_1D)
] ]
spec_row = TensorSpec(parallel_action_list_row) spec_row = TensorSpec(parallel_action_list_row)
parallel_action_list_col = [ parallel_action_list_col = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Linear, parallel_mode=ParallelMode.PARALLEL_1D) ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DCol_Linear,
parallel_mode=ParallelMode.PARALLEL_1D)
] ]
spec_col = TensorSpec(parallel_action_list_col) spec_col = TensorSpec(parallel_action_list_col)
parallel_action_list_embedding_col = [ parallel_action_list_embedding_col = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol_Embedding, parallel_mode=ParallelMode.PARALLEL_1D) ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DCol_Embedding,
parallel_mode=ParallelMode.PARALLEL_1D)
] ]
spec_embedding_col = TensorSpec(parallel_action_list_embedding_col) spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
@ -125,6 +131,9 @@ def test_model_parameters():
param_cnt += 1 param_cnt += 1
assert param_cnt == 5 assert param_cnt == 5
for name, colo_p in model.colo_named_parameters():
assert colo_p.is_model_data()
param_cnt = 0 param_cnt = 0
for name, p in model.named_parameters(recurse=False): for name, p in model.named_parameters(recurse=False):
param_cnt += 1 param_cnt += 1
@ -175,12 +184,16 @@ def run_1d_row_tp():
model = model_builder(checkpoint=True) model = model_builder(checkpoint=True)
parallel_action_list = [ parallel_action_list = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Linear, parallel_mode=ParallelMode.PARALLEL_1D) ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DRow_Linear,
parallel_mode=ParallelMode.PARALLEL_1D)
] ]
spec = TensorSpec(parallel_action_list) spec = TensorSpec(parallel_action_list)
parallel_action_list_embedding_row = [ parallel_action_list_embedding_row = [
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow_Embedding, parallel_mode=ParallelMode.PARALLEL_1D) ParallelAction(priority=1,
compute_pattern=ComputePattern.TP1DRow_Embedding,
parallel_mode=ParallelMode.PARALLEL_1D)
] ]
spec_embedding_row = TensorSpec(parallel_action_list_embedding_row) spec_embedding_row = TensorSpec(parallel_action_list_embedding_row)
@ -243,6 +256,7 @@ def run_dist(rank, world_size, port):
run_1d_row_tp() run_1d_row_tp()
run_1d_col_tp() run_1d_col_tp()
@pytest.mark.dist @pytest.mark.dist
@parameterize('world_size', [1, 4]) @parameterize('world_size', [1, 4])
@rerun_if_address_is_in_use() @rerun_if_address_is_in_use()
@ -252,6 +266,6 @@ def test_simple_net(world_size):
if __name__ == '__main__': if __name__ == '__main__':
test_simple_net() # test_simple_net()
# test_model_parameters() test_model_parameters()
# test_colo_optimizer() # test_colo_optimizer()