mirror of https://github.com/hpcaitech/ColossalAI
[autoparallel] update binary elementwise handler (#2451)
* [autoparallel] update binary elementwise handler * polishpull/2396/head
parent
c9ec5190a0
commit
8221fd7485
|
@ -32,20 +32,32 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
|||
return OperationDataType.ARG
|
||||
|
||||
def _get_arg_value(idx):
|
||||
non_tensor = False
|
||||
if isinstance(self.node.args[idx], Node):
|
||||
meta_data = self.node.args[idx]._meta_data
|
||||
# The meta_data of node type argument could also possibly be a non-tensor object.
|
||||
if not isinstance(meta_data, torch.Tensor):
|
||||
assert isinstance(meta_data, (int, float))
|
||||
meta_data = torch.Tensor([meta_data]).to('meta')
|
||||
non_tensor = True
|
||||
|
||||
else:
|
||||
# this is in fact a real data like int 1
|
||||
# but we can deem it as meta data
|
||||
# as it won't affect the strategy generation
|
||||
assert isinstance(self.node.args[idx], (int, float))
|
||||
meta_data = torch.Tensor([self.node.args[idx]]).to('meta')
|
||||
return meta_data
|
||||
non_tensor = True
|
||||
|
||||
input_meta_data = _get_arg_value(0)
|
||||
other_meta_data = _get_arg_value(1)
|
||||
return meta_data, non_tensor
|
||||
|
||||
input_meta_data, non_tensor_input = _get_arg_value(0)
|
||||
other_meta_data, non_tensor_other = _get_arg_value(1)
|
||||
output_meta_data = self.node._meta_data
|
||||
|
||||
# we need record op_data with non-tensor data in this list,
|
||||
# and filter the non-tensor op_data in post_process.
|
||||
self.non_tensor_list = []
|
||||
# assert False
|
||||
input_op_data = OperationData(name=str(self.node.args[0]),
|
||||
type=_get_op_data_type(input_meta_data),
|
||||
data=input_meta_data,
|
||||
|
@ -58,6 +70,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
|||
type=OperationDataType.OUTPUT,
|
||||
data=output_meta_data,
|
||||
logical_shape=bcast_shape)
|
||||
if non_tensor_input:
|
||||
self.non_tensor_list.append(input_op_data)
|
||||
if non_tensor_other:
|
||||
self.non_tensor_list.append(other_op_data)
|
||||
|
||||
mapping = {'input': input_op_data, 'other': other_op_data, 'output': output_op_data}
|
||||
return mapping
|
||||
|
@ -73,9 +89,10 @@ class BinaryElementwiseHandler(MetaInfoNodeHandler):
|
|||
op_data_mapping = self.get_operation_data_mapping()
|
||||
|
||||
for op_name, op_data in op_data_mapping.items():
|
||||
if not isinstance(op_data.data, torch.Tensor):
|
||||
if op_data in self.non_tensor_list:
|
||||
# remove the sharding spec if the op_data is not a tensor, e.g. torch.pow(tensor, 2)
|
||||
strategy.sharding_specs.pop(op_data)
|
||||
|
||||
else:
|
||||
# convert the logical sharding spec to physical sharding spec if broadcast
|
||||
# e.g. torch.rand(4, 4) + torch.rand(4)
|
||||
|
|
|
@ -122,25 +122,41 @@ def check_binary_elementwise_handler_with_tensor(rank, op, other_dim, world_size
|
|||
assert input_sharding_spec.sharding_sequence[-1] == other_sharding_spec.sharding_sequence[-1]
|
||||
|
||||
|
||||
def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, port):
|
||||
class BEOpModelWithNodeConst(nn.Module):
|
||||
|
||||
def __init__(self, op):
|
||||
super().__init__()
|
||||
self.op = op
|
||||
|
||||
def forward(self, x1):
|
||||
const = x1.dim()
|
||||
out = self.op(x1, const)
|
||||
return out
|
||||
|
||||
|
||||
class BEOpModelWithIntConst(nn.Module):
|
||||
|
||||
def __init__(self, op, const):
|
||||
super().__init__()
|
||||
self.op = op
|
||||
self.const = const
|
||||
|
||||
def forward(self, x1):
|
||||
out = self.op(x1, self.const)
|
||||
return out
|
||||
|
||||
|
||||
def check_binary_elementwise_handler_with_int(rank, op, other_dim, model_cls, world_size, port):
|
||||
disable_existing_loggers()
|
||||
launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')
|
||||
|
||||
class BinaryElementwiseOpModel(nn.Module):
|
||||
|
||||
def __init__(self, op, const):
|
||||
super().__init__()
|
||||
self.op = op
|
||||
self.const = const
|
||||
|
||||
def forward(self, x1):
|
||||
out = self.op(x1, self.const)
|
||||
return out
|
||||
|
||||
physical_mesh_id = torch.arange(0, 4)
|
||||
mesh_shape = (2, 2)
|
||||
device_mesh = DeviceMesh(physical_mesh_id, mesh_shape, init_process_group=True)
|
||||
model = BinaryElementwiseOpModel(op, other_dim).cuda()
|
||||
if model_cls == BEOpModelWithNodeConst:
|
||||
model = model_cls(op).cuda()
|
||||
else:
|
||||
model = model_cls(op, other_dim).cuda()
|
||||
x1 = torch.rand(4, 4).cuda()
|
||||
# the index of binary-elementwise node in computation graph
|
||||
node_index = 1
|
||||
|
@ -159,9 +175,14 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
|
|||
tracer = ColoTracer()
|
||||
meta_args = {'x1': torch.rand(4, 4).to('meta')}
|
||||
graph = tracer.trace(model, meta_args=meta_args)
|
||||
print(graph)
|
||||
# assert False
|
||||
gm = ColoGraphModule(model, graph)
|
||||
|
||||
op_node = list(graph.nodes)[1]
|
||||
if model_cls == BEOpModelWithNodeConst:
|
||||
op_node = list(graph.nodes)[2]
|
||||
else:
|
||||
op_node = list(graph.nodes)[1]
|
||||
strategies_vector = StrategiesVector(op_node)
|
||||
|
||||
# build handler
|
||||
|
@ -212,7 +233,7 @@ def check_binary_elementwise_handler_with_int(rank, op, other_dim, world_size, p
|
|||
@parameterize('other_dim', [1, 2])
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_binary_elementwise_handler(op, other_dim):
|
||||
def test_binary_elementwise_handler_with_tensor(op, other_dim):
|
||||
world_size = 4
|
||||
run_func_tensor = partial(check_binary_elementwise_handler_with_tensor,
|
||||
op=op,
|
||||
|
@ -220,8 +241,19 @@ def test_binary_elementwise_handler(op, other_dim):
|
|||
world_size=world_size,
|
||||
port=free_port())
|
||||
mp.spawn(run_func_tensor, nprocs=world_size)
|
||||
|
||||
|
||||
@run_on_environment_flag(name='AUTO_PARALLEL')
|
||||
@parameterize('op', [torch.add])
|
||||
@parameterize('other_dim', [1, 2])
|
||||
@parameterize('model_cls', [BEOpModelWithNodeConst, BEOpModelWithIntConst])
|
||||
@pytest.mark.dist
|
||||
@rerun_if_address_is_in_use()
|
||||
def test_binary_elementwise_handler_with_int(op, model_cls, other_dim):
|
||||
world_size = 4
|
||||
run_func_int = partial(check_binary_elementwise_handler_with_int,
|
||||
op=op,
|
||||
model_cls=model_cls,
|
||||
other_dim=other_dim,
|
||||
world_size=world_size,
|
||||
port=free_port())
|
||||
|
@ -229,4 +261,5 @@ def test_binary_elementwise_handler(op, other_dim):
|
|||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_binary_elementwise_handler()
|
||||
test_binary_elementwise_handler_with_tensor()
|
||||
test_binary_elementwise_handler_with_int()
|
||||
|
|
|
@ -90,7 +90,8 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
solver_options = SolverOptions()
|
||||
strategies_constructor = StrategiesConstructor(graph, device_mesh, solver_options)
|
||||
strategies_constructor.build_strategies_and_cost()
|
||||
target_node = list(graph.nodes)[node_index]
|
||||
target_node = [strategies_vector.node for strategies_vector in strategies_constructor.leaf_strategies
|
||||
][node_index]
|
||||
if node_type == 'normal':
|
||||
solution_len = len(strategies_constructor.leaf_strategies)
|
||||
solution = [0] * solution_len
|
||||
|
@ -112,7 +113,7 @@ def numerical_test_for_node_strategy(model: torch.nn.Module,
|
|||
ret = solver.call_solver_serialized_args()
|
||||
solution = list(ret[0])
|
||||
gm, sharding_spec_dict, origin_spec_dict, comm_actions_dict = runtime_preparation_pass(
|
||||
gm, solution, device_mesh)
|
||||
gm, solution, device_mesh, strategies_constructor)
|
||||
gm = runtime_apply_pass(gm)
|
||||
gm.recompile()
|
||||
|
||||
|
|
Loading…
Reference in New Issue