mirror of https://github.com/jumpserver/jumpserver
157 lines
4.8 KiB
Python
157 lines
4.8 KiB
Python
# ~*~ coding: utf-8 ~*~
|
|
from ansible.inventory.host import Host
|
|
from ansible.vars.manager import VariableManager
|
|
from ansible.inventory.manager import InventoryManager
|
|
from ansible.parsing.dataloader import DataLoader
|
|
|
|
|
|
__all__ = [
|
|
'BaseHost', 'BaseInventory'
|
|
]
|
|
|
|
|
|
class BaseHost(Host):
|
|
def __init__(self, host_data):
|
|
"""
|
|
初始化
|
|
:param host_data: {
|
|
"name": "",
|
|
"ip": "",
|
|
"port": "",
|
|
# behind is not must be required
|
|
"username": "",
|
|
"password": "",
|
|
"private_key": "",
|
|
"become": {
|
|
"method": "",
|
|
"user": "",
|
|
"pass": "",
|
|
}
|
|
"groups": [],
|
|
"vars": {},
|
|
}
|
|
"""
|
|
self.host_data = host_data
|
|
hostname = host_data.get('name') or host_data.get('ip')
|
|
port = host_data.get('port') or 22
|
|
super().__init__(hostname, port)
|
|
self.__set_required_variables()
|
|
self.__set_extra_variables()
|
|
|
|
def __set_required_variables(self):
|
|
host_data = self.host_data
|
|
self.set_variable('ansible_host', host_data['ip'])
|
|
self.set_variable('ansible_port', host_data['port'])
|
|
|
|
if host_data.get('username'):
|
|
self.set_variable('ansible_user', host_data['username'])
|
|
|
|
# 添加密码和密钥
|
|
if host_data.get('password'):
|
|
self.set_variable('ansible_ssh_pass', host_data['password'])
|
|
if host_data.get('private_key'):
|
|
self.set_variable('ansible_ssh_private_key_file', host_data['private_key'])
|
|
|
|
# 添加become支持
|
|
become = host_data.get("become", False)
|
|
if become:
|
|
self.set_variable("ansible_become", True)
|
|
self.set_variable("ansible_become_method", become.get('method', 'sudo'))
|
|
self.set_variable("ansible_become_user", become.get('user', 'root'))
|
|
self.set_variable("ansible_become_pass", become.get('pass', ''))
|
|
else:
|
|
self.set_variable("ansible_become", False)
|
|
|
|
def __set_extra_variables(self):
|
|
for k, v in self.host_data.get('vars', {}).items():
|
|
self.set_variable(k, v)
|
|
|
|
def __repr__(self):
|
|
return self.name
|
|
|
|
|
|
class BaseInventory(InventoryManager):
|
|
"""
|
|
提供生成Ansible inventory对象的方法
|
|
"""
|
|
loader_class = DataLoader
|
|
variable_manager_class = VariableManager
|
|
host_manager_class = BaseHost
|
|
|
|
def __init__(self, host_list=None, group_list=None):
|
|
"""
|
|
用于生成动态构建Ansible Inventory. super().__init__ 会自动调用
|
|
host_list: [{
|
|
"name": "",
|
|
"ip": "",
|
|
"port": "",
|
|
"username": "",
|
|
"password": "",
|
|
"private_key": "",
|
|
"become": {
|
|
"method": "",
|
|
"user": "",
|
|
"pass": "",
|
|
},
|
|
"groups": [],
|
|
"vars": {},
|
|
},
|
|
]
|
|
group_list: [
|
|
{"name: "", children: [""]},
|
|
]
|
|
:param host_list:
|
|
:param group_list
|
|
"""
|
|
self.host_list = host_list or []
|
|
self.group_list = group_list or []
|
|
assert isinstance(host_list, list)
|
|
self.loader = self.loader_class()
|
|
self.variable_manager = self.variable_manager_class()
|
|
super().__init__(self.loader)
|
|
|
|
def get_groups(self):
|
|
return self._inventory.groups
|
|
|
|
def get_group(self, name):
|
|
return self._inventory.groups.get(name, None)
|
|
|
|
def get_or_create_group(self, name):
|
|
group = self.get_group(name)
|
|
if not group:
|
|
self.add_group(name)
|
|
return self.get_or_create_group(name)
|
|
else:
|
|
return group
|
|
|
|
def parse_groups(self):
|
|
for g in self.group_list:
|
|
parent = self.get_or_create_group(g.get("name"))
|
|
children = [self.get_or_create_group(n) for n in g.get('children', [])]
|
|
for child in children:
|
|
parent.add_child_group(child)
|
|
|
|
def parse_hosts(self):
|
|
group_all = self.get_or_create_group('all')
|
|
ungrouped = self.get_or_create_group('ungrouped')
|
|
for host_data in self.host_list:
|
|
host = self.host_manager_class(host_data=host_data)
|
|
self.hosts[host_data['name']] = host
|
|
groups_data = host_data.get('groups')
|
|
if groups_data:
|
|
for group_name in groups_data:
|
|
group = self.get_or_create_group(group_name)
|
|
group.add_host(host)
|
|
else:
|
|
ungrouped.add_host(host)
|
|
group_all.add_host(host)
|
|
|
|
def parse_sources(self, cache=False):
|
|
self.parse_groups()
|
|
self.parse_hosts()
|
|
|
|
def get_matched_hosts(self, pattern):
|
|
return self.get_hosts(pattern)
|
|
|
|
|