mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] test compatibility for gemini and auto parallel (#2700)
parent
d701ef81b1
commit
7fa6be49d2
|
@ -377,8 +377,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
|
||||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||||
# we could use .data here, because all the operations just happen before the real training
|
# we could use .data here, because all the operations just happen before the real training
|
||||||
# loop, so we don't need to track these operations in the autograd graph.
|
# loop, so we don't need to track these operations in the autograd graph.
|
||||||
param.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
param = torch.nn.Parameter(
|
||||||
param.data, param.sharding_spec, target_sharding_spec).detach().clone()
|
shape_consistency_manager.apply_for_autoparallel_runtime(param.data, param.sharding_spec,
|
||||||
|
target_sharding_spec).detach().clone())
|
||||||
|
|
||||||
setattr(target_module, name, param)
|
setattr(target_module, name, param)
|
||||||
comm_actions = node.best_strategy.communication_actions
|
comm_actions = node.best_strategy.communication_actions
|
||||||
|
@ -432,8 +433,9 @@ def _module_params_sharding(gm: torch.fx.GraphModule, device_mesh: DeviceMesh, o
|
||||||
# TODO: build a ColoParamter class to manager the distributed parameters
|
# TODO: build a ColoParamter class to manager the distributed parameters
|
||||||
# we could use .data here, because all the operations just happen before the real training
|
# we could use .data here, because all the operations just happen before the real training
|
||||||
# loop, so we don't need to track these operations in the autograd graph.
|
# loop, so we don't need to track these operations in the autograd graph.
|
||||||
target.data = shape_consistency_manager.apply_for_autoparallel_runtime(
|
target = torch.nn.Parameter(
|
||||||
target.data, target.sharding_spec, target_sharding_spec).detach().clone()
|
shape_consistency_manager.apply_for_autoparallel_runtime(target.data, target.sharding_spec,
|
||||||
|
target_sharding_spec).detach().clone())
|
||||||
|
|
||||||
assert hasattr(target_module, atoms[-1])
|
assert hasattr(target_module, atoms[-1])
|
||||||
setattr(target_module, atoms[-1], target)
|
setattr(target_module, atoms[-1], target)
|
||||||
|
|
|
@ -0,0 +1,98 @@
|
||||||
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
from colossalai.utils import free_port
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False)
|
||||||
|
self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear_1(x)
|
||||||
|
x = self.linear_2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def check_compatibility_with_ddp(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
model = MLP(4).cuda()
|
||||||
|
input = torch.rand(4, 4).cuda()
|
||||||
|
output_compare = model(input)
|
||||||
|
loss_compare = output_compare.sum()
|
||||||
|
loss_compare.backward()
|
||||||
|
grad_compare = copy.deepcopy(model.linear_1.weight.grad)
|
||||||
|
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
|
meta_args = {'x': torch.rand(4, 4).to('meta')}
|
||||||
|
gm, solution = initialize_model(model,
|
||||||
|
meta_args=meta_args,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
return_solution=True,
|
||||||
|
solver_preference='tp',
|
||||||
|
shard_option='shard_last_axis')
|
||||||
|
|
||||||
|
msg = '| TP strategy combination chosen by auto-parallel solver |'
|
||||||
|
msg_length = len(msg)
|
||||||
|
if rank == 0:
|
||||||
|
print('=' * msg_length)
|
||||||
|
print(msg)
|
||||||
|
print('=' * msg_length)
|
||||||
|
for strategy in solution:
|
||||||
|
print(strategy)
|
||||||
|
print('=' * msg_length)
|
||||||
|
|
||||||
|
dp_process_group = None
|
||||||
|
for (ranks, process_group_handle) in device_mesh.process_groups_dict[0]:
|
||||||
|
if rank in ranks:
|
||||||
|
dp_process_group = process_group_handle
|
||||||
|
assert dp_process_group is not None
|
||||||
|
gm = DDP(gm, process_group=dp_process_group)
|
||||||
|
output = gm(input)
|
||||||
|
|
||||||
|
assert_close(output, output_compare)
|
||||||
|
print(f'output on rank{rank} is correct')
|
||||||
|
loss = output.sum()
|
||||||
|
|
||||||
|
loss.backward()
|
||||||
|
|
||||||
|
if rank in (0, 2):
|
||||||
|
assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 0, 8))
|
||||||
|
|
||||||
|
if rank in (1, 3):
|
||||||
|
assert_close(gm.module.module.linear_1.weight.grad, grad_compare.narrow(0, 8, 8))
|
||||||
|
|
||||||
|
print(f'gradient on rank{rank} is correct')
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_compatibility_with_ddp():
|
||||||
|
world_size = 4
|
||||||
|
run_func = partial(check_compatibility_with_ddp, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_compatibility_with_ddp()
|
|
@ -0,0 +1,108 @@
|
||||||
|
import copy
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||||
|
|
||||||
|
from colossalai.auto_parallel.tensor_shard.initialize import initialize_model
|
||||||
|
from colossalai.device.device_mesh import DeviceMesh
|
||||||
|
from colossalai.initialize import launch
|
||||||
|
from colossalai.logging import disable_existing_loggers
|
||||||
|
from colossalai.nn.optimizer import HybridAdam
|
||||||
|
from colossalai.nn.parallel import zero_model_wrapper, zero_optim_wrapper
|
||||||
|
from colossalai.tensor.process_group import ProcessGroup
|
||||||
|
from colossalai.testing import assert_close, rerun_if_address_is_in_use
|
||||||
|
from colossalai.testing.pytest_wrapper import run_on_environment_flag
|
||||||
|
from colossalai.utils import free_port, get_current_device
|
||||||
|
from colossalai.utils.model.colo_init_context import ColoInitContext, post_process_colo_init_ctx
|
||||||
|
|
||||||
|
|
||||||
|
class MLP(torch.nn.Module):
|
||||||
|
|
||||||
|
def __init__(self, in_features):
|
||||||
|
super().__init__()
|
||||||
|
self.linear_1 = torch.nn.Linear(in_features, 4 * in_features, bias=False)
|
||||||
|
self.linear_2 = torch.nn.Linear(4 * in_features, in_features, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.linear_1(x)
|
||||||
|
x = self.linear_2(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def check_auto_parallel_with_gemini(rank, world_size, port):
|
||||||
|
disable_existing_loggers()
|
||||||
|
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||||
|
model = MLP(4).half().cuda()
|
||||||
|
|
||||||
|
input = torch.rand(4, 4).half().cuda()
|
||||||
|
output_compare = model(input)
|
||||||
|
loss_compare = output_compare.sum()
|
||||||
|
loss_compare.backward()
|
||||||
|
grad_compare = copy.deepcopy(model.linear_1.weight.grad)
|
||||||
|
|
||||||
|
physical_mesh_id = torch.arange(0, 4)
|
||||||
|
mesh_shape = (2, 2)
|
||||||
|
# [[0, 1]
|
||||||
|
# [2, 3]]
|
||||||
|
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||||
|
meta_args = {'x': torch.rand(4, 4).half().to('meta')}
|
||||||
|
gm, solution = initialize_model(model,
|
||||||
|
meta_args=meta_args,
|
||||||
|
device_mesh=device_mesh,
|
||||||
|
return_solution=True,
|
||||||
|
solver_preference='tp',
|
||||||
|
shard_option='shard_last_axis')
|
||||||
|
|
||||||
|
if rank == 0:
|
||||||
|
msg = '| TP strategy combination chosen by auto-parallel solver |'
|
||||||
|
msg_length = len(msg)
|
||||||
|
print('=' * msg_length)
|
||||||
|
print(msg)
|
||||||
|
print('=' * msg_length)
|
||||||
|
for strategy in solution:
|
||||||
|
print(strategy)
|
||||||
|
print('=' * msg_length)
|
||||||
|
|
||||||
|
dp_process_group = ProcessGroup(rank=rank, ranks=[0, 1, 2, 3], tp_degree=2, dp_degree=2)
|
||||||
|
gemini_config = dict(strict_ddp_mode=False,
|
||||||
|
device=get_current_device(),
|
||||||
|
placement_policy='cpu',
|
||||||
|
pin_memory=True,
|
||||||
|
search_range_mb=128)
|
||||||
|
|
||||||
|
post_process_colo_init_ctx(gm, device=get_current_device(), default_pg=dp_process_group)
|
||||||
|
gm = zero_model_wrapper(gm, zero_stage=3, gemini_config=gemini_config)
|
||||||
|
optimizer = HybridAdam(gm.parameters(), betas=(0, 0))
|
||||||
|
optimizer = zero_optim_wrapper(gm, optimizer, initial_scale=1)
|
||||||
|
output = gm(input)
|
||||||
|
assert_close(output, output_compare)
|
||||||
|
print(f'output on rank{rank} is correct')
|
||||||
|
loss = output.sum()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
optimizer.backward(loss)
|
||||||
|
optimizer.step()
|
||||||
|
|
||||||
|
if rank in (0, 2):
|
||||||
|
assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 0, 8).flatten())
|
||||||
|
|
||||||
|
if rank in (1, 3):
|
||||||
|
assert_close(list(optimizer.optim.state.values())[0]['exp_avg'].half(), grad_compare.narrow(0, 8, 8).flatten())
|
||||||
|
|
||||||
|
print(f'gradient on rank{rank} is correct')
|
||||||
|
|
||||||
|
|
||||||
|
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||||
|
@pytest.mark.dist
|
||||||
|
@rerun_if_address_is_in_use()
|
||||||
|
def test_auto_parallel_with_gemini():
|
||||||
|
world_size = 4
|
||||||
|
run_func = partial(check_auto_parallel_with_gemini, world_size=world_size, port=free_port())
|
||||||
|
mp.spawn(run_func, nprocs=world_size)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test_auto_parallel_with_gemini()
|
Loading…
Reference in New Issue