mirror of https://github.com/jumpserver/jumpserver
58 lines
1.5 KiB
Python
58 lines
1.5 KiB
Python
from contextlib import contextmanager
|
||
|
||
from django.db import connections
|
||
|
||
from common.utils import get_logger
|
||
|
||
logger = get_logger(__file__)
|
||
|
||
|
||
def get_object_if_need(model, pk):
|
||
if not isinstance(pk, model):
|
||
try:
|
||
return model.objects.get(id=pk)
|
||
except model.DoesNotExist as e:
|
||
logger.error(f'DoesNotExist: <{model.__name__}:{pk}> not exist')
|
||
raise e
|
||
return pk
|
||
|
||
|
||
def get_objects_if_need(model, pks):
|
||
if not pks:
|
||
return pks
|
||
if not isinstance(pks[0], model):
|
||
objs = list(model.objects.filter(id__in=pks))
|
||
if len(objs) != len(pks):
|
||
pks = set(pks)
|
||
exists_pks = {o.id for o in objs}
|
||
not_found_pks = ','.join(pks - exists_pks)
|
||
logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>')
|
||
return objs
|
||
return pks
|
||
|
||
|
||
def get_objects(model, pks):
|
||
if not pks:
|
||
return pks
|
||
|
||
objs = list(model.objects.filter(id__in=pks))
|
||
if len(objs) != len(pks):
|
||
pks = set(pks)
|
||
exists_pks = {o.id for o in objs}
|
||
not_found_pks = pks - exists_pks
|
||
logger.error(f'DoesNotExist: <{model.__name__}: {not_found_pks}>')
|
||
return objs
|
||
|
||
|
||
# 复制 django.db.close_old_connections, 因为它没有导出,ide 提示有问题
|
||
def close_old_connections():
|
||
for conn in connections.all():
|
||
conn.close_if_unusable_or_obsolete()
|
||
|
||
|
||
@contextmanager
|
||
def safe_db_connection():
|
||
close_old_connections()
|
||
yield
|
||
close_old_connections()
|