mirror of https://github.com/hpcaitech/ColossalAI
parent
dfaff4e243
commit
75d221918a
|
@ -1,7 +1,8 @@
|
||||||
|
from colossalai.tensor.spec import ShardPattern
|
||||||
import torch
|
import torch
|
||||||
from colossalai.tensor.op_wrapper import colo_op_impl
|
from colossalai.tensor.op_wrapper import colo_op_impl
|
||||||
from colossalai.tensor import ColoTensor
|
from colossalai.tensor import ColoTensor
|
||||||
|
from colossalai.nn.loss.loss_1d import VocabParallelCrossEntropyLoss1D
|
||||||
|
|
||||||
@colo_op_impl(torch.nn.functional.cross_entropy)
|
@colo_op_impl(torch.nn.functional.cross_entropy)
|
||||||
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
||||||
|
@ -12,18 +13,29 @@ def colo_cross_entropy(types, args=(), kwargs=None, pg=None):
|
||||||
if arg_num > 1:
|
if arg_num > 1:
|
||||||
target = args[1]
|
target = args[1]
|
||||||
if arg_num > 2:
|
if arg_num > 2:
|
||||||
weight = args[3]
|
weight = args[2]
|
||||||
|
|
||||||
if 'input' in kwargs:
|
if 'input' in kwargs:
|
||||||
input_tensor = kwargs['input']
|
input_tensor = kwargs.pop('input')
|
||||||
if 'target' in kwargs:
|
if 'target' in kwargs:
|
||||||
target = kwargs['target']
|
target = kwargs.pop('target')
|
||||||
if 'weight' in kwargs:
|
if 'weight' in kwargs:
|
||||||
weight = kwargs['weight']
|
weight = kwargs.pop('weight')
|
||||||
|
|
||||||
if isinstance(input_tensor, ColoTensor):
|
if not isinstance(input_tensor, ColoTensor):
|
||||||
input_tensor = input_tensor.torch_tensor()
|
input_tensor = ColoTensor.init_from_torch_tensor(input_tensor)
|
||||||
if isinstance(target, ColoTensor):
|
if isinstance(target, ColoTensor):
|
||||||
target = target.torch_tensor()
|
target = target.torch_tensor()
|
||||||
|
|
||||||
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(input_tensor, target, weight))
|
if input_tensor.is_gathered(): # Input is gathered
|
||||||
|
# TODO(jzy) Shall we make the result of loss function a ColoTensor?
|
||||||
|
return ColoTensor.init_from_torch_tensor(torch.nn.functional.cross_entropy(
|
||||||
|
input_tensor.torch_tensor(), target, weight))
|
||||||
|
elif input_tensor.has_spec() and input_tensor.shard_spec.num_action == 1: # Single Model Parallel Applied
|
||||||
|
if input_tensor.shard_pattern == ShardPattern.Col:
|
||||||
|
return ColoTensor.init_from_torch_tensor(
|
||||||
|
VocabParallelCrossEntropyLoss1D()(input_tensor.torch_tensor(), target))
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
else:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
|
@ -17,6 +17,7 @@ class SimpleNet(CheckpointModule):
|
||||||
self.ln1 = nn.LayerNorm(8)
|
self.ln1 = nn.LayerNorm(8)
|
||||||
self.proj2 = nn.Linear(8, 4)
|
self.proj2 = nn.Linear(8, 4)
|
||||||
self.ln2 = nn.LayerNorm(4)
|
self.ln2 = nn.LayerNorm(4)
|
||||||
|
self.classifier = nn.Linear(4, 4)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
x = self.embed(x)
|
x = self.embed(x)
|
||||||
|
@ -24,6 +25,7 @@ class SimpleNet(CheckpointModule):
|
||||||
x = self.ln1(x)
|
x = self.ln1(x)
|
||||||
x = self.proj2(x)
|
x = self.proj2(x)
|
||||||
x = self.ln2(x)
|
x = self.ln2(x)
|
||||||
|
x = self.classifier(x)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
@ -31,8 +33,8 @@ class SimpleNet(CheckpointModule):
|
||||||
class DummyDataLoader(DummyDataGenerator):
|
class DummyDataLoader(DummyDataGenerator):
|
||||||
|
|
||||||
def generate(self):
|
def generate(self):
|
||||||
data = torch.randint(low=0, high=20, size=(16,20), device=get_current_device())
|
data = torch.randint(low=0, high=20, size=(16,), device=get_current_device())
|
||||||
label = torch.randint(low=0, high=2, size=(16,4), device=get_current_device())
|
label = torch.randint(low=0, high=2, size=(16,), device=get_current_device())
|
||||||
return data, label
|
return data, label
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -144,10 +144,18 @@ def run_1d_hybrid_tp(model_name):
|
||||||
parallel_action_list_col = [
|
parallel_action_list_col = [
|
||||||
ParallelAction(priority=1,
|
ParallelAction(priority=1,
|
||||||
compute_pattern=ComputePattern.TP1DCol_Linear,
|
compute_pattern=ComputePattern.TP1DCol_Linear,
|
||||||
parallel_mode=ParallelMode.PARALLEL_1D)
|
parallel_mode=ParallelMode.PARALLEL_1D),
|
||||||
]
|
]
|
||||||
spec_col = TensorSpec(parallel_action_list_col)
|
spec_col = TensorSpec(parallel_action_list_col)
|
||||||
|
|
||||||
|
parallel_action_list_classifier_col = [
|
||||||
|
ParallelAction(priority=1,
|
||||||
|
compute_pattern=ComputePattern.TP1DCol_Linear,
|
||||||
|
parallel_mode=ParallelMode.PARALLEL_1D,
|
||||||
|
gather_out=False),
|
||||||
|
]
|
||||||
|
spec_classifier_col = TensorSpec(parallel_action_list_classifier_col)
|
||||||
|
|
||||||
parallel_action_list_embedding_col = [
|
parallel_action_list_embedding_col = [
|
||||||
ParallelAction(priority=1,
|
ParallelAction(priority=1,
|
||||||
compute_pattern=ComputePattern.TP1DCol_Embedding,
|
compute_pattern=ComputePattern.TP1DCol_Embedding,
|
||||||
|
@ -158,12 +166,14 @@ def run_1d_hybrid_tp(model_name):
|
||||||
for name, p in model.colo_named_parameters():
|
for name, p in model.colo_named_parameters():
|
||||||
if not isinstance(p, ColoTensor):
|
if not isinstance(p, ColoTensor):
|
||||||
continue
|
continue
|
||||||
|
if 'embed' in name and 'weight' in name:
|
||||||
|
p.set_spec(spec_embedding_col)
|
||||||
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
if 'proj1' in name and ('weight' in name or 'bias' in name):
|
||||||
p.set_spec(spec_col)
|
p.set_spec(spec_col)
|
||||||
if 'proj2' in name and 'weight' in name:
|
if 'proj2' in name and 'weight' in name:
|
||||||
p.set_spec(spec_row)
|
p.set_spec(spec_row)
|
||||||
if 'embed' in name and 'weight' in name:
|
if 'classifier' in name and ('weight' in name or 'bias' in name):
|
||||||
p.set_spec(spec_embedding_col)
|
p.set_spec(spec_classifier_col)
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
|
Loading…
Reference in New Issue