mirror of https://github.com/hpcaitech/ColossalAI
aibig-modeldata-parallelismdeep-learningdistributed-computingfoundation-modelsheterogeneous-traininghpcinferencelarge-scalemodel-parallelismpipeline-parallelism
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
125 lines
3.9 KiB
125 lines
3.9 KiB
#!/usr/bin/env python |
|
# -*- encoding: utf-8 -*- |
|
|
|
import pytest |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
from colossalai.context.parallel_mode import ParallelMode |
|
from colossalai.context.random import add_seed, reset_seeds, seed, set_mode |
|
from colossalai.testing import clear_cache_before_run, parameterize |
|
from colossalai.utils.activation_checkpoint import checkpoint |
|
|
|
|
|
def forward(x, weight): |
|
out = torch.matmul(x, weight) |
|
with seed(ParallelMode.DATA): |
|
out_ = F.dropout(out, p=0.4, training=True) |
|
return out_ |
|
|
|
|
|
def forward_inplace_ckpt(x, weight, cpu_offload=False): |
|
out = torch.matmul(x, weight) |
|
bn = torch.nn.BatchNorm1d(4, affine=False) |
|
bn = bn.to(device="cuda") |
|
out = bn(out) |
|
|
|
def ckpt0(x): |
|
return F.relu(x, inplace=True) |
|
|
|
out = checkpoint(ckpt0, cpu_offload, out, use_reentrant=False) |
|
return out |
|
|
|
|
|
def forward_inplace(x, weight): |
|
out = torch.matmul(x, weight) |
|
bn = torch.nn.BatchNorm1d(4, affine=False) |
|
bn = bn.to(device="cuda") |
|
out = bn(out) |
|
out = F.relu(out, inplace=True) |
|
return out |
|
|
|
|
|
@pytest.mark.gpu |
|
@clear_cache_before_run() |
|
@parameterize("use_reentrant", [True, False]) |
|
@parameterize("cpu_offload", [True, False]) |
|
def test_activation_checkpointing(cpu_offload, use_reentrant): |
|
|
|
# as seed manager is singleton |
|
# if we don't reset seeds here, |
|
# other tests might affect this test |
|
reset_seeds() |
|
|
|
# We put initialization here to avoid change cuda rng state below |
|
inputs = torch.rand(2, 2, requires_grad=True, device='cuda') |
|
weight = torch.rand(2, 4, requires_grad=True, device='cuda') |
|
|
|
# Get a copy of input tensors |
|
inputs_ = torch.empty(2, 2, requires_grad=True, device='cuda') |
|
inputs_.data.copy_(inputs.data) |
|
weight_ = torch.empty(2, 4, requires_grad=True, device='cuda') |
|
weight_.data.copy_(weight.data) |
|
|
|
add_seed(ParallelMode.GLOBAL, 1024) |
|
add_seed(ParallelMode.DATA, 1026) |
|
set_mode(ParallelMode.GLOBAL) |
|
global_cuda_rng_state = torch.cuda.get_rng_state() |
|
set_mode(ParallelMode.DATA) |
|
data_parallel_cuda_rng_state = torch.cuda.get_rng_state() |
|
set_mode(ParallelMode.GLOBAL) |
|
|
|
out = forward(inputs, weight) |
|
loss = out.sum() |
|
loss.backward() |
|
|
|
# Recover cuda rng states |
|
set_mode(ParallelMode.GLOBAL) |
|
torch.cuda.set_rng_state(global_cuda_rng_state) |
|
set_mode(ParallelMode.DATA) |
|
torch.cuda.set_rng_state(data_parallel_cuda_rng_state) |
|
set_mode(ParallelMode.GLOBAL) |
|
|
|
out = checkpoint(forward, cpu_offload, inputs_, weight_, use_reentrant=use_reentrant) |
|
loss = out.sum() |
|
loss.backward() |
|
|
|
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' |
|
torch.cuda.empty_cache() |
|
|
|
# Extra test for use_reentrant=False |
|
if use_reentrant == False: |
|
# Recover cuda rng states |
|
set_mode(ParallelMode.GLOBAL) |
|
torch.cuda.set_rng_state(global_cuda_rng_state) |
|
set_mode(ParallelMode.DATA) |
|
torch.cuda.set_rng_state(data_parallel_cuda_rng_state) |
|
set_mode(ParallelMode.GLOBAL) |
|
|
|
out = forward_inplace(inputs, weight) |
|
loss = out.sum() |
|
loss.backward() |
|
|
|
# Recover cuda rng states |
|
set_mode(ParallelMode.GLOBAL) |
|
torch.cuda.set_rng_state(global_cuda_rng_state) |
|
set_mode(ParallelMode.DATA) |
|
torch.cuda.set_rng_state(data_parallel_cuda_rng_state) |
|
set_mode(ParallelMode.GLOBAL) |
|
|
|
out = forward_inplace_ckpt(inputs_, weight_, cpu_offload=cpu_offload) |
|
loss = out.sum() |
|
loss.backward() |
|
|
|
assert torch.all(inputs.grad == inputs_.grad), 'Gradient of the input does not match' |
|
torch.cuda.empty_cache() |
|
|
|
# as seed manager is singleton |
|
# if we don't reset seeds here, |
|
# other tests will fail if running together with this test |
|
# as other tests can't overwrite the seed set by this test |
|
reset_seeds() |
|
|
|
|
|
if __name__ == "__main__": |
|
test_activation_checkpointing(False, False)
|
|
|