当前数据库版本升级到5,但是升级脚本还需要调整:原来各个表的前缀是ts_,应该保持不变,从v5开始,配置表的名称改为tp_前缀,下一个版本v6则进一步将所有表的前缀改为tp_并且调整字段名称(去除不必要的字段名称前缀)。

pull/32/merge
apexliu 2017-03-23 01:51:36 +08:00
parent ba44d6c951
commit bff60a1f74
7 changed files with 122 additions and 208 deletions

View File

@ -4,21 +4,20 @@ import builtins
import os
import sqlite3
import threading
import datetime
from eom_common.eomcore.logger import log
from .configs import app_cfg
# from .configs import app_cfg
from .database.create import create_and_init
from .database.upgrade import DatabaseUpgrade
cfg = app_cfg()
# cfg = app_cfg()
__all__ = ['get_db', 'DbItem']
# 注意,每次调整数据库结构,必须增加版本号,并且在升级接口中编写对应的升级操作
class TPDatabase:
# 注意,每次调整数据库结构,必须增加版本号,并且在升级接口中编写对应的升级操作
DB_VERSION = 5
DB_TYPE_UNKNOWN = 0
@ -61,7 +60,7 @@ class TPDatabase:
# 看看数据库中是否存在指定的数据表(如果不存在,可能是一个空数据库文件),则可能是一个新安装的系统
# ret = self.query('SELECT COUNT(*) FROM `sqlite_master` WHERE `type`="table" AND `name`="{}account";'.format(self._table_prefix))
ret = self.is_table_exists('{}group'.format(self._table_prefix))
ret = self.is_table_exists('{}group'.format(self._table_prefix))
if ret is None or not ret:
# if ret is None or ret[0][0] == 0:
log.w('database need create.\n')
@ -70,7 +69,7 @@ class TPDatabase:
# 尝试从配置表中读取当前数据库版本号(如果不存在,说明是比较旧的版本了)
ret = self.query('SELECT `value` FROM `{}config` WHERE `name`="db_ver";'.format(self._table_prefix))
log.w(ret)
# log.w(ret)
if ret is None or 0 == len(ret):
self.current_ver = 1
else:
@ -81,9 +80,9 @@ class TPDatabase:
self.need_upgrade = True
return True
# DO TEST
self.alter_table('ts_account', [['account_id', 'id'], ['account_type', 'type']])
# DO TEST
# self.alter_table('ts_account', [['account_id', 'id'], ['account_type', 'type']])
return True
def is_table_exists(self, table_name):
@ -93,7 +92,7 @@ class TPDatabase:
@return: None or Boolean
"""
if self.db_source['type'] == self.DB_TYPE_SQLITE:
ret = self.query('SELECT COUNT(*) FROM `sqlite_master` WHERE `type`="table" AND `name`="{}";'.format(table_name))
ret = self.query('SELECT COUNT(*) FROM `sqlite_master` WHERE `type`="table" AND `name`="{}";'.format(table_name))
if ret is None:
return None
if len(ret) == 0:
@ -104,14 +103,24 @@ class TPDatabase:
elif self.db_source['type'] == self.DB_TYPE_MYSQL:
return None
else:
log.e('Unknown database type.\n')
log.e('Unknown database type.\n')
return None
def query(self, sql):
return self._conn_pool.query(sql)
# _start = datetime.datetime.utcnow().timestamp()
ret = self._conn_pool.query(sql)
# _end = datetime.datetime.utcnow().timestamp()
# log.d('[db] {}\n'.format(sql))
# log.d('[db] cost {} seconds.\n'.format(_end - _start))
return ret
def exec(self, sql):
return self._conn_pool.exec(sql)
# _start = datetime.datetime.utcnow().timestamp()
ret = self._conn_pool.exec(sql)
# _end = datetime.datetime.utcnow().timestamp()
# log.d('[db] {}\n'.format(sql))
# log.d('[db] cost {} seconds.\n'.format(_end - _start))
return ret
def create_and_init(self, step_begin, step_end):
if create_and_init(self, step_begin, step_end):
@ -128,61 +137,61 @@ class TPDatabase:
else:
return False
def alter_table(self, table_names, field_names=None):
"""
修改表名称及字段名称
table_name: 如果是string则指定要操作的表如果是list则第一个元素是要操作的表第二个元素是此表改名的目标名称
fields_names: 如果为None则不修改字段名否则应该是一个list其中每个元素是包含两个str的list表示将此list第一个指定的字段改名为第二个指定的名称
@return: None or Boolean
"""
if self.db_source['type'] == self.DB_TYPE_SQLITE:
if not isinstance(table_names, list) and field_names is None:
log.w('nothing to do.\n')
return False
def alter_table(self, table_names, field_names=None):
"""
修改表名称及字段名称
table_name: 如果是string则指定要操作的表如果是list则第一个元素是要操作的表第二个元素是此表改名的目标名称
fields_names: 如果为None则不修改字段名否则应该是一个list其中每个元素是包含两个str的list表示将此list第一个指定的字段改名为第二个指定的名称
@return: None or Boolean
"""
if self.db_source['type'] == self.DB_TYPE_SQLITE:
if not isinstance(table_names, list) and field_names is None:
log.w('nothing to do.\n')
return False
if isinstance(table_names, str):
old_table_name = table_names
new_table_name = table_names
elif isinstance(table_names, list) and len(table_names) == 2:
old_table_name = table_names[0]
new_table_name = table_names[1]
else:
log.w('invalid param.\n')
return False
if isinstance(field_names, list):
for i in field_names:
if not isinstance(i, list) or 2 != len(i):
log.w('invalid param.\n')
return False
if field_names is None:
# 仅数据表改名
return self.exec('ALTER TABLE `{}` RENAME TO `{}`;'.format(old_table_name, new_table_name))
else:
# sqlite不支持字段改名所以需要通过临时表中转一下
# 先获取数据表的字段名列表
ret = self.query('SELECT * FROM `sqlite_master` WHERE `type`="table" AND `name`="{}";'.format(old_table_name))
log.w('-----\n')
log.w(ret[0][4])
log.w('\n')
# 先将数据表改名,成为一个临时表
# tmp_table_name = '{}_sqlite_tmp'.format(old_table_name)
# ret = self.exec('ALTER TABLE `{}` RENAME TO `{}`;'.format(old_table_name, tmp_table_name))
# if ret is None or not ret:
# return ret
pass
elif self.db_source['type'] == self.DB_TYPE_MYSQL:
log.e('mysql not supported yet.\n')
return False
else:
log.e('Unknown database type.\n')
return False
if isinstance(table_names, str):
old_table_name = table_names
new_table_name = table_names
elif isinstance(table_names, list) and len(table_names) == 2:
old_table_name = table_names[0]
new_table_name = table_names[1]
else:
log.w('invalid param.\n')
return False
if isinstance(field_names, list):
for i in field_names:
if not isinstance(i, list) or 2 != len(i):
log.w('invalid param.\n')
return False
if field_names is None:
# 仅数据表改名
return self.exec('ALTER TABLE `{}` RENAME TO `{}`;'.format(old_table_name, new_table_name))
else:
# sqlite不支持字段改名所以需要通过临时表中转一下
# 先获取数据表的字段名列表
ret = self.query('SELECT * FROM `sqlite_master` WHERE `type`="table" AND `name`="{}";'.format(old_table_name))
log.w('-----\n')
log.w(ret)
log.w('\n')
# 先将数据表改名,成为一个临时表
# tmp_table_name = '{}_sqlite_tmp'.format(old_table_name)
# ret = self.exec('ALTER TABLE `{}` RENAME TO `{}`;'.format(old_table_name, tmp_table_name))
# if ret is None or not ret:
# return ret
pass
elif self.db_source['type'] == self.DB_TYPE_MYSQL:
log.e('mysql not supported yet.\n')
return False
else:
log.e('Unknown database type.\n')
return False
class TPDatabasePool:
def __init__(self):
self._locker = threading.RLock()

View File

@ -1,30 +0,0 @@
# -*- coding: utf-8 -*-
import eom_common.eomcore.eom_mysql as mysql
import eom_common.eomcore.eom_sqlite as sqlite
# from eom_app.app.configs import app_cfg
# cfg = app_cfg()
class DbItem(dict):
def load(self, db_item, db_fields):
if len(db_fields) != len(db_item):
raise RuntimeError('!=')
for i in range(len(db_item)):
self[db_fields[i]] = db_item[i]
def __getattr__(self, name):
try:
return self[name]
except KeyError:
raise
def get_db_con():
if False:
sql_exec = mysql.get_mysql_pool().get_tssqlcon()
else:
sql_exec = sqlite.get_sqlite_pool().get_tssqlcon()
return sql_exec

View File

@ -3,13 +3,11 @@
import json
import time
# from .common import *
from eom_app.app.db import get_db, DbItem
# 获取主机列表,包括主机的基本信息
def get_all_host_info_list(_filter, order, limit, with_pwd=False):
# sql_exec = get_db_con()
db = get_db()
_where = ''
@ -126,12 +124,9 @@ def get_all_host_info_list(_filter, order, limit, with_pwd=False):
def get_host_info_list_by_user(_filter, order, limit):
# sql_exec = get_db_con()
db = get_db()
_where = ''
# _where = ''
# _where = 'WHERE ( a.account_name=\'{}\' '.format(uname)
if len(_filter) > 0:
_where = 'WHERE ( '
@ -167,7 +162,6 @@ def get_host_info_list_by_user(_filter, order, limit):
_where += ')'
# http://www.jb51.net/article/46015.htm
field_a = ['auth_id', 'host_id', 'account_name', 'host_auth_id']
field_b = ['host_id', 'host_lock', 'host_ip', 'protocol', 'host_port', 'host_desc', 'group_id', 'host_sys_type']
field_c = ['group_name']
@ -271,7 +265,6 @@ def get_host_info_list_by_user(_filter, order, limit):
def get_group_list():
db = get_db()
field_a = ['group_id', 'group_name']
# sql_exec = get_db_con()
sql = 'SELECT {} FROM `{}group` AS a; '.format(','.join(['`a`.`{}`'.format(i) for i in field_a]), db.table_prefix)
db_ret = db.query(sql)
ret = list()
@ -286,23 +279,6 @@ def get_group_list():
return ret
# def get_config_list():
# try:
# sql_exec = get_db_con()
# field_a = ['name', 'value']
# string_sql = 'SELECT {} FROM ts_config as a ;'.format(','.join(['a.{}'.format(i) for i in field_a]))
# db_ret = sql_exec.ExecProcQuery(string_sql)
# h = dict()
# for item in db_ret:
# x = DbItem()
# x.load(item, ['a_{}'.format(i) for i in field_a])
# h[x.a_name] = x.a_value
#
# return h
# except:
# return None
def update(host_id, kv):
db = get_db()
@ -327,10 +303,7 @@ def update(host_id, kv):
def get_cert_list():
db = get_db()
# http://www.jb51.net/article/46015.htm
field_a = ['cert_id', 'cert_name', 'cert_pub', 'cert_pri', 'cert_desc']
sql = 'SELECT {} FROM `{}key` AS a;'.format(','.join(['`a`.`{}`'.format(i) for i in field_a]), db.table_prefix)
db_ret = db.query(sql)
@ -425,7 +398,7 @@ def delete_host(host_list):
sql = 'DELETE FROM `{}auth_info` WHERE `host_id`={};'.format(db.table_prefix, host_id)
ret = db.exec(sql)
str_sql = 'DELETE FROM `{}auth` WHERE `host_id`={};'.format(db.table_prefix, host_id)
sql = 'DELETE FROM `{}auth` WHERE `host_id`={};'.format(db.table_prefix, host_id)
ret = db.exec(sql)
return True
@ -477,7 +450,7 @@ def delete_group(group_id):
def update_group(group_id, group_name):
db = get_db()
sql = 'UPDATE {}group SET `group_name`="{}" ' \
sql = 'UPDATE `{}group` SET `group_name`="{}" ' \
'WHERE `group_id`={};'.format(db.table_prefix, group_name, int(group_id))
return db.exec(sql)
@ -502,8 +475,8 @@ def get_host_auth_info(host_auth_id):
sql = 'SELECT {},{} ' \
'FROM `{}auth_info` AS a ' \
'LEFT JOIN `{}host_info` AS b ON `a`.`host_id`=`b`.`host_id` ' \
'WHERE `a`.`id`={};'.format(','.join(['a.{}'.format(i) for i in field_a]),
','.join(['b.{}'.format(i) for i in field_b]),
'WHERE `a`.`id`={};'.format(','.join(['`a`.`{}`'.format(i) for i in field_a]),
','.join(['`b`.`{}`'.format(i) for i in field_b]),
db.table_prefix, db.table_prefix,
host_auth_id)
db_ret = db.query(sql)
@ -555,41 +528,6 @@ def get_host_auth_info(host_auth_id):
return h
# def update_host_extend_info(host_id, args):
# db = get_db()
#
# ip = args['ip']
# port = int(args['port'])
# user_name = args['user_name']
# user_pwd = args['user_pwd']
# cert_id = int(args['cert_id'])
# pro_type = int(args['pro_type'])
# sys_type = int(args['sys_type'])
# group_id = args['group_id']
# host_desc = args['desc']
# host_auth_mode = int(args['host_auth_mode'])
# host_encrypt = 1
#
# # if len(user_pwd) == 0 and 0 == cert_id:
# # return False
# if 0 == len(user_pwd):
# str_sql = 'UPDATE ts_host_info SET host_ip = \'{}\', ' \
# 'host_pro_port = {}, host_user_name = \'{}\', ' \
# 'cert_id = {}, host_pro_type = {},host_sys_type={}, group_id={},host_auth_mode={},host_encrypt={}, ' \
# 'host_desc=\'{}\' WHERE host_id = {}'.format(
# ip, port, user_name, cert_id, pro_type, sys_type, group_id, host_auth_mode, host_encrypt, host_desc, host_id)
#
# else:
# str_sql = 'UPDATE ts_host_info SET host_ip = \'{}\', ' \
# 'host_pro_port = {}, host_user_name = \'{}\', host_user_pwd = \'{}\', ' \
# 'cert_id = {}, host_pro_type = {},host_sys_type={}, group_id={},host_auth_mode={},host_encrypt={}, ' \
# 'host_desc=\'{}\' WHERE host_id = {}'.format(
# ip, port, user_name, user_pwd, cert_id, pro_type, sys_type, group_id, host_auth_mode, host_encrypt, host_desc, host_id)
#
# ret = sql_exec.ExecProcNonQuery(str_sql)
# return ret
def get_cert_info(cert_id):
db = get_db()
sql = 'SELECT `cert_pri` FROM `{}key` WHERE `cert_id`={};'.format(db.table_prefix, cert_id)
@ -607,11 +545,11 @@ def sys_user_list(host_id, with_pwd=True, host_auth_id=0):
if host_auth_id == 0:
sql = 'SELECT {} ' \
'FROM `{}auth_info` AS a ' \
'WHERE `a`.`host_id`={};'.format(','.join(['a.{}'.format(i) for i in field_a]), db.table_prefix, int(host_id))
'WHERE `a`.`host_id`={};'.format(','.join(['`a`.`{}`'.format(i) for i in field_a]), db.table_prefix, int(host_id))
else:
sql = 'SELECT {} ' \
'FROM `{}auth_info` AS a ' \
'WHERE `a`.`id`={} and `a`.`host_id`={};'.format(','.join(['a.{}'.format(i) for i in field_a]), db.table_prefix, int(host_auth_id), int(host_id))
'WHERE `a`.`id`={} and `a`.`host_id`={};'.format(','.join(['`a`.`{}`'.format(i) for i in field_a]), db.table_prefix, int(host_auth_id), int(host_id))
db_ret = db.query(sql)
@ -678,15 +616,15 @@ def sys_user_add(args):
log_time = GetNowTime()
if auth_mode == 1:
sql = 'INSERT INTO {}auth_info (host_id, auth_mode, user_name, user_pswd, user_param, encrypt, cert_id, log_time) ' \
sql = 'INSERT INTO `{}auth_info` (`host_id`,`auth_mode`,`user_name`,`user_pswd`,`user_param`,`encrypt`,`cert_id`,`log_time`) ' \
'VALUES ({},{},"{}","{}","{}",{}, {},"{}")' \
''.format(db.table_prefix, host_id, auth_mode, user_name, user_pswd, user_param, encrypt, 0, log_time)
elif auth_mode == 2:
sql = 'INSERT INTO {}auth_info (host_id, auth_mode, user_name, user_pswd, user_param, encrypt, cert_id, log_time) ' \
sql = 'INSERT INTO `{}auth_info` (`host_id`,`auth_mode`,`user_name`,`user_pswd`,`user_param`,`encrypt`,`cert_id`,`log_time`) ' \
'VALUES ({},{},"{}","{}","{}",{},{},"{}")' \
''.format(db.table_prefix, host_id, auth_mode, user_name, '', user_param, encrypt, cert_id, log_time)
elif auth_mode == 0:
sql = 'INSERT INTO {}auth_info (host_id, auth_mode, user_name, user_pswd, user_param, encrypt, cert_id, log_time) ' \
sql = 'INSERT INTO `{}auth_info` (`host_id`,`auth_mode`,`user_name`,`user_pswd`,`user_param`,`encrypt`,`cert_id`,`log_time`) ' \
'VALUES ({},{},"{}","{}","{}",{},{},"{}")' \
''.format(db.table_prefix, host_id, auth_mode, user_name, '', user_param, encrypt, 0, log_time)
ret = db.exec(sql)
@ -750,10 +688,10 @@ def get_auth_info(auth_id):
'LEFT JOIN `{}auth_info` AS c ON `a`.`host_auth_id`=`c`.`id` ' \
'LEFT JOIN `{}account` AS d ON `a`.`account_name`=`d`.`account_name` ' \
'WHERE `a`.`auth_id`={};' \
''.format(','.join(['a.{}'.format(i) for i in field_a]),
','.join(['b.{}'.format(i) for i in field_b]),
','.join(['c.{}'.format(i) for i in field_c]),
','.join(['d.{}'.format(i) for i in field_d]),
''.format(','.join(['`a`.`{}`'.format(i) for i in field_a]),
','.join(['`b`.`{}`'.format(i) for i in field_b]),
','.join(['`c`.`{}`'.format(i) for i in field_c]),
','.join(['`d`.`{}`'.format(i) for i in field_d]),
db.table_prefix, db.table_prefix, db.table_prefix, db.table_prefix,
auth_id)

View File

@ -1,19 +1,17 @@
# -*- coding: utf-8 -*-
import datetime
import os
import shutil
import struct
from eom_app.app.configs import app_cfg
from eom_app.app.db import get_db
from eom_common.eomcore.logger import log
from .common import *
cfg = app_cfg()
from eom_common.eomcore.utils import timestamp_utc_now
def read_record_head(record_id):
record_path = os.path.join(cfg.data_path, 'replay', 'ssh', '{:06d}'.format(int(record_id)))
record_path = os.path.join(app_cfg().data_path, 'replay', 'ssh', '{:06d}'.format(int(record_id)))
header_file_path = os.path.join(record_path, 'tp-ssh.tpr')
file = None
try:
@ -128,7 +126,7 @@ def read_record_head(record_id):
def read_record_info(record_id, file_id):
record_path = os.path.join(cfg.data_path, 'replay', 'ssh', '{:06d}'.format(int(record_id)))
record_path = os.path.join(app_cfg().data_path, 'replay', 'ssh', '{:06d}'.format(int(record_id)))
file_info = os.path.join(record_path, 'tp-ssh.{:03d}'.format(int(file_id)))
file = None
try:
@ -190,20 +188,21 @@ def delete_log(log_list):
try:
where = list()
for item in log_list:
where.append(' id={}'.format(item))
where.append(' `id`={}'.format(item))
str_sql = 'DELETE FROM ts_log WHERE{};'.format(' OR'.join(where))
ret = get_db_con().ExecProcNonQuery(str_sql)
db = get_db()
sql = 'DELETE FROM `{}log` WHERE{};'.format(db.table_prefix, ' OR'.join(where))
ret = db.exec(sql)
if not ret:
return False
for item in log_list:
log_id = int(item)
try:
record_path = os.path.join(cfg.data_path, 'replay', 'ssh', '{:06d}'.format(log_id))
record_path = os.path.join(app_cfg().data_path, 'replay', 'ssh', '{:06d}'.format(log_id))
if os.path.exists(record_path):
shutil.rmtree(record_path)
record_path = os.path.join(cfg.data_path, 'replay', 'rdp', '{:06d}'.format(log_id))
record_path = os.path.join(app_cfg().data_path, 'replay', 'rdp', '{:06d}'.format(log_id))
if os.path.exists(record_path):
shutil.rmtree(record_path)
except Exception:
@ -216,28 +215,27 @@ def delete_log(log_list):
def session_fix():
try:
sql_exec = get_db_con()
str_sql = 'UPDATE ts_log SET ret_code=7 WHERE ret_code=0;'
return sql_exec.ExecProcNonQuery(str_sql)
db = get_db()
sql = 'UPDATE `{}log` SET `ret_code`=7 WHERE `ret_code`=0;'.format(db.table_prefix)
return db.exec(sql)
except:
return False
def session_begin(sid, acc_name, host_ip, sys_type, host_port, auth_mode, user_name, protocol):
try:
_now = int(datetime.datetime.utcnow().timestamp())
sql_exec = get_db_con()
db = get_db()
sql = 'INSERT INTO `{}log` (`session_id`,`account_name`,`host_ip`,`sys_type`,`host_port`,`auth_type`,`user_name`,`ret_code`,`begin_time`,`end_time`,`log_time`,`protocol`) ' \
'VALUES ("{}","{}","{}",{},{},{},"{}",{},{},{},"{}",{});' \
''.format(db.table_prefix,
sid, acc_name, host_ip, sys_type, host_port, auth_mode, user_name, 0, timestamp_utc_now(), 0, '', protocol)
str_sql = 'INSERT INTO ts_log (session_id, account_name,host_ip,sys_type, host_port,auth_type, user_name,ret_code,begin_time,end_time,log_time, protocol) ' \
'VALUES (\'{}\',\'{}\',\'{}\',{},{},{},\'{}\',{},{},{},\'{}\',{});'.format(
sid, acc_name, host_ip, sys_type, host_port, auth_mode, user_name, 0, _now, 0, '', protocol)
ret = sql_exec.ExecProcNonQuery(str_sql)
ret = db.exec(sql)
if not ret:
return -101
str_sql = 'SELECT last_insert_rowid()'
db_ret = sql_exec.ExecProcQuery(str_sql)
sql = 'SELECT last_insert_rowid()'
db_ret = db.query(sql)
if db_ret is None:
return -102
user_id = db_ret[0][0]
@ -249,9 +247,8 @@ def session_begin(sid, acc_name, host_ip, sys_type, host_port, auth_mode, user_n
def session_end(record_id, ret_code):
try:
_now = int(datetime.datetime.utcnow().timestamp())
sql_exec = get_db_con()
str_sql = 'UPDATE ts_log SET ret_code={}, end_time={} WHERE id={};'.format(ret_code, _now, record_id)
return sql_exec.ExecProcNonQuery(str_sql)
db = get_db()
sql = 'UPDATE `{}log` SET `ret_code`={}, `end_time`={} WHERE `id`={};'.format(db.table_prefix, int(ret_code), timestamp_utc_now(), int(record_id))
return db.exec(sql)
except:
return False

View File

@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
from eom_app.app.configs import app_cfg
cfg = app_cfg()
# from eom_app.app.configs import app_cfg
# cfg = app_cfg()
# def get_config_list():
# try:
@ -33,7 +33,7 @@ cfg = app_cfg()
# ret = sql_exec.ExecProcNonQuery(str_sql)
#
# return ret
def get_config_list():
print(cfg.core)
return cfg.core
#
# def get_config_list():
# print(cfg.core)
# return cfg.core

View File

@ -138,7 +138,7 @@ def add_user(user_name, user_pwd, user_desc):
def alloc_host(user_name, host_list):
db = get_db()
field_a = ['host_id']
sql = 'SELECT {} FROM `{}auth` AS a WHERE `account_name`="{}";'.format(','.join(['a.{}'.format(i) for i in field_a]), db.table_prefix, user_name)
sql = 'SELECT {} FROM `{}auth` AS a WHERE `account_name`="{}";'.format(','.join(['`a`.`{}`'.format(i) for i in field_a]), db.table_prefix, user_name)
db_ret = db.query(sql)
ret = dict()
for item in db_ret:

View File

@ -169,6 +169,6 @@ class UniqueId():
self._id += 1
return self._id
unique_id = UniqueId()
del UniqueId