|
|
|
@ -7,7 +7,8 @@ from colossalai.testing import parameterize, rerun_if_address_is_in_use
|
|
|
|
|
from colossalai.utils.cuda import get_current_device |
|
|
|
|
from colossalai.utils import free_port |
|
|
|
|
from colossalai.utils import ColoInitContext |
|
|
|
|
from colossalai.tensor import named_params_with_colotensor |
|
|
|
|
from colossalai.tensor import named_params_with_colotensor, TensorSpec, ComputePattern, ParallelAction, ColoTensor |
|
|
|
|
from colossalai.context import ParallelMode |
|
|
|
|
|
|
|
|
|
from functools import partial |
|
|
|
|
|
|
|
|
@ -20,18 +21,32 @@ def run_simple_net():
|
|
|
|
|
with ColoInitContext(device=get_current_device()): |
|
|
|
|
model = model_builder(checkpoint=True) |
|
|
|
|
|
|
|
|
|
parallel_action_list = [ |
|
|
|
|
ParallelAction(priority=1, compute_pattern=ComputePattern.TP1DRow, parallel_mode=ParallelMode.PARALLEL_1D) |
|
|
|
|
] |
|
|
|
|
spec = TensorSpec(parallel_action_list) |
|
|
|
|
|
|
|
|
|
# A naive way to set spec for all weights in Linear |
|
|
|
|
for name, p in named_params_with_colotensor(model): |
|
|
|
|
if not isinstance(p, ColoTensor): |
|
|
|
|
continue |
|
|
|
|
if 'weight' in name and 'LayerNorm' not in name and 'ln' not in name and 'embed' not in name: |
|
|
|
|
p.set_spec(spec) |
|
|
|
|
|
|
|
|
|
model.cuda() |
|
|
|
|
|
|
|
|
|
for param in named_params_with_colotensor(model): |
|
|
|
|
print(param) |
|
|
|
|
# we set the Specs for weight of each linear. |
|
|
|
|
# model.proj1.weight.set_spec('1Drow') |
|
|
|
|
# model.proj2.weight.set_spec('1Drow') |
|
|
|
|
|
|
|
|
|
for i, (data, label) in enumerate(train_dataloader): |
|
|
|
|
output = model(data) |
|
|
|
|
data = data.to(get_current_device()) |
|
|
|
|
label = label.to(get_current_device()) |
|
|
|
|
|
|
|
|
|
if criterion: |
|
|
|
|
output = model(data) |
|
|
|
|
loss = criterion(output, label) |
|
|
|
|
else: |
|
|
|
|
output = model(data, label) |
|
|
|
|
loss = output |
|
|
|
|
|
|
|
|
|
print(loss.torch_tensor()) |
|
|
|
|