mirror of https://github.com/hpcaitech/ColossalAI
add DistSpec for loss and test_model (#947)
parent
67c33f57eb
commit
797a9dc5a9
|
@ -1,6 +1,6 @@
|
||||||
from .linear import colo_linear
|
from .linear import colo_linear
|
||||||
from .element_wise import *
|
from .element_wise import *
|
||||||
from .layernorm import colo_layernorm
|
from .layernorm import colo_layernorm
|
||||||
# from .loss import colo_cross_entropy
|
from .loss import colo_cross_entropy
|
||||||
from .embedding import colo_embedding
|
from .embedding import colo_embedding
|
||||||
from .addmm import colo_addmm
|
from .addmm import colo_addmm
|
||||||
|
|
|
@ -28,7 +28,7 @@ def colo_layernorm(types, args=(), kwargs=None, pg=None):
|
||||||
|
|
||||||
if isinstance(input_tensor, ColoTensor):
|
if isinstance(input_tensor, ColoTensor):
|
||||||
# TODO (ver217): check input dist spec
|
# TODO (ver217): check input dist spec
|
||||||
input_tensor.to_dist_spec(dist_spec.replicate())
|
input_tensor.to_dist_spec(dist_spec.replicate(input_tensor.spec.get_process_group()))
|
||||||
input_tensor = input_tensor.torch_tensor()
|
input_tensor = input_tensor.torch_tensor()
|
||||||
if isinstance(weight, ColoTensor):
|
if isinstance(weight, ColoTensor):
|
||||||
weight = weight.torch_tensor()
|
weight = weight.torch_tensor()
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from colossalai.tensor.spec import ShardPattern
|
from colossalai.tensor.dist_spec import DistPlacementPattern
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoTensor
|
||||||
|
@ -27,12 +27,11 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
||||||
if isinstance(target, ColoTensor):
|
if isinstance(target, ColoTensor):
|
||||||
target = target.torch_tensor()
|
target = target.torch_tensor()
|
||||||
|
|
||||||
if input_tensor.is_gathered(): # Input is gathered
|
if input_tensor.spec.is_gathered(): # Input is gathered
|
||||||
# TODO(jzy) Shall we make the result of loss function a ColoTensor?
|
|
||||||
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(
|
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(
|
||||||
input_tensor.torch_tensor(), target, weight))
|
input_tensor.torch_tensor(), target, weight))
|
||||||
elif input_tensor.has_spec() and input_tensor.shard_spec.num_action == 1: # Single Model Parallel Applied
|
elif input_tensor.has_spec() and input_tensor.spec.num_action == 1: # Single Model Parallel Applied
|
||||||
if input_tensor.shard_pattern == ShardPattern.Col:
|
if input_tensor.spec.is_1Dcol():
|
||||||
return ColoTensor.init_from_torch_tensor(
|
return ColoTensor.init_from_torch_tensor(
|
||||||
VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), target))
|
VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), target))
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -53,7 +53,8 @@ class DistSpecManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
def _r2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||||
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group:
|
if old_dist_spec.process_group is not None and old_dist_spec.process_group != dist_spec.process_group \
|
||||||
|
and dist_spec.process_group is not None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
|
@ -65,7 +66,8 @@ class DistSpecManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
def _s2r(tensor: torch.Tensor, old_dist_spec: _DistSpec, dist_spec: _DistSpec) -> torch.Tensor:
|
||||||
if old_dist_spec.process_group != dist_spec.process_group:
|
if old_dist_spec.process_group != dist_spec.process_group \
|
||||||
|
and dist_spec.process_group is not None:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
return DistSpecManager._gather(tensor, old_dist_spec)
|
return DistSpecManager._gather(tensor, old_dist_spec)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import List
|
from typing import List
|
||||||
from colossalai.context.parallel_mode import ParallelMode
|
from colossalai.context.parallel_mode import ParallelMode
|
||||||
from colossalai.tensor.dist_spec import _DistSpec
|
from colossalai.tensor.dist_spec import _DistSpec, DistPlacementPattern
|
||||||
|
|
||||||
|
|
||||||
class ComputePattern(Enum):
|
class ComputePattern(Enum):
|
||||||
|
@ -84,3 +84,16 @@ class TensorSpec(object):
|
||||||
|
|
||||||
def get_process_group(self):
|
def get_process_group(self):
|
||||||
return self.dist_spec.process_group
|
return self.dist_spec.process_group
|
||||||
|
|
||||||
|
def get_placement(self):
|
||||||
|
return self.dist_spec.placement
|
||||||
|
|
||||||
|
def is_gathered(self):
|
||||||
|
return self.dist_spec.placement == DistPlacementPattern.REPLICATE \
|
||||||
|
or (len(self.dist_spec.num_partitions) == 1
|
||||||
|
and self.dist_spec.num_partitions[0] == 1) \
|
||||||
|
or (self.dist_spec.process_group.size() == 1)
|
||||||
|
|
||||||
|
def is_1Dcol(self):
|
||||||
|
return self.dist_spec.placement == DistPlacementPattern.SHARD \
|
||||||
|
and len(self.dist_spec.dims) == 1 and self.dist_spec.dims[0] == -1
|
|
@ -9,7 +9,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
||||||
from colossalai.utils.cuda import get_current_device
|
from colossalai.utils.cuda import get_current_device
|
||||||
from colossalai.utils import free_port
|
from colossalai.utils import free_port
|
||||||
from colossalai.utils import ColoInitContext
|
from colossalai.utils import ColoInitContext
|
||||||
from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor, ColoOptimizer
|
from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, \
|
||||||
|
ParallelAction, ColoTensor, ColoOptimizer, dist_spec, DistSpecManager
|
||||||
from colossalai.context import ParallelMode
|
from colossalai.context import ParallelMode
|
||||||
from colossalai.core import global_context as gpc
|
from colossalai.core import global_context as gpc
|
||||||
|
|
||||||
|
@ -85,6 +86,34 @@ def set_seed(seed):
|
||||||
torch.cuda.manual_seed(seed)
|
torch.cuda.manual_seed(seed)
|
||||||
torch.backends.cudnn.deterministic = True
|
torch.backends.cudnn.deterministic = True
|
||||||
|
|
||||||
|
def init_1d_row_linear(weight):
|
||||||
|
spec = TensorSpec(
|
||||||
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
|
with DistSpecManager.no_grad():
|
||||||
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
def init_1d_col_linear(weight, gather_out=True):
|
||||||
|
spec = TensorSpec(
|
||||||
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D, \
|
||||||
|
gather_out=gather_out)])
|
||||||
|
with DistSpecManager.no_grad():
|
||||||
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
def init_1d_row_embedding(weight):
|
||||||
|
spec = TensorSpec(
|
||||||
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [0], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
|
with DistSpecManager.no_grad():
|
||||||
|
weight.set_spec(spec)
|
||||||
|
|
||||||
|
def init_1d_col_embedding(weight):
|
||||||
|
spec = TensorSpec(
|
||||||
|
dist_spec.shard(gpc.get_group(ParallelMode.PARALLEL_1D), [-1], [gpc.get_world_size(ParallelMode.PARALLEL_1D)]),
|
||||||
|
[ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DCol, parallel_mode=ParallelMode.PARALLEL_1D)])
|
||||||
|
with DistSpecManager.no_grad():
|
||||||
|
weight.set_spec(spec)
|
||||||
|
|
||||||
def run_1d_hybrid_tp(model_name):
|
def run_1d_hybrid_tp(model_name):
|
||||||
# A simple net with two stacked nn.Linear
|
# A simple net with two stacked nn.Linear
|
||||||
|
@ -106,84 +135,35 @@ def run_1d_hybrid_tp(model_name):
|
||||||
p2.data.copy_(p1.data)
|
p2.data.copy_(p1.data)
|
||||||
|
|
||||||
if 'bert' == model_name:
|
if 'bert' == model_name:
|
||||||
parallel_action_list_row = [
|
|
||||||
ParallelAction(priority=1,
|
|
||||||
compute_pattern=ComputePattern.TP1DRow_Linear,
|
|
||||||
parallel_mode=ParallelMode.PARALLEL_1D)
|
|
||||||
]
|
|
||||||
spec_linear_row = TensorSpec(parallel_action_list_row)
|
|
||||||
|
|
||||||
parallel_action_list_embedding_col = [
|
|
||||||
ParallelAction(priority=1,
|
|
||||||
compute_pattern=ComputePattern.TP1DCol_Embedding,
|
|
||||||
parallel_mode=ParallelMode.PARALLEL_1D)
|
|
||||||
]
|
|
||||||
spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
|
|
||||||
|
|
||||||
parallel_action_list_embedding_row = [
|
|
||||||
ParallelAction(priority=1,
|
|
||||||
compute_pattern=ComputePattern.TP1DRow_Embedding,
|
|
||||||
parallel_mode=ParallelMode.PARALLEL_1D)
|
|
||||||
]
|
|
||||||
spec_embedding_row = TensorSpec(parallel_action_list_embedding_row)
|
|
||||||
|
|
||||||
for name, p in model.colo_named_parameters():
|
for name, p in model.colo_named_parameters():
|
||||||
if not isinstance(p, ColoTensor):
|
if not isinstance(p, ColoTensor):
|
||||||
continue
|
continue
|
||||||
# print(name)
|
# print(name)
|
||||||
# num_class = type_vocab_size = 2 | (8, 2)
|
# num_class = type_vocab_size = 2 | (8, 2)
|
||||||
if 'classifier' in name and 'weight' in name:
|
if 'classifier' in name and 'weight' in name:
|
||||||
p.set_spec(spec_linear_row)
|
init_1d_row_linear(p)
|
||||||
# num_class = vocab_size = 30524 | (30524, 8)
|
# num_class = vocab_size = 30524 | (30524, 8)
|
||||||
if 'word_embeddings' in name and 'weight' in name:
|
if 'word_embeddings' in name and 'weight' in name:
|
||||||
p.set_spec(spec_embedding_row)
|
init_1d_row_embedding(p)
|
||||||
# num_class = seq_len = 512 | (512, 8)
|
# num_class = seq_len = 512 | (512, 8)
|
||||||
if 'position_embeddings' in name and 'weight' in name:
|
if 'position_embeddings' in name and 'weight' in name:
|
||||||
p.set_spec(spec_embedding_row)
|
init_1d_row_embedding(p)
|
||||||
# num_class = type_vocab_size = 2 | (2, 8)
|
# num_class = type_vocab_size = 2 | (2, 8)
|
||||||
if 'token_type_embeddings' in name and 'weight' in name:
|
if 'token_type_embeddings' in name and 'weight' in name:
|
||||||
p.set_spec(spec_embedding_col)
|
init_1d_col_embedding(p)
|
||||||
elif "simple_net" == model_name:
|
elif "simple_net" == model_name:
|
||||||
parallel_action_list_row = [
|
|
||||||
ParallelAction(priority=1,
|
|
||||||
compute_pattern=ComputePattern.TP1DRow_Linear,
|
|
||||||
parallel_mode=ParallelMode.PARALLEL_1D)
|
|
||||||
]
|
|
||||||
spec_row = TensorSpec(parallel_action_list_row)
|
|
||||||
|
|
||||||
parallel_action_list_col = [
|
|
||||||
ParallelAction(priority=1,
|
|
||||||
compute_pattern=ComputePattern.TP1DCol_Linear,
|
|
||||||
parallel_mode=ParallelMode.PARALLEL_1D),
|
|
||||||
]
|
|
||||||
spec_col = TensorSpec(parallel_action_list_col)
|
|
||||||
|
|
||||||
parallel_action_list_classifier_col = [
|
|
||||||
ParallelAction(priority=1,
|
|
||||||
compute_pattern=ComputePattern.TP1DCol_Linear,
|
|
||||||
parallel_mode=ParallelMode.PARALLEL_1D,
|
|
||||||
gather_out=False),
|
|
||||||
]
|
|
||||||
spec_classifier_col = TensorSpec(parallel_action_list_classifier_col)
|
|
||||||
|
|
||||||
parallel_action_list_embedding_col = [
|
|
||||||
ParallelAction(priority=1,
|
|
||||||
compute_pattern=ComputePattern.TP1DCol_Embedding,
|
|
||||||
parallel_mode=ParallelMode.PARALLEL_1D)
|
|
||||||
]
|
|
||||||
spec_embedding_col = TensorSpec(parallel_action_list_embedding_col)
|
|
||||||
# A naive way to set spec for all weights in Linear
|
# A naive way to set spec for all weights in Linear
|
||||||
for name, p in model.colo_named_parameters():
|
for name, p in model.colo_named_parameters():
|
||||||
if not isinstance(p, ColoTensor):
|
if not isinstance(p, ColoTensor):
|
||||||
continue
|
continue
|
||||||
if 'embed' in name and 'weight' in name:
|
if 'embed' in name and 'weight' in name:
|
||||||
p.set_spec(spec_embedding_col)
|
init_1d_col_embedding(p)
|
||||||
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
||||||
p.set_spec(spec_col)
|
init_1d_col_linear(p)
|
||||||
if 'proj2' in name and 'weight' in name:
|
if 'proj2' in name and 'weight' in name:
|
||||||
p.set_spec(spec_row)
|
init_1d_row_linear(p)
|
||||||
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||||
p.set_spec(spec_classifier_col)
|
init_1d_col_linear(p, gather_out=False)
|
||||||
|
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
colo_optimizer = ColoOptimizer(dict(model.named_parameters()), torch.optim.SGD, lr=0.1)
|
||||||
|
@ -251,8 +231,6 @@ def run_1d_hybrid_tp(model_name):
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
# FIXME (ver217): enable this test
|
|
||||||
@pytest.mark.skip
|
|
||||||
# Test the overrided parameters() and named_parameters() member functions
|
# Test the overrided parameters() and named_parameters() member functions
|
||||||
def test_model_parameters():
|
def test_model_parameters():
|
||||||
# build a module with 2 Linear, 4 parameters in total.
|
# build a module with 2 Linear, 4 parameters in total.
|
||||||
|
@ -285,8 +263,6 @@ def test_model_parameters():
|
||||||
assert param_cnt == 2
|
assert param_cnt == 2
|
||||||
|
|
||||||
|
|
||||||
# FIXME (ver217): enable this test
|
|
||||||
@pytest.mark.skip
|
|
||||||
def test_colo_optimizer():
|
def test_colo_optimizer():
|
||||||
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
get_components_func = non_distributed_component_funcs.get_callable('simple_net')
|
||||||
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
model_builder, train_dataloader, test_dataloader, optimizer_class, criterion = get_components_func()
|
||||||
|
@ -329,29 +305,14 @@ def run_1d_row_tp(model_name: str):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
model_torch = model_builder(checkpoint=True)
|
model_torch = model_builder(checkpoint=True)
|
||||||
model_torch = model_torch.cuda()
|
model_torch = model_torch.cuda()
|
||||||
|
|
||||||
parallel_action_list = [
|
|
||||||
ParallelAction(priority=1,
|
|
||||||
compute_pattern=ComputePattern.TP1DRow_Linear,
|
|
||||||
parallel_mode=ParallelMode.PARALLEL_1D)
|
|
||||||
]
|
|
||||||
spec = TensorSpec(parallel_action_list)
|
|
||||||
|
|
||||||
parallel_action_list_embedding_row = [
|
|
||||||
ParallelAction(priority=1,
|
|
||||||
compute_pattern=ComputePattern.TP1DRow_Embedding,
|
|
||||||
parallel_mode=ParallelMode.PARALLEL_1D)
|
|
||||||
]
|
|
||||||
spec_embedding_row = TensorSpec(parallel_action_list_embedding_row)
|
|
||||||
|
|
||||||
# A naive way to set spec for all weights in Linear
|
# A naive way to set spec for all weights in Linear
|
||||||
for name, p in model.colo_named_parameters():
|
for name, p in model.colo_named_parameters():
|
||||||
if not isinstance(p, ColoTensor):
|
if not isinstance(p, ColoTensor):
|
||||||
continue
|
continue
|
||||||
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
|
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name:
|
||||||
p.set_spec(spec)
|
init_1d_row_linear(p)
|
||||||
if 'embed' in name and 'weight' in name:
|
if 'embed' in name and 'weight' in name:
|
||||||
p.set_spec(spec_embedding_row)
|
init_1d_row_embedding(p)
|
||||||
|
|
||||||
model = model.cuda()
|
model = model.cuda()
|
||||||
|
|
||||||
|
@ -434,9 +395,6 @@ def run_model_dist(rank, world_size, port):
|
||||||
for name in ['bert', 'simple_net']:
|
for name in ['bert', 'simple_net']:
|
||||||
run_1d_hybrid_tp(name)
|
run_1d_hybrid_tp(name)
|
||||||
|
|
||||||
|
|
||||||
# FIXME (ver217): enable this test
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
# @parameterize('world_size', [1, 4])
|
# @parameterize('world_size', [1, 4])
|
||||||
|
@ -454,8 +412,6 @@ def run_pretrain_load_dist(rank, world_size, port):
|
||||||
|
|
||||||
# The test case has to download huggingface pretrained models from the internet
|
# The test case has to download huggingface pretrained models from the internet
|
||||||
# So we manually trigger the test.
|
# So we manually trigger the test.
|
||||||
# FIXME (ver217): enable this test
|
|
||||||
@pytest.mark.skip
|
|
||||||
@pytest.mark.dist
|
@pytest.mark.dist
|
||||||
@pytest.mark.parametrize('world_size', [1, 4])
|
@pytest.mark.parametrize('world_size', [1, 4])
|
||||||
@rerun_if_address_is_in_use()
|
@rerun_if_address_is_in_use()
|
||||||
|
|
Loading…
Reference in New Issue