mirror of https://github.com/tp4a/teleport
安装和升级向导功能完成,开始测试。
parent
9b59d6ec99
commit
112cf02146
|
@ -72,8 +72,12 @@ class WebServerCore:
|
|||
# db_path = os.path.join(cfg.data_path, 'ts_db.db')
|
||||
get_sqlite_pool().init(cfg.data_path)
|
||||
|
||||
get_db().init_sqlite(os.path.join(cfg.data_path, 'ts_db.db'))
|
||||
if get_db().need_create or get_db().need_upgrade:
|
||||
# get_db().init_sqlite(os.path.join(cfg.data_path, 'ts_db.db'))
|
||||
_db = get_db()
|
||||
if not _db.init({'type': _db.DB_TYPE_SQLITE, 'file': os.path.join(cfg.data_path, 'ts_db.db')}):
|
||||
log.e('initialize database interface failed.\n')
|
||||
return False
|
||||
if _db.need_create or _db.need_upgrade:
|
||||
cfg.app_mode = APP_MODE_MAINTENANCE
|
||||
else:
|
||||
cfg.app_mode = APP_MODE_NORMAL
|
||||
|
|
|
@ -15,16 +15,16 @@ def _db_exec(db, step_begin, step_end, msg, sql):
|
|||
step_end(_step, 0)
|
||||
|
||||
|
||||
def create_and_init(db, step_begin, step_end, db_ver):
|
||||
def create_and_init(db, step_begin, step_end):
|
||||
try:
|
||||
_db_exec(db, step_begin, step_end, '创建表 account', """CREATE TABLE `{}account` (
|
||||
`account_id` integer PRIMARY KEY AUTOINCREMENT,
|
||||
`account_type` int(11) DEFAULT 0,
|
||||
`account_name` varchar(32) DEFAULT NULL,
|
||||
`account_pwd` varchar(32) DEFAULT NULL,
|
||||
`account_status` int(11) DEFAULT 0,
|
||||
`account_lock` int(11) DEFAULT 0,
|
||||
`account_desc` varchar(255)
|
||||
`account_id` integer PRIMARY KEY AUTOINCREMENT,
|
||||
`account_type` int(11) DEFAULT 0,
|
||||
`account_name` varchar(32) DEFAULT NULL,
|
||||
`account_pwd` varchar(32) DEFAULT NULL,
|
||||
`account_status` int(11) DEFAULT 0,
|
||||
`account_lock` int(11) DEFAULT 0,
|
||||
`account_desc` varchar(255)
|
||||
);""".format(db.table_prefix))
|
||||
|
||||
_db_exec(db, step_begin, step_end, '创建表 auth', """CREATE TABLE `{}auth`(
|
||||
|
@ -104,7 +104,7 @@ PRIMARY KEY (`name` ASC)
|
|||
|
||||
_db_exec(db, step_begin, step_end,
|
||||
'设定数据库版本',
|
||||
'INSERT INTO `{}config` VALUES ("db_ver", "{}");'.format(db.table_prefix, db_ver)
|
||||
'INSERT INTO `{}config` VALUES ("db_ver", "{}");'.format(db.table_prefix, db.DB_VERSION)
|
||||
)
|
||||
|
||||
return True
|
||||
|
|
|
@ -1,24 +1,506 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import shutil
|
||||
import json
|
||||
|
||||
from eom_app.app.util import sec_generate_password
|
||||
from eom_common.eomcore.logger import log
|
||||
|
||||
|
||||
def _db_exec(db, step_begin, step_end, msg, sql):
|
||||
_step = step_begin(msg)
|
||||
class DatabaseUpgrade:
|
||||
def __init__(self, db, step_begin, step_end):
|
||||
self.db = db
|
||||
self.step_begin = step_begin
|
||||
self.step_end = step_end
|
||||
|
||||
ret = db.exec(sql)
|
||||
if not ret:
|
||||
step_end(_step, -1)
|
||||
raise RuntimeError('[FAILED] {}'.format(sql))
|
||||
else:
|
||||
step_end(_step, 0)
|
||||
def do_upgrade(self):
|
||||
for i in range(self.db.DB_VERSION):
|
||||
if self.db.current_ver < i + 1:
|
||||
_f_name = '_upgrade_to_v{}'.format(i + 1)
|
||||
if _f_name in dir(self):
|
||||
if self.__getattribute__(_f_name)():
|
||||
self.db.current_ver = i + 1
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def upgrade_database(db, step_begin, step_end, db_ver):
|
||||
try:
|
||||
pass
|
||||
return True
|
||||
except:
|
||||
log.e('ERROR')
|
||||
return False
|
||||
|
||||
def _upgrade_to_v2(self):
|
||||
# 服务端升级到版本1.2.102.3时,管理员后台和普通用户后台合并了,数据库略有调整
|
||||
|
||||
_step = self.step_begin('检查数据库版本v2...')
|
||||
|
||||
try:
|
||||
# 判断依据:
|
||||
# 如果存在名为 ${prefix}sys_user 的表,说明是旧版本,需要升级
|
||||
|
||||
ret = self.db.is_table_exists('sys_user')
|
||||
if ret is None:
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
elif not ret:
|
||||
self.step_end(_step, 0, '跳过 v1 到 v2 的升级操作')
|
||||
return True
|
||||
self.step_end(_step, 0, '需要升级到v2')
|
||||
|
||||
if self.db.db_source['type'] == self.db.DB_TYPE_SQLITE:
|
||||
_step = self.step_begin(' - 备份数据库文件')
|
||||
_bak_file = '{}.before-v1-to-v2'.format(self.db.db_source['file'])
|
||||
if not os.path.exists(_bak_file):
|
||||
shutil.copy(self.db.db_source['file'], _bak_file)
|
||||
self.step_end(_step, 0)
|
||||
|
||||
# 将原来的普通用户的account_type从 0 改为 1
|
||||
_step = self.step_begin(' - 调整用户账号类型...')
|
||||
if not self.db.exec('UPDATE `{}account` SET `account_type`=1 WHERE `account_type`=0;'.format(self.db.table_prefix)):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
else:
|
||||
self.step_end(_step, 0)
|
||||
|
||||
# 将原来的管理员合并到用户账号表中
|
||||
_step = self.step_begin(' - 合并管理员和普通用户账号...')
|
||||
db_ret = self.db.query('SELECT * FROM `{}sys_user`;'.format(self.db.table_prefix))
|
||||
if db_ret is None:
|
||||
self.step_end(_step, 0)
|
||||
return True
|
||||
|
||||
for i in range(len(db_ret)):
|
||||
user_name = db_ret[i][1]
|
||||
user_pwd = db_ret[i][2]
|
||||
|
||||
if not self.db.exec("""INSERT INTO `{}account`
|
||||
(`account_type`, `account_name`, `account_pwd`, `account_status`, `account_lock`, `account_desc`)
|
||||
VALUES (100,"{}","{}",0,0,"{}");""".format(self.db.table_prefix, user_name, user_pwd, '超级管理员')):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
# 移除旧的表(暂时改名而不是真的删除)
|
||||
# str_sql = 'ALTER TABLE ts_sys_user RENAME TO _bak_ts_sys_user;'
|
||||
_step = self.step_begin(' - 移除不再使用的数据表...')
|
||||
if not self.db.exec('ALTER TABLE `{}sys_user` RENAME TO `_bak_ts_sys_user`;'.format(self.db.table_prefix)):
|
||||
self.step_end(_step, 0)
|
||||
return False
|
||||
else:
|
||||
self.step_end(_step, -1)
|
||||
|
||||
return True
|
||||
|
||||
except:
|
||||
log.e('failed.\n')
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
def _upgrade_to_v3(self):
|
||||
# 服务端升级到版本1.5.217.9时,为了支持一机多用户多协议,数据库结构有较大程度改动
|
||||
|
||||
_step = self.step_begin('检查数据库版本v3...')
|
||||
|
||||
try:
|
||||
# 判断依据:
|
||||
# 如果不存在名为 ts_host_info 的表,说明是旧版本,需要升级
|
||||
|
||||
ret = self.db.is_table_exists('host_info')
|
||||
if ret is None:
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
elif ret:
|
||||
self.step_end(_step, 0, '跳过 v2 到 v3 的升级操作')
|
||||
return True
|
||||
self.step_end(_step, 0, '需要升级到v3')
|
||||
|
||||
# log.v('upgrade database to version 1.5.217.9 ...\n')
|
||||
# bak_file = '{}.before-1.5.217.9'.format(db_file)
|
||||
# if not os.path.exists(bak_file):
|
||||
# shutil.copy(db_file, bak_file)
|
||||
if self.db.db_source['type'] == self.db.DB_TYPE_SQLITE:
|
||||
_step = self.step_begin(' - 备份数据库文件')
|
||||
_bak_file = '{}.before-v2-to-v3'.format(self.db.db_source['file'])
|
||||
if not os.path.exists(_bak_file):
|
||||
shutil.copy(self.db.db_source['file'], _bak_file)
|
||||
self.step_end(_step, 0)
|
||||
|
||||
_step = self.step_begin(' - 调整数据表...')
|
||||
if not self.db.exec('ALTER TABLE `{}auth` ADD `host_auth_id` INTEGER;'.format(self.db.table_prefix)):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
if not self.db.exec('UPDATE `{}auth` SET `host_auth_id`=`host_id`;'.format(self.db.table_prefix)):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
if not self.db.exec('ALTER TABLE `{}log` ADD `protocol` INTEGER;'.format(self.db.table_prefix)):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
if not self.db.exec('UPDATE `{}log` SET `protocol`=1 WHERE `sys_type`=1;'.format(self.db.table_prefix)):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
if not self.db.exec('UPDATE `{}log` SET `protocol`=2 WHERE `sys_type`=2;'.format(self.db.table_prefix)):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
if not self.db.exec('UPDATE ``{}log`` SET `ret_code`=9999 WHERE `ret_code`=0;'.format(self.db.table_prefix)):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
# 新建两个表,用于拆分原来的 ts_host 表
|
||||
if not self.db.exec("""CREATE TABLE `{}host_info` (
|
||||
`host_id` integer PRIMARY KEY AUTOINCREMENT,
|
||||
`group_id` int(11) DEFAULT 0,
|
||||
`host_sys_type` int(11) DEFAULT 1,
|
||||
`host_ip` varchar(32) DEFAULT '',
|
||||
`pro_port` varchar(256) NULL,
|
||||
`host_lock` int(11) DEFAULT 0,
|
||||
`host_desc` varchar(128) DEFAULT ''
|
||||
);""".format(self.db.table_prefix)):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
if not self.db.exec("""CREATE TABLE `{}auth_info` (
|
||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||
`host_id` INTEGER,
|
||||
`pro_type` INTEGER,
|
||||
`auth_mode` INTEGER,
|
||||
`user_name` varchar(256),
|
||||
`user_pswd` varchar(256),
|
||||
`cert_id` INTEGER,
|
||||
`encrypt` INTEGER,
|
||||
`log_time` varchar(60)
|
||||
);""".format(self.db.table_prefix)):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
# 将原来的 ts_host 表改名
|
||||
if not self.db.exec('ALTER TABLE `{}host` RENAME TO `_bak_{}host;`'.format(self.db.table_prefix, self.db.table_prefix)):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
self.step_end(_step, 0)
|
||||
_step = self.step_begin(' - 调整数据内容...')
|
||||
|
||||
# 从原来 ts_host 表中查询出所有数据
|
||||
db_ret = self.db.query('SELECT * FROM `_bak_{}host;`'.format(self.db.table_prefix))
|
||||
if db_ret is not None:
|
||||
for i in range(len(db_ret)):
|
||||
host_id = db_ret[i][0]
|
||||
group_id = db_ret[i][1]
|
||||
host_sys_type = db_ret[i][2]
|
||||
host_ip = db_ret[i][3]
|
||||
host_pro_port = db_ret[i][4]
|
||||
host_user_name = db_ret[i][5]
|
||||
host_user_pwd = db_ret[i][6]
|
||||
host_pro_type = db_ret[i][7]
|
||||
cert_id = db_ret[i][8]
|
||||
host_lock = db_ret[i][9]
|
||||
host_encrypt = db_ret[i][10]
|
||||
host_auth_mode = db_ret[i][11]
|
||||
host_desc = db_ret[i][12]
|
||||
|
||||
_pro_port = {}
|
||||
_pro_port['ssh'] = {}
|
||||
_pro_port['ssh']['enable'] = 0
|
||||
_pro_port['ssh']['port'] = 22
|
||||
_pro_port['rdp'] = {}
|
||||
_pro_port['rdp']['enable'] = 0
|
||||
_pro_port['rdp']['port'] = 3389
|
||||
|
||||
if host_pro_type == 1:
|
||||
_pro_port['rdp']['enable'] = 1
|
||||
_pro_port['rdp']['port'] = host_pro_port
|
||||
elif host_pro_type == 2:
|
||||
_pro_port['ssh']['enable'] = 1
|
||||
_pro_port['ssh']['port'] = host_pro_port
|
||||
pro_port = json.dumps(_pro_port)
|
||||
|
||||
sql = 'INSERT INTO `{}host_info` (`host_id`, `group_id`, `host_sys_type`, `host_ip`, `pro_port`, `host_lock`, `host_desc`) ' \
|
||||
'VALUES ({}, {}, {}, "{}", "{}", {}, "{}");'.format(self.db.table_prefix, host_id, group_id, host_sys_type, host_ip, pro_port, host_lock, host_desc)
|
||||
if not self.db.exec(sql):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
sql = 'INSERT INTO `{}auth_info` (`host_id`, `pro_type`, `auth_mode`, `user_name`, `user_pswd`, `cert_id`, `encrypt`, `log_time`) ' \
|
||||
'VALUES ({}, {}, {}, "{}", "{}", {}, {}, "{}");'.format(self.db.table_prefix, host_id, host_pro_type, host_auth_mode, host_user_name, host_user_pwd, cert_id, host_encrypt, '1')
|
||||
if not self.db.exec(sql):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
self.step_end(_step, 0)
|
||||
return True
|
||||
|
||||
except:
|
||||
log.e('failed.\n')
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
def _upgrade_to_v4(self):
|
||||
_step = self.step_begin('检查数据库版本v4...')
|
||||
|
||||
# 服务端升级到版本1.6.224.3时,加入telnet支持,数据库有调整
|
||||
try:
|
||||
# 判断依据:
|
||||
# 如果ts_host_info表中还有pro_port字段,说明是旧版本,需要升级
|
||||
|
||||
db_ret = self.db.query('SELECT `pro_port` FROM `{}host_info` LIMIT 0;'.format(self.db.table_prefix))
|
||||
if db_ret is None:
|
||||
self.step_end(_step, 0, '跳过 v3 到 v4 的升级操作')
|
||||
return True
|
||||
self.step_end(_step, 0, '需要升级到v4')
|
||||
|
||||
if self.db.db_source['type'] == self.db.DB_TYPE_SQLITE:
|
||||
_step = self.step_begin(' - 备份数据库文件')
|
||||
_bak_file = '{}.before-v3-to-v4'.format(self.db.db_source['file'])
|
||||
if not os.path.exists(_bak_file):
|
||||
shutil.copy(self.db.db_source['file'], _bak_file)
|
||||
self.step_end(_step, 0)
|
||||
|
||||
# 如果ts_config表中没有ts_server_telnet_port项,则增加默认值52389
|
||||
db_ret = self.db.query('SELECT * FROM `{}config` WHERE `name`="ts_server_telnet_port";'.format(self.db.table_prefix))
|
||||
if len(db_ret) == 0:
|
||||
if not self.db.exec('INSERT INTO `{}config` (`name`, `value`) VALUES ("ts_server_telnet_port", "52389");'.format(self.db.table_prefix)):
|
||||
return False
|
||||
|
||||
auth_info_ret = self.db.query('SELECT `id`, `host_id`, `pro_type`, `auth_mode`, `user_name`, `user_pswd`, `cert_id`, `encrypt`, `log_time` FROM `{}auth_info`;'.format(self.db.table_prefix))
|
||||
auth_ret = self.db.query('SELECT `auth_id`, `account_name`, `host_id`, `host_auth_id` FROM `{}auth`;'.format(self.db.table_prefix))
|
||||
|
||||
max_host_id = 0
|
||||
new_host_info = []
|
||||
new_auth_info = []
|
||||
new_auth = []
|
||||
|
||||
# 从原来的表中查询数据
|
||||
host_info_ret = self.db.query('SELECT `host_id`, `group_id`, `host_sys_type`, `host_ip`, `pro_port`, `host_lock`, `host_desc` FROM {}host_info;'.format(self.db.table_prefix))
|
||||
if host_info_ret is None:
|
||||
return True
|
||||
# 先找出最大的host_id,这样如果要拆分一个host,就知道新的host_id应该是多少了
|
||||
for i in range(len(host_info_ret)):
|
||||
if host_info_ret[i][0] > max_host_id:
|
||||
max_host_id = host_info_ret[i][0]
|
||||
max_host_id += 1
|
||||
|
||||
# 然后构建新的host列表
|
||||
for i in range(len(host_info_ret)):
|
||||
host_info = {}
|
||||
host_info_alt = None
|
||||
|
||||
protocol = json.loads(host_info_ret[i][4])
|
||||
host_info['host_id'] = host_info_ret[i][0]
|
||||
host_info['group_id'] = host_info_ret[i][1]
|
||||
host_info['host_sys_type'] = host_info_ret[i][2]
|
||||
host_info['host_ip'] = host_info_ret[i][3]
|
||||
host_info['host_lock'] = host_info_ret[i][5]
|
||||
host_info['host_desc'] = host_info_ret[i][6]
|
||||
host_info['_old_host_id'] = host_info_ret[i][0]
|
||||
host_info['host_port'] = 0
|
||||
host_info['protocol'] = 0
|
||||
|
||||
have_rdp = False
|
||||
have_ssh = False
|
||||
if auth_info_ret is not None:
|
||||
for j in range(len(auth_info_ret)):
|
||||
if auth_info_ret[j][1] == host_info['host_id']:
|
||||
if auth_info_ret[j][2] == 1: # 用到了此主机的RDP
|
||||
have_rdp = True
|
||||
elif auth_info_ret[j][2] == 2: # 用到了此主机的SSH
|
||||
have_ssh = True
|
||||
|
||||
if have_rdp and have_ssh:
|
||||
# 需要拆分
|
||||
host_info['protocol'] = 1
|
||||
host_info['host_port'] = protocol['rdp']['port']
|
||||
|
||||
host_info_alt = {}
|
||||
host_info_alt['host_id'] = max_host_id
|
||||
max_host_id += 1
|
||||
host_info_alt['group_id'] = host_info_ret[i][1]
|
||||
host_info_alt['host_sys_type'] = host_info_ret[i][2]
|
||||
host_info_alt['host_ip'] = host_info_ret[i][3]
|
||||
host_info_alt['host_lock'] = host_info_ret[i][5]
|
||||
host_info_alt['host_desc'] = host_info_ret[i][6]
|
||||
host_info_alt['_old_host_id'] = host_info_ret[i][0]
|
||||
host_info_alt['host_port'] = protocol['ssh']['port']
|
||||
host_info_alt['protocol'] = 2
|
||||
elif have_rdp:
|
||||
host_info['protocol'] = 1
|
||||
host_info['host_port'] = protocol['rdp']['port']
|
||||
elif have_ssh:
|
||||
host_info['host_port'] = protocol['ssh']['port']
|
||||
host_info['protocol'] = 2
|
||||
|
||||
new_host_info.append(host_info)
|
||||
if host_info_alt is not None:
|
||||
new_host_info.append(host_info_alt)
|
||||
|
||||
# print('=====================================')
|
||||
# for i in range(len(new_host_info)):
|
||||
# print(new_host_info[i])
|
||||
|
||||
# 现在有了新的ts_host_info表,重构ts_auth_info表
|
||||
# 'SELECT id, host_id, pro_type, auth_mode, user_name, user_pswd, cert_id, encrypt, log_time FROM ts_auth_info;'
|
||||
if auth_info_ret is not None:
|
||||
for i in range(len(auth_info_ret)):
|
||||
auth_info = {}
|
||||
auth_info['id'] = auth_info_ret[i][0]
|
||||
auth_info['auth_mode'] = auth_info_ret[i][3]
|
||||
auth_info['user_name'] = auth_info_ret[i][4]
|
||||
auth_info['user_pswd'] = auth_info_ret[i][5]
|
||||
auth_info['cert_id'] = auth_info_ret[i][6]
|
||||
auth_info['encrypt'] = auth_info_ret[i][7]
|
||||
auth_info['log_time'] = auth_info_ret[i][8]
|
||||
auth_info['user_param'] = 'ogin:\nassword:'
|
||||
found = False
|
||||
for j in range(len(new_host_info)):
|
||||
if auth_info_ret[i][1] == new_host_info[j]['_old_host_id'] and auth_info_ret[i][2] == new_host_info[j]['protocol']:
|
||||
found = True
|
||||
auth_info['host_id'] = new_host_info[j]['host_id']
|
||||
auth_info['_old_host_id'] = new_host_info[j]['_old_host_id']
|
||||
break
|
||||
if found:
|
||||
new_auth_info.append(auth_info)
|
||||
|
||||
# for i in range(len(new_auth_info)):
|
||||
# print(new_auth_info[i])
|
||||
|
||||
# 最后重构ts_auth表
|
||||
if auth_ret is not None:
|
||||
for i in range(len(auth_ret)):
|
||||
auth = {}
|
||||
auth['auth_id'] = auth_ret[i][0]
|
||||
auth['account_name'] = auth_ret[i][1]
|
||||
found = False
|
||||
for j in range(len(new_auth_info)):
|
||||
if auth_ret[i][2] == new_auth_info[j]['_old_host_id'] and auth_ret[i][3] == new_auth_info[j]['id']:
|
||||
found = True
|
||||
auth['host_id'] = new_auth_info[j]['host_id']
|
||||
auth['host_auth_id'] = new_auth_info[j]['id']
|
||||
break
|
||||
if found:
|
||||
new_auth.append(auth)
|
||||
|
||||
# for i in range(len(new_auth)):
|
||||
# print(new_auth[i])
|
||||
|
||||
# 将整理好的数据写入新的临时表
|
||||
# 先创建三个临时表
|
||||
if not self.db.exec("""CREATE TABLE `{}auth_tmp` (
|
||||
`auth_id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
`account_name` varchar(256),
|
||||
`host_id` INTEGER,
|
||||
`host_auth_id` int(11) NOT NULL
|
||||
);""".format(self.db.table_prefix)):
|
||||
return False
|
||||
|
||||
if not self.db.exec("""CREATE TABLE `{}host_info_tmp` (
|
||||
`host_id` integer PRIMARY KEY AUTOINCREMENT,
|
||||
`group_id` int(11) DEFAULT 0,
|
||||
`host_sys_type` int(11) DEFAULT 1,
|
||||
`host_ip` varchar(32) DEFAULT '',
|
||||
`host_port` int(11) DEFAULT 0,
|
||||
`protocol` int(11) DEFAULT 0,
|
||||
`host_lock` int(11) DEFAULT 0,
|
||||
`host_desc` DEFAULT ''
|
||||
);""".format(self.db.table_prefix)):
|
||||
return False
|
||||
|
||||
if not self.db.exec("""CREATE TABLE `{}auth_info_tmp` (
|
||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT NOT NULL,
|
||||
`host_id` INTEGER,
|
||||
`auth_mode` INTEGER,
|
||||
`user_name` varchar(256),
|
||||
`user_pswd` varchar(256),
|
||||
`user_param` varchar(256),
|
||||
`cert_id` INTEGER,
|
||||
`encrypt` INTEGER,
|
||||
`log_time` varchar(60)
|
||||
);""".format(self.db.table_prefix)):
|
||||
return False
|
||||
|
||||
for i in range(len(new_host_info)):
|
||||
sql = 'INSERT INTO `{}host_info_tmp` (`host_id`, `group_id`, `host_sys_type`, `host_ip`, `host_port`, `protocol`, `host_lock`, `host_desc`) ' \
|
||||
'VALUES ({}, {}, {}, \'{}\', {}, {}, {}, "{}");'.format(
|
||||
self.db.table_prefix,
|
||||
new_host_info[i]['host_id'], new_host_info[i]['group_id'], new_host_info[i]['host_sys_type'],
|
||||
new_host_info[i]['host_ip'], new_host_info[i]['host_port'], new_host_info[i]['protocol'],
|
||||
new_host_info[i]['host_lock'], new_host_info[i]['host_desc']
|
||||
)
|
||||
if not self.db.exec(sql):
|
||||
return False
|
||||
|
||||
for i in range(len(new_auth_info)):
|
||||
sql = 'INSERT INTO `{}auth_info_tmp` (`id`, `host_id`, `auth_mode`, `user_name`, `user_pswd`, `user_param`, `cert_id`, `encrypt`, `log_time`) ' \
|
||||
'VALUES ({}, {}, {}, "{}", "{}", "{}", {}, {}, "{}");'.format(
|
||||
self.db.table_prefix,
|
||||
new_auth_info[i]['id'], new_auth_info[i]['host_id'], new_auth_info[i]['auth_mode'],
|
||||
new_auth_info[i]['user_name'], new_auth_info[i]['user_pswd'], new_auth_info[i]['user_param'],
|
||||
new_auth_info[i]['cert_id'], new_auth_info[i]['encrypt'], '1'
|
||||
)
|
||||
# print(str_sql)
|
||||
if not self.db.exec(sql):
|
||||
return False
|
||||
|
||||
for i in range(len(new_auth)):
|
||||
sql = 'INSERT INTO `{}auth_tmp` (`auth_id`, `account_name`, `host_id`, `host_auth_id`) ' \
|
||||
'VALUES ({}, \'{}\', {}, {});'.format(
|
||||
self.db.table_prefix,
|
||||
new_auth[i]['auth_id'], new_auth[i]['account_name'], new_auth[i]['host_id'], new_auth[i]['host_auth_id']
|
||||
)
|
||||
if not self.db.exec(sql):
|
||||
return False
|
||||
|
||||
# 表改名
|
||||
if not self.db.exec('ALTER TABLE `{}auth` RENAME TO `__bak_{}auth`;'.format(self.db.table_prefix, self.db.table_prefix)):
|
||||
return False
|
||||
|
||||
if not self.db.exec('ALTER TABLE `{}auth_info` RENAME TO `__bak_{}auth_info`;'.format(self.db.table_prefix, self.db.table_prefix)):
|
||||
return False
|
||||
|
||||
if not self.db.exec('ALTER TABLE `{}host_info` RENAME TO `__bak_{}host_info`;'.format(self.db.table_prefix, self.db.table_prefix)):
|
||||
return False
|
||||
|
||||
if not self.db.exec('ALTER TABLE `{}auth_tmp` RENAME TO `{}auth`;'.format(self.db.table_prefix, self.db.table_prefix)):
|
||||
return False
|
||||
|
||||
if not self.db.exec('ALTER TABLE `{}auth_info_tmp` RENAME TO `{}auth_info`;'.format(self.db.table_prefix, self.db.table_prefix)):
|
||||
return False
|
||||
|
||||
if not self.db.exec('ALTER TABLE `{}host_info_tmp` RENAME TO `{}host_info`;'.format(self.db.table_prefix, self.db.table_prefix)):
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
except:
|
||||
log.e('failed.\n')
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
def _upgrade_to_v5(self):
|
||||
_step = self.step_begin('检查数据库版本v5...')
|
||||
|
||||
# 服务端升级到版本2.1.0.1时,为解决将来数据库升级的问题,在 ts_config 表中加入 db_ver 指明当前数据结构版本
|
||||
try:
|
||||
# 判断依据:
|
||||
# 如果 config 表中不存在名为db_ver的数据,说明是旧版本,需要升级
|
||||
|
||||
db_ret = self.db.query('SELECT `value` FROM `{}config` WHERE `name`="db_ver";'.format(self.db.table_prefix))
|
||||
if db_ret is None:
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
if len(db_ret) > 0 and int(db_ret[0][0]) >= self.db.DB_VERSION:
|
||||
self.step_end(_step, 0, '跳过 v4 到 v5 的升级操作')
|
||||
return True
|
||||
self.step_end(_step, 0, '需要升级到v5')
|
||||
|
||||
_step = self.step_begin(' - 更新数据库版本号')
|
||||
if not self.db.exec('INSERT INTO `{}config` VALUES ("db_ver", "{}");'.format(self.db.table_prefix, self.db.DB_VERSION)):
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
else:
|
||||
self.step_end(_step, 0)
|
||||
return True
|
||||
|
||||
except:
|
||||
log.e('failed.\n')
|
||||
self.step_end(_step, -1)
|
||||
return False
|
||||
|
||||
|
||||
|
|
|
@ -8,7 +8,7 @@ import threading
|
|||
from eom_common.eomcore.logger import log
|
||||
from .configs import app_cfg
|
||||
from .database.create import create_and_init
|
||||
from .database.upgrade import upgrade_database
|
||||
from .database.upgrade import DatabaseUpgrade
|
||||
|
||||
cfg = app_cfg()
|
||||
|
||||
|
@ -16,49 +16,117 @@ __all__ = ['get_db']
|
|||
|
||||
|
||||
# 注意,每次调整数据库结构,必须增加版本号,并且在升级接口中编写对应的升级操作
|
||||
TELEPORT_DATABASE_VERSION = 10
|
||||
|
||||
|
||||
class TPDatabase:
|
||||
DB_VERSION = 10
|
||||
|
||||
DB_TYPE_UNKNOWN = 0
|
||||
DB_TYPE_SQLITE = 1
|
||||
DB_TYPE_MYSQL = 2
|
||||
|
||||
def __init__(self):
|
||||
if '__teleport_db__' in builtins.__dict__:
|
||||
raise RuntimeError('TPDatabase object exists, you can not create more than one instance.')
|
||||
|
||||
self._table_prefix = ''
|
||||
|
||||
self.db_source = {'type': self.DB_TYPE_UNKNOWN}
|
||||
self.need_create = False # 数据尚未存在,需要创建
|
||||
self.need_upgrade = False # 数据库已存在但版本较低,需要升级
|
||||
self.current_ver = 0
|
||||
|
||||
self._table_prefix = ''
|
||||
self._conn_pool = None
|
||||
|
||||
@property
|
||||
def table_prefix(self):
|
||||
return self._table_prefix
|
||||
|
||||
def init_mysql(self):
|
||||
# NOT SUPPORTED YET
|
||||
pass
|
||||
def init(self, db_source):
|
||||
self.db_source = db_source
|
||||
|
||||
def init_sqlite(self, db_file):
|
||||
self._table_prefix = 'ts_'
|
||||
self._conn_pool = TPSqlitePool(db_file)
|
||||
if db_source['type'] == self.DB_TYPE_MYSQL:
|
||||
log.e('MySQL not supported yet.')
|
||||
return False
|
||||
elif db_source['type'] == self.DB_TYPE_SQLITE:
|
||||
self._table_prefix = 'ts_'
|
||||
self._conn_pool = TPSqlitePool(db_source['file'])
|
||||
|
||||
if not os.path.exists(db_file):
|
||||
if not os.path.exists(db_source['file']):
|
||||
log.w('database need create.\n')
|
||||
self.need_create = True
|
||||
return True
|
||||
else:
|
||||
log.e('Unknown database type: {}'.format(db_source['type']))
|
||||
return False
|
||||
|
||||
# 看看数据库中是否存在指定的数据表(如果不存在,可能是一个空数据库文件),则可能是一个新安装的系统
|
||||
# ret = self.query('SELECT COUNT(*) FROM `sqlite_master` WHERE `type`="table" AND `name`="{}account";'.format(self._table_prefix))
|
||||
ret = self.is_table_exists('group')
|
||||
if ret is None or not ret:
|
||||
# if ret is None or ret[0][0] == 0:
|
||||
log.w('database need create.\n')
|
||||
self.need_create = True
|
||||
return
|
||||
return True
|
||||
|
||||
# 看看数据库中是否存在用户表(如果不存在,可能是一个空数据库文件),则可能是一个新安装的系统
|
||||
ret = self.query('SELECT COUNT(*) FROM `sqlite_master` WHERE `type`="table" AND `name`="{}account";'.format(self._table_prefix))
|
||||
if ret is None or ret[0][0] == 0:
|
||||
log.w('database need create.\n')
|
||||
self.need_create = True
|
||||
return
|
||||
# 尝试从配置表中读取当前数据库版本号(如果不存在,说明是比较旧的版本了)
|
||||
ret = self.query('SELECT `value` FROM `{}config` WHERE `name`="db_ver";'.format(self._table_prefix))
|
||||
log.w(ret)
|
||||
if ret is None or 0 == len(ret):
|
||||
self.current_ver = 1
|
||||
else:
|
||||
self.current_ver = int(ret[0][0])
|
||||
|
||||
# 尝试从配置表中读取当前数据库版本号(如果不存在,说明是比较旧的版本了,则置为0)
|
||||
ret = self.query('SELECT `value` FROM {}config WHERE `name`="db_ver";'.format(self._table_prefix))
|
||||
if ret is None or 0 == len(ret) or ret[0][0] < TELEPORT_DATABASE_VERSION:
|
||||
if self.current_ver < self.DB_VERSION:
|
||||
log.w('database need upgrade.\n')
|
||||
self.need_upgrade = True
|
||||
return True
|
||||
|
||||
return True
|
||||
|
||||
# def init_sqlite(self, db_file):
|
||||
# self._table_prefix = 'ts_'
|
||||
# self._conn_pool = TPSqlitePool(db_file)
|
||||
#
|
||||
# if not os.path.exists(db_file):
|
||||
# log.w('database need create.\n')
|
||||
# self.need_create = True
|
||||
# return
|
||||
#
|
||||
# # 看看数据库中是否存在指定的数据表(如果不存在,可能是一个空数据库文件),则可能是一个新安装的系统
|
||||
# # ret = self.query('SELECT COUNT(*) FROM `sqlite_master` WHERE `type`="table" AND `name`="{}account";'.format(self._table_prefix))
|
||||
# ret = self.is_table_exists('group')
|
||||
# if ret is None or not ret:
|
||||
# # if ret is None or ret[0][0] == 0:
|
||||
# log.w('database need create.\n')
|
||||
# self.need_create = True
|
||||
# return
|
||||
#
|
||||
# # 尝试从配置表中读取当前数据库版本号(如果不存在,说明是比较旧的版本了)
|
||||
# ret = self.query('SELECT `value` FROM {}config WHERE `name`="db_ver";'.format(self._table_prefix))
|
||||
# if ret is None or 0 == len(ret) or ret[0][0] < TELEPORT_DATABASE_VERSION:
|
||||
# log.w('database need upgrade.\n')
|
||||
# self.need_upgrade = True
|
||||
|
||||
def is_table_exists(self, table_name):
|
||||
"""
|
||||
判断指定的表是否存在
|
||||
@param table_name: string
|
||||
@return: None or Boolean
|
||||
"""
|
||||
# return self._conn_pool.is_table_exists(table_name)
|
||||
if self.db_source['type'] == self.DB_TYPE_SQLITE:
|
||||
ret = self.query('SELECT COUNT(*) FROM `sqlite_master` WHERE `type`="table" AND `name`="{}{}";'.format(self._table_prefix, table_name))
|
||||
if ret is None:
|
||||
return None
|
||||
if len(ret) == 0:
|
||||
return False
|
||||
if ret[0][0] == 0:
|
||||
return False
|
||||
return True
|
||||
elif self.db_source['type'] == self.DB_TYPE_MYSQL:
|
||||
return None
|
||||
else:
|
||||
return None
|
||||
|
||||
def query(self, sql):
|
||||
return self._conn_pool.query(sql)
|
||||
|
@ -67,16 +135,15 @@ class TPDatabase:
|
|||
return self._conn_pool.exec(sql)
|
||||
|
||||
def create_and_init(self, step_begin, step_end):
|
||||
step_begin('准备创建数据表')
|
||||
if create_and_init(self, step_begin, step_end, TELEPORT_DATABASE_VERSION):
|
||||
if create_and_init(self, step_begin, step_end):
|
||||
self.need_create = False
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
def upgrade_database(self, step_begin, step_end):
|
||||
step_begin('准备升级数据表')
|
||||
if upgrade_database(self, step_begin, step_end, TELEPORT_DATABASE_VERSION):
|
||||
if DatabaseUpgrade(self, step_begin, step_end).do_upgrade():
|
||||
# if upgrade_database(self, step_begin, step_end):
|
||||
self.need_upgrade = False
|
||||
return True
|
||||
else:
|
||||
|
@ -88,6 +155,12 @@ class TPDatabasePool:
|
|||
self._locker = threading.RLock()
|
||||
self._connections = dict()
|
||||
|
||||
# def is_table_exists(self, table_name):
|
||||
# _conn = self._get_connect()
|
||||
# if _conn is None:
|
||||
# return None
|
||||
# return self._is_table_exists(_conn, table_name)
|
||||
|
||||
def query(self, sql):
|
||||
_conn = self._get_connect()
|
||||
if _conn is None:
|
||||
|
@ -114,6 +187,9 @@ class TPDatabasePool:
|
|||
def _do_connect(self):
|
||||
return None
|
||||
|
||||
# def _is_table_exists(self, conn, table_name):
|
||||
# return None
|
||||
|
||||
def _do_query(self, conn, sql):
|
||||
return None
|
||||
|
||||
|
@ -133,6 +209,14 @@ class TPSqlitePool(TPDatabasePool):
|
|||
log.e('[sqlite] can not connect, does the database file correct?')
|
||||
return None
|
||||
|
||||
# def _is_table_exists(self, conn, table_name):
|
||||
# ret = self._do_query(conn, 'SELECT COUNT(*) FROM `sqlite_master` WHERE `type`="table" AND `name`="{}{}";'.format(self._table_prefix, table_name))
|
||||
# # ret = self.query('SELECT COUNT(*) FROM `sqlite_master` WHERE `type`="table" AND `name`="{}{}";'.format(self._table_prefix, table_name))
|
||||
# if ret is None or ret[0][0] == 0:
|
||||
# return False
|
||||
# else:
|
||||
# return True
|
||||
|
||||
def _do_query(self, conn, sql):
|
||||
cursor = conn.cursor()
|
||||
try:
|
||||
|
|
|
@ -92,7 +92,7 @@ class RpcThreadManage:
|
|||
|
||||
def _create_db(self, tid):
|
||||
def _step_begin(msg):
|
||||
self._step_begin(tid, msg)
|
||||
return self._step_begin(tid, msg)
|
||||
|
||||
def _step_end(sid, code, msg=None):
|
||||
self._step_end(tid, sid, code, msg)
|
||||
|
@ -106,7 +106,7 @@ class RpcThreadManage:
|
|||
|
||||
def _upgrade_db(self, tid):
|
||||
def _step_begin(msg):
|
||||
self._step_begin(tid, msg)
|
||||
return self._step_begin(tid, msg)
|
||||
|
||||
def _step_end(sid, code, msg=None):
|
||||
self._step_end(tid, sid, code, msg)
|
||||
|
@ -132,12 +132,10 @@ class RpcThreadManage:
|
|||
self._threads[tid]['steps'][sid]['code'] = code
|
||||
self._threads[tid]['steps'][sid]['stat'] = 0 # 0 表示此步骤已完成
|
||||
if msg is not None:
|
||||
self._threads[tid]['steps'][sid]['msg'] = msg
|
||||
self._threads[tid]['steps'][sid]['msg'] = '{}{}'.format(self._threads[tid]['steps'][sid]['msg'], msg)
|
||||
except:
|
||||
pass
|
||||
|
||||
return len(self._threads[tid]['steps']) - 1
|
||||
|
||||
def _thread_end(self, tid):
|
||||
with self._lock:
|
||||
if tid in self._threads:
|
||||
|
|
|
@ -28,8 +28,8 @@ class eom_sqlite:
|
|||
self._conn = None
|
||||
|
||||
def connect(self):
|
||||
if not os.path.exists(self._db_file):
|
||||
return None
|
||||
# if not os.path.exists(self._db_file):
|
||||
# return None
|
||||
try:
|
||||
self._conn = sqlite3.connect(self._db_file)
|
||||
except:
|
||||
|
|
|
@ -39,8 +39,8 @@ if os.path.exists(os.path.join(PATH_APP_ROOT, '..', '..', 'share', 'etc')):
|
|||
|
||||
# 检查操作系统,目前仅支持Win和Linux
|
||||
PLATFORM = platform.system().lower()
|
||||
if PLATFORM not in ['windows', 'linux']:
|
||||
print('Teleport WEB Server support Windows and Linux only.')
|
||||
if PLATFORM not in ['windows', 'linux', 'darwin']:
|
||||
print('TELEPORT WEB Server does not support `{}` platform yet.'.format(PLATFORM))
|
||||
sys.exit(1)
|
||||
|
||||
BITS = 'x64'
|
||||
|
|
Loading…
Reference in New Issue