From c188696328c9e198f39af30dcbdc37e9a17f536f Mon Sep 17 00:00:00 2001 From: ibuler Date: Thu, 31 Mar 2016 23:45:01 +0800 Subject: [PATCH] =?UTF-8?q?fix(connect)=20=E5=A2=9E=E5=8A=A0=E6=A8=A1?= =?UTF-8?q?=E7=B3=8A=E6=90=9C=E7=B4=A2?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 之前只是输入id登陆,增加了模糊搜索登陆 如果搜索唯一则登陆 --- connect.py | 155 +++++++++++++++++++++++++++++------------------------ 1 file changed, 85 insertions(+), 70 deletions(-) diff --git a/connect.py b/connect.py index bb4924ddd..7aee8ea83 100755 --- a/connect.py +++ b/connect.py @@ -436,8 +436,10 @@ class Nav(object): """ def __init__(self, user): self.user = user - self.search_result = {} - self.user_perm = {} + self.search_result = None + self.user_perm = get_group_user_perm(self.user) + self.perm_assets = tuple(self.user_perm.get('asset', [])) + self.perm_asset_groups = self.user_perm.get('asset_group', []) @staticmethod def print_nav(): @@ -460,46 +462,85 @@ class Nav(object): """ print textwrap.dedent(msg) - def search(self, str_r=''): + def get_asset_group_member(self, str_r): gid_pattern = re.compile(r'^g\d+$') - # 获取用户授权的所有主机信息 - if not self.user_perm: - self.user_perm = get_group_user_perm(self.user) - user_asset_all = self.user_perm.get('asset').keys() - # 搜索结果保存 - user_asset_search = [] - if str_r: - # 资产组组id匹配 - if gid_pattern.match(str_r): - gid = int(str_r.lstrip('g')) - # 获取资产组包含的资产 - asset_group = get_object(AssetGroup, id=gid) - if asset_group: - user_asset_search = asset_group.asset_set.all() - else: - color_print('没有该资产组或没有权限') - return + if gid_pattern.match(str_r): + gid = int(str_r.lstrip('g')) + # 获取资产组包含的资产 + asset_group = get_object(AssetGroup, id=gid) + if asset_group: + self.search_result = list(asset_group.asset_set.all()) else: + color_print('没有该资产组或没有权限') + return + + def search(self, str_r=''): + # 搜索结果保存 + if str_r: + try: + id_ = int(str_r) + if id_ < len(self.search_result): + self.search_result = [self.search_result[id_]] + return + else: + raise ValueError + + except (ValueError, TypeError): # 匹配 ip, hostname, 备注 - for asset in user_asset_all: - if str_r in asset.ip or str_r in str(asset.hostname) or str_r in str(asset.comment): - user_asset_search.append(asset) + self.search_result = [asset for asset in self.perm_assets if str_r in str(asset.ip) + or str_r in str(asset.hostname) or str_r in str(asset.comment)] else: # 如果没有输入就展现所有 - user_asset_search = user_asset_all + self.search_result = self.perm_assets - self.search_result = dict(zip(range(len(user_asset_search)), user_asset_search)) + self.search_result = list(set(self.search_result)) + + def print_search_result(self): color_print('[%-3s] %-12s %-15s %-5s %-10s %s' % ('ID', '主机名', 'IP', '端口', '系统用户', '备注'), 'title') - for index, asset in self.search_result.items(): - # 获取该资产信息 - asset_info = get_asset_info(asset) - # 获取该资产包含的角色 - role = [str(role.name) for role in self.user_perm.get('asset').get(asset).get('role')] - print '[%-3s] %-15s %-15s %-5s %-10s %s' % (index, asset.hostname, asset.ip, asset_info.get('port'), - role, asset.comment) + if hasattr(self.search_result, '__iter__'): + for index, asset in enumerate(self.search_result): + # 获取该资产信息 + asset_info = get_asset_info(asset) + # 获取该资产包含的角色 + role = [str(role.name) for role in self.user_perm.get('asset').get(asset).get('role')] + print '[%-3s] %-15s %-15s %-5s %-10s %s' % (index, asset.hostname, asset.ip, asset_info.get('port'), + role, asset.comment) print + def try_connect(self): + try: + asset = self.search_result[0] + roles = list(self.user_perm.get('asset').get(asset).get('role')) + if len(roles) == 1: + role = roles[0] + elif len(roles) > 1: + print "\033[32m[ID] 系统用户\033[0m" + for index, role in enumerate(roles): + print "[%-2s] %s" % (index, role.name) + print + print "授权系统用户超过1个,请输入ID, q退出" + try: + role_index = raw_input("\033[1;32mID>:\033[0m ").strip() + if role_index == 'q': + return + else: + role = roles[int(role_index)] + except IndexError: + color_print('请输入正确ID', 'red') + return + else: + color_print('没有映射用户', 'red') + return + + ssh_tty = SshTty(login_user, asset, role) + print('Connecting %s ...' % asset.hostname) + ssh_tty.connect() + except (KeyError, ValueError): + color_print('请输入正确ID', 'red') + except ServerError, e: + color_print(e, 'red') + def print_asset_group(self): """ 打印用户授权的资产组 @@ -515,9 +556,6 @@ class Nav(object): 批量执行命令 """ while True: - if not self.user_perm: - self.user_perm = get_group_user_perm(self.user) - roles = self.user_perm.get('role').keys() if len(roles) > 1: # 授权角色数大于1 color_print('[%-2s] %-15s' % ('ID', '系统用户'), 'info') @@ -587,8 +625,6 @@ class Nav(object): def upload(self): while True: - if not self.user_perm: - self.user_perm = get_group_user_perm(self.user) try: print "进入批量上传模式" print "请输入主机名或ansible支持的pattern, 多个主机:分隔 q退出" @@ -640,8 +676,6 @@ class Nav(object): def download(self): while True: - if not self.user_perm: - self.user_perm = get_group_user_perm(self.user) try: print "进入批量下载模式" print "请输入主机名或ansible支持的pattern, 多个主机:分隔,q退出" @@ -723,9 +757,14 @@ def main(): sys.exit(0) if option in ['P', 'p', '\n', '']: nav.search() + nav.print_search_result() continue - if option.startswith('/') or gid_pattern.match(option): + if option.startswith('/'): nav.search(option.lstrip('/')) + nav.print_search_result() + elif gid_pattern.match(option): + nav.get_asset_group_member(str_r=option) + nav.print_search_result() elif option in ['G', 'g']: nav.print_asset_group() continue @@ -741,36 +780,12 @@ def main(): elif option in ['Q', 'q', 'exit']: sys.exit() else: - try: - asset = nav.search_result[int(option)] - roles = nav.user_perm.get('asset').get(asset).get('role') - if len(roles) > 1: - role_check = dict(zip(range(len(roles)), roles)) - print "\033[32m[ID] 系统用户\033[0m" - for index, role in role_check.items(): - print "[%-2s] %s" % (index, role.name) - print - print "授权系统用户超过1个,请输入ID, q退出" - try: - role_index = raw_input("\033[1;32mID>:\033[0m ").strip() - if role_index == 'q': - continue - else: - role = role_check[int(role_index)] - except IndexError: - color_print('请输入正确ID', 'red') - continue - elif len(roles) == 1: - role = list(roles)[0] - else: - color_print('没有映射用户', 'red') - continue - ssh_tty = SshTty(login_user, asset, role) - ssh_tty.connect() - except (KeyError, ValueError): - color_print('请输入正确ID', 'red') - except ServerError, e: - color_print(e, 'red') + nav.search(option) + if len(nav.search_result) == 1: + nav.try_connect() + else: + nav.print_search_result() + except IndexError, e: color_print(e) time.sleep(5)