mirror of https://github.com/tp4a/teleport
web端做了很大改动,尚未完成。
parent
2d0ce5da20
commit
51b143c828
|
@ -1,8 +1,8 @@
|
|||
# mako/__init__.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
||||
|
||||
__version__ = '1.0.3'
|
||||
__version__ = '1.0.6'
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/_ast_util.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/ast.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/cache.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/cmd.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/codegen.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -5,6 +5,7 @@ py3k = sys.version_info >= (3, 0)
|
|||
py33 = sys.version_info >= (3, 3)
|
||||
py2k = sys.version_info < (3,)
|
||||
py26 = sys.version_info >= (2, 6)
|
||||
py27 = sys.version_info >= (2, 7)
|
||||
jython = sys.platform.startswith('java')
|
||||
win32 = sys.platform.startswith('win')
|
||||
pypy = hasattr(sys, 'pypy_version_info')
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/exceptions.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# ext/autohandler.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# ext/babelplugin.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# ext/preprocessors.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# ext/pygmentplugin.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# ext/turbogears.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/filters.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/lexer.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
@ -95,31 +95,37 @@ class Lexer(object):
|
|||
# (match and "TRUE" or "FALSE")
|
||||
return match
|
||||
|
||||
def parse_until_text(self, *text):
|
||||
def parse_until_text(self, watch_nesting, *text):
|
||||
startpos = self.match_position
|
||||
text_re = r'|'.join(text)
|
||||
brace_level = 0
|
||||
paren_level = 0
|
||||
bracket_level = 0
|
||||
while True:
|
||||
match = self.match(r'#.*\n')
|
||||
if match:
|
||||
continue
|
||||
match = self.match(r'(\"\"\"|\'\'\'|\"|\')((?<!\\)\\\1|.)*?\1',
|
||||
match = self.match(r'(\"\"\"|\'\'\'|\"|\')[^\\]*?(\\.[^\\]*?)*\1',
|
||||
re.S)
|
||||
if match:
|
||||
continue
|
||||
match = self.match(r'(%s)' % text_re)
|
||||
if match:
|
||||
if match.group(1) == '}' and brace_level > 0:
|
||||
brace_level -= 1
|
||||
continue
|
||||
if match and not (watch_nesting
|
||||
and (brace_level > 0 or paren_level > 0
|
||||
or bracket_level > 0)):
|
||||
return \
|
||||
self.text[startpos:
|
||||
self.match_position - len(match.group(1))],\
|
||||
match.group(1)
|
||||
match = self.match(r"(.*?)(?=\"|\'|#|%s)" % text_re, re.S)
|
||||
elif not match:
|
||||
match = self.match(r"(.*?)(?=\"|\'|#|%s)" % text_re, re.S)
|
||||
if match:
|
||||
brace_level += match.group(1).count('{')
|
||||
brace_level -= match.group(1).count('}')
|
||||
paren_level += match.group(1).count('(')
|
||||
paren_level -= match.group(1).count(')')
|
||||
bracket_level += match.group(1).count('[')
|
||||
bracket_level -= match.group(1).count(']')
|
||||
continue
|
||||
raise exceptions.SyntaxException(
|
||||
"Expected: %s" %
|
||||
|
@ -368,7 +374,7 @@ class Lexer(object):
|
|||
match = self.match(r"<%(!)?")
|
||||
if match:
|
||||
line, pos = self.matched_lineno, self.matched_charpos
|
||||
text, end = self.parse_until_text(r'%>')
|
||||
text, end = self.parse_until_text(False, r'%>')
|
||||
# the trailing newline helps
|
||||
# compiler.parse() not complain about indentation
|
||||
text = adjust_whitespace(text) + "\n"
|
||||
|
@ -384,9 +390,9 @@ class Lexer(object):
|
|||
match = self.match(r"\${")
|
||||
if match:
|
||||
line, pos = self.matched_lineno, self.matched_charpos
|
||||
text, end = self.parse_until_text(r'\|', r'}')
|
||||
text, end = self.parse_until_text(True, r'\|', r'}')
|
||||
if end == '|':
|
||||
escapes, end = self.parse_until_text(r'}')
|
||||
escapes, end = self.parse_until_text(True, r'}')
|
||||
else:
|
||||
escapes = ""
|
||||
text = text.replace('\r\n', '\n')
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/lookup.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
@ -96,7 +96,7 @@ class TemplateLookup(TemplateCollection):
|
|||
.. sourcecode:: python
|
||||
|
||||
lookup = TemplateLookup(["/path/to/templates"])
|
||||
some_template = lookup.get_template("/admin_index.mako")
|
||||
some_template = lookup.get_template("/index.html")
|
||||
|
||||
The :class:`.TemplateLookup` can also be given :class:`.Template` objects
|
||||
programatically using :meth:`.put_string` or :meth:`.put_template`:
|
||||
|
@ -180,7 +180,8 @@ class TemplateLookup(TemplateCollection):
|
|||
enable_loop=True,
|
||||
input_encoding=None,
|
||||
preprocessor=None,
|
||||
lexer_cls=None):
|
||||
lexer_cls=None,
|
||||
include_error_handler=None):
|
||||
|
||||
self.directories = [posixpath.normpath(d) for d in
|
||||
util.to_list(directories, ())
|
||||
|
@ -203,6 +204,7 @@ class TemplateLookup(TemplateCollection):
|
|||
self.template_args = {
|
||||
'format_exceptions': format_exceptions,
|
||||
'error_handler': error_handler,
|
||||
'include_error_handler': include_error_handler,
|
||||
'disable_unicode': disable_unicode,
|
||||
'bytestring_passthrough': bytestring_passthrough,
|
||||
'output_encoding': output_encoding,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/parsetree.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/pygen.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/pyparser.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/runtime.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
@ -749,7 +749,16 @@ def _include_file(context, uri, calling_uri, **kwargs):
|
|||
(callable_, ctx) = _populate_self_namespace(
|
||||
context._clean_inheritance_tokens(),
|
||||
template)
|
||||
callable_(ctx, **_kwargs_for_include(callable_, context._data, **kwargs))
|
||||
kwargs = _kwargs_for_include(callable_, context._data, **kwargs)
|
||||
if template.include_error_handler:
|
||||
try:
|
||||
callable_(ctx, **kwargs)
|
||||
except Exception:
|
||||
result = template.include_error_handler(ctx, compat.exception_as())
|
||||
if not result:
|
||||
compat.reraise(*sys.exc_info())
|
||||
else:
|
||||
callable_(ctx, **kwargs)
|
||||
|
||||
|
||||
def _inherit_from(context, uri, calling_uri):
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/template.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
@ -109,6 +109,11 @@ class Template(object):
|
|||
completes. Is used to provide custom error-rendering
|
||||
functions.
|
||||
|
||||
.. seealso::
|
||||
|
||||
:paramref:`.Template.include_error_handler` - include-specific
|
||||
error handler function
|
||||
|
||||
:param format_exceptions: if ``True``, exceptions which occur during
|
||||
the render phase of this template will be caught and
|
||||
formatted into an HTML error page, which then becomes the
|
||||
|
@ -129,6 +134,16 @@ class Template(object):
|
|||
import will not appear as the first executed statement in the generated
|
||||
code and will therefore not have the desired effect.
|
||||
|
||||
:param include_error_handler: An error handler that runs when this template
|
||||
is included within another one via the ``<%include>`` tag, and raises an
|
||||
error. Compare to the :paramref:`.Template.error_handler` option.
|
||||
|
||||
.. versionadded:: 1.0.6
|
||||
|
||||
.. seealso::
|
||||
|
||||
:paramref:`.Template.error_handler` - top-level error handler function
|
||||
|
||||
:param input_encoding: Encoding of the template's source code. Can
|
||||
be used in lieu of the coding comment. See
|
||||
:ref:`usage_unicode` as well as :ref:`unicode_toplevel` for
|
||||
|
@ -171,7 +186,7 @@ class Template(object):
|
|||
|
||||
from mako.template import Template
|
||||
mytemplate = Template(
|
||||
filename="admin_index.mako",
|
||||
filename="index.html",
|
||||
module_directory="/path/to/modules",
|
||||
module_writer=module_writer
|
||||
)
|
||||
|
@ -243,7 +258,8 @@ class Template(object):
|
|||
future_imports=None,
|
||||
enable_loop=True,
|
||||
preprocessor=None,
|
||||
lexer_cls=None):
|
||||
lexer_cls=None,
|
||||
include_error_handler=None):
|
||||
if uri:
|
||||
self.module_id = re.sub(r'\W', "_", uri)
|
||||
self.uri = uri
|
||||
|
@ -329,6 +345,7 @@ class Template(object):
|
|||
self.callable_ = self.module.render_body
|
||||
self.format_exceptions = format_exceptions
|
||||
self.error_handler = error_handler
|
||||
self.include_error_handler = include_error_handler
|
||||
self.lookup = lookup
|
||||
|
||||
self.module_directory = module_directory
|
||||
|
@ -475,6 +492,14 @@ class Template(object):
|
|||
|
||||
return DefTemplate(self, getattr(self.module, "render_%s" % name))
|
||||
|
||||
def list_defs(self):
|
||||
"""return a list of defs in the template.
|
||||
|
||||
.. versionadded:: 1.0.4
|
||||
|
||||
"""
|
||||
return [i[7:] for i in dir(self.module) if i[:7] == 'render_']
|
||||
|
||||
def _get_def_callable(self, name):
|
||||
return getattr(self.module, "render_%s" % name)
|
||||
|
||||
|
@ -520,6 +545,7 @@ class ModuleTemplate(Template):
|
|||
cache_type=None,
|
||||
cache_dir=None,
|
||||
cache_url=None,
|
||||
include_error_handler=None,
|
||||
):
|
||||
self.module_id = re.sub(r'\W', "_", module._template_uri)
|
||||
self.uri = module._template_uri
|
||||
|
@ -551,6 +577,7 @@ class ModuleTemplate(Template):
|
|||
self.callable_ = self.module.render_body
|
||||
self.format_exceptions = format_exceptions
|
||||
self.error_handler = error_handler
|
||||
self.include_error_handler = include_error_handler
|
||||
self.lookup = lookup
|
||||
self._setup_cache_args(
|
||||
cache_impl, cache_enabled, cache_args,
|
||||
|
@ -571,6 +598,7 @@ class DefTemplate(Template):
|
|||
self.encoding_errors = parent.encoding_errors
|
||||
self.format_exceptions = parent.format_exceptions
|
||||
self.error_handler = parent.error_handler
|
||||
self.include_error_handler = parent.include_error_handler
|
||||
self.enable_loop = parent.enable_loop
|
||||
self.lookup = parent.lookup
|
||||
self.bytestring_passthrough = parent.bytestring_passthrough
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# mako/util.py
|
||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
||||
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||
#
|
||||
# This module is part of Mako and is released under
|
||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
__version__ = '1.3.5'
|
|
@ -1,12 +0,0 @@
|
|||
# API Backwards compatibility
|
||||
|
||||
from pymemcache.client.base import Client # noqa
|
||||
from pymemcache.client.base import PooledClient # noqa
|
||||
|
||||
from pymemcache.exceptions import MemcacheError # noqa
|
||||
from pymemcache.exceptions import MemcacheClientError # noqa
|
||||
from pymemcache.exceptions import MemcacheUnknownCommandError # noqa
|
||||
from pymemcache.exceptions import MemcacheIllegalInputError # noqa
|
||||
from pymemcache.exceptions import MemcacheServerError # noqa
|
||||
from pymemcache.exceptions import MemcacheUnknownError # noqa
|
||||
from pymemcache.exceptions import MemcacheUnexpectedCloseError # noqa
|
File diff suppressed because it is too large
Load Diff
|
@ -1,333 +0,0 @@
|
|||
import socket
|
||||
import time
|
||||
import logging
|
||||
|
||||
from pymemcache.client.base import Client, PooledClient, _check_key
|
||||
from pymemcache.client.rendezvous import RendezvousHash
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HashClient(object):
|
||||
"""
|
||||
A client for communicating with a cluster of memcached servers
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
servers,
|
||||
hasher=RendezvousHash,
|
||||
serializer=None,
|
||||
deserializer=None,
|
||||
connect_timeout=None,
|
||||
timeout=None,
|
||||
no_delay=False,
|
||||
socket_module=socket,
|
||||
key_prefix=b'',
|
||||
max_pool_size=None,
|
||||
lock_generator=None,
|
||||
retry_attempts=2,
|
||||
retry_timeout=1,
|
||||
dead_timeout=60,
|
||||
use_pooling=False,
|
||||
ignore_exc=False,
|
||||
):
|
||||
"""
|
||||
Constructor.
|
||||
|
||||
Args:
|
||||
servers: list(tuple(hostname, port))
|
||||
hasher: optional class three functions ``get_node``, ``add_node``,
|
||||
and ``remove_node``
|
||||
defaults to Rendezvous (HRW) hash.
|
||||
|
||||
use_pooling: use py:class:`.PooledClient` as the default underlying
|
||||
class. ``max_pool_size`` and ``lock_generator`` can
|
||||
be used with this. default: False
|
||||
|
||||
retry_attempts: Amount of times a client should be tried before it
|
||||
is marked dead and removed from the pool.
|
||||
retry_timeout (float): Time in seconds that should pass between retry
|
||||
attempts.
|
||||
dead_timeout (float): Time in seconds before attempting to add a node
|
||||
back in the pool.
|
||||
|
||||
Further arguments are interpreted as for :py:class:`.Client`
|
||||
constructor.
|
||||
|
||||
The default ``hasher`` is using a pure python implementation that can
|
||||
be significantly improved performance wise by switching to a C based
|
||||
version. We recommend using ``python-clandestined`` if having a C
|
||||
dependency is acceptable.
|
||||
"""
|
||||
self.clients = {}
|
||||
self.retry_attempts = retry_attempts
|
||||
self.retry_timeout = retry_timeout
|
||||
self.dead_timeout = dead_timeout
|
||||
self.use_pooling = use_pooling
|
||||
self.key_prefix = key_prefix
|
||||
self.ignore_exc = ignore_exc
|
||||
self._failed_clients = {}
|
||||
self._dead_clients = {}
|
||||
self._last_dead_check_time = time.time()
|
||||
|
||||
self.hasher = hasher()
|
||||
|
||||
self.default_kwargs = {
|
||||
'connect_timeout': connect_timeout,
|
||||
'timeout': timeout,
|
||||
'no_delay': no_delay,
|
||||
'socket_module': socket_module,
|
||||
'key_prefix': key_prefix,
|
||||
'serializer': serializer,
|
||||
'deserializer': deserializer,
|
||||
}
|
||||
|
||||
if use_pooling is True:
|
||||
self.default_kwargs.update({
|
||||
'max_pool_size': max_pool_size,
|
||||
'lock_generator': lock_generator
|
||||
})
|
||||
|
||||
for server, port in servers:
|
||||
self.add_server(server, port)
|
||||
|
||||
def add_server(self, server, port):
|
||||
key = '%s:%s' % (server, port)
|
||||
|
||||
if self.use_pooling:
|
||||
client = PooledClient(
|
||||
(server, port),
|
||||
**self.default_kwargs
|
||||
)
|
||||
else:
|
||||
client = Client((server, port), **self.default_kwargs)
|
||||
|
||||
self.clients[key] = client
|
||||
self.hasher.add_node(key)
|
||||
|
||||
def remove_server(self, server, port):
|
||||
dead_time = time.time()
|
||||
self._failed_clients.pop((server, port))
|
||||
self._dead_clients[(server, port)] = dead_time
|
||||
key = '%s:%s' % (server, port)
|
||||
self.hasher.remove_node(key)
|
||||
|
||||
def _get_client(self, key):
|
||||
_check_key(key, self.key_prefix)
|
||||
if len(self._dead_clients) > 0:
|
||||
current_time = time.time()
|
||||
ldc = self._last_dead_check_time
|
||||
# we have dead clients and we have reached the
|
||||
# timeout retry
|
||||
if current_time - ldc > self.dead_timeout:
|
||||
for server, dead_time in self._dead_clients.items():
|
||||
if current_time - dead_time > self.dead_timeout:
|
||||
logger.debug(
|
||||
'bringing server back into rotation %s',
|
||||
server
|
||||
)
|
||||
self.add_server(*server)
|
||||
self._last_dead_check_time = current_time
|
||||
|
||||
server = self.hasher.get_node(key)
|
||||
# We've ran out of servers to try
|
||||
if server is None:
|
||||
if self.ignore_exc is True:
|
||||
return
|
||||
raise Exception('All servers seem to be down right now')
|
||||
|
||||
client = self.clients[server]
|
||||
return client
|
||||
|
||||
def _safely_run_func(self, client, func, default_val, *args, **kwargs):
|
||||
try:
|
||||
if client.server in self._failed_clients:
|
||||
# This server is currently failing, lets check if it is in
|
||||
# retry or marked as dead
|
||||
failed_metadata = self._failed_clients[client.server]
|
||||
|
||||
# we haven't tried our max amount yet, if it has been enough
|
||||
# time lets just retry using it
|
||||
if failed_metadata['attempts'] < self.retry_attempts:
|
||||
failed_time = failed_metadata['failed_time']
|
||||
if time.time() - failed_time > self.retry_timeout:
|
||||
logger.debug(
|
||||
'retrying failed server: %s', client.server
|
||||
)
|
||||
result = func(*args, **kwargs)
|
||||
# we were successful, lets remove it from the failed
|
||||
# clients
|
||||
self._failed_clients.pop(client.server)
|
||||
return result
|
||||
return default_val
|
||||
else:
|
||||
# We've reached our max retry attempts, we need to mark
|
||||
# the sever as dead
|
||||
logger.debug('marking server as dead: %s', client.server)
|
||||
self.remove_server(*client.server)
|
||||
|
||||
result = func(*args, **kwargs)
|
||||
return result
|
||||
|
||||
# Connecting to the server fail, we should enter
|
||||
# retry mode
|
||||
except socket.error:
|
||||
# This client has never failed, lets mark it for failure
|
||||
if (
|
||||
client.server not in self._failed_clients and
|
||||
self.retry_attempts > 0
|
||||
):
|
||||
self._failed_clients[client.server] = {
|
||||
'failed_time': time.time(),
|
||||
'attempts': 0,
|
||||
}
|
||||
# We aren't allowing any retries, we should mark the server as
|
||||
# dead immediately
|
||||
elif (
|
||||
client.server not in self._failed_clients and
|
||||
self.retry_attempts <= 0
|
||||
):
|
||||
self._failed_clients[client.server] = {
|
||||
'failed_time': time.time(),
|
||||
'attempts': 0,
|
||||
}
|
||||
logger.debug("marking server as dead %s", client.server)
|
||||
self.remove_server(*client.server)
|
||||
# This client has failed previously, we need to update the metadata
|
||||
# to reflect that we have attempted it again
|
||||
else:
|
||||
failed_metadata = self._failed_clients[client.server]
|
||||
failed_metadata['attempts'] += 1
|
||||
failed_metadata['failed_time'] = time.time()
|
||||
self._failed_clients[client.server] = failed_metadata
|
||||
|
||||
# if we haven't enabled ignore_exc, don't move on gracefully, just
|
||||
# raise the exception
|
||||
if not self.ignore_exc:
|
||||
raise
|
||||
|
||||
return default_val
|
||||
except:
|
||||
# any exceptions that aren't socket.error we need to handle
|
||||
# gracefully as well
|
||||
if not self.ignore_exc:
|
||||
raise
|
||||
|
||||
return default_val
|
||||
|
||||
def _run_cmd(self, cmd, key, default_val, *args, **kwargs):
|
||||
client = self._get_client(key)
|
||||
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
func = getattr(client, cmd)
|
||||
args = list(args)
|
||||
args.insert(0, key)
|
||||
return self._safely_run_func(
|
||||
client, func, default_val, *args, **kwargs
|
||||
)
|
||||
|
||||
def set(self, key, *args, **kwargs):
|
||||
return self._run_cmd('set', key, False, *args, **kwargs)
|
||||
|
||||
def get(self, key, *args, **kwargs):
|
||||
return self._run_cmd('get', key, None, *args, **kwargs)
|
||||
|
||||
def incr(self, key, *args, **kwargs):
|
||||
return self._run_cmd('incr', key, False, *args, **kwargs)
|
||||
|
||||
def decr(self, key, *args, **kwargs):
|
||||
return self._run_cmd('decr', key, False, *args, **kwargs)
|
||||
|
||||
def set_many(self, values, *args, **kwargs):
|
||||
client_batches = {}
|
||||
end = []
|
||||
|
||||
for key, value in values.items():
|
||||
client = self._get_client(key)
|
||||
|
||||
if client is None:
|
||||
end.append(False)
|
||||
continue
|
||||
|
||||
if client.server not in client_batches:
|
||||
client_batches[client.server] = {}
|
||||
|
||||
client_batches[client.server][key] = value
|
||||
|
||||
for server, values in client_batches.items():
|
||||
client = self.clients['%s:%s' % server]
|
||||
new_args = list(args)
|
||||
new_args.insert(0, values)
|
||||
result = self._safely_run_func(
|
||||
client,
|
||||
client.set_many, False, *new_args, **kwargs
|
||||
)
|
||||
end.append(result)
|
||||
|
||||
return all(end)
|
||||
|
||||
set_multi = set_many
|
||||
|
||||
def get_many(self, keys, *args, **kwargs):
|
||||
client_batches = {}
|
||||
end = {}
|
||||
|
||||
for key in keys:
|
||||
client = self._get_client(key)
|
||||
|
||||
if client is None:
|
||||
end[key] = False
|
||||
continue
|
||||
|
||||
if client.server not in client_batches:
|
||||
client_batches[client.server] = []
|
||||
|
||||
client_batches[client.server].append(key)
|
||||
|
||||
for server, keys in client_batches.items():
|
||||
client = self.clients['%s:%s' % server]
|
||||
new_args = list(args)
|
||||
new_args.insert(0, keys)
|
||||
result = self._safely_run_func(
|
||||
client,
|
||||
client.get_many, {}, *new_args, **kwargs
|
||||
)
|
||||
end.update(result)
|
||||
|
||||
return end
|
||||
|
||||
get_multi = get_many
|
||||
|
||||
def gets(self, key, *args, **kwargs):
|
||||
return self._run_cmd('gets', key, None, *args, **kwargs)
|
||||
|
||||
def add(self, key, *args, **kwargs):
|
||||
return self._run_cmd('add', key, False, *args, **kwargs)
|
||||
|
||||
def prepend(self, key, *args, **kwargs):
|
||||
return self._run_cmd('prepend', key, False, *args, **kwargs)
|
||||
|
||||
def append(self, key, *args, **kwargs):
|
||||
return self._run_cmd('append', key, False, *args, **kwargs)
|
||||
|
||||
def delete(self, key, *args, **kwargs):
|
||||
return self._run_cmd('delete', key, False, *args, **kwargs)
|
||||
|
||||
def delete_many(self, keys, *args, **kwargs):
|
||||
for key in keys:
|
||||
self._run_cmd('delete', key, False, *args, **kwargs)
|
||||
return True
|
||||
|
||||
delete_multi = delete_many
|
||||
|
||||
def cas(self, key, *args, **kwargs):
|
||||
return self._run_cmd('cas', key, False, *args, **kwargs)
|
||||
|
||||
def replace(self, key, *args, **kwargs):
|
||||
return self._run_cmd('replace', key, False, *args, **kwargs)
|
||||
|
||||
def flush_all(self):
|
||||
for _, client in self.clients.items():
|
||||
self._safely_run_func(client, client.flush_all, False)
|
|
@ -1,51 +0,0 @@
|
|||
def murmur3_32(data, seed=0):
|
||||
"""MurmurHash3 was written by Austin Appleby, and is placed in the
|
||||
public domain. The author hereby disclaims copyright to this source
|
||||
code."""
|
||||
|
||||
c1 = 0xcc9e2d51
|
||||
c2 = 0x1b873593
|
||||
|
||||
length = len(data)
|
||||
h1 = seed
|
||||
roundedEnd = (length & 0xfffffffc) # round down to 4 byte block
|
||||
for i in range(0, roundedEnd, 4):
|
||||
# little endian load order
|
||||
k1 = (ord(data[i]) & 0xff) | ((ord(data[i + 1]) & 0xff) << 8) | \
|
||||
((ord(data[i + 2]) & 0xff) << 16) | (ord(data[i + 3]) << 24)
|
||||
k1 *= c1
|
||||
k1 = (k1 << 15) | ((k1 & 0xffffffff) >> 17) # ROTL32(k1,15)
|
||||
k1 *= c2
|
||||
|
||||
h1 ^= k1
|
||||
h1 = (h1 << 13) | ((h1 & 0xffffffff) >> 19) # ROTL32(h1,13)
|
||||
h1 = h1 * 5 + 0xe6546b64
|
||||
|
||||
# tail
|
||||
k1 = 0
|
||||
|
||||
val = length & 0x03
|
||||
if val == 3:
|
||||
k1 = (ord(data[roundedEnd + 2]) & 0xff) << 16
|
||||
# fallthrough
|
||||
if val in [2, 3]:
|
||||
k1 |= (ord(data[roundedEnd + 1]) & 0xff) << 8
|
||||
# fallthrough
|
||||
if val in [1, 2, 3]:
|
||||
k1 |= ord(data[roundedEnd]) & 0xff
|
||||
k1 *= c1
|
||||
k1 = (k1 << 15) | ((k1 & 0xffffffff) >> 17) # ROTL32(k1,15)
|
||||
k1 *= c2
|
||||
h1 ^= k1
|
||||
|
||||
# finalization
|
||||
h1 ^= length
|
||||
|
||||
# fmix(h1)
|
||||
h1 ^= ((h1 & 0xffffffff) >> 16)
|
||||
h1 *= 0x85ebca6b
|
||||
h1 ^= ((h1 & 0xffffffff) >> 13)
|
||||
h1 *= 0xc2b2ae35
|
||||
h1 ^= ((h1 & 0xffffffff) >> 16)
|
||||
|
||||
return h1 & 0xffffffff
|
|
@ -1,46 +0,0 @@
|
|||
from pymemcache.client.murmur3 import murmur3_32
|
||||
|
||||
|
||||
class RendezvousHash(object):
|
||||
"""
|
||||
Implements the Highest Random Weight (HRW) hashing algorithm most
|
||||
commonly referred to as rendezvous hashing.
|
||||
|
||||
Originally developed as part of python-clandestined.
|
||||
|
||||
Copyright (c) 2014 Ernest W. Durbin III
|
||||
"""
|
||||
def __init__(self, nodes=None, seed=0, hash_function=murmur3_32):
|
||||
"""
|
||||
Constructor.
|
||||
"""
|
||||
self.nodes = []
|
||||
self.seed = seed
|
||||
if nodes is not None:
|
||||
self.nodes = nodes
|
||||
self.hash_function = lambda x: hash_function(x, seed)
|
||||
|
||||
def add_node(self, node):
|
||||
if node not in self.nodes:
|
||||
self.nodes.append(node)
|
||||
|
||||
def remove_node(self, node):
|
||||
if node in self.nodes:
|
||||
self.nodes.remove(node)
|
||||
else:
|
||||
raise ValueError("No such node %s to remove" % (node))
|
||||
|
||||
def get_node(self, key):
|
||||
high_score = -1
|
||||
winner = None
|
||||
|
||||
for node in self.nodes:
|
||||
score = self.hash_function(
|
||||
"%s-%s" % (str(node), str(key)))
|
||||
|
||||
if score > high_score:
|
||||
(high_score, winner) = (score, node)
|
||||
elif score == high_score:
|
||||
(high_score, winner) = (score, max(str(node), str(winner)))
|
||||
|
||||
return winner
|
|
@ -1,40 +0,0 @@
|
|||
class MemcacheError(Exception):
|
||||
"Base exception class"
|
||||
pass
|
||||
|
||||
|
||||
class MemcacheClientError(MemcacheError):
|
||||
"""Raised when memcached fails to parse the arguments to a request, likely
|
||||
due to a malformed key and/or value, a bug in this library, or a version
|
||||
mismatch with memcached."""
|
||||
pass
|
||||
|
||||
|
||||
class MemcacheUnknownCommandError(MemcacheClientError):
|
||||
"""Raised when memcached fails to parse a request, likely due to a bug in
|
||||
this library or a version mismatch with memcached."""
|
||||
pass
|
||||
|
||||
|
||||
class MemcacheIllegalInputError(MemcacheClientError):
|
||||
"""Raised when a key or value is not legal for Memcache (see the class docs
|
||||
for Client for more details)."""
|
||||
pass
|
||||
|
||||
|
||||
class MemcacheServerError(MemcacheError):
|
||||
"""Raised when memcached reports a failure while processing a request,
|
||||
likely due to a bug or transient issue in memcached."""
|
||||
pass
|
||||
|
||||
|
||||
class MemcacheUnknownError(MemcacheError):
|
||||
"""Raised when this library receives a response from memcached that it
|
||||
cannot parse, likely due to a bug in this library or a version mismatch
|
||||
with memcached."""
|
||||
pass
|
||||
|
||||
|
||||
class MemcacheUnexpectedCloseError(MemcacheServerError):
|
||||
"Raised when the connection with memcached closes unexpectedly."
|
||||
pass
|
|
@ -1,123 +0,0 @@
|
|||
# Copyright 2012 Pinterest.com
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
A client for falling back to older memcached servers when performing reads.
|
||||
|
||||
It is sometimes necessary to deploy memcached on new servers, or with a
|
||||
different configuration. In theses cases, it is undesirable to start up an
|
||||
empty memcached server and point traffic to it, since the cache will be cold,
|
||||
and the backing store will have a large increase in traffic.
|
||||
|
||||
This class attempts to solve that problem by providing an interface identical
|
||||
to the Client interface, but which can fall back to older memcached servers
|
||||
when reads to the primary server fail. The approach for upgrading memcached
|
||||
servers or configuration then becomes:
|
||||
|
||||
1. Deploy a new host (or fleet) with memcached, possibly with a new
|
||||
configuration.
|
||||
2. From your application servers, use FallbackClient to write and read from
|
||||
the new cluster, and to read from the old cluster when there is a miss in
|
||||
the new cluster.
|
||||
3. Wait until the new cache is warm enough to support the load.
|
||||
4. Switch from FallbackClient to a regular Client library for doing all
|
||||
reads and writes to the new cluster.
|
||||
5. Take down the old cluster.
|
||||
|
||||
Best Practices:
|
||||
---------------
|
||||
- Make sure that the old client has "ignore_exc" set to True, so that it
|
||||
treats failures like cache misses. That will allow you to take down the
|
||||
old cluster before you switch away from FallbackClient.
|
||||
"""
|
||||
|
||||
|
||||
class FallbackClient(object):
|
||||
def __init__(self, caches):
|
||||
assert len(caches) > 0
|
||||
self.caches = caches
|
||||
|
||||
def close(self):
|
||||
"Close each of the memcached clients"
|
||||
for cache in self.caches:
|
||||
cache.close()
|
||||
|
||||
def set(self, key, value, expire=0, noreply=True):
|
||||
self.caches[0].set(key, value, expire, noreply)
|
||||
|
||||
def add(self, key, value, expire=0, noreply=True):
|
||||
self.caches[0].add(key, value, expire, noreply)
|
||||
|
||||
def replace(self, key, value, expire=0, noreply=True):
|
||||
self.caches[0].replace(key, value, expire, noreply)
|
||||
|
||||
def append(self, key, value, expire=0, noreply=True):
|
||||
self.caches[0].append(key, value, expire, noreply)
|
||||
|
||||
def prepend(self, key, value, expire=0, noreply=True):
|
||||
self.caches[0].prepend(key, value, expire, noreply)
|
||||
|
||||
def cas(self, key, value, cas, expire=0, noreply=True):
|
||||
self.caches[0].cas(key, value, cas, expire, noreply)
|
||||
|
||||
def get(self, key):
|
||||
for cache in self.caches:
|
||||
result = cache.get(key)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
def get_many(self, keys):
|
||||
for cache in self.caches:
|
||||
result = cache.get_many(keys)
|
||||
if result:
|
||||
return result
|
||||
return []
|
||||
|
||||
def gets(self, key):
|
||||
for cache in self.caches:
|
||||
result = cache.gets(key)
|
||||
if result is not None:
|
||||
return result
|
||||
return None
|
||||
|
||||
def gets_many(self, keys):
|
||||
for cache in self.caches:
|
||||
result = cache.gets_many(keys)
|
||||
if result:
|
||||
return result
|
||||
return []
|
||||
|
||||
def delete(self, key, noreply=True):
|
||||
self.caches[0].delete(key, noreply)
|
||||
|
||||
def incr(self, key, value, noreply=True):
|
||||
self.caches[0].incr(key, value, noreply)
|
||||
|
||||
def decr(self, key, value, noreply=True):
|
||||
self.caches[0].decr(key, value, noreply)
|
||||
|
||||
def touch(self, key, expire=0, noreply=True):
|
||||
self.caches[0].touch(key, expire, noreply)
|
||||
|
||||
def stats(self):
|
||||
# TODO: ??
|
||||
pass
|
||||
|
||||
def flush_all(self, delay=0, noreply=True):
|
||||
self.caches[0].flush_all(delay, noreply)
|
||||
|
||||
def quit(self):
|
||||
# TODO: ??
|
||||
pass
|
|
@ -1,114 +0,0 @@
|
|||
# Copyright 2015 Yahoo.com
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import sys
|
||||
import threading
|
||||
|
||||
import six
|
||||
|
||||
|
||||
class ObjectPool(object):
|
||||
"""A pool of objects that release/creates/destroys as needed."""
|
||||
|
||||
def __init__(self, obj_creator,
|
||||
after_remove=None, max_size=None,
|
||||
lock_generator=None):
|
||||
self._used_objs = collections.deque()
|
||||
self._free_objs = collections.deque()
|
||||
self._obj_creator = obj_creator
|
||||
if lock_generator is None:
|
||||
self._lock = threading.Lock()
|
||||
else:
|
||||
self._lock = lock_generator()
|
||||
self._after_remove = after_remove
|
||||
max_size = max_size or 2 ** 31
|
||||
if not isinstance(max_size, six.integer_types) or max_size < 0:
|
||||
raise ValueError('"max_size" must be a positive integer')
|
||||
self.max_size = max_size
|
||||
|
||||
@property
|
||||
def used(self):
|
||||
return tuple(self._used_objs)
|
||||
|
||||
@property
|
||||
def free(self):
|
||||
return tuple(self._free_objs)
|
||||
|
||||
@contextlib.contextmanager
|
||||
def get_and_release(self, destroy_on_fail=False):
|
||||
obj = self.get()
|
||||
try:
|
||||
yield obj
|
||||
except Exception:
|
||||
exc_info = sys.exc_info()
|
||||
if not destroy_on_fail:
|
||||
self.release(obj)
|
||||
else:
|
||||
self.destroy(obj)
|
||||
six.reraise(exc_info[0], exc_info[1], exc_info[2])
|
||||
self.release(obj)
|
||||
|
||||
def get(self):
|
||||
with self._lock:
|
||||
if not self._free_objs:
|
||||
curr_count = len(self._used_objs)
|
||||
if curr_count >= self.max_size:
|
||||
raise RuntimeError("Too many objects,"
|
||||
" %s >= %s" % (curr_count,
|
||||
self.max_size))
|
||||
obj = self._obj_creator()
|
||||
self._used_objs.append(obj)
|
||||
return obj
|
||||
else:
|
||||
obj = self._free_objs.pop()
|
||||
self._used_objs.append(obj)
|
||||
return obj
|
||||
|
||||
def destroy(self, obj, silent=True):
|
||||
was_dropped = False
|
||||
with self._lock:
|
||||
try:
|
||||
self._used_objs.remove(obj)
|
||||
was_dropped = True
|
||||
except ValueError:
|
||||
if not silent:
|
||||
raise
|
||||
if was_dropped and self._after_remove is not None:
|
||||
self._after_remove(obj)
|
||||
|
||||
def release(self, obj, silent=True):
|
||||
with self._lock:
|
||||
try:
|
||||
self._used_objs.remove(obj)
|
||||
self._free_objs.append(obj)
|
||||
except ValueError:
|
||||
if not silent:
|
||||
raise
|
||||
|
||||
def clear(self):
|
||||
if self._after_remove is not None:
|
||||
needs_destroy = []
|
||||
with self._lock:
|
||||
needs_destroy.extend(self._used_objs)
|
||||
needs_destroy.extend(self._free_objs)
|
||||
self._free_objs.clear()
|
||||
self._used_objs.clear()
|
||||
for obj in needs_destroy:
|
||||
self._after_remove(obj)
|
||||
else:
|
||||
with self._lock:
|
||||
self._free_objs.clear()
|
||||
self._used_objs.clear()
|
|
@ -1,69 +0,0 @@
|
|||
# Copyright 2012 Pinterest.com
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import pickle
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO
|
||||
except ImportError:
|
||||
from StringIO import StringIO
|
||||
|
||||
|
||||
FLAG_PICKLE = 1 << 0
|
||||
FLAG_INTEGER = 1 << 1
|
||||
FLAG_LONG = 1 << 2
|
||||
|
||||
|
||||
def python_memcache_serializer(key, value):
|
||||
flags = 0
|
||||
|
||||
if isinstance(value, str):
|
||||
pass
|
||||
elif isinstance(value, int):
|
||||
flags |= FLAG_INTEGER
|
||||
value = "%d" % value
|
||||
elif isinstance(value, long):
|
||||
flags |= FLAG_LONG
|
||||
value = "%d" % value
|
||||
else:
|
||||
flags |= FLAG_PICKLE
|
||||
output = StringIO()
|
||||
pickler = pickle.Pickler(output, 0)
|
||||
pickler.dump(value)
|
||||
value = output.getvalue()
|
||||
|
||||
return value, flags
|
||||
|
||||
|
||||
def python_memcache_deserializer(key, value, flags):
|
||||
if flags == 0:
|
||||
return value
|
||||
|
||||
if flags & FLAG_INTEGER:
|
||||
return int(value)
|
||||
|
||||
if flags & FLAG_LONG:
|
||||
return long(value)
|
||||
|
||||
if flags & FLAG_PICKLE:
|
||||
try:
|
||||
buf = StringIO(value)
|
||||
unpickler = pickle.Unpickler(buf)
|
||||
return unpickler.load()
|
||||
except Exception:
|
||||
logging.info('Pickle error', exc_info=True)
|
||||
return None
|
||||
|
||||
return value
|
|
@ -1,7 +1,7 @@
|
|||
'''
|
||||
"""
|
||||
PyMySQL: A pure-Python MySQL client library.
|
||||
|
||||
Copyright (c) 2010, 2013 PyMySQL contributors
|
||||
Copyright (c) 2010-2016 PyMySQL contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
|
@ -20,30 +20,29 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||
THE SOFTWARE.
|
||||
|
||||
'''
|
||||
|
||||
VERSION = (0, 6, 7, None)
|
||||
|
||||
from ._compat import text_type, JYTHON, IRONPYTHON
|
||||
from .constants import FIELD_TYPE
|
||||
from .converters import escape_dict, escape_sequence, escape_string
|
||||
from .err import Warning, Error, InterfaceError, DataError, \
|
||||
DatabaseError, OperationalError, IntegrityError, InternalError, \
|
||||
NotSupportedError, ProgrammingError, MySQLError
|
||||
from .times import Date, Time, Timestamp, \
|
||||
DateFromTicks, TimeFromTicks, TimestampFromTicks
|
||||
|
||||
"""
|
||||
import sys
|
||||
|
||||
from ._compat import PY2
|
||||
from .constants import FIELD_TYPE
|
||||
from .converters import escape_dict, escape_sequence, escape_string
|
||||
from .err import (
|
||||
Warning, Error, InterfaceError, DataError,
|
||||
DatabaseError, OperationalError, IntegrityError, InternalError,
|
||||
NotSupportedError, ProgrammingError, MySQLError)
|
||||
from .times import (
|
||||
Date, Time, Timestamp,
|
||||
DateFromTicks, TimeFromTicks, TimestampFromTicks)
|
||||
|
||||
|
||||
VERSION = (0, 7, 11, None)
|
||||
threadsafety = 1
|
||||
apilevel = "2.0"
|
||||
paramstyle = "format"
|
||||
paramstyle = "pyformat"
|
||||
|
||||
|
||||
class DBAPISet(frozenset):
|
||||
|
||||
|
||||
def __ne__(self, other):
|
||||
if isinstance(other, set):
|
||||
return frozenset.__ne__(self, other)
|
||||
|
@ -73,11 +72,14 @@ TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME])
|
|||
DATETIME = TIMESTAMP
|
||||
ROWID = DBAPISet()
|
||||
|
||||
|
||||
def Binary(x):
|
||||
"""Return x as a binary type."""
|
||||
if isinstance(x, text_type) and not (JYTHON or IRONPYTHON):
|
||||
return x.encode()
|
||||
return bytes(x)
|
||||
if PY2:
|
||||
return bytearray(x)
|
||||
else:
|
||||
return bytes(x)
|
||||
|
||||
|
||||
def Connect(*args, **kwargs):
|
||||
"""
|
||||
|
@ -87,27 +89,26 @@ def Connect(*args, **kwargs):
|
|||
from .connections import Connection
|
||||
return Connection(*args, **kwargs)
|
||||
|
||||
from pymysql import connections as _orig_conn
|
||||
from . import connections as _orig_conn
|
||||
if _orig_conn.Connection.__init__.__doc__ is not None:
|
||||
Connect.__doc__ = _orig_conn.Connection.__init__.__doc__ + ("""
|
||||
See connections.Connection.__init__() for information about defaults.
|
||||
""")
|
||||
Connect.__doc__ = _orig_conn.Connection.__init__.__doc__
|
||||
del _orig_conn
|
||||
|
||||
|
||||
def get_client_info(): # for MySQLdb compatibility
|
||||
return '.'.join(map(str, VERSION))
|
||||
|
||||
connect = Connection = Connect
|
||||
|
||||
# we include a doctored version_info here for MySQLdb compatibility
|
||||
version_info = (1,2,2,"final",0)
|
||||
version_info = (1,2,6,"final",0)
|
||||
|
||||
NULL = "NULL"
|
||||
|
||||
__version__ = get_client_info()
|
||||
|
||||
def thread_safe():
|
||||
return True # match MySQLdb.thread_safe()
|
||||
return True # match MySQLdb.thread_safe()
|
||||
|
||||
def install_as_MySQLdb():
|
||||
"""
|
||||
|
@ -116,6 +117,7 @@ def install_as_MySQLdb():
|
|||
"""
|
||||
sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"]
|
||||
|
||||
|
||||
__all__ = [
|
||||
'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date',
|
||||
'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks',
|
||||
|
@ -128,6 +130,5 @@ __all__ = [
|
|||
'paramstyle', 'threadsafety', 'version_info',
|
||||
|
||||
"install_as_MySQLdb",
|
||||
|
||||
"NULL","__version__",
|
||||
]
|
||||
"NULL", "__version__",
|
||||
]
|
||||
|
|
|
@ -7,12 +7,15 @@ IRONPYTHON = sys.platform == 'cli'
|
|||
CPYTHON = not PYPY and not JYTHON and not IRONPYTHON
|
||||
|
||||
if PY2:
|
||||
import __builtin__
|
||||
range_type = xrange
|
||||
text_type = unicode
|
||||
long_type = long
|
||||
str_type = basestring
|
||||
unichr = __builtin__.unichr
|
||||
else:
|
||||
range_type = range
|
||||
text_type = str
|
||||
long_type = int
|
||||
str_type = str
|
||||
unichr = chr
|
||||
|
|
|
@ -11,6 +11,10 @@ class Charset(object):
|
|||
self.id, self.name, self.collation = id, name, collation
|
||||
self.is_default = is_default == 'Yes'
|
||||
|
||||
def __repr__(self):
|
||||
return "Charset(id=%s, name=%r, collation=%r)" % (
|
||||
self.id, self.name, self.collation)
|
||||
|
||||
@property
|
||||
def encoding(self):
|
||||
name = self.name
|
||||
|
@ -249,6 +253,10 @@ _charsets.add(Charset(240, 'utf8mb4', 'utf8mb4_persian_ci', ''))
|
|||
_charsets.add(Charset(241, 'utf8mb4', 'utf8mb4_esperanto_ci', ''))
|
||||
_charsets.add(Charset(242, 'utf8mb4', 'utf8mb4_hungarian_ci', ''))
|
||||
_charsets.add(Charset(243, 'utf8mb4', 'utf8mb4_sinhala_ci', ''))
|
||||
_charsets.add(Charset(244, 'utf8mb4', 'utf8mb4_german2_ci', ''))
|
||||
_charsets.add(Charset(245, 'utf8mb4', 'utf8mb4_croatian_ci', ''))
|
||||
_charsets.add(Charset(246, 'utf8mb4', 'utf8mb4_unicode_520_ci', ''))
|
||||
_charsets.add(Charset(247, 'utf8mb4', 'utf8mb4_vietnamese_ci', ''))
|
||||
|
||||
|
||||
charset_by_name = _charsets.by_name
|
||||
|
|
|
@ -17,9 +17,8 @@ import traceback
|
|||
import warnings
|
||||
|
||||
from .charset import MBLENGTH, charset_by_name, charset_by_id
|
||||
from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS
|
||||
from .converters import (
|
||||
escape_item, encoders, decoders, escape_string, through)
|
||||
from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
|
||||
from .converters import escape_item, escape_string, through, conversions as _conv
|
||||
from .cursors import Cursor
|
||||
from .optionfile import Parser
|
||||
from .util import byte2int, int2byte
|
||||
|
@ -36,7 +35,8 @@ try:
|
|||
import getpass
|
||||
DEFAULT_USER = getpass.getuser()
|
||||
del getpass
|
||||
except ImportError:
|
||||
except (ImportError, KeyError):
|
||||
# KeyError occurs when there's no entry in OS database for a current user.
|
||||
DEFAULT_USER = None
|
||||
|
||||
|
||||
|
@ -117,26 +117,24 @@ def dump_packet(data): # pragma: no cover
|
|||
|
||||
try:
|
||||
print("packet length:", len(data))
|
||||
print("method call[1]:", sys._getframe(1).f_code.co_name)
|
||||
print("method call[2]:", sys._getframe(2).f_code.co_name)
|
||||
print("method call[3]:", sys._getframe(3).f_code.co_name)
|
||||
print("method call[4]:", sys._getframe(4).f_code.co_name)
|
||||
print("method call[5]:", sys._getframe(5).f_code.co_name)
|
||||
print("-" * 88)
|
||||
for i in range(1, 6):
|
||||
f = sys._getframe(i)
|
||||
print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno))
|
||||
print("-" * 66)
|
||||
except ValueError:
|
||||
pass
|
||||
dump_data = [data[i:i+16] for i in range_type(0, min(len(data), 256), 16)]
|
||||
for d in dump_data:
|
||||
print(' '.join(map(lambda x: "{:02X}".format(byte2int(x)), d)) +
|
||||
' ' * (16 - len(d)) + ' ' * 2 +
|
||||
' '.join(map(lambda x: "{}".format(is_ascii(x)), d)))
|
||||
print("-" * 88)
|
||||
''.join(map(lambda x: "{}".format(is_ascii(x)), d)))
|
||||
print("-" * 66)
|
||||
print()
|
||||
|
||||
|
||||
def _scramble(password, message):
|
||||
if not password:
|
||||
return b'\0'
|
||||
return b''
|
||||
if DEBUG: print('password=' + str(password))
|
||||
stage1 = sha_new(password).digest()
|
||||
stage2 = sha_new(stage1).digest()
|
||||
|
@ -149,7 +147,7 @@ def _scramble(password, message):
|
|||
|
||||
def _my_crypt(message1, message2):
|
||||
length = len(message1)
|
||||
result = struct.pack('B', length)
|
||||
result = b''
|
||||
for i in range_type(length):
|
||||
x = (struct.unpack('B', message1[i:i+1])[0] ^
|
||||
struct.unpack('B', message2[i:i+1])[0])
|
||||
|
@ -196,7 +194,8 @@ def _hash_password_323(password):
|
|||
add = 7
|
||||
nr2 = 0x12345671
|
||||
|
||||
for c in [byte2int(x) for x in password if x not in (' ', '\t')]:
|
||||
# x in py3 is numbers, p27 is chars
|
||||
for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]:
|
||||
nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
|
||||
nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
|
||||
add = (add + c) & 0xFFFFFFFF
|
||||
|
@ -209,6 +208,20 @@ def _hash_password_323(password):
|
|||
def pack_int24(n):
|
||||
return struct.pack('<I', n)[:3]
|
||||
|
||||
# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
|
||||
def lenenc_int(i):
|
||||
if (i < 0):
|
||||
raise ValueError("Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i)
|
||||
elif (i < 0xfb):
|
||||
return int2byte(i)
|
||||
elif (i < (1 << 16)):
|
||||
return b'\xfc' + struct.pack('<H', i)
|
||||
elif (i < (1 << 24)):
|
||||
return b'\xfd' + struct.pack('<I', i)[:3]
|
||||
elif (i < (1 << 64)):
|
||||
return b'\xfe' + struct.pack('<Q', i)
|
||||
else:
|
||||
raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64)))
|
||||
|
||||
class MysqlPacket(object):
|
||||
"""Representation of a MySQL response packet.
|
||||
|
@ -303,6 +316,14 @@ class MysqlPacket(object):
|
|||
self._position += 8
|
||||
return result
|
||||
|
||||
def read_string(self):
|
||||
end_pos = self._data.find(b'\0', self._position)
|
||||
if end_pos < 0:
|
||||
return None
|
||||
result = self._data[self._position:end_pos]
|
||||
self._position = end_pos + 1
|
||||
return result
|
||||
|
||||
def read_length_encoded_integer(self):
|
||||
"""Read a 'Length Coded Binary' number from the data buffer.
|
||||
|
||||
|
@ -340,13 +361,18 @@ class MysqlPacket(object):
|
|||
return result
|
||||
|
||||
def is_ok_packet(self):
|
||||
return self._data[0:1] == b'\0'
|
||||
# https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
|
||||
return self._data[0:1] == b'\0' and len(self._data) >= 7
|
||||
|
||||
def is_eof_packet(self):
|
||||
# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
|
||||
# Caution: \xFE may be LengthEncodedInteger.
|
||||
# If \xFE is LengthEncodedInteger header, 8bytes followed.
|
||||
return len(self._data) < 9 and self._data[0:1] == b'\xfe'
|
||||
return self._data[0:1] == b'\xfe' and len(self._data) < 9
|
||||
|
||||
def is_auth_switch_request(self):
|
||||
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
|
||||
return self._data[0:1] == b'\xfe'
|
||||
|
||||
def is_resultset_packet(self):
|
||||
field_count = ord(self._data[0:1])
|
||||
|
@ -379,9 +405,9 @@ class FieldDescriptorPacket(MysqlPacket):
|
|||
|
||||
def __init__(self, data, encoding):
|
||||
MysqlPacket.__init__(self, data, encoding)
|
||||
self.__parse_field_descriptor(encoding)
|
||||
self._parse_field_descriptor(encoding)
|
||||
|
||||
def __parse_field_descriptor(self, encoding):
|
||||
def _parse_field_descriptor(self, encoding):
|
||||
"""Parse the 'Field Descriptor' (Metadata) packet.
|
||||
|
||||
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
|
||||
|
@ -494,20 +520,23 @@ class Connection(object):
|
|||
|
||||
The proper way to get an instance of this class is to call
|
||||
connect().
|
||||
|
||||
"""
|
||||
|
||||
socket = None
|
||||
_sock = None
|
||||
_auth_plugin_name = ''
|
||||
_closed = False
|
||||
|
||||
def __init__(self, host="localhost", user=None, password="",
|
||||
database=None, port=3306, unix_socket=None,
|
||||
def __init__(self, host=None, user=None, password="",
|
||||
database=None, port=0, unix_socket=None,
|
||||
charset='', sql_mode=None,
|
||||
read_default_file=None, conv=decoders, use_unicode=None,
|
||||
read_default_file=None, conv=None, use_unicode=None,
|
||||
client_flag=0, cursorclass=Cursor, init_command=None,
|
||||
connect_timeout=None, ssl=None, read_default_group=None,
|
||||
connect_timeout=10, ssl=None, read_default_group=None,
|
||||
compress=None, named_pipe=None, no_delay=None,
|
||||
autocommit=False, db=None, passwd=None, local_infile=False,
|
||||
max_allowed_packet=16*1024*1024, defer_connect=False):
|
||||
max_allowed_packet=16*1024*1024, defer_connect=False,
|
||||
auth_plugin_map={}, read_timeout=None, write_timeout=None,
|
||||
bind_address=None):
|
||||
"""
|
||||
Establish a connection to the MySQL database. Accepts several
|
||||
arguments:
|
||||
|
@ -516,15 +545,19 @@ class Connection(object):
|
|||
user: Username to log in as
|
||||
password: Password to use.
|
||||
database: Database to use, None to not use a particular one.
|
||||
port: MySQL port to use, default is usually OK.
|
||||
port: MySQL port to use, default is usually OK. (default: 3306)
|
||||
bind_address: When the client has multiple network interfaces, specify
|
||||
the interface from which to connect to the host. Argument can be
|
||||
a hostname or an IP address.
|
||||
unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
|
||||
charset: Charset you want to use.
|
||||
sql_mode: Default SQL_MODE to use.
|
||||
read_default_file:
|
||||
Specifies my.cnf file to read these parameters from under the [client] section.
|
||||
conv:
|
||||
Decoders dictionary to use instead of the default one.
|
||||
This is used to provide custom marshalling of types. See converters.
|
||||
Conversion dictionary to use instead of the default one.
|
||||
This is used to provide custom marshalling and unmarshaling of types.
|
||||
See converters.
|
||||
use_unicode:
|
||||
Whether or not to default to unicode strings.
|
||||
This option defaults to true for Py3k.
|
||||
|
@ -532,27 +565,29 @@ class Connection(object):
|
|||
cursorclass: Custom cursor class to use.
|
||||
init_command: Initial SQL statement to run when connection is established.
|
||||
connect_timeout: Timeout before throwing an exception when connecting.
|
||||
(default: 10, min: 1, max: 31536000)
|
||||
ssl:
|
||||
A dict of arguments similar to mysql_ssl_set()'s parameters.
|
||||
For now the capath and cipher arguments are not supported.
|
||||
read_default_group: Group to read from in the configuration file.
|
||||
compress; Not supported
|
||||
named_pipe: Not supported
|
||||
no_delay: Disable Nagle's algorithm on the socket. (deprecated, default: True)
|
||||
autocommit: Autocommit mode. None means use server default. (default: False)
|
||||
local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
|
||||
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
|
||||
Only used to limit size of "LOAD LOCAL INFILE" data packet smaller than default (16KB).
|
||||
defer_connect: Don't explicitly connect on contruction - wait for connect call.
|
||||
(default: False)
|
||||
|
||||
auth_plugin_map: A dict of plugin names to a class that processes that plugin.
|
||||
The class will take the Connection object as the argument to the constructor.
|
||||
The class needs an authenticate method taking an authentication packet as
|
||||
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
|
||||
(if no authenticate method) for returning a string from the user. (experimental)
|
||||
db: Alias for database. (for compatibility to MySQLdb)
|
||||
passwd: Alias for password. (for compatibility to MySQLdb)
|
||||
"""
|
||||
if no_delay is not None:
|
||||
warnings.warn("no_delay option is deprecated", DeprecationWarning)
|
||||
no_delay = bool(no_delay)
|
||||
else:
|
||||
no_delay = True
|
||||
|
||||
if use_unicode is None and sys.version_info[0] > 2:
|
||||
use_unicode = True
|
||||
|
@ -565,24 +600,10 @@ class Connection(object):
|
|||
if compress or named_pipe:
|
||||
raise NotImplementedError("compress and named_pipe arguments are not supported")
|
||||
|
||||
if local_infile:
|
||||
self._local_infile = bool(local_infile)
|
||||
if self._local_infile:
|
||||
client_flag |= CLIENT.LOCAL_FILES
|
||||
|
||||
if ssl and ('capath' in ssl or 'cipher' in ssl):
|
||||
raise NotImplementedError('ssl options capath and cipher are not supported')
|
||||
|
||||
self.ssl = False
|
||||
if ssl:
|
||||
if not SSL_ENABLED:
|
||||
raise NotImplementedError("ssl module not found")
|
||||
self.ssl = True
|
||||
client_flag |= CLIENT.SSL
|
||||
for k in ('key', 'cert', 'ca'):
|
||||
v = None
|
||||
if k in ssl:
|
||||
v = ssl[k]
|
||||
setattr(self, k, v)
|
||||
|
||||
if read_default_group and not read_default_file:
|
||||
if sys.platform.startswith("win"):
|
||||
read_default_file = "c:\\my.ini"
|
||||
|
@ -610,15 +631,40 @@ class Connection(object):
|
|||
database = _config("database", database)
|
||||
unix_socket = _config("socket", unix_socket)
|
||||
port = int(_config("port", port))
|
||||
bind_address = _config("bind-address", bind_address)
|
||||
charset = _config("default-character-set", charset)
|
||||
if not ssl:
|
||||
ssl = {}
|
||||
if isinstance(ssl, dict):
|
||||
for key in ["ca", "capath", "cert", "key", "cipher"]:
|
||||
value = _config("ssl-" + key, ssl.get(key))
|
||||
if value:
|
||||
ssl[key] = value
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
self.ssl = False
|
||||
if ssl:
|
||||
if not SSL_ENABLED:
|
||||
raise NotImplementedError("ssl module not found")
|
||||
self.ssl = True
|
||||
client_flag |= CLIENT.SSL
|
||||
self.ctx = self._create_ssl_ctx(ssl)
|
||||
|
||||
self.host = host or "localhost"
|
||||
self.port = port or 3306
|
||||
self.user = user or DEFAULT_USER
|
||||
self.password = password or ""
|
||||
self.db = database
|
||||
self.no_delay = no_delay
|
||||
self.unix_socket = unix_socket
|
||||
self.bind_address = bind_address
|
||||
if not (0 < connect_timeout <= 31536000):
|
||||
raise ValueError("connect_timeout should be >0 and <=31536000")
|
||||
self.connect_timeout = connect_timeout or None
|
||||
if read_timeout is not None and read_timeout <= 0:
|
||||
raise ValueError("read_timeout should be >= 0")
|
||||
self._read_timeout = read_timeout
|
||||
if write_timeout is not None and write_timeout <= 0:
|
||||
raise ValueError("write_timeout should be >= 0")
|
||||
self._write_timeout = write_timeout
|
||||
if charset:
|
||||
self.charset = charset
|
||||
self.use_unicode = True
|
||||
|
@ -631,13 +677,12 @@ class Connection(object):
|
|||
|
||||
self.encoding = charset_by_name(self.charset).encoding
|
||||
|
||||
client_flag |= CLIENT.CAPABILITIES | CLIENT.MULTI_STATEMENTS
|
||||
client_flag |= CLIENT.CAPABILITIES
|
||||
if self.db:
|
||||
client_flag |= CLIENT.CONNECT_WITH_DB
|
||||
self.client_flag = client_flag
|
||||
|
||||
self.cursorclass = cursorclass
|
||||
self.connect_timeout = connect_timeout
|
||||
|
||||
self._result = None
|
||||
self._affected_rows = 0
|
||||
|
@ -646,44 +691,68 @@ class Connection(object):
|
|||
#: specified autocommit mode. None means use server default.
|
||||
self.autocommit_mode = autocommit
|
||||
|
||||
self.encoders = encoders # Need for MySQLdb compatibility.
|
||||
self.decoders = conv
|
||||
if conv is None:
|
||||
conv = _conv
|
||||
# Need for MySQLdb compatibility.
|
||||
self.encoders = dict([(k, v) for (k, v) in conv.items() if type(k) is not int])
|
||||
self.decoders = dict([(k, v) for (k, v) in conv.items() if type(k) is int])
|
||||
self.sql_mode = sql_mode
|
||||
self.init_command = init_command
|
||||
self.max_allowed_packet = max_allowed_packet
|
||||
self._auth_plugin_map = auth_plugin_map
|
||||
if defer_connect:
|
||||
self.socket = None
|
||||
self._sock = None
|
||||
else:
|
||||
self.connect()
|
||||
|
||||
def _create_ssl_ctx(self, sslp):
|
||||
if isinstance(sslp, ssl.SSLContext):
|
||||
return sslp
|
||||
ca = sslp.get('ca')
|
||||
capath = sslp.get('capath')
|
||||
hasnoca = ca is None and capath is None
|
||||
ctx = ssl.create_default_context(cafile=ca, capath=capath)
|
||||
ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
|
||||
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
|
||||
if 'cert' in sslp:
|
||||
ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
|
||||
if 'cipher' in sslp:
|
||||
ctx.set_ciphers(sslp['cipher'])
|
||||
ctx.options |= ssl.OP_NO_SSLv2
|
||||
ctx.options |= ssl.OP_NO_SSLv3
|
||||
return ctx
|
||||
|
||||
def close(self):
|
||||
"""Send the quit message and close the socket"""
|
||||
if self.socket is None:
|
||||
if self._closed:
|
||||
raise err.Error("Already closed")
|
||||
self._closed = True
|
||||
if self._sock is None:
|
||||
return
|
||||
send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
|
||||
try:
|
||||
self._write_bytes(send_data)
|
||||
except Exception:
|
||||
pass
|
||||
finally:
|
||||
sock = self.socket
|
||||
self.socket = None
|
||||
self._rfile = None
|
||||
sock.close()
|
||||
self._force_close()
|
||||
|
||||
@property
|
||||
def open(self):
|
||||
return self.socket is not None
|
||||
return self._sock is not None
|
||||
|
||||
def __del__(self):
|
||||
if self.socket:
|
||||
def _force_close(self):
|
||||
"""Close connection without QUIT message"""
|
||||
if self._sock:
|
||||
try:
|
||||
self.socket.close()
|
||||
self._sock.close()
|
||||
except:
|
||||
pass
|
||||
self.socket = None
|
||||
self._sock = None
|
||||
self._rfile = None
|
||||
|
||||
__del__ = _force_close
|
||||
|
||||
def autocommit(self, value):
|
||||
self.autocommit_mode = bool(value)
|
||||
current = self.get_autocommit()
|
||||
|
@ -731,19 +800,25 @@ class Connection(object):
|
|||
return result.rows
|
||||
|
||||
def select_db(self, db):
|
||||
'''Set current db'''
|
||||
"""Set current db"""
|
||||
self._execute_command(COMMAND.COM_INIT_DB, db)
|
||||
self._read_ok_packet()
|
||||
|
||||
def escape(self, obj, mapping=None):
|
||||
"""Escape whatever value you pass to it"""
|
||||
"""Escape whatever value you pass to it.
|
||||
|
||||
Non-standard, for internal use; do not use this in your applications.
|
||||
"""
|
||||
if isinstance(obj, str_type):
|
||||
return "'" + self.escape_string(obj) + "'"
|
||||
return escape_item(obj, self.charset, mapping=mapping)
|
||||
|
||||
def literal(self, obj):
|
||||
'''Alias for escape()'''
|
||||
return self.escape(obj)
|
||||
"""Alias for escape()
|
||||
|
||||
Non-standard, for internal use; do not use this in your applications.
|
||||
"""
|
||||
return self.escape(obj, self.encoders)
|
||||
|
||||
def escape_string(self, s):
|
||||
if (self.server_status &
|
||||
|
@ -795,7 +870,7 @@ class Connection(object):
|
|||
|
||||
def ping(self, reconnect=True):
|
||||
"""Check if the server is alive"""
|
||||
if self.socket is None:
|
||||
if self._sock is None:
|
||||
if reconnect:
|
||||
self.connect()
|
||||
reconnect = False
|
||||
|
@ -821,6 +896,7 @@ class Connection(object):
|
|||
self.encoding = encoding
|
||||
|
||||
def connect(self, sock=None):
|
||||
self._closed = False
|
||||
try:
|
||||
if sock is None:
|
||||
if self.unix_socket and self.host in ('localhost', '127.0.0.1'):
|
||||
|
@ -830,10 +906,14 @@ class Connection(object):
|
|||
self.host_info = "Localhost via UNIX socket"
|
||||
if DEBUG: print('connected using unix_socket')
|
||||
else:
|
||||
kwargs = {}
|
||||
if self.bind_address is not None:
|
||||
kwargs['source_address'] = (self.bind_address, 0)
|
||||
while True:
|
||||
try:
|
||||
sock = socket.create_connection(
|
||||
(self.host, self.port), self.connect_timeout)
|
||||
(self.host, self.port), self.connect_timeout,
|
||||
**kwargs)
|
||||
break
|
||||
except (OSError, IOError) as e:
|
||||
if e.errno == errno.EINTR:
|
||||
|
@ -841,12 +921,13 @@ class Connection(object):
|
|||
raise
|
||||
self.host_info = "socket %s:%d" % (self.host, self.port)
|
||||
if DEBUG: print('connected using socket')
|
||||
if self.no_delay:
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
# sock.settimeout(None)
|
||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
sock.settimeout(None)
|
||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||
self.socket = sock
|
||||
self._sock = sock
|
||||
self._rfile = _makefile(sock, 'rb')
|
||||
self._next_seq_id = 0
|
||||
|
||||
self._get_server_information()
|
||||
self._request_authentication()
|
||||
|
||||
|
@ -886,6 +967,17 @@ class Connection(object):
|
|||
# So just reraise it.
|
||||
raise
|
||||
|
||||
def write_packet(self, payload):
|
||||
"""Writes an entire "mysql packet" in its entirety to the network
|
||||
addings its length and sequence number.
|
||||
"""
|
||||
# Internal note: when you build packet manualy and calls _write_bytes()
|
||||
# directly, you should set self._next_seq_id properly.
|
||||
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
|
||||
if DEBUG: dump_packet(data)
|
||||
self._write_bytes(data)
|
||||
self._next_seq_id = (self._next_seq_id + 1) % 256
|
||||
|
||||
def _read_packet(self, packet_type=MysqlPacket):
|
||||
"""Read an entire "mysql packet" in its entirety from the network
|
||||
and return a MysqlPacket type that represents the results.
|
||||
|
@ -894,19 +986,36 @@ class Connection(object):
|
|||
while True:
|
||||
packet_header = self._read_bytes(4)
|
||||
if DEBUG: dump_packet(packet_header)
|
||||
|
||||
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
|
||||
bytes_to_read = btrl + (btrh << 16)
|
||||
#TODO: check sequence id
|
||||
if packet_number != self._next_seq_id:
|
||||
self._force_close()
|
||||
if packet_number == 0:
|
||||
# MariaDB sends error packet with seqno==0 when shutdown
|
||||
raise err.OperationalError(
|
||||
CR.CR_SERVER_LOST,
|
||||
"Lost connection to MySQL server during query")
|
||||
raise err.InternalError(
|
||||
"Packet sequence number wrong - got %d expected %d"
|
||||
% (packet_number, self._next_seq_id))
|
||||
self._next_seq_id = (self._next_seq_id + 1) % 256
|
||||
|
||||
recv_data = self._read_bytes(bytes_to_read)
|
||||
if DEBUG: dump_packet(recv_data)
|
||||
buff += recv_data
|
||||
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
|
||||
if bytes_to_read == 0xffffff:
|
||||
continue
|
||||
if bytes_to_read < MAX_PACKET_LEN:
|
||||
break
|
||||
|
||||
packet = packet_type(buff, self.encoding)
|
||||
packet.check_error()
|
||||
return packet
|
||||
|
||||
def _read_bytes(self, num_bytes):
|
||||
self._sock.settimeout(self._read_timeout)
|
||||
while True:
|
||||
try:
|
||||
data = self._rfile.read(num_bytes)
|
||||
|
@ -914,19 +1023,25 @@ class Connection(object):
|
|||
except (IOError, OSError) as e:
|
||||
if e.errno == errno.EINTR:
|
||||
continue
|
||||
self._force_close()
|
||||
raise err.OperationalError(
|
||||
2013,
|
||||
CR.CR_SERVER_LOST,
|
||||
"Lost connection to MySQL server during query (%s)" % (e,))
|
||||
if len(data) < num_bytes:
|
||||
self._force_close()
|
||||
raise err.OperationalError(
|
||||
2013, "Lost connection to MySQL server during query")
|
||||
CR.CR_SERVER_LOST, "Lost connection to MySQL server during query")
|
||||
return data
|
||||
|
||||
def _write_bytes(self, data):
|
||||
self._sock.settimeout(self._write_timeout)
|
||||
try:
|
||||
self.socket.sendall(data)
|
||||
self._sock.sendall(data)
|
||||
except IOError as e:
|
||||
raise err.OperationalError(2006, "MySQL server has gone away (%r)" % (e,))
|
||||
self._force_close()
|
||||
raise err.OperationalError(
|
||||
CR.CR_SERVER_GONE_ERROR,
|
||||
"MySQL server has gone away (%r)" % (e,))
|
||||
|
||||
def _read_query_result(self, unbuffered=False):
|
||||
if unbuffered:
|
||||
|
@ -952,42 +1067,45 @@ class Connection(object):
|
|||
return 0
|
||||
|
||||
def _execute_command(self, command, sql):
|
||||
if not self.socket:
|
||||
if not self._sock:
|
||||
raise err.InterfaceError("(0, '')")
|
||||
|
||||
# If the last query was unbuffered, make sure it finishes before
|
||||
# sending new commands
|
||||
if self._result is not None and self._result.unbuffered_active:
|
||||
warnings.warn("Previous unbuffered result was left incomplete")
|
||||
self._result._finish_unbuffered_query()
|
||||
if self._result is not None:
|
||||
if self._result.unbuffered_active:
|
||||
warnings.warn("Previous unbuffered result was left incomplete")
|
||||
self._result._finish_unbuffered_query()
|
||||
while self._result.has_next:
|
||||
self.next_result()
|
||||
self._result = None
|
||||
|
||||
if isinstance(sql, text_type):
|
||||
sql = sql.encode(self.encoding)
|
||||
|
||||
chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command
|
||||
packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
|
||||
|
||||
prelude = struct.pack('<iB', chunk_size, command)
|
||||
self._write_bytes(prelude + sql[:chunk_size-1])
|
||||
if DEBUG: dump_packet(prelude + sql)
|
||||
# tiny optimization: build first packet manually instead of
|
||||
# calling self..write_packet()
|
||||
prelude = struct.pack('<iB', packet_size, command)
|
||||
packet = prelude + sql[:packet_size-1]
|
||||
self._write_bytes(packet)
|
||||
if DEBUG: dump_packet(packet)
|
||||
self._next_seq_id = 1
|
||||
|
||||
if chunk_size < self.max_allowed_packet:
|
||||
if packet_size < MAX_PACKET_LEN:
|
||||
return
|
||||
|
||||
seq_id = 1
|
||||
sql = sql[chunk_size-1:]
|
||||
sql = sql[packet_size-1:]
|
||||
while True:
|
||||
chunk_size = min(self.max_allowed_packet, len(sql))
|
||||
prelude = struct.pack('<i', chunk_size)[:3]
|
||||
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
|
||||
self._write_bytes(data)
|
||||
if DEBUG: dump_packet(data)
|
||||
sql = sql[chunk_size:]
|
||||
if not sql and chunk_size < self.max_allowed_packet:
|
||||
packet_size = min(MAX_PACKET_LEN, len(sql))
|
||||
self.write_packet(sql[:packet_size])
|
||||
sql = sql[packet_size:]
|
||||
if not sql and packet_size < MAX_PACKET_LEN:
|
||||
break
|
||||
seq_id += 1
|
||||
|
||||
def _request_authentication(self):
|
||||
self.client_flag |= CLIENT.CAPABILITIES
|
||||
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||||
if int(self.server_version.split('.', 1)[0]) >= 5:
|
||||
self.client_flag |= CLIENT.MULTI_RESULTS
|
||||
|
||||
|
@ -1000,48 +1118,114 @@ class Connection(object):
|
|||
|
||||
data_init = struct.pack('<iIB23s', self.client_flag, 1, charset_id, b'')
|
||||
|
||||
next_packet = 1
|
||||
if self.ssl and self.server_capabilities & CLIENT.SSL:
|
||||
self.write_packet(data_init)
|
||||
|
||||
if self.ssl:
|
||||
data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
|
||||
next_packet += 1
|
||||
self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host)
|
||||
self._rfile = _makefile(self._sock, 'rb')
|
||||
|
||||
if DEBUG: dump_packet(data)
|
||||
self._write_bytes(data)
|
||||
data = data_init + self.user + b'\0'
|
||||
|
||||
cert_reqs = ssl.CERT_NONE if self.ca is None else ssl.CERT_REQUIRED
|
||||
self.socket = ssl.wrap_socket(self.socket, keyfile=self.key,
|
||||
certfile=self.cert,
|
||||
ssl_version=ssl.PROTOCOL_TLSv1,
|
||||
cert_reqs=cert_reqs,
|
||||
ca_certs=self.ca)
|
||||
self._rfile = _makefile(self.socket, 'rb')
|
||||
authresp = b''
|
||||
if self._auth_plugin_name in ('', 'mysql_native_password'):
|
||||
authresp = _scramble(self.password.encode('latin1'), self.salt)
|
||||
|
||||
data = data_init + self.user + b'\0' + \
|
||||
_scramble(self.password.encode('latin1'), self.salt)
|
||||
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
|
||||
data += lenenc_int(len(authresp)) + authresp
|
||||
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
|
||||
data += struct.pack('B', len(authresp)) + authresp
|
||||
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
|
||||
data += authresp + b'\0'
|
||||
|
||||
if self.db:
|
||||
if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
|
||||
if isinstance(self.db, text_type):
|
||||
self.db = self.db.encode(self.encoding)
|
||||
data += self.db + int2byte(0)
|
||||
data += self.db + b'\0'
|
||||
|
||||
data = pack_int24(len(data)) + int2byte(next_packet) + data
|
||||
next_packet += 2
|
||||
|
||||
if DEBUG: dump_packet(data)
|
||||
self._write_bytes(data)
|
||||
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
|
||||
name = self._auth_plugin_name
|
||||
if isinstance(name, text_type):
|
||||
name = name.encode('ascii')
|
||||
data += name + b'\0'
|
||||
|
||||
self.write_packet(data)
|
||||
auth_packet = self._read_packet()
|
||||
|
||||
# if old_passwords is enabled the packet will be 1 byte long and
|
||||
# have the octet 254
|
||||
# if authentication method isn't accepted the first byte
|
||||
# will have the octet 254
|
||||
if auth_packet.is_auth_switch_request():
|
||||
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
|
||||
auth_packet.read_uint8() # 0xfe packet identifier
|
||||
plugin_name = auth_packet.read_string()
|
||||
if self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None:
|
||||
auth_packet = self._process_auth(plugin_name, auth_packet)
|
||||
else:
|
||||
# send legacy handshake
|
||||
data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
|
||||
self.write_packet(data)
|
||||
auth_packet = self._read_packet()
|
||||
|
||||
if auth_packet.is_eof_packet():
|
||||
# send legacy handshake
|
||||
data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
|
||||
data = pack_int24(len(data)) + int2byte(next_packet) + data
|
||||
self._write_bytes(data)
|
||||
auth_packet = self._read_packet()
|
||||
def _process_auth(self, plugin_name, auth_packet):
|
||||
plugin_class = self._auth_plugin_map.get(plugin_name)
|
||||
if not plugin_class:
|
||||
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
|
||||
if plugin_class:
|
||||
try:
|
||||
handler = plugin_class(self)
|
||||
return handler.authenticate(auth_packet)
|
||||
except AttributeError:
|
||||
if plugin_name != b'dialog':
|
||||
raise err.OperationalError(2059, "Authentication plugin '%s'" \
|
||||
" not loaded: - %r missing authenticate method" % (plugin_name, plugin_class))
|
||||
except TypeError:
|
||||
raise err.OperationalError(2059, "Authentication plugin '%s'" \
|
||||
" not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
|
||||
else:
|
||||
handler = None
|
||||
if plugin_name == b"mysql_native_password":
|
||||
# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
|
||||
data = _scramble(self.password.encode('latin1'), auth_packet.read_all()) + b'\0'
|
||||
elif plugin_name == b"mysql_old_password":
|
||||
# https://dev.mysql.com/doc/internals/en/old-password-authentication.html
|
||||
data = _scramble_323(self.password.encode('latin1'), auth_packet.read_all()) + b'\0'
|
||||
elif plugin_name == b"mysql_clear_password":
|
||||
# https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
|
||||
data = self.password.encode('latin1') + b'\0'
|
||||
elif plugin_name == b"dialog":
|
||||
pkt = auth_packet
|
||||
while True:
|
||||
flag = pkt.read_uint8()
|
||||
echo = (flag & 0x06) == 0x02
|
||||
last = (flag & 0x01) == 0x01
|
||||
prompt = pkt.read_all()
|
||||
|
||||
if prompt == b"Password: ":
|
||||
self.write_packet(self.password.encode('latin1') + b'\0')
|
||||
elif handler:
|
||||
resp = 'no response - TypeError within plugin.prompt method'
|
||||
try:
|
||||
resp = handler.prompt(echo, prompt)
|
||||
self.write_packet(resp + b'\0')
|
||||
except AttributeError:
|
||||
raise err.OperationalError(2059, "Authentication plugin '%s'" \
|
||||
" not loaded: - %r missing prompt method" % (plugin_name, handler))
|
||||
except TypeError:
|
||||
raise err.OperationalError(2061, "Authentication plugin '%s'" \
|
||||
" %r didn't respond with string. Returned '%r' to prompt %r" % (plugin_name, handler, resp, prompt))
|
||||
else:
|
||||
raise err.OperationalError(2059, "Authentication plugin '%s' (%r) not configured" % (plugin_name, handler))
|
||||
pkt = self._read_packet()
|
||||
pkt.check_error()
|
||||
if pkt.is_ok_packet() or last:
|
||||
break
|
||||
return pkt
|
||||
else:
|
||||
raise err.OperationalError(2059, "Authentication plugin '%s' not configured" % plugin_name)
|
||||
|
||||
self.write_packet(data)
|
||||
pkt = self._read_packet()
|
||||
pkt.check_error()
|
||||
return pkt
|
||||
|
||||
# _mysql support
|
||||
def thread_id(self):
|
||||
|
@ -1065,7 +1249,7 @@ class Connection(object):
|
|||
self.protocol_version = byte2int(data[i:i+1])
|
||||
i += 1
|
||||
|
||||
server_end = data.find(int2byte(0), i)
|
||||
server_end = data.find(b'\0', i)
|
||||
self.server_version = data[i:server_end].decode('latin1')
|
||||
i = server_end + 1
|
||||
|
||||
|
@ -1097,7 +1281,22 @@ class Connection(object):
|
|||
if len(data) >= i + salt_len:
|
||||
# salt_len includes auth_plugin_data_part_1 and filler
|
||||
self.salt += data[i:i+salt_len]
|
||||
# TODO: AUTH PLUGIN NAME may appeare here.
|
||||
i += salt_len
|
||||
|
||||
i+=1
|
||||
# AUTH PLUGIN NAME may appear here.
|
||||
if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
|
||||
# Due to Bug#59453 the auth-plugin-name is missing the terminating
|
||||
# NUL-char in versions prior to 5.5.10 and 5.6.2.
|
||||
# ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
|
||||
# didn't use version checks as mariadb is corrected and reports
|
||||
# earlier than those two.
|
||||
server_end = data.find(b'\0', i)
|
||||
if server_end < 0: # pragma: no cover - very specific upstream bug
|
||||
# not found \0 and last field so take it all
|
||||
self._auth_plugin_name = data[i:].decode('latin1')
|
||||
else:
|
||||
self._auth_plugin_name = data[i:server_end].decode('latin1')
|
||||
|
||||
def get_server_info(self):
|
||||
return self.server_version
|
||||
|
@ -1117,6 +1316,9 @@ class Connection(object):
|
|||
class MySQLResult(object):
|
||||
|
||||
def __init__(self, connection):
|
||||
"""
|
||||
:type connection: Connection
|
||||
"""
|
||||
self.connection = connection
|
||||
self.affected_rows = None
|
||||
self.insert_id = None
|
||||
|
@ -1144,7 +1346,7 @@ class MySQLResult(object):
|
|||
else:
|
||||
self._read_result_packet(first_packet)
|
||||
finally:
|
||||
self.connection = False
|
||||
self.connection = None
|
||||
|
||||
def init_unbuffered_query(self):
|
||||
self.unbuffered_active = True
|
||||
|
@ -1154,6 +1356,10 @@ class MySQLResult(object):
|
|||
self._read_ok_packet(first_packet)
|
||||
self.unbuffered_active = False
|
||||
self.connection = None
|
||||
elif first_packet.is_load_local_packet():
|
||||
self._read_load_local_packet(first_packet)
|
||||
self.unbuffered_active = False
|
||||
self.connection = None
|
||||
else:
|
||||
self.field_count = first_packet.read_length_encoded_integer()
|
||||
self._get_descriptions()
|
||||
|
@ -1173,22 +1379,33 @@ class MySQLResult(object):
|
|||
self.has_next = ok_packet.has_next
|
||||
|
||||
def _read_load_local_packet(self, first_packet):
|
||||
if not self.connection._local_infile:
|
||||
raise RuntimeError(
|
||||
"**WARN**: Received LOAD_LOCAL packet but local_infile option is false.")
|
||||
load_packet = LoadLocalPacketWrapper(first_packet)
|
||||
sender = LoadLocalFile(load_packet.filename, self.connection)
|
||||
sender.send_data()
|
||||
try:
|
||||
sender.send_data()
|
||||
except:
|
||||
self.connection._read_packet() # skip ok packet
|
||||
raise
|
||||
|
||||
ok_packet = self.connection._read_packet()
|
||||
if not ok_packet.is_ok_packet():
|
||||
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
|
||||
raise err.OperationalError(2014, "Commands Out of Sync")
|
||||
self._read_ok_packet(ok_packet)
|
||||
|
||||
def _check_packet_is_eof(self, packet):
|
||||
if packet.is_eof_packet():
|
||||
eof_packet = EOFPacketWrapper(packet)
|
||||
self.warning_count = eof_packet.warning_count
|
||||
self.has_next = eof_packet.has_next
|
||||
return True
|
||||
return False
|
||||
if not packet.is_eof_packet():
|
||||
return False
|
||||
#TODO: Support CLIENT.DEPRECATE_EOF
|
||||
# 1) Add DEPRECATE_EOF to CAPABILITIES
|
||||
# 2) Mask CAPABILITIES with server_capabilities
|
||||
# 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper instead of EOFPacketWrapper
|
||||
wp = EOFPacketWrapper(packet)
|
||||
self.warning_count = wp.warning_count
|
||||
self.has_next = wp.has_next
|
||||
return True
|
||||
|
||||
def _read_result_packet(self, first_packet):
|
||||
self.field_count = first_packet.read_length_encoded_integer()
|
||||
|
@ -1239,7 +1456,12 @@ class MySQLResult(object):
|
|||
def _read_row_from_packet(self, packet):
|
||||
row = []
|
||||
for encoding, converter in self.converters:
|
||||
data = packet.read_length_coded_string()
|
||||
try:
|
||||
data = packet.read_length_coded_string()
|
||||
except IndexError:
|
||||
# No more columns in this row
|
||||
# See https://github.com/PyMySQL/PyMySQL/pull/434
|
||||
break
|
||||
if data is not None:
|
||||
if encoding is not None:
|
||||
data = data.decode(encoding)
|
||||
|
@ -1254,21 +1476,30 @@ class MySQLResult(object):
|
|||
self.fields = []
|
||||
self.converters = []
|
||||
use_unicode = self.connection.use_unicode
|
||||
conn_encoding = self.connection.encoding
|
||||
description = []
|
||||
|
||||
for i in range_type(self.field_count):
|
||||
field = self.connection._read_packet(FieldDescriptorPacket)
|
||||
self.fields.append(field)
|
||||
description.append(field.description())
|
||||
field_type = field.type_code
|
||||
if use_unicode:
|
||||
if field_type in TEXT_TYPES:
|
||||
charset = charset_by_id(field.charsetnr)
|
||||
if charset.is_binary:
|
||||
if field_type == FIELD_TYPE.JSON:
|
||||
# When SELECT from JSON column: charset = binary
|
||||
# When SELECT CAST(... AS JSON): charset = connection encoding
|
||||
# This behavior is different from TEXT / BLOB.
|
||||
# We should decode result by connection encoding regardless charsetnr.
|
||||
# See https://github.com/PyMySQL/PyMySQL/issues/488
|
||||
encoding = conn_encoding # SELECT CAST(... AS JSON)
|
||||
elif field_type in TEXT_TYPES:
|
||||
if field.charsetnr == 63: # binary
|
||||
# TEXTs with charset=binary means BINARY types.
|
||||
encoding = None
|
||||
else:
|
||||
encoding = charset.encoding
|
||||
encoding = conn_encoding
|
||||
else:
|
||||
# Integers, Dates and Times, and other basic data is encoded in ascii
|
||||
encoding = 'ascii'
|
||||
else:
|
||||
encoding = None
|
||||
|
@ -1290,28 +1521,20 @@ class LoadLocalFile(object):
|
|||
|
||||
def send_data(self):
|
||||
"""Send data packets from the local file to the server"""
|
||||
if not self.connection.socket:
|
||||
if not self.connection._sock:
|
||||
raise err.InterfaceError("(0, '')")
|
||||
conn = self.connection
|
||||
|
||||
# sequence id is 2 as we already sent a query packet
|
||||
seq_id = 2
|
||||
try:
|
||||
with open(self.filename, 'rb') as open_file:
|
||||
chunk_size = self.connection.max_allowed_packet
|
||||
packet = b""
|
||||
|
||||
packet_size = min(conn.max_allowed_packet, 16*1024) # 16KB is efficient enough
|
||||
while True:
|
||||
chunk = open_file.read(chunk_size)
|
||||
chunk = open_file.read(packet_size)
|
||||
if not chunk:
|
||||
break
|
||||
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
|
||||
format_str = '!{0}s'.format(len(chunk))
|
||||
packet += struct.pack(format_str, chunk)
|
||||
self.connection._write_bytes(packet)
|
||||
seq_id += 1
|
||||
conn.write_packet(chunk)
|
||||
except IOError:
|
||||
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
|
||||
finally:
|
||||
# send the empty packet to signify we are done sending data
|
||||
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
|
||||
self.connection._write_bytes(packet)
|
||||
conn.write_packet(b'')
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
|
||||
LONG_PASSWORD = 1
|
||||
FOUND_ROWS = 1 << 1
|
||||
LONG_FLAG = 1 << 2
|
||||
|
@ -15,5 +16,16 @@ TRANSACTIONS = 1 << 13
|
|||
SECURE_CONNECTION = 1 << 15
|
||||
MULTI_STATEMENTS = 1 << 16
|
||||
MULTI_RESULTS = 1 << 17
|
||||
CAPABILITIES = (LONG_PASSWORD | LONG_FLAG | TRANSACTIONS |
|
||||
PROTOCOL_41 | SECURE_CONNECTION)
|
||||
PS_MULTI_RESULTS = 1 << 18
|
||||
PLUGIN_AUTH = 1 << 19
|
||||
PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21
|
||||
CAPABILITIES = (
|
||||
LONG_PASSWORD | LONG_FLAG | PROTOCOL_41 | TRANSACTIONS
|
||||
| SECURE_CONNECTION | MULTI_STATEMENTS | MULTI_RESULTS
|
||||
| PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA)
|
||||
|
||||
# Not done yet
|
||||
CONNECT_ATTRS = 1 << 20
|
||||
HANDLE_EXPIRED_PASSWORDS = 1 << 22
|
||||
SESSION_TRACK = 1 << 23
|
||||
DEPRECATE_EOF = 1 << 24
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# flake8: noqa
|
||||
# errmsg.h
|
||||
CR_ERROR_FIRST = 2000
|
||||
CR_UNKNOWN_ERROR = 2000
|
||||
|
|
|
@ -17,6 +17,7 @@ YEAR = 13
|
|||
NEWDATE = 14
|
||||
VARCHAR = 15
|
||||
BIT = 16
|
||||
JSON = 245
|
||||
NEWDECIMAL = 246
|
||||
ENUM = 247
|
||||
SET = 248
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
from ._compat import PY2, text_type, long_type, JYTHON, IRONPYTHON
|
||||
from ._compat import PY2, text_type, long_type, JYTHON, IRONPYTHON, unichr
|
||||
|
||||
import sys
|
||||
import binascii
|
||||
import datetime
|
||||
from decimal import Decimal
|
||||
import re
|
||||
|
@ -11,10 +9,6 @@ from .constants import FIELD_TYPE, FLAG
|
|||
from .charset import charset_by_id, charset_to_encoding
|
||||
|
||||
|
||||
ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
|
||||
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
|
||||
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
|
||||
|
||||
def escape_item(val, charset, mapping=None):
|
||||
if mapping is None:
|
||||
mapping = encoders
|
||||
|
@ -48,8 +42,7 @@ def escape_sequence(val, charset, mapping=None):
|
|||
return "(" + ",".join(n) + ")"
|
||||
|
||||
def escape_set(val, charset, mapping=None):
|
||||
val = map(lambda x: escape_item(x, charset, mapping), val)
|
||||
return ','.join(val)
|
||||
return ','.join([escape_item(x, charset, mapping) for x in val])
|
||||
|
||||
def escape_bool(value, mapping=None):
|
||||
return str(int(value))
|
||||
|
@ -63,19 +56,61 @@ def escape_int(value, mapping=None):
|
|||
def escape_float(value, mapping=None):
|
||||
return ('%.15g' % value)
|
||||
|
||||
def escape_string(value, mapping=None):
|
||||
return ("%s" % (ESCAPE_REGEX.sub(
|
||||
lambda match: ESCAPE_MAP.get(match.group(0)), value),))
|
||||
_escape_table = [unichr(x) for x in range(128)]
|
||||
_escape_table[0] = u'\\0'
|
||||
_escape_table[ord('\\')] = u'\\\\'
|
||||
_escape_table[ord('\n')] = u'\\n'
|
||||
_escape_table[ord('\r')] = u'\\r'
|
||||
_escape_table[ord('\032')] = u'\\Z'
|
||||
_escape_table[ord('"')] = u'\\"'
|
||||
_escape_table[ord("'")] = u"\\'"
|
||||
|
||||
def _escape_unicode(value, mapping=None):
|
||||
"""escapes *value* without adding quote.
|
||||
|
||||
Value should be unicode
|
||||
"""
|
||||
return value.translate(_escape_table)
|
||||
|
||||
if PY2:
|
||||
def escape_string(value, mapping=None):
|
||||
"""escape_string escapes *value* but not surround it with quotes.
|
||||
|
||||
Value should be bytes or unicode.
|
||||
"""
|
||||
if isinstance(value, unicode):
|
||||
return _escape_unicode(value)
|
||||
assert isinstance(value, (bytes, bytearray))
|
||||
value = value.replace('\\', '\\\\')
|
||||
value = value.replace('\0', '\\0')
|
||||
value = value.replace('\n', '\\n')
|
||||
value = value.replace('\r', '\\r')
|
||||
value = value.replace('\032', '\\Z')
|
||||
value = value.replace("'", "\\'")
|
||||
value = value.replace('"', '\\"')
|
||||
return value
|
||||
|
||||
def escape_bytes(value, mapping=None):
|
||||
assert isinstance(value, (bytes, bytearray))
|
||||
return b"_binary'%s'" % escape_string(value)
|
||||
else:
|
||||
escape_string = _escape_unicode
|
||||
|
||||
# On Python ~3.5, str.decode('ascii', 'surrogateescape') is slow.
|
||||
# (fixed in Python 3.6, http://bugs.python.org/issue24870)
|
||||
# Workaround is str.decode('latin1') then translate 0x80-0xff into 0udc80-0udcff.
|
||||
# We can escape special chars and surrogateescape at once.
|
||||
_escape_bytes_table = _escape_table + [chr(i) for i in range(0xdc80, 0xdd00)]
|
||||
|
||||
def escape_bytes(value, mapping=None):
|
||||
return "_binary'%s'" % value.decode('latin1').translate(_escape_bytes_table)
|
||||
|
||||
def escape_str(value, mapping=None):
|
||||
return "'%s'" % escape_string(value, mapping)
|
||||
|
||||
def escape_unicode(value, mapping=None):
|
||||
return escape_str(value, mapping)
|
||||
return u"'%s'" % _escape_unicode(value)
|
||||
|
||||
def escape_bytes(value, mapping=None):
|
||||
# escape_bytes is calld only on Python 3.
|
||||
return escape_str(value.decode('ascii', 'surrogateescape'), mapping)
|
||||
def escape_str(value, mapping=None):
|
||||
return "'%s'" % escape_string(str(value), mapping)
|
||||
|
||||
def escape_None(value, mapping=None):
|
||||
return 'NULL'
|
||||
|
@ -111,6 +146,16 @@ def escape_date(obj, mapping=None):
|
|||
def escape_struct_time(obj, mapping=None):
|
||||
return escape_datetime(datetime.datetime(*obj[:6]))
|
||||
|
||||
def _convert_second_fraction(s):
|
||||
if not s:
|
||||
return 0
|
||||
# Pad zeros to ensure the fraction length in microseconds
|
||||
s = s.ljust(6, '0')
|
||||
return int(s[:6])
|
||||
|
||||
DATETIME_RE = re.compile(r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
|
||||
|
||||
|
||||
def convert_datetime(obj):
|
||||
"""Returns a DATETIME or TIMESTAMP column value as a datetime object:
|
||||
|
||||
|
@ -127,23 +172,22 @@ def convert_datetime(obj):
|
|||
True
|
||||
|
||||
"""
|
||||
if ' ' in obj:
|
||||
sep = ' '
|
||||
elif 'T' in obj:
|
||||
sep = 'T'
|
||||
else:
|
||||
if not PY2 and isinstance(obj, (bytes, bytearray)):
|
||||
obj = obj.decode('ascii')
|
||||
|
||||
m = DATETIME_RE.match(obj)
|
||||
if not m:
|
||||
return convert_date(obj)
|
||||
|
||||
try:
|
||||
ymd, hms = obj.split(sep, 1)
|
||||
usecs = '0'
|
||||
if '.' in hms:
|
||||
hms, usecs = hms.split('.')
|
||||
usecs = float('0.' + usecs) * 1e6
|
||||
return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':')+[usecs] ])
|
||||
groups = list(m.groups())
|
||||
groups[-1] = _convert_second_fraction(groups[-1])
|
||||
return datetime.datetime(*[ int(x) for x in groups ])
|
||||
except ValueError:
|
||||
return convert_date(obj)
|
||||
|
||||
TIMEDELTA_RE = re.compile(r"(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
|
||||
|
||||
|
||||
def convert_timedelta(obj):
|
||||
"""Returns a TIME column as a timedelta object:
|
||||
|
@ -162,16 +206,19 @@ def convert_timedelta(obj):
|
|||
can accept values as (+|-)DD HH:MM:SS. The latter format will not
|
||||
be parsed correctly by this function.
|
||||
"""
|
||||
if not PY2 and isinstance(obj, (bytes, bytearray)):
|
||||
obj = obj.decode('ascii')
|
||||
|
||||
m = TIMEDELTA_RE.match(obj)
|
||||
if not m:
|
||||
return None
|
||||
|
||||
try:
|
||||
microseconds = 0
|
||||
if "." in obj:
|
||||
(obj, tail) = obj.split('.')
|
||||
microseconds = float('0.' + tail) * 1e6
|
||||
hours, minutes, seconds = obj.split(':')
|
||||
negate = 1
|
||||
if hours.startswith("-"):
|
||||
hours = hours[1:]
|
||||
negate = -1
|
||||
groups = list(m.groups())
|
||||
groups[-1] = _convert_second_fraction(groups[-1])
|
||||
negate = -1 if groups[0] else 1
|
||||
hours, minutes, seconds, microseconds = groups[1:]
|
||||
|
||||
tdelta = datetime.timedelta(
|
||||
hours = int(hours),
|
||||
minutes = int(minutes),
|
||||
|
@ -182,6 +229,9 @@ def convert_timedelta(obj):
|
|||
except ValueError:
|
||||
return None
|
||||
|
||||
TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
|
||||
|
||||
|
||||
def convert_time(obj):
|
||||
"""Returns a TIME column as a time object:
|
||||
|
||||
|
@ -204,17 +254,23 @@ def convert_time(obj):
|
|||
to be treated as time-of-day and not a time offset, then you can
|
||||
use set this function as the converter for FIELD_TYPE.TIME.
|
||||
"""
|
||||
if not PY2 and isinstance(obj, (bytes, bytearray)):
|
||||
obj = obj.decode('ascii')
|
||||
|
||||
m = TIME_RE.match(obj)
|
||||
if not m:
|
||||
return None
|
||||
|
||||
try:
|
||||
microseconds = 0
|
||||
if "." in obj:
|
||||
(obj, tail) = obj.split('.')
|
||||
microseconds = float('0.' + tail) * 1e6
|
||||
hours, minutes, seconds = obj.split(':')
|
||||
groups = list(m.groups())
|
||||
groups[-1] = _convert_second_fraction(groups[-1])
|
||||
hours, minutes, seconds, microseconds = groups
|
||||
return datetime.time(hour=int(hours), minute=int(minutes),
|
||||
second=int(seconds), microsecond=int(microseconds))
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
|
||||
def convert_date(obj):
|
||||
"""Returns a DATE column as a date object:
|
||||
|
||||
|
@ -229,6 +285,8 @@ def convert_date(obj):
|
|||
True
|
||||
|
||||
"""
|
||||
if not PY2 and isinstance(obj, (bytes, bytearray)):
|
||||
obj = obj.decode('ascii')
|
||||
try:
|
||||
return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
|
||||
except ValueError:
|
||||
|
@ -256,6 +314,8 @@ def convert_mysql_timestamp(timestamp):
|
|||
True
|
||||
|
||||
"""
|
||||
if not PY2 and isinstance(timestamp, (bytes, bytearray)):
|
||||
timestamp = timestamp.decode('ascii')
|
||||
if timestamp[4] == '-':
|
||||
return convert_datetime(timestamp)
|
||||
timestamp += "0"*(14-len(timestamp)) # padding
|
||||
|
@ -268,6 +328,8 @@ def convert_mysql_timestamp(timestamp):
|
|||
return None
|
||||
|
||||
def convert_set(s):
|
||||
if isinstance(s, (bytes, bytearray)):
|
||||
return set(s.split(b","))
|
||||
return set(s.split(","))
|
||||
|
||||
|
||||
|
@ -278,7 +340,7 @@ def through(x):
|
|||
#def convert_bit(b):
|
||||
# b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
|
||||
# return struct.unpack(">Q", b)[0]
|
||||
#
|
||||
#
|
||||
# the snippet above is right, but MySQLdb doesn't process bits,
|
||||
# so we shouldn't either
|
||||
convert_bit = through
|
||||
|
@ -309,7 +371,9 @@ encoders = {
|
|||
tuple: escape_sequence,
|
||||
list: escape_sequence,
|
||||
set: escape_sequence,
|
||||
frozenset: escape_sequence,
|
||||
dict: escape_dict,
|
||||
bytearray: escape_bytes,
|
||||
type(None): escape_None,
|
||||
datetime.date: escape_date,
|
||||
datetime.datetime: escape_datetime,
|
||||
|
@ -350,7 +414,6 @@ decoders = {
|
|||
|
||||
|
||||
# for MySQLdb compatibility
|
||||
conversions = decoders
|
||||
|
||||
def Thing2Literal(obj):
|
||||
return escape_str(str(obj))
|
||||
conversions = encoders.copy()
|
||||
conversions.update(decoders)
|
||||
Thing2Literal = escape_str
|
||||
|
|
|
@ -5,33 +5,37 @@ import re
|
|||
import warnings
|
||||
|
||||
from ._compat import range_type, text_type, PY2
|
||||
|
||||
from . import err
|
||||
|
||||
|
||||
#: Regular expression for :meth:`Cursor.executemany`.
|
||||
#: executemany only suports simple bulk insert.
|
||||
#: You can use it to load large dataset.
|
||||
RE_INSERT_VALUES = re.compile(r"""(INSERT\s.+\sVALUES\s+)(\(\s*%s\s*(?:,\s*%s\s*)*\))(\s*(?:ON DUPLICATE.*)?)\Z""",
|
||||
re.IGNORECASE | re.DOTALL)
|
||||
RE_INSERT_VALUES = re.compile(
|
||||
r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
|
||||
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
|
||||
r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
|
||||
re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
||||
class Cursor(object):
|
||||
'''
|
||||
"""
|
||||
This is the object you use to interact with the database.
|
||||
'''
|
||||
"""
|
||||
|
||||
#: Max stetement size which :meth:`executemany` generates.
|
||||
#: Max statement size which :meth:`executemany` generates.
|
||||
#:
|
||||
#: Max size of allowed statement is max_allowed_packet - packet_header_size.
|
||||
#: Default value of max_allowed_packet is 1048576.
|
||||
max_stmt_length = 1024000
|
||||
|
||||
_defer_warnings = False
|
||||
|
||||
def __init__(self, connection):
|
||||
'''
|
||||
"""
|
||||
Do not create an instance of a Cursor yourself. Call
|
||||
connections.Connection.cursor().
|
||||
'''
|
||||
"""
|
||||
self.connection = connection
|
||||
self.description = None
|
||||
self.rownumber = 0
|
||||
|
@ -40,11 +44,12 @@ class Cursor(object):
|
|||
self._executed = None
|
||||
self._result = None
|
||||
self._rows = None
|
||||
self._warnings_handled = False
|
||||
|
||||
def close(self):
|
||||
'''
|
||||
"""
|
||||
Closing a cursor just exhausts all remaining data.
|
||||
'''
|
||||
"""
|
||||
conn = self.connection
|
||||
if conn is None:
|
||||
return
|
||||
|
@ -83,6 +88,9 @@ class Cursor(object):
|
|||
"""Get the next query set"""
|
||||
conn = self._get_db()
|
||||
current_result = self._result
|
||||
# for unbuffered queries warnings are only available once whole result has been read
|
||||
if unbuffered:
|
||||
self._show_warnings()
|
||||
if current_result is None or current_result is not conn._result:
|
||||
return None
|
||||
if not current_result.has_next:
|
||||
|
@ -107,17 +115,17 @@ class Cursor(object):
|
|||
if isinstance(args, (tuple, list)):
|
||||
if PY2:
|
||||
args = tuple(map(ensure_bytes, args))
|
||||
return tuple(conn.escape(arg) for arg in args)
|
||||
return tuple(conn.literal(arg) for arg in args)
|
||||
elif isinstance(args, dict):
|
||||
if PY2:
|
||||
args = dict((ensure_bytes(key), ensure_bytes(val)) for
|
||||
(key, val) in args.items())
|
||||
return dict((key, conn.escape(val)) for (key, val) in args.items())
|
||||
return dict((key, conn.literal(val)) for (key, val) in args.items())
|
||||
else:
|
||||
# If it's not a dictionary let's try escaping it anyways.
|
||||
# Worst case it will throw a Value error
|
||||
if PY2:
|
||||
ensure_bytes(args)
|
||||
args = ensure_bytes(args)
|
||||
return conn.escape(args)
|
||||
|
||||
def mogrify(self, query, args=None):
|
||||
|
@ -137,7 +145,19 @@ class Cursor(object):
|
|||
return query
|
||||
|
||||
def execute(self, query, args=None):
|
||||
'''Execute a query'''
|
||||
"""Execute a query
|
||||
|
||||
:param str query: Query to execute.
|
||||
|
||||
:param args: parameters used with query. (optional)
|
||||
:type args: tuple, list or dict
|
||||
|
||||
:return: Number of affected rows
|
||||
:rtype: int
|
||||
|
||||
If args is a list or tuple, %s can be used as a placeholder in the query.
|
||||
If args is a dict, %(name)s can be used as a placeholder in the query.
|
||||
"""
|
||||
while self.nextset():
|
||||
pass
|
||||
|
||||
|
@ -148,17 +168,23 @@ class Cursor(object):
|
|||
return result
|
||||
|
||||
def executemany(self, query, args):
|
||||
# type: (str, list) -> int
|
||||
"""Run several data against one query
|
||||
|
||||
PyMySQL can execute bulkinsert for query like 'INSERT ... VALUES (%s)'.
|
||||
In other form of queries, just run :meth:`execute` many times.
|
||||
:param query: query to execute on server
|
||||
:param args: Sequence of sequences or mappings. It is used as parameter.
|
||||
:return: Number of rows affected, if any.
|
||||
|
||||
This method improves performance on multiple-row INSERT and
|
||||
REPLACE. Otherwise it is equivalent to looping over args with
|
||||
execute().
|
||||
"""
|
||||
if not args:
|
||||
return
|
||||
|
||||
m = RE_INSERT_VALUES.match(query)
|
||||
if m:
|
||||
q_prefix = m.group(1)
|
||||
q_prefix = m.group(1) % ()
|
||||
q_values = m.group(2).rstrip()
|
||||
q_postfix = m.group(3) or ''
|
||||
assert q_values[0] == '(' and q_values[-1] == ')'
|
||||
|
@ -247,7 +273,7 @@ class Cursor(object):
|
|||
return args
|
||||
|
||||
def fetchone(self):
|
||||
''' Fetch the next row '''
|
||||
"""Fetch the next row"""
|
||||
self._check_executed()
|
||||
if self._rows is None or self.rownumber >= len(self._rows):
|
||||
return None
|
||||
|
@ -256,7 +282,7 @@ class Cursor(object):
|
|||
return result
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
''' Fetch several rows '''
|
||||
"""Fetch several rows"""
|
||||
self._check_executed()
|
||||
if self._rows is None:
|
||||
return ()
|
||||
|
@ -266,7 +292,7 @@ class Cursor(object):
|
|||
return result
|
||||
|
||||
def fetchall(self):
|
||||
''' Fetch all the rows '''
|
||||
"""Fetch all the rows"""
|
||||
self._check_executed()
|
||||
if self._rows is None:
|
||||
return ()
|
||||
|
@ -307,14 +333,18 @@ class Cursor(object):
|
|||
self.description = result.description
|
||||
self.lastrowid = result.insert_id
|
||||
self._rows = result.rows
|
||||
self._warnings_handled = False
|
||||
|
||||
if result.warning_count > 0:
|
||||
self._show_warnings(conn)
|
||||
if not self._defer_warnings:
|
||||
self._show_warnings()
|
||||
|
||||
def _show_warnings(self, conn):
|
||||
if self._result and self._result.has_next:
|
||||
def _show_warnings(self):
|
||||
if self._warnings_handled:
|
||||
return
|
||||
ws = conn.show_warnings()
|
||||
self._warnings_handled = True
|
||||
if self._result and (self._result.has_next or not self._result.warning_count):
|
||||
return
|
||||
ws = self._get_db().show_warnings()
|
||||
if ws is None:
|
||||
return
|
||||
for w in ws:
|
||||
|
@ -322,7 +352,7 @@ class Cursor(object):
|
|||
if PY2:
|
||||
if isinstance(msg, unicode):
|
||||
msg = msg.encode('utf-8', 'replace')
|
||||
warnings.warn(str(msg), err.Warning, 4)
|
||||
warnings.warn(err.Warning(*w[1:3]), stacklevel=4)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.fetchone, None)
|
||||
|
@ -373,8 +403,8 @@ class SSCursor(Cursor):
|
|||
or for connections to remote servers over a slow network.
|
||||
|
||||
Instead of copying every row of data into a buffer, this will fetch
|
||||
rows as needed. The upside of this, is the client uses much less memory,
|
||||
and rows are returned much faster when traveling over a slow network,
|
||||
rows as needed. The upside of this is the client uses much less memory,
|
||||
and rows are returned much faster when traveling over a slow network
|
||||
or if the result set is very big.
|
||||
|
||||
There are limitations, though. The MySQL protocol doesn't support
|
||||
|
@ -383,6 +413,8 @@ class SSCursor(Cursor):
|
|||
possible to scroll backwards, as only the current row is held in memory.
|
||||
"""
|
||||
|
||||
_defer_warnings = True
|
||||
|
||||
def _conv_row(self, row):
|
||||
return row
|
||||
|
||||
|
@ -411,14 +443,15 @@ class SSCursor(Cursor):
|
|||
return self._nextset(unbuffered=True)
|
||||
|
||||
def read_next(self):
|
||||
""" Read next row """
|
||||
"""Read next row"""
|
||||
return self._conv_row(self._result._read_rowdata_packet_unbuffered())
|
||||
|
||||
def fetchone(self):
|
||||
""" Fetch next row """
|
||||
"""Fetch next row"""
|
||||
self._check_executed()
|
||||
row = self.read_next()
|
||||
if row is None:
|
||||
self._show_warnings()
|
||||
return None
|
||||
self.rownumber += 1
|
||||
return row
|
||||
|
@ -443,7 +476,7 @@ class SSCursor(Cursor):
|
|||
return self.fetchall_unbuffered()
|
||||
|
||||
def fetchmany(self, size=None):
|
||||
""" Fetch many """
|
||||
"""Fetch many"""
|
||||
self._check_executed()
|
||||
if size is None:
|
||||
size = self.arraysize
|
||||
|
@ -452,6 +485,7 @@ class SSCursor(Cursor):
|
|||
for i in range_type(size):
|
||||
row = self.read_next()
|
||||
if row is None:
|
||||
self._show_warnings()
|
||||
break
|
||||
rows.append(row)
|
||||
self.rownumber += 1
|
||||
|
@ -482,4 +516,4 @@ class SSCursor(Cursor):
|
|||
|
||||
|
||||
class SSDictCursor(DictCursorMixin, SSCursor):
|
||||
""" An unbuffered cursor, which returns results as a dictionary """
|
||||
"""An unbuffered cursor, which returns results as a dictionary"""
|
||||
|
|
|
@ -68,10 +68,12 @@ class NotSupportedError(DatabaseError):
|
|||
|
||||
error_map = {}
|
||||
|
||||
|
||||
def _map_error(exc, *errors):
|
||||
for error in errors:
|
||||
error_map[error] = exc
|
||||
|
||||
|
||||
_map_error(ProgrammingError, ER.DB_CREATE_EXISTS, ER.SYNTAX_ERROR,
|
||||
ER.PARSE_ERROR, ER.NO_SUCH_TABLE, ER.WRONG_DB_NAME,
|
||||
ER.WRONG_TABLE_NAME, ER.FIELD_SPECIFIED_TWICE,
|
||||
|
@ -89,32 +91,17 @@ _map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
|
|||
ER.CON_COUNT_ERROR, ER.TABLEACCESS_DENIED_ERROR,
|
||||
ER.COLUMNACCESS_DENIED_ERROR)
|
||||
|
||||
|
||||
del _map_error, ER
|
||||
|
||||
|
||||
def _get_error_info(data):
|
||||
def raise_mysql_exception(data):
|
||||
errno = struct.unpack('<h', data[1:3])[0]
|
||||
is_41 = data[3:4] == b"#"
|
||||
if is_41:
|
||||
# version 4.1
|
||||
sqlstate = data[4:9].decode("utf8", 'replace')
|
||||
errorvalue = data[9:].decode("utf8", 'replace')
|
||||
return (errno, sqlstate, errorvalue)
|
||||
# client protocol 4.1
|
||||
errval = data[9:].decode('utf-8', 'replace')
|
||||
else:
|
||||
# version 4.0
|
||||
return (errno, None, data[3:].decode("utf8", 'replace'))
|
||||
|
||||
|
||||
def _check_mysql_exception(errinfo):
|
||||
errno, sqlstate, errorvalue = errinfo
|
||||
errorclass = error_map.get(errno, None)
|
||||
if errorclass:
|
||||
raise errorclass(errno, errorvalue)
|
||||
|
||||
# couldn't find the right error number
|
||||
raise InternalError(errno, errorvalue)
|
||||
|
||||
|
||||
def raise_mysql_exception(data):
|
||||
errinfo = _get_error_info(data)
|
||||
_check_mysql_exception(errinfo)
|
||||
errval = data[3:].decode('utf-8', 'replace')
|
||||
errorclass = error_map.get(errno, InternalError)
|
||||
raise errorclass(errno, errval)
|
||||
|
|
|
@ -1,16 +1,20 @@
|
|||
from time import localtime
|
||||
from datetime import date, datetime, time, timedelta
|
||||
|
||||
|
||||
Date = date
|
||||
Time = time
|
||||
TimeDelta = timedelta
|
||||
Timestamp = datetime
|
||||
|
||||
|
||||
def DateFromTicks(ticks):
|
||||
return date(*localtime(ticks)[:3])
|
||||
|
||||
|
||||
def TimeFromTicks(ticks):
|
||||
return time(*localtime(ticks)[3:6])
|
||||
|
||||
|
||||
def TimestampFromTicks(ticks):
|
||||
return datetime(*localtime(ticks)[:6])
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import struct
|
||||
|
||||
|
||||
def byte2int(b):
|
||||
if isinstance(b, int):
|
||||
return b
|
||||
else:
|
||||
return struct.unpack("!B", b)[0]
|
||||
|
||||
|
||||
def int2byte(i):
|
||||
return struct.pack("!B", i)
|
||||
|
||||
|
||||
def join_bytes(bs):
|
||||
if len(bs) == 0:
|
||||
return ""
|
||||
|
|
|
@ -1,45 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RSA module
|
||||
|
||||
Module for calculating large primes, and RSA encryption, decryption, signing
|
||||
and verification. Includes generating public and private keys.
|
||||
|
||||
WARNING: this implementation does not use random padding, compression of the
|
||||
cleartext input to prevent repetitions, or other common security improvements.
|
||||
Use with care.
|
||||
|
||||
If you want to have a more secure implementation, use the functions from the
|
||||
``rsa.pkcs1`` module.
|
||||
|
||||
"""
|
||||
|
||||
__author__ = "Sybren Stuvel, Barry Mead and Yesudeep Mangalapilly"
|
||||
__date__ = "2015-07-29"
|
||||
__version__ = '3.2'
|
||||
|
||||
from rsa.key import newkeys, PrivateKey, PublicKey
|
||||
from rsa.pkcs1 import encrypt, decrypt, sign, verify, DecryptionError, \
|
||||
VerificationError
|
||||
|
||||
# Do doctest if we're run directly
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
||||
__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify", 'PublicKey',
|
||||
'PrivateKey', 'DecryptionError', 'VerificationError']
|
||||
|
|
@ -1,160 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Python compatibility wrappers."""
|
||||
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
import sys
|
||||
from struct import pack
|
||||
|
||||
try:
|
||||
MAX_INT = sys.maxsize
|
||||
except AttributeError:
|
||||
MAX_INT = sys.maxint
|
||||
|
||||
MAX_INT64 = (1 << 63) - 1
|
||||
MAX_INT32 = (1 << 31) - 1
|
||||
MAX_INT16 = (1 << 15) - 1
|
||||
|
||||
# Determine the word size of the processor.
|
||||
if MAX_INT == MAX_INT64:
|
||||
# 64-bit processor.
|
||||
MACHINE_WORD_SIZE = 64
|
||||
elif MAX_INT == MAX_INT32:
|
||||
# 32-bit processor.
|
||||
MACHINE_WORD_SIZE = 32
|
||||
else:
|
||||
# Else we just assume 64-bit processor keeping up with modern times.
|
||||
MACHINE_WORD_SIZE = 64
|
||||
|
||||
|
||||
try:
|
||||
# < Python3
|
||||
unicode_type = unicode
|
||||
have_python3 = False
|
||||
except NameError:
|
||||
# Python3.
|
||||
unicode_type = str
|
||||
have_python3 = True
|
||||
|
||||
# Fake byte literals.
|
||||
if str is unicode_type:
|
||||
def byte_literal(s):
|
||||
return s.encode('latin1')
|
||||
else:
|
||||
def byte_literal(s):
|
||||
return s
|
||||
|
||||
# ``long`` is no more. Do type detection using this instead.
|
||||
try:
|
||||
integer_types = (int, long)
|
||||
except NameError:
|
||||
integer_types = (int,)
|
||||
|
||||
b = byte_literal
|
||||
|
||||
try:
|
||||
# Python 2.6 or higher.
|
||||
bytes_type = bytes
|
||||
except NameError:
|
||||
# Python 2.5
|
||||
bytes_type = str
|
||||
|
||||
|
||||
# To avoid calling b() multiple times in tight loops.
|
||||
ZERO_BYTE = b('\x00')
|
||||
EMPTY_BYTE = b('')
|
||||
|
||||
|
||||
def is_bytes(obj):
|
||||
"""
|
||||
Determines whether the given value is a byte string.
|
||||
|
||||
:param obj:
|
||||
The value to test.
|
||||
:returns:
|
||||
``True`` if ``value`` is a byte string; ``False`` otherwise.
|
||||
"""
|
||||
return isinstance(obj, bytes_type)
|
||||
|
||||
|
||||
def is_integer(obj):
|
||||
"""
|
||||
Determines whether the given value is an integer.
|
||||
|
||||
:param obj:
|
||||
The value to test.
|
||||
:returns:
|
||||
``True`` if ``value`` is an integer; ``False`` otherwise.
|
||||
"""
|
||||
return isinstance(obj, integer_types)
|
||||
|
||||
|
||||
def byte(num):
|
||||
"""
|
||||
Converts a number between 0 and 255 (both inclusive) to a base-256 (byte)
|
||||
representation.
|
||||
|
||||
Use it as a replacement for ``chr`` where you are expecting a byte
|
||||
because this will work on all current versions of Python::
|
||||
|
||||
:param num:
|
||||
An unsigned integer between 0 and 255 (both inclusive).
|
||||
:returns:
|
||||
A single byte.
|
||||
"""
|
||||
return pack("B", num)
|
||||
|
||||
|
||||
def get_word_alignment(num, force_arch=64,
|
||||
_machine_word_size=MACHINE_WORD_SIZE):
|
||||
"""
|
||||
Returns alignment details for the given number based on the platform
|
||||
Python is running on.
|
||||
|
||||
:param num:
|
||||
Unsigned integral number.
|
||||
:param force_arch:
|
||||
If you don't want to use 64-bit unsigned chunks, set this to
|
||||
anything other than 64. 32-bit chunks will be preferred then.
|
||||
Default 64 will be used when on a 64-bit machine.
|
||||
:param _machine_word_size:
|
||||
(Internal) The machine word size used for alignment.
|
||||
:returns:
|
||||
4-tuple::
|
||||
|
||||
(word_bits, word_bytes,
|
||||
max_uint, packing_format_type)
|
||||
"""
|
||||
max_uint64 = 0xffffffffffffffff
|
||||
max_uint32 = 0xffffffff
|
||||
max_uint16 = 0xffff
|
||||
max_uint8 = 0xff
|
||||
|
||||
if force_arch == 64 and _machine_word_size >= 64 and num > max_uint32:
|
||||
# 64-bit unsigned integer.
|
||||
return 64, 8, max_uint64, "Q"
|
||||
elif num > max_uint16:
|
||||
# 32-bit unsigned integer
|
||||
return 32, 4, max_uint32, "L"
|
||||
elif num > max_uint8:
|
||||
# 16-bit unsigned integer.
|
||||
return 16, 2, max_uint16, "H"
|
||||
else:
|
||||
# 8-bit unsigned integer.
|
||||
return 8, 1, max_uint8, "B"
|
|
@ -1,442 +0,0 @@
|
|||
"""RSA module
|
||||
pri = k[1] //Private part of keys d,p,q
|
||||
|
||||
Module for calculating large primes, and RSA encryption, decryption,
|
||||
signing and verification. Includes generating public and private keys.
|
||||
|
||||
WARNING: this code implements the mathematics of RSA. It is not suitable for
|
||||
real-world secure cryptography purposes. It has not been reviewed by a security
|
||||
expert. It does not include padding of data. There are many ways in which the
|
||||
output of this module, when used without any modification, can be sucessfully
|
||||
attacked.
|
||||
"""
|
||||
|
||||
__author__ = "Sybren Stuvel, Marloes de Boer and Ivo Tamboer"
|
||||
__date__ = "2010-02-05"
|
||||
__version__ = '1.3.3'
|
||||
|
||||
# NOTE: Python's modulo can return negative numbers. We compensate for
|
||||
# this behaviour using the abs() function
|
||||
|
||||
from cPickle import dumps, loads
|
||||
import base64
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import types
|
||||
import zlib
|
||||
|
||||
from rsa._compat import byte
|
||||
|
||||
# Display a warning that this insecure version is imported.
|
||||
import warnings
|
||||
warnings.warn('Insecure version of the RSA module is imported as %s, be careful'
|
||||
% __name__)
|
||||
|
||||
def gcd(p, q):
|
||||
"""Returns the greatest common divisor of p and q
|
||||
|
||||
|
||||
>>> gcd(42, 6)
|
||||
6
|
||||
"""
|
||||
if p<q: return gcd(q, p)
|
||||
if q == 0: return p
|
||||
return gcd(q, abs(p%q))
|
||||
|
||||
def bytes2int(bytes):
|
||||
"""Converts a list of bytes or a string to an integer
|
||||
|
||||
>>> (128*256 + 64)*256 + + 15
|
||||
8405007
|
||||
>>> l = [128, 64, 15]
|
||||
>>> bytes2int(l)
|
||||
8405007
|
||||
"""
|
||||
|
||||
if not (type(bytes) is types.ListType or type(bytes) is types.StringType):
|
||||
raise TypeError("You must pass a string or a list")
|
||||
|
||||
# Convert byte stream to integer
|
||||
integer = 0
|
||||
for byte in bytes:
|
||||
integer *= 256
|
||||
if type(byte) is types.StringType: byte = ord(byte)
|
||||
integer += byte
|
||||
|
||||
return integer
|
||||
|
||||
def int2bytes(number):
|
||||
"""Converts a number to a string of bytes
|
||||
|
||||
>>> bytes2int(int2bytes(123456789))
|
||||
123456789
|
||||
"""
|
||||
|
||||
if not (type(number) is types.LongType or type(number) is types.IntType):
|
||||
raise TypeError("You must pass a long or an int")
|
||||
|
||||
string = ""
|
||||
|
||||
while number > 0:
|
||||
string = "%s%s" % (byte(number & 0xFF), string)
|
||||
number /= 256
|
||||
|
||||
return string
|
||||
|
||||
def fast_exponentiation(a, p, n):
|
||||
"""Calculates r = a^p mod n
|
||||
"""
|
||||
result = a % n
|
||||
remainders = []
|
||||
while p != 1:
|
||||
remainders.append(p & 1)
|
||||
p = p >> 1
|
||||
while remainders:
|
||||
rem = remainders.pop()
|
||||
result = ((a ** rem) * result ** 2) % n
|
||||
return result
|
||||
|
||||
def read_random_int(nbits):
|
||||
"""Reads a random integer of approximately nbits bits rounded up
|
||||
to whole bytes"""
|
||||
|
||||
nbytes = ceil(nbits/8.)
|
||||
randomdata = os.urandom(nbytes)
|
||||
return bytes2int(randomdata)
|
||||
|
||||
def ceil(x):
|
||||
"""ceil(x) -> int(math.ceil(x))"""
|
||||
|
||||
return int(math.ceil(x))
|
||||
|
||||
def randint(minvalue, maxvalue):
|
||||
"""Returns a random integer x with minvalue <= x <= maxvalue"""
|
||||
|
||||
# Safety - get a lot of random data even if the range is fairly
|
||||
# small
|
||||
min_nbits = 32
|
||||
|
||||
# The range of the random numbers we need to generate
|
||||
range = maxvalue - minvalue
|
||||
|
||||
# Which is this number of bytes
|
||||
rangebytes = ceil(math.log(range, 2) / 8.)
|
||||
|
||||
# Convert to bits, but make sure it's always at least min_nbits*2
|
||||
rangebits = max(rangebytes * 8, min_nbits * 2)
|
||||
|
||||
# Take a random number of bits between min_nbits and rangebits
|
||||
nbits = random.randint(min_nbits, rangebits)
|
||||
|
||||
return (read_random_int(nbits) % range) + minvalue
|
||||
|
||||
def fermat_little_theorem(p):
|
||||
"""Returns 1 if p may be prime, and something else if p definitely
|
||||
is not prime"""
|
||||
|
||||
a = randint(1, p-1)
|
||||
return fast_exponentiation(a, p-1, p)
|
||||
|
||||
def jacobi(a, b):
|
||||
"""Calculates the value of the Jacobi symbol (a/b)
|
||||
"""
|
||||
|
||||
if a % b == 0:
|
||||
return 0
|
||||
result = 1
|
||||
while a > 1:
|
||||
if a & 1:
|
||||
if ((a-1)*(b-1) >> 2) & 1:
|
||||
result = -result
|
||||
b, a = a, b % a
|
||||
else:
|
||||
if ((b ** 2 - 1) >> 3) & 1:
|
||||
result = -result
|
||||
a = a >> 1
|
||||
return result
|
||||
|
||||
def jacobi_witness(x, n):
|
||||
"""Returns False if n is an Euler pseudo-prime with base x, and
|
||||
True otherwise.
|
||||
"""
|
||||
|
||||
j = jacobi(x, n) % n
|
||||
f = fast_exponentiation(x, (n-1)/2, n)
|
||||
|
||||
if j == f: return False
|
||||
return True
|
||||
|
||||
def randomized_primality_testing(n, k):
|
||||
"""Calculates whether n is composite (which is always correct) or
|
||||
prime (which is incorrect with error probability 2**-k)
|
||||
|
||||
Returns False if the number if composite, and True if it's
|
||||
probably prime.
|
||||
"""
|
||||
|
||||
q = 0.5 # Property of the jacobi_witness function
|
||||
|
||||
# t = int(math.ceil(k / math.log(1/q, 2)))
|
||||
t = ceil(k / math.log(1/q, 2))
|
||||
for i in range(t+1):
|
||||
x = randint(1, n-1)
|
||||
if jacobi_witness(x, n): return False
|
||||
|
||||
return True
|
||||
|
||||
def is_prime(number):
|
||||
"""Returns True if the number is prime, and False otherwise.
|
||||
|
||||
>>> is_prime(42)
|
||||
0
|
||||
>>> is_prime(41)
|
||||
1
|
||||
"""
|
||||
|
||||
"""
|
||||
if not fermat_little_theorem(number) == 1:
|
||||
# Not prime, according to Fermat's little theorem
|
||||
return False
|
||||
"""
|
||||
|
||||
if randomized_primality_testing(number, 5):
|
||||
# Prime, according to Jacobi
|
||||
return True
|
||||
|
||||
# Not prime
|
||||
return False
|
||||
|
||||
|
||||
def getprime(nbits):
|
||||
"""Returns a prime number of max. 'math.ceil(nbits/8)*8' bits. In
|
||||
other words: nbits is rounded up to whole bytes.
|
||||
|
||||
>>> p = getprime(8)
|
||||
>>> is_prime(p-1)
|
||||
0
|
||||
>>> is_prime(p)
|
||||
1
|
||||
>>> is_prime(p+1)
|
||||
0
|
||||
"""
|
||||
|
||||
nbytes = int(math.ceil(nbits/8.))
|
||||
|
||||
while True:
|
||||
integer = read_random_int(nbits)
|
||||
|
||||
# Make sure it's odd
|
||||
integer |= 1
|
||||
|
||||
# Test for primeness
|
||||
if is_prime(integer): break
|
||||
|
||||
# Retry if not prime
|
||||
|
||||
return integer
|
||||
|
||||
def are_relatively_prime(a, b):
|
||||
"""Returns True if a and b are relatively prime, and False if they
|
||||
are not.
|
||||
|
||||
>>> are_relatively_prime(2, 3)
|
||||
1
|
||||
>>> are_relatively_prime(2, 4)
|
||||
0
|
||||
"""
|
||||
|
||||
d = gcd(a, b)
|
||||
return (d == 1)
|
||||
|
||||
def find_p_q(nbits):
|
||||
"""Returns a tuple of two different primes of nbits bits"""
|
||||
|
||||
p = getprime(nbits)
|
||||
while True:
|
||||
q = getprime(nbits)
|
||||
if not q == p: break
|
||||
|
||||
return (p, q)
|
||||
|
||||
def extended_euclid_gcd(a, b):
|
||||
"""Returns a tuple (d, i, j) such that d = gcd(a, b) = ia + jb
|
||||
"""
|
||||
|
||||
if b == 0:
|
||||
return (a, 1, 0)
|
||||
|
||||
q = abs(a % b)
|
||||
r = long(a / b)
|
||||
(d, k, l) = extended_euclid_gcd(b, q)
|
||||
|
||||
return (d, l, k - l*r)
|
||||
|
||||
# Main function: calculate encryption and decryption keys
|
||||
def calculate_keys(p, q, nbits):
|
||||
"""Calculates an encryption and a decryption key for p and q, and
|
||||
returns them as a tuple (e, d)"""
|
||||
|
||||
n = p * q
|
||||
phi_n = (p-1) * (q-1)
|
||||
|
||||
while True:
|
||||
# Make sure e has enough bits so we ensure "wrapping" through
|
||||
# modulo n
|
||||
e = getprime(max(8, nbits/2))
|
||||
if are_relatively_prime(e, n) and are_relatively_prime(e, phi_n): break
|
||||
|
||||
(d, i, j) = extended_euclid_gcd(e, phi_n)
|
||||
|
||||
if not d == 1:
|
||||
raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n))
|
||||
|
||||
if not (e * i) % phi_n == 1:
|
||||
raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n))
|
||||
|
||||
return (e, i)
|
||||
|
||||
|
||||
def gen_keys(nbits):
|
||||
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
|
||||
|
||||
Note: this can take a long time, depending on the key size.
|
||||
"""
|
||||
|
||||
while True:
|
||||
(p, q) = find_p_q(nbits)
|
||||
(e, d) = calculate_keys(p, q, nbits)
|
||||
|
||||
# For some reason, d is sometimes negative. We don't know how
|
||||
# to fix it (yet), so we keep trying until everything is shiny
|
||||
if d > 0: break
|
||||
|
||||
return (p, q, e, d)
|
||||
|
||||
def gen_pubpriv_keys(nbits):
|
||||
"""Generates public and private keys, and returns them as (pub,
|
||||
priv).
|
||||
|
||||
The public key consists of a dict {e: ..., , n: ....). The private
|
||||
key consists of a dict {d: ...., p: ...., q: ....).
|
||||
"""
|
||||
|
||||
(p, q, e, d) = gen_keys(nbits)
|
||||
|
||||
return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} )
|
||||
|
||||
def encrypt_int(message, ekey, n):
|
||||
"""Encrypts a message using encryption key 'ekey', working modulo
|
||||
n"""
|
||||
|
||||
if type(message) is types.IntType:
|
||||
return encrypt_int(long(message), ekey, n)
|
||||
|
||||
if not type(message) is types.LongType:
|
||||
raise TypeError("You must pass a long or an int")
|
||||
|
||||
if message > 0 and \
|
||||
math.floor(math.log(message, 2)) > math.floor(math.log(n, 2)):
|
||||
raise OverflowError("The message is too long")
|
||||
|
||||
return fast_exponentiation(message, ekey, n)
|
||||
|
||||
def decrypt_int(cyphertext, dkey, n):
|
||||
"""Decrypts a cypher text using the decryption key 'dkey', working
|
||||
modulo n"""
|
||||
|
||||
return encrypt_int(cyphertext, dkey, n)
|
||||
|
||||
def sign_int(message, dkey, n):
|
||||
"""Signs 'message' using key 'dkey', working modulo n"""
|
||||
|
||||
return decrypt_int(message, dkey, n)
|
||||
|
||||
def verify_int(signed, ekey, n):
|
||||
"""verifies 'signed' using key 'ekey', working modulo n"""
|
||||
|
||||
return encrypt_int(signed, ekey, n)
|
||||
|
||||
def picklechops(chops):
|
||||
"""Pickles and base64encodes it's argument chops"""
|
||||
|
||||
value = zlib.compress(dumps(chops))
|
||||
encoded = base64.encodestring(value)
|
||||
return encoded.strip()
|
||||
|
||||
def unpicklechops(string):
|
||||
"""base64decodes and unpickes it's argument string into chops"""
|
||||
|
||||
return loads(zlib.decompress(base64.decodestring(string)))
|
||||
|
||||
def chopstring(message, key, n, funcref):
|
||||
"""Splits 'message' into chops that are at most as long as n,
|
||||
converts these into integers, and calls funcref(integer, key, n)
|
||||
for each chop.
|
||||
|
||||
Used by 'encrypt' and 'sign'.
|
||||
"""
|
||||
|
||||
msglen = len(message)
|
||||
mbits = msglen * 8
|
||||
nbits = int(math.floor(math.log(n, 2)))
|
||||
nbytes = nbits / 8
|
||||
blocks = msglen / nbytes
|
||||
|
||||
if msglen % nbytes > 0:
|
||||
blocks += 1
|
||||
|
||||
cypher = []
|
||||
|
||||
for bindex in range(blocks):
|
||||
offset = bindex * nbytes
|
||||
block = message[offset:offset+nbytes]
|
||||
value = bytes2int(block)
|
||||
cypher.append(funcref(value, key, n))
|
||||
|
||||
return picklechops(cypher)
|
||||
|
||||
def gluechops(chops, key, n, funcref):
|
||||
"""Glues chops back together into a string. calls
|
||||
funcref(integer, key, n) for each chop.
|
||||
|
||||
Used by 'decrypt' and 'verify'.
|
||||
"""
|
||||
message = ""
|
||||
|
||||
chops = unpicklechops(chops)
|
||||
|
||||
for cpart in chops:
|
||||
mpart = funcref(cpart, key, n)
|
||||
message += int2bytes(mpart)
|
||||
|
||||
return message
|
||||
|
||||
def encrypt(message, key):
|
||||
"""Encrypts a string 'message' with the public key 'key'"""
|
||||
|
||||
return chopstring(message, key['e'], key['n'], encrypt_int)
|
||||
|
||||
def sign(message, key):
|
||||
"""Signs a string 'message' with the private key 'key'"""
|
||||
|
||||
return chopstring(message, key['d'], key['p']*key['q'], decrypt_int)
|
||||
|
||||
def decrypt(cypher, key):
|
||||
"""Decrypts a cypher with the private key 'key'"""
|
||||
|
||||
return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int)
|
||||
|
||||
def verify(cypher, key):
|
||||
"""Verifies a cypher with the public key 'key'"""
|
||||
|
||||
return gluechops(cypher, key['e'], key['n'], encrypt_int)
|
||||
|
||||
# Do doctest if we're not imported
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
||||
__all__ = ["gen_pubpriv_keys", "encrypt", "decrypt", "sign", "verify"]
|
||||
|
|
@ -1,529 +0,0 @@
|
|||
"""RSA module
|
||||
|
||||
Module for calculating large primes, and RSA encryption, decryption,
|
||||
signing and verification. Includes generating public and private keys.
|
||||
|
||||
WARNING: this implementation does not use random padding, compression of the
|
||||
cleartext input to prevent repetitions, or other common security improvements.
|
||||
Use with care.
|
||||
|
||||
"""
|
||||
|
||||
__author__ = "Sybren Stuvel, Marloes de Boer, Ivo Tamboer, and Barry Mead"
|
||||
__date__ = "2010-02-08"
|
||||
__version__ = '2.0'
|
||||
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import sys
|
||||
import types
|
||||
from rsa._compat import byte
|
||||
|
||||
# Display a warning that this insecure version is imported.
|
||||
import warnings
|
||||
warnings.warn('Insecure version of the RSA module is imported as %s' % __name__)
|
||||
|
||||
|
||||
def bit_size(number):
|
||||
"""Returns the number of bits required to hold a specific long number"""
|
||||
|
||||
return int(math.ceil(math.log(number,2)))
|
||||
|
||||
def gcd(p, q):
|
||||
"""Returns the greatest common divisor of p and q
|
||||
>>> gcd(48, 180)
|
||||
12
|
||||
"""
|
||||
# Iterateive Version is faster and uses much less stack space
|
||||
while q != 0:
|
||||
if p < q: (p,q) = (q,p)
|
||||
(p,q) = (q, p % q)
|
||||
return p
|
||||
|
||||
|
||||
def bytes2int(bytes):
|
||||
"""Converts a list of bytes or a string to an integer
|
||||
|
||||
>>> (((128 * 256) + 64) * 256) + 15
|
||||
8405007
|
||||
>>> l = [128, 64, 15]
|
||||
>>> bytes2int(l) #same as bytes2int('\x80@\x0f')
|
||||
8405007
|
||||
"""
|
||||
|
||||
if not (type(bytes) is types.ListType or type(bytes) is types.StringType):
|
||||
raise TypeError("You must pass a string or a list")
|
||||
|
||||
# Convert byte stream to integer
|
||||
integer = 0
|
||||
for byte in bytes:
|
||||
integer *= 256
|
||||
if type(byte) is types.StringType: byte = ord(byte)
|
||||
integer += byte
|
||||
|
||||
return integer
|
||||
|
||||
def int2bytes(number):
|
||||
"""
|
||||
Converts a number to a string of bytes
|
||||
"""
|
||||
|
||||
if not (type(number) is types.LongType or type(number) is types.IntType):
|
||||
raise TypeError("You must pass a long or an int")
|
||||
|
||||
string = ""
|
||||
|
||||
while number > 0:
|
||||
string = "%s%s" % (byte(number & 0xFF), string)
|
||||
number /= 256
|
||||
|
||||
return string
|
||||
|
||||
def to64(number):
|
||||
"""Converts a number in the range of 0 to 63 into base 64 digit
|
||||
character in the range of '0'-'9', 'A'-'Z', 'a'-'z','-','_'.
|
||||
|
||||
>>> to64(10)
|
||||
'A'
|
||||
"""
|
||||
|
||||
if not (type(number) is types.LongType or type(number) is types.IntType):
|
||||
raise TypeError("You must pass a long or an int")
|
||||
|
||||
if 0 <= number <= 9: #00-09 translates to '0' - '9'
|
||||
return byte(number + 48)
|
||||
|
||||
if 10 <= number <= 35:
|
||||
return byte(number + 55) #10-35 translates to 'A' - 'Z'
|
||||
|
||||
if 36 <= number <= 61:
|
||||
return byte(number + 61) #36-61 translates to 'a' - 'z'
|
||||
|
||||
if number == 62: # 62 translates to '-' (minus)
|
||||
return byte(45)
|
||||
|
||||
if number == 63: # 63 translates to '_' (underscore)
|
||||
return byte(95)
|
||||
|
||||
raise ValueError('Invalid Base64 value: %i' % number)
|
||||
|
||||
|
||||
def from64(number):
|
||||
"""Converts an ordinal character value in the range of
|
||||
0-9,A-Z,a-z,-,_ to a number in the range of 0-63.
|
||||
|
||||
>>> from64(49)
|
||||
1
|
||||
"""
|
||||
|
||||
if not (type(number) is types.LongType or type(number) is types.IntType):
|
||||
raise TypeError("You must pass a long or an int")
|
||||
|
||||
if 48 <= number <= 57: #ord('0') - ord('9') translates to 0-9
|
||||
return(number - 48)
|
||||
|
||||
if 65 <= number <= 90: #ord('A') - ord('Z') translates to 10-35
|
||||
return(number - 55)
|
||||
|
||||
if 97 <= number <= 122: #ord('a') - ord('z') translates to 36-61
|
||||
return(number - 61)
|
||||
|
||||
if number == 45: #ord('-') translates to 62
|
||||
return(62)
|
||||
|
||||
if number == 95: #ord('_') translates to 63
|
||||
return(63)
|
||||
|
||||
raise ValueError('Invalid Base64 value: %i' % number)
|
||||
|
||||
|
||||
def int2str64(number):
|
||||
"""Converts a number to a string of base64 encoded characters in
|
||||
the range of '0'-'9','A'-'Z,'a'-'z','-','_'.
|
||||
|
||||
>>> int2str64(123456789)
|
||||
'7MyqL'
|
||||
"""
|
||||
|
||||
if not (type(number) is types.LongType or type(number) is types.IntType):
|
||||
raise TypeError("You must pass a long or an int")
|
||||
|
||||
string = ""
|
||||
|
||||
while number > 0:
|
||||
string = "%s%s" % (to64(number & 0x3F), string)
|
||||
number /= 64
|
||||
|
||||
return string
|
||||
|
||||
|
||||
def str642int(string):
|
||||
"""Converts a base64 encoded string into an integer.
|
||||
The chars of this string in in the range '0'-'9','A'-'Z','a'-'z','-','_'
|
||||
|
||||
>>> str642int('7MyqL')
|
||||
123456789
|
||||
"""
|
||||
|
||||
if not (type(string) is types.ListType or type(string) is types.StringType):
|
||||
raise TypeError("You must pass a string or a list")
|
||||
|
||||
integer = 0
|
||||
for byte in string:
|
||||
integer *= 64
|
||||
if type(byte) is types.StringType: byte = ord(byte)
|
||||
integer += from64(byte)
|
||||
|
||||
return integer
|
||||
|
||||
def read_random_int(nbits):
|
||||
"""Reads a random integer of approximately nbits bits rounded up
|
||||
to whole bytes"""
|
||||
|
||||
nbytes = int(math.ceil(nbits/8.))
|
||||
randomdata = os.urandom(nbytes)
|
||||
return bytes2int(randomdata)
|
||||
|
||||
def randint(minvalue, maxvalue):
|
||||
"""Returns a random integer x with minvalue <= x <= maxvalue"""
|
||||
|
||||
# Safety - get a lot of random data even if the range is fairly
|
||||
# small
|
||||
min_nbits = 32
|
||||
|
||||
# The range of the random numbers we need to generate
|
||||
range = (maxvalue - minvalue) + 1
|
||||
|
||||
# Which is this number of bytes
|
||||
rangebytes = ((bit_size(range) + 7) / 8)
|
||||
|
||||
# Convert to bits, but make sure it's always at least min_nbits*2
|
||||
rangebits = max(rangebytes * 8, min_nbits * 2)
|
||||
|
||||
# Take a random number of bits between min_nbits and rangebits
|
||||
nbits = random.randint(min_nbits, rangebits)
|
||||
|
||||
return (read_random_int(nbits) % range) + minvalue
|
||||
|
||||
def jacobi(a, b):
|
||||
"""Calculates the value of the Jacobi symbol (a/b)
|
||||
where both a and b are positive integers, and b is odd
|
||||
"""
|
||||
|
||||
if a == 0: return 0
|
||||
result = 1
|
||||
while a > 1:
|
||||
if a & 1:
|
||||
if ((a-1)*(b-1) >> 2) & 1:
|
||||
result = -result
|
||||
a, b = b % a, a
|
||||
else:
|
||||
if (((b * b) - 1) >> 3) & 1:
|
||||
result = -result
|
||||
a >>= 1
|
||||
if a == 0: return 0
|
||||
return result
|
||||
|
||||
def jacobi_witness(x, n):
|
||||
"""Returns False if n is an Euler pseudo-prime with base x, and
|
||||
True otherwise.
|
||||
"""
|
||||
|
||||
j = jacobi(x, n) % n
|
||||
f = pow(x, (n-1)/2, n)
|
||||
|
||||
if j == f: return False
|
||||
return True
|
||||
|
||||
def randomized_primality_testing(n, k):
|
||||
"""Calculates whether n is composite (which is always correct) or
|
||||
prime (which is incorrect with error probability 2**-k)
|
||||
|
||||
Returns False if the number is composite, and True if it's
|
||||
probably prime.
|
||||
"""
|
||||
|
||||
# 50% of Jacobi-witnesses can report compositness of non-prime numbers
|
||||
|
||||
for i in range(k):
|
||||
x = randint(1, n-1)
|
||||
if jacobi_witness(x, n): return False
|
||||
|
||||
return True
|
||||
|
||||
def is_prime(number):
|
||||
"""Returns True if the number is prime, and False otherwise.
|
||||
|
||||
>>> is_prime(42)
|
||||
0
|
||||
>>> is_prime(41)
|
||||
1
|
||||
"""
|
||||
|
||||
if randomized_primality_testing(number, 6):
|
||||
# Prime, according to Jacobi
|
||||
return True
|
||||
|
||||
# Not prime
|
||||
return False
|
||||
|
||||
|
||||
def getprime(nbits):
|
||||
"""Returns a prime number of max. 'math.ceil(nbits/8)*8' bits. In
|
||||
other words: nbits is rounded up to whole bytes.
|
||||
|
||||
>>> p = getprime(8)
|
||||
>>> is_prime(p-1)
|
||||
0
|
||||
>>> is_prime(p)
|
||||
1
|
||||
>>> is_prime(p+1)
|
||||
0
|
||||
"""
|
||||
|
||||
while True:
|
||||
integer = read_random_int(nbits)
|
||||
|
||||
# Make sure it's odd
|
||||
integer |= 1
|
||||
|
||||
# Test for primeness
|
||||
if is_prime(integer): break
|
||||
|
||||
# Retry if not prime
|
||||
|
||||
return integer
|
||||
|
||||
def are_relatively_prime(a, b):
|
||||
"""Returns True if a and b are relatively prime, and False if they
|
||||
are not.
|
||||
|
||||
>>> are_relatively_prime(2, 3)
|
||||
1
|
||||
>>> are_relatively_prime(2, 4)
|
||||
0
|
||||
"""
|
||||
|
||||
d = gcd(a, b)
|
||||
return (d == 1)
|
||||
|
||||
def find_p_q(nbits):
|
||||
"""Returns a tuple of two different primes of nbits bits"""
|
||||
pbits = nbits + (nbits/16) #Make sure that p and q aren't too close
|
||||
qbits = nbits - (nbits/16) #or the factoring programs can factor n
|
||||
p = getprime(pbits)
|
||||
while True:
|
||||
q = getprime(qbits)
|
||||
#Make sure p and q are different.
|
||||
if not q == p: break
|
||||
return (p, q)
|
||||
|
||||
def extended_gcd(a, b):
|
||||
"""Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
|
||||
"""
|
||||
# r = gcd(a,b) i = multiplicitive inverse of a mod b
|
||||
# or j = multiplicitive inverse of b mod a
|
||||
# Neg return values for i or j are made positive mod b or a respectively
|
||||
# Iterateive Version is faster and uses much less stack space
|
||||
x = 0
|
||||
y = 1
|
||||
lx = 1
|
||||
ly = 0
|
||||
oa = a #Remember original a/b to remove
|
||||
ob = b #negative values from return results
|
||||
while b != 0:
|
||||
q = long(a/b)
|
||||
(a, b) = (b, a % b)
|
||||
(x, lx) = ((lx - (q * x)),x)
|
||||
(y, ly) = ((ly - (q * y)),y)
|
||||
if (lx < 0): lx += ob #If neg wrap modulo orignal b
|
||||
if (ly < 0): ly += oa #If neg wrap modulo orignal a
|
||||
return (a, lx, ly) #Return only positive values
|
||||
|
||||
# Main function: calculate encryption and decryption keys
|
||||
def calculate_keys(p, q, nbits):
|
||||
"""Calculates an encryption and a decryption key for p and q, and
|
||||
returns them as a tuple (e, d)"""
|
||||
|
||||
n = p * q
|
||||
phi_n = (p-1) * (q-1)
|
||||
|
||||
while True:
|
||||
# Make sure e has enough bits so we ensure "wrapping" through
|
||||
# modulo n
|
||||
e = max(65537,getprime(nbits/4))
|
||||
if are_relatively_prime(e, n) and are_relatively_prime(e, phi_n): break
|
||||
|
||||
(d, i, j) = extended_gcd(e, phi_n)
|
||||
|
||||
if not d == 1:
|
||||
raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n))
|
||||
if (i < 0):
|
||||
raise Exception("New extended_gcd shouldn't return negative values")
|
||||
if not (e * i) % phi_n == 1:
|
||||
raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n))
|
||||
|
||||
return (e, i)
|
||||
|
||||
|
||||
def gen_keys(nbits):
|
||||
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
|
||||
|
||||
Note: this can take a long time, depending on the key size.
|
||||
"""
|
||||
|
||||
(p, q) = find_p_q(nbits)
|
||||
(e, d) = calculate_keys(p, q, nbits)
|
||||
|
||||
return (p, q, e, d)
|
||||
|
||||
def newkeys(nbits):
|
||||
"""Generates public and private keys, and returns them as (pub,
|
||||
priv).
|
||||
|
||||
The public key consists of a dict {e: ..., , n: ....). The private
|
||||
key consists of a dict {d: ...., p: ...., q: ....).
|
||||
"""
|
||||
nbits = max(9,nbits) # Don't let nbits go below 9 bits
|
||||
(p, q, e, d) = gen_keys(nbits)
|
||||
|
||||
return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} )
|
||||
|
||||
def encrypt_int(message, ekey, n):
|
||||
"""Encrypts a message using encryption key 'ekey', working modulo n"""
|
||||
|
||||
if type(message) is types.IntType:
|
||||
message = long(message)
|
||||
|
||||
if not type(message) is types.LongType:
|
||||
raise TypeError("You must pass a long or int")
|
||||
|
||||
if message < 0 or message > n:
|
||||
raise OverflowError("The message is too long")
|
||||
|
||||
#Note: Bit exponents start at zero (bit counts start at 1) this is correct
|
||||
safebit = bit_size(n) - 2 #compute safe bit (MSB - 1)
|
||||
message += (1 << safebit) #add safebit to ensure folding
|
||||
|
||||
return pow(message, ekey, n)
|
||||
|
||||
def decrypt_int(cyphertext, dkey, n):
|
||||
"""Decrypts a cypher text using the decryption key 'dkey', working
|
||||
modulo n"""
|
||||
|
||||
message = pow(cyphertext, dkey, n)
|
||||
|
||||
safebit = bit_size(n) - 2 #compute safe bit (MSB - 1)
|
||||
message -= (1 << safebit) #remove safebit before decode
|
||||
|
||||
return message
|
||||
|
||||
def encode64chops(chops):
|
||||
"""base64encodes chops and combines them into a ',' delimited string"""
|
||||
|
||||
chips = [] #chips are character chops
|
||||
|
||||
for value in chops:
|
||||
chips.append(int2str64(value))
|
||||
|
||||
#delimit chops with comma
|
||||
encoded = ','.join(chips)
|
||||
|
||||
return encoded
|
||||
|
||||
def decode64chops(string):
|
||||
"""base64decodes and makes a ',' delimited string into chops"""
|
||||
|
||||
chips = string.split(',') #split chops at commas
|
||||
|
||||
chops = []
|
||||
|
||||
for string in chips: #make char chops (chips) into chops
|
||||
chops.append(str642int(string))
|
||||
|
||||
return chops
|
||||
|
||||
def chopstring(message, key, n, funcref):
|
||||
"""Chops the 'message' into integers that fit into n,
|
||||
leaving room for a safebit to be added to ensure that all
|
||||
messages fold during exponentiation. The MSB of the number n
|
||||
is not independant modulo n (setting it could cause overflow), so
|
||||
use the next lower bit for the safebit. Therefore reserve 2-bits
|
||||
in the number n for non-data bits. Calls specified encryption
|
||||
function for each chop.
|
||||
|
||||
Used by 'encrypt' and 'sign'.
|
||||
"""
|
||||
|
||||
msglen = len(message)
|
||||
mbits = msglen * 8
|
||||
#Set aside 2-bits so setting of safebit won't overflow modulo n.
|
||||
nbits = bit_size(n) - 2 # leave room for safebit
|
||||
nbytes = nbits / 8
|
||||
blocks = msglen / nbytes
|
||||
|
||||
if msglen % nbytes > 0:
|
||||
blocks += 1
|
||||
|
||||
cypher = []
|
||||
|
||||
for bindex in range(blocks):
|
||||
offset = bindex * nbytes
|
||||
block = message[offset:offset+nbytes]
|
||||
value = bytes2int(block)
|
||||
cypher.append(funcref(value, key, n))
|
||||
|
||||
return encode64chops(cypher) #Encode encrypted ints to base64 strings
|
||||
|
||||
def gluechops(string, key, n, funcref):
|
||||
"""Glues chops back together into a string. calls
|
||||
funcref(integer, key, n) for each chop.
|
||||
|
||||
Used by 'decrypt' and 'verify'.
|
||||
"""
|
||||
message = ""
|
||||
|
||||
chops = decode64chops(string) #Decode base64 strings into integer chops
|
||||
|
||||
for cpart in chops:
|
||||
mpart = funcref(cpart, key, n) #Decrypt each chop
|
||||
message += int2bytes(mpart) #Combine decrypted strings into a msg
|
||||
|
||||
return message
|
||||
|
||||
def encrypt(message, key):
|
||||
"""Encrypts a string 'message' with the public key 'key'"""
|
||||
if 'n' not in key:
|
||||
raise Exception("You must use the public key with encrypt")
|
||||
|
||||
return chopstring(message, key['e'], key['n'], encrypt_int)
|
||||
|
||||
def sign(message, key):
|
||||
"""Signs a string 'message' with the private key 'key'"""
|
||||
if 'p' not in key:
|
||||
raise Exception("You must use the private key with sign")
|
||||
|
||||
return chopstring(message, key['d'], key['p']*key['q'], encrypt_int)
|
||||
|
||||
def decrypt(cypher, key):
|
||||
"""Decrypts a string 'cypher' with the private key 'key'"""
|
||||
if 'p' not in key:
|
||||
raise Exception("You must use the private key with decrypt")
|
||||
|
||||
return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int)
|
||||
|
||||
def verify(cypher, key):
|
||||
"""Verifies a string 'cypher' with the public key 'key'"""
|
||||
if 'n' not in key:
|
||||
raise Exception("You must use the public key with verify")
|
||||
|
||||
return gluechops(cypher, key['e'], key['n'], decrypt_int)
|
||||
|
||||
# Do doctest if we're not imported
|
||||
if __name__ == "__main__":
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
||||
__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify"]
|
||||
|
|
@ -1,35 +0,0 @@
|
|||
'''ASN.1 definitions.
|
||||
|
||||
Not all ASN.1-handling code use these definitions, but when it does, they should be here.
|
||||
'''
|
||||
|
||||
from pyasn1.type import univ, namedtype, tag
|
||||
|
||||
class PubKeyHeader(univ.Sequence):
|
||||
componentType = namedtype.NamedTypes(
|
||||
namedtype.NamedType('oid', univ.ObjectIdentifier()),
|
||||
namedtype.NamedType('parameters', univ.Null()),
|
||||
)
|
||||
|
||||
class OpenSSLPubKey(univ.Sequence):
|
||||
componentType = namedtype.NamedTypes(
|
||||
namedtype.NamedType('header', PubKeyHeader()),
|
||||
|
||||
# This little hack (the implicit tag) allows us to get a Bit String as Octet String
|
||||
namedtype.NamedType('key', univ.OctetString().subtype(
|
||||
implicitTag=tag.Tag(tagClass=0, tagFormat=0, tagId=3))),
|
||||
)
|
||||
|
||||
|
||||
class AsnPubKey(univ.Sequence):
|
||||
'''ASN.1 contents of DER encoded public key:
|
||||
|
||||
RSAPublicKey ::= SEQUENCE {
|
||||
modulus INTEGER, -- n
|
||||
publicExponent INTEGER, -- e
|
||||
'''
|
||||
|
||||
componentType = namedtype.NamedTypes(
|
||||
namedtype.NamedType('modulus', univ.Integer()),
|
||||
namedtype.NamedType('publicExponent', univ.Integer()),
|
||||
)
|
|
@ -1,87 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''Large file support
|
||||
|
||||
- break a file into smaller blocks, and encrypt them, and store the
|
||||
encrypted blocks in another file.
|
||||
|
||||
- take such an encrypted files, decrypt its blocks, and reconstruct the
|
||||
original file.
|
||||
|
||||
The encrypted file format is as follows, where || denotes byte concatenation:
|
||||
|
||||
FILE := VERSION || BLOCK || BLOCK ...
|
||||
|
||||
BLOCK := LENGTH || DATA
|
||||
|
||||
LENGTH := varint-encoded length of the subsequent data. Varint comes from
|
||||
Google Protobuf, and encodes an integer into a variable number of bytes.
|
||||
Each byte uses the 7 lowest bits to encode the value. The highest bit set
|
||||
to 1 indicates the next byte is also part of the varint. The last byte will
|
||||
have this bit set to 0.
|
||||
|
||||
This file format is called the VARBLOCK format, in line with the varint format
|
||||
used to denote the block sizes.
|
||||
|
||||
'''
|
||||
|
||||
from rsa import key, common, pkcs1, varblock
|
||||
from rsa._compat import byte
|
||||
|
||||
def encrypt_bigfile(infile, outfile, pub_key):
|
||||
'''Encrypts a file, writing it to 'outfile' in VARBLOCK format.
|
||||
|
||||
:param infile: file-like object to read the cleartext from
|
||||
:param outfile: file-like object to write the crypto in VARBLOCK format to
|
||||
:param pub_key: :py:class:`rsa.PublicKey` to encrypt with
|
||||
|
||||
'''
|
||||
|
||||
if not isinstance(pub_key, key.PublicKey):
|
||||
raise TypeError('Public key required, but got %r' % pub_key)
|
||||
|
||||
key_bytes = common.bit_size(pub_key.n) // 8
|
||||
blocksize = key_bytes - 11 # keep space for PKCS#1 padding
|
||||
|
||||
# Write the version number to the VARBLOCK file
|
||||
outfile.write(byte(varblock.VARBLOCK_VERSION))
|
||||
|
||||
# Encrypt and write each block
|
||||
for block in varblock.yield_fixedblocks(infile, blocksize):
|
||||
crypto = pkcs1.encrypt(block, pub_key)
|
||||
|
||||
varblock.write_varint(outfile, len(crypto))
|
||||
outfile.write(crypto)
|
||||
|
||||
def decrypt_bigfile(infile, outfile, priv_key):
|
||||
'''Decrypts an encrypted VARBLOCK file, writing it to 'outfile'
|
||||
|
||||
:param infile: file-like object to read the crypto in VARBLOCK format from
|
||||
:param outfile: file-like object to write the cleartext to
|
||||
:param priv_key: :py:class:`rsa.PrivateKey` to decrypt with
|
||||
|
||||
'''
|
||||
|
||||
if not isinstance(priv_key, key.PrivateKey):
|
||||
raise TypeError('Private key required, but got %r' % priv_key)
|
||||
|
||||
for block in varblock.yield_varblocks(infile):
|
||||
cleartext = pkcs1.decrypt(block, priv_key)
|
||||
outfile.write(cleartext)
|
||||
|
||||
__all__ = ['encrypt_bigfile', 'decrypt_bigfile']
|
||||
|
|
@ -1,379 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''Commandline scripts.
|
||||
|
||||
These scripts are called by the executables defined in setup.py.
|
||||
'''
|
||||
|
||||
from __future__ import with_statement, print_function
|
||||
|
||||
import abc
|
||||
import sys
|
||||
from optparse import OptionParser
|
||||
|
||||
import rsa
|
||||
import rsa.bigfile
|
||||
import rsa.pkcs1
|
||||
|
||||
HASH_METHODS = sorted(rsa.pkcs1.HASH_METHODS.keys())
|
||||
|
||||
def keygen():
|
||||
'''Key generator.'''
|
||||
|
||||
# Parse the CLI options
|
||||
parser = OptionParser(usage='usage: %prog [options] keysize',
|
||||
description='Generates a new RSA keypair of "keysize" bits.')
|
||||
|
||||
parser.add_option('--pubout', type='string',
|
||||
help='Output filename for the public key. The public key is '
|
||||
'not saved if this option is not present. You can use '
|
||||
'pyrsa-priv2pub to create the public key file later.')
|
||||
|
||||
parser.add_option('-o', '--out', type='string',
|
||||
help='Output filename for the private key. The key is '
|
||||
'written to stdout if this option is not present.')
|
||||
|
||||
parser.add_option('--form',
|
||||
help='key format of the private and public keys - default PEM',
|
||||
choices=('PEM', 'DER'), default='PEM')
|
||||
|
||||
(cli, cli_args) = parser.parse_args(sys.argv[1:])
|
||||
|
||||
if len(cli_args) != 1:
|
||||
parser.print_help()
|
||||
raise SystemExit(1)
|
||||
|
||||
try:
|
||||
keysize = int(cli_args[0])
|
||||
except ValueError:
|
||||
parser.print_help()
|
||||
print('Not a valid number: %s' % cli_args[0], file=sys.stderr)
|
||||
raise SystemExit(1)
|
||||
|
||||
print('Generating %i-bit key' % keysize, file=sys.stderr)
|
||||
(pub_key, priv_key) = rsa.newkeys(keysize)
|
||||
|
||||
|
||||
# Save public key
|
||||
if cli.pubout:
|
||||
print('Writing public key to %s' % cli.pubout, file=sys.stderr)
|
||||
data = pub_key.save_pkcs1(format=cli.form)
|
||||
with open(cli.pubout, 'wb') as outfile:
|
||||
outfile.write(data)
|
||||
|
||||
# Save private key
|
||||
data = priv_key.save_pkcs1(format=cli.form)
|
||||
|
||||
if cli.out:
|
||||
print('Writing private key to %s' % cli.out, file=sys.stderr)
|
||||
with open(cli.out, 'wb') as outfile:
|
||||
outfile.write(data)
|
||||
else:
|
||||
print('Writing private key to stdout', file=sys.stderr)
|
||||
sys.stdout.write(data)
|
||||
|
||||
|
||||
class CryptoOperation(object):
|
||||
'''CLI callable that operates with input, output, and a key.'''
|
||||
|
||||
__metaclass__ = abc.ABCMeta
|
||||
|
||||
keyname = 'public' # or 'private'
|
||||
usage = 'usage: %%prog [options] %(keyname)s_key'
|
||||
description = None
|
||||
operation = 'decrypt'
|
||||
operation_past = 'decrypted'
|
||||
operation_progressive = 'decrypting'
|
||||
input_help = 'Name of the file to %(operation)s. Reads from stdin if ' \
|
||||
'not specified.'
|
||||
output_help = 'Name of the file to write the %(operation_past)s file ' \
|
||||
'to. Written to stdout if this option is not present.'
|
||||
expected_cli_args = 1
|
||||
has_output = True
|
||||
|
||||
key_class = rsa.PublicKey
|
||||
|
||||
def __init__(self):
|
||||
self.usage = self.usage % self.__class__.__dict__
|
||||
self.input_help = self.input_help % self.__class__.__dict__
|
||||
self.output_help = self.output_help % self.__class__.__dict__
|
||||
|
||||
@abc.abstractmethod
|
||||
def perform_operation(self, indata, key, cli_args=None):
|
||||
'''Performs the program's operation.
|
||||
|
||||
Implement in a subclass.
|
||||
|
||||
:returns: the data to write to the output.
|
||||
'''
|
||||
|
||||
def __call__(self):
|
||||
'''Runs the program.'''
|
||||
|
||||
(cli, cli_args) = self.parse_cli()
|
||||
|
||||
key = self.read_key(cli_args[0], cli.keyform)
|
||||
|
||||
indata = self.read_infile(cli.input)
|
||||
|
||||
print(self.operation_progressive.title(), file=sys.stderr)
|
||||
outdata = self.perform_operation(indata, key, cli_args)
|
||||
|
||||
if self.has_output:
|
||||
self.write_outfile(outdata, cli.output)
|
||||
|
||||
def parse_cli(self):
|
||||
'''Parse the CLI options
|
||||
|
||||
:returns: (cli_opts, cli_args)
|
||||
'''
|
||||
|
||||
parser = OptionParser(usage=self.usage, description=self.description)
|
||||
|
||||
parser.add_option('-i', '--input', type='string', help=self.input_help)
|
||||
|
||||
if self.has_output:
|
||||
parser.add_option('-o', '--output', type='string', help=self.output_help)
|
||||
|
||||
parser.add_option('--keyform',
|
||||
help='Key format of the %s key - default PEM' % self.keyname,
|
||||
choices=('PEM', 'DER'), default='PEM')
|
||||
|
||||
(cli, cli_args) = parser.parse_args(sys.argv[1:])
|
||||
|
||||
if len(cli_args) != self.expected_cli_args:
|
||||
parser.print_help()
|
||||
raise SystemExit(1)
|
||||
|
||||
return (cli, cli_args)
|
||||
|
||||
def read_key(self, filename, keyform):
|
||||
'''Reads a public or private key.'''
|
||||
|
||||
print('Reading %s key from %s' % (self.keyname, filename), file=sys.stderr)
|
||||
with open(filename, 'rb') as keyfile:
|
||||
keydata = keyfile.read()
|
||||
|
||||
return self.key_class.load_pkcs1(keydata, keyform)
|
||||
|
||||
def read_infile(self, inname):
|
||||
'''Read the input file'''
|
||||
|
||||
if inname:
|
||||
print('Reading input from %s' % inname, file=sys.stderr)
|
||||
with open(inname, 'rb') as infile:
|
||||
return infile.read()
|
||||
|
||||
print('Reading input from stdin', file=sys.stderr)
|
||||
return sys.stdin.read()
|
||||
|
||||
def write_outfile(self, outdata, outname):
|
||||
'''Write the output file'''
|
||||
|
||||
if outname:
|
||||
print('Writing output to %s' % outname, file=sys.stderr)
|
||||
with open(outname, 'wb') as outfile:
|
||||
outfile.write(outdata)
|
||||
else:
|
||||
print('Writing output to stdout', file=sys.stderr)
|
||||
sys.stdout.write(outdata)
|
||||
|
||||
class EncryptOperation(CryptoOperation):
|
||||
'''Encrypts a file.'''
|
||||
|
||||
keyname = 'public'
|
||||
description = ('Encrypts a file. The file must be shorter than the key '
|
||||
'length in order to be encrypted. For larger files, use the '
|
||||
'pyrsa-encrypt-bigfile command.')
|
||||
operation = 'encrypt'
|
||||
operation_past = 'encrypted'
|
||||
operation_progressive = 'encrypting'
|
||||
|
||||
|
||||
def perform_operation(self, indata, pub_key, cli_args=None):
|
||||
'''Encrypts files.'''
|
||||
|
||||
return rsa.encrypt(indata, pub_key)
|
||||
|
||||
class DecryptOperation(CryptoOperation):
|
||||
'''Decrypts a file.'''
|
||||
|
||||
keyname = 'private'
|
||||
description = ('Decrypts a file. The original file must be shorter than '
|
||||
'the key length in order to have been encrypted. For larger '
|
||||
'files, use the pyrsa-decrypt-bigfile command.')
|
||||
operation = 'decrypt'
|
||||
operation_past = 'decrypted'
|
||||
operation_progressive = 'decrypting'
|
||||
key_class = rsa.PrivateKey
|
||||
|
||||
def perform_operation(self, indata, priv_key, cli_args=None):
|
||||
'''Decrypts files.'''
|
||||
|
||||
return rsa.decrypt(indata, priv_key)
|
||||
|
||||
class SignOperation(CryptoOperation):
|
||||
'''Signs a file.'''
|
||||
|
||||
keyname = 'private'
|
||||
usage = 'usage: %%prog [options] private_key hash_method'
|
||||
description = ('Signs a file, outputs the signature. Choose the hash '
|
||||
'method from %s' % ', '.join(HASH_METHODS))
|
||||
operation = 'sign'
|
||||
operation_past = 'signature'
|
||||
operation_progressive = 'Signing'
|
||||
key_class = rsa.PrivateKey
|
||||
expected_cli_args = 2
|
||||
|
||||
output_help = ('Name of the file to write the signature to. Written '
|
||||
'to stdout if this option is not present.')
|
||||
|
||||
def perform_operation(self, indata, priv_key, cli_args):
|
||||
'''Decrypts files.'''
|
||||
|
||||
hash_method = cli_args[1]
|
||||
if hash_method not in HASH_METHODS:
|
||||
raise SystemExit('Invalid hash method, choose one of %s' %
|
||||
', '.join(HASH_METHODS))
|
||||
|
||||
return rsa.sign(indata, priv_key, hash_method)
|
||||
|
||||
class VerifyOperation(CryptoOperation):
|
||||
'''Verify a signature.'''
|
||||
|
||||
keyname = 'public'
|
||||
usage = 'usage: %%prog [options] public_key signature_file'
|
||||
description = ('Verifies a signature, exits with status 0 upon success, '
|
||||
'prints an error message and exits with status 1 upon error.')
|
||||
operation = 'verify'
|
||||
operation_past = 'verified'
|
||||
operation_progressive = 'Verifying'
|
||||
key_class = rsa.PublicKey
|
||||
expected_cli_args = 2
|
||||
has_output = False
|
||||
|
||||
def perform_operation(self, indata, pub_key, cli_args):
|
||||
'''Decrypts files.'''
|
||||
|
||||
signature_file = cli_args[1]
|
||||
|
||||
with open(signature_file, 'rb') as sigfile:
|
||||
signature = sigfile.read()
|
||||
|
||||
try:
|
||||
rsa.verify(indata, signature, pub_key)
|
||||
except rsa.VerificationError:
|
||||
raise SystemExit('Verification failed.')
|
||||
|
||||
print('Verification OK', file=sys.stderr)
|
||||
|
||||
|
||||
class BigfileOperation(CryptoOperation):
|
||||
'''CryptoOperation that doesn't read the entire file into memory.'''
|
||||
|
||||
def __init__(self):
|
||||
CryptoOperation.__init__(self)
|
||||
|
||||
self.file_objects = []
|
||||
|
||||
def __del__(self):
|
||||
'''Closes any open file handles.'''
|
||||
|
||||
for fobj in self.file_objects:
|
||||
fobj.close()
|
||||
|
||||
def __call__(self):
|
||||
'''Runs the program.'''
|
||||
|
||||
(cli, cli_args) = self.parse_cli()
|
||||
|
||||
key = self.read_key(cli_args[0], cli.keyform)
|
||||
|
||||
# Get the file handles
|
||||
infile = self.get_infile(cli.input)
|
||||
outfile = self.get_outfile(cli.output)
|
||||
|
||||
# Call the operation
|
||||
print(self.operation_progressive.title(), file=sys.stderr)
|
||||
self.perform_operation(infile, outfile, key, cli_args)
|
||||
|
||||
def get_infile(self, inname):
|
||||
'''Returns the input file object'''
|
||||
|
||||
if inname:
|
||||
print('Reading input from %s' % inname, file=sys.stderr)
|
||||
fobj = open(inname, 'rb')
|
||||
self.file_objects.append(fobj)
|
||||
else:
|
||||
print('Reading input from stdin', file=sys.stderr)
|
||||
fobj = sys.stdin
|
||||
|
||||
return fobj
|
||||
|
||||
def get_outfile(self, outname):
|
||||
'''Returns the output file object'''
|
||||
|
||||
if outname:
|
||||
print('Will write output to %s' % outname, file=sys.stderr)
|
||||
fobj = open(outname, 'wb')
|
||||
self.file_objects.append(fobj)
|
||||
else:
|
||||
print('Will write output to stdout', file=sys.stderr)
|
||||
fobj = sys.stdout
|
||||
|
||||
return fobj
|
||||
|
||||
class EncryptBigfileOperation(BigfileOperation):
|
||||
'''Encrypts a file to VARBLOCK format.'''
|
||||
|
||||
keyname = 'public'
|
||||
description = ('Encrypts a file to an encrypted VARBLOCK file. The file '
|
||||
'can be larger than the key length, but the output file is only '
|
||||
'compatible with Python-RSA.')
|
||||
operation = 'encrypt'
|
||||
operation_past = 'encrypted'
|
||||
operation_progressive = 'encrypting'
|
||||
|
||||
def perform_operation(self, infile, outfile, pub_key, cli_args=None):
|
||||
'''Encrypts files to VARBLOCK.'''
|
||||
|
||||
return rsa.bigfile.encrypt_bigfile(infile, outfile, pub_key)
|
||||
|
||||
class DecryptBigfileOperation(BigfileOperation):
|
||||
'''Decrypts a file in VARBLOCK format.'''
|
||||
|
||||
keyname = 'private'
|
||||
description = ('Decrypts an encrypted VARBLOCK file that was encrypted '
|
||||
'with pyrsa-encrypt-bigfile')
|
||||
operation = 'decrypt'
|
||||
operation_past = 'decrypted'
|
||||
operation_progressive = 'decrypting'
|
||||
key_class = rsa.PrivateKey
|
||||
|
||||
def perform_operation(self, infile, outfile, priv_key, cli_args=None):
|
||||
'''Decrypts a VARBLOCK file.'''
|
||||
|
||||
return rsa.bigfile.decrypt_bigfile(infile, outfile, priv_key)
|
||||
|
||||
|
||||
encrypt = EncryptOperation()
|
||||
decrypt = DecryptOperation()
|
||||
sign = SignOperation()
|
||||
verify = VerifyOperation()
|
||||
encrypt_bigfile = EncryptBigfileOperation()
|
||||
decrypt_bigfile = DecryptBigfileOperation()
|
||||
|
|
@ -1,185 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''Common functionality shared by several modules.'''
|
||||
|
||||
|
||||
def bit_size(num):
|
||||
'''
|
||||
Number of bits needed to represent a integer excluding any prefix
|
||||
0 bits.
|
||||
|
||||
As per definition from http://wiki.python.org/moin/BitManipulation and
|
||||
to match the behavior of the Python 3 API.
|
||||
|
||||
Usage::
|
||||
|
||||
>>> bit_size(1023)
|
||||
10
|
||||
>>> bit_size(1024)
|
||||
11
|
||||
>>> bit_size(1025)
|
||||
11
|
||||
|
||||
:param num:
|
||||
Integer value. If num is 0, returns 0. Only the absolute value of the
|
||||
number is considered. Therefore, signed integers will be abs(num)
|
||||
before the number's bit length is determined.
|
||||
:returns:
|
||||
Returns the number of bits in the integer.
|
||||
'''
|
||||
if num == 0:
|
||||
return 0
|
||||
if num < 0:
|
||||
num = -num
|
||||
|
||||
# Make sure this is an int and not a float.
|
||||
num & 1
|
||||
|
||||
hex_num = "%x" % num
|
||||
return ((len(hex_num) - 1) * 4) + {
|
||||
'0':0, '1':1, '2':2, '3':2,
|
||||
'4':3, '5':3, '6':3, '7':3,
|
||||
'8':4, '9':4, 'a':4, 'b':4,
|
||||
'c':4, 'd':4, 'e':4, 'f':4,
|
||||
}[hex_num[0]]
|
||||
|
||||
|
||||
def _bit_size(number):
|
||||
'''
|
||||
Returns the number of bits required to hold a specific long number.
|
||||
'''
|
||||
if number < 0:
|
||||
raise ValueError('Only nonnegative numbers possible: %s' % number)
|
||||
|
||||
if number == 0:
|
||||
return 0
|
||||
|
||||
# This works, even with very large numbers. When using math.log(number, 2),
|
||||
# you'll get rounding errors and it'll fail.
|
||||
bits = 0
|
||||
while number:
|
||||
bits += 1
|
||||
number >>= 1
|
||||
|
||||
return bits
|
||||
|
||||
|
||||
def byte_size(number):
|
||||
'''
|
||||
Returns the number of bytes required to hold a specific long number.
|
||||
|
||||
The number of bytes is rounded up.
|
||||
|
||||
Usage::
|
||||
|
||||
>>> byte_size(1 << 1023)
|
||||
128
|
||||
>>> byte_size((1 << 1024) - 1)
|
||||
128
|
||||
>>> byte_size(1 << 1024)
|
||||
129
|
||||
|
||||
:param number:
|
||||
An unsigned integer
|
||||
:returns:
|
||||
The number of bytes required to hold a specific long number.
|
||||
'''
|
||||
quanta, mod = divmod(bit_size(number), 8)
|
||||
if mod or number == 0:
|
||||
quanta += 1
|
||||
return quanta
|
||||
#return int(math.ceil(bit_size(number) / 8.0))
|
||||
|
||||
|
||||
def extended_gcd(a, b):
|
||||
'''Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
|
||||
'''
|
||||
# r = gcd(a,b) i = multiplicitive inverse of a mod b
|
||||
# or j = multiplicitive inverse of b mod a
|
||||
# Neg return values for i or j are made positive mod b or a respectively
|
||||
# Iterateive Version is faster and uses much less stack space
|
||||
x = 0
|
||||
y = 1
|
||||
lx = 1
|
||||
ly = 0
|
||||
oa = a #Remember original a/b to remove
|
||||
ob = b #negative values from return results
|
||||
while b != 0:
|
||||
q = a // b
|
||||
(a, b) = (b, a % b)
|
||||
(x, lx) = ((lx - (q * x)),x)
|
||||
(y, ly) = ((ly - (q * y)),y)
|
||||
if (lx < 0): lx += ob #If neg wrap modulo orignal b
|
||||
if (ly < 0): ly += oa #If neg wrap modulo orignal a
|
||||
return (a, lx, ly) #Return only positive values
|
||||
|
||||
|
||||
def inverse(x, n):
|
||||
'''Returns x^-1 (mod n)
|
||||
|
||||
>>> inverse(7, 4)
|
||||
3
|
||||
>>> (inverse(143, 4) * 143) % 4
|
||||
1
|
||||
'''
|
||||
|
||||
(divider, inv, _) = extended_gcd(x, n)
|
||||
|
||||
if divider != 1:
|
||||
raise ValueError("x (%d) and n (%d) are not relatively prime" % (x, n))
|
||||
|
||||
return inv
|
||||
|
||||
|
||||
def crt(a_values, modulo_values):
|
||||
'''Chinese Remainder Theorem.
|
||||
|
||||
Calculates x such that x = a[i] (mod m[i]) for each i.
|
||||
|
||||
:param a_values: the a-values of the above equation
|
||||
:param modulo_values: the m-values of the above equation
|
||||
:returns: x such that x = a[i] (mod m[i]) for each i
|
||||
|
||||
|
||||
>>> crt([2, 3], [3, 5])
|
||||
8
|
||||
|
||||
>>> crt([2, 3, 2], [3, 5, 7])
|
||||
23
|
||||
|
||||
>>> crt([2, 3, 0], [7, 11, 15])
|
||||
135
|
||||
'''
|
||||
|
||||
m = 1
|
||||
x = 0
|
||||
|
||||
for modulo in modulo_values:
|
||||
m *= modulo
|
||||
|
||||
for (m_i, a_i) in zip(modulo_values, a_values):
|
||||
M_i = m // m_i
|
||||
inv = inverse(M_i, m_i)
|
||||
|
||||
x = (x + a_i * M_i * inv) % m
|
||||
|
||||
return x
|
||||
|
||||
if __name__ == '__main__':
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
|
@ -1,58 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''Core mathematical operations.
|
||||
|
||||
This is the actual core RSA implementation, which is only defined
|
||||
mathematically on integers.
|
||||
'''
|
||||
|
||||
|
||||
from rsa._compat import is_integer
|
||||
|
||||
def assert_int(var, name):
|
||||
|
||||
if is_integer(var):
|
||||
return
|
||||
|
||||
raise TypeError('%s should be an integer, not %s' % (name, var.__class__))
|
||||
|
||||
def encrypt_int(message, ekey, n):
|
||||
'''Encrypts a message using encryption key 'ekey', working modulo n'''
|
||||
|
||||
assert_int(message, 'message')
|
||||
assert_int(ekey, 'ekey')
|
||||
assert_int(n, 'n')
|
||||
|
||||
if message < 0:
|
||||
raise ValueError('Only non-negative numbers are supported')
|
||||
|
||||
if message > n:
|
||||
raise OverflowError("The message %i is too long for n=%i" % (message, n))
|
||||
|
||||
return pow(message, ekey, n)
|
||||
|
||||
def decrypt_int(cyphertext, dkey, n):
|
||||
'''Decrypts a cypher text using the decryption key 'dkey', working
|
||||
modulo n'''
|
||||
|
||||
assert_int(cyphertext, 'cyphertext')
|
||||
assert_int(dkey, 'dkey')
|
||||
assert_int(n, 'n')
|
||||
|
||||
message = pow(cyphertext, dkey, n)
|
||||
return message
|
||||
|
|
@ -1,612 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''RSA key generation code.
|
||||
|
||||
Create new keys with the newkeys() function. It will give you a PublicKey and a
|
||||
PrivateKey object.
|
||||
|
||||
Loading and saving keys requires the pyasn1 module. This module is imported as
|
||||
late as possible, such that other functionality will remain working in absence
|
||||
of pyasn1.
|
||||
|
||||
'''
|
||||
|
||||
import logging
|
||||
from rsa._compat import b, bytes_type
|
||||
|
||||
import rsa.prime
|
||||
import rsa.pem
|
||||
import rsa.common
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
class AbstractKey(object):
|
||||
'''Abstract superclass for private and public keys.'''
|
||||
|
||||
@classmethod
|
||||
def load_pkcs1(cls, keyfile, format='PEM'):
|
||||
r'''Loads a key in PKCS#1 DER or PEM format.
|
||||
|
||||
:param keyfile: contents of a DER- or PEM-encoded file that contains
|
||||
the public key.
|
||||
:param format: the format of the file to load; 'PEM' or 'DER'
|
||||
|
||||
:return: a PublicKey object
|
||||
|
||||
'''
|
||||
|
||||
methods = {
|
||||
'PEM': cls._load_pkcs1_pem,
|
||||
'DER': cls._load_pkcs1_der,
|
||||
}
|
||||
|
||||
if format not in methods:
|
||||
formats = ', '.join(sorted(methods.keys()))
|
||||
raise ValueError('Unsupported format: %r, try one of %s' % (format,
|
||||
formats))
|
||||
|
||||
method = methods[format]
|
||||
return method(keyfile)
|
||||
|
||||
def save_pkcs1(self, format='PEM'):
|
||||
'''Saves the public key in PKCS#1 DER or PEM format.
|
||||
|
||||
:param format: the format to save; 'PEM' or 'DER'
|
||||
:returns: the DER- or PEM-encoded public key.
|
||||
|
||||
'''
|
||||
|
||||
methods = {
|
||||
'PEM': self._save_pkcs1_pem,
|
||||
'DER': self._save_pkcs1_der,
|
||||
}
|
||||
|
||||
if format not in methods:
|
||||
formats = ', '.join(sorted(methods.keys()))
|
||||
raise ValueError('Unsupported format: %r, try one of %s' % (format,
|
||||
formats))
|
||||
|
||||
method = methods[format]
|
||||
return method()
|
||||
|
||||
class PublicKey(AbstractKey):
|
||||
'''Represents a public RSA key.
|
||||
|
||||
This key is also known as the 'encryption key'. It contains the 'n' and 'e'
|
||||
values.
|
||||
|
||||
Supports attributes as well as dictionary-like access. Attribute accesss is
|
||||
faster, though.
|
||||
|
||||
>>> PublicKey(5, 3)
|
||||
PublicKey(5, 3)
|
||||
|
||||
>>> key = PublicKey(5, 3)
|
||||
>>> key.n
|
||||
5
|
||||
>>> key['n']
|
||||
5
|
||||
>>> key.e
|
||||
3
|
||||
>>> key['e']
|
||||
3
|
||||
|
||||
'''
|
||||
|
||||
__slots__ = ('n', 'e')
|
||||
|
||||
def __init__(self, n, e):
|
||||
self.n = n
|
||||
self.e = e
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __repr__(self):
|
||||
return 'PublicKey(%i, %i)' % (self.n, self.e)
|
||||
|
||||
def __eq__(self, other):
|
||||
if other is None:
|
||||
return False
|
||||
|
||||
if not isinstance(other, PublicKey):
|
||||
return False
|
||||
|
||||
return self.n == other.n and self.e == other.e
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
@classmethod
|
||||
def _load_pkcs1_der(cls, keyfile):
|
||||
r'''Loads a key in PKCS#1 DER format.
|
||||
|
||||
@param keyfile: contents of a DER-encoded file that contains the public
|
||||
key.
|
||||
@return: a PublicKey object
|
||||
|
||||
First let's construct a DER encoded key:
|
||||
|
||||
>>> import base64
|
||||
>>> b64der = 'MAwCBQCNGmYtAgMBAAE='
|
||||
>>> der = base64.decodestring(b64der)
|
||||
|
||||
This loads the file:
|
||||
|
||||
>>> PublicKey._load_pkcs1_der(der)
|
||||
PublicKey(2367317549, 65537)
|
||||
|
||||
'''
|
||||
|
||||
from pyasn1.codec.der import decoder
|
||||
from rsa.asn1 import AsnPubKey
|
||||
|
||||
(priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey())
|
||||
return cls(n=int(priv['modulus']), e=int(priv['publicExponent']))
|
||||
|
||||
def _save_pkcs1_der(self):
|
||||
'''Saves the public key in PKCS#1 DER format.
|
||||
|
||||
@returns: the DER-encoded public key.
|
||||
'''
|
||||
|
||||
from pyasn1.codec.der import encoder
|
||||
from rsa.asn1 import AsnPubKey
|
||||
|
||||
# Create the ASN object
|
||||
asn_key = AsnPubKey()
|
||||
asn_key.setComponentByName('modulus', self.n)
|
||||
asn_key.setComponentByName('publicExponent', self.e)
|
||||
|
||||
return encoder.encode(asn_key)
|
||||
|
||||
@classmethod
|
||||
def _load_pkcs1_pem(cls, keyfile):
|
||||
'''Loads a PKCS#1 PEM-encoded public key file.
|
||||
|
||||
The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and
|
||||
after the "-----END RSA PUBLIC KEY-----" lines is ignored.
|
||||
|
||||
@param keyfile: contents of a PEM-encoded file that contains the public
|
||||
key.
|
||||
@return: a PublicKey object
|
||||
'''
|
||||
|
||||
der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY')
|
||||
return cls._load_pkcs1_der(der)
|
||||
|
||||
def _save_pkcs1_pem(self):
|
||||
'''Saves a PKCS#1 PEM-encoded public key file.
|
||||
|
||||
@return: contents of a PEM-encoded file that contains the public key.
|
||||
'''
|
||||
|
||||
der = self._save_pkcs1_der()
|
||||
return rsa.pem.save_pem(der, 'RSA PUBLIC KEY')
|
||||
|
||||
@classmethod
|
||||
def load_pkcs1_openssl_pem(cls, keyfile):
|
||||
'''Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL.
|
||||
|
||||
These files can be recognised in that they start with BEGIN PUBLIC KEY
|
||||
rather than BEGIN RSA PUBLIC KEY.
|
||||
|
||||
The contents of the file before the "-----BEGIN PUBLIC KEY-----" and
|
||||
after the "-----END PUBLIC KEY-----" lines is ignored.
|
||||
|
||||
@param keyfile: contents of a PEM-encoded file that contains the public
|
||||
key, from OpenSSL.
|
||||
@return: a PublicKey object
|
||||
'''
|
||||
|
||||
der = rsa.pem.load_pem(keyfile, 'PUBLIC KEY')
|
||||
return cls.load_pkcs1_openssl_der(der)
|
||||
|
||||
@classmethod
|
||||
def load_pkcs1_openssl_der(cls, keyfile):
|
||||
'''Loads a PKCS#1 DER-encoded public key file from OpenSSL.
|
||||
|
||||
@param keyfile: contents of a DER-encoded file that contains the public
|
||||
key, from OpenSSL.
|
||||
@return: a PublicKey object
|
||||
'''
|
||||
|
||||
from rsa.asn1 import OpenSSLPubKey
|
||||
from pyasn1.codec.der import decoder
|
||||
from pyasn1.type import univ
|
||||
|
||||
(keyinfo, _) = decoder.decode(keyfile, asn1Spec=OpenSSLPubKey())
|
||||
|
||||
if keyinfo['header']['oid'] != univ.ObjectIdentifier('1.2.840.113549.1.1.1'):
|
||||
raise TypeError("This is not a DER-encoded OpenSSL-compatible public key")
|
||||
|
||||
return cls._load_pkcs1_der(keyinfo['key'][1:])
|
||||
|
||||
|
||||
|
||||
|
||||
class PrivateKey(AbstractKey):
|
||||
'''Represents a private RSA key.
|
||||
|
||||
This key is also known as the 'decryption key'. It contains the 'n', 'e',
|
||||
'd', 'p', 'q' and other values.
|
||||
|
||||
Supports attributes as well as dictionary-like access. Attribute accesss is
|
||||
faster, though.
|
||||
|
||||
>>> PrivateKey(3247, 65537, 833, 191, 17)
|
||||
PrivateKey(3247, 65537, 833, 191, 17)
|
||||
|
||||
exp1, exp2 and coef don't have to be given, they will be calculated:
|
||||
|
||||
>>> pk = PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
|
||||
>>> pk.exp1
|
||||
55063
|
||||
>>> pk.exp2
|
||||
10095
|
||||
>>> pk.coef
|
||||
50797
|
||||
|
||||
If you give exp1, exp2 or coef, they will be used as-is:
|
||||
|
||||
>>> pk = PrivateKey(1, 2, 3, 4, 5, 6, 7, 8)
|
||||
>>> pk.exp1
|
||||
6
|
||||
>>> pk.exp2
|
||||
7
|
||||
>>> pk.coef
|
||||
8
|
||||
|
||||
'''
|
||||
|
||||
__slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef')
|
||||
|
||||
def __init__(self, n, e, d, p, q, exp1=None, exp2=None, coef=None):
|
||||
self.n = n
|
||||
self.e = e
|
||||
self.d = d
|
||||
self.p = p
|
||||
self.q = q
|
||||
|
||||
# Calculate the other values if they aren't supplied
|
||||
if exp1 is None:
|
||||
self.exp1 = int(d % (p - 1))
|
||||
else:
|
||||
self.exp1 = exp1
|
||||
|
||||
if exp1 is None:
|
||||
self.exp2 = int(d % (q - 1))
|
||||
else:
|
||||
self.exp2 = exp2
|
||||
|
||||
if coef is None:
|
||||
self.coef = rsa.common.inverse(q, p)
|
||||
else:
|
||||
self.coef = coef
|
||||
|
||||
def __getitem__(self, key):
|
||||
return getattr(self, key)
|
||||
|
||||
def __repr__(self):
|
||||
return 'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self
|
||||
|
||||
def __eq__(self, other):
|
||||
if other is None:
|
||||
return False
|
||||
|
||||
if not isinstance(other, PrivateKey):
|
||||
return False
|
||||
|
||||
return (self.n == other.n and
|
||||
self.e == other.e and
|
||||
self.d == other.d and
|
||||
self.p == other.p and
|
||||
self.q == other.q and
|
||||
self.exp1 == other.exp1 and
|
||||
self.exp2 == other.exp2 and
|
||||
self.coef == other.coef)
|
||||
|
||||
def __ne__(self, other):
|
||||
return not (self == other)
|
||||
|
||||
@classmethod
|
||||
def _load_pkcs1_der(cls, keyfile):
|
||||
r'''Loads a key in PKCS#1 DER format.
|
||||
|
||||
@param keyfile: contents of a DER-encoded file that contains the private
|
||||
key.
|
||||
@return: a PrivateKey object
|
||||
|
||||
First let's construct a DER encoded key:
|
||||
|
||||
>>> import base64
|
||||
>>> b64der = 'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt'
|
||||
>>> der = base64.decodestring(b64der)
|
||||
|
||||
This loads the file:
|
||||
|
||||
>>> PrivateKey._load_pkcs1_der(der)
|
||||
PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
|
||||
|
||||
'''
|
||||
|
||||
from pyasn1.codec.der import decoder
|
||||
(priv, _) = decoder.decode(keyfile)
|
||||
|
||||
# ASN.1 contents of DER encoded private key:
|
||||
#
|
||||
# RSAPrivateKey ::= SEQUENCE {
|
||||
# version Version,
|
||||
# modulus INTEGER, -- n
|
||||
# publicExponent INTEGER, -- e
|
||||
# privateExponent INTEGER, -- d
|
||||
# prime1 INTEGER, -- p
|
||||
# prime2 INTEGER, -- q
|
||||
# exponent1 INTEGER, -- d mod (p-1)
|
||||
# exponent2 INTEGER, -- d mod (q-1)
|
||||
# coefficient INTEGER, -- (inverse of q) mod p
|
||||
# otherPrimeInfos OtherPrimeInfos OPTIONAL
|
||||
# }
|
||||
|
||||
if priv[0] != 0:
|
||||
raise ValueError('Unable to read this file, version %s != 0' % priv[0])
|
||||
|
||||
as_ints = tuple(int(x) for x in priv[1:9])
|
||||
return cls(*as_ints)
|
||||
|
||||
def _save_pkcs1_der(self):
|
||||
'''Saves the private key in PKCS#1 DER format.
|
||||
|
||||
@returns: the DER-encoded private key.
|
||||
'''
|
||||
|
||||
from pyasn1.type import univ, namedtype
|
||||
from pyasn1.codec.der import encoder
|
||||
|
||||
class AsnPrivKey(univ.Sequence):
|
||||
componentType = namedtype.NamedTypes(
|
||||
namedtype.NamedType('version', univ.Integer()),
|
||||
namedtype.NamedType('modulus', univ.Integer()),
|
||||
namedtype.NamedType('publicExponent', univ.Integer()),
|
||||
namedtype.NamedType('privateExponent', univ.Integer()),
|
||||
namedtype.NamedType('prime1', univ.Integer()),
|
||||
namedtype.NamedType('prime2', univ.Integer()),
|
||||
namedtype.NamedType('exponent1', univ.Integer()),
|
||||
namedtype.NamedType('exponent2', univ.Integer()),
|
||||
namedtype.NamedType('coefficient', univ.Integer()),
|
||||
)
|
||||
|
||||
# Create the ASN object
|
||||
asn_key = AsnPrivKey()
|
||||
asn_key.setComponentByName('version', 0)
|
||||
asn_key.setComponentByName('modulus', self.n)
|
||||
asn_key.setComponentByName('publicExponent', self.e)
|
||||
asn_key.setComponentByName('privateExponent', self.d)
|
||||
asn_key.setComponentByName('prime1', self.p)
|
||||
asn_key.setComponentByName('prime2', self.q)
|
||||
asn_key.setComponentByName('exponent1', self.exp1)
|
||||
asn_key.setComponentByName('exponent2', self.exp2)
|
||||
asn_key.setComponentByName('coefficient', self.coef)
|
||||
|
||||
return encoder.encode(asn_key)
|
||||
|
||||
@classmethod
|
||||
def _load_pkcs1_pem(cls, keyfile):
|
||||
'''Loads a PKCS#1 PEM-encoded private key file.
|
||||
|
||||
The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and
|
||||
after the "-----END RSA PRIVATE KEY-----" lines is ignored.
|
||||
|
||||
@param keyfile: contents of a PEM-encoded file that contains the private
|
||||
key.
|
||||
@return: a PrivateKey object
|
||||
'''
|
||||
|
||||
der = rsa.pem.load_pem(keyfile, b('RSA PRIVATE KEY'))
|
||||
return cls._load_pkcs1_der(der)
|
||||
|
||||
def _save_pkcs1_pem(self):
|
||||
'''Saves a PKCS#1 PEM-encoded private key file.
|
||||
|
||||
@return: contents of a PEM-encoded file that contains the private key.
|
||||
'''
|
||||
|
||||
der = self._save_pkcs1_der()
|
||||
return rsa.pem.save_pem(der, b('RSA PRIVATE KEY'))
|
||||
|
||||
def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
|
||||
''''Returns a tuple of two different primes of nbits bits each.
|
||||
|
||||
The resulting p * q has exacty 2 * nbits bits, and the returned p and q
|
||||
will not be equal.
|
||||
|
||||
:param nbits: the number of bits in each of p and q.
|
||||
:param getprime_func: the getprime function, defaults to
|
||||
:py:func:`rsa.prime.getprime`.
|
||||
|
||||
*Introduced in Python-RSA 3.1*
|
||||
|
||||
:param accurate: whether to enable accurate mode or not.
|
||||
:returns: (p, q), where p > q
|
||||
|
||||
>>> (p, q) = find_p_q(128)
|
||||
>>> from rsa import common
|
||||
>>> common.bit_size(p * q)
|
||||
256
|
||||
|
||||
When not in accurate mode, the number of bits can be slightly less
|
||||
|
||||
>>> (p, q) = find_p_q(128, accurate=False)
|
||||
>>> from rsa import common
|
||||
>>> common.bit_size(p * q) <= 256
|
||||
True
|
||||
>>> common.bit_size(p * q) > 240
|
||||
True
|
||||
|
||||
'''
|
||||
|
||||
total_bits = nbits * 2
|
||||
|
||||
# Make sure that p and q aren't too close or the factoring programs can
|
||||
# factor n.
|
||||
shift = nbits // 16
|
||||
pbits = nbits + shift
|
||||
qbits = nbits - shift
|
||||
|
||||
# Choose the two initial primes
|
||||
log.debug('find_p_q(%i): Finding p', nbits)
|
||||
p = getprime_func(pbits)
|
||||
log.debug('find_p_q(%i): Finding q', nbits)
|
||||
q = getprime_func(qbits)
|
||||
|
||||
def is_acceptable(p, q):
|
||||
'''Returns True iff p and q are acceptable:
|
||||
|
||||
- p and q differ
|
||||
- (p * q) has the right nr of bits (when accurate=True)
|
||||
'''
|
||||
|
||||
if p == q:
|
||||
return False
|
||||
|
||||
if not accurate:
|
||||
return True
|
||||
|
||||
# Make sure we have just the right amount of bits
|
||||
found_size = rsa.common.bit_size(p * q)
|
||||
return total_bits == found_size
|
||||
|
||||
# Keep choosing other primes until they match our requirements.
|
||||
change_p = False
|
||||
while not is_acceptable(p, q):
|
||||
# Change p on one iteration and q on the other
|
||||
if change_p:
|
||||
p = getprime_func(pbits)
|
||||
else:
|
||||
q = getprime_func(qbits)
|
||||
|
||||
change_p = not change_p
|
||||
|
||||
# We want p > q as described on
|
||||
# http://www.di-mgt.com.au/rsa_alg.html#crt
|
||||
return (max(p, q), min(p, q))
|
||||
|
||||
def calculate_keys(p, q, nbits):
|
||||
'''Calculates an encryption and a decryption key given p and q, and
|
||||
returns them as a tuple (e, d)
|
||||
|
||||
'''
|
||||
|
||||
phi_n = (p - 1) * (q - 1)
|
||||
|
||||
# A very common choice for e is 65537
|
||||
e = 65537
|
||||
|
||||
try:
|
||||
d = rsa.common.inverse(e, phi_n)
|
||||
except ValueError:
|
||||
raise ValueError("e (%d) and phi_n (%d) are not relatively prime" %
|
||||
(e, phi_n))
|
||||
|
||||
if (e * d) % phi_n != 1:
|
||||
raise ValueError("e (%d) and d (%d) are not mult. inv. modulo "
|
||||
"phi_n (%d)" % (e, d, phi_n))
|
||||
|
||||
return (e, d)
|
||||
|
||||
def gen_keys(nbits, getprime_func, accurate=True):
|
||||
'''Generate RSA keys of nbits bits. Returns (p, q, e, d).
|
||||
|
||||
Note: this can take a long time, depending on the key size.
|
||||
|
||||
:param nbits: the total number of bits in ``p`` and ``q``. Both ``p`` and
|
||||
``q`` will use ``nbits/2`` bits.
|
||||
:param getprime_func: either :py:func:`rsa.prime.getprime` or a function
|
||||
with similar signature.
|
||||
'''
|
||||
|
||||
(p, q) = find_p_q(nbits // 2, getprime_func, accurate)
|
||||
(e, d) = calculate_keys(p, q, nbits // 2)
|
||||
|
||||
return (p, q, e, d)
|
||||
|
||||
def newkeys(nbits, accurate=True, poolsize=1):
|
||||
'''Generates public and private keys, and returns them as (pub, priv).
|
||||
|
||||
The public key is also known as the 'encryption key', and is a
|
||||
:py:class:`rsa.PublicKey` object. The private key is also known as the
|
||||
'decryption key' and is a :py:class:`rsa.PrivateKey` object.
|
||||
|
||||
:param nbits: the number of bits required to store ``n = p*q``.
|
||||
:param accurate: when True, ``n`` will have exactly the number of bits you
|
||||
asked for. However, this makes key generation much slower. When False,
|
||||
`n`` may have slightly less bits.
|
||||
:param poolsize: the number of processes to use to generate the prime
|
||||
numbers. If set to a number > 1, a parallel algorithm will be used.
|
||||
This requires Python 2.6 or newer.
|
||||
|
||||
:returns: a tuple (:py:class:`rsa.PublicKey`, :py:class:`rsa.PrivateKey`)
|
||||
|
||||
The ``poolsize`` parameter was added in *Python-RSA 3.1* and requires
|
||||
Python 2.6 or newer.
|
||||
|
||||
'''
|
||||
|
||||
if nbits < 16:
|
||||
raise ValueError('Key too small')
|
||||
|
||||
if poolsize < 1:
|
||||
raise ValueError('Pool size (%i) should be >= 1' % poolsize)
|
||||
|
||||
# Determine which getprime function to use
|
||||
if poolsize > 1:
|
||||
from rsa import parallel
|
||||
import functools
|
||||
|
||||
getprime_func = functools.partial(parallel.getprime, poolsize=poolsize)
|
||||
else: getprime_func = rsa.prime.getprime
|
||||
|
||||
# Generate the key components
|
||||
(p, q, e, d) = gen_keys(nbits, getprime_func)
|
||||
|
||||
# Create the key objects
|
||||
n = p * q
|
||||
|
||||
return (
|
||||
PublicKey(n, e),
|
||||
PrivateKey(n, e, d, p, q)
|
||||
)
|
||||
|
||||
__all__ = ['PublicKey', 'PrivateKey', 'newkeys']
|
||||
|
||||
if __name__ == '__main__':
|
||||
import doctest
|
||||
|
||||
try:
|
||||
for count in range(100):
|
||||
(failures, tests) = doctest.testmod()
|
||||
if failures:
|
||||
break
|
||||
|
||||
if (count and count % 10 == 0) or count == 1:
|
||||
print('%i times' % count)
|
||||
except KeyboardInterrupt:
|
||||
print('Aborted')
|
||||
else:
|
||||
print('Doctests done')
|
|
@ -1,94 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''Functions for parallel computation on multiple cores.
|
||||
|
||||
Introduced in Python-RSA 3.1.
|
||||
|
||||
.. note::
|
||||
|
||||
Requires Python 2.6 or newer.
|
||||
|
||||
'''
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import multiprocessing as mp
|
||||
|
||||
import rsa.prime
|
||||
import rsa.randnum
|
||||
|
||||
def _find_prime(nbits, pipe):
|
||||
while True:
|
||||
integer = rsa.randnum.read_random_int(nbits)
|
||||
|
||||
# Make sure it's odd
|
||||
integer |= 1
|
||||
|
||||
# Test for primeness
|
||||
if rsa.prime.is_prime(integer):
|
||||
pipe.send(integer)
|
||||
return
|
||||
|
||||
def getprime(nbits, poolsize):
|
||||
'''Returns a prime number that can be stored in 'nbits' bits.
|
||||
|
||||
Works in multiple threads at the same time.
|
||||
|
||||
>>> p = getprime(128, 3)
|
||||
>>> rsa.prime.is_prime(p-1)
|
||||
False
|
||||
>>> rsa.prime.is_prime(p)
|
||||
True
|
||||
>>> rsa.prime.is_prime(p+1)
|
||||
False
|
||||
|
||||
>>> from rsa import common
|
||||
>>> common.bit_size(p) == 128
|
||||
True
|
||||
|
||||
'''
|
||||
|
||||
(pipe_recv, pipe_send) = mp.Pipe(duplex=False)
|
||||
|
||||
# Create processes
|
||||
procs = [mp.Process(target=_find_prime, args=(nbits, pipe_send))
|
||||
for _ in range(poolsize)]
|
||||
[p.start() for p in procs]
|
||||
|
||||
result = pipe_recv.recv()
|
||||
|
||||
[p.terminate() for p in procs]
|
||||
|
||||
return result
|
||||
|
||||
__all__ = ['getprime']
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Running doctests 1000x or until failure')
|
||||
import doctest
|
||||
|
||||
for count in range(100):
|
||||
(failures, tests) = doctest.testmod()
|
||||
if failures:
|
||||
break
|
||||
|
||||
if count and count % 10 == 0:
|
||||
print('%i times' % count)
|
||||
|
||||
print('Doctests done')
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''Functions that load and write PEM-encoded files.'''
|
||||
|
||||
import base64
|
||||
from rsa._compat import b, is_bytes
|
||||
|
||||
def _markers(pem_marker):
|
||||
'''
|
||||
Returns the start and end PEM markers
|
||||
'''
|
||||
|
||||
if is_bytes(pem_marker):
|
||||
pem_marker = pem_marker.decode('utf-8')
|
||||
|
||||
return (b('-----BEGIN %s-----' % pem_marker),
|
||||
b('-----END %s-----' % pem_marker))
|
||||
|
||||
def load_pem(contents, pem_marker):
|
||||
'''Loads a PEM file.
|
||||
|
||||
@param contents: the contents of the file to interpret
|
||||
@param pem_marker: the marker of the PEM content, such as 'RSA PRIVATE KEY'
|
||||
when your file has '-----BEGIN RSA PRIVATE KEY-----' and
|
||||
'-----END RSA PRIVATE KEY-----' markers.
|
||||
|
||||
@return the base64-decoded content between the start and end markers.
|
||||
|
||||
@raise ValueError: when the content is invalid, for example when the start
|
||||
marker cannot be found.
|
||||
|
||||
'''
|
||||
|
||||
(pem_start, pem_end) = _markers(pem_marker)
|
||||
|
||||
pem_lines = []
|
||||
in_pem_part = False
|
||||
|
||||
for line in contents.splitlines():
|
||||
line = line.strip()
|
||||
|
||||
# Skip empty lines
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Handle start marker
|
||||
if line == pem_start:
|
||||
if in_pem_part:
|
||||
raise ValueError('Seen start marker "%s" twice' % pem_start)
|
||||
|
||||
in_pem_part = True
|
||||
continue
|
||||
|
||||
# Skip stuff before first marker
|
||||
if not in_pem_part:
|
||||
continue
|
||||
|
||||
# Handle end marker
|
||||
if in_pem_part and line == pem_end:
|
||||
in_pem_part = False
|
||||
break
|
||||
|
||||
# Load fields
|
||||
if b(':') in line:
|
||||
continue
|
||||
|
||||
pem_lines.append(line)
|
||||
|
||||
# Do some sanity checks
|
||||
if not pem_lines:
|
||||
raise ValueError('No PEM start marker "%s" found' % pem_start)
|
||||
|
||||
if in_pem_part:
|
||||
raise ValueError('No PEM end marker "%s" found' % pem_end)
|
||||
|
||||
# Base64-decode the contents
|
||||
pem = b('').join(pem_lines)
|
||||
return base64.decodestring(pem)
|
||||
|
||||
|
||||
def save_pem(contents, pem_marker):
|
||||
'''Saves a PEM file.
|
||||
|
||||
@param contents: the contents to encode in PEM format
|
||||
@param pem_marker: the marker of the PEM content, such as 'RSA PRIVATE KEY'
|
||||
when your file has '-----BEGIN RSA PRIVATE KEY-----' and
|
||||
'-----END RSA PRIVATE KEY-----' markers.
|
||||
|
||||
@return the base64-encoded content between the start and end markers.
|
||||
|
||||
'''
|
||||
|
||||
(pem_start, pem_end) = _markers(pem_marker)
|
||||
|
||||
b64 = base64.encodestring(contents).replace(b('\n'), b(''))
|
||||
pem_lines = [pem_start]
|
||||
|
||||
for block_start in range(0, len(b64), 64):
|
||||
block = b64[block_start:block_start + 64]
|
||||
pem_lines.append(block)
|
||||
|
||||
pem_lines.append(pem_end)
|
||||
pem_lines.append(b(''))
|
||||
|
||||
return b('\n').join(pem_lines)
|
||||
|
|
@ -1,391 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''Functions for PKCS#1 version 1.5 encryption and signing
|
||||
|
||||
This module implements certain functionality from PKCS#1 version 1.5. For a
|
||||
very clear example, read http://www.di-mgt.com.au/rsa_alg.html#pkcs1schemes
|
||||
|
||||
At least 8 bytes of random padding is used when encrypting a message. This makes
|
||||
these methods much more secure than the ones in the ``rsa`` module.
|
||||
|
||||
WARNING: this module leaks information when decryption or verification fails.
|
||||
The exceptions that are raised contain the Python traceback information, which
|
||||
can be used to deduce where in the process the failure occurred. DO NOT PASS
|
||||
SUCH INFORMATION to your users.
|
||||
'''
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
from rsa._compat import b
|
||||
from rsa import common, transform, core, varblock
|
||||
|
||||
# ASN.1 codes that describe the hash algorithm used.
|
||||
HASH_ASN1 = {
|
||||
'MD5': b('\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10'),
|
||||
'SHA-1': b('\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'),
|
||||
'SHA-256': b('\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'),
|
||||
'SHA-384': b('\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30'),
|
||||
'SHA-512': b('\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'),
|
||||
}
|
||||
|
||||
HASH_METHODS = {
|
||||
'MD5': hashlib.md5,
|
||||
'SHA-1': hashlib.sha1,
|
||||
'SHA-256': hashlib.sha256,
|
||||
'SHA-384': hashlib.sha384,
|
||||
'SHA-512': hashlib.sha512,
|
||||
}
|
||||
|
||||
class CryptoError(Exception):
|
||||
'''Base class for all exceptions in this module.'''
|
||||
|
||||
class DecryptionError(CryptoError):
|
||||
'''Raised when decryption fails.'''
|
||||
|
||||
class VerificationError(CryptoError):
|
||||
'''Raised when verification fails.'''
|
||||
|
||||
def _pad_for_encryption(message, target_length):
|
||||
r'''Pads the message for encryption, returning the padded message.
|
||||
|
||||
:return: 00 02 RANDOM_DATA 00 MESSAGE
|
||||
|
||||
>>> block = _pad_for_encryption('hello', 16)
|
||||
>>> len(block)
|
||||
16
|
||||
>>> block[0:2]
|
||||
'\x00\x02'
|
||||
>>> block[-6:]
|
||||
'\x00hello'
|
||||
|
||||
'''
|
||||
|
||||
max_msglength = target_length - 11
|
||||
msglength = len(message)
|
||||
|
||||
if msglength > max_msglength:
|
||||
raise OverflowError('%i bytes needed for message, but there is only'
|
||||
' space for %i' % (msglength, max_msglength))
|
||||
|
||||
# Get random padding
|
||||
padding = b('')
|
||||
padding_length = target_length - msglength - 3
|
||||
|
||||
# We remove 0-bytes, so we'll end up with less padding than we've asked for,
|
||||
# so keep adding data until we're at the correct length.
|
||||
while len(padding) < padding_length:
|
||||
needed_bytes = padding_length - len(padding)
|
||||
|
||||
# Always read at least 8 bytes more than we need, and trim off the rest
|
||||
# after removing the 0-bytes. This increases the chance of getting
|
||||
# enough bytes, especially when needed_bytes is small
|
||||
new_padding = os.urandom(needed_bytes + 5)
|
||||
new_padding = new_padding.replace(b('\x00'), b(''))
|
||||
padding = padding + new_padding[:needed_bytes]
|
||||
|
||||
assert len(padding) == padding_length
|
||||
|
||||
return b('').join([b('\x00\x02'),
|
||||
padding,
|
||||
b('\x00'),
|
||||
message])
|
||||
|
||||
|
||||
def _pad_for_signing(message, target_length):
|
||||
r'''Pads the message for signing, returning the padded message.
|
||||
|
||||
The padding is always a repetition of FF bytes.
|
||||
|
||||
:return: 00 01 PADDING 00 MESSAGE
|
||||
|
||||
>>> block = _pad_for_signing('hello', 16)
|
||||
>>> len(block)
|
||||
16
|
||||
>>> block[0:2]
|
||||
'\x00\x01'
|
||||
>>> block[-6:]
|
||||
'\x00hello'
|
||||
>>> block[2:-6]
|
||||
'\xff\xff\xff\xff\xff\xff\xff\xff'
|
||||
|
||||
'''
|
||||
|
||||
max_msglength = target_length - 11
|
||||
msglength = len(message)
|
||||
|
||||
if msglength > max_msglength:
|
||||
raise OverflowError('%i bytes needed for message, but there is only'
|
||||
' space for %i' % (msglength, max_msglength))
|
||||
|
||||
padding_length = target_length - msglength - 3
|
||||
|
||||
return b('').join([b('\x00\x01'),
|
||||
padding_length * b('\xff'),
|
||||
b('\x00'),
|
||||
message])
|
||||
|
||||
|
||||
def encrypt(message, pub_key):
|
||||
'''Encrypts the given message using PKCS#1 v1.5
|
||||
|
||||
:param message: the message to encrypt. Must be a byte string no longer than
|
||||
``k-11`` bytes, where ``k`` is the number of bytes needed to encode
|
||||
the ``n`` component of the public key.
|
||||
:param pub_key: the :py:class:`rsa.PublicKey` to encrypt with.
|
||||
:raise OverflowError: when the message is too large to fit in the padded
|
||||
block.
|
||||
|
||||
>>> from rsa import key, common
|
||||
>>> (pub_key, priv_key) = key.newkeys(256)
|
||||
>>> message = 'hello'
|
||||
>>> crypto = encrypt(message, pub_key)
|
||||
|
||||
The crypto text should be just as long as the public key 'n' component:
|
||||
|
||||
>>> len(crypto) == common.byte_size(pub_key.n)
|
||||
True
|
||||
|
||||
'''
|
||||
|
||||
keylength = common.byte_size(pub_key.n)
|
||||
padded = _pad_for_encryption(message, keylength)
|
||||
|
||||
payload = transform.bytes2int(padded)
|
||||
encrypted = core.encrypt_int(payload, pub_key.e, pub_key.n)
|
||||
block = transform.int2bytes(encrypted, keylength)
|
||||
|
||||
return block
|
||||
|
||||
def decrypt(crypto, priv_key):
|
||||
r'''Decrypts the given message using PKCS#1 v1.5
|
||||
|
||||
The decryption is considered 'failed' when the resulting cleartext doesn't
|
||||
start with the bytes 00 02, or when the 00 byte between the padding and
|
||||
the message cannot be found.
|
||||
|
||||
:param crypto: the crypto text as returned by :py:func:`rsa.encrypt`
|
||||
:param priv_key: the :py:class:`rsa.PrivateKey` to decrypt with.
|
||||
:raise DecryptionError: when the decryption fails. No details are given as
|
||||
to why the code thinks the decryption fails, as this would leak
|
||||
information about the private key.
|
||||
|
||||
|
||||
>>> import rsa
|
||||
>>> (pub_key, priv_key) = rsa.newkeys(256)
|
||||
|
||||
It works with strings:
|
||||
|
||||
>>> crypto = encrypt('hello', pub_key)
|
||||
>>> decrypt(crypto, priv_key)
|
||||
'hello'
|
||||
|
||||
And with binary data:
|
||||
|
||||
>>> crypto = encrypt('\x00\x00\x00\x00\x01', pub_key)
|
||||
>>> decrypt(crypto, priv_key)
|
||||
'\x00\x00\x00\x00\x01'
|
||||
|
||||
Altering the encrypted information will *likely* cause a
|
||||
:py:class:`rsa.pkcs1.DecryptionError`. If you want to be *sure*, use
|
||||
:py:func:`rsa.sign`.
|
||||
|
||||
|
||||
.. warning::
|
||||
|
||||
Never display the stack trace of a
|
||||
:py:class:`rsa.pkcs1.DecryptionError` exception. It shows where in the
|
||||
code the exception occurred, and thus leaks information about the key.
|
||||
It's only a tiny bit of information, but every bit makes cracking the
|
||||
keys easier.
|
||||
|
||||
>>> crypto = encrypt('hello', pub_key)
|
||||
>>> crypto = crypto[0:5] + 'X' + crypto[6:] # change a byte
|
||||
>>> decrypt(crypto, priv_key)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
DecryptionError: Decryption failed
|
||||
|
||||
'''
|
||||
|
||||
blocksize = common.byte_size(priv_key.n)
|
||||
encrypted = transform.bytes2int(crypto)
|
||||
decrypted = core.decrypt_int(encrypted, priv_key.d, priv_key.n)
|
||||
cleartext = transform.int2bytes(decrypted, blocksize)
|
||||
|
||||
# If we can't find the cleartext marker, decryption failed.
|
||||
if cleartext[0:2] != b('\x00\x02'):
|
||||
raise DecryptionError('Decryption failed')
|
||||
|
||||
# Find the 00 separator between the padding and the message
|
||||
try:
|
||||
sep_idx = cleartext.index(b('\x00'), 2)
|
||||
except ValueError:
|
||||
raise DecryptionError('Decryption failed')
|
||||
|
||||
return cleartext[sep_idx+1:]
|
||||
|
||||
def sign(message, priv_key, hash):
|
||||
'''Signs the message with the private key.
|
||||
|
||||
Hashes the message, then signs the hash with the given key. This is known
|
||||
as a "detached signature", because the message itself isn't altered.
|
||||
|
||||
:param message: the message to sign. Can be an 8-bit string or a file-like
|
||||
object. If ``message`` has a ``read()`` method, it is assumed to be a
|
||||
file-like object.
|
||||
:param priv_key: the :py:class:`rsa.PrivateKey` to sign with
|
||||
:param hash: the hash method used on the message. Use 'MD5', 'SHA-1',
|
||||
'SHA-256', 'SHA-384' or 'SHA-512'.
|
||||
:return: a message signature block.
|
||||
:raise OverflowError: if the private key is too small to contain the
|
||||
requested hash.
|
||||
|
||||
'''
|
||||
|
||||
# Get the ASN1 code for this hash method
|
||||
if hash not in HASH_ASN1:
|
||||
raise ValueError('Invalid hash method: %s' % hash)
|
||||
asn1code = HASH_ASN1[hash]
|
||||
|
||||
# Calculate the hash
|
||||
hash = _hash(message, hash)
|
||||
|
||||
# Encrypt the hash with the private key
|
||||
cleartext = asn1code + hash
|
||||
keylength = common.byte_size(priv_key.n)
|
||||
padded = _pad_for_signing(cleartext, keylength)
|
||||
|
||||
payload = transform.bytes2int(padded)
|
||||
encrypted = core.encrypt_int(payload, priv_key.d, priv_key.n)
|
||||
block = transform.int2bytes(encrypted, keylength)
|
||||
|
||||
return block
|
||||
|
||||
def verify(message, signature, pub_key):
|
||||
'''Verifies that the signature matches the message.
|
||||
|
||||
The hash method is detected automatically from the signature.
|
||||
|
||||
:param message: the signed message. Can be an 8-bit string or a file-like
|
||||
object. If ``message`` has a ``read()`` method, it is assumed to be a
|
||||
file-like object.
|
||||
:param signature: the signature block, as created with :py:func:`rsa.sign`.
|
||||
:param pub_key: the :py:class:`rsa.PublicKey` of the person signing the message.
|
||||
:raise VerificationError: when the signature doesn't match the message.
|
||||
|
||||
.. warning::
|
||||
|
||||
Never display the stack trace of a
|
||||
:py:class:`rsa.pkcs1.VerificationError` exception. It shows where in
|
||||
the code the exception occurred, and thus leaks information about the
|
||||
key. It's only a tiny bit of information, but every bit makes cracking
|
||||
the keys easier.
|
||||
|
||||
'''
|
||||
|
||||
blocksize = common.byte_size(pub_key.n)
|
||||
encrypted = transform.bytes2int(signature)
|
||||
decrypted = core.decrypt_int(encrypted, pub_key.e, pub_key.n)
|
||||
clearsig = transform.int2bytes(decrypted, blocksize)
|
||||
|
||||
# If we can't find the signature marker, verification failed.
|
||||
if clearsig[0:2] != b('\x00\x01'):
|
||||
raise VerificationError('Verification failed')
|
||||
|
||||
# Find the 00 separator between the padding and the payload
|
||||
try:
|
||||
sep_idx = clearsig.index(b('\x00'), 2)
|
||||
except ValueError:
|
||||
raise VerificationError('Verification failed')
|
||||
|
||||
# Get the hash and the hash method
|
||||
(method_name, signature_hash) = _find_method_hash(clearsig[sep_idx+1:])
|
||||
message_hash = _hash(message, method_name)
|
||||
|
||||
# Compare the real hash to the hash in the signature
|
||||
if message_hash != signature_hash:
|
||||
raise VerificationError('Verification failed')
|
||||
|
||||
return True
|
||||
|
||||
def _hash(message, method_name):
|
||||
'''Returns the message digest.
|
||||
|
||||
:param message: the signed message. Can be an 8-bit string or a file-like
|
||||
object. If ``message`` has a ``read()`` method, it is assumed to be a
|
||||
file-like object.
|
||||
:param method_name: the hash method, must be a key of
|
||||
:py:const:`HASH_METHODS`.
|
||||
|
||||
'''
|
||||
|
||||
if method_name not in HASH_METHODS:
|
||||
raise ValueError('Invalid hash method: %s' % method_name)
|
||||
|
||||
method = HASH_METHODS[method_name]
|
||||
hasher = method()
|
||||
|
||||
if hasattr(message, 'read') and hasattr(message.read, '__call__'):
|
||||
# read as 1K blocks
|
||||
for block in varblock.yield_fixedblocks(message, 1024):
|
||||
hasher.update(block)
|
||||
else:
|
||||
# hash the message object itself.
|
||||
hasher.update(message)
|
||||
|
||||
return hasher.digest()
|
||||
|
||||
|
||||
def _find_method_hash(method_hash):
|
||||
'''Finds the hash method and the hash itself.
|
||||
|
||||
:param method_hash: ASN1 code for the hash method concatenated with the
|
||||
hash itself.
|
||||
|
||||
:return: tuple (method, hash) where ``method`` is the used hash method, and
|
||||
``hash`` is the hash itself.
|
||||
|
||||
:raise VerificationFailed: when the hash method cannot be found
|
||||
|
||||
'''
|
||||
|
||||
for (hashname, asn1code) in HASH_ASN1.items():
|
||||
if not method_hash.startswith(asn1code):
|
||||
continue
|
||||
|
||||
return (hashname, method_hash[len(asn1code):])
|
||||
|
||||
raise VerificationError('Verification failed')
|
||||
|
||||
|
||||
__all__ = ['encrypt', 'decrypt', 'sign', 'verify',
|
||||
'DecryptionError', 'VerificationError', 'CryptoError']
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Running doctests 1000x or until failure')
|
||||
import doctest
|
||||
|
||||
for count in range(1000):
|
||||
(failures, tests) = doctest.testmod()
|
||||
if failures:
|
||||
break
|
||||
|
||||
if count and count % 100 == 0:
|
||||
print('%i times' % count)
|
||||
|
||||
print('Doctests done')
|
|
@ -1,166 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''Numerical functions related to primes.
|
||||
|
||||
Implementation based on the book Algorithm Design by Michael T. Goodrich and
|
||||
Roberto Tamassia, 2002.
|
||||
'''
|
||||
|
||||
__all__ = [ 'getprime', 'are_relatively_prime']
|
||||
|
||||
import rsa.randnum
|
||||
|
||||
def gcd(p, q):
|
||||
'''Returns the greatest common divisor of p and q
|
||||
|
||||
>>> gcd(48, 180)
|
||||
12
|
||||
'''
|
||||
|
||||
while q != 0:
|
||||
if p < q: (p,q) = (q,p)
|
||||
(p,q) = (q, p % q)
|
||||
return p
|
||||
|
||||
|
||||
def jacobi(a, b):
|
||||
'''Calculates the value of the Jacobi symbol (a/b) where both a and b are
|
||||
positive integers, and b is odd
|
||||
|
||||
:returns: -1, 0 or 1
|
||||
'''
|
||||
|
||||
assert a > 0
|
||||
assert b > 0
|
||||
|
||||
if a == 0: return 0
|
||||
result = 1
|
||||
while a > 1:
|
||||
if a & 1:
|
||||
if ((a-1)*(b-1) >> 2) & 1:
|
||||
result = -result
|
||||
a, b = b % a, a
|
||||
else:
|
||||
if (((b * b) - 1) >> 3) & 1:
|
||||
result = -result
|
||||
a >>= 1
|
||||
if a == 0: return 0
|
||||
return result
|
||||
|
||||
def jacobi_witness(x, n):
|
||||
'''Returns False if n is an Euler pseudo-prime with base x, and
|
||||
True otherwise.
|
||||
'''
|
||||
|
||||
j = jacobi(x, n) % n
|
||||
|
||||
f = pow(x, n >> 1, n)
|
||||
|
||||
if j == f: return False
|
||||
return True
|
||||
|
||||
def randomized_primality_testing(n, k):
|
||||
'''Calculates whether n is composite (which is always correct) or
|
||||
prime (which is incorrect with error probability 2**-k)
|
||||
|
||||
Returns False if the number is composite, and True if it's
|
||||
probably prime.
|
||||
'''
|
||||
|
||||
# 50% of Jacobi-witnesses can report compositness of non-prime numbers
|
||||
|
||||
# The implemented algorithm using the Jacobi witness function has error
|
||||
# probability q <= 0.5, according to Goodrich et. al
|
||||
#
|
||||
# q = 0.5
|
||||
# t = int(math.ceil(k / log(1 / q, 2)))
|
||||
# So t = k / log(2, 2) = k / 1 = k
|
||||
# this means we can use range(k) rather than range(t)
|
||||
|
||||
for _ in range(k):
|
||||
x = rsa.randnum.randint(n-1)
|
||||
if jacobi_witness(x, n): return False
|
||||
|
||||
return True
|
||||
|
||||
def is_prime(number):
|
||||
'''Returns True if the number is prime, and False otherwise.
|
||||
|
||||
>>> is_prime(42)
|
||||
False
|
||||
>>> is_prime(41)
|
||||
True
|
||||
'''
|
||||
|
||||
return randomized_primality_testing(number, 6)
|
||||
|
||||
def getprime(nbits):
|
||||
'''Returns a prime number that can be stored in 'nbits' bits.
|
||||
|
||||
>>> p = getprime(128)
|
||||
>>> is_prime(p-1)
|
||||
False
|
||||
>>> is_prime(p)
|
||||
True
|
||||
>>> is_prime(p+1)
|
||||
False
|
||||
|
||||
>>> from rsa import common
|
||||
>>> common.bit_size(p) == 128
|
||||
True
|
||||
|
||||
'''
|
||||
|
||||
while True:
|
||||
integer = rsa.randnum.read_random_int(nbits)
|
||||
|
||||
# Make sure it's odd
|
||||
integer |= 1
|
||||
|
||||
# Test for primeness
|
||||
if is_prime(integer):
|
||||
return integer
|
||||
|
||||
# Retry if not prime
|
||||
|
||||
|
||||
def are_relatively_prime(a, b):
|
||||
'''Returns True if a and b are relatively prime, and False if they
|
||||
are not.
|
||||
|
||||
>>> are_relatively_prime(2, 3)
|
||||
1
|
||||
>>> are_relatively_prime(2, 4)
|
||||
0
|
||||
'''
|
||||
|
||||
d = gcd(a, b)
|
||||
return (d == 1)
|
||||
|
||||
if __name__ == '__main__':
|
||||
print('Running doctests 1000x or until failure')
|
||||
import doctest
|
||||
|
||||
for count in range(1000):
|
||||
(failures, tests) = doctest.testmod()
|
||||
if failures:
|
||||
break
|
||||
|
||||
if count and count % 100 == 0:
|
||||
print('%i times' % count)
|
||||
|
||||
print('Doctests done')
|
|
@ -1,85 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''Functions for generating random numbers.'''
|
||||
|
||||
# Source inspired by code by Yesudeep Mangalapilly <yesudeep@gmail.com>
|
||||
|
||||
import os
|
||||
|
||||
from rsa import common, transform
|
||||
from rsa._compat import byte
|
||||
|
||||
def read_random_bits(nbits):
|
||||
'''Reads 'nbits' random bits.
|
||||
|
||||
If nbits isn't a whole number of bytes, an extra byte will be appended with
|
||||
only the lower bits set.
|
||||
'''
|
||||
|
||||
nbytes, rbits = divmod(nbits, 8)
|
||||
|
||||
# Get the random bytes
|
||||
randomdata = os.urandom(nbytes)
|
||||
|
||||
# Add the remaining random bits
|
||||
if rbits > 0:
|
||||
randomvalue = ord(os.urandom(1))
|
||||
randomvalue >>= (8 - rbits)
|
||||
randomdata = byte(randomvalue) + randomdata
|
||||
|
||||
return randomdata
|
||||
|
||||
|
||||
def read_random_int(nbits):
|
||||
'''Reads a random integer of approximately nbits bits.
|
||||
'''
|
||||
|
||||
randomdata = read_random_bits(nbits)
|
||||
value = transform.bytes2int(randomdata)
|
||||
|
||||
# Ensure that the number is large enough to just fill out the required
|
||||
# number of bits.
|
||||
value |= 1 << (nbits - 1)
|
||||
|
||||
return value
|
||||
|
||||
def randint(maxvalue):
|
||||
'''Returns a random integer x with 1 <= x <= maxvalue
|
||||
|
||||
May take a very long time in specific situations. If maxvalue needs N bits
|
||||
to store, the closer maxvalue is to (2 ** N) - 1, the faster this function
|
||||
is.
|
||||
'''
|
||||
|
||||
bit_size = common.bit_size(maxvalue)
|
||||
|
||||
tries = 0
|
||||
while True:
|
||||
value = read_random_int(bit_size)
|
||||
if value <= maxvalue:
|
||||
break
|
||||
|
||||
if tries and tries % 10 == 0:
|
||||
# After a lot of tries to get the right number of bits but still
|
||||
# smaller than maxvalue, decrease the number of bits by 1. That'll
|
||||
# dramatically increase the chances to get a large enough number.
|
||||
bit_size -= 1
|
||||
tries += 1
|
||||
|
||||
return value
|
||||
|
||||
|
|
@ -1,220 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''Data transformation functions.
|
||||
|
||||
From bytes to a number, number to bytes, etc.
|
||||
'''
|
||||
|
||||
from __future__ import absolute_import
|
||||
|
||||
try:
|
||||
# We'll use psyco if available on 32-bit architectures to speed up code.
|
||||
# Using psyco (if available) cuts down the execution time on Python 2.5
|
||||
# at least by half.
|
||||
import psyco
|
||||
psyco.full()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import binascii
|
||||
from struct import pack
|
||||
from rsa import common
|
||||
from rsa._compat import is_integer, b, byte, get_word_alignment, ZERO_BYTE, EMPTY_BYTE
|
||||
|
||||
|
||||
def bytes2int(raw_bytes):
|
||||
r'''Converts a list of bytes or an 8-bit string to an integer.
|
||||
|
||||
When using unicode strings, encode it to some encoding like UTF8 first.
|
||||
|
||||
>>> (((128 * 256) + 64) * 256) + 15
|
||||
8405007
|
||||
>>> bytes2int('\x80@\x0f')
|
||||
8405007
|
||||
|
||||
'''
|
||||
|
||||
return int(binascii.hexlify(raw_bytes), 16)
|
||||
|
||||
|
||||
def _int2bytes(number, block_size=None):
|
||||
r'''Converts a number to a string of bytes.
|
||||
|
||||
Usage::
|
||||
|
||||
>>> _int2bytes(123456789)
|
||||
'\x07[\xcd\x15'
|
||||
>>> bytes2int(_int2bytes(123456789))
|
||||
123456789
|
||||
|
||||
>>> _int2bytes(123456789, 6)
|
||||
'\x00\x00\x07[\xcd\x15'
|
||||
>>> bytes2int(_int2bytes(123456789, 128))
|
||||
123456789
|
||||
|
||||
>>> _int2bytes(123456789, 3)
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
OverflowError: Needed 4 bytes for number, but block size is 3
|
||||
|
||||
@param number: the number to convert
|
||||
@param block_size: the number of bytes to output. If the number encoded to
|
||||
bytes is less than this, the block will be zero-padded. When not given,
|
||||
the returned block is not padded.
|
||||
|
||||
@throws OverflowError when block_size is given and the number takes up more
|
||||
bytes than fit into the block.
|
||||
'''
|
||||
# Type checking
|
||||
if not is_integer(number):
|
||||
raise TypeError("You must pass an integer for 'number', not %s" %
|
||||
number.__class__)
|
||||
|
||||
if number < 0:
|
||||
raise ValueError('Negative numbers cannot be used: %i' % number)
|
||||
|
||||
# Do some bounds checking
|
||||
if number == 0:
|
||||
needed_bytes = 1
|
||||
raw_bytes = [ZERO_BYTE]
|
||||
else:
|
||||
needed_bytes = common.byte_size(number)
|
||||
raw_bytes = []
|
||||
|
||||
# You cannot compare None > 0 in Python 3x. It will fail with a TypeError.
|
||||
if block_size and block_size > 0:
|
||||
if needed_bytes > block_size:
|
||||
raise OverflowError('Needed %i bytes for number, but block size '
|
||||
'is %i' % (needed_bytes, block_size))
|
||||
|
||||
# Convert the number to bytes.
|
||||
while number > 0:
|
||||
raw_bytes.insert(0, byte(number & 0xFF))
|
||||
number >>= 8
|
||||
|
||||
# Pad with zeroes to fill the block
|
||||
if block_size and block_size > 0:
|
||||
padding = (block_size - needed_bytes) * ZERO_BYTE
|
||||
else:
|
||||
padding = EMPTY_BYTE
|
||||
|
||||
return padding + EMPTY_BYTE.join(raw_bytes)
|
||||
|
||||
|
||||
def bytes_leading(raw_bytes, needle=ZERO_BYTE):
|
||||
'''
|
||||
Finds the number of prefixed byte occurrences in the haystack.
|
||||
|
||||
Useful when you want to deal with padding.
|
||||
|
||||
:param raw_bytes:
|
||||
Raw bytes.
|
||||
:param needle:
|
||||
The byte to count. Default \000.
|
||||
:returns:
|
||||
The number of leading needle bytes.
|
||||
'''
|
||||
leading = 0
|
||||
# Indexing keeps compatibility between Python 2.x and Python 3.x
|
||||
_byte = needle[0]
|
||||
for x in raw_bytes:
|
||||
if x == _byte:
|
||||
leading += 1
|
||||
else:
|
||||
break
|
||||
return leading
|
||||
|
||||
|
||||
def int2bytes(number, fill_size=None, chunk_size=None, overflow=False):
|
||||
'''
|
||||
Convert an unsigned integer to bytes (base-256 representation)::
|
||||
|
||||
Does not preserve leading zeros if you don't specify a chunk size or
|
||||
fill size.
|
||||
|
||||
.. NOTE:
|
||||
You must not specify both fill_size and chunk_size. Only one
|
||||
of them is allowed.
|
||||
|
||||
:param number:
|
||||
Integer value
|
||||
:param fill_size:
|
||||
If the optional fill size is given the length of the resulting
|
||||
byte string is expected to be the fill size and will be padded
|
||||
with prefix zero bytes to satisfy that length.
|
||||
:param chunk_size:
|
||||
If optional chunk size is given and greater than zero, pad the front of
|
||||
the byte string with binary zeros so that the length is a multiple of
|
||||
``chunk_size``.
|
||||
:param overflow:
|
||||
``False`` (default). If this is ``True``, no ``OverflowError``
|
||||
will be raised when the fill_size is shorter than the length
|
||||
of the generated byte sequence. Instead the byte sequence will
|
||||
be returned as is.
|
||||
:returns:
|
||||
Raw bytes (base-256 representation).
|
||||
:raises:
|
||||
``OverflowError`` when fill_size is given and the number takes up more
|
||||
bytes than fit into the block. This requires the ``overflow``
|
||||
argument to this function to be set to ``False`` otherwise, no
|
||||
error will be raised.
|
||||
'''
|
||||
if number < 0:
|
||||
raise ValueError("Number must be an unsigned integer: %d" % number)
|
||||
|
||||
if fill_size and chunk_size:
|
||||
raise ValueError("You can either fill or pad chunks, but not both")
|
||||
|
||||
# Ensure these are integers.
|
||||
number & 1
|
||||
|
||||
raw_bytes = b('')
|
||||
|
||||
# Pack the integer one machine word at a time into bytes.
|
||||
num = number
|
||||
word_bits, _, max_uint, pack_type = get_word_alignment(num)
|
||||
pack_format = ">%s" % pack_type
|
||||
while num > 0:
|
||||
raw_bytes = pack(pack_format, num & max_uint) + raw_bytes
|
||||
num >>= word_bits
|
||||
# Obtain the index of the first non-zero byte.
|
||||
zero_leading = bytes_leading(raw_bytes)
|
||||
if number == 0:
|
||||
raw_bytes = ZERO_BYTE
|
||||
# De-padding.
|
||||
raw_bytes = raw_bytes[zero_leading:]
|
||||
|
||||
length = len(raw_bytes)
|
||||
if fill_size and fill_size > 0:
|
||||
if not overflow and length > fill_size:
|
||||
raise OverflowError(
|
||||
"Need %d bytes for number, but fill size is %d" %
|
||||
(length, fill_size)
|
||||
)
|
||||
raw_bytes = raw_bytes.rjust(fill_size, ZERO_BYTE)
|
||||
elif chunk_size and chunk_size > 0:
|
||||
remainder = length % chunk_size
|
||||
if remainder:
|
||||
padding_size = chunk_size - remainder
|
||||
raw_bytes = raw_bytes.rjust(length + padding_size, ZERO_BYTE)
|
||||
return raw_bytes
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
import doctest
|
||||
doctest.testmod()
|
||||
|
|
@ -1,81 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''Utility functions.'''
|
||||
|
||||
from __future__ import with_statement, print_function
|
||||
|
||||
import sys
|
||||
from optparse import OptionParser
|
||||
|
||||
import rsa.key
|
||||
|
||||
def private_to_public():
|
||||
'''Reads a private key and outputs the corresponding public key.'''
|
||||
|
||||
# Parse the CLI options
|
||||
parser = OptionParser(usage='usage: %prog [options]',
|
||||
description='Reads a private key and outputs the '
|
||||
'corresponding public key. Both private and public keys use '
|
||||
'the format described in PKCS#1 v1.5')
|
||||
|
||||
parser.add_option('-i', '--input', dest='infilename', type='string',
|
||||
help='Input filename. Reads from stdin if not specified')
|
||||
parser.add_option('-o', '--output', dest='outfilename', type='string',
|
||||
help='Output filename. Writes to stdout of not specified')
|
||||
|
||||
parser.add_option('--inform', dest='inform',
|
||||
help='key format of input - default PEM',
|
||||
choices=('PEM', 'DER'), default='PEM')
|
||||
|
||||
parser.add_option('--outform', dest='outform',
|
||||
help='key format of output - default PEM',
|
||||
choices=('PEM', 'DER'), default='PEM')
|
||||
|
||||
(cli, cli_args) = parser.parse_args(sys.argv)
|
||||
|
||||
# Read the input data
|
||||
if cli.infilename:
|
||||
print('Reading private key from %s in %s format' % \
|
||||
(cli.infilename, cli.inform), file=sys.stderr)
|
||||
with open(cli.infilename, 'rb') as infile:
|
||||
in_data = infile.read()
|
||||
else:
|
||||
print('Reading private key from stdin in %s format' % cli.inform,
|
||||
file=sys.stderr)
|
||||
in_data = sys.stdin.read().encode('ascii')
|
||||
|
||||
assert type(in_data) == bytes, type(in_data)
|
||||
|
||||
|
||||
# Take the public fields and create a public key
|
||||
priv_key = rsa.key.PrivateKey.load_pkcs1(in_data, cli.inform)
|
||||
pub_key = rsa.key.PublicKey(priv_key.n, priv_key.e)
|
||||
|
||||
# Save to the output file
|
||||
out_data = pub_key.save_pkcs1(cli.outform)
|
||||
|
||||
if cli.outfilename:
|
||||
print('Writing public key to %s in %s format' % \
|
||||
(cli.outfilename, cli.outform), file=sys.stderr)
|
||||
with open(cli.outfilename, 'wb') as outfile:
|
||||
outfile.write(out_data)
|
||||
else:
|
||||
print('Writing public key to stdout in %s format' % cli.outform,
|
||||
file=sys.stderr)
|
||||
sys.stdout.write(out_data.decode('ascii'))
|
||||
|
||||
|
|
@ -1,155 +0,0 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
#
|
||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
'''VARBLOCK file support
|
||||
|
||||
The VARBLOCK file format is as follows, where || denotes byte concatenation:
|
||||
|
||||
FILE := VERSION || BLOCK || BLOCK ...
|
||||
|
||||
BLOCK := LENGTH || DATA
|
||||
|
||||
LENGTH := varint-encoded length of the subsequent data. Varint comes from
|
||||
Google Protobuf, and encodes an integer into a variable number of bytes.
|
||||
Each byte uses the 7 lowest bits to encode the value. The highest bit set
|
||||
to 1 indicates the next byte is also part of the varint. The last byte will
|
||||
have this bit set to 0.
|
||||
|
||||
This file format is called the VARBLOCK format, in line with the varint format
|
||||
used to denote the block sizes.
|
||||
|
||||
'''
|
||||
|
||||
from rsa._compat import byte, b
|
||||
|
||||
|
||||
ZERO_BYTE = b('\x00')
|
||||
VARBLOCK_VERSION = 1
|
||||
|
||||
def read_varint(infile):
|
||||
'''Reads a varint from the file.
|
||||
|
||||
When the first byte to be read indicates EOF, (0, 0) is returned. When an
|
||||
EOF occurs when at least one byte has been read, an EOFError exception is
|
||||
raised.
|
||||
|
||||
@param infile: the file-like object to read from. It should have a read()
|
||||
method.
|
||||
@returns (varint, length), the read varint and the number of read bytes.
|
||||
'''
|
||||
|
||||
varint = 0
|
||||
read_bytes = 0
|
||||
|
||||
while True:
|
||||
char = infile.read(1)
|
||||
if len(char) == 0:
|
||||
if read_bytes == 0:
|
||||
return (0, 0)
|
||||
raise EOFError('EOF while reading varint, value is %i so far' %
|
||||
varint)
|
||||
|
||||
byte = ord(char)
|
||||
varint += (byte & 0x7F) << (7 * read_bytes)
|
||||
|
||||
read_bytes += 1
|
||||
|
||||
if not byte & 0x80:
|
||||
return (varint, read_bytes)
|
||||
|
||||
|
||||
def write_varint(outfile, value):
|
||||
'''Writes a varint to a file.
|
||||
|
||||
@param outfile: the file-like object to write to. It should have a write()
|
||||
method.
|
||||
@returns the number of written bytes.
|
||||
'''
|
||||
|
||||
# there is a big difference between 'write the value 0' (this case) and
|
||||
# 'there is nothing left to write' (the false-case of the while loop)
|
||||
|
||||
if value == 0:
|
||||
outfile.write(ZERO_BYTE)
|
||||
return 1
|
||||
|
||||
written_bytes = 0
|
||||
while value > 0:
|
||||
to_write = value & 0x7f
|
||||
value = value >> 7
|
||||
|
||||
if value > 0:
|
||||
to_write |= 0x80
|
||||
|
||||
outfile.write(byte(to_write))
|
||||
written_bytes += 1
|
||||
|
||||
return written_bytes
|
||||
|
||||
|
||||
def yield_varblocks(infile):
|
||||
'''Generator, yields each block in the input file.
|
||||
|
||||
@param infile: file to read, is expected to have the VARBLOCK format as
|
||||
described in the module's docstring.
|
||||
@yields the contents of each block.
|
||||
'''
|
||||
|
||||
# Check the version number
|
||||
first_char = infile.read(1)
|
||||
if len(first_char) == 0:
|
||||
raise EOFError('Unable to read VARBLOCK version number')
|
||||
|
||||
version = ord(first_char)
|
||||
if version != VARBLOCK_VERSION:
|
||||
raise ValueError('VARBLOCK version %i not supported' % version)
|
||||
|
||||
while True:
|
||||
(block_size, read_bytes) = read_varint(infile)
|
||||
|
||||
# EOF at block boundary, that's fine.
|
||||
if read_bytes == 0 and block_size == 0:
|
||||
break
|
||||
|
||||
block = infile.read(block_size)
|
||||
|
||||
read_size = len(block)
|
||||
if read_size != block_size:
|
||||
raise EOFError('Block size is %i, but could read only %i bytes' %
|
||||
(block_size, read_size))
|
||||
|
||||
yield block
|
||||
|
||||
|
||||
def yield_fixedblocks(infile, blocksize):
|
||||
'''Generator, yields each block of ``blocksize`` bytes in the input file.
|
||||
|
||||
:param infile: file to read and separate in blocks.
|
||||
:returns: a generator that yields the contents of each block
|
||||
'''
|
||||
|
||||
while True:
|
||||
block = infile.read(blocksize)
|
||||
|
||||
read_bytes = len(block)
|
||||
if read_bytes == 0:
|
||||
break
|
||||
|
||||
yield block
|
||||
|
||||
if read_bytes < blocksize:
|
||||
break
|
||||
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
"""The Tornado web server and tools."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
# version is a human-readable version number.
|
||||
|
||||
|
@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement
|
|||
# is zero for an official release, positive for a development branch,
|
||||
# or negative for a release candidate or beta (after the base version
|
||||
# number has been incremented)
|
||||
version = "4.3"
|
||||
version_info = (4, 3, 0, 0)
|
||||
version = "4.5.1"
|
||||
version_info = (4, 5, 1, 0)
|
||||
|
|
|
@ -17,78 +17,69 @@
|
|||
|
||||
"""Data used by the tornado.locale module."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
# NOTE: This file is supposed to contain unicode strings, which is
|
||||
# exactly what you'd get with e.g. u"Español" in most python versions.
|
||||
# However, Python 3.2 doesn't support the u"" syntax, so we use a u()
|
||||
# function instead. tornado.util.u cannot be used because it doesn't
|
||||
# support non-ascii characters on python 2.
|
||||
# When we drop support for Python 3.2, we can remove the parens
|
||||
# and make these plain unicode strings.
|
||||
from tornado.escape import to_unicode as u
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
LOCALE_NAMES = {
|
||||
"af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")},
|
||||
"am_ET": {"name_en": u("Amharic"), "name": u("አማርኛ")},
|
||||
"ar_AR": {"name_en": u("Arabic"), "name": u("العربية")},
|
||||
"bg_BG": {"name_en": u("Bulgarian"), "name": u("Български")},
|
||||
"bn_IN": {"name_en": u("Bengali"), "name": u("বাংলা")},
|
||||
"bs_BA": {"name_en": u("Bosnian"), "name": u("Bosanski")},
|
||||
"ca_ES": {"name_en": u("Catalan"), "name": u("Català")},
|
||||
"cs_CZ": {"name_en": u("Czech"), "name": u("Čeština")},
|
||||
"cy_GB": {"name_en": u("Welsh"), "name": u("Cymraeg")},
|
||||
"da_DK": {"name_en": u("Danish"), "name": u("Dansk")},
|
||||
"de_DE": {"name_en": u("German"), "name": u("Deutsch")},
|
||||
"el_GR": {"name_en": u("Greek"), "name": u("Ελληνικά")},
|
||||
"en_GB": {"name_en": u("English (UK)"), "name": u("English (UK)")},
|
||||
"en_US": {"name_en": u("English (US)"), "name": u("English (US)")},
|
||||
"es_ES": {"name_en": u("Spanish (Spain)"), "name": u("Español (España)")},
|
||||
"es_LA": {"name_en": u("Spanish"), "name": u("Español")},
|
||||
"et_EE": {"name_en": u("Estonian"), "name": u("Eesti")},
|
||||
"eu_ES": {"name_en": u("Basque"), "name": u("Euskara")},
|
||||
"fa_IR": {"name_en": u("Persian"), "name": u("فارسی")},
|
||||
"fi_FI": {"name_en": u("Finnish"), "name": u("Suomi")},
|
||||
"fr_CA": {"name_en": u("French (Canada)"), "name": u("Français (Canada)")},
|
||||
"fr_FR": {"name_en": u("French"), "name": u("Français")},
|
||||
"ga_IE": {"name_en": u("Irish"), "name": u("Gaeilge")},
|
||||
"gl_ES": {"name_en": u("Galician"), "name": u("Galego")},
|
||||
"he_IL": {"name_en": u("Hebrew"), "name": u("עברית")},
|
||||
"hi_IN": {"name_en": u("Hindi"), "name": u("हिन्दी")},
|
||||
"hr_HR": {"name_en": u("Croatian"), "name": u("Hrvatski")},
|
||||
"hu_HU": {"name_en": u("Hungarian"), "name": u("Magyar")},
|
||||
"id_ID": {"name_en": u("Indonesian"), "name": u("Bahasa Indonesia")},
|
||||
"is_IS": {"name_en": u("Icelandic"), "name": u("Íslenska")},
|
||||
"it_IT": {"name_en": u("Italian"), "name": u("Italiano")},
|
||||
"ja_JP": {"name_en": u("Japanese"), "name": u("日本語")},
|
||||
"ko_KR": {"name_en": u("Korean"), "name": u("한국어")},
|
||||
"lt_LT": {"name_en": u("Lithuanian"), "name": u("Lietuvių")},
|
||||
"lv_LV": {"name_en": u("Latvian"), "name": u("Latviešu")},
|
||||
"mk_MK": {"name_en": u("Macedonian"), "name": u("Македонски")},
|
||||
"ml_IN": {"name_en": u("Malayalam"), "name": u("മലയാളം")},
|
||||
"ms_MY": {"name_en": u("Malay"), "name": u("Bahasa Melayu")},
|
||||
"nb_NO": {"name_en": u("Norwegian (bokmal)"), "name": u("Norsk (bokmål)")},
|
||||
"nl_NL": {"name_en": u("Dutch"), "name": u("Nederlands")},
|
||||
"nn_NO": {"name_en": u("Norwegian (nynorsk)"), "name": u("Norsk (nynorsk)")},
|
||||
"pa_IN": {"name_en": u("Punjabi"), "name": u("ਪੰਜਾਬੀ")},
|
||||
"pl_PL": {"name_en": u("Polish"), "name": u("Polski")},
|
||||
"pt_BR": {"name_en": u("Portuguese (Brazil)"), "name": u("Português (Brasil)")},
|
||||
"pt_PT": {"name_en": u("Portuguese (Portugal)"), "name": u("Português (Portugal)")},
|
||||
"ro_RO": {"name_en": u("Romanian"), "name": u("Română")},
|
||||
"ru_RU": {"name_en": u("Russian"), "name": u("Русский")},
|
||||
"sk_SK": {"name_en": u("Slovak"), "name": u("Slovenčina")},
|
||||
"sl_SI": {"name_en": u("Slovenian"), "name": u("Slovenščina")},
|
||||
"sq_AL": {"name_en": u("Albanian"), "name": u("Shqip")},
|
||||
"sr_RS": {"name_en": u("Serbian"), "name": u("Српски")},
|
||||
"sv_SE": {"name_en": u("Swedish"), "name": u("Svenska")},
|
||||
"sw_KE": {"name_en": u("Swahili"), "name": u("Kiswahili")},
|
||||
"ta_IN": {"name_en": u("Tamil"), "name": u("தமிழ்")},
|
||||
"te_IN": {"name_en": u("Telugu"), "name": u("తెలుగు")},
|
||||
"th_TH": {"name_en": u("Thai"), "name": u("ภาษาไทย")},
|
||||
"tl_PH": {"name_en": u("Filipino"), "name": u("Filipino")},
|
||||
"tr_TR": {"name_en": u("Turkish"), "name": u("Türkçe")},
|
||||
"uk_UA": {"name_en": u("Ukraini "), "name": u("Українська")},
|
||||
"vi_VN": {"name_en": u("Vietnamese"), "name": u("Tiếng Việt")},
|
||||
"zh_CN": {"name_en": u("Chinese (Simplified)"), "name": u("中文(简体)")},
|
||||
"zh_TW": {"name_en": u("Chinese (Traditional)"), "name": u("中文(繁體)")},
|
||||
"af_ZA": {"name_en": u"Afrikaans", "name": u"Afrikaans"},
|
||||
"am_ET": {"name_en": u"Amharic", "name": u"አማርኛ"},
|
||||
"ar_AR": {"name_en": u"Arabic", "name": u"العربية"},
|
||||
"bg_BG": {"name_en": u"Bulgarian", "name": u"Български"},
|
||||
"bn_IN": {"name_en": u"Bengali", "name": u"বাংলা"},
|
||||
"bs_BA": {"name_en": u"Bosnian", "name": u"Bosanski"},
|
||||
"ca_ES": {"name_en": u"Catalan", "name": u"Català"},
|
||||
"cs_CZ": {"name_en": u"Czech", "name": u"Čeština"},
|
||||
"cy_GB": {"name_en": u"Welsh", "name": u"Cymraeg"},
|
||||
"da_DK": {"name_en": u"Danish", "name": u"Dansk"},
|
||||
"de_DE": {"name_en": u"German", "name": u"Deutsch"},
|
||||
"el_GR": {"name_en": u"Greek", "name": u"Ελληνικά"},
|
||||
"en_GB": {"name_en": u"English (UK)", "name": u"English (UK)"},
|
||||
"en_US": {"name_en": u"English (US)", "name": u"English (US)"},
|
||||
"es_ES": {"name_en": u"Spanish (Spain)", "name": u"Español (España)"},
|
||||
"es_LA": {"name_en": u"Spanish", "name": u"Español"},
|
||||
"et_EE": {"name_en": u"Estonian", "name": u"Eesti"},
|
||||
"eu_ES": {"name_en": u"Basque", "name": u"Euskara"},
|
||||
"fa_IR": {"name_en": u"Persian", "name": u"فارسی"},
|
||||
"fi_FI": {"name_en": u"Finnish", "name": u"Suomi"},
|
||||
"fr_CA": {"name_en": u"French (Canada)", "name": u"Français (Canada)"},
|
||||
"fr_FR": {"name_en": u"French", "name": u"Français"},
|
||||
"ga_IE": {"name_en": u"Irish", "name": u"Gaeilge"},
|
||||
"gl_ES": {"name_en": u"Galician", "name": u"Galego"},
|
||||
"he_IL": {"name_en": u"Hebrew", "name": u"עברית"},
|
||||
"hi_IN": {"name_en": u"Hindi", "name": u"हिन्दी"},
|
||||
"hr_HR": {"name_en": u"Croatian", "name": u"Hrvatski"},
|
||||
"hu_HU": {"name_en": u"Hungarian", "name": u"Magyar"},
|
||||
"id_ID": {"name_en": u"Indonesian", "name": u"Bahasa Indonesia"},
|
||||
"is_IS": {"name_en": u"Icelandic", "name": u"Íslenska"},
|
||||
"it_IT": {"name_en": u"Italian", "name": u"Italiano"},
|
||||
"ja_JP": {"name_en": u"Japanese", "name": u"日本語"},
|
||||
"ko_KR": {"name_en": u"Korean", "name": u"한국어"},
|
||||
"lt_LT": {"name_en": u"Lithuanian", "name": u"Lietuvių"},
|
||||
"lv_LV": {"name_en": u"Latvian", "name": u"Latviešu"},
|
||||
"mk_MK": {"name_en": u"Macedonian", "name": u"Македонски"},
|
||||
"ml_IN": {"name_en": u"Malayalam", "name": u"മലയാളം"},
|
||||
"ms_MY": {"name_en": u"Malay", "name": u"Bahasa Melayu"},
|
||||
"nb_NO": {"name_en": u"Norwegian (bokmal)", "name": u"Norsk (bokmål)"},
|
||||
"nl_NL": {"name_en": u"Dutch", "name": u"Nederlands"},
|
||||
"nn_NO": {"name_en": u"Norwegian (nynorsk)", "name": u"Norsk (nynorsk)"},
|
||||
"pa_IN": {"name_en": u"Punjabi", "name": u"ਪੰਜਾਬੀ"},
|
||||
"pl_PL": {"name_en": u"Polish", "name": u"Polski"},
|
||||
"pt_BR": {"name_en": u"Portuguese (Brazil)", "name": u"Português (Brasil)"},
|
||||
"pt_PT": {"name_en": u"Portuguese (Portugal)", "name": u"Português (Portugal)"},
|
||||
"ro_RO": {"name_en": u"Romanian", "name": u"Română"},
|
||||
"ru_RU": {"name_en": u"Russian", "name": u"Русский"},
|
||||
"sk_SK": {"name_en": u"Slovak", "name": u"Slovenčina"},
|
||||
"sl_SI": {"name_en": u"Slovenian", "name": u"Slovenščina"},
|
||||
"sq_AL": {"name_en": u"Albanian", "name": u"Shqip"},
|
||||
"sr_RS": {"name_en": u"Serbian", "name": u"Српски"},
|
||||
"sv_SE": {"name_en": u"Swedish", "name": u"Svenska"},
|
||||
"sw_KE": {"name_en": u"Swahili", "name": u"Kiswahili"},
|
||||
"ta_IN": {"name_en": u"Tamil", "name": u"தமிழ்"},
|
||||
"te_IN": {"name_en": u"Telugu", "name": u"తెలుగు"},
|
||||
"th_TH": {"name_en": u"Thai", "name": u"ภาษาไทย"},
|
||||
"tl_PH": {"name_en": u"Filipino", "name": u"Filipino"},
|
||||
"tr_TR": {"name_en": u"Turkish", "name": u"Türkçe"},
|
||||
"uk_UA": {"name_en": u"Ukraini ", "name": u"Українська"},
|
||||
"vi_VN": {"name_en": u"Vietnamese", "name": u"Tiếng Việt"},
|
||||
"zh_CN": {"name_en": u"Chinese (Simplified)", "name": u"中文(简体)"},
|
||||
"zh_TW": {"name_en": u"Chinese (Traditional)", "name": u"中文(繁體)"},
|
||||
}
|
||||
|
|
|
@ -65,7 +65,7 @@ Example usage for Google OAuth:
|
|||
errors are more consistently reported through the ``Future`` interfaces.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import base64
|
||||
import binascii
|
||||
|
@ -82,22 +82,15 @@ from tornado import escape
|
|||
from tornado.httputil import url_concat
|
||||
from tornado.log import gen_log
|
||||
from tornado.stack_context import ExceptionStackContext
|
||||
from tornado.util import u, unicode_type, ArgReplacer
|
||||
from tornado.util import unicode_type, ArgReplacer, PY3
|
||||
|
||||
try:
|
||||
import urlparse # py2
|
||||
except ImportError:
|
||||
import urllib.parse as urlparse # py3
|
||||
|
||||
try:
|
||||
import urllib.parse as urllib_parse # py3
|
||||
except ImportError:
|
||||
import urllib as urllib_parse # py2
|
||||
|
||||
try:
|
||||
long # py2
|
||||
except NameError:
|
||||
long = int # py3
|
||||
if PY3:
|
||||
import urllib.parse as urlparse
|
||||
import urllib.parse as urllib_parse
|
||||
long = int
|
||||
else:
|
||||
import urlparse
|
||||
import urllib as urllib_parse
|
||||
|
||||
|
||||
class AuthError(Exception):
|
||||
|
@ -188,7 +181,7 @@ class OpenIdMixin(object):
|
|||
"""
|
||||
# Verify the OpenID response via direct request to the OP
|
||||
args = dict((k, v[-1]) for k, v in self.request.arguments.items())
|
||||
args["openid.mode"] = u("check_authentication")
|
||||
args["openid.mode"] = u"check_authentication"
|
||||
url = self._OPENID_ENDPOINT
|
||||
if http_client is None:
|
||||
http_client = self.get_auth_http_client()
|
||||
|
@ -255,13 +248,13 @@ class OpenIdMixin(object):
|
|||
ax_ns = None
|
||||
for name in self.request.arguments:
|
||||
if name.startswith("openid.ns.") and \
|
||||
self.get_argument(name) == u("http://openid.net/srv/ax/1.0"):
|
||||
self.get_argument(name) == u"http://openid.net/srv/ax/1.0":
|
||||
ax_ns = name[10:]
|
||||
break
|
||||
|
||||
def get_ax_arg(uri):
|
||||
if not ax_ns:
|
||||
return u("")
|
||||
return u""
|
||||
prefix = "openid." + ax_ns + ".type."
|
||||
ax_name = None
|
||||
for name in self.request.arguments.keys():
|
||||
|
@ -270,8 +263,8 @@ class OpenIdMixin(object):
|
|||
ax_name = "openid." + ax_ns + ".value." + part
|
||||
break
|
||||
if not ax_name:
|
||||
return u("")
|
||||
return self.get_argument(ax_name, u(""))
|
||||
return u""
|
||||
return self.get_argument(ax_name, u"")
|
||||
|
||||
email = get_ax_arg("http://axschema.org/contact/email")
|
||||
name = get_ax_arg("http://axschema.org/namePerson")
|
||||
|
@ -290,7 +283,7 @@ class OpenIdMixin(object):
|
|||
if name:
|
||||
user["name"] = name
|
||||
elif name_parts:
|
||||
user["name"] = u(" ").join(name_parts)
|
||||
user["name"] = u" ".join(name_parts)
|
||||
elif email:
|
||||
user["name"] = email.split("@")[0]
|
||||
if email:
|
||||
|
@ -961,6 +954,20 @@ class FacebookGraphMixin(OAuth2Mixin):
|
|||
.. testoutput::
|
||||
:hide:
|
||||
|
||||
This method returns a dictionary which may contain the following fields:
|
||||
|
||||
* ``access_token``, a string which may be passed to `facebook_request`
|
||||
* ``session_expires``, an integer encoded as a string representing
|
||||
the time until the access token expires in seconds. This field should
|
||||
be used like ``int(user['session_expires'])``; in a future version of
|
||||
Tornado it will change from a string to an integer.
|
||||
* ``id``, ``name``, ``first_name``, ``last_name``, ``locale``, ``picture``,
|
||||
``link``, plus any fields named in the ``extra_fields`` argument. These
|
||||
fields are copied from the Facebook graph API `user object <https://developers.facebook.com/docs/graph-api/reference/user>`_
|
||||
|
||||
.. versionchanged:: 4.5
|
||||
The ``session_expires`` field was updated to support changes made to the
|
||||
Facebook API in March 2017.
|
||||
"""
|
||||
http = self.get_auth_http_client()
|
||||
args = {
|
||||
|
@ -985,10 +992,10 @@ class FacebookGraphMixin(OAuth2Mixin):
|
|||
future.set_exception(AuthError('Facebook auth error: %s' % str(response)))
|
||||
return
|
||||
|
||||
args = urlparse.parse_qs(escape.native_str(response.body))
|
||||
args = escape.json_decode(response.body)
|
||||
session = {
|
||||
"access_token": args["access_token"][-1],
|
||||
"expires": args.get("expires")
|
||||
"access_token": args.get("access_token"),
|
||||
"expires_in": args.get("expires_in")
|
||||
}
|
||||
|
||||
self.facebook_request(
|
||||
|
@ -996,6 +1003,9 @@ class FacebookGraphMixin(OAuth2Mixin):
|
|||
callback=functools.partial(
|
||||
self._on_get_user_info, future, session, fields),
|
||||
access_token=session["access_token"],
|
||||
appsecret_proof=hmac.new(key=client_secret.encode('utf8'),
|
||||
msg=session["access_token"].encode('utf8'),
|
||||
digestmod=hashlib.sha256).hexdigest(),
|
||||
fields=",".join(fields)
|
||||
)
|
||||
|
||||
|
@ -1008,7 +1018,12 @@ class FacebookGraphMixin(OAuth2Mixin):
|
|||
for field in fields:
|
||||
fieldmap[field] = user.get(field)
|
||||
|
||||
fieldmap.update({"access_token": session["access_token"], "session_expires": session.get("expires")})
|
||||
# session_expires is converted to str for compatibility with
|
||||
# older versions in which the server used url-encoding and
|
||||
# this code simply returned the string verbatim.
|
||||
# This should change in Tornado 5.0.
|
||||
fieldmap.update({"access_token": session["access_token"],
|
||||
"session_expires": str(session.get("expires_in"))})
|
||||
future.set_result(fieldmap)
|
||||
|
||||
@_auth_return_future
|
||||
|
|
|
@ -45,7 +45,7 @@ incorrectly.
|
|||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
@ -83,7 +83,7 @@ if __name__ == "__main__":
|
|||
import functools
|
||||
import logging
|
||||
import os
|
||||
import pkgutil
|
||||
import pkgutil # type: ignore
|
||||
import sys
|
||||
import traceback
|
||||
import types
|
||||
|
@ -103,16 +103,12 @@ except ImportError:
|
|||
# os.execv is broken on Windows and can't properly parse command line
|
||||
# arguments and executable name if they contain whitespaces. subprocess
|
||||
# fixes that behavior.
|
||||
# This distinction is also important because when we use execv, we want to
|
||||
# close the IOLoop and all its file descriptors, to guard against any
|
||||
# file descriptors that were not set CLOEXEC. When execv is not available,
|
||||
# we must not close the IOLoop because we want the process to exit cleanly.
|
||||
_has_execv = sys.platform != 'win32'
|
||||
|
||||
_watched_files = set()
|
||||
_reload_hooks = []
|
||||
_reload_attempted = False
|
||||
_io_loops = weakref.WeakKeyDictionary()
|
||||
_io_loops = weakref.WeakKeyDictionary() # type: ignore
|
||||
|
||||
|
||||
def start(io_loop=None, check_time=500):
|
||||
|
@ -127,8 +123,6 @@ def start(io_loop=None, check_time=500):
|
|||
_io_loops[io_loop] = True
|
||||
if len(_io_loops) > 1:
|
||||
gen_log.warning("tornado.autoreload started more than once in the same process")
|
||||
if _has_execv:
|
||||
add_reload_hook(functools.partial(io_loop.close, all_fds=True))
|
||||
modify_times = {}
|
||||
callback = functools.partial(_reload_on_update, modify_times)
|
||||
scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop)
|
||||
|
@ -249,6 +243,7 @@ def _reload():
|
|||
# unwind, so just exit uncleanly.
|
||||
os._exit(0)
|
||||
|
||||
|
||||
_USAGE = """\
|
||||
Usage:
|
||||
python -m tornado.autoreload -m module.to.run [args...]
|
||||
|
|
|
@ -21,7 +21,7 @@ a mostly-compatible `Future` class designed for use from coroutines,
|
|||
as well as some utility functions for interacting with the
|
||||
`concurrent.futures` package.
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import functools
|
||||
import platform
|
||||
|
@ -31,13 +31,18 @@ import sys
|
|||
|
||||
from tornado.log import app_log
|
||||
from tornado.stack_context import ExceptionStackContext, wrap
|
||||
from tornado.util import raise_exc_info, ArgReplacer
|
||||
from tornado.util import raise_exc_info, ArgReplacer, is_finalizing
|
||||
|
||||
try:
|
||||
from concurrent import futures
|
||||
except ImportError:
|
||||
futures = None
|
||||
|
||||
try:
|
||||
import typing
|
||||
except ImportError:
|
||||
typing = None
|
||||
|
||||
|
||||
# Can the garbage collector handle cycles that include __del__ methods?
|
||||
# This is true in cpython beginning with version 3.4 (PEP 442).
|
||||
|
@ -118,8 +123,8 @@ class _TracebackLogger(object):
|
|||
self.exc_info = None
|
||||
self.formatted_tb = None
|
||||
|
||||
def __del__(self):
|
||||
if self.formatted_tb:
|
||||
def __del__(self, is_finalizing=is_finalizing):
|
||||
if not is_finalizing() and self.formatted_tb:
|
||||
app_log.error('Future exception was never retrieved: %s',
|
||||
''.join(self.formatted_tb).rstrip())
|
||||
|
||||
|
@ -229,7 +234,10 @@ class Future(object):
|
|||
if self._result is not None:
|
||||
return self._result
|
||||
if self._exc_info is not None:
|
||||
raise_exc_info(self._exc_info)
|
||||
try:
|
||||
raise_exc_info(self._exc_info)
|
||||
finally:
|
||||
self = None
|
||||
self._check_done()
|
||||
return self._result
|
||||
|
||||
|
@ -324,8 +332,8 @@ class Future(object):
|
|||
# cycle are never destroyed. It's no longer the case on Python 3.4 thanks to
|
||||
# the PEP 442.
|
||||
if _GC_CYCLE_FINALIZERS:
|
||||
def __del__(self):
|
||||
if not self._log_traceback:
|
||||
def __del__(self, is_finalizing=is_finalizing):
|
||||
if is_finalizing() or not self._log_traceback:
|
||||
# set_exception() was not called, or result() or exception()
|
||||
# has consumed the exception
|
||||
return
|
||||
|
@ -335,10 +343,11 @@ class Future(object):
|
|||
app_log.error('Future %r exception was never retrieved: %s',
|
||||
self, ''.join(tb).rstrip())
|
||||
|
||||
|
||||
TracebackFuture = Future
|
||||
|
||||
if futures is None:
|
||||
FUTURES = Future
|
||||
FUTURES = Future # type: typing.Union[type, typing.Tuple[type, ...]]
|
||||
else:
|
||||
FUTURES = (futures.Future, Future)
|
||||
|
||||
|
@ -359,6 +368,7 @@ class DummyExecutor(object):
|
|||
def shutdown(self, wait=True):
|
||||
pass
|
||||
|
||||
|
||||
dummy_executor = DummyExecutor()
|
||||
|
||||
|
||||
|
@ -500,8 +510,9 @@ def chain_future(a, b):
|
|||
assert future is a
|
||||
if b.done():
|
||||
return
|
||||
if (isinstance(a, TracebackFuture) and isinstance(b, TracebackFuture)
|
||||
and a.exc_info() is not None):
|
||||
if (isinstance(a, TracebackFuture) and
|
||||
isinstance(b, TracebackFuture) and
|
||||
a.exc_info() is not None):
|
||||
b.set_exc_info(a.exc_info())
|
||||
elif a.exception() is not None:
|
||||
b.set_exception(a.exception())
|
||||
|
|
|
@ -16,12 +16,12 @@
|
|||
|
||||
"""Non-blocking HTTP client implementation using pycurl."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
import logging
|
||||
import pycurl
|
||||
import pycurl # type: ignore
|
||||
import threading
|
||||
import time
|
||||
from io import BytesIO
|
||||
|
@ -221,6 +221,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
|||
# _process_queue() is called from
|
||||
# _finish_pending_requests the exceptions have
|
||||
# nowhere to go.
|
||||
self._free_list.append(curl)
|
||||
callback(HTTPResponse(
|
||||
request=request,
|
||||
code=599,
|
||||
|
@ -277,6 +278,9 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
|||
if curl_log.isEnabledFor(logging.DEBUG):
|
||||
curl.setopt(pycurl.VERBOSE, 1)
|
||||
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
|
||||
if hasattr(pycurl, 'PROTOCOLS'): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12)
|
||||
curl.setopt(pycurl.PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
|
||||
curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
|
||||
return curl
|
||||
|
||||
def _curl_setup_request(self, curl, request, buffer, headers):
|
||||
|
@ -341,6 +345,15 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
|||
credentials = '%s:%s' % (request.proxy_username,
|
||||
request.proxy_password)
|
||||
curl.setopt(pycurl.PROXYUSERPWD, credentials)
|
||||
|
||||
if (request.proxy_auth_mode is None or
|
||||
request.proxy_auth_mode == "basic"):
|
||||
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC)
|
||||
elif request.proxy_auth_mode == "digest":
|
||||
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported proxy_auth_mode %s" % request.proxy_auth_mode)
|
||||
else:
|
||||
curl.setopt(pycurl.PROXY, '')
|
||||
curl.unsetopt(pycurl.PROXYUSERPWD)
|
||||
|
@ -461,7 +474,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
|||
request.prepare_curl_callback(curl)
|
||||
|
||||
def _curl_header_callback(self, headers, header_callback, header_line):
|
||||
header_line = native_str(header_line)
|
||||
header_line = native_str(header_line.decode('latin1'))
|
||||
if header_callback is not None:
|
||||
self.io_loop.add_callback(header_callback, header_line)
|
||||
# header_line as returned by curl includes the end-of-line characters.
|
||||
|
|
|
@ -20,34 +20,28 @@ Also includes a few other miscellaneous string manipulation functions that
|
|||
have crept in over time.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
import re
|
||||
import sys
|
||||
|
||||
from tornado.util import unicode_type, basestring_type, u
|
||||
|
||||
try:
|
||||
from urllib.parse import parse_qs as _parse_qs # py3
|
||||
except ImportError:
|
||||
from urlparse import parse_qs as _parse_qs # Python 2.6+
|
||||
|
||||
try:
|
||||
import htmlentitydefs # py2
|
||||
except ImportError:
|
||||
import html.entities as htmlentitydefs # py3
|
||||
|
||||
try:
|
||||
import urllib.parse as urllib_parse # py3
|
||||
except ImportError:
|
||||
import urllib as urllib_parse # py2
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import json
|
||||
import re
|
||||
|
||||
from tornado.util import PY3, unicode_type, basestring_type
|
||||
|
||||
if PY3:
|
||||
from urllib.parse import parse_qs as _parse_qs
|
||||
import html.entities as htmlentitydefs
|
||||
import urllib.parse as urllib_parse
|
||||
unichr = chr
|
||||
else:
|
||||
from urlparse import parse_qs as _parse_qs
|
||||
import htmlentitydefs
|
||||
import urllib as urllib_parse
|
||||
|
||||
try:
|
||||
unichr
|
||||
except NameError:
|
||||
unichr = chr
|
||||
import typing # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
_XHTML_ESCAPE_RE = re.compile('[&<>"\']')
|
||||
_XHTML_ESCAPE_DICT = {'&': '&', '<': '<', '>': '>', '"': '"',
|
||||
|
@ -116,7 +110,7 @@ def url_escape(value, plus=True):
|
|||
# python 3 changed things around enough that we need two separate
|
||||
# implementations of url_unescape. We also need our own implementation
|
||||
# of parse_qs since python 3's version insists on decoding everything.
|
||||
if sys.version_info[0] < 3:
|
||||
if not PY3:
|
||||
def url_unescape(value, encoding='utf-8', plus=True):
|
||||
"""Decodes the given value from a URL.
|
||||
|
||||
|
@ -191,6 +185,7 @@ _UTF8_TYPES = (bytes, type(None))
|
|||
|
||||
|
||||
def utf8(value):
|
||||
# type: (typing.Union[bytes,unicode_type,None])->typing.Union[bytes,None]
|
||||
"""Converts a string argument to a byte string.
|
||||
|
||||
If the argument is already a byte string or None, it is returned unchanged.
|
||||
|
@ -204,6 +199,7 @@ def utf8(value):
|
|||
)
|
||||
return value.encode("utf-8")
|
||||
|
||||
|
||||
_TO_UNICODE_TYPES = (unicode_type, type(None))
|
||||
|
||||
|
||||
|
@ -221,6 +217,7 @@ def to_unicode(value):
|
|||
)
|
||||
return value.decode("utf-8")
|
||||
|
||||
|
||||
# to_unicode was previously named _unicode not because it was private,
|
||||
# but to avoid conflicts with the built-in unicode() function/type
|
||||
_unicode = to_unicode
|
||||
|
@ -269,6 +266,7 @@ def recursive_unicode(obj):
|
|||
else:
|
||||
return obj
|
||||
|
||||
|
||||
# I originally used the regex from
|
||||
# http://daringfireball.net/2010/07/improved_regex_for_matching_urls
|
||||
# but it gets all exponential on certain patterns (such as too many trailing
|
||||
|
@ -366,7 +364,7 @@ def linkify(text, shorten=False, extra_params="",
|
|||
# have a status bar, such as Safari by default)
|
||||
params += ' title="%s"' % href
|
||||
|
||||
return u('<a href="%s"%s>%s</a>') % (href, params, url)
|
||||
return u'<a href="%s"%s>%s</a>' % (href, params, url)
|
||||
|
||||
# First HTML-escape so that our strings are all safe.
|
||||
# The regex is modified to avoid character entites other than & so
|
||||
|
@ -396,4 +394,5 @@ def _build_unicode_map():
|
|||
unicode_map[name] = unichr(value)
|
||||
return unicode_map
|
||||
|
||||
|
||||
_HTML_UNICODE_MAP = _build_unicode_map()
|
||||
|
|
|
@ -74,7 +74,7 @@ See the `convert_yielded` function to extend this mechanism.
|
|||
via ``singledispatch``.
|
||||
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
import functools
|
||||
|
@ -83,16 +83,18 @@ import os
|
|||
import sys
|
||||
import textwrap
|
||||
import types
|
||||
import weakref
|
||||
|
||||
from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.log import app_log
|
||||
from tornado import stack_context
|
||||
from tornado.util import raise_exc_info
|
||||
from tornado.util import PY3, raise_exc_info
|
||||
|
||||
try:
|
||||
try:
|
||||
from functools import singledispatch # py34+
|
||||
# py34+
|
||||
from functools import singledispatch # type: ignore
|
||||
except ImportError:
|
||||
from singledispatch import singledispatch # backport
|
||||
except ImportError:
|
||||
|
@ -108,12 +110,14 @@ except ImportError:
|
|||
|
||||
try:
|
||||
try:
|
||||
from collections.abc import Generator as GeneratorType # py35+
|
||||
# py35+
|
||||
from collections.abc import Generator as GeneratorType # type: ignore
|
||||
except ImportError:
|
||||
from backports_abc import Generator as GeneratorType
|
||||
from backports_abc import Generator as GeneratorType # type: ignore
|
||||
|
||||
try:
|
||||
from inspect import isawaitable # py35+
|
||||
# py35+
|
||||
from inspect import isawaitable # type: ignore
|
||||
except ImportError:
|
||||
from backports_abc import isawaitable
|
||||
except ImportError:
|
||||
|
@ -121,12 +125,12 @@ except ImportError:
|
|||
raise
|
||||
from types import GeneratorType
|
||||
|
||||
def isawaitable(x):
|
||||
def isawaitable(x): # type: ignore
|
||||
return False
|
||||
|
||||
try:
|
||||
import builtins # py3
|
||||
except ImportError:
|
||||
if PY3:
|
||||
import builtins
|
||||
else:
|
||||
import __builtin__ as builtins
|
||||
|
||||
|
||||
|
@ -242,6 +246,26 @@ def coroutine(func, replace_callback=True):
|
|||
return _make_coroutine_wrapper(func, replace_callback=True)
|
||||
|
||||
|
||||
# Ties lifetime of runners to their result futures. Github Issue #1769
|
||||
# Generators, like any object in Python, must be strong referenced
|
||||
# in order to not be cleaned up by the garbage collector. When using
|
||||
# coroutines, the Runner object is what strong-refs the inner
|
||||
# generator. However, the only item that strong-reffed the Runner
|
||||
# was the last Future that the inner generator yielded (via the
|
||||
# Future's internal done_callback list). Usually this is enough, but
|
||||
# it is also possible for this Future to not have any strong references
|
||||
# other than other objects referenced by the Runner object (usually
|
||||
# when using other callback patterns and/or weakrefs). In this
|
||||
# situation, if a garbage collection ran, a cycle would be detected and
|
||||
# Runner objects could be destroyed along with their inner generators
|
||||
# and everything in their local scope.
|
||||
# This map provides strong references to Runner objects as long as
|
||||
# their result future objects also have strong references (typically
|
||||
# from the parent coroutine's Runner). This keeps the coroutine's
|
||||
# Runner alive.
|
||||
_futures_to_runners = weakref.WeakKeyDictionary()
|
||||
|
||||
|
||||
def _make_coroutine_wrapper(func, replace_callback):
|
||||
"""The inner workings of ``@gen.coroutine`` and ``@gen.engine``.
|
||||
|
||||
|
@ -251,10 +275,11 @@ def _make_coroutine_wrapper(func, replace_callback):
|
|||
"""
|
||||
# On Python 3.5, set the coroutine flag on our generator, to allow it
|
||||
# to be used with 'await'.
|
||||
wrapped = func
|
||||
if hasattr(types, 'coroutine'):
|
||||
func = types.coroutine(func)
|
||||
|
||||
@functools.wraps(func)
|
||||
@functools.wraps(wrapped)
|
||||
def wrapper(*args, **kwargs):
|
||||
future = TracebackFuture()
|
||||
|
||||
|
@ -291,7 +316,8 @@ def _make_coroutine_wrapper(func, replace_callback):
|
|||
except Exception:
|
||||
future.set_exc_info(sys.exc_info())
|
||||
else:
|
||||
Runner(result, future, yielded)
|
||||
_futures_to_runners[future] = Runner(result, future, yielded)
|
||||
yielded = None
|
||||
try:
|
||||
return future
|
||||
finally:
|
||||
|
@ -306,9 +332,21 @@ def _make_coroutine_wrapper(func, replace_callback):
|
|||
future = None
|
||||
future.set_result(result)
|
||||
return future
|
||||
|
||||
wrapper.__wrapped__ = wrapped
|
||||
wrapper.__tornado_coroutine__ = True
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_coroutine_function(func):
|
||||
"""Return whether *func* is a coroutine function, i.e. a function
|
||||
wrapped with `~.gen.coroutine`.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
"""
|
||||
return getattr(func, '__tornado_coroutine__', False)
|
||||
|
||||
|
||||
class Return(Exception):
|
||||
"""Special exception to return a value from a `coroutine`.
|
||||
|
||||
|
@ -682,6 +720,7 @@ def multi(children, quiet_exceptions=()):
|
|||
else:
|
||||
return multi_future(children, quiet_exceptions=quiet_exceptions)
|
||||
|
||||
|
||||
Multi = multi
|
||||
|
||||
|
||||
|
@ -830,7 +869,7 @@ def maybe_future(x):
|
|||
|
||||
|
||||
def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
|
||||
"""Wraps a `.Future` in a timeout.
|
||||
"""Wraps a `.Future` (or other yieldable object) in a timeout.
|
||||
|
||||
Raises `TimeoutError` if the input future does not complete before
|
||||
``timeout``, which may be specified in any form allowed by
|
||||
|
@ -841,15 +880,18 @@ def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
|
|||
will be logged unless it is of a type contained in ``quiet_exceptions``
|
||||
(which may be an exception type or a sequence of types).
|
||||
|
||||
Currently only supports Futures, not other `YieldPoint` classes.
|
||||
Does not support `YieldPoint` subclasses.
|
||||
|
||||
.. versionadded:: 4.0
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
Added the ``quiet_exceptions`` argument and the logging of unhandled
|
||||
exceptions.
|
||||
|
||||
.. versionchanged:: 4.4
|
||||
Added support for yieldable objects other than `.Future`.
|
||||
"""
|
||||
# TODO: allow yield points in addition to futures?
|
||||
# TODO: allow YieldPoints in addition to other yieldables?
|
||||
# Tricky to do with stack_context semantics.
|
||||
#
|
||||
# It's tempting to optimize this by cancelling the input future on timeout
|
||||
|
@ -857,6 +899,7 @@ def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
|
|||
# one waiting on the input future, so cancelling it might disrupt other
|
||||
# callers and B) concurrent futures can only be cancelled while they are
|
||||
# in the queue, so cancellation cannot reliably bound our waiting time.
|
||||
future = convert_yielded(future)
|
||||
result = Future()
|
||||
chain_future(future, result)
|
||||
if io_loop is None:
|
||||
|
@ -923,6 +966,9 @@ coroutines that are likely to yield Futures that are ready instantly.
|
|||
Usage: ``yield gen.moment``
|
||||
|
||||
.. versionadded:: 4.0
|
||||
|
||||
.. deprecated:: 4.5
|
||||
``yield None`` is now equivalent to ``yield gen.moment``.
|
||||
"""
|
||||
moment.set_result(None)
|
||||
|
||||
|
@ -953,6 +999,7 @@ class Runner(object):
|
|||
# of the coroutine.
|
||||
self.stack_context_deactivate = None
|
||||
if self.handle_yield(first_yielded):
|
||||
gen = result_future = first_yielded = None
|
||||
self.run()
|
||||
|
||||
def register_callback(self, key):
|
||||
|
@ -1009,10 +1056,15 @@ class Runner(object):
|
|||
except Exception:
|
||||
self.had_exception = True
|
||||
exc_info = sys.exc_info()
|
||||
future = None
|
||||
|
||||
if exc_info is not None:
|
||||
yielded = self.gen.throw(*exc_info)
|
||||
exc_info = None
|
||||
try:
|
||||
yielded = self.gen.throw(*exc_info)
|
||||
finally:
|
||||
# Break up a reference to itself
|
||||
# for faster GC on CPython.
|
||||
exc_info = None
|
||||
else:
|
||||
yielded = self.gen.send(value)
|
||||
|
||||
|
@ -1045,6 +1097,7 @@ class Runner(object):
|
|||
return
|
||||
if not self.handle_yield(yielded):
|
||||
return
|
||||
yielded = None
|
||||
finally:
|
||||
self.running = False
|
||||
|
||||
|
@ -1093,8 +1146,12 @@ class Runner(object):
|
|||
self.future.set_exc_info(sys.exc_info())
|
||||
|
||||
if not self.future.done() or self.future is moment:
|
||||
def inner(f):
|
||||
# Break a reference cycle to speed GC.
|
||||
f = None # noqa
|
||||
self.run()
|
||||
self.io_loop.add_future(
|
||||
self.future, lambda f: self.run())
|
||||
self.future, inner)
|
||||
return False
|
||||
return True
|
||||
|
||||
|
@ -1116,6 +1173,7 @@ class Runner(object):
|
|||
self.stack_context_deactivate()
|
||||
self.stack_context_deactivate = None
|
||||
|
||||
|
||||
Arguments = collections.namedtuple('Arguments', ['args', 'kwargs'])
|
||||
|
||||
|
||||
|
@ -1135,6 +1193,7 @@ def _argument_adapter(callback):
|
|||
callback(None)
|
||||
return wrapper
|
||||
|
||||
|
||||
# Convert Awaitables into Futures. It is unfortunately possible
|
||||
# to have infinite recursion here if those Awaitables assume that
|
||||
# we're using a different coroutine runner and yield objects
|
||||
|
@ -1212,7 +1271,9 @@ def convert_yielded(yielded):
|
|||
.. versionadded:: 4.1
|
||||
"""
|
||||
# Lists and dicts containing YieldPoints were handled earlier.
|
||||
if isinstance(yielded, (list, dict)):
|
||||
if yielded is None:
|
||||
return moment
|
||||
elif isinstance(yielded, (list, dict)):
|
||||
return multi(yielded)
|
||||
elif is_future(yielded):
|
||||
return yielded
|
||||
|
@ -1221,6 +1282,7 @@ def convert_yielded(yielded):
|
|||
else:
|
||||
raise BadYieldError("yielded unknown object %r" % (yielded,))
|
||||
|
||||
|
||||
if singledispatch is not None:
|
||||
convert_yielded = singledispatch(convert_yielded)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
.. versionadded:: 4.0
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import re
|
||||
|
||||
|
@ -30,7 +30,7 @@ from tornado import httputil
|
|||
from tornado import iostream
|
||||
from tornado.log import gen_log, app_log
|
||||
from tornado import stack_context
|
||||
from tornado.util import GzipDecompressor
|
||||
from tornado.util import GzipDecompressor, PY3
|
||||
|
||||
|
||||
class _QuietException(Exception):
|
||||
|
@ -257,6 +257,7 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
if need_delegate_close:
|
||||
with _ExceptionLoggingContext(app_log):
|
||||
delegate.on_connection_close()
|
||||
header_future = None
|
||||
self._clear_callbacks()
|
||||
raise gen.Return(True)
|
||||
|
||||
|
@ -342,7 +343,7 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
'Transfer-Encoding' not in headers)
|
||||
else:
|
||||
self._response_start_line = start_line
|
||||
lines.append(utf8('HTTP/1.1 %s %s' % (start_line[1], start_line[2])))
|
||||
lines.append(utf8('HTTP/1.1 %d %s' % (start_line[1], start_line[2])))
|
||||
self._chunking_output = (
|
||||
# TODO: should this use
|
||||
# self._request_start_line.version or
|
||||
|
@ -351,7 +352,7 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
# 304 responses have no body (not even a zero-length body), and so
|
||||
# should not have either Content-Length or Transfer-Encoding.
|
||||
# headers.
|
||||
start_line.code != 304 and
|
||||
start_line.code not in (204, 304) and
|
||||
# No need to chunk the output if a Content-Length is specified.
|
||||
'Content-Length' not in headers and
|
||||
# Applications are discouraged from touching Transfer-Encoding,
|
||||
|
@ -359,8 +360,8 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
'Transfer-Encoding' not in headers)
|
||||
# If a 1.0 client asked for keep-alive, add the header.
|
||||
if (self._request_start_line.version == 'HTTP/1.0' and
|
||||
(self._request_headers.get('Connection', '').lower()
|
||||
== 'keep-alive')):
|
||||
(self._request_headers.get('Connection', '').lower() ==
|
||||
'keep-alive')):
|
||||
headers['Connection'] = 'Keep-Alive'
|
||||
if self._chunking_output:
|
||||
headers['Transfer-Encoding'] = 'chunked'
|
||||
|
@ -372,7 +373,14 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
self._expected_content_remaining = int(headers['Content-Length'])
|
||||
else:
|
||||
self._expected_content_remaining = None
|
||||
lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()])
|
||||
# TODO: headers are supposed to be of type str, but we still have some
|
||||
# cases that let bytes slip through. Remove these native_str calls when those
|
||||
# are fixed.
|
||||
header_lines = (native_str(n) + ": " + native_str(v) for n, v in headers.get_all())
|
||||
if PY3:
|
||||
lines.extend(l.encode('latin1') for l in header_lines)
|
||||
else:
|
||||
lines.extend(header_lines)
|
||||
for line in lines:
|
||||
if b'\n' in line:
|
||||
raise ValueError('Newline in header: ' + repr(line))
|
||||
|
@ -479,9 +487,11 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
connection_header = connection_header.lower()
|
||||
if start_line.version == "HTTP/1.1":
|
||||
return connection_header != "close"
|
||||
elif ("Content-Length" in headers
|
||||
or headers.get("Transfer-Encoding", "").lower() == "chunked"
|
||||
or start_line.method in ("HEAD", "GET")):
|
||||
elif ("Content-Length" in headers or
|
||||
headers.get("Transfer-Encoding", "").lower() == "chunked" or
|
||||
getattr(start_line, 'method', None) in ("HEAD", "GET")):
|
||||
# start_line may be a request or response start line; only
|
||||
# the former has a method attribute.
|
||||
return connection_header == "keep-alive"
|
||||
return False
|
||||
|
||||
|
@ -531,7 +541,13 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
"Multiple unequal Content-Lengths: %r" %
|
||||
headers["Content-Length"])
|
||||
headers["Content-Length"] = pieces[0]
|
||||
content_length = int(headers["Content-Length"])
|
||||
|
||||
try:
|
||||
content_length = int(headers["Content-Length"])
|
||||
except ValueError:
|
||||
# Handles non-integer Content-Length value.
|
||||
raise httputil.HTTPInputError(
|
||||
"Only integer Content-Length is allowed: %s" % headers["Content-Length"])
|
||||
|
||||
if content_length > self._max_body_size:
|
||||
raise httputil.HTTPInputError("Content-Length too long")
|
||||
|
@ -550,7 +566,7 @@ class HTTP1Connection(httputil.HTTPConnection):
|
|||
|
||||
if content_length is not None:
|
||||
return self._read_fixed_body(content_length, delegate)
|
||||
if headers.get("Transfer-Encoding") == "chunked":
|
||||
if headers.get("Transfer-Encoding", "").lower() == "chunked":
|
||||
return self._read_chunked_body(delegate)
|
||||
if self.is_client:
|
||||
return self._read_body_until_close(delegate)
|
||||
|
|
|
@ -25,7 +25,7 @@ to switch to ``curl_httpclient`` for reasons such as the following:
|
|||
Note that if you are using ``curl_httpclient``, it is highly
|
||||
recommended that you use a recent version of ``libcurl`` and
|
||||
``pycurl``. Currently the minimum supported version of libcurl is
|
||||
7.21.1, and the minimum version of pycurl is 7.18.2. It is highly
|
||||
7.22.0, and the minimum version of pycurl is 7.18.2. It is highly
|
||||
recommended that your ``libcurl`` installation is built with
|
||||
asynchronous DNS resolver (threaded or c-ares), otherwise you may
|
||||
encounter various problems with request timeouts (for more
|
||||
|
@ -38,7 +38,7 @@ To select ``curl_httpclient``, call `AsyncHTTPClient.configure` at startup::
|
|||
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import functools
|
||||
import time
|
||||
|
@ -61,7 +61,7 @@ class HTTPClient(object):
|
|||
http_client = httpclient.HTTPClient()
|
||||
try:
|
||||
response = http_client.fetch("http://www.google.com/")
|
||||
print response.body
|
||||
print(response.body)
|
||||
except httpclient.HTTPError as e:
|
||||
# HTTPError is raised for non-200 responses; the response
|
||||
# can be found in e.response.
|
||||
|
@ -108,14 +108,14 @@ class AsyncHTTPClient(Configurable):
|
|||
|
||||
Example usage::
|
||||
|
||||
def handle_request(response):
|
||||
def handle_response(response):
|
||||
if response.error:
|
||||
print "Error:", response.error
|
||||
print("Error: %s" % response.error)
|
||||
else:
|
||||
print response.body
|
||||
print(response.body)
|
||||
|
||||
http_client = AsyncHTTPClient()
|
||||
http_client.fetch("http://www.google.com/", handle_request)
|
||||
http_client.fetch("http://www.google.com/", handle_response)
|
||||
|
||||
The constructor for this class is magic in several respects: It
|
||||
actually creates an instance of an implementation-specific
|
||||
|
@ -211,10 +211,12 @@ class AsyncHTTPClient(Configurable):
|
|||
kwargs: ``HTTPRequest(request, **kwargs)``
|
||||
|
||||
This method returns a `.Future` whose result is an
|
||||
`HTTPResponse`. By default, the ``Future`` will raise an `HTTPError`
|
||||
if the request returned a non-200 response code. Instead, if
|
||||
``raise_error`` is set to False, the response will always be
|
||||
returned regardless of the response code.
|
||||
`HTTPResponse`. By default, the ``Future`` will raise an
|
||||
`HTTPError` if the request returned a non-200 response code
|
||||
(other errors may also be raised if the server could not be
|
||||
contacted). Instead, if ``raise_error`` is set to False, the
|
||||
response will always be returned regardless of the response
|
||||
code.
|
||||
|
||||
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
|
||||
In the callback interface, `HTTPError` is not automatically raised.
|
||||
|
@ -225,6 +227,9 @@ class AsyncHTTPClient(Configurable):
|
|||
raise RuntimeError("fetch() called on closed AsyncHTTPClient")
|
||||
if not isinstance(request, HTTPRequest):
|
||||
request = HTTPRequest(url=request, **kwargs)
|
||||
else:
|
||||
if kwargs:
|
||||
raise ValueError("kwargs can't be used if request is an HTTPRequest object")
|
||||
# We may modify this (to add Host, Accept-Encoding, etc),
|
||||
# so make sure we don't modify the caller's object. This is also
|
||||
# where normal dicts get converted to HTTPHeaders objects.
|
||||
|
@ -305,10 +310,10 @@ class HTTPRequest(object):
|
|||
network_interface=None, streaming_callback=None,
|
||||
header_callback=None, prepare_curl_callback=None,
|
||||
proxy_host=None, proxy_port=None, proxy_username=None,
|
||||
proxy_password=None, allow_nonstandard_methods=None,
|
||||
validate_cert=None, ca_certs=None,
|
||||
allow_ipv6=None,
|
||||
client_key=None, client_cert=None, body_producer=None,
|
||||
proxy_password=None, proxy_auth_mode=None,
|
||||
allow_nonstandard_methods=None, validate_cert=None,
|
||||
ca_certs=None, allow_ipv6=None, client_key=None,
|
||||
client_cert=None, body_producer=None,
|
||||
expect_100_continue=False, decompress_response=None,
|
||||
ssl_options=None):
|
||||
r"""All parameters except ``url`` are optional.
|
||||
|
@ -336,13 +341,15 @@ class HTTPRequest(object):
|
|||
Allowed values are implementation-defined; ``curl_httpclient``
|
||||
supports "basic" and "digest"; ``simple_httpclient`` only supports
|
||||
"basic"
|
||||
:arg float connect_timeout: Timeout for initial connection in seconds
|
||||
:arg float request_timeout: Timeout for entire request in seconds
|
||||
:arg float connect_timeout: Timeout for initial connection in seconds,
|
||||
default 20 seconds
|
||||
:arg float request_timeout: Timeout for entire request in seconds,
|
||||
default 20 seconds
|
||||
:arg if_modified_since: Timestamp for ``If-Modified-Since`` header
|
||||
:type if_modified_since: `datetime` or `float`
|
||||
:arg bool follow_redirects: Should redirects be followed automatically
|
||||
or return the 3xx response?
|
||||
:arg int max_redirects: Limit for ``follow_redirects``
|
||||
or return the 3xx response? Default True.
|
||||
:arg int max_redirects: Limit for ``follow_redirects``, default 5.
|
||||
:arg string user_agent: String to send as ``User-Agent`` header
|
||||
:arg bool decompress_response: Request a compressed response from
|
||||
the server and decompress it after downloading. Default is True.
|
||||
|
@ -367,16 +374,18 @@ class HTTPRequest(object):
|
|||
a ``pycurl.Curl`` object to allow the application to make additional
|
||||
``setopt`` calls.
|
||||
:arg string proxy_host: HTTP proxy hostname. To use proxies,
|
||||
``proxy_host`` and ``proxy_port`` must be set; ``proxy_username`` and
|
||||
``proxy_pass`` are optional. Proxies are currently only supported
|
||||
with ``curl_httpclient``.
|
||||
``proxy_host`` and ``proxy_port`` must be set; ``proxy_username``,
|
||||
``proxy_pass`` and ``proxy_auth_mode`` are optional. Proxies are
|
||||
currently only supported with ``curl_httpclient``.
|
||||
:arg int proxy_port: HTTP proxy port
|
||||
:arg string proxy_username: HTTP proxy username
|
||||
:arg string proxy_password: HTTP proxy password
|
||||
:arg string proxy_auth_mode: HTTP proxy Authentication mode;
|
||||
default is "basic". supports "basic" and "digest"
|
||||
:arg bool allow_nonstandard_methods: Allow unknown values for ``method``
|
||||
argument?
|
||||
argument? Default is False.
|
||||
:arg bool validate_cert: For HTTPS requests, validate the server's
|
||||
certificate?
|
||||
certificate? Default is True.
|
||||
:arg string ca_certs: filename of CA certificates in PEM format,
|
||||
or None to use defaults. See note below when used with
|
||||
``curl_httpclient``.
|
||||
|
@ -414,6 +423,9 @@ class HTTPRequest(object):
|
|||
|
||||
.. versionadded:: 4.2
|
||||
The ``ssl_options`` argument.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
The ``proxy_auth_mode`` argument.
|
||||
"""
|
||||
# Note that some of these attributes go through property setters
|
||||
# defined below.
|
||||
|
@ -425,6 +437,7 @@ class HTTPRequest(object):
|
|||
self.proxy_port = proxy_port
|
||||
self.proxy_username = proxy_username
|
||||
self.proxy_password = proxy_password
|
||||
self.proxy_auth_mode = proxy_auth_mode
|
||||
self.url = url
|
||||
self.method = method
|
||||
self.body = body
|
||||
|
@ -525,7 +538,7 @@ class HTTPResponse(object):
|
|||
|
||||
* buffer: ``cStringIO`` object for response body
|
||||
|
||||
* body: response body as string (created on demand from ``self.buffer``)
|
||||
* body: response body as bytes (created on demand from ``self.buffer``)
|
||||
|
||||
* error: Exception object, if any
|
||||
|
||||
|
@ -567,7 +580,8 @@ class HTTPResponse(object):
|
|||
self.request_time = request_time
|
||||
self.time_info = time_info or {}
|
||||
|
||||
def _get_body(self):
|
||||
@property
|
||||
def body(self):
|
||||
if self.buffer is None:
|
||||
return None
|
||||
elif self._body is None:
|
||||
|
@ -575,8 +589,6 @@ class HTTPResponse(object):
|
|||
|
||||
return self._body
|
||||
|
||||
body = property(_get_body)
|
||||
|
||||
def rethrow(self):
|
||||
"""If there was an error on the request, raise an `HTTPError`."""
|
||||
if self.error:
|
||||
|
@ -610,6 +622,12 @@ class HTTPError(Exception):
|
|||
def __str__(self):
|
||||
return "HTTP %d: %s" % (self.code, self.message)
|
||||
|
||||
# There is a cyclic reference between self and self.response,
|
||||
# which breaks the default __repr__ implementation.
|
||||
# (especially on pypy, which doesn't have the same recursion
|
||||
# detection as cpython).
|
||||
__repr__ = __str__
|
||||
|
||||
|
||||
class _RequestProxy(object):
|
||||
"""Combines an object with a dictionary of defaults.
|
||||
|
@ -655,5 +673,6 @@ def main():
|
|||
print(native_str(response.body))
|
||||
client.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
|
@ -26,7 +26,7 @@ class except to start a server at the beginning of the process
|
|||
to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import socket
|
||||
|
||||
|
@ -62,6 +62,13 @@ class HTTPServer(TCPServer, Configurable,
|
|||
if Tornado is run behind an SSL-decoding proxy that does not set one of
|
||||
the supported ``xheaders``.
|
||||
|
||||
By default, when parsing the ``X-Forwarded-For`` header, Tornado will
|
||||
select the last (i.e., the closest) address on the list of hosts as the
|
||||
remote host IP address. To select the next server in the chain, a list of
|
||||
trusted downstream hosts may be passed as the ``trusted_downstream``
|
||||
argument. These hosts will be skipped when parsing the ``X-Forwarded-For``
|
||||
header.
|
||||
|
||||
To make this server serve SSL traffic, send the ``ssl_options`` keyword
|
||||
argument with an `ssl.SSLContext` object. For compatibility with older
|
||||
versions of Python ``ssl_options`` may also be a dictionary of keyword
|
||||
|
@ -124,6 +131,9 @@ class HTTPServer(TCPServer, Configurable,
|
|||
|
||||
.. versionchanged:: 4.2
|
||||
`HTTPServer` is now a subclass of `tornado.util.Configurable`.
|
||||
|
||||
.. versionchanged:: 4.5
|
||||
Added the ``trusted_downstream`` argument.
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Ignore args to __init__; real initialization belongs in
|
||||
|
@ -138,7 +148,8 @@ class HTTPServer(TCPServer, Configurable,
|
|||
decompress_request=False,
|
||||
chunk_size=None, max_header_size=None,
|
||||
idle_connection_timeout=None, body_timeout=None,
|
||||
max_body_size=None, max_buffer_size=None):
|
||||
max_body_size=None, max_buffer_size=None,
|
||||
trusted_downstream=None):
|
||||
self.request_callback = request_callback
|
||||
self.no_keep_alive = no_keep_alive
|
||||
self.xheaders = xheaders
|
||||
|
@ -149,11 +160,13 @@ class HTTPServer(TCPServer, Configurable,
|
|||
max_header_size=max_header_size,
|
||||
header_timeout=idle_connection_timeout or 3600,
|
||||
max_body_size=max_body_size,
|
||||
body_timeout=body_timeout)
|
||||
body_timeout=body_timeout,
|
||||
no_keep_alive=no_keep_alive)
|
||||
TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
|
||||
max_buffer_size=max_buffer_size,
|
||||
read_chunk_size=chunk_size)
|
||||
self._connections = set()
|
||||
self.trusted_downstream = trusted_downstream
|
||||
|
||||
@classmethod
|
||||
def configurable_base(cls):
|
||||
|
@ -172,21 +185,55 @@ class HTTPServer(TCPServer, Configurable,
|
|||
|
||||
def handle_stream(self, stream, address):
|
||||
context = _HTTPRequestContext(stream, address,
|
||||
self.protocol)
|
||||
self.protocol,
|
||||
self.trusted_downstream)
|
||||
conn = HTTP1ServerConnection(
|
||||
stream, self.conn_params, context)
|
||||
self._connections.add(conn)
|
||||
conn.start_serving(self)
|
||||
|
||||
def start_request(self, server_conn, request_conn):
|
||||
return _ServerRequestAdapter(self, server_conn, request_conn)
|
||||
if isinstance(self.request_callback, httputil.HTTPServerConnectionDelegate):
|
||||
delegate = self.request_callback.start_request(server_conn, request_conn)
|
||||
else:
|
||||
delegate = _CallableAdapter(self.request_callback, request_conn)
|
||||
|
||||
if self.xheaders:
|
||||
delegate = _ProxyAdapter(delegate, request_conn)
|
||||
|
||||
return delegate
|
||||
|
||||
def on_close(self, server_conn):
|
||||
self._connections.remove(server_conn)
|
||||
|
||||
|
||||
class _CallableAdapter(httputil.HTTPMessageDelegate):
|
||||
def __init__(self, request_callback, request_conn):
|
||||
self.connection = request_conn
|
||||
self.request_callback = request_callback
|
||||
self.request = None
|
||||
self.delegate = None
|
||||
self._chunks = []
|
||||
|
||||
def headers_received(self, start_line, headers):
|
||||
self.request = httputil.HTTPServerRequest(
|
||||
connection=self.connection, start_line=start_line,
|
||||
headers=headers)
|
||||
|
||||
def data_received(self, chunk):
|
||||
self._chunks.append(chunk)
|
||||
|
||||
def finish(self):
|
||||
self.request.body = b''.join(self._chunks)
|
||||
self.request._parse_body()
|
||||
self.request_callback(self.request)
|
||||
|
||||
def on_connection_close(self):
|
||||
self._chunks = None
|
||||
|
||||
|
||||
class _HTTPRequestContext(object):
|
||||
def __init__(self, stream, address, protocol):
|
||||
def __init__(self, stream, address, protocol, trusted_downstream=None):
|
||||
self.address = address
|
||||
# Save the socket's address family now so we know how to
|
||||
# interpret self.address even after the stream is closed
|
||||
|
@ -210,6 +257,7 @@ class _HTTPRequestContext(object):
|
|||
self.protocol = "http"
|
||||
self._orig_remote_ip = self.remote_ip
|
||||
self._orig_protocol = self.protocol
|
||||
self.trusted_downstream = set(trusted_downstream or [])
|
||||
|
||||
def __str__(self):
|
||||
if self.address_family in (socket.AF_INET, socket.AF_INET6):
|
||||
|
@ -226,7 +274,10 @@ class _HTTPRequestContext(object):
|
|||
"""Rewrite the ``remote_ip`` and ``protocol`` fields."""
|
||||
# Squid uses X-Forwarded-For, others use X-Real-Ip
|
||||
ip = headers.get("X-Forwarded-For", self.remote_ip)
|
||||
ip = ip.split(',')[-1].strip()
|
||||
# Skip trusted downstream hosts in X-Forwarded-For list
|
||||
for ip in (cand.strip() for cand in reversed(ip.split(','))):
|
||||
if ip not in self.trusted_downstream:
|
||||
break
|
||||
ip = headers.get("X-Real-Ip", ip)
|
||||
if netutil.is_valid_ip(ip):
|
||||
self.remote_ip = ip
|
||||
|
@ -247,58 +298,28 @@ class _HTTPRequestContext(object):
|
|||
self.protocol = self._orig_protocol
|
||||
|
||||
|
||||
class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
|
||||
"""Adapts the `HTTPMessageDelegate` interface to the interface expected
|
||||
by our clients.
|
||||
"""
|
||||
def __init__(self, server, server_conn, request_conn):
|
||||
self.server = server
|
||||
class _ProxyAdapter(httputil.HTTPMessageDelegate):
|
||||
def __init__(self, delegate, request_conn):
|
||||
self.connection = request_conn
|
||||
self.request = None
|
||||
if isinstance(server.request_callback,
|
||||
httputil.HTTPServerConnectionDelegate):
|
||||
self.delegate = server.request_callback.start_request(
|
||||
server_conn, request_conn)
|
||||
self._chunks = None
|
||||
else:
|
||||
self.delegate = None
|
||||
self._chunks = []
|
||||
self.delegate = delegate
|
||||
|
||||
def headers_received(self, start_line, headers):
|
||||
if self.server.xheaders:
|
||||
self.connection.context._apply_xheaders(headers)
|
||||
if self.delegate is None:
|
||||
self.request = httputil.HTTPServerRequest(
|
||||
connection=self.connection, start_line=start_line,
|
||||
headers=headers)
|
||||
else:
|
||||
return self.delegate.headers_received(start_line, headers)
|
||||
self.connection.context._apply_xheaders(headers)
|
||||
return self.delegate.headers_received(start_line, headers)
|
||||
|
||||
def data_received(self, chunk):
|
||||
if self.delegate is None:
|
||||
self._chunks.append(chunk)
|
||||
else:
|
||||
return self.delegate.data_received(chunk)
|
||||
return self.delegate.data_received(chunk)
|
||||
|
||||
def finish(self):
|
||||
if self.delegate is None:
|
||||
self.request.body = b''.join(self._chunks)
|
||||
self.request._parse_body()
|
||||
self.server.request_callback(self.request)
|
||||
else:
|
||||
self.delegate.finish()
|
||||
self.delegate.finish()
|
||||
self._cleanup()
|
||||
|
||||
def on_connection_close(self):
|
||||
if self.delegate is None:
|
||||
self._chunks = None
|
||||
else:
|
||||
self.delegate.on_connection_close()
|
||||
self.delegate.on_connection_close()
|
||||
self._cleanup()
|
||||
|
||||
def _cleanup(self):
|
||||
if self.server.xheaders:
|
||||
self.connection.context._unapply_xheaders()
|
||||
self.connection.context._unapply_xheaders()
|
||||
|
||||
|
||||
HTTPRequest = httputil.HTTPServerRequest
|
||||
|
|
|
@ -20,7 +20,7 @@ This module also defines the `HTTPServerRequest` class which is exposed
|
|||
via `tornado.web.RequestHandler.request`.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import calendar
|
||||
import collections
|
||||
|
@ -33,33 +33,37 @@ import time
|
|||
|
||||
from tornado.escape import native_str, parse_qs_bytes, utf8
|
||||
from tornado.log import gen_log
|
||||
from tornado.util import ObjectDict
|
||||
from tornado.util import ObjectDict, PY3
|
||||
|
||||
try:
|
||||
import Cookie # py2
|
||||
except ImportError:
|
||||
import http.cookies as Cookie # py3
|
||||
if PY3:
|
||||
import http.cookies as Cookie
|
||||
from http.client import responses
|
||||
from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl
|
||||
else:
|
||||
import Cookie
|
||||
from httplib import responses
|
||||
from urllib import urlencode
|
||||
from urlparse import urlparse, urlunparse, parse_qsl
|
||||
|
||||
try:
|
||||
from httplib import responses # py2
|
||||
except ImportError:
|
||||
from http.client import responses # py3
|
||||
|
||||
# responses is unused in this file, but we re-export it to other files.
|
||||
# Reference it so pyflakes doesn't complain.
|
||||
responses
|
||||
|
||||
try:
|
||||
from urllib import urlencode # py2
|
||||
except ImportError:
|
||||
from urllib.parse import urlencode # py3
|
||||
|
||||
try:
|
||||
from ssl import SSLError
|
||||
except ImportError:
|
||||
# ssl is unavailable on app engine.
|
||||
class SSLError(Exception):
|
||||
class _SSLError(Exception):
|
||||
pass
|
||||
# Hack around a mypy limitation. We can't simply put "type: ignore"
|
||||
# on the class definition itself; must go through an assignment.
|
||||
SSLError = _SSLError # type: ignore
|
||||
|
||||
try:
|
||||
import typing
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
|
||||
|
@ -95,6 +99,7 @@ class _NormalizedHeaderCache(dict):
|
|||
del self[old_key]
|
||||
return normalized
|
||||
|
||||
|
||||
_normalized_headers = _NormalizedHeaderCache(1000)
|
||||
|
||||
|
||||
|
@ -127,8 +132,8 @@ class HTTPHeaders(collections.MutableMapping):
|
|||
Set-Cookie: C=D
|
||||
"""
|
||||
def __init__(self, *args, **kwargs):
|
||||
self._dict = {}
|
||||
self._as_list = {}
|
||||
self._dict = {} # type: typing.Dict[str, str]
|
||||
self._as_list = {} # type: typing.Dict[str, typing.List[str]]
|
||||
self._last_key = None
|
||||
if (len(args) == 1 and len(kwargs) == 0 and
|
||||
isinstance(args[0], HTTPHeaders)):
|
||||
|
@ -142,6 +147,7 @@ class HTTPHeaders(collections.MutableMapping):
|
|||
# new public methods
|
||||
|
||||
def add(self, name, value):
|
||||
# type: (str, str) -> None
|
||||
"""Adds a new value for the given key."""
|
||||
norm_name = _normalized_headers[name]
|
||||
self._last_key = norm_name
|
||||
|
@ -158,6 +164,7 @@ class HTTPHeaders(collections.MutableMapping):
|
|||
return self._as_list.get(norm_name, [])
|
||||
|
||||
def get_all(self):
|
||||
# type: () -> typing.Iterable[typing.Tuple[str, str]]
|
||||
"""Returns an iterable of all (name, value) pairs.
|
||||
|
||||
If a header has multiple values, multiple pairs will be
|
||||
|
@ -206,6 +213,7 @@ class HTTPHeaders(collections.MutableMapping):
|
|||
self._as_list[norm_name] = [value]
|
||||
|
||||
def __getitem__(self, name):
|
||||
# type: (str) -> str
|
||||
return self._dict[_normalized_headers[name]]
|
||||
|
||||
def __delitem__(self, name):
|
||||
|
@ -228,6 +236,14 @@ class HTTPHeaders(collections.MutableMapping):
|
|||
# the appearance that HTTPHeaders is a single container.
|
||||
__copy__ = copy
|
||||
|
||||
def __str__(self):
|
||||
lines = []
|
||||
for name, value in self.get_all():
|
||||
lines.append("%s: %s\n" % (name, value))
|
||||
return "".join(lines)
|
||||
|
||||
__unicode__ = __str__
|
||||
|
||||
|
||||
class HTTPServerRequest(object):
|
||||
"""A single HTTP request.
|
||||
|
@ -323,7 +339,7 @@ class HTTPServerRequest(object):
|
|||
"""
|
||||
def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
|
||||
body=None, host=None, files=None, connection=None,
|
||||
start_line=None):
|
||||
start_line=None, server_connection=None):
|
||||
if start_line is not None:
|
||||
method, uri, version = start_line
|
||||
self.method = method
|
||||
|
@ -338,8 +354,10 @@ class HTTPServerRequest(object):
|
|||
self.protocol = getattr(context, 'protocol', "http")
|
||||
|
||||
self.host = host or self.headers.get("Host") or "127.0.0.1"
|
||||
self.host_name = split_host_and_port(self.host.lower())[0]
|
||||
self.files = files or {}
|
||||
self.connection = connection
|
||||
self.server_connection = server_connection
|
||||
self._start_time = time.time()
|
||||
self._finish_time = None
|
||||
|
||||
|
@ -365,10 +383,18 @@ class HTTPServerRequest(object):
|
|||
self._cookies = Cookie.SimpleCookie()
|
||||
if "Cookie" in self.headers:
|
||||
try:
|
||||
self._cookies.load(
|
||||
native_str(self.headers["Cookie"]))
|
||||
parsed = parse_cookie(self.headers["Cookie"])
|
||||
except Exception:
|
||||
self._cookies = {}
|
||||
pass
|
||||
else:
|
||||
for k, v in parsed.items():
|
||||
try:
|
||||
self._cookies[k] = v
|
||||
except Exception:
|
||||
# SimpleCookie imposes some restrictions on keys;
|
||||
# parse_cookie does not. Discard any cookies
|
||||
# with disallowed keys.
|
||||
pass
|
||||
return self._cookies
|
||||
|
||||
def write(self, chunk, callback=None):
|
||||
|
@ -577,11 +603,28 @@ def url_concat(url, args):
|
|||
>>> url_concat("http://example.com/foo?a=b", [("c", "d"), ("c", "d2")])
|
||||
'http://example.com/foo?a=b&c=d&c=d2'
|
||||
"""
|
||||
if not args:
|
||||
if args is None:
|
||||
return url
|
||||
if url[-1] not in ('?', '&'):
|
||||
url += '&' if ('?' in url) else '?'
|
||||
return url + urlencode(args)
|
||||
parsed_url = urlparse(url)
|
||||
if isinstance(args, dict):
|
||||
parsed_query = parse_qsl(parsed_url.query, keep_blank_values=True)
|
||||
parsed_query.extend(args.items())
|
||||
elif isinstance(args, list) or isinstance(args, tuple):
|
||||
parsed_query = parse_qsl(parsed_url.query, keep_blank_values=True)
|
||||
parsed_query.extend(args)
|
||||
else:
|
||||
err = "'args' parameter should be dict, list or tuple. Not {0}".format(
|
||||
type(args))
|
||||
raise TypeError(err)
|
||||
final_query = urlencode(parsed_query)
|
||||
url = urlunparse((
|
||||
parsed_url[0],
|
||||
parsed_url[1],
|
||||
parsed_url[2],
|
||||
parsed_url[3],
|
||||
final_query,
|
||||
parsed_url[5]))
|
||||
return url
|
||||
|
||||
|
||||
class HTTPFile(ObjectDict):
|
||||
|
@ -743,7 +786,7 @@ def parse_multipart_form_data(boundary, data, arguments, files):
|
|||
name = disp_params["name"]
|
||||
if disp_params.get("filename"):
|
||||
ctype = headers.get("Content-Type", "application/unknown")
|
||||
files.setdefault(name, []).append(HTTPFile(
|
||||
files.setdefault(name, []).append(HTTPFile( # type: ignore
|
||||
filename=disp_params["filename"], body=value,
|
||||
content_type=ctype))
|
||||
else:
|
||||
|
@ -895,3 +938,84 @@ def split_host_and_port(netloc):
|
|||
host = netloc
|
||||
port = None
|
||||
return (host, port)
|
||||
|
||||
|
||||
_OctalPatt = re.compile(r"\\[0-3][0-7][0-7]")
|
||||
_QuotePatt = re.compile(r"[\\].")
|
||||
_nulljoin = ''.join
|
||||
|
||||
|
||||
def _unquote_cookie(str):
|
||||
"""Handle double quotes and escaping in cookie values.
|
||||
|
||||
This method is copied verbatim from the Python 3.5 standard
|
||||
library (http.cookies._unquote) so we don't have to depend on
|
||||
non-public interfaces.
|
||||
"""
|
||||
# If there aren't any doublequotes,
|
||||
# then there can't be any special characters. See RFC 2109.
|
||||
if str is None or len(str) < 2:
|
||||
return str
|
||||
if str[0] != '"' or str[-1] != '"':
|
||||
return str
|
||||
|
||||
# We have to assume that we must decode this string.
|
||||
# Down to work.
|
||||
|
||||
# Remove the "s
|
||||
str = str[1:-1]
|
||||
|
||||
# Check for special sequences. Examples:
|
||||
# \012 --> \n
|
||||
# \" --> "
|
||||
#
|
||||
i = 0
|
||||
n = len(str)
|
||||
res = []
|
||||
while 0 <= i < n:
|
||||
o_match = _OctalPatt.search(str, i)
|
||||
q_match = _QuotePatt.search(str, i)
|
||||
if not o_match and not q_match: # Neither matched
|
||||
res.append(str[i:])
|
||||
break
|
||||
# else:
|
||||
j = k = -1
|
||||
if o_match:
|
||||
j = o_match.start(0)
|
||||
if q_match:
|
||||
k = q_match.start(0)
|
||||
if q_match and (not o_match or k < j): # QuotePatt matched
|
||||
res.append(str[i:k])
|
||||
res.append(str[k + 1])
|
||||
i = k + 2
|
||||
else: # OctalPatt matched
|
||||
res.append(str[i:j])
|
||||
res.append(chr(int(str[j + 1:j + 4], 8)))
|
||||
i = j + 4
|
||||
return _nulljoin(res)
|
||||
|
||||
|
||||
def parse_cookie(cookie):
|
||||
"""Parse a ``Cookie`` HTTP header into a dict of name/value pairs.
|
||||
|
||||
This function attempts to mimic browser cookie parsing behavior;
|
||||
it specifically does not follow any of the cookie-related RFCs
|
||||
(because browsers don't either).
|
||||
|
||||
The algorithm used is identical to that used by Django version 1.9.10.
|
||||
|
||||
.. versionadded:: 4.4.2
|
||||
"""
|
||||
cookiedict = {}
|
||||
for chunk in cookie.split(str(';')):
|
||||
if str('=') in chunk:
|
||||
key, val = chunk.split(str('='), 1)
|
||||
else:
|
||||
# Assume an empty name per
|
||||
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
|
||||
key, val = str(''), chunk
|
||||
key, val = key.strip(), val.strip()
|
||||
if key or val:
|
||||
# unquote using Python's algorithm.
|
||||
cookiedict[key] = _unquote_cookie(val)
|
||||
return cookiedict
|
||||
|
|
|
@ -26,8 +26,9 @@ In addition to I/O events, the `IOLoop` can also schedule time-based events.
|
|||
`IOLoop.add_timeout` is a non-blocking alternative to `time.sleep`.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
import datetime
|
||||
import errno
|
||||
import functools
|
||||
|
@ -45,20 +46,20 @@ import math
|
|||
|
||||
from tornado.concurrent import TracebackFuture, is_future
|
||||
from tornado.log import app_log, gen_log
|
||||
from tornado.platform.auto import set_close_exec, Waker
|
||||
from tornado import stack_context
|
||||
from tornado.util import Configurable, errno_from_exception, timedelta_to_seconds
|
||||
from tornado.util import PY3, Configurable, errno_from_exception, timedelta_to_seconds
|
||||
|
||||
try:
|
||||
import signal
|
||||
except ImportError:
|
||||
signal = None
|
||||
|
||||
try:
|
||||
import thread # py2
|
||||
except ImportError:
|
||||
import _thread as thread # py3
|
||||
|
||||
from tornado.platform.auto import set_close_exec, Waker
|
||||
if PY3:
|
||||
import _thread as thread
|
||||
else:
|
||||
import thread
|
||||
|
||||
|
||||
_POLL_TIMEOUT = 3600.0
|
||||
|
@ -172,6 +173,10 @@ class IOLoop(Configurable):
|
|||
This is normally not necessary as `instance()` will create
|
||||
an `IOLoop` on demand, but you may want to call `install` to use
|
||||
a custom subclass of `IOLoop`.
|
||||
|
||||
When using an `IOLoop` subclass, `install` must be called prior
|
||||
to creating any objects that implicitly create their own
|
||||
`IOLoop` (e.g., :class:`tornado.httpclient.AsyncHTTPClient`).
|
||||
"""
|
||||
assert not IOLoop.initialized()
|
||||
IOLoop._instance = self
|
||||
|
@ -612,10 +617,14 @@ class IOLoop(Configurable):
|
|||
# result, which should just be ignored.
|
||||
pass
|
||||
else:
|
||||
self.add_future(ret, lambda f: f.result())
|
||||
self.add_future(ret, self._discard_future_result)
|
||||
except Exception:
|
||||
self.handle_callback_exception(callback)
|
||||
|
||||
def _discard_future_result(self, future):
|
||||
"""Avoid unhandled-exception warnings from spawned coroutines."""
|
||||
future.result()
|
||||
|
||||
def handle_callback_exception(self, callback):
|
||||
"""This method is called whenever a callback run by the `IOLoop`
|
||||
throws an exception.
|
||||
|
@ -685,8 +694,7 @@ class PollIOLoop(IOLoop):
|
|||
self.time_func = time_func or time.time
|
||||
self._handlers = {}
|
||||
self._events = {}
|
||||
self._callbacks = []
|
||||
self._callback_lock = threading.Lock()
|
||||
self._callbacks = collections.deque()
|
||||
self._timeouts = []
|
||||
self._cancellations = 0
|
||||
self._running = False
|
||||
|
@ -704,11 +712,10 @@ class PollIOLoop(IOLoop):
|
|||
self.READ)
|
||||
|
||||
def close(self, all_fds=False):
|
||||
with self._callback_lock:
|
||||
self._closing = True
|
||||
self._closing = True
|
||||
self.remove_handler(self._waker.fileno())
|
||||
if all_fds:
|
||||
for fd, handler in self._handlers.values():
|
||||
for fd, handler in list(self._handlers.values()):
|
||||
self.close_fd(fd)
|
||||
self._waker.close()
|
||||
self._impl.close()
|
||||
|
@ -792,9 +799,7 @@ class PollIOLoop(IOLoop):
|
|||
while True:
|
||||
# Prevent IO event starvation by delaying new callbacks
|
||||
# to the next iteration of the event loop.
|
||||
with self._callback_lock:
|
||||
callbacks = self._callbacks
|
||||
self._callbacks = []
|
||||
ncallbacks = len(self._callbacks)
|
||||
|
||||
# Add any timeouts that have come due to the callback list.
|
||||
# Do not run anything until we have determined which ones
|
||||
|
@ -814,8 +819,8 @@ class PollIOLoop(IOLoop):
|
|||
due_timeouts.append(heapq.heappop(self._timeouts))
|
||||
else:
|
||||
break
|
||||
if (self._cancellations > 512
|
||||
and self._cancellations > (len(self._timeouts) >> 1)):
|
||||
if (self._cancellations > 512 and
|
||||
self._cancellations > (len(self._timeouts) >> 1)):
|
||||
# Clean up the timeout queue when it gets large and it's
|
||||
# more than half cancellations.
|
||||
self._cancellations = 0
|
||||
|
@ -823,14 +828,14 @@ class PollIOLoop(IOLoop):
|
|||
if x.callback is not None]
|
||||
heapq.heapify(self._timeouts)
|
||||
|
||||
for callback in callbacks:
|
||||
self._run_callback(callback)
|
||||
for i in range(ncallbacks):
|
||||
self._run_callback(self._callbacks.popleft())
|
||||
for timeout in due_timeouts:
|
||||
if timeout.callback is not None:
|
||||
self._run_callback(timeout.callback)
|
||||
# Closures may be holding on to a lot of memory, so allow
|
||||
# them to be freed before we go into our poll wait.
|
||||
callbacks = callback = due_timeouts = timeout = None
|
||||
due_timeouts = timeout = None
|
||||
|
||||
if self._callbacks:
|
||||
# If any callbacks or timeouts called add_callback,
|
||||
|
@ -874,7 +879,7 @@ class PollIOLoop(IOLoop):
|
|||
# Pop one fd at a time from the set of pending fds and run
|
||||
# its handler. Since that handler may perform actions on
|
||||
# other file descriptors, there may be reentrant calls to
|
||||
# this IOLoop that update self._events
|
||||
# this IOLoop that modify self._events
|
||||
self._events.update(event_pairs)
|
||||
while self._events:
|
||||
fd, events = self._events.popitem()
|
||||
|
@ -926,36 +931,20 @@ class PollIOLoop(IOLoop):
|
|||
self._cancellations += 1
|
||||
|
||||
def add_callback(self, callback, *args, **kwargs):
|
||||
if self._closing:
|
||||
return
|
||||
# Blindly insert into self._callbacks. This is safe even
|
||||
# from signal handlers because deque.append is atomic.
|
||||
self._callbacks.append(functools.partial(
|
||||
stack_context.wrap(callback), *args, **kwargs))
|
||||
if thread.get_ident() != self._thread_ident:
|
||||
# If we're not on the IOLoop's thread, we need to synchronize
|
||||
# with other threads, or waking logic will induce a race.
|
||||
with self._callback_lock:
|
||||
if self._closing:
|
||||
return
|
||||
list_empty = not self._callbacks
|
||||
self._callbacks.append(functools.partial(
|
||||
stack_context.wrap(callback), *args, **kwargs))
|
||||
if list_empty:
|
||||
# If we're not in the IOLoop's thread, and we added the
|
||||
# first callback to an empty list, we may need to wake it
|
||||
# up (it may wake up on its own, but an occasional extra
|
||||
# wake is harmless). Waking up a polling IOLoop is
|
||||
# relatively expensive, so we try to avoid it when we can.
|
||||
self._waker.wake()
|
||||
# This will write one byte but Waker.consume() reads many
|
||||
# at once, so it's ok to write even when not strictly
|
||||
# necessary.
|
||||
self._waker.wake()
|
||||
else:
|
||||
if self._closing:
|
||||
return
|
||||
# If we're on the IOLoop's thread, we don't need the lock,
|
||||
# since we don't need to wake anyone, just add the
|
||||
# callback. Blindly insert into self._callbacks. This is
|
||||
# safe even from signal handlers because the GIL makes
|
||||
# list.append atomic. One subtlety is that if the signal
|
||||
# is interrupting another thread holding the
|
||||
# _callback_lock block in IOLoop.start, we may modify
|
||||
# either the old or new version of self._callbacks, but
|
||||
# either way will work.
|
||||
self._callbacks.append(functools.partial(
|
||||
stack_context.wrap(callback), *args, **kwargs))
|
||||
# If we're on the IOLoop's thread, we don't need to wake anyone.
|
||||
pass
|
||||
|
||||
def add_callback_from_signal(self, callback, *args, **kwargs):
|
||||
with stack_context.NullContext():
|
||||
|
@ -966,26 +955,24 @@ class _Timeout(object):
|
|||
"""An IOLoop timeout, a UNIX timestamp and a callback"""
|
||||
|
||||
# Reduce memory overhead when there are lots of pending callbacks
|
||||
__slots__ = ['deadline', 'callback', 'tiebreaker']
|
||||
__slots__ = ['deadline', 'callback', 'tdeadline']
|
||||
|
||||
def __init__(self, deadline, callback, io_loop):
|
||||
if not isinstance(deadline, numbers.Real):
|
||||
raise TypeError("Unsupported deadline %r" % deadline)
|
||||
self.deadline = deadline
|
||||
self.callback = callback
|
||||
self.tiebreaker = next(io_loop._timeout_counter)
|
||||
self.tdeadline = (deadline, next(io_loop._timeout_counter))
|
||||
|
||||
# Comparison methods to sort by deadline, with object id as a tiebreaker
|
||||
# to guarantee a consistent ordering. The heapq module uses __le__
|
||||
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons
|
||||
# use __lt__).
|
||||
def __lt__(self, other):
|
||||
return ((self.deadline, self.tiebreaker) <
|
||||
(other.deadline, other.tiebreaker))
|
||||
return self.tdeadline < other.tdeadline
|
||||
|
||||
def __le__(self, other):
|
||||
return ((self.deadline, self.tiebreaker) <=
|
||||
(other.deadline, other.tiebreaker))
|
||||
return self.tdeadline <= other.tdeadline
|
||||
|
||||
|
||||
class PeriodicCallback(object):
|
||||
|
@ -1048,6 +1035,7 @@ class PeriodicCallback(object):
|
|||
|
||||
if self._next_timeout <= current_time:
|
||||
callback_time_sec = self.callback_time / 1000.0
|
||||
self._next_timeout += (math.floor((current_time - self._next_timeout) / callback_time_sec) + 1) * callback_time_sec
|
||||
self._next_timeout += (math.floor((current_time - self._next_timeout) /
|
||||
callback_time_sec) + 1) * callback_time_sec
|
||||
|
||||
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)
|
||||
|
|
|
@ -24,7 +24,7 @@ Contents:
|
|||
* `PipeIOStream`: Pipe-based IOStream implementation.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
import errno
|
||||
|
@ -58,7 +58,7 @@ except ImportError:
|
|||
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
|
||||
|
||||
if hasattr(errno, "WSAEWOULDBLOCK"):
|
||||
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
|
||||
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore
|
||||
|
||||
# These errnos indicate that a connection has been abruptly terminated.
|
||||
# They should be caught and handled less noisily than other errors.
|
||||
|
@ -66,7 +66,7 @@ _ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE,
|
|||
errno.ETIMEDOUT)
|
||||
|
||||
if hasattr(errno, "WSAECONNRESET"):
|
||||
_ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT)
|
||||
_ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT) # type: ignore
|
||||
|
||||
if sys.platform == 'darwin':
|
||||
# OSX appears to have a race condition that causes send(2) to return
|
||||
|
@ -74,13 +74,15 @@ if sys.platform == 'darwin':
|
|||
# http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
|
||||
# Since the socket is being closed anyway, treat this as an ECONNRESET
|
||||
# instead of an unexpected error.
|
||||
_ERRNO_CONNRESET += (errno.EPROTOTYPE,)
|
||||
_ERRNO_CONNRESET += (errno.EPROTOTYPE,) # type: ignore
|
||||
|
||||
# More non-portable errnos:
|
||||
_ERRNO_INPROGRESS = (errno.EINPROGRESS,)
|
||||
|
||||
if hasattr(errno, "WSAEINPROGRESS"):
|
||||
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,)
|
||||
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,) # type: ignore
|
||||
|
||||
_WINDOWS = sys.platform.startswith('win')
|
||||
|
||||
|
||||
class StreamClosedError(IOError):
|
||||
|
@ -158,11 +160,16 @@ class BaseIOStream(object):
|
|||
self.max_buffer_size // 2)
|
||||
self.max_write_buffer_size = max_write_buffer_size
|
||||
self.error = None
|
||||
self._read_buffer = collections.deque()
|
||||
self._write_buffer = collections.deque()
|
||||
self._read_buffer = bytearray()
|
||||
self._read_buffer_pos = 0
|
||||
self._read_buffer_size = 0
|
||||
self._write_buffer = bytearray()
|
||||
self._write_buffer_pos = 0
|
||||
self._write_buffer_size = 0
|
||||
self._write_buffer_frozen = False
|
||||
self._total_write_index = 0
|
||||
self._total_write_done_index = 0
|
||||
self._pending_writes_while_frozen = []
|
||||
self._read_delimiter = None
|
||||
self._read_regex = None
|
||||
self._read_max_bytes = None
|
||||
|
@ -173,7 +180,7 @@ class BaseIOStream(object):
|
|||
self._read_future = None
|
||||
self._streaming_callback = None
|
||||
self._write_callback = None
|
||||
self._write_future = None
|
||||
self._write_futures = collections.deque()
|
||||
self._close_callback = None
|
||||
self._connect_callback = None
|
||||
self._connect_future = None
|
||||
|
@ -367,36 +374,37 @@ class BaseIOStream(object):
|
|||
|
||||
If no ``callback`` is given, this method returns a `.Future` that
|
||||
resolves (with a result of ``None``) when the write has been
|
||||
completed. If `write` is called again before that `.Future` has
|
||||
resolved, the previous future will be orphaned and will never resolve.
|
||||
completed.
|
||||
|
||||
The ``data`` argument may be of type `bytes` or `memoryview`.
|
||||
|
||||
.. versionchanged:: 4.0
|
||||
Now returns a `.Future` if no callback is given.
|
||||
|
||||
.. versionchanged:: 4.5
|
||||
Added support for `memoryview` arguments.
|
||||
"""
|
||||
assert isinstance(data, bytes)
|
||||
self._check_closed()
|
||||
# We use bool(_write_buffer) as a proxy for write_buffer_size>0,
|
||||
# so never put empty strings in the buffer.
|
||||
if data:
|
||||
if (self.max_write_buffer_size is not None and
|
||||
self._write_buffer_size + len(data) > self.max_write_buffer_size):
|
||||
raise StreamBufferFullError("Reached maximum write buffer size")
|
||||
# Break up large contiguous strings before inserting them in the
|
||||
# write buffer, so we don't have to recopy the entire thing
|
||||
# as we slice off pieces to send to the socket.
|
||||
WRITE_BUFFER_CHUNK_SIZE = 128 * 1024
|
||||
for i in range(0, len(data), WRITE_BUFFER_CHUNK_SIZE):
|
||||
self._write_buffer.append(data[i:i + WRITE_BUFFER_CHUNK_SIZE])
|
||||
self._write_buffer_size += len(data)
|
||||
if self._write_buffer_frozen:
|
||||
self._pending_writes_while_frozen.append(data)
|
||||
else:
|
||||
self._write_buffer += data
|
||||
self._write_buffer_size += len(data)
|
||||
self._total_write_index += len(data)
|
||||
if callback is not None:
|
||||
self._write_callback = stack_context.wrap(callback)
|
||||
future = None
|
||||
else:
|
||||
future = self._write_future = TracebackFuture()
|
||||
future = TracebackFuture()
|
||||
future.add_done_callback(lambda f: f.exception())
|
||||
self._write_futures.append((self._total_write_index, future))
|
||||
if not self._connecting:
|
||||
self._handle_write()
|
||||
if self._write_buffer:
|
||||
if self._write_buffer_size:
|
||||
self._add_io_state(self.io_loop.WRITE)
|
||||
self._maybe_add_error_listener()
|
||||
return future
|
||||
|
@ -445,9 +453,8 @@ class BaseIOStream(object):
|
|||
if self._read_future is not None:
|
||||
futures.append(self._read_future)
|
||||
self._read_future = None
|
||||
if self._write_future is not None:
|
||||
futures.append(self._write_future)
|
||||
self._write_future = None
|
||||
futures += [future for _, future in self._write_futures]
|
||||
self._write_futures.clear()
|
||||
if self._connect_future is not None:
|
||||
futures.append(self._connect_future)
|
||||
self._connect_future = None
|
||||
|
@ -466,6 +473,7 @@ class BaseIOStream(object):
|
|||
# if the IOStream object is kept alive by a reference cycle.
|
||||
# TODO: Clear the read buffer too; it currently breaks some tests.
|
||||
self._write_buffer = None
|
||||
self._write_buffer_size = 0
|
||||
|
||||
def reading(self):
|
||||
"""Returns true if we are currently reading from the stream."""
|
||||
|
@ -473,7 +481,7 @@ class BaseIOStream(object):
|
|||
|
||||
def writing(self):
|
||||
"""Returns true if we are currently writing to the stream."""
|
||||
return bool(self._write_buffer)
|
||||
return self._write_buffer_size > 0
|
||||
|
||||
def closed(self):
|
||||
"""Returns true if the stream has been closed."""
|
||||
|
@ -743,7 +751,7 @@ class BaseIOStream(object):
|
|||
break
|
||||
if chunk is None:
|
||||
return 0
|
||||
self._read_buffer.append(chunk)
|
||||
self._read_buffer += chunk
|
||||
self._read_buffer_size += len(chunk)
|
||||
if self._read_buffer_size > self.max_buffer_size:
|
||||
gen_log.error("Reached maximum read buffer size")
|
||||
|
@ -791,30 +799,25 @@ class BaseIOStream(object):
|
|||
# since large merges are relatively expensive and get undone in
|
||||
# _consume().
|
||||
if self._read_buffer:
|
||||
while True:
|
||||
loc = self._read_buffer[0].find(self._read_delimiter)
|
||||
if loc != -1:
|
||||
delimiter_len = len(self._read_delimiter)
|
||||
self._check_max_bytes(self._read_delimiter,
|
||||
loc + delimiter_len)
|
||||
return loc + delimiter_len
|
||||
if len(self._read_buffer) == 1:
|
||||
break
|
||||
_double_prefix(self._read_buffer)
|
||||
loc = self._read_buffer.find(self._read_delimiter,
|
||||
self._read_buffer_pos)
|
||||
if loc != -1:
|
||||
loc -= self._read_buffer_pos
|
||||
delimiter_len = len(self._read_delimiter)
|
||||
self._check_max_bytes(self._read_delimiter,
|
||||
loc + delimiter_len)
|
||||
return loc + delimiter_len
|
||||
self._check_max_bytes(self._read_delimiter,
|
||||
len(self._read_buffer[0]))
|
||||
self._read_buffer_size)
|
||||
elif self._read_regex is not None:
|
||||
if self._read_buffer:
|
||||
while True:
|
||||
m = self._read_regex.search(self._read_buffer[0])
|
||||
if m is not None:
|
||||
self._check_max_bytes(self._read_regex, m.end())
|
||||
return m.end()
|
||||
if len(self._read_buffer) == 1:
|
||||
break
|
||||
_double_prefix(self._read_buffer)
|
||||
self._check_max_bytes(self._read_regex,
|
||||
len(self._read_buffer[0]))
|
||||
m = self._read_regex.search(self._read_buffer,
|
||||
self._read_buffer_pos)
|
||||
if m is not None:
|
||||
loc = m.end() - self._read_buffer_pos
|
||||
self._check_max_bytes(self._read_regex, loc)
|
||||
return loc
|
||||
self._check_max_bytes(self._read_regex, self._read_buffer_size)
|
||||
return None
|
||||
|
||||
def _check_max_bytes(self, delimiter, size):
|
||||
|
@ -824,35 +827,56 @@ class BaseIOStream(object):
|
|||
"delimiter %r not found within %d bytes" % (
|
||||
delimiter, self._read_max_bytes))
|
||||
|
||||
def _freeze_write_buffer(self, size):
|
||||
self._write_buffer_frozen = size
|
||||
|
||||
def _unfreeze_write_buffer(self):
|
||||
self._write_buffer_frozen = False
|
||||
self._write_buffer += b''.join(self._pending_writes_while_frozen)
|
||||
self._write_buffer_size += sum(map(len, self._pending_writes_while_frozen))
|
||||
self._pending_writes_while_frozen[:] = []
|
||||
|
||||
def _got_empty_write(self, size):
|
||||
"""
|
||||
Called when a non-blocking write() failed writing anything.
|
||||
Can be overridden in subclasses.
|
||||
"""
|
||||
|
||||
def _handle_write(self):
|
||||
while self._write_buffer:
|
||||
while self._write_buffer_size:
|
||||
assert self._write_buffer_size >= 0
|
||||
try:
|
||||
if not self._write_buffer_frozen:
|
||||
start = self._write_buffer_pos
|
||||
if self._write_buffer_frozen:
|
||||
size = self._write_buffer_frozen
|
||||
elif _WINDOWS:
|
||||
# On windows, socket.send blows up if given a
|
||||
# write buffer that's too large, instead of just
|
||||
# returning the number of bytes it was able to
|
||||
# process. Therefore we must not call socket.send
|
||||
# with more than 128KB at a time.
|
||||
_merge_prefix(self._write_buffer, 128 * 1024)
|
||||
num_bytes = self.write_to_fd(self._write_buffer[0])
|
||||
size = 128 * 1024
|
||||
else:
|
||||
size = self._write_buffer_size
|
||||
num_bytes = self.write_to_fd(
|
||||
memoryview(self._write_buffer)[start:start + size])
|
||||
if num_bytes == 0:
|
||||
# With OpenSSL, if we couldn't write the entire buffer,
|
||||
# the very same string object must be used on the
|
||||
# next call to send. Therefore we suppress
|
||||
# merging the write buffer after an incomplete send.
|
||||
# A cleaner solution would be to set
|
||||
# SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER, but this is
|
||||
# not yet accessible from python
|
||||
# (http://bugs.python.org/issue8240)
|
||||
self._write_buffer_frozen = True
|
||||
self._got_empty_write(size)
|
||||
break
|
||||
self._write_buffer_frozen = False
|
||||
_merge_prefix(self._write_buffer, num_bytes)
|
||||
self._write_buffer.popleft()
|
||||
self._write_buffer_pos += num_bytes
|
||||
self._write_buffer_size -= num_bytes
|
||||
# Amortized O(1) shrink
|
||||
# (this heuristic is implemented natively in Python 3.4+
|
||||
# but is replicated here for Python 2)
|
||||
if self._write_buffer_pos > self._write_buffer_size:
|
||||
del self._write_buffer[:self._write_buffer_pos]
|
||||
self._write_buffer_pos = 0
|
||||
if self._write_buffer_frozen:
|
||||
self._unfreeze_write_buffer()
|
||||
self._total_write_done_index += num_bytes
|
||||
except (socket.error, IOError, OSError) as e:
|
||||
if e.args[0] in _ERRNO_WOULDBLOCK:
|
||||
self._write_buffer_frozen = True
|
||||
self._got_empty_write(size)
|
||||
break
|
||||
else:
|
||||
if not self._is_connreset(e):
|
||||
|
@ -863,22 +887,38 @@ class BaseIOStream(object):
|
|||
self.fileno(), e)
|
||||
self.close(exc_info=True)
|
||||
return
|
||||
if not self._write_buffer:
|
||||
|
||||
while self._write_futures:
|
||||
index, future = self._write_futures[0]
|
||||
if index > self._total_write_done_index:
|
||||
break
|
||||
self._write_futures.popleft()
|
||||
future.set_result(None)
|
||||
|
||||
if not self._write_buffer_size:
|
||||
if self._write_callback:
|
||||
callback = self._write_callback
|
||||
self._write_callback = None
|
||||
self._run_callback(callback)
|
||||
if self._write_future:
|
||||
future = self._write_future
|
||||
self._write_future = None
|
||||
future.set_result(None)
|
||||
|
||||
def _consume(self, loc):
|
||||
# Consume loc bytes from the read buffer and return them
|
||||
if loc == 0:
|
||||
return b""
|
||||
_merge_prefix(self._read_buffer, loc)
|
||||
assert loc <= self._read_buffer_size
|
||||
# Slice the bytearray buffer into bytes, without intermediate copying
|
||||
b = (memoryview(self._read_buffer)
|
||||
[self._read_buffer_pos:self._read_buffer_pos + loc]
|
||||
).tobytes()
|
||||
self._read_buffer_pos += loc
|
||||
self._read_buffer_size -= loc
|
||||
return self._read_buffer.popleft()
|
||||
# Amortized O(1) shrink
|
||||
# (this heuristic is implemented natively in Python 3.4+
|
||||
# but is replicated here for Python 2)
|
||||
if self._read_buffer_pos > self._read_buffer_size:
|
||||
del self._read_buffer[:self._read_buffer_pos]
|
||||
self._read_buffer_pos = 0
|
||||
return b
|
||||
|
||||
def _check_closed(self):
|
||||
if self.closed():
|
||||
|
@ -1124,7 +1164,7 @@ class IOStream(BaseIOStream):
|
|||
suitably-configured `ssl.SSLContext` to disable.
|
||||
"""
|
||||
if (self._read_callback or self._read_future or
|
||||
self._write_callback or self._write_future or
|
||||
self._write_callback or self._write_futures or
|
||||
self._connect_callback or self._connect_future or
|
||||
self._pending_callbacks or self._closed or
|
||||
self._read_buffer or self._write_buffer):
|
||||
|
@ -1251,6 +1291,17 @@ class SSLIOStream(IOStream):
|
|||
def writing(self):
|
||||
return self._handshake_writing or super(SSLIOStream, self).writing()
|
||||
|
||||
def _got_empty_write(self, size):
|
||||
# With OpenSSL, if we couldn't write the entire buffer,
|
||||
# the very same string object must be used on the
|
||||
# next call to send. Therefore we suppress
|
||||
# merging the write buffer after an incomplete send.
|
||||
# A cleaner solution would be to set
|
||||
# SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER, but this is
|
||||
# not yet accessible from python
|
||||
# (http://bugs.python.org/issue8240)
|
||||
self._freeze_write_buffer(size)
|
||||
|
||||
def _do_ssl_handshake(self):
|
||||
# Based on code from test_ssl.py in the python stdlib
|
||||
try:
|
||||
|
@ -1498,53 +1549,6 @@ class PipeIOStream(BaseIOStream):
|
|||
return chunk
|
||||
|
||||
|
||||
def _double_prefix(deque):
|
||||
"""Grow by doubling, but don't split the second chunk just because the
|
||||
first one is small.
|
||||
"""
|
||||
new_len = max(len(deque[0]) * 2,
|
||||
(len(deque[0]) + len(deque[1])))
|
||||
_merge_prefix(deque, new_len)
|
||||
|
||||
|
||||
def _merge_prefix(deque, size):
|
||||
"""Replace the first entries in a deque of strings with a single
|
||||
string of up to size bytes.
|
||||
|
||||
>>> d = collections.deque(['abc', 'de', 'fghi', 'j'])
|
||||
>>> _merge_prefix(d, 5); print(d)
|
||||
deque(['abcde', 'fghi', 'j'])
|
||||
|
||||
Strings will be split as necessary to reach the desired size.
|
||||
>>> _merge_prefix(d, 7); print(d)
|
||||
deque(['abcdefg', 'hi', 'j'])
|
||||
|
||||
>>> _merge_prefix(d, 3); print(d)
|
||||
deque(['abc', 'defg', 'hi', 'j'])
|
||||
|
||||
>>> _merge_prefix(d, 100); print(d)
|
||||
deque(['abcdefghij'])
|
||||
"""
|
||||
if len(deque) == 1 and len(deque[0]) <= size:
|
||||
return
|
||||
prefix = []
|
||||
remaining = size
|
||||
while deque and remaining > 0:
|
||||
chunk = deque.popleft()
|
||||
if len(chunk) > remaining:
|
||||
deque.appendleft(chunk[remaining:])
|
||||
chunk = chunk[:remaining]
|
||||
prefix.append(chunk)
|
||||
remaining -= len(chunk)
|
||||
# This data structure normally just contains byte strings, but
|
||||
# the unittest gets messy if it doesn't use the default str() type,
|
||||
# so do the merge based on the type of data that's actually present.
|
||||
if prefix:
|
||||
deque.appendleft(type(prefix[0])().join(prefix))
|
||||
if not deque:
|
||||
deque.appendleft(b"")
|
||||
|
||||
|
||||
def doctests():
|
||||
import doctest
|
||||
return doctest.DocTestSuite()
|
||||
|
|
|
@ -19,7 +19,7 @@
|
|||
To load a locale and generate a translated string::
|
||||
|
||||
user_locale = tornado.locale.get("es_LA")
|
||||
print user_locale.translate("Sign out")
|
||||
print(user_locale.translate("Sign out"))
|
||||
|
||||
`tornado.locale.get()` returns the closest matching locale, not necessarily the
|
||||
specific locale you requested. You can support pluralization with
|
||||
|
@ -28,7 +28,7 @@ additional arguments to `~Locale.translate()`, e.g.::
|
|||
people = [...]
|
||||
message = user_locale.translate(
|
||||
"%(list)s is online", "%(list)s are online", len(people))
|
||||
print message % {"list": user_locale.list(people)}
|
||||
print(message % {"list": user_locale.list(people)})
|
||||
|
||||
The first string is chosen if ``len(people) == 1``, otherwise the second
|
||||
string is chosen.
|
||||
|
@ -39,7 +39,7 @@ supported by `gettext` and related tools). If neither method is called,
|
|||
the `Locale.translate` method will simply return the original string.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import codecs
|
||||
import csv
|
||||
|
@ -51,12 +51,12 @@ import re
|
|||
|
||||
from tornado import escape
|
||||
from tornado.log import gen_log
|
||||
from tornado.util import u
|
||||
from tornado.util import PY3
|
||||
|
||||
from tornado._locale_data import LOCALE_NAMES
|
||||
|
||||
_default_locale = "en_US"
|
||||
_translations = {}
|
||||
_translations = {} # type: dict
|
||||
_supported_locales = frozenset([_default_locale])
|
||||
_use_gettext = False
|
||||
CONTEXT_SEPARATOR = "\x04"
|
||||
|
@ -148,11 +148,11 @@ def load_translations(directory, encoding=None):
|
|||
# in most cases but is common with CSV files because Excel
|
||||
# cannot read utf-8 files without a BOM.
|
||||
encoding = 'utf-8-sig'
|
||||
try:
|
||||
if PY3:
|
||||
# python 3: csv.reader requires a file open in text mode.
|
||||
# Force utf8 to avoid dependence on $LANG environment variable.
|
||||
f = open(full_path, "r", encoding=encoding)
|
||||
except TypeError:
|
||||
else:
|
||||
# python 2: csv can only handle byte strings (in ascii-compatible
|
||||
# encodings), which we decode below. Transcode everything into
|
||||
# utf8 before passing it to csv.reader.
|
||||
|
@ -187,7 +187,7 @@ def load_gettext_translations(directory, domain):
|
|||
|
||||
{directory}/{lang}/LC_MESSAGES/{domain}.mo
|
||||
|
||||
Three steps are required to have you app translated:
|
||||
Three steps are required to have your app translated:
|
||||
|
||||
1. Generate POT translation file::
|
||||
|
||||
|
@ -274,7 +274,7 @@ class Locale(object):
|
|||
|
||||
def __init__(self, code, translations):
|
||||
self.code = code
|
||||
self.name = LOCALE_NAMES.get(code, {}).get("name", u("Unknown"))
|
||||
self.name = LOCALE_NAMES.get(code, {}).get("name", u"Unknown")
|
||||
self.rtl = False
|
||||
for prefix in ["fa", "ar", "he"]:
|
||||
if self.code.startswith(prefix):
|
||||
|
@ -376,7 +376,7 @@ class Locale(object):
|
|||
str_time = "%d:%02d" % (local_date.hour, local_date.minute)
|
||||
elif self.code == "zh_CN":
|
||||
str_time = "%s%d:%02d" % (
|
||||
(u('\u4e0a\u5348'), u('\u4e0b\u5348'))[local_date.hour >= 12],
|
||||
(u'\u4e0a\u5348', u'\u4e0b\u5348')[local_date.hour >= 12],
|
||||
local_date.hour % 12 or 12, local_date.minute)
|
||||
else:
|
||||
str_time = "%d:%02d %s" % (
|
||||
|
@ -422,7 +422,7 @@ class Locale(object):
|
|||
return ""
|
||||
if len(parts) == 1:
|
||||
return parts[0]
|
||||
comma = u(' \u0648 ') if self.code.startswith("fa") else u(", ")
|
||||
comma = u' \u0648 ' if self.code.startswith("fa") else u", "
|
||||
return _("%(commas)s and %(last)s") % {
|
||||
"commas": comma.join(parts[:-1]),
|
||||
"last": parts[len(parts) - 1],
|
||||
|
|
|
@ -12,15 +12,15 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
|
||||
__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
|
||||
from tornado import gen, ioloop
|
||||
from tornado.concurrent import Future
|
||||
|
||||
__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
|
||||
|
||||
|
||||
class _TimeoutGarbageCollector(object):
|
||||
"""Base class for objects that periodically clean up timed-out waiters.
|
||||
|
@ -465,7 +465,7 @@ class Lock(object):
|
|||
...
|
||||
... # Now the lock is released.
|
||||
|
||||
.. versionchanged:: 3.5
|
||||
.. versionchanged:: 4.3
|
||||
Added ``async with`` support in Python 3.5.
|
||||
|
||||
"""
|
||||
|
|
|
@ -28,7 +28,7 @@ These streams may be configured independently using the standard library's
|
|||
`logging` module. For example, you may wish to send ``tornado.access`` logs
|
||||
to a separate file for analysis.
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import logging
|
||||
import logging.handlers
|
||||
|
@ -38,7 +38,12 @@ from tornado.escape import _unicode
|
|||
from tornado.util import unicode_type, basestring_type
|
||||
|
||||
try:
|
||||
import curses
|
||||
import colorama
|
||||
except ImportError:
|
||||
colorama = None
|
||||
|
||||
try:
|
||||
import curses # type: ignore
|
||||
except ImportError:
|
||||
curses = None
|
||||
|
||||
|
@ -49,15 +54,21 @@ gen_log = logging.getLogger("tornado.general")
|
|||
|
||||
|
||||
def _stderr_supports_color():
|
||||
color = False
|
||||
if curses and hasattr(sys.stderr, 'isatty') and sys.stderr.isatty():
|
||||
try:
|
||||
curses.setupterm()
|
||||
if curses.tigetnum("colors") > 0:
|
||||
color = True
|
||||
except Exception:
|
||||
pass
|
||||
return color
|
||||
try:
|
||||
if hasattr(sys.stderr, 'isatty') and sys.stderr.isatty():
|
||||
if curses:
|
||||
curses.setupterm()
|
||||
if curses.tigetnum("colors") > 0:
|
||||
return True
|
||||
elif colorama:
|
||||
if sys.stderr is getattr(colorama.initialise, 'wrapped_stderr',
|
||||
object()):
|
||||
return True
|
||||
except Exception:
|
||||
# Very broad exception handling because it's always better to
|
||||
# fall back to non-colored logs than to break at startup.
|
||||
pass
|
||||
return False
|
||||
|
||||
|
||||
def _safe_unicode(s):
|
||||
|
@ -77,8 +88,19 @@ class LogFormatter(logging.Formatter):
|
|||
* Robust against str/bytes encoding problems.
|
||||
|
||||
This formatter is enabled automatically by
|
||||
`tornado.options.parse_command_line` (unless ``--logging=none`` is
|
||||
used).
|
||||
`tornado.options.parse_command_line` or `tornado.options.parse_config_file`
|
||||
(unless ``--logging=none`` is used).
|
||||
|
||||
Color support on Windows versions that do not support ANSI color codes is
|
||||
enabled by use of the colorama__ library. Applications that wish to use
|
||||
this must first initialize colorama with a call to ``colorama.init``.
|
||||
See the colorama documentation for details.
|
||||
|
||||
__ https://pypi.python.org/pypi/colorama
|
||||
|
||||
.. versionchanged:: 4.5
|
||||
Added support for ``colorama``. Changed the constructor
|
||||
signature to be compatible with `logging.config.dictConfig`.
|
||||
"""
|
||||
DEFAULT_FORMAT = '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s'
|
||||
DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S'
|
||||
|
@ -89,8 +111,8 @@ class LogFormatter(logging.Formatter):
|
|||
logging.ERROR: 1, # Red
|
||||
}
|
||||
|
||||
def __init__(self, color=True, fmt=DEFAULT_FORMAT,
|
||||
datefmt=DEFAULT_DATE_FORMAT, colors=DEFAULT_COLORS):
|
||||
def __init__(self, fmt=DEFAULT_FORMAT, datefmt=DEFAULT_DATE_FORMAT,
|
||||
style='%', color=True, colors=DEFAULT_COLORS):
|
||||
r"""
|
||||
:arg bool color: Enables color support.
|
||||
:arg string fmt: Log message format.
|
||||
|
@ -111,21 +133,28 @@ class LogFormatter(logging.Formatter):
|
|||
|
||||
self._colors = {}
|
||||
if color and _stderr_supports_color():
|
||||
# The curses module has some str/bytes confusion in
|
||||
# python3. Until version 3.2.3, most methods return
|
||||
# bytes, but only accept strings. In addition, we want to
|
||||
# output these strings with the logging module, which
|
||||
# works with unicode strings. The explicit calls to
|
||||
# unicode() below are harmless in python2 but will do the
|
||||
# right conversion in python 3.
|
||||
fg_color = (curses.tigetstr("setaf") or
|
||||
curses.tigetstr("setf") or "")
|
||||
if (3, 0) < sys.version_info < (3, 2, 3):
|
||||
fg_color = unicode_type(fg_color, "ascii")
|
||||
if curses is not None:
|
||||
# The curses module has some str/bytes confusion in
|
||||
# python3. Until version 3.2.3, most methods return
|
||||
# bytes, but only accept strings. In addition, we want to
|
||||
# output these strings with the logging module, which
|
||||
# works with unicode strings. The explicit calls to
|
||||
# unicode() below are harmless in python2 but will do the
|
||||
# right conversion in python 3.
|
||||
fg_color = (curses.tigetstr("setaf") or
|
||||
curses.tigetstr("setf") or "")
|
||||
if (3, 0) < sys.version_info < (3, 2, 3):
|
||||
fg_color = unicode_type(fg_color, "ascii")
|
||||
|
||||
for levelno, code in colors.items():
|
||||
self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii")
|
||||
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
|
||||
for levelno, code in colors.items():
|
||||
self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii")
|
||||
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
|
||||
else:
|
||||
# If curses is not present (currently we'll only get here for
|
||||
# colorama on windows), assume hard-coded ANSI color codes.
|
||||
for levelno, code in colors.items():
|
||||
self._colors[levelno] = '\033[2;3%dm' % code
|
||||
self._normal = '\033[0m'
|
||||
else:
|
||||
self._normal = ''
|
||||
|
||||
|
@ -183,7 +212,8 @@ def enable_pretty_logging(options=None, logger=None):
|
|||
and `tornado.options.parse_config_file`.
|
||||
"""
|
||||
if options is None:
|
||||
from tornado.options import options
|
||||
import tornado.options
|
||||
options = tornado.options.options
|
||||
if options.logging is None or options.logging.lower() == 'none':
|
||||
return
|
||||
if logger is None:
|
||||
|
@ -228,7 +258,8 @@ def define_logging_options(options=None):
|
|||
"""
|
||||
if options is None:
|
||||
# late import to prevent cycle
|
||||
from tornado.options import options
|
||||
import tornado.options
|
||||
options = tornado.options.options
|
||||
options.define("logging", default="info",
|
||||
help=("Set the Python log level. If 'none', tornado won't touch the "
|
||||
"logging configuration."),
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
"""Miscellaneous network utility code."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import errno
|
||||
import os
|
||||
|
@ -27,7 +27,7 @@ import stat
|
|||
from tornado.concurrent import dummy_executor, run_on_executor
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.platform.auto import set_close_exec
|
||||
from tornado.util import u, Configurable, errno_from_exception
|
||||
from tornado.util import PY3, Configurable, errno_from_exception
|
||||
|
||||
try:
|
||||
import ssl
|
||||
|
@ -44,20 +44,18 @@ except ImportError:
|
|||
else:
|
||||
raise
|
||||
|
||||
try:
|
||||
xrange # py2
|
||||
except NameError:
|
||||
xrange = range # py3
|
||||
if PY3:
|
||||
xrange = range
|
||||
|
||||
if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+
|
||||
ssl_match_hostname = ssl.match_hostname
|
||||
SSLCertificateError = ssl.CertificateError
|
||||
elif ssl is None:
|
||||
ssl_match_hostname = SSLCertificateError = None
|
||||
ssl_match_hostname = SSLCertificateError = None # type: ignore
|
||||
else:
|
||||
import backports.ssl_match_hostname
|
||||
ssl_match_hostname = backports.ssl_match_hostname.match_hostname
|
||||
SSLCertificateError = backports.ssl_match_hostname.CertificateError
|
||||
SSLCertificateError = backports.ssl_match_hostname.CertificateError # type: ignore
|
||||
|
||||
if hasattr(ssl, 'SSLContext'):
|
||||
if hasattr(ssl, 'create_default_context'):
|
||||
|
@ -96,7 +94,10 @@ else:
|
|||
# module-import time, the import lock is already held by the main thread,
|
||||
# leading to deadlock. Avoid it by caching the idna encoder on the main
|
||||
# thread now.
|
||||
u('foo').encode('idna')
|
||||
u'foo'.encode('idna')
|
||||
|
||||
# For undiagnosed reasons, 'latin1' codec may also need to be preloaded.
|
||||
u'foo'.encode('latin1')
|
||||
|
||||
# These errnos indicate that a non-blocking operation must be retried
|
||||
# at a later time. On most platforms they're the same value, but on
|
||||
|
@ -104,7 +105,7 @@ u('foo').encode('idna')
|
|||
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
|
||||
|
||||
if hasattr(errno, "WSAEWOULDBLOCK"):
|
||||
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
|
||||
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore
|
||||
|
||||
# Default backlog used when calling sock.listen()
|
||||
_DEFAULT_BACKLOG = 128
|
||||
|
@ -131,7 +132,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
|
|||
``flags`` is a bitmask of AI_* flags to `~socket.getaddrinfo`, like
|
||||
``socket.AI_PASSIVE | socket.AI_NUMERICHOST``.
|
||||
|
||||
``resuse_port`` option sets ``SO_REUSEPORT`` option for every socket
|
||||
``reuse_port`` option sets ``SO_REUSEPORT`` option for every socket
|
||||
in the list. If your platform doesn't support this option ValueError will
|
||||
be raised.
|
||||
"""
|
||||
|
@ -199,6 +200,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
|
|||
sockets.append(sock)
|
||||
return sockets
|
||||
|
||||
|
||||
if hasattr(socket, 'AF_UNIX'):
|
||||
def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
|
||||
"""Creates a listening unix socket.
|
||||
|
@ -334,6 +336,11 @@ class Resolver(Configurable):
|
|||
port)`` pair for IPv4; additional fields may be present for
|
||||
IPv6). If a ``callback`` is passed, it will be run with the
|
||||
result as an argument when it is complete.
|
||||
|
||||
:raises IOError: if the address cannot be resolved.
|
||||
|
||||
.. versionchanged:: 4.4
|
||||
Standardized all implementations to raise `IOError`.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
@ -413,8 +420,8 @@ class ThreadedResolver(ExecutorResolver):
|
|||
All ``ThreadedResolvers`` share a single thread pool, whose
|
||||
size is set by the first one to be created.
|
||||
"""
|
||||
_threadpool = None
|
||||
_threadpool_pid = None
|
||||
_threadpool = None # type: ignore
|
||||
_threadpool_pid = None # type: int
|
||||
|
||||
def initialize(self, io_loop=None, num_threads=10):
|
||||
threadpool = ThreadedResolver._create_threadpool(num_threads)
|
||||
|
@ -518,4 +525,4 @@ def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
|
|||
else:
|
||||
return context.wrap_socket(socket, **kwargs)
|
||||
else:
|
||||
return ssl.wrap_socket(socket, **dict(context, **kwargs))
|
||||
return ssl.wrap_socket(socket, **dict(context, **kwargs)) # type: ignore
|
||||
|
|
|
@ -41,6 +41,12 @@ either::
|
|||
# or
|
||||
tornado.options.parse_config_file("/etc/server.conf")
|
||||
|
||||
.. note:
|
||||
|
||||
When using tornado.options.parse_command_line or
|
||||
tornado.options.parse_config_file, the only options that are set are
|
||||
ones that were previously defined with tornado.options.define.
|
||||
|
||||
Command line formats are what you would expect (``--myoption=myvalue``).
|
||||
Config files are just Python files. Global names become options, e.g.::
|
||||
|
||||
|
@ -76,7 +82,7 @@ instances to define isolated sets of options, such as for subcommands.
|
|||
underscores.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import datetime
|
||||
import numbers
|
||||
|
@ -132,8 +138,10 @@ class OptionParser(object):
|
|||
return name in self._options
|
||||
|
||||
def __getitem__(self, name):
|
||||
name = self._normalize_name(name)
|
||||
return self._options[name].value()
|
||||
return self.__getattr__(name)
|
||||
|
||||
def __setitem__(self, name, value):
|
||||
return self.__setattr__(name, value)
|
||||
|
||||
def items(self):
|
||||
"""A sequence of (name, value) pairs.
|
||||
|
@ -300,8 +308,12 @@ class OptionParser(object):
|
|||
.. versionchanged:: 4.1
|
||||
Config files are now always interpreted as utf-8 instead of
|
||||
the system default encoding.
|
||||
|
||||
.. versionchanged:: 4.4
|
||||
The special variable ``__file__`` is available inside config
|
||||
files, specifying the absolute path to the config file itself.
|
||||
"""
|
||||
config = {}
|
||||
config = {'__file__': os.path.abspath(path)}
|
||||
with open(path, 'rb') as f:
|
||||
exec_in(native_str(f.read()), config, config)
|
||||
for name in config:
|
||||
|
|
|
@ -14,12 +14,12 @@ loops.
|
|||
|
||||
.. note::
|
||||
|
||||
Tornado requires the `~asyncio.BaseEventLoop.add_reader` family of methods,
|
||||
so it is not compatible with the `~asyncio.ProactorEventLoop` on Windows.
|
||||
Use the `~asyncio.SelectorEventLoop` instead.
|
||||
Tornado requires the `~asyncio.AbstractEventLoop.add_reader` family of
|
||||
methods, so it is not compatible with the `~asyncio.ProactorEventLoop` on
|
||||
Windows. Use the `~asyncio.SelectorEventLoop` instead.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import functools
|
||||
|
||||
import tornado.concurrent
|
||||
|
@ -30,11 +30,11 @@ from tornado import stack_context
|
|||
try:
|
||||
# Import the real asyncio module for py33+ first. Older versions of the
|
||||
# trollius backport also use this name.
|
||||
import asyncio
|
||||
import asyncio # type: ignore
|
||||
except ImportError as e:
|
||||
# Asyncio itself isn't available; see if trollius is (backport to py26+).
|
||||
try:
|
||||
import trollius as asyncio
|
||||
import trollius as asyncio # type: ignore
|
||||
except ImportError:
|
||||
# Re-raise the original asyncio error, not the trollius one.
|
||||
raise e
|
||||
|
@ -141,6 +141,8 @@ class BaseAsyncIOLoop(IOLoop):
|
|||
|
||||
def add_callback(self, callback, *args, **kwargs):
|
||||
if self.closing:
|
||||
# TODO: this is racy; we need a lock to ensure that the
|
||||
# loop isn't closed during call_soon_threadsafe.
|
||||
raise RuntimeError("IOLoop is closing")
|
||||
self.asyncio_loop.call_soon_threadsafe(
|
||||
self._run_callback,
|
||||
|
@ -158,6 +160,9 @@ class AsyncIOMainLoop(BaseAsyncIOLoop):
|
|||
import asyncio
|
||||
AsyncIOMainLoop().install()
|
||||
asyncio.get_event_loop().run_forever()
|
||||
|
||||
See also :meth:`tornado.ioloop.IOLoop.install` for general notes on
|
||||
installing alternative IOLoops.
|
||||
"""
|
||||
def initialize(self, **kwargs):
|
||||
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(),
|
||||
|
@ -212,5 +217,6 @@ def to_asyncio_future(tornado_future):
|
|||
tornado.concurrent.chain_future(tornado_future, af)
|
||||
return af
|
||||
|
||||
|
||||
if hasattr(convert_yielded, 'register'):
|
||||
convert_yielded.register(asyncio.Future, to_tornado_future)
|
||||
convert_yielded.register(asyncio.Future, to_tornado_future) # type: ignore
|
||||
|
|
|
@ -23,7 +23,7 @@ Most code that needs access to this functionality should do e.g.::
|
|||
from tornado.platform.auto import set_close_exec
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import os
|
||||
|
||||
|
@ -47,8 +47,13 @@ try:
|
|||
except ImportError:
|
||||
pass
|
||||
try:
|
||||
from time import monotonic as monotonic_time
|
||||
# monotonic can provide a monotonic function in versions of python before
|
||||
# 3.3, too.
|
||||
from monotonic import monotonic as monotonic_time
|
||||
except ImportError:
|
||||
monotonic_time = None
|
||||
try:
|
||||
from time import monotonic as monotonic_time
|
||||
except ImportError:
|
||||
monotonic_time = None
|
||||
|
||||
__all__ = ['Waker', 'set_close_exec', 'monotonic_time']
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
import pycares
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import pycares # type: ignore
|
||||
import socket
|
||||
|
||||
from tornado import gen
|
||||
|
@ -61,8 +61,8 @@ class CaresResolver(Resolver):
|
|||
assert not callback_args.kwargs
|
||||
result, error = callback_args.args
|
||||
if error:
|
||||
raise Exception('C-Ares returned error %s: %s while resolving %s' %
|
||||
(error, pycares.errno.strerror(error), host))
|
||||
raise IOError('C-Ares returned error %s: %s while resolving %s' %
|
||||
(error, pycares.errno.strerror(error), host))
|
||||
addresses = result.addresses
|
||||
addrinfo = []
|
||||
for address in addresses:
|
||||
|
@ -73,7 +73,7 @@ class CaresResolver(Resolver):
|
|||
else:
|
||||
address_family = socket.AF_UNSPEC
|
||||
if family != socket.AF_UNSPEC and family != address_family:
|
||||
raise Exception('Requested socket family %d but got %d' %
|
||||
(family, address_family))
|
||||
raise IOError('Requested socket family %d but got %d' %
|
||||
(family, address_family))
|
||||
addrinfo.append((address_family, (address, port)))
|
||||
raise gen.Return(addrinfo)
|
||||
|
|
|
@ -1,10 +1,27 @@
|
|||
"""Lowest-common-denominator implementations of platform functionality."""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import errno
|
||||
import socket
|
||||
import time
|
||||
|
||||
from tornado.platform import interface
|
||||
from tornado.util import errno_from_exception
|
||||
|
||||
|
||||
def try_close(f):
|
||||
# Avoid issue #875 (race condition when using the file in another
|
||||
# thread).
|
||||
for i in range(10):
|
||||
try:
|
||||
f.close()
|
||||
except IOError:
|
||||
# Yield to another thread
|
||||
time.sleep(1e-3)
|
||||
else:
|
||||
break
|
||||
# Try a last time and let raise
|
||||
f.close()
|
||||
|
||||
|
||||
class Waker(interface.Waker):
|
||||
|
@ -45,7 +62,7 @@ class Waker(interface.Waker):
|
|||
break # success
|
||||
except socket.error as detail:
|
||||
if (not hasattr(errno, 'WSAEADDRINUSE') or
|
||||
detail[0] != errno.WSAEADDRINUSE):
|
||||
errno_from_exception(detail) != errno.WSAEADDRINUSE):
|
||||
# "Address already in use" is the only error
|
||||
# I've seen on two WinXP Pro SP2 boxes, under
|
||||
# Pythons 2.3.5 and 2.4.1.
|
||||
|
@ -75,7 +92,7 @@ class Waker(interface.Waker):
|
|||
def wake(self):
|
||||
try:
|
||||
self.writer.send(b"x")
|
||||
except (IOError, socket.error):
|
||||
except (IOError, socket.error, ValueError):
|
||||
pass
|
||||
|
||||
def consume(self):
|
||||
|
@ -89,4 +106,4 @@ class Waker(interface.Waker):
|
|||
|
||||
def close(self):
|
||||
self.reader.close()
|
||||
self.writer.close()
|
||||
try_close(self.writer)
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""EPoll-based IOLoop implementation for Linux systems."""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import select
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ for other tornado.platform modules. Most code should import the appropriate
|
|||
implementation from `tornado.platform.auto`.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
|
||||
def set_close_exec(fd):
|
||||
|
@ -61,3 +61,7 @@ class Waker(object):
|
|||
def close(self):
|
||||
"""Closes the waker's file descriptor(s)."""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def monotonic_time():
|
||||
raise NotImplementedError()
|
||||
|
|
|
@ -14,7 +14,7 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
"""KQueue-based IOLoop implementation for BSD/Mac systems."""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import select
|
||||
|
||||
|
|
|
@ -16,12 +16,12 @@
|
|||
|
||||
"""Posix implementations of platform-specific functionality."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import fcntl
|
||||
import os
|
||||
|
||||
from tornado.platform import interface
|
||||
from tornado.platform import common, interface
|
||||
|
||||
|
||||
def set_close_exec(fd):
|
||||
|
@ -53,7 +53,7 @@ class Waker(interface.Waker):
|
|||
def wake(self):
|
||||
try:
|
||||
self.writer.write(b"x")
|
||||
except IOError:
|
||||
except (IOError, ValueError):
|
||||
pass
|
||||
|
||||
def consume(self):
|
||||
|
@ -67,4 +67,4 @@ class Waker(interface.Waker):
|
|||
|
||||
def close(self):
|
||||
self.reader.close()
|
||||
self.writer.close()
|
||||
common.try_close(self.writer)
|
||||
|
|
|
@ -17,7 +17,7 @@
|
|||
|
||||
Used as a fallback for systems that don't support epoll or kqueue.
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import select
|
||||
|
||||
|
|
|
@ -21,7 +21,7 @@ depending on which library's underlying event loop you want to use.
|
|||
This module has been tested with Twisted versions 11.0.0 and newer.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import datetime
|
||||
import functools
|
||||
|
@ -29,19 +29,18 @@ import numbers
|
|||
import socket
|
||||
import sys
|
||||
|
||||
import twisted.internet.abstract
|
||||
from twisted.internet.defer import Deferred
|
||||
from twisted.internet.posixbase import PosixReactorBase
|
||||
from twisted.internet.interfaces import \
|
||||
IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor
|
||||
from twisted.python import failure, log
|
||||
from twisted.internet import error
|
||||
import twisted.names.cache
|
||||
import twisted.names.client
|
||||
import twisted.names.hosts
|
||||
import twisted.names.resolve
|
||||
import twisted.internet.abstract # type: ignore
|
||||
from twisted.internet.defer import Deferred # type: ignore
|
||||
from twisted.internet.posixbase import PosixReactorBase # type: ignore
|
||||
from twisted.internet.interfaces import IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor # type: ignore
|
||||
from twisted.python import failure, log # type: ignore
|
||||
from twisted.internet import error # type: ignore
|
||||
import twisted.names.cache # type: ignore
|
||||
import twisted.names.client # type: ignore
|
||||
import twisted.names.hosts # type: ignore
|
||||
import twisted.names.resolve # type: ignore
|
||||
|
||||
from zope.interface import implementer
|
||||
from zope.interface import implementer # type: ignore
|
||||
|
||||
from tornado.concurrent import Future
|
||||
from tornado.escape import utf8
|
||||
|
@ -354,7 +353,7 @@ def install(io_loop=None):
|
|||
if not io_loop:
|
||||
io_loop = tornado.ioloop.IOLoop.current()
|
||||
reactor = TornadoReactor(io_loop)
|
||||
from twisted.internet.main import installReactor
|
||||
from twisted.internet.main import installReactor # type: ignore
|
||||
installReactor(reactor)
|
||||
return reactor
|
||||
|
||||
|
@ -408,11 +407,14 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
|
|||
Not compatible with `tornado.process.Subprocess.set_exit_callback`
|
||||
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
|
||||
with each other.
|
||||
|
||||
See also :meth:`tornado.ioloop.IOLoop.install` for general notes on
|
||||
installing alternative IOLoops.
|
||||
"""
|
||||
def initialize(self, reactor=None, **kwargs):
|
||||
super(TwistedIOLoop, self).initialize(**kwargs)
|
||||
if reactor is None:
|
||||
import twisted.internet.reactor
|
||||
import twisted.internet.reactor # type: ignore
|
||||
reactor = twisted.internet.reactor
|
||||
self.reactor = reactor
|
||||
self.fds = {}
|
||||
|
@ -554,7 +556,10 @@ class TwistedResolver(Resolver):
|
|||
deferred = self.resolver.getHostByName(utf8(host))
|
||||
resolved = yield gen.Task(deferred.addBoth)
|
||||
if isinstance(resolved, failure.Failure):
|
||||
resolved.raiseException()
|
||||
try:
|
||||
resolved.raiseException()
|
||||
except twisted.names.error.DomainError as e:
|
||||
raise IOError(e)
|
||||
elif twisted.internet.abstract.isIPAddress(resolved):
|
||||
resolved_family = socket.AF_INET
|
||||
elif twisted.internet.abstract.isIPv6Address(resolved):
|
||||
|
@ -569,8 +574,9 @@ class TwistedResolver(Resolver):
|
|||
]
|
||||
raise gen.Return(result)
|
||||
|
||||
|
||||
if hasattr(gen.convert_yielded, 'register'):
|
||||
@gen.convert_yielded.register(Deferred)
|
||||
@gen.convert_yielded.register(Deferred) # type: ignore
|
||||
def _(d):
|
||||
f = Future()
|
||||
|
||||
|
|
|
@ -2,9 +2,9 @@
|
|||
# for production use.
|
||||
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
import ctypes
|
||||
import ctypes.wintypes
|
||||
from __future__ import absolute_import, division, print_function
|
||||
import ctypes # type: ignore
|
||||
import ctypes.wintypes # type: ignore
|
||||
|
||||
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
|
||||
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
|
||||
|
@ -17,4 +17,4 @@ HANDLE_FLAG_INHERIT = 0x00000001
|
|||
def set_close_exec(fd):
|
||||
success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
|
||||
if not success:
|
||||
raise ctypes.GetLastError()
|
||||
raise ctypes.WinError()
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
the server into multiple processes and managing subprocesses.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import errno
|
||||
import os
|
||||
|
@ -35,7 +35,7 @@ from tornado.iostream import PipeIOStream
|
|||
from tornado.log import gen_log
|
||||
from tornado.platform.auto import set_close_exec
|
||||
from tornado import stack_context
|
||||
from tornado.util import errno_from_exception
|
||||
from tornado.util import errno_from_exception, PY3
|
||||
|
||||
try:
|
||||
import multiprocessing
|
||||
|
@ -43,11 +43,8 @@ except ImportError:
|
|||
# Multiprocessing is not available on Google App Engine.
|
||||
multiprocessing = None
|
||||
|
||||
try:
|
||||
long # py2
|
||||
except NameError:
|
||||
long = int # py3
|
||||
|
||||
if PY3:
|
||||
long = int
|
||||
|
||||
# Re-export this exception for convenience.
|
||||
try:
|
||||
|
@ -70,7 +67,7 @@ def cpu_count():
|
|||
pass
|
||||
try:
|
||||
return os.sysconf("SC_NPROCESSORS_CONF")
|
||||
except ValueError:
|
||||
except (AttributeError, ValueError):
|
||||
pass
|
||||
gen_log.error("Could not detect number of processors; assuming 1")
|
||||
return 1
|
||||
|
@ -147,6 +144,7 @@ def fork_processes(num_processes, max_restarts=100):
|
|||
else:
|
||||
children[pid] = i
|
||||
return None
|
||||
|
||||
for i in range(num_processes):
|
||||
id = start_child(i)
|
||||
if id is not None:
|
||||
|
@ -204,13 +202,19 @@ class Subprocess(object):
|
|||
attribute of the resulting Subprocess a `.PipeIOStream`.
|
||||
* A new keyword argument ``io_loop`` may be used to pass in an IOLoop.
|
||||
|
||||
The ``Subprocess.STREAM`` option and the ``set_exit_callback`` and
|
||||
``wait_for_exit`` methods do not work on Windows. There is
|
||||
therefore no reason to use this class instead of
|
||||
``subprocess.Popen`` on that platform.
|
||||
|
||||
.. versionchanged:: 4.1
|
||||
The ``io_loop`` argument is deprecated.
|
||||
|
||||
"""
|
||||
STREAM = object()
|
||||
|
||||
_initialized = False
|
||||
_waiting = {}
|
||||
_waiting = {} # type: ignore
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.io_loop = kwargs.pop('io_loop', None) or ioloop.IOLoop.current()
|
||||
|
@ -351,6 +355,10 @@ class Subprocess(object):
|
|||
else:
|
||||
assert os.WIFEXITED(status)
|
||||
self.returncode = os.WEXITSTATUS(status)
|
||||
# We've taken over wait() duty from the subprocess.Popen
|
||||
# object. If we don't inform it of the process's return code,
|
||||
# it will log a warning at destruction in python 3.6+.
|
||||
self.proc.returncode = self.returncode
|
||||
if self._exit_callback:
|
||||
callback = self._exit_callback
|
||||
self._exit_callback = None
|
||||
|
|
|
@ -12,9 +12,17 @@
|
|||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
"""Asynchronous queues for coroutines.
|
||||
|
||||
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
|
||||
.. warning::
|
||||
|
||||
Unlike the standard library's `queue` module, the classes defined here
|
||||
are *not* thread-safe. To use these queues from another thread,
|
||||
use `.IOLoop.add_callback` to transfer control to the `.IOLoop` thread
|
||||
before calling any queue methods.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import collections
|
||||
import heapq
|
||||
|
@ -23,6 +31,8 @@ from tornado import gen, ioloop
|
|||
from tornado.concurrent import Future
|
||||
from tornado.locks import Event
|
||||
|
||||
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
|
||||
|
||||
|
||||
class QueueEmpty(Exception):
|
||||
"""Raised by `.Queue.get_nowait` when the queue has no items."""
|
||||
|
|
|
@ -0,0 +1,625 @@
|
|||
# Copyright 2015 The Tornado Authors
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||
# not use this file except in compliance with the License. You may obtain
|
||||
# a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||
# License for the specific language governing permissions and limitations
|
||||
# under the License.
|
||||
|
||||
"""Flexible routing implementation.
|
||||
|
||||
Tornado routes HTTP requests to appropriate handlers using `Router`
|
||||
class implementations. The `tornado.web.Application` class is a
|
||||
`Router` implementation and may be used directly, or the classes in
|
||||
this module may be used for additional flexibility. The `RuleRouter`
|
||||
class can match on more criteria than `.Application`, or the `Router`
|
||||
interface can be subclassed for maximum customization.
|
||||
|
||||
`Router` interface extends `~.httputil.HTTPServerConnectionDelegate`
|
||||
to provide additional routing capabilities. This also means that any
|
||||
`Router` implementation can be used directly as a ``request_callback``
|
||||
for `~.httpserver.HTTPServer` constructor.
|
||||
|
||||
`Router` subclass must implement a ``find_handler`` method to provide
|
||||
a suitable `~.httputil.HTTPMessageDelegate` instance to handle the
|
||||
request:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
class CustomRouter(Router):
|
||||
def find_handler(self, request, **kwargs):
|
||||
# some routing logic providing a suitable HTTPMessageDelegate instance
|
||||
return MessageDelegate(request.connection)
|
||||
|
||||
class MessageDelegate(HTTPMessageDelegate):
|
||||
def __init__(self, connection):
|
||||
self.connection = connection
|
||||
|
||||
def finish(self):
|
||||
self.connection.write_headers(
|
||||
ResponseStartLine("HTTP/1.1", 200, "OK"),
|
||||
HTTPHeaders({"Content-Length": "2"}),
|
||||
b"OK")
|
||||
self.connection.finish()
|
||||
|
||||
router = CustomRouter()
|
||||
server = HTTPServer(router)
|
||||
|
||||
The main responsibility of `Router` implementation is to provide a
|
||||
mapping from a request to `~.httputil.HTTPMessageDelegate` instance
|
||||
that will handle this request. In the example above we can see that
|
||||
routing is possible even without instantiating an `~.web.Application`.
|
||||
|
||||
For routing to `~.web.RequestHandler` implementations we need an
|
||||
`~.web.Application` instance. `~.web.Application.get_handler_delegate`
|
||||
provides a convenient way to create `~.httputil.HTTPMessageDelegate`
|
||||
for a given request and `~.web.RequestHandler`.
|
||||
|
||||
Here is a simple example of how we can we route to
|
||||
`~.web.RequestHandler` subclasses by HTTP method:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
resources = {}
|
||||
|
||||
class GetResource(RequestHandler):
|
||||
def get(self, path):
|
||||
if path not in resources:
|
||||
raise HTTPError(404)
|
||||
|
||||
self.finish(resources[path])
|
||||
|
||||
class PostResource(RequestHandler):
|
||||
def post(self, path):
|
||||
resources[path] = self.request.body
|
||||
|
||||
class HTTPMethodRouter(Router):
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
|
||||
def find_handler(self, request, **kwargs):
|
||||
handler = GetResource if request.method == "GET" else PostResource
|
||||
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
|
||||
|
||||
router = HTTPMethodRouter(Application())
|
||||
server = HTTPServer(router)
|
||||
|
||||
`ReversibleRouter` interface adds the ability to distinguish between
|
||||
the routes and reverse them to the original urls using route's name
|
||||
and additional arguments. `~.web.Application` is itself an
|
||||
implementation of `ReversibleRouter` class.
|
||||
|
||||
`RuleRouter` and `ReversibleRuleRouter` are implementations of
|
||||
`Router` and `ReversibleRouter` interfaces and can be used for
|
||||
creating rule-based routing configurations.
|
||||
|
||||
Rules are instances of `Rule` class. They contain a `Matcher`, which
|
||||
provides the logic for determining whether the rule is a match for a
|
||||
particular request and a target, which can be one of the following.
|
||||
|
||||
1) An instance of `~.httputil.HTTPServerConnectionDelegate`:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
router = RuleRouter([
|
||||
Rule(PathMatches("/handler"), ConnectionDelegate()),
|
||||
# ... more rules
|
||||
])
|
||||
|
||||
class ConnectionDelegate(HTTPServerConnectionDelegate):
|
||||
def start_request(self, server_conn, request_conn):
|
||||
return MessageDelegate(request_conn)
|
||||
|
||||
2) A callable accepting a single argument of `~.httputil.HTTPServerRequest` type:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
router = RuleRouter([
|
||||
Rule(PathMatches("/callable"), request_callable)
|
||||
])
|
||||
|
||||
def request_callable(request):
|
||||
request.write(b"HTTP/1.1 200 OK\\r\\nContent-Length: 2\\r\\n\\r\\nOK")
|
||||
request.finish()
|
||||
|
||||
3) Another `Router` instance:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
router = RuleRouter([
|
||||
Rule(PathMatches("/router.*"), CustomRouter())
|
||||
])
|
||||
|
||||
Of course a nested `RuleRouter` or a `~.web.Application` is allowed:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
router = RuleRouter([
|
||||
Rule(HostMatches("example.com"), RuleRouter([
|
||||
Rule(PathMatches("/app1/.*"), Application([(r"/app1/handler", Handler)]))),
|
||||
]))
|
||||
])
|
||||
|
||||
server = HTTPServer(router)
|
||||
|
||||
In the example below `RuleRouter` is used to route between applications:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
app1 = Application([
|
||||
(r"/app1/handler", Handler1),
|
||||
# other handlers ...
|
||||
])
|
||||
|
||||
app2 = Application([
|
||||
(r"/app2/handler", Handler2),
|
||||
# other handlers ...
|
||||
])
|
||||
|
||||
router = RuleRouter([
|
||||
Rule(PathMatches("/app1.*"), app1),
|
||||
Rule(PathMatches("/app2.*"), app2)
|
||||
])
|
||||
|
||||
server = HTTPServer(router)
|
||||
|
||||
For more information on application-level routing see docs for `~.web.Application`.
|
||||
|
||||
.. versionadded:: 4.5
|
||||
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import re
|
||||
from functools import partial
|
||||
|
||||
from tornado import httputil
|
||||
from tornado.httpserver import _CallableAdapter
|
||||
from tornado.escape import url_escape, url_unescape, utf8
|
||||
from tornado.log import app_log
|
||||
from tornado.util import basestring_type, import_object, re_unescape, unicode_type
|
||||
|
||||
try:
|
||||
import typing # noqa
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class Router(httputil.HTTPServerConnectionDelegate):
|
||||
"""Abstract router interface."""
|
||||
|
||||
def find_handler(self, request, **kwargs):
|
||||
# type: (httputil.HTTPServerRequest, typing.Any)->httputil.HTTPMessageDelegate
|
||||
"""Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate`
|
||||
that can serve the request.
|
||||
Routing implementations may pass additional kwargs to extend the routing logic.
|
||||
|
||||
:arg httputil.HTTPServerRequest request: current HTTP request.
|
||||
:arg kwargs: additional keyword arguments passed by routing implementation.
|
||||
:returns: an instance of `~.httputil.HTTPMessageDelegate` that will be used to
|
||||
process the request.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
def start_request(self, server_conn, request_conn):
|
||||
return _RoutingDelegate(self, server_conn, request_conn)
|
||||
|
||||
|
||||
class ReversibleRouter(Router):
|
||||
"""Abstract router interface for routers that can handle named routes
|
||||
and support reversing them to original urls.
|
||||
"""
|
||||
|
||||
def reverse_url(self, name, *args):
|
||||
"""Returns url string for a given route name and arguments
|
||||
or ``None`` if no match is found.
|
||||
|
||||
:arg str name: route name.
|
||||
:arg args: url parameters.
|
||||
:returns: parametrized url string for a given route name (or ``None``).
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class _RoutingDelegate(httputil.HTTPMessageDelegate):
|
||||
def __init__(self, router, server_conn, request_conn):
|
||||
self.server_conn = server_conn
|
||||
self.request_conn = request_conn
|
||||
self.delegate = None
|
||||
self.router = router # type: Router
|
||||
|
||||
def headers_received(self, start_line, headers):
|
||||
request = httputil.HTTPServerRequest(
|
||||
connection=self.request_conn,
|
||||
server_connection=self.server_conn,
|
||||
start_line=start_line, headers=headers)
|
||||
|
||||
self.delegate = self.router.find_handler(request)
|
||||
return self.delegate.headers_received(start_line, headers)
|
||||
|
||||
def data_received(self, chunk):
|
||||
return self.delegate.data_received(chunk)
|
||||
|
||||
def finish(self):
|
||||
self.delegate.finish()
|
||||
|
||||
def on_connection_close(self):
|
||||
self.delegate.on_connection_close()
|
||||
|
||||
|
||||
class RuleRouter(Router):
|
||||
"""Rule-based router implementation."""
|
||||
|
||||
def __init__(self, rules=None):
|
||||
"""Constructs a router from an ordered list of rules::
|
||||
|
||||
RuleRouter([
|
||||
Rule(PathMatches("/handler"), Target),
|
||||
# ... more rules
|
||||
])
|
||||
|
||||
You can also omit explicit `Rule` constructor and use tuples of arguments::
|
||||
|
||||
RuleRouter([
|
||||
(PathMatches("/handler"), Target),
|
||||
])
|
||||
|
||||
`PathMatches` is a default matcher, so the example above can be simplified::
|
||||
|
||||
RuleRouter([
|
||||
("/handler", Target),
|
||||
])
|
||||
|
||||
In the examples above, ``Target`` can be a nested `Router` instance, an instance of
|
||||
`~.httputil.HTTPServerConnectionDelegate` or an old-style callable, accepting a request argument.
|
||||
|
||||
:arg rules: a list of `Rule` instances or tuples of `Rule`
|
||||
constructor arguments.
|
||||
"""
|
||||
self.rules = [] # type: typing.List[Rule]
|
||||
if rules:
|
||||
self.add_rules(rules)
|
||||
|
||||
def add_rules(self, rules):
|
||||
"""Appends new rules to the router.
|
||||
|
||||
:arg rules: a list of Rule instances (or tuples of arguments, which are
|
||||
passed to Rule constructor).
|
||||
"""
|
||||
for rule in rules:
|
||||
if isinstance(rule, (tuple, list)):
|
||||
assert len(rule) in (2, 3, 4)
|
||||
if isinstance(rule[0], basestring_type):
|
||||
rule = Rule(PathMatches(rule[0]), *rule[1:])
|
||||
else:
|
||||
rule = Rule(*rule)
|
||||
|
||||
self.rules.append(self.process_rule(rule))
|
||||
|
||||
def process_rule(self, rule):
|
||||
"""Override this method for additional preprocessing of each rule.
|
||||
|
||||
:arg Rule rule: a rule to be processed.
|
||||
:returns: the same or modified Rule instance.
|
||||
"""
|
||||
return rule
|
||||
|
||||
def find_handler(self, request, **kwargs):
|
||||
for rule in self.rules:
|
||||
target_params = rule.matcher.match(request)
|
||||
if target_params is not None:
|
||||
if rule.target_kwargs:
|
||||
target_params['target_kwargs'] = rule.target_kwargs
|
||||
|
||||
delegate = self.get_target_delegate(
|
||||
rule.target, request, **target_params)
|
||||
|
||||
if delegate is not None:
|
||||
return delegate
|
||||
|
||||
return None
|
||||
|
||||
def get_target_delegate(self, target, request, **target_params):
|
||||
"""Returns an instance of `~.httputil.HTTPMessageDelegate` for a
|
||||
Rule's target. This method is called by `~.find_handler` and can be
|
||||
extended to provide additional target types.
|
||||
|
||||
:arg target: a Rule's target.
|
||||
:arg httputil.HTTPServerRequest request: current request.
|
||||
:arg target_params: additional parameters that can be useful
|
||||
for `~.httputil.HTTPMessageDelegate` creation.
|
||||
"""
|
||||
if isinstance(target, Router):
|
||||
return target.find_handler(request, **target_params)
|
||||
|
||||
elif isinstance(target, httputil.HTTPServerConnectionDelegate):
|
||||
return target.start_request(request.server_connection, request.connection)
|
||||
|
||||
elif callable(target):
|
||||
return _CallableAdapter(
|
||||
partial(target, **target_params), request.connection
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class ReversibleRuleRouter(ReversibleRouter, RuleRouter):
|
||||
"""A rule-based router that implements ``reverse_url`` method.
|
||||
|
||||
Each rule added to this router may have a ``name`` attribute that can be
|
||||
used to reconstruct an original uri. The actual reconstruction takes place
|
||||
in a rule's matcher (see `Matcher.reverse`).
|
||||
"""
|
||||
|
||||
def __init__(self, rules=None):
|
||||
self.named_rules = {} # type: typing.Dict[str]
|
||||
super(ReversibleRuleRouter, self).__init__(rules)
|
||||
|
||||
def process_rule(self, rule):
|
||||
rule = super(ReversibleRuleRouter, self).process_rule(rule)
|
||||
|
||||
if rule.name:
|
||||
if rule.name in self.named_rules:
|
||||
app_log.warning(
|
||||
"Multiple handlers named %s; replacing previous value",
|
||||
rule.name)
|
||||
self.named_rules[rule.name] = rule
|
||||
|
||||
return rule
|
||||
|
||||
def reverse_url(self, name, *args):
|
||||
if name in self.named_rules:
|
||||
return self.named_rules[name].matcher.reverse(*args)
|
||||
|
||||
for rule in self.rules:
|
||||
if isinstance(rule.target, ReversibleRouter):
|
||||
reversed_url = rule.target.reverse_url(name, *args)
|
||||
if reversed_url is not None:
|
||||
return reversed_url
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class Rule(object):
|
||||
"""A routing rule."""
|
||||
|
||||
def __init__(self, matcher, target, target_kwargs=None, name=None):
|
||||
"""Constructs a Rule instance.
|
||||
|
||||
:arg Matcher matcher: a `Matcher` instance used for determining
|
||||
whether the rule should be considered a match for a specific
|
||||
request.
|
||||
:arg target: a Rule's target (typically a ``RequestHandler`` or
|
||||
`~.httputil.HTTPServerConnectionDelegate` subclass or even a nested `Router`,
|
||||
depending on routing implementation).
|
||||
:arg dict target_kwargs: a dict of parameters that can be useful
|
||||
at the moment of target instantiation (for example, ``status_code``
|
||||
for a ``RequestHandler`` subclass). They end up in
|
||||
``target_params['target_kwargs']`` of `RuleRouter.get_target_delegate`
|
||||
method.
|
||||
:arg str name: the name of the rule that can be used to find it
|
||||
in `ReversibleRouter.reverse_url` implementation.
|
||||
"""
|
||||
if isinstance(target, str):
|
||||
# import the Module and instantiate the class
|
||||
# Must be a fully qualified name (module.ClassName)
|
||||
target = import_object(target)
|
||||
|
||||
self.matcher = matcher # type: Matcher
|
||||
self.target = target
|
||||
self.target_kwargs = target_kwargs if target_kwargs else {}
|
||||
self.name = name
|
||||
|
||||
def reverse(self, *args):
|
||||
return self.matcher.reverse(*args)
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%r, %s, kwargs=%r, name=%r)' % \
|
||||
(self.__class__.__name__, self.matcher,
|
||||
self.target, self.target_kwargs, self.name)
|
||||
|
||||
|
||||
class Matcher(object):
|
||||
"""Represents a matcher for request features."""
|
||||
|
||||
def match(self, request):
|
||||
"""Matches current instance against the request.
|
||||
|
||||
:arg httputil.HTTPServerRequest request: current HTTP request
|
||||
:returns: a dict of parameters to be passed to the target handler
|
||||
(for example, ``handler_kwargs``, ``path_args``, ``path_kwargs``
|
||||
can be passed for proper `~.web.RequestHandler` instantiation).
|
||||
An empty dict is a valid (and common) return value to indicate a match
|
||||
when the argument-passing features are not used.
|
||||
``None`` must be returned to indicate that there is no match."""
|
||||
raise NotImplementedError()
|
||||
|
||||
def reverse(self, *args):
|
||||
"""Reconstructs full url from matcher instance and additional arguments."""
|
||||
return None
|
||||
|
||||
|
||||
class AnyMatches(Matcher):
|
||||
"""Matches any request."""
|
||||
|
||||
def match(self, request):
|
||||
return {}
|
||||
|
||||
|
||||
class HostMatches(Matcher):
|
||||
"""Matches requests from hosts specified by ``host_pattern`` regex."""
|
||||
|
||||
def __init__(self, host_pattern):
|
||||
if isinstance(host_pattern, basestring_type):
|
||||
if not host_pattern.endswith("$"):
|
||||
host_pattern += "$"
|
||||
self.host_pattern = re.compile(host_pattern)
|
||||
else:
|
||||
self.host_pattern = host_pattern
|
||||
|
||||
def match(self, request):
|
||||
if self.host_pattern.match(request.host_name):
|
||||
return {}
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class DefaultHostMatches(Matcher):
|
||||
"""Matches requests from host that is equal to application's default_host.
|
||||
Always returns no match if ``X-Real-Ip`` header is present.
|
||||
"""
|
||||
|
||||
def __init__(self, application, host_pattern):
|
||||
self.application = application
|
||||
self.host_pattern = host_pattern
|
||||
|
||||
def match(self, request):
|
||||
# Look for default host if not behind load balancer (for debugging)
|
||||
if "X-Real-Ip" not in request.headers:
|
||||
if self.host_pattern.match(self.application.default_host):
|
||||
return {}
|
||||
return None
|
||||
|
||||
|
||||
class PathMatches(Matcher):
|
||||
"""Matches requests with paths specified by ``path_pattern`` regex."""
|
||||
|
||||
def __init__(self, path_pattern):
|
||||
if isinstance(path_pattern, basestring_type):
|
||||
if not path_pattern.endswith('$'):
|
||||
path_pattern += '$'
|
||||
self.regex = re.compile(path_pattern)
|
||||
else:
|
||||
self.regex = path_pattern
|
||||
|
||||
assert len(self.regex.groupindex) in (0, self.regex.groups), \
|
||||
("groups in url regexes must either be all named or all "
|
||||
"positional: %r" % self.regex.pattern)
|
||||
|
||||
self._path, self._group_count = self._find_groups()
|
||||
|
||||
def match(self, request):
|
||||
match = self.regex.match(request.path)
|
||||
if match is None:
|
||||
return None
|
||||
if not self.regex.groups:
|
||||
return {}
|
||||
|
||||
path_args, path_kwargs = [], {}
|
||||
|
||||
# Pass matched groups to the handler. Since
|
||||
# match.groups() includes both named and
|
||||
# unnamed groups, we want to use either groups
|
||||
# or groupdict but not both.
|
||||
if self.regex.groupindex:
|
||||
path_kwargs = dict(
|
||||
(str(k), _unquote_or_none(v))
|
||||
for (k, v) in match.groupdict().items())
|
||||
else:
|
||||
path_args = [_unquote_or_none(s) for s in match.groups()]
|
||||
|
||||
return dict(path_args=path_args, path_kwargs=path_kwargs)
|
||||
|
||||
def reverse(self, *args):
|
||||
if self._path is None:
|
||||
raise ValueError("Cannot reverse url regex " + self.regex.pattern)
|
||||
assert len(args) == self._group_count, "required number of arguments " \
|
||||
"not found"
|
||||
if not len(args):
|
||||
return self._path
|
||||
converted_args = []
|
||||
for a in args:
|
||||
if not isinstance(a, (unicode_type, bytes)):
|
||||
a = str(a)
|
||||
converted_args.append(url_escape(utf8(a), plus=False))
|
||||
return self._path % tuple(converted_args)
|
||||
|
||||
def _find_groups(self):
|
||||
"""Returns a tuple (reverse string, group count) for a url.
|
||||
|
||||
For example: Given the url pattern /([0-9]{4})/([a-z-]+)/, this method
|
||||
would return ('/%s/%s/', 2).
|
||||
"""
|
||||
pattern = self.regex.pattern
|
||||
if pattern.startswith('^'):
|
||||
pattern = pattern[1:]
|
||||
if pattern.endswith('$'):
|
||||
pattern = pattern[:-1]
|
||||
|
||||
if self.regex.groups != pattern.count('('):
|
||||
# The pattern is too complicated for our simplistic matching,
|
||||
# so we can't support reversing it.
|
||||
return None, None
|
||||
|
||||
pieces = []
|
||||
for fragment in pattern.split('('):
|
||||
if ')' in fragment:
|
||||
paren_loc = fragment.index(')')
|
||||
if paren_loc >= 0:
|
||||
pieces.append('%s' + fragment[paren_loc + 1:])
|
||||
else:
|
||||
try:
|
||||
unescaped_fragment = re_unescape(fragment)
|
||||
except ValueError as exc:
|
||||
# If we can't unescape part of it, we can't
|
||||
# reverse this url.
|
||||
return (None, None)
|
||||
pieces.append(unescaped_fragment)
|
||||
|
||||
return ''.join(pieces), self.regex.groups
|
||||
|
||||
|
||||
class URLSpec(Rule):
|
||||
"""Specifies mappings between URLs and handlers.
|
||||
|
||||
.. versionchanged: 4.5
|
||||
`URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for
|
||||
backwards compatibility.
|
||||
"""
|
||||
def __init__(self, pattern, handler, kwargs=None, name=None):
|
||||
"""Parameters:
|
||||
|
||||
* ``pattern``: Regular expression to be matched. Any capturing
|
||||
groups in the regex will be passed in to the handler's
|
||||
get/post/etc methods as arguments (by keyword if named, by
|
||||
position if unnamed. Named and unnamed capturing groups may
|
||||
may not be mixed in the same rule).
|
||||
|
||||
* ``handler``: `~.web.RequestHandler` subclass to be invoked.
|
||||
|
||||
* ``kwargs`` (optional): A dictionary of additional arguments
|
||||
to be passed to the handler's constructor.
|
||||
|
||||
* ``name`` (optional): A name for this handler. Used by
|
||||
`~.web.Application.reverse_url`.
|
||||
|
||||
"""
|
||||
super(URLSpec, self).__init__(PathMatches(pattern), handler, kwargs, name)
|
||||
|
||||
self.regex = self.matcher.regex
|
||||
self.handler_class = self.target
|
||||
self.kwargs = kwargs
|
||||
|
||||
def __repr__(self):
|
||||
return '%s(%r, %s, kwargs=%r, name=%r)' % \
|
||||
(self.__class__.__name__, self.regex.pattern,
|
||||
self.handler_class, self.kwargs, self.name)
|
||||
|
||||
|
||||
def _unquote_or_none(s):
|
||||
"""None-safe wrapper around url_unescape to handle unmatched optional
|
||||
groups correctly.
|
||||
|
||||
Note that args are passed as bytes so the handler can decide what
|
||||
encoding to use.
|
||||
"""
|
||||
if s is None:
|
||||
return s
|
||||
return url_unescape(s, encoding=None, plus=False)
|
|
@ -1,5 +1,5 @@
|
|||
#!/usr/bin/env python
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
from tornado.escape import utf8, _unicode
|
||||
from tornado import gen
|
||||
|
@ -11,6 +11,7 @@ from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
|
|||
from tornado.log import gen_log
|
||||
from tornado import stack_context
|
||||
from tornado.tcpclient import TCPClient
|
||||
from tornado.util import PY3
|
||||
|
||||
import base64
|
||||
import collections
|
||||
|
@ -22,10 +23,10 @@ import sys
|
|||
from io import BytesIO
|
||||
|
||||
|
||||
try:
|
||||
import urlparse # py2
|
||||
except ImportError:
|
||||
import urllib.parse as urlparse # py3
|
||||
if PY3:
|
||||
import urllib.parse as urlparse
|
||||
else:
|
||||
import urlparse
|
||||
|
||||
try:
|
||||
import ssl
|
||||
|
@ -126,7 +127,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
|
|||
timeout_handle = self.io_loop.add_timeout(
|
||||
self.io_loop.time() + min(request.connect_timeout,
|
||||
request.request_timeout),
|
||||
functools.partial(self._on_timeout, key))
|
||||
functools.partial(self._on_timeout, key, "in request queue"))
|
||||
else:
|
||||
timeout_handle = None
|
||||
self.waiting[key] = (request, callback, timeout_handle)
|
||||
|
@ -167,11 +168,20 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
|
|||
self.io_loop.remove_timeout(timeout_handle)
|
||||
del self.waiting[key]
|
||||
|
||||
def _on_timeout(self, key):
|
||||
def _on_timeout(self, key, info=None):
|
||||
"""Timeout callback of request.
|
||||
|
||||
Construct a timeout HTTPResponse when a timeout occurs.
|
||||
|
||||
:arg object key: A simple object to mark the request.
|
||||
:info string key: More detailed timeout information.
|
||||
"""
|
||||
request, callback, timeout_handle = self.waiting[key]
|
||||
self.queue.remove((key, request, callback))
|
||||
|
||||
error_message = "Timeout {0}".format(info) if info else "Timeout"
|
||||
timeout_response = HTTPResponse(
|
||||
request, 599, error=HTTPError(599, "Timeout"),
|
||||
request, 599, error=HTTPError(599, error_message),
|
||||
request_time=self.io_loop.time() - request.start_time)
|
||||
self.io_loop.add_callback(callback, timeout_response)
|
||||
del self.waiting[key]
|
||||
|
@ -229,7 +239,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
|||
if timeout:
|
||||
self._timeout = self.io_loop.add_timeout(
|
||||
self.start_time + timeout,
|
||||
stack_context.wrap(self._on_timeout))
|
||||
stack_context.wrap(functools.partial(self._on_timeout, "while connecting")))
|
||||
self.tcp_client.connect(host, port, af=af,
|
||||
ssl_options=ssl_options,
|
||||
max_buffer_size=self.max_buffer_size,
|
||||
|
@ -284,10 +294,17 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
|||
return ssl_options
|
||||
return None
|
||||
|
||||
def _on_timeout(self):
|
||||
def _on_timeout(self, info=None):
|
||||
"""Timeout callback of _HTTPConnection instance.
|
||||
|
||||
Raise a timeout HTTPError when a timeout occurs.
|
||||
|
||||
:info string key: More detailed timeout information.
|
||||
"""
|
||||
self._timeout = None
|
||||
error_message = "Timeout {0}".format(info) if info else "Timeout"
|
||||
if self.final_callback is not None:
|
||||
raise HTTPError(599, "Timeout")
|
||||
raise HTTPError(599, error_message)
|
||||
|
||||
def _remove_timeout(self):
|
||||
if self._timeout is not None:
|
||||
|
@ -307,13 +324,14 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
|||
if self.request.request_timeout:
|
||||
self._timeout = self.io_loop.add_timeout(
|
||||
self.start_time + self.request.request_timeout,
|
||||
stack_context.wrap(self._on_timeout))
|
||||
stack_context.wrap(functools.partial(self._on_timeout, "during request")))
|
||||
if (self.request.method not in self._SUPPORTED_METHODS and
|
||||
not self.request.allow_nonstandard_methods):
|
||||
raise KeyError("unknown method %s" % self.request.method)
|
||||
for key in ('network_interface',
|
||||
'proxy_host', 'proxy_port',
|
||||
'proxy_username', 'proxy_password'):
|
||||
'proxy_username', 'proxy_password',
|
||||
'proxy_auth_mode'):
|
||||
if getattr(self.request, key, None):
|
||||
raise NotImplementedError('%s not supported' % key)
|
||||
if "Connection" not in self.request.headers:
|
||||
|
@ -481,7 +499,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
|||
def _should_follow_redirect(self):
|
||||
return (self.request.follow_redirects and
|
||||
self.request.max_redirects > 0 and
|
||||
self.code in (301, 302, 303, 307))
|
||||
self.code in (301, 302, 303, 307, 308))
|
||||
|
||||
def finish(self):
|
||||
data = b''.join(self.chunks)
|
||||
|
|
|
@ -67,7 +67,7 @@ Here are a few rules of thumb for when it's necessary:
|
|||
block that references your `StackContext`.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import sys
|
||||
import threading
|
||||
|
@ -82,6 +82,8 @@ class StackContextInconsistentError(Exception):
|
|||
class _State(threading.local):
|
||||
def __init__(self):
|
||||
self.contexts = (tuple(), None)
|
||||
|
||||
|
||||
_state = _State()
|
||||
|
||||
|
||||
|
|
|
@ -16,7 +16,7 @@
|
|||
|
||||
"""A non-blocking TCP connection factory.
|
||||
"""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import functools
|
||||
import socket
|
||||
|
@ -155,16 +155,30 @@ class TCPClient(object):
|
|||
|
||||
@gen.coroutine
|
||||
def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
|
||||
max_buffer_size=None):
|
||||
max_buffer_size=None, source_ip=None, source_port=None):
|
||||
"""Connect to the given host and port.
|
||||
|
||||
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
|
||||
``ssl_options`` is not None).
|
||||
|
||||
Using the ``source_ip`` kwarg, one can specify the source
|
||||
IP address to use when establishing the connection.
|
||||
In case the user needs to resolve and
|
||||
use a specific interface, it has to be handled outside
|
||||
of Tornado as this depends very much on the platform.
|
||||
|
||||
Similarly, when the user requires a certain source port, it can
|
||||
be specified using the ``source_port`` arg.
|
||||
|
||||
.. versionchanged:: 4.5
|
||||
Added the ``source_ip`` and ``source_port`` arguments.
|
||||
"""
|
||||
addrinfo = yield self.resolver.resolve(host, port, af)
|
||||
connector = _Connector(
|
||||
addrinfo, self.io_loop,
|
||||
functools.partial(self._create_stream, max_buffer_size))
|
||||
functools.partial(self._create_stream, max_buffer_size,
|
||||
source_ip=source_ip, source_port=source_port)
|
||||
)
|
||||
af, addr, stream = yield connector.start()
|
||||
# TODO: For better performance we could cache the (af, addr)
|
||||
# information here and re-use it on subsequent connections to
|
||||
|
@ -174,10 +188,35 @@ class TCPClient(object):
|
|||
server_hostname=host)
|
||||
raise gen.Return(stream)
|
||||
|
||||
def _create_stream(self, max_buffer_size, af, addr):
|
||||
def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
|
||||
source_port=None):
|
||||
# Always connect in plaintext; we'll convert to ssl if necessary
|
||||
# after one connection has completed.
|
||||
stream = IOStream(socket.socket(af),
|
||||
io_loop=self.io_loop,
|
||||
max_buffer_size=max_buffer_size)
|
||||
return stream.connect(addr)
|
||||
source_port_bind = source_port if isinstance(source_port, int) else 0
|
||||
source_ip_bind = source_ip
|
||||
if source_port_bind and not source_ip:
|
||||
# User required a specific port, but did not specify
|
||||
# a certain source IP, will bind to the default loopback.
|
||||
source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1'
|
||||
# Trying to use the same address family as the requested af socket:
|
||||
# - 127.0.0.1 for IPv4
|
||||
# - ::1 for IPv6
|
||||
socket_obj = socket.socket(af)
|
||||
if source_port_bind or source_ip_bind:
|
||||
# If the user requires binding also to a specific IP/port.
|
||||
try:
|
||||
socket_obj.bind((source_ip_bind, source_port_bind))
|
||||
except socket.error:
|
||||
socket_obj.close()
|
||||
# Fail loudly if unable to use the IP/port.
|
||||
raise
|
||||
try:
|
||||
stream = IOStream(socket_obj,
|
||||
io_loop=self.io_loop,
|
||||
max_buffer_size=max_buffer_size)
|
||||
except socket.error as e:
|
||||
fu = Future()
|
||||
fu.set_exception(e)
|
||||
return fu
|
||||
else:
|
||||
return stream.connect(addr)
|
||||
|
|
|
@ -15,12 +15,13 @@
|
|||
# under the License.
|
||||
|
||||
"""A non-blocking, single-threaded TCP server."""
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import errno
|
||||
import os
|
||||
import socket
|
||||
|
||||
from tornado import gen
|
||||
from tornado.log import app_log
|
||||
from tornado.ioloop import IOLoop
|
||||
from tornado.iostream import IOStream, SSLIOStream
|
||||
|
@ -39,7 +40,21 @@ class TCPServer(object):
|
|||
r"""A non-blocking, single-threaded TCP server.
|
||||
|
||||
To use `TCPServer`, define a subclass which overrides the `handle_stream`
|
||||
method.
|
||||
method. For example, a simple echo server could be defined like this::
|
||||
|
||||
from tornado.tcpserver import TCPServer
|
||||
from tornado.iostream import StreamClosedError
|
||||
from tornado import gen
|
||||
|
||||
class EchoServer(TCPServer):
|
||||
@gen.coroutine
|
||||
def handle_stream(self, stream, address):
|
||||
while True:
|
||||
try:
|
||||
data = yield stream.read_until(b"\n")
|
||||
yield stream.write(data)
|
||||
except StreamClosedError:
|
||||
break
|
||||
|
||||
To make this server serve SSL traffic, send the ``ssl_options`` keyword
|
||||
argument with an `ssl.SSLContext` object. For compatibility with older
|
||||
|
@ -95,6 +110,7 @@ class TCPServer(object):
|
|||
self._sockets = {} # fd -> socket object
|
||||
self._pending_sockets = []
|
||||
self._started = False
|
||||
self._stopped = False
|
||||
self.max_buffer_size = max_buffer_size
|
||||
self.read_chunk_size = read_chunk_size
|
||||
|
||||
|
@ -147,7 +163,8 @@ class TCPServer(object):
|
|||
"""Singular version of `add_sockets`. Takes a single socket object."""
|
||||
self.add_sockets([socket])
|
||||
|
||||
def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128):
|
||||
def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128,
|
||||
reuse_port=False):
|
||||
"""Binds this server to the given port on the given address.
|
||||
|
||||
To start the server, call `start`. If you want to run this server
|
||||
|
@ -162,13 +179,17 @@ class TCPServer(object):
|
|||
both will be used if available.
|
||||
|
||||
The ``backlog`` argument has the same meaning as for
|
||||
`socket.listen <socket.socket.listen>`.
|
||||
`socket.listen <socket.socket.listen>`. The ``reuse_port`` argument
|
||||
has the same meaning as for `.bind_sockets`.
|
||||
|
||||
This method may be called multiple times prior to `start` to listen
|
||||
on multiple ports or interfaces.
|
||||
|
||||
.. versionchanged:: 4.4
|
||||
Added the ``reuse_port`` argument.
|
||||
"""
|
||||
sockets = bind_sockets(port, address=address, family=family,
|
||||
backlog=backlog)
|
||||
backlog=backlog, reuse_port=reuse_port)
|
||||
if self._started:
|
||||
self.add_sockets(sockets)
|
||||
else:
|
||||
|
@ -208,7 +229,11 @@ class TCPServer(object):
|
|||
Requests currently in progress may still continue after the
|
||||
server is stopped.
|
||||
"""
|
||||
if self._stopped:
|
||||
return
|
||||
self._stopped = True
|
||||
for fd, sock in self._sockets.items():
|
||||
assert sock.fileno() == fd
|
||||
self.io_loop.remove_handler(fd)
|
||||
sock.close()
|
||||
|
||||
|
@ -266,8 +291,10 @@ class TCPServer(object):
|
|||
stream = IOStream(connection, io_loop=self.io_loop,
|
||||
max_buffer_size=self.max_buffer_size,
|
||||
read_chunk_size=self.read_chunk_size)
|
||||
|
||||
future = self.handle_stream(stream, address)
|
||||
if future is not None:
|
||||
self.io_loop.add_future(future, lambda f: f.result())
|
||||
self.io_loop.add_future(gen.convert_yielded(future),
|
||||
lambda f: f.result())
|
||||
except Exception:
|
||||
app_log.error("Error in connection callback", exc_info=True)
|
||||
|
|
|
@ -19,13 +19,13 @@
|
|||
Basic usage looks like::
|
||||
|
||||
t = template.Template("<html>{{ myvalue }}</html>")
|
||||
print t.generate(myvalue="XXX")
|
||||
print(t.generate(myvalue="XXX"))
|
||||
|
||||
`Loader` is a class that loads templates from a root directory and caches
|
||||
the compiled templates::
|
||||
|
||||
loader = template.Loader("/home/btaylor")
|
||||
print loader.load("test.html").generate(myvalue="XXX")
|
||||
print(loader.load("test.html").generate(myvalue="XXX"))
|
||||
|
||||
We compile all templates to raw Python. Error-reporting is currently... uh,
|
||||
interesting. Syntax for the templates::
|
||||
|
@ -94,12 +94,15 @@ Syntax Reference
|
|||
Template expressions are surrounded by double curly braces: ``{{ ... }}``.
|
||||
The contents may be any python expression, which will be escaped according
|
||||
to the current autoescape setting and inserted into the output. Other
|
||||
template directives use ``{% %}``. These tags may be escaped as ``{{!``
|
||||
and ``{%!`` if you need to include a literal ``{{`` or ``{%`` in the output.
|
||||
template directives use ``{% %}``.
|
||||
|
||||
To comment out a section so that it is omitted from the output, surround it
|
||||
with ``{# ... #}``.
|
||||
|
||||
These tags may be escaped as ``{{!``, ``{%!``, and ``{#!``
|
||||
if you need to include a literal ``{{``, ``{%``, or ``{#`` in the output.
|
||||
|
||||
|
||||
``{% apply *function* %}...{% end %}``
|
||||
Applies a function to the output of all template code between ``apply``
|
||||
and ``end``::
|
||||
|
@ -193,7 +196,7 @@ with ``{# ... #}``.
|
|||
`filter_whitespace` for available options. New in Tornado 4.3.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, with_statement
|
||||
from __future__ import absolute_import, division, print_function
|
||||
|
||||
import datetime
|
||||
import linecache
|
||||
|
@ -204,12 +207,12 @@ import threading
|
|||
|
||||
from tornado import escape
|
||||
from tornado.log import app_log
|
||||
from tornado.util import ObjectDict, exec_in, unicode_type
|
||||
from tornado.util import ObjectDict, exec_in, unicode_type, PY3
|
||||
|
||||
try:
|
||||
from cStringIO import StringIO # py2
|
||||
except ImportError:
|
||||
from io import StringIO # py3
|
||||
if PY3:
|
||||
from io import StringIO
|
||||
else:
|
||||
from cStringIO import StringIO
|
||||
|
||||
_DEFAULT_AUTOESCAPE = "xhtml_escape"
|
||||
_UNSET = object()
|
||||
|
@ -665,7 +668,7 @@ class ParseError(Exception):
|
|||
.. versionchanged:: 4.3
|
||||
Added ``filename`` and ``lineno`` attributes.
|
||||
"""
|
||||
def __init__(self, message, filename, lineno):
|
||||
def __init__(self, message, filename=None, lineno=0):
|
||||
self.message = message
|
||||
# The names "filename" and "lineno" are chosen for consistency
|
||||
# with python SyntaxError.
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue