diff --git a/apps/authentication/mixins.py b/apps/authentication/mixins.py index 247e2e4a0..bbc2203ed 100644 --- a/apps/authentication/mixins.py +++ b/apps/authentication/mixins.py @@ -18,7 +18,7 @@ from django.contrib.auth import ( from django.core.exceptions import ImproperlyConfigured from django.shortcuts import reverse, redirect, get_object_or_404 -from common.utils import get_request_ip, get_logger, bulk_get, FlashMessageUtil +from common.utils import get_request_ip_or_data, get_request_ip, get_logger, bulk_get, FlashMessageUtil from acls.models import LoginACL from users.models import User from users.utils import LoginBlockUtil, MFABlockUtils, LoginIpBlockUtil @@ -92,13 +92,12 @@ auth.authenticate = authenticate class CommonMixin: request: Request + _ip = '' def get_request_ip(self): ip = '' - if hasattr(self.request, 'data'): - ip = self.request.data.get('remote_addr', '') - ip = ip or get_request_ip(self.request) - return ip + if not self._ip: + self._ip = get_request_ip_or_data(self.request) def raise_credential_error(self, error): raise self.partial_credential_error(error=error) diff --git a/apps/common/utils/common.py b/apps/common/utils/common.py index a105d50c7..c198a0ad2 100644 --- a/apps/common/utils/common.py +++ b/apps/common/utils/common.py @@ -76,6 +76,7 @@ def setattr_bulk(seq, key, value): def set_attr(obj): setattr(obj, key, value) return obj + return map(set_attr, seq) @@ -97,12 +98,12 @@ def capacity_convert(size, expect='auto', rate=1000): rate_mapping = ( ('K', rate), ('KB', rate), - ('M', rate**2), - ('MB', rate**2), - ('G', rate**3), - ('GB', rate**3), - ('T', rate**4), - ('TB', rate**4), + ('M', rate ** 2), + ('MB', rate ** 2), + ('G', rate ** 3), + ('GB', rate ** 3), + ('T', rate ** 4), + ('TB', rate ** 4), ) rate_mapping = OrderedDict(rate_mapping) @@ -117,7 +118,7 @@ def capacity_convert(size, expect='auto', rate=1000): if expect == 'auto': for unit, rate_ in rate_mapping.items(): - if rate > std_size/rate_ >= 1 or unit == "T": + if rate > std_size / rate_ >= 1 or unit == "T": expect = unit break @@ -152,19 +153,28 @@ def is_uuid(seq): def get_request_ip(request): - x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR', '').split(',') + x_real_ip = request.META.get('HTTP_X_REAL_IP', '') + if x_real_ip: + return x_real_ip + x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR', '').split(',') if x_forwarded_for and x_forwarded_for[0]: login_ip = x_forwarded_for[0] - else: - login_ip = request.META.get('REMOTE_ADDR', '') + return login_ip + + login_ip = request.META.get('REMOTE_ADDR', '') return login_ip def get_request_ip_or_data(request): + from common.permissions import ServiceAccountSignaturePermission + ip = '' - if hasattr(request, 'data'): - ip = request.data.get('remote_addr', '') + + if hasattr(request, 'data') and request.data.get('remote_addr', ''): + permission = ServiceAccountSignaturePermission() + if permission.has_permission(request, None): + ip = request.data.get('remote_addr', '') ip = ip or get_request_ip(request) return ip @@ -195,6 +205,7 @@ def with_cache(func): res = func(*args, **kwargs) cache[key] = res return res + return wrapper @@ -216,6 +227,7 @@ def timeit(func): msg = "End call {}, using: {:.1f}ms".format(name, using) logger.debug(msg) return result + return wrapper @@ -310,7 +322,7 @@ class Time: def print(self): last, *timestamps = self._timestamps for timestamp, msg in zip(timestamps, self._msgs): - logger.debug(f'TIME_IT: {msg} {timestamp-last}') + logger.debug(f'TIME_IT: {msg} {timestamp - last}') last = timestamp @@ -366,7 +378,7 @@ def pretty_string(data: str, max_length=128, ellipsis_str='...'): def group_by_count(it, count): - return [it[i:i+count] for i in range(0, len(it), count)] + return [it[i:i + count] for i in range(0, len(it), count)] def test_ip_connectivity(host, port, timeout=0.5):