mirror of https://github.com/tp4a/teleport
web端做了很大改动,尚未完成。
parent
2d0ce5da20
commit
51b143c828
|
@ -1,8 +1,8 @@
|
||||||
# mako/__init__.py
|
# mako/__init__.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
||||||
|
|
||||||
__version__ = '1.0.3'
|
__version__ = '1.0.6'
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/_ast_util.py
|
# mako/_ast_util.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/ast.py
|
# mako/ast.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/cache.py
|
# mako/cache.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/cmd.py
|
# mako/cmd.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/codegen.py
|
# mako/codegen.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -5,6 +5,7 @@ py3k = sys.version_info >= (3, 0)
|
||||||
py33 = sys.version_info >= (3, 3)
|
py33 = sys.version_info >= (3, 3)
|
||||||
py2k = sys.version_info < (3,)
|
py2k = sys.version_info < (3,)
|
||||||
py26 = sys.version_info >= (2, 6)
|
py26 = sys.version_info >= (2, 6)
|
||||||
|
py27 = sys.version_info >= (2, 7)
|
||||||
jython = sys.platform.startswith('java')
|
jython = sys.platform.startswith('java')
|
||||||
win32 = sys.platform.startswith('win')
|
win32 = sys.platform.startswith('win')
|
||||||
pypy = hasattr(sys, 'pypy_version_info')
|
pypy = hasattr(sys, 'pypy_version_info')
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/exceptions.py
|
# mako/exceptions.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# ext/autohandler.py
|
# ext/autohandler.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# ext/babelplugin.py
|
# ext/babelplugin.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# ext/preprocessors.py
|
# ext/preprocessors.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# ext/pygmentplugin.py
|
# ext/pygmentplugin.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# ext/turbogears.py
|
# ext/turbogears.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/filters.py
|
# mako/filters.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/lexer.py
|
# mako/lexer.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
@ -95,31 +95,37 @@ class Lexer(object):
|
||||||
# (match and "TRUE" or "FALSE")
|
# (match and "TRUE" or "FALSE")
|
||||||
return match
|
return match
|
||||||
|
|
||||||
def parse_until_text(self, *text):
|
def parse_until_text(self, watch_nesting, *text):
|
||||||
startpos = self.match_position
|
startpos = self.match_position
|
||||||
text_re = r'|'.join(text)
|
text_re = r'|'.join(text)
|
||||||
brace_level = 0
|
brace_level = 0
|
||||||
|
paren_level = 0
|
||||||
|
bracket_level = 0
|
||||||
while True:
|
while True:
|
||||||
match = self.match(r'#.*\n')
|
match = self.match(r'#.*\n')
|
||||||
if match:
|
if match:
|
||||||
continue
|
continue
|
||||||
match = self.match(r'(\"\"\"|\'\'\'|\"|\')((?<!\\)\\\1|.)*?\1',
|
match = self.match(r'(\"\"\"|\'\'\'|\"|\')[^\\]*?(\\.[^\\]*?)*\1',
|
||||||
re.S)
|
re.S)
|
||||||
if match:
|
if match:
|
||||||
continue
|
continue
|
||||||
match = self.match(r'(%s)' % text_re)
|
match = self.match(r'(%s)' % text_re)
|
||||||
if match:
|
if match and not (watch_nesting
|
||||||
if match.group(1) == '}' and brace_level > 0:
|
and (brace_level > 0 or paren_level > 0
|
||||||
brace_level -= 1
|
or bracket_level > 0)):
|
||||||
continue
|
|
||||||
return \
|
return \
|
||||||
self.text[startpos:
|
self.text[startpos:
|
||||||
self.match_position - len(match.group(1))],\
|
self.match_position - len(match.group(1))],\
|
||||||
match.group(1)
|
match.group(1)
|
||||||
match = self.match(r"(.*?)(?=\"|\'|#|%s)" % text_re, re.S)
|
elif not match:
|
||||||
|
match = self.match(r"(.*?)(?=\"|\'|#|%s)" % text_re, re.S)
|
||||||
if match:
|
if match:
|
||||||
brace_level += match.group(1).count('{')
|
brace_level += match.group(1).count('{')
|
||||||
brace_level -= match.group(1).count('}')
|
brace_level -= match.group(1).count('}')
|
||||||
|
paren_level += match.group(1).count('(')
|
||||||
|
paren_level -= match.group(1).count(')')
|
||||||
|
bracket_level += match.group(1).count('[')
|
||||||
|
bracket_level -= match.group(1).count(']')
|
||||||
continue
|
continue
|
||||||
raise exceptions.SyntaxException(
|
raise exceptions.SyntaxException(
|
||||||
"Expected: %s" %
|
"Expected: %s" %
|
||||||
|
@ -368,7 +374,7 @@ class Lexer(object):
|
||||||
match = self.match(r"<%(!)?")
|
match = self.match(r"<%(!)?")
|
||||||
if match:
|
if match:
|
||||||
line, pos = self.matched_lineno, self.matched_charpos
|
line, pos = self.matched_lineno, self.matched_charpos
|
||||||
text, end = self.parse_until_text(r'%>')
|
text, end = self.parse_until_text(False, r'%>')
|
||||||
# the trailing newline helps
|
# the trailing newline helps
|
||||||
# compiler.parse() not complain about indentation
|
# compiler.parse() not complain about indentation
|
||||||
text = adjust_whitespace(text) + "\n"
|
text = adjust_whitespace(text) + "\n"
|
||||||
|
@ -384,9 +390,9 @@ class Lexer(object):
|
||||||
match = self.match(r"\${")
|
match = self.match(r"\${")
|
||||||
if match:
|
if match:
|
||||||
line, pos = self.matched_lineno, self.matched_charpos
|
line, pos = self.matched_lineno, self.matched_charpos
|
||||||
text, end = self.parse_until_text(r'\|', r'}')
|
text, end = self.parse_until_text(True, r'\|', r'}')
|
||||||
if end == '|':
|
if end == '|':
|
||||||
escapes, end = self.parse_until_text(r'}')
|
escapes, end = self.parse_until_text(True, r'}')
|
||||||
else:
|
else:
|
||||||
escapes = ""
|
escapes = ""
|
||||||
text = text.replace('\r\n', '\n')
|
text = text.replace('\r\n', '\n')
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/lookup.py
|
# mako/lookup.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
@ -96,7 +96,7 @@ class TemplateLookup(TemplateCollection):
|
||||||
.. sourcecode:: python
|
.. sourcecode:: python
|
||||||
|
|
||||||
lookup = TemplateLookup(["/path/to/templates"])
|
lookup = TemplateLookup(["/path/to/templates"])
|
||||||
some_template = lookup.get_template("/admin_index.mako")
|
some_template = lookup.get_template("/index.html")
|
||||||
|
|
||||||
The :class:`.TemplateLookup` can also be given :class:`.Template` objects
|
The :class:`.TemplateLookup` can also be given :class:`.Template` objects
|
||||||
programatically using :meth:`.put_string` or :meth:`.put_template`:
|
programatically using :meth:`.put_string` or :meth:`.put_template`:
|
||||||
|
@ -180,7 +180,8 @@ class TemplateLookup(TemplateCollection):
|
||||||
enable_loop=True,
|
enable_loop=True,
|
||||||
input_encoding=None,
|
input_encoding=None,
|
||||||
preprocessor=None,
|
preprocessor=None,
|
||||||
lexer_cls=None):
|
lexer_cls=None,
|
||||||
|
include_error_handler=None):
|
||||||
|
|
||||||
self.directories = [posixpath.normpath(d) for d in
|
self.directories = [posixpath.normpath(d) for d in
|
||||||
util.to_list(directories, ())
|
util.to_list(directories, ())
|
||||||
|
@ -203,6 +204,7 @@ class TemplateLookup(TemplateCollection):
|
||||||
self.template_args = {
|
self.template_args = {
|
||||||
'format_exceptions': format_exceptions,
|
'format_exceptions': format_exceptions,
|
||||||
'error_handler': error_handler,
|
'error_handler': error_handler,
|
||||||
|
'include_error_handler': include_error_handler,
|
||||||
'disable_unicode': disable_unicode,
|
'disable_unicode': disable_unicode,
|
||||||
'bytestring_passthrough': bytestring_passthrough,
|
'bytestring_passthrough': bytestring_passthrough,
|
||||||
'output_encoding': output_encoding,
|
'output_encoding': output_encoding,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/parsetree.py
|
# mako/parsetree.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/pygen.py
|
# mako/pygen.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/pyparser.py
|
# mako/pyparser.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/runtime.py
|
# mako/runtime.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
@ -749,7 +749,16 @@ def _include_file(context, uri, calling_uri, **kwargs):
|
||||||
(callable_, ctx) = _populate_self_namespace(
|
(callable_, ctx) = _populate_self_namespace(
|
||||||
context._clean_inheritance_tokens(),
|
context._clean_inheritance_tokens(),
|
||||||
template)
|
template)
|
||||||
callable_(ctx, **_kwargs_for_include(callable_, context._data, **kwargs))
|
kwargs = _kwargs_for_include(callable_, context._data, **kwargs)
|
||||||
|
if template.include_error_handler:
|
||||||
|
try:
|
||||||
|
callable_(ctx, **kwargs)
|
||||||
|
except Exception:
|
||||||
|
result = template.include_error_handler(ctx, compat.exception_as())
|
||||||
|
if not result:
|
||||||
|
compat.reraise(*sys.exc_info())
|
||||||
|
else:
|
||||||
|
callable_(ctx, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _inherit_from(context, uri, calling_uri):
|
def _inherit_from(context, uri, calling_uri):
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/template.py
|
# mako/template.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
@ -109,6 +109,11 @@ class Template(object):
|
||||||
completes. Is used to provide custom error-rendering
|
completes. Is used to provide custom error-rendering
|
||||||
functions.
|
functions.
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
:paramref:`.Template.include_error_handler` - include-specific
|
||||||
|
error handler function
|
||||||
|
|
||||||
:param format_exceptions: if ``True``, exceptions which occur during
|
:param format_exceptions: if ``True``, exceptions which occur during
|
||||||
the render phase of this template will be caught and
|
the render phase of this template will be caught and
|
||||||
formatted into an HTML error page, which then becomes the
|
formatted into an HTML error page, which then becomes the
|
||||||
|
@ -129,6 +134,16 @@ class Template(object):
|
||||||
import will not appear as the first executed statement in the generated
|
import will not appear as the first executed statement in the generated
|
||||||
code and will therefore not have the desired effect.
|
code and will therefore not have the desired effect.
|
||||||
|
|
||||||
|
:param include_error_handler: An error handler that runs when this template
|
||||||
|
is included within another one via the ``<%include>`` tag, and raises an
|
||||||
|
error. Compare to the :paramref:`.Template.error_handler` option.
|
||||||
|
|
||||||
|
.. versionadded:: 1.0.6
|
||||||
|
|
||||||
|
.. seealso::
|
||||||
|
|
||||||
|
:paramref:`.Template.error_handler` - top-level error handler function
|
||||||
|
|
||||||
:param input_encoding: Encoding of the template's source code. Can
|
:param input_encoding: Encoding of the template's source code. Can
|
||||||
be used in lieu of the coding comment. See
|
be used in lieu of the coding comment. See
|
||||||
:ref:`usage_unicode` as well as :ref:`unicode_toplevel` for
|
:ref:`usage_unicode` as well as :ref:`unicode_toplevel` for
|
||||||
|
@ -171,7 +186,7 @@ class Template(object):
|
||||||
|
|
||||||
from mako.template import Template
|
from mako.template import Template
|
||||||
mytemplate = Template(
|
mytemplate = Template(
|
||||||
filename="admin_index.mako",
|
filename="index.html",
|
||||||
module_directory="/path/to/modules",
|
module_directory="/path/to/modules",
|
||||||
module_writer=module_writer
|
module_writer=module_writer
|
||||||
)
|
)
|
||||||
|
@ -243,7 +258,8 @@ class Template(object):
|
||||||
future_imports=None,
|
future_imports=None,
|
||||||
enable_loop=True,
|
enable_loop=True,
|
||||||
preprocessor=None,
|
preprocessor=None,
|
||||||
lexer_cls=None):
|
lexer_cls=None,
|
||||||
|
include_error_handler=None):
|
||||||
if uri:
|
if uri:
|
||||||
self.module_id = re.sub(r'\W', "_", uri)
|
self.module_id = re.sub(r'\W', "_", uri)
|
||||||
self.uri = uri
|
self.uri = uri
|
||||||
|
@ -329,6 +345,7 @@ class Template(object):
|
||||||
self.callable_ = self.module.render_body
|
self.callable_ = self.module.render_body
|
||||||
self.format_exceptions = format_exceptions
|
self.format_exceptions = format_exceptions
|
||||||
self.error_handler = error_handler
|
self.error_handler = error_handler
|
||||||
|
self.include_error_handler = include_error_handler
|
||||||
self.lookup = lookup
|
self.lookup = lookup
|
||||||
|
|
||||||
self.module_directory = module_directory
|
self.module_directory = module_directory
|
||||||
|
@ -475,6 +492,14 @@ class Template(object):
|
||||||
|
|
||||||
return DefTemplate(self, getattr(self.module, "render_%s" % name))
|
return DefTemplate(self, getattr(self.module, "render_%s" % name))
|
||||||
|
|
||||||
|
def list_defs(self):
|
||||||
|
"""return a list of defs in the template.
|
||||||
|
|
||||||
|
.. versionadded:: 1.0.4
|
||||||
|
|
||||||
|
"""
|
||||||
|
return [i[7:] for i in dir(self.module) if i[:7] == 'render_']
|
||||||
|
|
||||||
def _get_def_callable(self, name):
|
def _get_def_callable(self, name):
|
||||||
return getattr(self.module, "render_%s" % name)
|
return getattr(self.module, "render_%s" % name)
|
||||||
|
|
||||||
|
@ -520,6 +545,7 @@ class ModuleTemplate(Template):
|
||||||
cache_type=None,
|
cache_type=None,
|
||||||
cache_dir=None,
|
cache_dir=None,
|
||||||
cache_url=None,
|
cache_url=None,
|
||||||
|
include_error_handler=None,
|
||||||
):
|
):
|
||||||
self.module_id = re.sub(r'\W', "_", module._template_uri)
|
self.module_id = re.sub(r'\W', "_", module._template_uri)
|
||||||
self.uri = module._template_uri
|
self.uri = module._template_uri
|
||||||
|
@ -551,6 +577,7 @@ class ModuleTemplate(Template):
|
||||||
self.callable_ = self.module.render_body
|
self.callable_ = self.module.render_body
|
||||||
self.format_exceptions = format_exceptions
|
self.format_exceptions = format_exceptions
|
||||||
self.error_handler = error_handler
|
self.error_handler = error_handler
|
||||||
|
self.include_error_handler = include_error_handler
|
||||||
self.lookup = lookup
|
self.lookup = lookup
|
||||||
self._setup_cache_args(
|
self._setup_cache_args(
|
||||||
cache_impl, cache_enabled, cache_args,
|
cache_impl, cache_enabled, cache_args,
|
||||||
|
@ -571,6 +598,7 @@ class DefTemplate(Template):
|
||||||
self.encoding_errors = parent.encoding_errors
|
self.encoding_errors = parent.encoding_errors
|
||||||
self.format_exceptions = parent.format_exceptions
|
self.format_exceptions = parent.format_exceptions
|
||||||
self.error_handler = parent.error_handler
|
self.error_handler = parent.error_handler
|
||||||
|
self.include_error_handler = parent.include_error_handler
|
||||||
self.enable_loop = parent.enable_loop
|
self.enable_loop = parent.enable_loop
|
||||||
self.lookup = parent.lookup
|
self.lookup = parent.lookup
|
||||||
self.bytestring_passthrough = parent.bytestring_passthrough
|
self.bytestring_passthrough = parent.bytestring_passthrough
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
# mako/util.py
|
# mako/util.py
|
||||||
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
|
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
|
||||||
#
|
#
|
||||||
# This module is part of Mako and is released under
|
# This module is part of Mako and is released under
|
||||||
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
# the MIT License: http://www.opensource.org/licenses/mit-license.php
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
__version__ = '1.3.5'
|
|
|
@ -1,12 +0,0 @@
|
||||||
# API Backwards compatibility
|
|
||||||
|
|
||||||
from pymemcache.client.base import Client # noqa
|
|
||||||
from pymemcache.client.base import PooledClient # noqa
|
|
||||||
|
|
||||||
from pymemcache.exceptions import MemcacheError # noqa
|
|
||||||
from pymemcache.exceptions import MemcacheClientError # noqa
|
|
||||||
from pymemcache.exceptions import MemcacheUnknownCommandError # noqa
|
|
||||||
from pymemcache.exceptions import MemcacheIllegalInputError # noqa
|
|
||||||
from pymemcache.exceptions import MemcacheServerError # noqa
|
|
||||||
from pymemcache.exceptions import MemcacheUnknownError # noqa
|
|
||||||
from pymemcache.exceptions import MemcacheUnexpectedCloseError # noqa
|
|
File diff suppressed because it is too large
Load Diff
|
@ -1,333 +0,0 @@
|
||||||
import socket
|
|
||||||
import time
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from pymemcache.client.base import Client, PooledClient, _check_key
|
|
||||||
from pymemcache.client.rendezvous import RendezvousHash
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class HashClient(object):
|
|
||||||
"""
|
|
||||||
A client for communicating with a cluster of memcached servers
|
|
||||||
"""
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
servers,
|
|
||||||
hasher=RendezvousHash,
|
|
||||||
serializer=None,
|
|
||||||
deserializer=None,
|
|
||||||
connect_timeout=None,
|
|
||||||
timeout=None,
|
|
||||||
no_delay=False,
|
|
||||||
socket_module=socket,
|
|
||||||
key_prefix=b'',
|
|
||||||
max_pool_size=None,
|
|
||||||
lock_generator=None,
|
|
||||||
retry_attempts=2,
|
|
||||||
retry_timeout=1,
|
|
||||||
dead_timeout=60,
|
|
||||||
use_pooling=False,
|
|
||||||
ignore_exc=False,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Constructor.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
servers: list(tuple(hostname, port))
|
|
||||||
hasher: optional class three functions ``get_node``, ``add_node``,
|
|
||||||
and ``remove_node``
|
|
||||||
defaults to Rendezvous (HRW) hash.
|
|
||||||
|
|
||||||
use_pooling: use py:class:`.PooledClient` as the default underlying
|
|
||||||
class. ``max_pool_size`` and ``lock_generator`` can
|
|
||||||
be used with this. default: False
|
|
||||||
|
|
||||||
retry_attempts: Amount of times a client should be tried before it
|
|
||||||
is marked dead and removed from the pool.
|
|
||||||
retry_timeout (float): Time in seconds that should pass between retry
|
|
||||||
attempts.
|
|
||||||
dead_timeout (float): Time in seconds before attempting to add a node
|
|
||||||
back in the pool.
|
|
||||||
|
|
||||||
Further arguments are interpreted as for :py:class:`.Client`
|
|
||||||
constructor.
|
|
||||||
|
|
||||||
The default ``hasher`` is using a pure python implementation that can
|
|
||||||
be significantly improved performance wise by switching to a C based
|
|
||||||
version. We recommend using ``python-clandestined`` if having a C
|
|
||||||
dependency is acceptable.
|
|
||||||
"""
|
|
||||||
self.clients = {}
|
|
||||||
self.retry_attempts = retry_attempts
|
|
||||||
self.retry_timeout = retry_timeout
|
|
||||||
self.dead_timeout = dead_timeout
|
|
||||||
self.use_pooling = use_pooling
|
|
||||||
self.key_prefix = key_prefix
|
|
||||||
self.ignore_exc = ignore_exc
|
|
||||||
self._failed_clients = {}
|
|
||||||
self._dead_clients = {}
|
|
||||||
self._last_dead_check_time = time.time()
|
|
||||||
|
|
||||||
self.hasher = hasher()
|
|
||||||
|
|
||||||
self.default_kwargs = {
|
|
||||||
'connect_timeout': connect_timeout,
|
|
||||||
'timeout': timeout,
|
|
||||||
'no_delay': no_delay,
|
|
||||||
'socket_module': socket_module,
|
|
||||||
'key_prefix': key_prefix,
|
|
||||||
'serializer': serializer,
|
|
||||||
'deserializer': deserializer,
|
|
||||||
}
|
|
||||||
|
|
||||||
if use_pooling is True:
|
|
||||||
self.default_kwargs.update({
|
|
||||||
'max_pool_size': max_pool_size,
|
|
||||||
'lock_generator': lock_generator
|
|
||||||
})
|
|
||||||
|
|
||||||
for server, port in servers:
|
|
||||||
self.add_server(server, port)
|
|
||||||
|
|
||||||
def add_server(self, server, port):
|
|
||||||
key = '%s:%s' % (server, port)
|
|
||||||
|
|
||||||
if self.use_pooling:
|
|
||||||
client = PooledClient(
|
|
||||||
(server, port),
|
|
||||||
**self.default_kwargs
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
client = Client((server, port), **self.default_kwargs)
|
|
||||||
|
|
||||||
self.clients[key] = client
|
|
||||||
self.hasher.add_node(key)
|
|
||||||
|
|
||||||
def remove_server(self, server, port):
|
|
||||||
dead_time = time.time()
|
|
||||||
self._failed_clients.pop((server, port))
|
|
||||||
self._dead_clients[(server, port)] = dead_time
|
|
||||||
key = '%s:%s' % (server, port)
|
|
||||||
self.hasher.remove_node(key)
|
|
||||||
|
|
||||||
def _get_client(self, key):
|
|
||||||
_check_key(key, self.key_prefix)
|
|
||||||
if len(self._dead_clients) > 0:
|
|
||||||
current_time = time.time()
|
|
||||||
ldc = self._last_dead_check_time
|
|
||||||
# we have dead clients and we have reached the
|
|
||||||
# timeout retry
|
|
||||||
if current_time - ldc > self.dead_timeout:
|
|
||||||
for server, dead_time in self._dead_clients.items():
|
|
||||||
if current_time - dead_time > self.dead_timeout:
|
|
||||||
logger.debug(
|
|
||||||
'bringing server back into rotation %s',
|
|
||||||
server
|
|
||||||
)
|
|
||||||
self.add_server(*server)
|
|
||||||
self._last_dead_check_time = current_time
|
|
||||||
|
|
||||||
server = self.hasher.get_node(key)
|
|
||||||
# We've ran out of servers to try
|
|
||||||
if server is None:
|
|
||||||
if self.ignore_exc is True:
|
|
||||||
return
|
|
||||||
raise Exception('All servers seem to be down right now')
|
|
||||||
|
|
||||||
client = self.clients[server]
|
|
||||||
return client
|
|
||||||
|
|
||||||
def _safely_run_func(self, client, func, default_val, *args, **kwargs):
|
|
||||||
try:
|
|
||||||
if client.server in self._failed_clients:
|
|
||||||
# This server is currently failing, lets check if it is in
|
|
||||||
# retry or marked as dead
|
|
||||||
failed_metadata = self._failed_clients[client.server]
|
|
||||||
|
|
||||||
# we haven't tried our max amount yet, if it has been enough
|
|
||||||
# time lets just retry using it
|
|
||||||
if failed_metadata['attempts'] < self.retry_attempts:
|
|
||||||
failed_time = failed_metadata['failed_time']
|
|
||||||
if time.time() - failed_time > self.retry_timeout:
|
|
||||||
logger.debug(
|
|
||||||
'retrying failed server: %s', client.server
|
|
||||||
)
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
# we were successful, lets remove it from the failed
|
|
||||||
# clients
|
|
||||||
self._failed_clients.pop(client.server)
|
|
||||||
return result
|
|
||||||
return default_val
|
|
||||||
else:
|
|
||||||
# We've reached our max retry attempts, we need to mark
|
|
||||||
# the sever as dead
|
|
||||||
logger.debug('marking server as dead: %s', client.server)
|
|
||||||
self.remove_server(*client.server)
|
|
||||||
|
|
||||||
result = func(*args, **kwargs)
|
|
||||||
return result
|
|
||||||
|
|
||||||
# Connecting to the server fail, we should enter
|
|
||||||
# retry mode
|
|
||||||
except socket.error:
|
|
||||||
# This client has never failed, lets mark it for failure
|
|
||||||
if (
|
|
||||||
client.server not in self._failed_clients and
|
|
||||||
self.retry_attempts > 0
|
|
||||||
):
|
|
||||||
self._failed_clients[client.server] = {
|
|
||||||
'failed_time': time.time(),
|
|
||||||
'attempts': 0,
|
|
||||||
}
|
|
||||||
# We aren't allowing any retries, we should mark the server as
|
|
||||||
# dead immediately
|
|
||||||
elif (
|
|
||||||
client.server not in self._failed_clients and
|
|
||||||
self.retry_attempts <= 0
|
|
||||||
):
|
|
||||||
self._failed_clients[client.server] = {
|
|
||||||
'failed_time': time.time(),
|
|
||||||
'attempts': 0,
|
|
||||||
}
|
|
||||||
logger.debug("marking server as dead %s", client.server)
|
|
||||||
self.remove_server(*client.server)
|
|
||||||
# This client has failed previously, we need to update the metadata
|
|
||||||
# to reflect that we have attempted it again
|
|
||||||
else:
|
|
||||||
failed_metadata = self._failed_clients[client.server]
|
|
||||||
failed_metadata['attempts'] += 1
|
|
||||||
failed_metadata['failed_time'] = time.time()
|
|
||||||
self._failed_clients[client.server] = failed_metadata
|
|
||||||
|
|
||||||
# if we haven't enabled ignore_exc, don't move on gracefully, just
|
|
||||||
# raise the exception
|
|
||||||
if not self.ignore_exc:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return default_val
|
|
||||||
except:
|
|
||||||
# any exceptions that aren't socket.error we need to handle
|
|
||||||
# gracefully as well
|
|
||||||
if not self.ignore_exc:
|
|
||||||
raise
|
|
||||||
|
|
||||||
return default_val
|
|
||||||
|
|
||||||
def _run_cmd(self, cmd, key, default_val, *args, **kwargs):
|
|
||||||
client = self._get_client(key)
|
|
||||||
|
|
||||||
if client is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
func = getattr(client, cmd)
|
|
||||||
args = list(args)
|
|
||||||
args.insert(0, key)
|
|
||||||
return self._safely_run_func(
|
|
||||||
client, func, default_val, *args, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
def set(self, key, *args, **kwargs):
|
|
||||||
return self._run_cmd('set', key, False, *args, **kwargs)
|
|
||||||
|
|
||||||
def get(self, key, *args, **kwargs):
|
|
||||||
return self._run_cmd('get', key, None, *args, **kwargs)
|
|
||||||
|
|
||||||
def incr(self, key, *args, **kwargs):
|
|
||||||
return self._run_cmd('incr', key, False, *args, **kwargs)
|
|
||||||
|
|
||||||
def decr(self, key, *args, **kwargs):
|
|
||||||
return self._run_cmd('decr', key, False, *args, **kwargs)
|
|
||||||
|
|
||||||
def set_many(self, values, *args, **kwargs):
|
|
||||||
client_batches = {}
|
|
||||||
end = []
|
|
||||||
|
|
||||||
for key, value in values.items():
|
|
||||||
client = self._get_client(key)
|
|
||||||
|
|
||||||
if client is None:
|
|
||||||
end.append(False)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if client.server not in client_batches:
|
|
||||||
client_batches[client.server] = {}
|
|
||||||
|
|
||||||
client_batches[client.server][key] = value
|
|
||||||
|
|
||||||
for server, values in client_batches.items():
|
|
||||||
client = self.clients['%s:%s' % server]
|
|
||||||
new_args = list(args)
|
|
||||||
new_args.insert(0, values)
|
|
||||||
result = self._safely_run_func(
|
|
||||||
client,
|
|
||||||
client.set_many, False, *new_args, **kwargs
|
|
||||||
)
|
|
||||||
end.append(result)
|
|
||||||
|
|
||||||
return all(end)
|
|
||||||
|
|
||||||
set_multi = set_many
|
|
||||||
|
|
||||||
def get_many(self, keys, *args, **kwargs):
|
|
||||||
client_batches = {}
|
|
||||||
end = {}
|
|
||||||
|
|
||||||
for key in keys:
|
|
||||||
client = self._get_client(key)
|
|
||||||
|
|
||||||
if client is None:
|
|
||||||
end[key] = False
|
|
||||||
continue
|
|
||||||
|
|
||||||
if client.server not in client_batches:
|
|
||||||
client_batches[client.server] = []
|
|
||||||
|
|
||||||
client_batches[client.server].append(key)
|
|
||||||
|
|
||||||
for server, keys in client_batches.items():
|
|
||||||
client = self.clients['%s:%s' % server]
|
|
||||||
new_args = list(args)
|
|
||||||
new_args.insert(0, keys)
|
|
||||||
result = self._safely_run_func(
|
|
||||||
client,
|
|
||||||
client.get_many, {}, *new_args, **kwargs
|
|
||||||
)
|
|
||||||
end.update(result)
|
|
||||||
|
|
||||||
return end
|
|
||||||
|
|
||||||
get_multi = get_many
|
|
||||||
|
|
||||||
def gets(self, key, *args, **kwargs):
|
|
||||||
return self._run_cmd('gets', key, None, *args, **kwargs)
|
|
||||||
|
|
||||||
def add(self, key, *args, **kwargs):
|
|
||||||
return self._run_cmd('add', key, False, *args, **kwargs)
|
|
||||||
|
|
||||||
def prepend(self, key, *args, **kwargs):
|
|
||||||
return self._run_cmd('prepend', key, False, *args, **kwargs)
|
|
||||||
|
|
||||||
def append(self, key, *args, **kwargs):
|
|
||||||
return self._run_cmd('append', key, False, *args, **kwargs)
|
|
||||||
|
|
||||||
def delete(self, key, *args, **kwargs):
|
|
||||||
return self._run_cmd('delete', key, False, *args, **kwargs)
|
|
||||||
|
|
||||||
def delete_many(self, keys, *args, **kwargs):
|
|
||||||
for key in keys:
|
|
||||||
self._run_cmd('delete', key, False, *args, **kwargs)
|
|
||||||
return True
|
|
||||||
|
|
||||||
delete_multi = delete_many
|
|
||||||
|
|
||||||
def cas(self, key, *args, **kwargs):
|
|
||||||
return self._run_cmd('cas', key, False, *args, **kwargs)
|
|
||||||
|
|
||||||
def replace(self, key, *args, **kwargs):
|
|
||||||
return self._run_cmd('replace', key, False, *args, **kwargs)
|
|
||||||
|
|
||||||
def flush_all(self):
|
|
||||||
for _, client in self.clients.items():
|
|
||||||
self._safely_run_func(client, client.flush_all, False)
|
|
|
@ -1,51 +0,0 @@
|
||||||
def murmur3_32(data, seed=0):
|
|
||||||
"""MurmurHash3 was written by Austin Appleby, and is placed in the
|
|
||||||
public domain. The author hereby disclaims copyright to this source
|
|
||||||
code."""
|
|
||||||
|
|
||||||
c1 = 0xcc9e2d51
|
|
||||||
c2 = 0x1b873593
|
|
||||||
|
|
||||||
length = len(data)
|
|
||||||
h1 = seed
|
|
||||||
roundedEnd = (length & 0xfffffffc) # round down to 4 byte block
|
|
||||||
for i in range(0, roundedEnd, 4):
|
|
||||||
# little endian load order
|
|
||||||
k1 = (ord(data[i]) & 0xff) | ((ord(data[i + 1]) & 0xff) << 8) | \
|
|
||||||
((ord(data[i + 2]) & 0xff) << 16) | (ord(data[i + 3]) << 24)
|
|
||||||
k1 *= c1
|
|
||||||
k1 = (k1 << 15) | ((k1 & 0xffffffff) >> 17) # ROTL32(k1,15)
|
|
||||||
k1 *= c2
|
|
||||||
|
|
||||||
h1 ^= k1
|
|
||||||
h1 = (h1 << 13) | ((h1 & 0xffffffff) >> 19) # ROTL32(h1,13)
|
|
||||||
h1 = h1 * 5 + 0xe6546b64
|
|
||||||
|
|
||||||
# tail
|
|
||||||
k1 = 0
|
|
||||||
|
|
||||||
val = length & 0x03
|
|
||||||
if val == 3:
|
|
||||||
k1 = (ord(data[roundedEnd + 2]) & 0xff) << 16
|
|
||||||
# fallthrough
|
|
||||||
if val in [2, 3]:
|
|
||||||
k1 |= (ord(data[roundedEnd + 1]) & 0xff) << 8
|
|
||||||
# fallthrough
|
|
||||||
if val in [1, 2, 3]:
|
|
||||||
k1 |= ord(data[roundedEnd]) & 0xff
|
|
||||||
k1 *= c1
|
|
||||||
k1 = (k1 << 15) | ((k1 & 0xffffffff) >> 17) # ROTL32(k1,15)
|
|
||||||
k1 *= c2
|
|
||||||
h1 ^= k1
|
|
||||||
|
|
||||||
# finalization
|
|
||||||
h1 ^= length
|
|
||||||
|
|
||||||
# fmix(h1)
|
|
||||||
h1 ^= ((h1 & 0xffffffff) >> 16)
|
|
||||||
h1 *= 0x85ebca6b
|
|
||||||
h1 ^= ((h1 & 0xffffffff) >> 13)
|
|
||||||
h1 *= 0xc2b2ae35
|
|
||||||
h1 ^= ((h1 & 0xffffffff) >> 16)
|
|
||||||
|
|
||||||
return h1 & 0xffffffff
|
|
|
@ -1,46 +0,0 @@
|
||||||
from pymemcache.client.murmur3 import murmur3_32
|
|
||||||
|
|
||||||
|
|
||||||
class RendezvousHash(object):
|
|
||||||
"""
|
|
||||||
Implements the Highest Random Weight (HRW) hashing algorithm most
|
|
||||||
commonly referred to as rendezvous hashing.
|
|
||||||
|
|
||||||
Originally developed as part of python-clandestined.
|
|
||||||
|
|
||||||
Copyright (c) 2014 Ernest W. Durbin III
|
|
||||||
"""
|
|
||||||
def __init__(self, nodes=None, seed=0, hash_function=murmur3_32):
|
|
||||||
"""
|
|
||||||
Constructor.
|
|
||||||
"""
|
|
||||||
self.nodes = []
|
|
||||||
self.seed = seed
|
|
||||||
if nodes is not None:
|
|
||||||
self.nodes = nodes
|
|
||||||
self.hash_function = lambda x: hash_function(x, seed)
|
|
||||||
|
|
||||||
def add_node(self, node):
|
|
||||||
if node not in self.nodes:
|
|
||||||
self.nodes.append(node)
|
|
||||||
|
|
||||||
def remove_node(self, node):
|
|
||||||
if node in self.nodes:
|
|
||||||
self.nodes.remove(node)
|
|
||||||
else:
|
|
||||||
raise ValueError("No such node %s to remove" % (node))
|
|
||||||
|
|
||||||
def get_node(self, key):
|
|
||||||
high_score = -1
|
|
||||||
winner = None
|
|
||||||
|
|
||||||
for node in self.nodes:
|
|
||||||
score = self.hash_function(
|
|
||||||
"%s-%s" % (str(node), str(key)))
|
|
||||||
|
|
||||||
if score > high_score:
|
|
||||||
(high_score, winner) = (score, node)
|
|
||||||
elif score == high_score:
|
|
||||||
(high_score, winner) = (score, max(str(node), str(winner)))
|
|
||||||
|
|
||||||
return winner
|
|
|
@ -1,40 +0,0 @@
|
||||||
class MemcacheError(Exception):
|
|
||||||
"Base exception class"
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MemcacheClientError(MemcacheError):
|
|
||||||
"""Raised when memcached fails to parse the arguments to a request, likely
|
|
||||||
due to a malformed key and/or value, a bug in this library, or a version
|
|
||||||
mismatch with memcached."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MemcacheUnknownCommandError(MemcacheClientError):
|
|
||||||
"""Raised when memcached fails to parse a request, likely due to a bug in
|
|
||||||
this library or a version mismatch with memcached."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MemcacheIllegalInputError(MemcacheClientError):
|
|
||||||
"""Raised when a key or value is not legal for Memcache (see the class docs
|
|
||||||
for Client for more details)."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MemcacheServerError(MemcacheError):
|
|
||||||
"""Raised when memcached reports a failure while processing a request,
|
|
||||||
likely due to a bug or transient issue in memcached."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MemcacheUnknownError(MemcacheError):
|
|
||||||
"""Raised when this library receives a response from memcached that it
|
|
||||||
cannot parse, likely due to a bug in this library or a version mismatch
|
|
||||||
with memcached."""
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class MemcacheUnexpectedCloseError(MemcacheServerError):
|
|
||||||
"Raised when the connection with memcached closes unexpectedly."
|
|
||||||
pass
|
|
|
@ -1,123 +0,0 @@
|
||||||
# Copyright 2012 Pinterest.com
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""
|
|
||||||
A client for falling back to older memcached servers when performing reads.
|
|
||||||
|
|
||||||
It is sometimes necessary to deploy memcached on new servers, or with a
|
|
||||||
different configuration. In theses cases, it is undesirable to start up an
|
|
||||||
empty memcached server and point traffic to it, since the cache will be cold,
|
|
||||||
and the backing store will have a large increase in traffic.
|
|
||||||
|
|
||||||
This class attempts to solve that problem by providing an interface identical
|
|
||||||
to the Client interface, but which can fall back to older memcached servers
|
|
||||||
when reads to the primary server fail. The approach for upgrading memcached
|
|
||||||
servers or configuration then becomes:
|
|
||||||
|
|
||||||
1. Deploy a new host (or fleet) with memcached, possibly with a new
|
|
||||||
configuration.
|
|
||||||
2. From your application servers, use FallbackClient to write and read from
|
|
||||||
the new cluster, and to read from the old cluster when there is a miss in
|
|
||||||
the new cluster.
|
|
||||||
3. Wait until the new cache is warm enough to support the load.
|
|
||||||
4. Switch from FallbackClient to a regular Client library for doing all
|
|
||||||
reads and writes to the new cluster.
|
|
||||||
5. Take down the old cluster.
|
|
||||||
|
|
||||||
Best Practices:
|
|
||||||
---------------
|
|
||||||
- Make sure that the old client has "ignore_exc" set to True, so that it
|
|
||||||
treats failures like cache misses. That will allow you to take down the
|
|
||||||
old cluster before you switch away from FallbackClient.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class FallbackClient(object):
|
|
||||||
def __init__(self, caches):
|
|
||||||
assert len(caches) > 0
|
|
||||||
self.caches = caches
|
|
||||||
|
|
||||||
def close(self):
|
|
||||||
"Close each of the memcached clients"
|
|
||||||
for cache in self.caches:
|
|
||||||
cache.close()
|
|
||||||
|
|
||||||
def set(self, key, value, expire=0, noreply=True):
|
|
||||||
self.caches[0].set(key, value, expire, noreply)
|
|
||||||
|
|
||||||
def add(self, key, value, expire=0, noreply=True):
|
|
||||||
self.caches[0].add(key, value, expire, noreply)
|
|
||||||
|
|
||||||
def replace(self, key, value, expire=0, noreply=True):
|
|
||||||
self.caches[0].replace(key, value, expire, noreply)
|
|
||||||
|
|
||||||
def append(self, key, value, expire=0, noreply=True):
|
|
||||||
self.caches[0].append(key, value, expire, noreply)
|
|
||||||
|
|
||||||
def prepend(self, key, value, expire=0, noreply=True):
|
|
||||||
self.caches[0].prepend(key, value, expire, noreply)
|
|
||||||
|
|
||||||
def cas(self, key, value, cas, expire=0, noreply=True):
|
|
||||||
self.caches[0].cas(key, value, cas, expire, noreply)
|
|
||||||
|
|
||||||
def get(self, key):
|
|
||||||
for cache in self.caches:
|
|
||||||
result = cache.get(key)
|
|
||||||
if result is not None:
|
|
||||||
return result
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_many(self, keys):
|
|
||||||
for cache in self.caches:
|
|
||||||
result = cache.get_many(keys)
|
|
||||||
if result:
|
|
||||||
return result
|
|
||||||
return []
|
|
||||||
|
|
||||||
def gets(self, key):
|
|
||||||
for cache in self.caches:
|
|
||||||
result = cache.gets(key)
|
|
||||||
if result is not None:
|
|
||||||
return result
|
|
||||||
return None
|
|
||||||
|
|
||||||
def gets_many(self, keys):
|
|
||||||
for cache in self.caches:
|
|
||||||
result = cache.gets_many(keys)
|
|
||||||
if result:
|
|
||||||
return result
|
|
||||||
return []
|
|
||||||
|
|
||||||
def delete(self, key, noreply=True):
|
|
||||||
self.caches[0].delete(key, noreply)
|
|
||||||
|
|
||||||
def incr(self, key, value, noreply=True):
|
|
||||||
self.caches[0].incr(key, value, noreply)
|
|
||||||
|
|
||||||
def decr(self, key, value, noreply=True):
|
|
||||||
self.caches[0].decr(key, value, noreply)
|
|
||||||
|
|
||||||
def touch(self, key, expire=0, noreply=True):
|
|
||||||
self.caches[0].touch(key, expire, noreply)
|
|
||||||
|
|
||||||
def stats(self):
|
|
||||||
# TODO: ??
|
|
||||||
pass
|
|
||||||
|
|
||||||
def flush_all(self, delay=0, noreply=True):
|
|
||||||
self.caches[0].flush_all(delay, noreply)
|
|
||||||
|
|
||||||
def quit(self):
|
|
||||||
# TODO: ??
|
|
||||||
pass
|
|
|
@ -1,114 +0,0 @@
|
||||||
# Copyright 2015 Yahoo.com
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import collections
|
|
||||||
import contextlib
|
|
||||||
import sys
|
|
||||||
import threading
|
|
||||||
|
|
||||||
import six
|
|
||||||
|
|
||||||
|
|
||||||
class ObjectPool(object):
|
|
||||||
"""A pool of objects that release/creates/destroys as needed."""
|
|
||||||
|
|
||||||
def __init__(self, obj_creator,
|
|
||||||
after_remove=None, max_size=None,
|
|
||||||
lock_generator=None):
|
|
||||||
self._used_objs = collections.deque()
|
|
||||||
self._free_objs = collections.deque()
|
|
||||||
self._obj_creator = obj_creator
|
|
||||||
if lock_generator is None:
|
|
||||||
self._lock = threading.Lock()
|
|
||||||
else:
|
|
||||||
self._lock = lock_generator()
|
|
||||||
self._after_remove = after_remove
|
|
||||||
max_size = max_size or 2 ** 31
|
|
||||||
if not isinstance(max_size, six.integer_types) or max_size < 0:
|
|
||||||
raise ValueError('"max_size" must be a positive integer')
|
|
||||||
self.max_size = max_size
|
|
||||||
|
|
||||||
@property
|
|
||||||
def used(self):
|
|
||||||
return tuple(self._used_objs)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def free(self):
|
|
||||||
return tuple(self._free_objs)
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
|
||||||
def get_and_release(self, destroy_on_fail=False):
|
|
||||||
obj = self.get()
|
|
||||||
try:
|
|
||||||
yield obj
|
|
||||||
except Exception:
|
|
||||||
exc_info = sys.exc_info()
|
|
||||||
if not destroy_on_fail:
|
|
||||||
self.release(obj)
|
|
||||||
else:
|
|
||||||
self.destroy(obj)
|
|
||||||
six.reraise(exc_info[0], exc_info[1], exc_info[2])
|
|
||||||
self.release(obj)
|
|
||||||
|
|
||||||
def get(self):
|
|
||||||
with self._lock:
|
|
||||||
if not self._free_objs:
|
|
||||||
curr_count = len(self._used_objs)
|
|
||||||
if curr_count >= self.max_size:
|
|
||||||
raise RuntimeError("Too many objects,"
|
|
||||||
" %s >= %s" % (curr_count,
|
|
||||||
self.max_size))
|
|
||||||
obj = self._obj_creator()
|
|
||||||
self._used_objs.append(obj)
|
|
||||||
return obj
|
|
||||||
else:
|
|
||||||
obj = self._free_objs.pop()
|
|
||||||
self._used_objs.append(obj)
|
|
||||||
return obj
|
|
||||||
|
|
||||||
def destroy(self, obj, silent=True):
|
|
||||||
was_dropped = False
|
|
||||||
with self._lock:
|
|
||||||
try:
|
|
||||||
self._used_objs.remove(obj)
|
|
||||||
was_dropped = True
|
|
||||||
except ValueError:
|
|
||||||
if not silent:
|
|
||||||
raise
|
|
||||||
if was_dropped and self._after_remove is not None:
|
|
||||||
self._after_remove(obj)
|
|
||||||
|
|
||||||
def release(self, obj, silent=True):
|
|
||||||
with self._lock:
|
|
||||||
try:
|
|
||||||
self._used_objs.remove(obj)
|
|
||||||
self._free_objs.append(obj)
|
|
||||||
except ValueError:
|
|
||||||
if not silent:
|
|
||||||
raise
|
|
||||||
|
|
||||||
def clear(self):
|
|
||||||
if self._after_remove is not None:
|
|
||||||
needs_destroy = []
|
|
||||||
with self._lock:
|
|
||||||
needs_destroy.extend(self._used_objs)
|
|
||||||
needs_destroy.extend(self._free_objs)
|
|
||||||
self._free_objs.clear()
|
|
||||||
self._used_objs.clear()
|
|
||||||
for obj in needs_destroy:
|
|
||||||
self._after_remove(obj)
|
|
||||||
else:
|
|
||||||
with self._lock:
|
|
||||||
self._free_objs.clear()
|
|
||||||
self._used_objs.clear()
|
|
|
@ -1,69 +0,0 @@
|
||||||
# Copyright 2012 Pinterest.com
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import pickle
|
|
||||||
|
|
||||||
try:
|
|
||||||
from cStringIO import StringIO
|
|
||||||
except ImportError:
|
|
||||||
from StringIO import StringIO
|
|
||||||
|
|
||||||
|
|
||||||
FLAG_PICKLE = 1 << 0
|
|
||||||
FLAG_INTEGER = 1 << 1
|
|
||||||
FLAG_LONG = 1 << 2
|
|
||||||
|
|
||||||
|
|
||||||
def python_memcache_serializer(key, value):
|
|
||||||
flags = 0
|
|
||||||
|
|
||||||
if isinstance(value, str):
|
|
||||||
pass
|
|
||||||
elif isinstance(value, int):
|
|
||||||
flags |= FLAG_INTEGER
|
|
||||||
value = "%d" % value
|
|
||||||
elif isinstance(value, long):
|
|
||||||
flags |= FLAG_LONG
|
|
||||||
value = "%d" % value
|
|
||||||
else:
|
|
||||||
flags |= FLAG_PICKLE
|
|
||||||
output = StringIO()
|
|
||||||
pickler = pickle.Pickler(output, 0)
|
|
||||||
pickler.dump(value)
|
|
||||||
value = output.getvalue()
|
|
||||||
|
|
||||||
return value, flags
|
|
||||||
|
|
||||||
|
|
||||||
def python_memcache_deserializer(key, value, flags):
|
|
||||||
if flags == 0:
|
|
||||||
return value
|
|
||||||
|
|
||||||
if flags & FLAG_INTEGER:
|
|
||||||
return int(value)
|
|
||||||
|
|
||||||
if flags & FLAG_LONG:
|
|
||||||
return long(value)
|
|
||||||
|
|
||||||
if flags & FLAG_PICKLE:
|
|
||||||
try:
|
|
||||||
buf = StringIO(value)
|
|
||||||
unpickler = pickle.Unpickler(buf)
|
|
||||||
return unpickler.load()
|
|
||||||
except Exception:
|
|
||||||
logging.info('Pickle error', exc_info=True)
|
|
||||||
return None
|
|
||||||
|
|
||||||
return value
|
|
|
@ -1,7 +1,7 @@
|
||||||
'''
|
"""
|
||||||
PyMySQL: A pure-Python MySQL client library.
|
PyMySQL: A pure-Python MySQL client library.
|
||||||
|
|
||||||
Copyright (c) 2010, 2013 PyMySQL contributors
|
Copyright (c) 2010-2016 PyMySQL contributors
|
||||||
|
|
||||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
of this software and associated documentation files (the "Software"), to deal
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
@ -20,30 +20,29 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
|
||||||
THE SOFTWARE.
|
THE SOFTWARE.
|
||||||
|
"""
|
||||||
'''
|
|
||||||
|
|
||||||
VERSION = (0, 6, 7, None)
|
|
||||||
|
|
||||||
from ._compat import text_type, JYTHON, IRONPYTHON
|
|
||||||
from .constants import FIELD_TYPE
|
|
||||||
from .converters import escape_dict, escape_sequence, escape_string
|
|
||||||
from .err import Warning, Error, InterfaceError, DataError, \
|
|
||||||
DatabaseError, OperationalError, IntegrityError, InternalError, \
|
|
||||||
NotSupportedError, ProgrammingError, MySQLError
|
|
||||||
from .times import Date, Time, Timestamp, \
|
|
||||||
DateFromTicks, TimeFromTicks, TimestampFromTicks
|
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
from ._compat import PY2
|
||||||
|
from .constants import FIELD_TYPE
|
||||||
|
from .converters import escape_dict, escape_sequence, escape_string
|
||||||
|
from .err import (
|
||||||
|
Warning, Error, InterfaceError, DataError,
|
||||||
|
DatabaseError, OperationalError, IntegrityError, InternalError,
|
||||||
|
NotSupportedError, ProgrammingError, MySQLError)
|
||||||
|
from .times import (
|
||||||
|
Date, Time, Timestamp,
|
||||||
|
DateFromTicks, TimeFromTicks, TimestampFromTicks)
|
||||||
|
|
||||||
|
|
||||||
|
VERSION = (0, 7, 11, None)
|
||||||
threadsafety = 1
|
threadsafety = 1
|
||||||
apilevel = "2.0"
|
apilevel = "2.0"
|
||||||
paramstyle = "format"
|
paramstyle = "pyformat"
|
||||||
|
|
||||||
|
|
||||||
class DBAPISet(frozenset):
|
class DBAPISet(frozenset):
|
||||||
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
def __ne__(self, other):
|
||||||
if isinstance(other, set):
|
if isinstance(other, set):
|
||||||
return frozenset.__ne__(self, other)
|
return frozenset.__ne__(self, other)
|
||||||
|
@ -73,11 +72,14 @@ TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME])
|
||||||
DATETIME = TIMESTAMP
|
DATETIME = TIMESTAMP
|
||||||
ROWID = DBAPISet()
|
ROWID = DBAPISet()
|
||||||
|
|
||||||
|
|
||||||
def Binary(x):
|
def Binary(x):
|
||||||
"""Return x as a binary type."""
|
"""Return x as a binary type."""
|
||||||
if isinstance(x, text_type) and not (JYTHON or IRONPYTHON):
|
if PY2:
|
||||||
return x.encode()
|
return bytearray(x)
|
||||||
return bytes(x)
|
else:
|
||||||
|
return bytes(x)
|
||||||
|
|
||||||
|
|
||||||
def Connect(*args, **kwargs):
|
def Connect(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
@ -87,27 +89,26 @@ def Connect(*args, **kwargs):
|
||||||
from .connections import Connection
|
from .connections import Connection
|
||||||
return Connection(*args, **kwargs)
|
return Connection(*args, **kwargs)
|
||||||
|
|
||||||
from pymysql import connections as _orig_conn
|
from . import connections as _orig_conn
|
||||||
if _orig_conn.Connection.__init__.__doc__ is not None:
|
if _orig_conn.Connection.__init__.__doc__ is not None:
|
||||||
Connect.__doc__ = _orig_conn.Connection.__init__.__doc__ + ("""
|
Connect.__doc__ = _orig_conn.Connection.__init__.__doc__
|
||||||
See connections.Connection.__init__() for information about defaults.
|
|
||||||
""")
|
|
||||||
del _orig_conn
|
del _orig_conn
|
||||||
|
|
||||||
|
|
||||||
def get_client_info(): # for MySQLdb compatibility
|
def get_client_info(): # for MySQLdb compatibility
|
||||||
return '.'.join(map(str, VERSION))
|
return '.'.join(map(str, VERSION))
|
||||||
|
|
||||||
connect = Connection = Connect
|
connect = Connection = Connect
|
||||||
|
|
||||||
# we include a doctored version_info here for MySQLdb compatibility
|
# we include a doctored version_info here for MySQLdb compatibility
|
||||||
version_info = (1,2,2,"final",0)
|
version_info = (1,2,6,"final",0)
|
||||||
|
|
||||||
NULL = "NULL"
|
NULL = "NULL"
|
||||||
|
|
||||||
__version__ = get_client_info()
|
__version__ = get_client_info()
|
||||||
|
|
||||||
def thread_safe():
|
def thread_safe():
|
||||||
return True # match MySQLdb.thread_safe()
|
return True # match MySQLdb.thread_safe()
|
||||||
|
|
||||||
def install_as_MySQLdb():
|
def install_as_MySQLdb():
|
||||||
"""
|
"""
|
||||||
|
@ -116,6 +117,7 @@ def install_as_MySQLdb():
|
||||||
"""
|
"""
|
||||||
sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"]
|
sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"]
|
||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date',
|
'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date',
|
||||||
'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks',
|
'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks',
|
||||||
|
@ -128,6 +130,5 @@ __all__ = [
|
||||||
'paramstyle', 'threadsafety', 'version_info',
|
'paramstyle', 'threadsafety', 'version_info',
|
||||||
|
|
||||||
"install_as_MySQLdb",
|
"install_as_MySQLdb",
|
||||||
|
"NULL", "__version__",
|
||||||
"NULL","__version__",
|
]
|
||||||
]
|
|
||||||
|
|
|
@ -7,12 +7,15 @@ IRONPYTHON = sys.platform == 'cli'
|
||||||
CPYTHON = not PYPY and not JYTHON and not IRONPYTHON
|
CPYTHON = not PYPY and not JYTHON and not IRONPYTHON
|
||||||
|
|
||||||
if PY2:
|
if PY2:
|
||||||
|
import __builtin__
|
||||||
range_type = xrange
|
range_type = xrange
|
||||||
text_type = unicode
|
text_type = unicode
|
||||||
long_type = long
|
long_type = long
|
||||||
str_type = basestring
|
str_type = basestring
|
||||||
|
unichr = __builtin__.unichr
|
||||||
else:
|
else:
|
||||||
range_type = range
|
range_type = range
|
||||||
text_type = str
|
text_type = str
|
||||||
long_type = int
|
long_type = int
|
||||||
str_type = str
|
str_type = str
|
||||||
|
unichr = chr
|
||||||
|
|
|
@ -11,6 +11,10 @@ class Charset(object):
|
||||||
self.id, self.name, self.collation = id, name, collation
|
self.id, self.name, self.collation = id, name, collation
|
||||||
self.is_default = is_default == 'Yes'
|
self.is_default = is_default == 'Yes'
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return "Charset(id=%s, name=%r, collation=%r)" % (
|
||||||
|
self.id, self.name, self.collation)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def encoding(self):
|
def encoding(self):
|
||||||
name = self.name
|
name = self.name
|
||||||
|
@ -249,6 +253,10 @@ _charsets.add(Charset(240, 'utf8mb4', 'utf8mb4_persian_ci', ''))
|
||||||
_charsets.add(Charset(241, 'utf8mb4', 'utf8mb4_esperanto_ci', ''))
|
_charsets.add(Charset(241, 'utf8mb4', 'utf8mb4_esperanto_ci', ''))
|
||||||
_charsets.add(Charset(242, 'utf8mb4', 'utf8mb4_hungarian_ci', ''))
|
_charsets.add(Charset(242, 'utf8mb4', 'utf8mb4_hungarian_ci', ''))
|
||||||
_charsets.add(Charset(243, 'utf8mb4', 'utf8mb4_sinhala_ci', ''))
|
_charsets.add(Charset(243, 'utf8mb4', 'utf8mb4_sinhala_ci', ''))
|
||||||
|
_charsets.add(Charset(244, 'utf8mb4', 'utf8mb4_german2_ci', ''))
|
||||||
|
_charsets.add(Charset(245, 'utf8mb4', 'utf8mb4_croatian_ci', ''))
|
||||||
|
_charsets.add(Charset(246, 'utf8mb4', 'utf8mb4_unicode_520_ci', ''))
|
||||||
|
_charsets.add(Charset(247, 'utf8mb4', 'utf8mb4_vietnamese_ci', ''))
|
||||||
|
|
||||||
|
|
||||||
charset_by_name = _charsets.by_name
|
charset_by_name = _charsets.by_name
|
||||||
|
|
|
@ -17,9 +17,8 @@ import traceback
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from .charset import MBLENGTH, charset_by_name, charset_by_id
|
from .charset import MBLENGTH, charset_by_name, charset_by_id
|
||||||
from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS
|
from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
|
||||||
from .converters import (
|
from .converters import escape_item, escape_string, through, conversions as _conv
|
||||||
escape_item, encoders, decoders, escape_string, through)
|
|
||||||
from .cursors import Cursor
|
from .cursors import Cursor
|
||||||
from .optionfile import Parser
|
from .optionfile import Parser
|
||||||
from .util import byte2int, int2byte
|
from .util import byte2int, int2byte
|
||||||
|
@ -36,7 +35,8 @@ try:
|
||||||
import getpass
|
import getpass
|
||||||
DEFAULT_USER = getpass.getuser()
|
DEFAULT_USER = getpass.getuser()
|
||||||
del getpass
|
del getpass
|
||||||
except ImportError:
|
except (ImportError, KeyError):
|
||||||
|
# KeyError occurs when there's no entry in OS database for a current user.
|
||||||
DEFAULT_USER = None
|
DEFAULT_USER = None
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,26 +117,24 @@ def dump_packet(data): # pragma: no cover
|
||||||
|
|
||||||
try:
|
try:
|
||||||
print("packet length:", len(data))
|
print("packet length:", len(data))
|
||||||
print("method call[1]:", sys._getframe(1).f_code.co_name)
|
for i in range(1, 6):
|
||||||
print("method call[2]:", sys._getframe(2).f_code.co_name)
|
f = sys._getframe(i)
|
||||||
print("method call[3]:", sys._getframe(3).f_code.co_name)
|
print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno))
|
||||||
print("method call[4]:", sys._getframe(4).f_code.co_name)
|
print("-" * 66)
|
||||||
print("method call[5]:", sys._getframe(5).f_code.co_name)
|
|
||||||
print("-" * 88)
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
dump_data = [data[i:i+16] for i in range_type(0, min(len(data), 256), 16)]
|
dump_data = [data[i:i+16] for i in range_type(0, min(len(data), 256), 16)]
|
||||||
for d in dump_data:
|
for d in dump_data:
|
||||||
print(' '.join(map(lambda x: "{:02X}".format(byte2int(x)), d)) +
|
print(' '.join(map(lambda x: "{:02X}".format(byte2int(x)), d)) +
|
||||||
' ' * (16 - len(d)) + ' ' * 2 +
|
' ' * (16 - len(d)) + ' ' * 2 +
|
||||||
' '.join(map(lambda x: "{}".format(is_ascii(x)), d)))
|
''.join(map(lambda x: "{}".format(is_ascii(x)), d)))
|
||||||
print("-" * 88)
|
print("-" * 66)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def _scramble(password, message):
|
def _scramble(password, message):
|
||||||
if not password:
|
if not password:
|
||||||
return b'\0'
|
return b''
|
||||||
if DEBUG: print('password=' + str(password))
|
if DEBUG: print('password=' + str(password))
|
||||||
stage1 = sha_new(password).digest()
|
stage1 = sha_new(password).digest()
|
||||||
stage2 = sha_new(stage1).digest()
|
stage2 = sha_new(stage1).digest()
|
||||||
|
@ -149,7 +147,7 @@ def _scramble(password, message):
|
||||||
|
|
||||||
def _my_crypt(message1, message2):
|
def _my_crypt(message1, message2):
|
||||||
length = len(message1)
|
length = len(message1)
|
||||||
result = struct.pack('B', length)
|
result = b''
|
||||||
for i in range_type(length):
|
for i in range_type(length):
|
||||||
x = (struct.unpack('B', message1[i:i+1])[0] ^
|
x = (struct.unpack('B', message1[i:i+1])[0] ^
|
||||||
struct.unpack('B', message2[i:i+1])[0])
|
struct.unpack('B', message2[i:i+1])[0])
|
||||||
|
@ -196,7 +194,8 @@ def _hash_password_323(password):
|
||||||
add = 7
|
add = 7
|
||||||
nr2 = 0x12345671
|
nr2 = 0x12345671
|
||||||
|
|
||||||
for c in [byte2int(x) for x in password if x not in (' ', '\t')]:
|
# x in py3 is numbers, p27 is chars
|
||||||
|
for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]:
|
||||||
nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
|
nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
|
||||||
nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
|
nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
|
||||||
add = (add + c) & 0xFFFFFFFF
|
add = (add + c) & 0xFFFFFFFF
|
||||||
|
@ -209,6 +208,20 @@ def _hash_password_323(password):
|
||||||
def pack_int24(n):
|
def pack_int24(n):
|
||||||
return struct.pack('<I', n)[:3]
|
return struct.pack('<I', n)[:3]
|
||||||
|
|
||||||
|
# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
|
||||||
|
def lenenc_int(i):
|
||||||
|
if (i < 0):
|
||||||
|
raise ValueError("Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i)
|
||||||
|
elif (i < 0xfb):
|
||||||
|
return int2byte(i)
|
||||||
|
elif (i < (1 << 16)):
|
||||||
|
return b'\xfc' + struct.pack('<H', i)
|
||||||
|
elif (i < (1 << 24)):
|
||||||
|
return b'\xfd' + struct.pack('<I', i)[:3]
|
||||||
|
elif (i < (1 << 64)):
|
||||||
|
return b'\xfe' + struct.pack('<Q', i)
|
||||||
|
else:
|
||||||
|
raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64)))
|
||||||
|
|
||||||
class MysqlPacket(object):
|
class MysqlPacket(object):
|
||||||
"""Representation of a MySQL response packet.
|
"""Representation of a MySQL response packet.
|
||||||
|
@ -303,6 +316,14 @@ class MysqlPacket(object):
|
||||||
self._position += 8
|
self._position += 8
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def read_string(self):
|
||||||
|
end_pos = self._data.find(b'\0', self._position)
|
||||||
|
if end_pos < 0:
|
||||||
|
return None
|
||||||
|
result = self._data[self._position:end_pos]
|
||||||
|
self._position = end_pos + 1
|
||||||
|
return result
|
||||||
|
|
||||||
def read_length_encoded_integer(self):
|
def read_length_encoded_integer(self):
|
||||||
"""Read a 'Length Coded Binary' number from the data buffer.
|
"""Read a 'Length Coded Binary' number from the data buffer.
|
||||||
|
|
||||||
|
@ -340,13 +361,18 @@ class MysqlPacket(object):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def is_ok_packet(self):
|
def is_ok_packet(self):
|
||||||
return self._data[0:1] == b'\0'
|
# https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
|
||||||
|
return self._data[0:1] == b'\0' and len(self._data) >= 7
|
||||||
|
|
||||||
def is_eof_packet(self):
|
def is_eof_packet(self):
|
||||||
# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
|
# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
|
||||||
# Caution: \xFE may be LengthEncodedInteger.
|
# Caution: \xFE may be LengthEncodedInteger.
|
||||||
# If \xFE is LengthEncodedInteger header, 8bytes followed.
|
# If \xFE is LengthEncodedInteger header, 8bytes followed.
|
||||||
return len(self._data) < 9 and self._data[0:1] == b'\xfe'
|
return self._data[0:1] == b'\xfe' and len(self._data) < 9
|
||||||
|
|
||||||
|
def is_auth_switch_request(self):
|
||||||
|
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
|
||||||
|
return self._data[0:1] == b'\xfe'
|
||||||
|
|
||||||
def is_resultset_packet(self):
|
def is_resultset_packet(self):
|
||||||
field_count = ord(self._data[0:1])
|
field_count = ord(self._data[0:1])
|
||||||
|
@ -379,9 +405,9 @@ class FieldDescriptorPacket(MysqlPacket):
|
||||||
|
|
||||||
def __init__(self, data, encoding):
|
def __init__(self, data, encoding):
|
||||||
MysqlPacket.__init__(self, data, encoding)
|
MysqlPacket.__init__(self, data, encoding)
|
||||||
self.__parse_field_descriptor(encoding)
|
self._parse_field_descriptor(encoding)
|
||||||
|
|
||||||
def __parse_field_descriptor(self, encoding):
|
def _parse_field_descriptor(self, encoding):
|
||||||
"""Parse the 'Field Descriptor' (Metadata) packet.
|
"""Parse the 'Field Descriptor' (Metadata) packet.
|
||||||
|
|
||||||
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
|
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
|
||||||
|
@ -494,20 +520,23 @@ class Connection(object):
|
||||||
|
|
||||||
The proper way to get an instance of this class is to call
|
The proper way to get an instance of this class is to call
|
||||||
connect().
|
connect().
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
socket = None
|
_sock = None
|
||||||
|
_auth_plugin_name = ''
|
||||||
|
_closed = False
|
||||||
|
|
||||||
def __init__(self, host="localhost", user=None, password="",
|
def __init__(self, host=None, user=None, password="",
|
||||||
database=None, port=3306, unix_socket=None,
|
database=None, port=0, unix_socket=None,
|
||||||
charset='', sql_mode=None,
|
charset='', sql_mode=None,
|
||||||
read_default_file=None, conv=decoders, use_unicode=None,
|
read_default_file=None, conv=None, use_unicode=None,
|
||||||
client_flag=0, cursorclass=Cursor, init_command=None,
|
client_flag=0, cursorclass=Cursor, init_command=None,
|
||||||
connect_timeout=None, ssl=None, read_default_group=None,
|
connect_timeout=10, ssl=None, read_default_group=None,
|
||||||
compress=None, named_pipe=None, no_delay=None,
|
compress=None, named_pipe=None, no_delay=None,
|
||||||
autocommit=False, db=None, passwd=None, local_infile=False,
|
autocommit=False, db=None, passwd=None, local_infile=False,
|
||||||
max_allowed_packet=16*1024*1024, defer_connect=False):
|
max_allowed_packet=16*1024*1024, defer_connect=False,
|
||||||
|
auth_plugin_map={}, read_timeout=None, write_timeout=None,
|
||||||
|
bind_address=None):
|
||||||
"""
|
"""
|
||||||
Establish a connection to the MySQL database. Accepts several
|
Establish a connection to the MySQL database. Accepts several
|
||||||
arguments:
|
arguments:
|
||||||
|
@ -516,15 +545,19 @@ class Connection(object):
|
||||||
user: Username to log in as
|
user: Username to log in as
|
||||||
password: Password to use.
|
password: Password to use.
|
||||||
database: Database to use, None to not use a particular one.
|
database: Database to use, None to not use a particular one.
|
||||||
port: MySQL port to use, default is usually OK.
|
port: MySQL port to use, default is usually OK. (default: 3306)
|
||||||
|
bind_address: When the client has multiple network interfaces, specify
|
||||||
|
the interface from which to connect to the host. Argument can be
|
||||||
|
a hostname or an IP address.
|
||||||
unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
|
unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
|
||||||
charset: Charset you want to use.
|
charset: Charset you want to use.
|
||||||
sql_mode: Default SQL_MODE to use.
|
sql_mode: Default SQL_MODE to use.
|
||||||
read_default_file:
|
read_default_file:
|
||||||
Specifies my.cnf file to read these parameters from under the [client] section.
|
Specifies my.cnf file to read these parameters from under the [client] section.
|
||||||
conv:
|
conv:
|
||||||
Decoders dictionary to use instead of the default one.
|
Conversion dictionary to use instead of the default one.
|
||||||
This is used to provide custom marshalling of types. See converters.
|
This is used to provide custom marshalling and unmarshaling of types.
|
||||||
|
See converters.
|
||||||
use_unicode:
|
use_unicode:
|
||||||
Whether or not to default to unicode strings.
|
Whether or not to default to unicode strings.
|
||||||
This option defaults to true for Py3k.
|
This option defaults to true for Py3k.
|
||||||
|
@ -532,27 +565,29 @@ class Connection(object):
|
||||||
cursorclass: Custom cursor class to use.
|
cursorclass: Custom cursor class to use.
|
||||||
init_command: Initial SQL statement to run when connection is established.
|
init_command: Initial SQL statement to run when connection is established.
|
||||||
connect_timeout: Timeout before throwing an exception when connecting.
|
connect_timeout: Timeout before throwing an exception when connecting.
|
||||||
|
(default: 10, min: 1, max: 31536000)
|
||||||
ssl:
|
ssl:
|
||||||
A dict of arguments similar to mysql_ssl_set()'s parameters.
|
A dict of arguments similar to mysql_ssl_set()'s parameters.
|
||||||
For now the capath and cipher arguments are not supported.
|
For now the capath and cipher arguments are not supported.
|
||||||
read_default_group: Group to read from in the configuration file.
|
read_default_group: Group to read from in the configuration file.
|
||||||
compress; Not supported
|
compress; Not supported
|
||||||
named_pipe: Not supported
|
named_pipe: Not supported
|
||||||
no_delay: Disable Nagle's algorithm on the socket. (deprecated, default: True)
|
|
||||||
autocommit: Autocommit mode. None means use server default. (default: False)
|
autocommit: Autocommit mode. None means use server default. (default: False)
|
||||||
local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
|
local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
|
||||||
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
|
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
|
||||||
|
Only used to limit size of "LOAD LOCAL INFILE" data packet smaller than default (16KB).
|
||||||
defer_connect: Don't explicitly connect on contruction - wait for connect call.
|
defer_connect: Don't explicitly connect on contruction - wait for connect call.
|
||||||
(default: False)
|
(default: False)
|
||||||
|
auth_plugin_map: A dict of plugin names to a class that processes that plugin.
|
||||||
|
The class will take the Connection object as the argument to the constructor.
|
||||||
|
The class needs an authenticate method taking an authentication packet as
|
||||||
|
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
|
||||||
|
(if no authenticate method) for returning a string from the user. (experimental)
|
||||||
db: Alias for database. (for compatibility to MySQLdb)
|
db: Alias for database. (for compatibility to MySQLdb)
|
||||||
passwd: Alias for password. (for compatibility to MySQLdb)
|
passwd: Alias for password. (for compatibility to MySQLdb)
|
||||||
"""
|
"""
|
||||||
if no_delay is not None:
|
if no_delay is not None:
|
||||||
warnings.warn("no_delay option is deprecated", DeprecationWarning)
|
warnings.warn("no_delay option is deprecated", DeprecationWarning)
|
||||||
no_delay = bool(no_delay)
|
|
||||||
else:
|
|
||||||
no_delay = True
|
|
||||||
|
|
||||||
if use_unicode is None and sys.version_info[0] > 2:
|
if use_unicode is None and sys.version_info[0] > 2:
|
||||||
use_unicode = True
|
use_unicode = True
|
||||||
|
@ -565,24 +600,10 @@ class Connection(object):
|
||||||
if compress or named_pipe:
|
if compress or named_pipe:
|
||||||
raise NotImplementedError("compress and named_pipe arguments are not supported")
|
raise NotImplementedError("compress and named_pipe arguments are not supported")
|
||||||
|
|
||||||
if local_infile:
|
self._local_infile = bool(local_infile)
|
||||||
|
if self._local_infile:
|
||||||
client_flag |= CLIENT.LOCAL_FILES
|
client_flag |= CLIENT.LOCAL_FILES
|
||||||
|
|
||||||
if ssl and ('capath' in ssl or 'cipher' in ssl):
|
|
||||||
raise NotImplementedError('ssl options capath and cipher are not supported')
|
|
||||||
|
|
||||||
self.ssl = False
|
|
||||||
if ssl:
|
|
||||||
if not SSL_ENABLED:
|
|
||||||
raise NotImplementedError("ssl module not found")
|
|
||||||
self.ssl = True
|
|
||||||
client_flag |= CLIENT.SSL
|
|
||||||
for k in ('key', 'cert', 'ca'):
|
|
||||||
v = None
|
|
||||||
if k in ssl:
|
|
||||||
v = ssl[k]
|
|
||||||
setattr(self, k, v)
|
|
||||||
|
|
||||||
if read_default_group and not read_default_file:
|
if read_default_group and not read_default_file:
|
||||||
if sys.platform.startswith("win"):
|
if sys.platform.startswith("win"):
|
||||||
read_default_file = "c:\\my.ini"
|
read_default_file = "c:\\my.ini"
|
||||||
|
@ -610,15 +631,40 @@ class Connection(object):
|
||||||
database = _config("database", database)
|
database = _config("database", database)
|
||||||
unix_socket = _config("socket", unix_socket)
|
unix_socket = _config("socket", unix_socket)
|
||||||
port = int(_config("port", port))
|
port = int(_config("port", port))
|
||||||
|
bind_address = _config("bind-address", bind_address)
|
||||||
charset = _config("default-character-set", charset)
|
charset = _config("default-character-set", charset)
|
||||||
|
if not ssl:
|
||||||
|
ssl = {}
|
||||||
|
if isinstance(ssl, dict):
|
||||||
|
for key in ["ca", "capath", "cert", "key", "cipher"]:
|
||||||
|
value = _config("ssl-" + key, ssl.get(key))
|
||||||
|
if value:
|
||||||
|
ssl[key] = value
|
||||||
|
|
||||||
self.host = host
|
self.ssl = False
|
||||||
self.port = port
|
if ssl:
|
||||||
|
if not SSL_ENABLED:
|
||||||
|
raise NotImplementedError("ssl module not found")
|
||||||
|
self.ssl = True
|
||||||
|
client_flag |= CLIENT.SSL
|
||||||
|
self.ctx = self._create_ssl_ctx(ssl)
|
||||||
|
|
||||||
|
self.host = host or "localhost"
|
||||||
|
self.port = port or 3306
|
||||||
self.user = user or DEFAULT_USER
|
self.user = user or DEFAULT_USER
|
||||||
self.password = password or ""
|
self.password = password or ""
|
||||||
self.db = database
|
self.db = database
|
||||||
self.no_delay = no_delay
|
|
||||||
self.unix_socket = unix_socket
|
self.unix_socket = unix_socket
|
||||||
|
self.bind_address = bind_address
|
||||||
|
if not (0 < connect_timeout <= 31536000):
|
||||||
|
raise ValueError("connect_timeout should be >0 and <=31536000")
|
||||||
|
self.connect_timeout = connect_timeout or None
|
||||||
|
if read_timeout is not None and read_timeout <= 0:
|
||||||
|
raise ValueError("read_timeout should be >= 0")
|
||||||
|
self._read_timeout = read_timeout
|
||||||
|
if write_timeout is not None and write_timeout <= 0:
|
||||||
|
raise ValueError("write_timeout should be >= 0")
|
||||||
|
self._write_timeout = write_timeout
|
||||||
if charset:
|
if charset:
|
||||||
self.charset = charset
|
self.charset = charset
|
||||||
self.use_unicode = True
|
self.use_unicode = True
|
||||||
|
@ -631,13 +677,12 @@ class Connection(object):
|
||||||
|
|
||||||
self.encoding = charset_by_name(self.charset).encoding
|
self.encoding = charset_by_name(self.charset).encoding
|
||||||
|
|
||||||
client_flag |= CLIENT.CAPABILITIES | CLIENT.MULTI_STATEMENTS
|
client_flag |= CLIENT.CAPABILITIES
|
||||||
if self.db:
|
if self.db:
|
||||||
client_flag |= CLIENT.CONNECT_WITH_DB
|
client_flag |= CLIENT.CONNECT_WITH_DB
|
||||||
self.client_flag = client_flag
|
self.client_flag = client_flag
|
||||||
|
|
||||||
self.cursorclass = cursorclass
|
self.cursorclass = cursorclass
|
||||||
self.connect_timeout = connect_timeout
|
|
||||||
|
|
||||||
self._result = None
|
self._result = None
|
||||||
self._affected_rows = 0
|
self._affected_rows = 0
|
||||||
|
@ -646,44 +691,68 @@ class Connection(object):
|
||||||
#: specified autocommit mode. None means use server default.
|
#: specified autocommit mode. None means use server default.
|
||||||
self.autocommit_mode = autocommit
|
self.autocommit_mode = autocommit
|
||||||
|
|
||||||
self.encoders = encoders # Need for MySQLdb compatibility.
|
if conv is None:
|
||||||
self.decoders = conv
|
conv = _conv
|
||||||
|
# Need for MySQLdb compatibility.
|
||||||
|
self.encoders = dict([(k, v) for (k, v) in conv.items() if type(k) is not int])
|
||||||
|
self.decoders = dict([(k, v) for (k, v) in conv.items() if type(k) is int])
|
||||||
self.sql_mode = sql_mode
|
self.sql_mode = sql_mode
|
||||||
self.init_command = init_command
|
self.init_command = init_command
|
||||||
self.max_allowed_packet = max_allowed_packet
|
self.max_allowed_packet = max_allowed_packet
|
||||||
|
self._auth_plugin_map = auth_plugin_map
|
||||||
if defer_connect:
|
if defer_connect:
|
||||||
self.socket = None
|
self._sock = None
|
||||||
else:
|
else:
|
||||||
self.connect()
|
self.connect()
|
||||||
|
|
||||||
|
def _create_ssl_ctx(self, sslp):
|
||||||
|
if isinstance(sslp, ssl.SSLContext):
|
||||||
|
return sslp
|
||||||
|
ca = sslp.get('ca')
|
||||||
|
capath = sslp.get('capath')
|
||||||
|
hasnoca = ca is None and capath is None
|
||||||
|
ctx = ssl.create_default_context(cafile=ca, capath=capath)
|
||||||
|
ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
|
||||||
|
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
|
||||||
|
if 'cert' in sslp:
|
||||||
|
ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
|
||||||
|
if 'cipher' in sslp:
|
||||||
|
ctx.set_ciphers(sslp['cipher'])
|
||||||
|
ctx.options |= ssl.OP_NO_SSLv2
|
||||||
|
ctx.options |= ssl.OP_NO_SSLv3
|
||||||
|
return ctx
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Send the quit message and close the socket"""
|
"""Send the quit message and close the socket"""
|
||||||
if self.socket is None:
|
if self._closed:
|
||||||
raise err.Error("Already closed")
|
raise err.Error("Already closed")
|
||||||
|
self._closed = True
|
||||||
|
if self._sock is None:
|
||||||
|
return
|
||||||
send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
|
send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
|
||||||
try:
|
try:
|
||||||
self._write_bytes(send_data)
|
self._write_bytes(send_data)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
sock = self.socket
|
self._force_close()
|
||||||
self.socket = None
|
|
||||||
self._rfile = None
|
|
||||||
sock.close()
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def open(self):
|
def open(self):
|
||||||
return self.socket is not None
|
return self._sock is not None
|
||||||
|
|
||||||
def __del__(self):
|
def _force_close(self):
|
||||||
if self.socket:
|
"""Close connection without QUIT message"""
|
||||||
|
if self._sock:
|
||||||
try:
|
try:
|
||||||
self.socket.close()
|
self._sock.close()
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
self.socket = None
|
self._sock = None
|
||||||
self._rfile = None
|
self._rfile = None
|
||||||
|
|
||||||
|
__del__ = _force_close
|
||||||
|
|
||||||
def autocommit(self, value):
|
def autocommit(self, value):
|
||||||
self.autocommit_mode = bool(value)
|
self.autocommit_mode = bool(value)
|
||||||
current = self.get_autocommit()
|
current = self.get_autocommit()
|
||||||
|
@ -731,19 +800,25 @@ class Connection(object):
|
||||||
return result.rows
|
return result.rows
|
||||||
|
|
||||||
def select_db(self, db):
|
def select_db(self, db):
|
||||||
'''Set current db'''
|
"""Set current db"""
|
||||||
self._execute_command(COMMAND.COM_INIT_DB, db)
|
self._execute_command(COMMAND.COM_INIT_DB, db)
|
||||||
self._read_ok_packet()
|
self._read_ok_packet()
|
||||||
|
|
||||||
def escape(self, obj, mapping=None):
|
def escape(self, obj, mapping=None):
|
||||||
"""Escape whatever value you pass to it"""
|
"""Escape whatever value you pass to it.
|
||||||
|
|
||||||
|
Non-standard, for internal use; do not use this in your applications.
|
||||||
|
"""
|
||||||
if isinstance(obj, str_type):
|
if isinstance(obj, str_type):
|
||||||
return "'" + self.escape_string(obj) + "'"
|
return "'" + self.escape_string(obj) + "'"
|
||||||
return escape_item(obj, self.charset, mapping=mapping)
|
return escape_item(obj, self.charset, mapping=mapping)
|
||||||
|
|
||||||
def literal(self, obj):
|
def literal(self, obj):
|
||||||
'''Alias for escape()'''
|
"""Alias for escape()
|
||||||
return self.escape(obj)
|
|
||||||
|
Non-standard, for internal use; do not use this in your applications.
|
||||||
|
"""
|
||||||
|
return self.escape(obj, self.encoders)
|
||||||
|
|
||||||
def escape_string(self, s):
|
def escape_string(self, s):
|
||||||
if (self.server_status &
|
if (self.server_status &
|
||||||
|
@ -795,7 +870,7 @@ class Connection(object):
|
||||||
|
|
||||||
def ping(self, reconnect=True):
|
def ping(self, reconnect=True):
|
||||||
"""Check if the server is alive"""
|
"""Check if the server is alive"""
|
||||||
if self.socket is None:
|
if self._sock is None:
|
||||||
if reconnect:
|
if reconnect:
|
||||||
self.connect()
|
self.connect()
|
||||||
reconnect = False
|
reconnect = False
|
||||||
|
@ -821,6 +896,7 @@ class Connection(object):
|
||||||
self.encoding = encoding
|
self.encoding = encoding
|
||||||
|
|
||||||
def connect(self, sock=None):
|
def connect(self, sock=None):
|
||||||
|
self._closed = False
|
||||||
try:
|
try:
|
||||||
if sock is None:
|
if sock is None:
|
||||||
if self.unix_socket and self.host in ('localhost', '127.0.0.1'):
|
if self.unix_socket and self.host in ('localhost', '127.0.0.1'):
|
||||||
|
@ -830,10 +906,14 @@ class Connection(object):
|
||||||
self.host_info = "Localhost via UNIX socket"
|
self.host_info = "Localhost via UNIX socket"
|
||||||
if DEBUG: print('connected using unix_socket')
|
if DEBUG: print('connected using unix_socket')
|
||||||
else:
|
else:
|
||||||
|
kwargs = {}
|
||||||
|
if self.bind_address is not None:
|
||||||
|
kwargs['source_address'] = (self.bind_address, 0)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
sock = socket.create_connection(
|
sock = socket.create_connection(
|
||||||
(self.host, self.port), self.connect_timeout)
|
(self.host, self.port), self.connect_timeout,
|
||||||
|
**kwargs)
|
||||||
break
|
break
|
||||||
except (OSError, IOError) as e:
|
except (OSError, IOError) as e:
|
||||||
if e.errno == errno.EINTR:
|
if e.errno == errno.EINTR:
|
||||||
|
@ -841,12 +921,13 @@ class Connection(object):
|
||||||
raise
|
raise
|
||||||
self.host_info = "socket %s:%d" % (self.host, self.port)
|
self.host_info = "socket %s:%d" % (self.host, self.port)
|
||||||
if DEBUG: print('connected using socket')
|
if DEBUG: print('connected using socket')
|
||||||
if self.no_delay:
|
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||||
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
sock.settimeout(None)
|
||||||
# sock.settimeout(None)
|
|
||||||
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
|
||||||
self.socket = sock
|
self._sock = sock
|
||||||
self._rfile = _makefile(sock, 'rb')
|
self._rfile = _makefile(sock, 'rb')
|
||||||
|
self._next_seq_id = 0
|
||||||
|
|
||||||
self._get_server_information()
|
self._get_server_information()
|
||||||
self._request_authentication()
|
self._request_authentication()
|
||||||
|
|
||||||
|
@ -886,6 +967,17 @@ class Connection(object):
|
||||||
# So just reraise it.
|
# So just reraise it.
|
||||||
raise
|
raise
|
||||||
|
|
||||||
|
def write_packet(self, payload):
|
||||||
|
"""Writes an entire "mysql packet" in its entirety to the network
|
||||||
|
addings its length and sequence number.
|
||||||
|
"""
|
||||||
|
# Internal note: when you build packet manualy and calls _write_bytes()
|
||||||
|
# directly, you should set self._next_seq_id properly.
|
||||||
|
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
|
||||||
|
if DEBUG: dump_packet(data)
|
||||||
|
self._write_bytes(data)
|
||||||
|
self._next_seq_id = (self._next_seq_id + 1) % 256
|
||||||
|
|
||||||
def _read_packet(self, packet_type=MysqlPacket):
|
def _read_packet(self, packet_type=MysqlPacket):
|
||||||
"""Read an entire "mysql packet" in its entirety from the network
|
"""Read an entire "mysql packet" in its entirety from the network
|
||||||
and return a MysqlPacket type that represents the results.
|
and return a MysqlPacket type that represents the results.
|
||||||
|
@ -894,19 +986,36 @@ class Connection(object):
|
||||||
while True:
|
while True:
|
||||||
packet_header = self._read_bytes(4)
|
packet_header = self._read_bytes(4)
|
||||||
if DEBUG: dump_packet(packet_header)
|
if DEBUG: dump_packet(packet_header)
|
||||||
|
|
||||||
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
|
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
|
||||||
bytes_to_read = btrl + (btrh << 16)
|
bytes_to_read = btrl + (btrh << 16)
|
||||||
#TODO: check sequence id
|
if packet_number != self._next_seq_id:
|
||||||
|
self._force_close()
|
||||||
|
if packet_number == 0:
|
||||||
|
# MariaDB sends error packet with seqno==0 when shutdown
|
||||||
|
raise err.OperationalError(
|
||||||
|
CR.CR_SERVER_LOST,
|
||||||
|
"Lost connection to MySQL server during query")
|
||||||
|
raise err.InternalError(
|
||||||
|
"Packet sequence number wrong - got %d expected %d"
|
||||||
|
% (packet_number, self._next_seq_id))
|
||||||
|
self._next_seq_id = (self._next_seq_id + 1) % 256
|
||||||
|
|
||||||
recv_data = self._read_bytes(bytes_to_read)
|
recv_data = self._read_bytes(bytes_to_read)
|
||||||
if DEBUG: dump_packet(recv_data)
|
if DEBUG: dump_packet(recv_data)
|
||||||
buff += recv_data
|
buff += recv_data
|
||||||
|
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
|
||||||
|
if bytes_to_read == 0xffffff:
|
||||||
|
continue
|
||||||
if bytes_to_read < MAX_PACKET_LEN:
|
if bytes_to_read < MAX_PACKET_LEN:
|
||||||
break
|
break
|
||||||
|
|
||||||
packet = packet_type(buff, self.encoding)
|
packet = packet_type(buff, self.encoding)
|
||||||
packet.check_error()
|
packet.check_error()
|
||||||
return packet
|
return packet
|
||||||
|
|
||||||
def _read_bytes(self, num_bytes):
|
def _read_bytes(self, num_bytes):
|
||||||
|
self._sock.settimeout(self._read_timeout)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
data = self._rfile.read(num_bytes)
|
data = self._rfile.read(num_bytes)
|
||||||
|
@ -914,19 +1023,25 @@ class Connection(object):
|
||||||
except (IOError, OSError) as e:
|
except (IOError, OSError) as e:
|
||||||
if e.errno == errno.EINTR:
|
if e.errno == errno.EINTR:
|
||||||
continue
|
continue
|
||||||
|
self._force_close()
|
||||||
raise err.OperationalError(
|
raise err.OperationalError(
|
||||||
2013,
|
CR.CR_SERVER_LOST,
|
||||||
"Lost connection to MySQL server during query (%s)" % (e,))
|
"Lost connection to MySQL server during query (%s)" % (e,))
|
||||||
if len(data) < num_bytes:
|
if len(data) < num_bytes:
|
||||||
|
self._force_close()
|
||||||
raise err.OperationalError(
|
raise err.OperationalError(
|
||||||
2013, "Lost connection to MySQL server during query")
|
CR.CR_SERVER_LOST, "Lost connection to MySQL server during query")
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def _write_bytes(self, data):
|
def _write_bytes(self, data):
|
||||||
|
self._sock.settimeout(self._write_timeout)
|
||||||
try:
|
try:
|
||||||
self.socket.sendall(data)
|
self._sock.sendall(data)
|
||||||
except IOError as e:
|
except IOError as e:
|
||||||
raise err.OperationalError(2006, "MySQL server has gone away (%r)" % (e,))
|
self._force_close()
|
||||||
|
raise err.OperationalError(
|
||||||
|
CR.CR_SERVER_GONE_ERROR,
|
||||||
|
"MySQL server has gone away (%r)" % (e,))
|
||||||
|
|
||||||
def _read_query_result(self, unbuffered=False):
|
def _read_query_result(self, unbuffered=False):
|
||||||
if unbuffered:
|
if unbuffered:
|
||||||
|
@ -952,42 +1067,45 @@ class Connection(object):
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def _execute_command(self, command, sql):
|
def _execute_command(self, command, sql):
|
||||||
if not self.socket:
|
if not self._sock:
|
||||||
raise err.InterfaceError("(0, '')")
|
raise err.InterfaceError("(0, '')")
|
||||||
|
|
||||||
# If the last query was unbuffered, make sure it finishes before
|
# If the last query was unbuffered, make sure it finishes before
|
||||||
# sending new commands
|
# sending new commands
|
||||||
if self._result is not None and self._result.unbuffered_active:
|
if self._result is not None:
|
||||||
warnings.warn("Previous unbuffered result was left incomplete")
|
if self._result.unbuffered_active:
|
||||||
self._result._finish_unbuffered_query()
|
warnings.warn("Previous unbuffered result was left incomplete")
|
||||||
|
self._result._finish_unbuffered_query()
|
||||||
|
while self._result.has_next:
|
||||||
|
self.next_result()
|
||||||
|
self._result = None
|
||||||
|
|
||||||
if isinstance(sql, text_type):
|
if isinstance(sql, text_type):
|
||||||
sql = sql.encode(self.encoding)
|
sql = sql.encode(self.encoding)
|
||||||
|
|
||||||
chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command
|
packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
|
||||||
|
|
||||||
prelude = struct.pack('<iB', chunk_size, command)
|
# tiny optimization: build first packet manually instead of
|
||||||
self._write_bytes(prelude + sql[:chunk_size-1])
|
# calling self..write_packet()
|
||||||
if DEBUG: dump_packet(prelude + sql)
|
prelude = struct.pack('<iB', packet_size, command)
|
||||||
|
packet = prelude + sql[:packet_size-1]
|
||||||
|
self._write_bytes(packet)
|
||||||
|
if DEBUG: dump_packet(packet)
|
||||||
|
self._next_seq_id = 1
|
||||||
|
|
||||||
if chunk_size < self.max_allowed_packet:
|
if packet_size < MAX_PACKET_LEN:
|
||||||
return
|
return
|
||||||
|
|
||||||
seq_id = 1
|
sql = sql[packet_size-1:]
|
||||||
sql = sql[chunk_size-1:]
|
|
||||||
while True:
|
while True:
|
||||||
chunk_size = min(self.max_allowed_packet, len(sql))
|
packet_size = min(MAX_PACKET_LEN, len(sql))
|
||||||
prelude = struct.pack('<i', chunk_size)[:3]
|
self.write_packet(sql[:packet_size])
|
||||||
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
|
sql = sql[packet_size:]
|
||||||
self._write_bytes(data)
|
if not sql and packet_size < MAX_PACKET_LEN:
|
||||||
if DEBUG: dump_packet(data)
|
|
||||||
sql = sql[chunk_size:]
|
|
||||||
if not sql and chunk_size < self.max_allowed_packet:
|
|
||||||
break
|
break
|
||||||
seq_id += 1
|
|
||||||
|
|
||||||
def _request_authentication(self):
|
def _request_authentication(self):
|
||||||
self.client_flag |= CLIENT.CAPABILITIES
|
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
|
||||||
if int(self.server_version.split('.', 1)[0]) >= 5:
|
if int(self.server_version.split('.', 1)[0]) >= 5:
|
||||||
self.client_flag |= CLIENT.MULTI_RESULTS
|
self.client_flag |= CLIENT.MULTI_RESULTS
|
||||||
|
|
||||||
|
@ -1000,48 +1118,114 @@ class Connection(object):
|
||||||
|
|
||||||
data_init = struct.pack('<iIB23s', self.client_flag, 1, charset_id, b'')
|
data_init = struct.pack('<iIB23s', self.client_flag, 1, charset_id, b'')
|
||||||
|
|
||||||
next_packet = 1
|
if self.ssl and self.server_capabilities & CLIENT.SSL:
|
||||||
|
self.write_packet(data_init)
|
||||||
|
|
||||||
if self.ssl:
|
self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host)
|
||||||
data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
|
self._rfile = _makefile(self._sock, 'rb')
|
||||||
next_packet += 1
|
|
||||||
|
|
||||||
if DEBUG: dump_packet(data)
|
data = data_init + self.user + b'\0'
|
||||||
self._write_bytes(data)
|
|
||||||
|
|
||||||
cert_reqs = ssl.CERT_NONE if self.ca is None else ssl.CERT_REQUIRED
|
authresp = b''
|
||||||
self.socket = ssl.wrap_socket(self.socket, keyfile=self.key,
|
if self._auth_plugin_name in ('', 'mysql_native_password'):
|
||||||
certfile=self.cert,
|
authresp = _scramble(self.password.encode('latin1'), self.salt)
|
||||||
ssl_version=ssl.PROTOCOL_TLSv1,
|
|
||||||
cert_reqs=cert_reqs,
|
|
||||||
ca_certs=self.ca)
|
|
||||||
self._rfile = _makefile(self.socket, 'rb')
|
|
||||||
|
|
||||||
data = data_init + self.user + b'\0' + \
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
|
||||||
_scramble(self.password.encode('latin1'), self.salt)
|
data += lenenc_int(len(authresp)) + authresp
|
||||||
|
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
|
||||||
|
data += struct.pack('B', len(authresp)) + authresp
|
||||||
|
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
|
||||||
|
data += authresp + b'\0'
|
||||||
|
|
||||||
if self.db:
|
if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
|
||||||
if isinstance(self.db, text_type):
|
if isinstance(self.db, text_type):
|
||||||
self.db = self.db.encode(self.encoding)
|
self.db = self.db.encode(self.encoding)
|
||||||
data += self.db + int2byte(0)
|
data += self.db + b'\0'
|
||||||
|
|
||||||
data = pack_int24(len(data)) + int2byte(next_packet) + data
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
|
||||||
next_packet += 2
|
name = self._auth_plugin_name
|
||||||
|
if isinstance(name, text_type):
|
||||||
if DEBUG: dump_packet(data)
|
name = name.encode('ascii')
|
||||||
self._write_bytes(data)
|
data += name + b'\0'
|
||||||
|
|
||||||
|
self.write_packet(data)
|
||||||
auth_packet = self._read_packet()
|
auth_packet = self._read_packet()
|
||||||
|
|
||||||
# if old_passwords is enabled the packet will be 1 byte long and
|
# if authentication method isn't accepted the first byte
|
||||||
# have the octet 254
|
# will have the octet 254
|
||||||
|
if auth_packet.is_auth_switch_request():
|
||||||
|
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
|
||||||
|
auth_packet.read_uint8() # 0xfe packet identifier
|
||||||
|
plugin_name = auth_packet.read_string()
|
||||||
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None:
|
||||||
|
auth_packet = self._process_auth(plugin_name, auth_packet)
|
||||||
|
else:
|
||||||
|
# send legacy handshake
|
||||||
|
data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
|
||||||
|
self.write_packet(data)
|
||||||
|
auth_packet = self._read_packet()
|
||||||
|
|
||||||
if auth_packet.is_eof_packet():
|
def _process_auth(self, plugin_name, auth_packet):
|
||||||
# send legacy handshake
|
plugin_class = self._auth_plugin_map.get(plugin_name)
|
||||||
data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
|
if not plugin_class:
|
||||||
data = pack_int24(len(data)) + int2byte(next_packet) + data
|
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
|
||||||
self._write_bytes(data)
|
if plugin_class:
|
||||||
auth_packet = self._read_packet()
|
try:
|
||||||
|
handler = plugin_class(self)
|
||||||
|
return handler.authenticate(auth_packet)
|
||||||
|
except AttributeError:
|
||||||
|
if plugin_name != b'dialog':
|
||||||
|
raise err.OperationalError(2059, "Authentication plugin '%s'" \
|
||||||
|
" not loaded: - %r missing authenticate method" % (plugin_name, plugin_class))
|
||||||
|
except TypeError:
|
||||||
|
raise err.OperationalError(2059, "Authentication plugin '%s'" \
|
||||||
|
" not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
|
||||||
|
else:
|
||||||
|
handler = None
|
||||||
|
if plugin_name == b"mysql_native_password":
|
||||||
|
# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
|
||||||
|
data = _scramble(self.password.encode('latin1'), auth_packet.read_all()) + b'\0'
|
||||||
|
elif plugin_name == b"mysql_old_password":
|
||||||
|
# https://dev.mysql.com/doc/internals/en/old-password-authentication.html
|
||||||
|
data = _scramble_323(self.password.encode('latin1'), auth_packet.read_all()) + b'\0'
|
||||||
|
elif plugin_name == b"mysql_clear_password":
|
||||||
|
# https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
|
||||||
|
data = self.password.encode('latin1') + b'\0'
|
||||||
|
elif plugin_name == b"dialog":
|
||||||
|
pkt = auth_packet
|
||||||
|
while True:
|
||||||
|
flag = pkt.read_uint8()
|
||||||
|
echo = (flag & 0x06) == 0x02
|
||||||
|
last = (flag & 0x01) == 0x01
|
||||||
|
prompt = pkt.read_all()
|
||||||
|
|
||||||
|
if prompt == b"Password: ":
|
||||||
|
self.write_packet(self.password.encode('latin1') + b'\0')
|
||||||
|
elif handler:
|
||||||
|
resp = 'no response - TypeError within plugin.prompt method'
|
||||||
|
try:
|
||||||
|
resp = handler.prompt(echo, prompt)
|
||||||
|
self.write_packet(resp + b'\0')
|
||||||
|
except AttributeError:
|
||||||
|
raise err.OperationalError(2059, "Authentication plugin '%s'" \
|
||||||
|
" not loaded: - %r missing prompt method" % (plugin_name, handler))
|
||||||
|
except TypeError:
|
||||||
|
raise err.OperationalError(2061, "Authentication plugin '%s'" \
|
||||||
|
" %r didn't respond with string. Returned '%r' to prompt %r" % (plugin_name, handler, resp, prompt))
|
||||||
|
else:
|
||||||
|
raise err.OperationalError(2059, "Authentication plugin '%s' (%r) not configured" % (plugin_name, handler))
|
||||||
|
pkt = self._read_packet()
|
||||||
|
pkt.check_error()
|
||||||
|
if pkt.is_ok_packet() or last:
|
||||||
|
break
|
||||||
|
return pkt
|
||||||
|
else:
|
||||||
|
raise err.OperationalError(2059, "Authentication plugin '%s' not configured" % plugin_name)
|
||||||
|
|
||||||
|
self.write_packet(data)
|
||||||
|
pkt = self._read_packet()
|
||||||
|
pkt.check_error()
|
||||||
|
return pkt
|
||||||
|
|
||||||
# _mysql support
|
# _mysql support
|
||||||
def thread_id(self):
|
def thread_id(self):
|
||||||
|
@ -1065,7 +1249,7 @@ class Connection(object):
|
||||||
self.protocol_version = byte2int(data[i:i+1])
|
self.protocol_version = byte2int(data[i:i+1])
|
||||||
i += 1
|
i += 1
|
||||||
|
|
||||||
server_end = data.find(int2byte(0), i)
|
server_end = data.find(b'\0', i)
|
||||||
self.server_version = data[i:server_end].decode('latin1')
|
self.server_version = data[i:server_end].decode('latin1')
|
||||||
i = server_end + 1
|
i = server_end + 1
|
||||||
|
|
||||||
|
@ -1097,7 +1281,22 @@ class Connection(object):
|
||||||
if len(data) >= i + salt_len:
|
if len(data) >= i + salt_len:
|
||||||
# salt_len includes auth_plugin_data_part_1 and filler
|
# salt_len includes auth_plugin_data_part_1 and filler
|
||||||
self.salt += data[i:i+salt_len]
|
self.salt += data[i:i+salt_len]
|
||||||
# TODO: AUTH PLUGIN NAME may appeare here.
|
i += salt_len
|
||||||
|
|
||||||
|
i+=1
|
||||||
|
# AUTH PLUGIN NAME may appear here.
|
||||||
|
if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
|
||||||
|
# Due to Bug#59453 the auth-plugin-name is missing the terminating
|
||||||
|
# NUL-char in versions prior to 5.5.10 and 5.6.2.
|
||||||
|
# ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
|
||||||
|
# didn't use version checks as mariadb is corrected and reports
|
||||||
|
# earlier than those two.
|
||||||
|
server_end = data.find(b'\0', i)
|
||||||
|
if server_end < 0: # pragma: no cover - very specific upstream bug
|
||||||
|
# not found \0 and last field so take it all
|
||||||
|
self._auth_plugin_name = data[i:].decode('latin1')
|
||||||
|
else:
|
||||||
|
self._auth_plugin_name = data[i:server_end].decode('latin1')
|
||||||
|
|
||||||
def get_server_info(self):
|
def get_server_info(self):
|
||||||
return self.server_version
|
return self.server_version
|
||||||
|
@ -1117,6 +1316,9 @@ class Connection(object):
|
||||||
class MySQLResult(object):
|
class MySQLResult(object):
|
||||||
|
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
|
"""
|
||||||
|
:type connection: Connection
|
||||||
|
"""
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.affected_rows = None
|
self.affected_rows = None
|
||||||
self.insert_id = None
|
self.insert_id = None
|
||||||
|
@ -1144,7 +1346,7 @@ class MySQLResult(object):
|
||||||
else:
|
else:
|
||||||
self._read_result_packet(first_packet)
|
self._read_result_packet(first_packet)
|
||||||
finally:
|
finally:
|
||||||
self.connection = False
|
self.connection = None
|
||||||
|
|
||||||
def init_unbuffered_query(self):
|
def init_unbuffered_query(self):
|
||||||
self.unbuffered_active = True
|
self.unbuffered_active = True
|
||||||
|
@ -1154,6 +1356,10 @@ class MySQLResult(object):
|
||||||
self._read_ok_packet(first_packet)
|
self._read_ok_packet(first_packet)
|
||||||
self.unbuffered_active = False
|
self.unbuffered_active = False
|
||||||
self.connection = None
|
self.connection = None
|
||||||
|
elif first_packet.is_load_local_packet():
|
||||||
|
self._read_load_local_packet(first_packet)
|
||||||
|
self.unbuffered_active = False
|
||||||
|
self.connection = None
|
||||||
else:
|
else:
|
||||||
self.field_count = first_packet.read_length_encoded_integer()
|
self.field_count = first_packet.read_length_encoded_integer()
|
||||||
self._get_descriptions()
|
self._get_descriptions()
|
||||||
|
@ -1173,22 +1379,33 @@ class MySQLResult(object):
|
||||||
self.has_next = ok_packet.has_next
|
self.has_next = ok_packet.has_next
|
||||||
|
|
||||||
def _read_load_local_packet(self, first_packet):
|
def _read_load_local_packet(self, first_packet):
|
||||||
|
if not self.connection._local_infile:
|
||||||
|
raise RuntimeError(
|
||||||
|
"**WARN**: Received LOAD_LOCAL packet but local_infile option is false.")
|
||||||
load_packet = LoadLocalPacketWrapper(first_packet)
|
load_packet = LoadLocalPacketWrapper(first_packet)
|
||||||
sender = LoadLocalFile(load_packet.filename, self.connection)
|
sender = LoadLocalFile(load_packet.filename, self.connection)
|
||||||
sender.send_data()
|
try:
|
||||||
|
sender.send_data()
|
||||||
|
except:
|
||||||
|
self.connection._read_packet() # skip ok packet
|
||||||
|
raise
|
||||||
|
|
||||||
ok_packet = self.connection._read_packet()
|
ok_packet = self.connection._read_packet()
|
||||||
if not ok_packet.is_ok_packet():
|
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
|
||||||
raise err.OperationalError(2014, "Commands Out of Sync")
|
raise err.OperationalError(2014, "Commands Out of Sync")
|
||||||
self._read_ok_packet(ok_packet)
|
self._read_ok_packet(ok_packet)
|
||||||
|
|
||||||
def _check_packet_is_eof(self, packet):
|
def _check_packet_is_eof(self, packet):
|
||||||
if packet.is_eof_packet():
|
if not packet.is_eof_packet():
|
||||||
eof_packet = EOFPacketWrapper(packet)
|
return False
|
||||||
self.warning_count = eof_packet.warning_count
|
#TODO: Support CLIENT.DEPRECATE_EOF
|
||||||
self.has_next = eof_packet.has_next
|
# 1) Add DEPRECATE_EOF to CAPABILITIES
|
||||||
return True
|
# 2) Mask CAPABILITIES with server_capabilities
|
||||||
return False
|
# 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper instead of EOFPacketWrapper
|
||||||
|
wp = EOFPacketWrapper(packet)
|
||||||
|
self.warning_count = wp.warning_count
|
||||||
|
self.has_next = wp.has_next
|
||||||
|
return True
|
||||||
|
|
||||||
def _read_result_packet(self, first_packet):
|
def _read_result_packet(self, first_packet):
|
||||||
self.field_count = first_packet.read_length_encoded_integer()
|
self.field_count = first_packet.read_length_encoded_integer()
|
||||||
|
@ -1239,7 +1456,12 @@ class MySQLResult(object):
|
||||||
def _read_row_from_packet(self, packet):
|
def _read_row_from_packet(self, packet):
|
||||||
row = []
|
row = []
|
||||||
for encoding, converter in self.converters:
|
for encoding, converter in self.converters:
|
||||||
data = packet.read_length_coded_string()
|
try:
|
||||||
|
data = packet.read_length_coded_string()
|
||||||
|
except IndexError:
|
||||||
|
# No more columns in this row
|
||||||
|
# See https://github.com/PyMySQL/PyMySQL/pull/434
|
||||||
|
break
|
||||||
if data is not None:
|
if data is not None:
|
||||||
if encoding is not None:
|
if encoding is not None:
|
||||||
data = data.decode(encoding)
|
data = data.decode(encoding)
|
||||||
|
@ -1254,21 +1476,30 @@ class MySQLResult(object):
|
||||||
self.fields = []
|
self.fields = []
|
||||||
self.converters = []
|
self.converters = []
|
||||||
use_unicode = self.connection.use_unicode
|
use_unicode = self.connection.use_unicode
|
||||||
|
conn_encoding = self.connection.encoding
|
||||||
description = []
|
description = []
|
||||||
|
|
||||||
for i in range_type(self.field_count):
|
for i in range_type(self.field_count):
|
||||||
field = self.connection._read_packet(FieldDescriptorPacket)
|
field = self.connection._read_packet(FieldDescriptorPacket)
|
||||||
self.fields.append(field)
|
self.fields.append(field)
|
||||||
description.append(field.description())
|
description.append(field.description())
|
||||||
field_type = field.type_code
|
field_type = field.type_code
|
||||||
if use_unicode:
|
if use_unicode:
|
||||||
if field_type in TEXT_TYPES:
|
if field_type == FIELD_TYPE.JSON:
|
||||||
charset = charset_by_id(field.charsetnr)
|
# When SELECT from JSON column: charset = binary
|
||||||
if charset.is_binary:
|
# When SELECT CAST(... AS JSON): charset = connection encoding
|
||||||
|
# This behavior is different from TEXT / BLOB.
|
||||||
|
# We should decode result by connection encoding regardless charsetnr.
|
||||||
|
# See https://github.com/PyMySQL/PyMySQL/issues/488
|
||||||
|
encoding = conn_encoding # SELECT CAST(... AS JSON)
|
||||||
|
elif field_type in TEXT_TYPES:
|
||||||
|
if field.charsetnr == 63: # binary
|
||||||
# TEXTs with charset=binary means BINARY types.
|
# TEXTs with charset=binary means BINARY types.
|
||||||
encoding = None
|
encoding = None
|
||||||
else:
|
else:
|
||||||
encoding = charset.encoding
|
encoding = conn_encoding
|
||||||
else:
|
else:
|
||||||
|
# Integers, Dates and Times, and other basic data is encoded in ascii
|
||||||
encoding = 'ascii'
|
encoding = 'ascii'
|
||||||
else:
|
else:
|
||||||
encoding = None
|
encoding = None
|
||||||
|
@ -1290,28 +1521,20 @@ class LoadLocalFile(object):
|
||||||
|
|
||||||
def send_data(self):
|
def send_data(self):
|
||||||
"""Send data packets from the local file to the server"""
|
"""Send data packets from the local file to the server"""
|
||||||
if not self.connection.socket:
|
if not self.connection._sock:
|
||||||
raise err.InterfaceError("(0, '')")
|
raise err.InterfaceError("(0, '')")
|
||||||
|
conn = self.connection
|
||||||
|
|
||||||
# sequence id is 2 as we already sent a query packet
|
|
||||||
seq_id = 2
|
|
||||||
try:
|
try:
|
||||||
with open(self.filename, 'rb') as open_file:
|
with open(self.filename, 'rb') as open_file:
|
||||||
chunk_size = self.connection.max_allowed_packet
|
packet_size = min(conn.max_allowed_packet, 16*1024) # 16KB is efficient enough
|
||||||
packet = b""
|
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
chunk = open_file.read(chunk_size)
|
chunk = open_file.read(packet_size)
|
||||||
if not chunk:
|
if not chunk:
|
||||||
break
|
break
|
||||||
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
|
conn.write_packet(chunk)
|
||||||
format_str = '!{0}s'.format(len(chunk))
|
|
||||||
packet += struct.pack(format_str, chunk)
|
|
||||||
self.connection._write_bytes(packet)
|
|
||||||
seq_id += 1
|
|
||||||
except IOError:
|
except IOError:
|
||||||
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
|
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
|
||||||
finally:
|
finally:
|
||||||
# send the empty packet to signify we are done sending data
|
# send the empty packet to signify we are done sending data
|
||||||
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
|
conn.write_packet(b'')
|
||||||
self.connection._write_bytes(packet)
|
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
# https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
|
||||||
LONG_PASSWORD = 1
|
LONG_PASSWORD = 1
|
||||||
FOUND_ROWS = 1 << 1
|
FOUND_ROWS = 1 << 1
|
||||||
LONG_FLAG = 1 << 2
|
LONG_FLAG = 1 << 2
|
||||||
|
@ -15,5 +16,16 @@ TRANSACTIONS = 1 << 13
|
||||||
SECURE_CONNECTION = 1 << 15
|
SECURE_CONNECTION = 1 << 15
|
||||||
MULTI_STATEMENTS = 1 << 16
|
MULTI_STATEMENTS = 1 << 16
|
||||||
MULTI_RESULTS = 1 << 17
|
MULTI_RESULTS = 1 << 17
|
||||||
CAPABILITIES = (LONG_PASSWORD | LONG_FLAG | TRANSACTIONS |
|
PS_MULTI_RESULTS = 1 << 18
|
||||||
PROTOCOL_41 | SECURE_CONNECTION)
|
PLUGIN_AUTH = 1 << 19
|
||||||
|
PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21
|
||||||
|
CAPABILITIES = (
|
||||||
|
LONG_PASSWORD | LONG_FLAG | PROTOCOL_41 | TRANSACTIONS
|
||||||
|
| SECURE_CONNECTION | MULTI_STATEMENTS | MULTI_RESULTS
|
||||||
|
| PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA)
|
||||||
|
|
||||||
|
# Not done yet
|
||||||
|
CONNECT_ATTRS = 1 << 20
|
||||||
|
HANDLE_EXPIRED_PASSWORDS = 1 << 22
|
||||||
|
SESSION_TRACK = 1 << 23
|
||||||
|
DEPRECATE_EOF = 1 << 24
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
# flake8: noqa
|
||||||
# errmsg.h
|
# errmsg.h
|
||||||
CR_ERROR_FIRST = 2000
|
CR_ERROR_FIRST = 2000
|
||||||
CR_UNKNOWN_ERROR = 2000
|
CR_UNKNOWN_ERROR = 2000
|
||||||
|
|
|
@ -17,6 +17,7 @@ YEAR = 13
|
||||||
NEWDATE = 14
|
NEWDATE = 14
|
||||||
VARCHAR = 15
|
VARCHAR = 15
|
||||||
BIT = 16
|
BIT = 16
|
||||||
|
JSON = 245
|
||||||
NEWDECIMAL = 246
|
NEWDECIMAL = 246
|
||||||
ENUM = 247
|
ENUM = 247
|
||||||
SET = 248
|
SET = 248
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
from ._compat import PY2, text_type, long_type, JYTHON, IRONPYTHON
|
from ._compat import PY2, text_type, long_type, JYTHON, IRONPYTHON, unichr
|
||||||
|
|
||||||
import sys
|
|
||||||
import binascii
|
|
||||||
import datetime
|
import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
import re
|
import re
|
||||||
|
@ -11,10 +9,6 @@ from .constants import FIELD_TYPE, FLAG
|
||||||
from .charset import charset_by_id, charset_to_encoding
|
from .charset import charset_by_id, charset_to_encoding
|
||||||
|
|
||||||
|
|
||||||
ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
|
|
||||||
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
|
|
||||||
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
|
|
||||||
|
|
||||||
def escape_item(val, charset, mapping=None):
|
def escape_item(val, charset, mapping=None):
|
||||||
if mapping is None:
|
if mapping is None:
|
||||||
mapping = encoders
|
mapping = encoders
|
||||||
|
@ -48,8 +42,7 @@ def escape_sequence(val, charset, mapping=None):
|
||||||
return "(" + ",".join(n) + ")"
|
return "(" + ",".join(n) + ")"
|
||||||
|
|
||||||
def escape_set(val, charset, mapping=None):
|
def escape_set(val, charset, mapping=None):
|
||||||
val = map(lambda x: escape_item(x, charset, mapping), val)
|
return ','.join([escape_item(x, charset, mapping) for x in val])
|
||||||
return ','.join(val)
|
|
||||||
|
|
||||||
def escape_bool(value, mapping=None):
|
def escape_bool(value, mapping=None):
|
||||||
return str(int(value))
|
return str(int(value))
|
||||||
|
@ -63,19 +56,61 @@ def escape_int(value, mapping=None):
|
||||||
def escape_float(value, mapping=None):
|
def escape_float(value, mapping=None):
|
||||||
return ('%.15g' % value)
|
return ('%.15g' % value)
|
||||||
|
|
||||||
def escape_string(value, mapping=None):
|
_escape_table = [unichr(x) for x in range(128)]
|
||||||
return ("%s" % (ESCAPE_REGEX.sub(
|
_escape_table[0] = u'\\0'
|
||||||
lambda match: ESCAPE_MAP.get(match.group(0)), value),))
|
_escape_table[ord('\\')] = u'\\\\'
|
||||||
|
_escape_table[ord('\n')] = u'\\n'
|
||||||
|
_escape_table[ord('\r')] = u'\\r'
|
||||||
|
_escape_table[ord('\032')] = u'\\Z'
|
||||||
|
_escape_table[ord('"')] = u'\\"'
|
||||||
|
_escape_table[ord("'")] = u"\\'"
|
||||||
|
|
||||||
|
def _escape_unicode(value, mapping=None):
|
||||||
|
"""escapes *value* without adding quote.
|
||||||
|
|
||||||
|
Value should be unicode
|
||||||
|
"""
|
||||||
|
return value.translate(_escape_table)
|
||||||
|
|
||||||
|
if PY2:
|
||||||
|
def escape_string(value, mapping=None):
|
||||||
|
"""escape_string escapes *value* but not surround it with quotes.
|
||||||
|
|
||||||
|
Value should be bytes or unicode.
|
||||||
|
"""
|
||||||
|
if isinstance(value, unicode):
|
||||||
|
return _escape_unicode(value)
|
||||||
|
assert isinstance(value, (bytes, bytearray))
|
||||||
|
value = value.replace('\\', '\\\\')
|
||||||
|
value = value.replace('\0', '\\0')
|
||||||
|
value = value.replace('\n', '\\n')
|
||||||
|
value = value.replace('\r', '\\r')
|
||||||
|
value = value.replace('\032', '\\Z')
|
||||||
|
value = value.replace("'", "\\'")
|
||||||
|
value = value.replace('"', '\\"')
|
||||||
|
return value
|
||||||
|
|
||||||
|
def escape_bytes(value, mapping=None):
|
||||||
|
assert isinstance(value, (bytes, bytearray))
|
||||||
|
return b"_binary'%s'" % escape_string(value)
|
||||||
|
else:
|
||||||
|
escape_string = _escape_unicode
|
||||||
|
|
||||||
|
# On Python ~3.5, str.decode('ascii', 'surrogateescape') is slow.
|
||||||
|
# (fixed in Python 3.6, http://bugs.python.org/issue24870)
|
||||||
|
# Workaround is str.decode('latin1') then translate 0x80-0xff into 0udc80-0udcff.
|
||||||
|
# We can escape special chars and surrogateescape at once.
|
||||||
|
_escape_bytes_table = _escape_table + [chr(i) for i in range(0xdc80, 0xdd00)]
|
||||||
|
|
||||||
|
def escape_bytes(value, mapping=None):
|
||||||
|
return "_binary'%s'" % value.decode('latin1').translate(_escape_bytes_table)
|
||||||
|
|
||||||
def escape_str(value, mapping=None):
|
|
||||||
return "'%s'" % escape_string(value, mapping)
|
|
||||||
|
|
||||||
def escape_unicode(value, mapping=None):
|
def escape_unicode(value, mapping=None):
|
||||||
return escape_str(value, mapping)
|
return u"'%s'" % _escape_unicode(value)
|
||||||
|
|
||||||
def escape_bytes(value, mapping=None):
|
def escape_str(value, mapping=None):
|
||||||
# escape_bytes is calld only on Python 3.
|
return "'%s'" % escape_string(str(value), mapping)
|
||||||
return escape_str(value.decode('ascii', 'surrogateescape'), mapping)
|
|
||||||
|
|
||||||
def escape_None(value, mapping=None):
|
def escape_None(value, mapping=None):
|
||||||
return 'NULL'
|
return 'NULL'
|
||||||
|
@ -111,6 +146,16 @@ def escape_date(obj, mapping=None):
|
||||||
def escape_struct_time(obj, mapping=None):
|
def escape_struct_time(obj, mapping=None):
|
||||||
return escape_datetime(datetime.datetime(*obj[:6]))
|
return escape_datetime(datetime.datetime(*obj[:6]))
|
||||||
|
|
||||||
|
def _convert_second_fraction(s):
|
||||||
|
if not s:
|
||||||
|
return 0
|
||||||
|
# Pad zeros to ensure the fraction length in microseconds
|
||||||
|
s = s.ljust(6, '0')
|
||||||
|
return int(s[:6])
|
||||||
|
|
||||||
|
DATETIME_RE = re.compile(r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
|
||||||
|
|
||||||
|
|
||||||
def convert_datetime(obj):
|
def convert_datetime(obj):
|
||||||
"""Returns a DATETIME or TIMESTAMP column value as a datetime object:
|
"""Returns a DATETIME or TIMESTAMP column value as a datetime object:
|
||||||
|
|
||||||
|
@ -127,23 +172,22 @@ def convert_datetime(obj):
|
||||||
True
|
True
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if ' ' in obj:
|
if not PY2 and isinstance(obj, (bytes, bytearray)):
|
||||||
sep = ' '
|
obj = obj.decode('ascii')
|
||||||
elif 'T' in obj:
|
|
||||||
sep = 'T'
|
m = DATETIME_RE.match(obj)
|
||||||
else:
|
if not m:
|
||||||
return convert_date(obj)
|
return convert_date(obj)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ymd, hms = obj.split(sep, 1)
|
groups = list(m.groups())
|
||||||
usecs = '0'
|
groups[-1] = _convert_second_fraction(groups[-1])
|
||||||
if '.' in hms:
|
return datetime.datetime(*[ int(x) for x in groups ])
|
||||||
hms, usecs = hms.split('.')
|
|
||||||
usecs = float('0.' + usecs) * 1e6
|
|
||||||
return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':')+[usecs] ])
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return convert_date(obj)
|
return convert_date(obj)
|
||||||
|
|
||||||
|
TIMEDELTA_RE = re.compile(r"(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
|
||||||
|
|
||||||
|
|
||||||
def convert_timedelta(obj):
|
def convert_timedelta(obj):
|
||||||
"""Returns a TIME column as a timedelta object:
|
"""Returns a TIME column as a timedelta object:
|
||||||
|
@ -162,16 +206,19 @@ def convert_timedelta(obj):
|
||||||
can accept values as (+|-)DD HH:MM:SS. The latter format will not
|
can accept values as (+|-)DD HH:MM:SS. The latter format will not
|
||||||
be parsed correctly by this function.
|
be parsed correctly by this function.
|
||||||
"""
|
"""
|
||||||
|
if not PY2 and isinstance(obj, (bytes, bytearray)):
|
||||||
|
obj = obj.decode('ascii')
|
||||||
|
|
||||||
|
m = TIMEDELTA_RE.match(obj)
|
||||||
|
if not m:
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
microseconds = 0
|
groups = list(m.groups())
|
||||||
if "." in obj:
|
groups[-1] = _convert_second_fraction(groups[-1])
|
||||||
(obj, tail) = obj.split('.')
|
negate = -1 if groups[0] else 1
|
||||||
microseconds = float('0.' + tail) * 1e6
|
hours, minutes, seconds, microseconds = groups[1:]
|
||||||
hours, minutes, seconds = obj.split(':')
|
|
||||||
negate = 1
|
|
||||||
if hours.startswith("-"):
|
|
||||||
hours = hours[1:]
|
|
||||||
negate = -1
|
|
||||||
tdelta = datetime.timedelta(
|
tdelta = datetime.timedelta(
|
||||||
hours = int(hours),
|
hours = int(hours),
|
||||||
minutes = int(minutes),
|
minutes = int(minutes),
|
||||||
|
@ -182,6 +229,9 @@ def convert_timedelta(obj):
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
|
||||||
|
|
||||||
|
|
||||||
def convert_time(obj):
|
def convert_time(obj):
|
||||||
"""Returns a TIME column as a time object:
|
"""Returns a TIME column as a time object:
|
||||||
|
|
||||||
|
@ -204,17 +254,23 @@ def convert_time(obj):
|
||||||
to be treated as time-of-day and not a time offset, then you can
|
to be treated as time-of-day and not a time offset, then you can
|
||||||
use set this function as the converter for FIELD_TYPE.TIME.
|
use set this function as the converter for FIELD_TYPE.TIME.
|
||||||
"""
|
"""
|
||||||
|
if not PY2 and isinstance(obj, (bytes, bytearray)):
|
||||||
|
obj = obj.decode('ascii')
|
||||||
|
|
||||||
|
m = TIME_RE.match(obj)
|
||||||
|
if not m:
|
||||||
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
microseconds = 0
|
groups = list(m.groups())
|
||||||
if "." in obj:
|
groups[-1] = _convert_second_fraction(groups[-1])
|
||||||
(obj, tail) = obj.split('.')
|
hours, minutes, seconds, microseconds = groups
|
||||||
microseconds = float('0.' + tail) * 1e6
|
|
||||||
hours, minutes, seconds = obj.split(':')
|
|
||||||
return datetime.time(hour=int(hours), minute=int(minutes),
|
return datetime.time(hour=int(hours), minute=int(minutes),
|
||||||
second=int(seconds), microsecond=int(microseconds))
|
second=int(seconds), microsecond=int(microseconds))
|
||||||
except ValueError:
|
except ValueError:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def convert_date(obj):
|
def convert_date(obj):
|
||||||
"""Returns a DATE column as a date object:
|
"""Returns a DATE column as a date object:
|
||||||
|
|
||||||
|
@ -229,6 +285,8 @@ def convert_date(obj):
|
||||||
True
|
True
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if not PY2 and isinstance(obj, (bytes, bytearray)):
|
||||||
|
obj = obj.decode('ascii')
|
||||||
try:
|
try:
|
||||||
return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
|
return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
@ -256,6 +314,8 @@ def convert_mysql_timestamp(timestamp):
|
||||||
True
|
True
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
if not PY2 and isinstance(timestamp, (bytes, bytearray)):
|
||||||
|
timestamp = timestamp.decode('ascii')
|
||||||
if timestamp[4] == '-':
|
if timestamp[4] == '-':
|
||||||
return convert_datetime(timestamp)
|
return convert_datetime(timestamp)
|
||||||
timestamp += "0"*(14-len(timestamp)) # padding
|
timestamp += "0"*(14-len(timestamp)) # padding
|
||||||
|
@ -268,6 +328,8 @@ def convert_mysql_timestamp(timestamp):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def convert_set(s):
|
def convert_set(s):
|
||||||
|
if isinstance(s, (bytes, bytearray)):
|
||||||
|
return set(s.split(b","))
|
||||||
return set(s.split(","))
|
return set(s.split(","))
|
||||||
|
|
||||||
|
|
||||||
|
@ -278,7 +340,7 @@ def through(x):
|
||||||
#def convert_bit(b):
|
#def convert_bit(b):
|
||||||
# b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
|
# b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
|
||||||
# return struct.unpack(">Q", b)[0]
|
# return struct.unpack(">Q", b)[0]
|
||||||
#
|
#
|
||||||
# the snippet above is right, but MySQLdb doesn't process bits,
|
# the snippet above is right, but MySQLdb doesn't process bits,
|
||||||
# so we shouldn't either
|
# so we shouldn't either
|
||||||
convert_bit = through
|
convert_bit = through
|
||||||
|
@ -309,7 +371,9 @@ encoders = {
|
||||||
tuple: escape_sequence,
|
tuple: escape_sequence,
|
||||||
list: escape_sequence,
|
list: escape_sequence,
|
||||||
set: escape_sequence,
|
set: escape_sequence,
|
||||||
|
frozenset: escape_sequence,
|
||||||
dict: escape_dict,
|
dict: escape_dict,
|
||||||
|
bytearray: escape_bytes,
|
||||||
type(None): escape_None,
|
type(None): escape_None,
|
||||||
datetime.date: escape_date,
|
datetime.date: escape_date,
|
||||||
datetime.datetime: escape_datetime,
|
datetime.datetime: escape_datetime,
|
||||||
|
@ -350,7 +414,6 @@ decoders = {
|
||||||
|
|
||||||
|
|
||||||
# for MySQLdb compatibility
|
# for MySQLdb compatibility
|
||||||
conversions = decoders
|
conversions = encoders.copy()
|
||||||
|
conversions.update(decoders)
|
||||||
def Thing2Literal(obj):
|
Thing2Literal = escape_str
|
||||||
return escape_str(str(obj))
|
|
||||||
|
|
|
@ -5,33 +5,37 @@ import re
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from ._compat import range_type, text_type, PY2
|
from ._compat import range_type, text_type, PY2
|
||||||
|
|
||||||
from . import err
|
from . import err
|
||||||
|
|
||||||
|
|
||||||
#: Regular expression for :meth:`Cursor.executemany`.
|
#: Regular expression for :meth:`Cursor.executemany`.
|
||||||
#: executemany only suports simple bulk insert.
|
#: executemany only suports simple bulk insert.
|
||||||
#: You can use it to load large dataset.
|
#: You can use it to load large dataset.
|
||||||
RE_INSERT_VALUES = re.compile(r"""(INSERT\s.+\sVALUES\s+)(\(\s*%s\s*(?:,\s*%s\s*)*\))(\s*(?:ON DUPLICATE.*)?)\Z""",
|
RE_INSERT_VALUES = re.compile(
|
||||||
re.IGNORECASE | re.DOTALL)
|
r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
|
||||||
|
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
|
||||||
|
r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
|
||||||
|
re.IGNORECASE | re.DOTALL)
|
||||||
|
|
||||||
|
|
||||||
class Cursor(object):
|
class Cursor(object):
|
||||||
'''
|
"""
|
||||||
This is the object you use to interact with the database.
|
This is the object you use to interact with the database.
|
||||||
'''
|
"""
|
||||||
|
|
||||||
#: Max stetement size which :meth:`executemany` generates.
|
#: Max statement size which :meth:`executemany` generates.
|
||||||
#:
|
#:
|
||||||
#: Max size of allowed statement is max_allowed_packet - packet_header_size.
|
#: Max size of allowed statement is max_allowed_packet - packet_header_size.
|
||||||
#: Default value of max_allowed_packet is 1048576.
|
#: Default value of max_allowed_packet is 1048576.
|
||||||
max_stmt_length = 1024000
|
max_stmt_length = 1024000
|
||||||
|
|
||||||
|
_defer_warnings = False
|
||||||
|
|
||||||
def __init__(self, connection):
|
def __init__(self, connection):
|
||||||
'''
|
"""
|
||||||
Do not create an instance of a Cursor yourself. Call
|
Do not create an instance of a Cursor yourself. Call
|
||||||
connections.Connection.cursor().
|
connections.Connection.cursor().
|
||||||
'''
|
"""
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
self.description = None
|
self.description = None
|
||||||
self.rownumber = 0
|
self.rownumber = 0
|
||||||
|
@ -40,11 +44,12 @@ class Cursor(object):
|
||||||
self._executed = None
|
self._executed = None
|
||||||
self._result = None
|
self._result = None
|
||||||
self._rows = None
|
self._rows = None
|
||||||
|
self._warnings_handled = False
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
'''
|
"""
|
||||||
Closing a cursor just exhausts all remaining data.
|
Closing a cursor just exhausts all remaining data.
|
||||||
'''
|
"""
|
||||||
conn = self.connection
|
conn = self.connection
|
||||||
if conn is None:
|
if conn is None:
|
||||||
return
|
return
|
||||||
|
@ -83,6 +88,9 @@ class Cursor(object):
|
||||||
"""Get the next query set"""
|
"""Get the next query set"""
|
||||||
conn = self._get_db()
|
conn = self._get_db()
|
||||||
current_result = self._result
|
current_result = self._result
|
||||||
|
# for unbuffered queries warnings are only available once whole result has been read
|
||||||
|
if unbuffered:
|
||||||
|
self._show_warnings()
|
||||||
if current_result is None or current_result is not conn._result:
|
if current_result is None or current_result is not conn._result:
|
||||||
return None
|
return None
|
||||||
if not current_result.has_next:
|
if not current_result.has_next:
|
||||||
|
@ -107,17 +115,17 @@ class Cursor(object):
|
||||||
if isinstance(args, (tuple, list)):
|
if isinstance(args, (tuple, list)):
|
||||||
if PY2:
|
if PY2:
|
||||||
args = tuple(map(ensure_bytes, args))
|
args = tuple(map(ensure_bytes, args))
|
||||||
return tuple(conn.escape(arg) for arg in args)
|
return tuple(conn.literal(arg) for arg in args)
|
||||||
elif isinstance(args, dict):
|
elif isinstance(args, dict):
|
||||||
if PY2:
|
if PY2:
|
||||||
args = dict((ensure_bytes(key), ensure_bytes(val)) for
|
args = dict((ensure_bytes(key), ensure_bytes(val)) for
|
||||||
(key, val) in args.items())
|
(key, val) in args.items())
|
||||||
return dict((key, conn.escape(val)) for (key, val) in args.items())
|
return dict((key, conn.literal(val)) for (key, val) in args.items())
|
||||||
else:
|
else:
|
||||||
# If it's not a dictionary let's try escaping it anyways.
|
# If it's not a dictionary let's try escaping it anyways.
|
||||||
# Worst case it will throw a Value error
|
# Worst case it will throw a Value error
|
||||||
if PY2:
|
if PY2:
|
||||||
ensure_bytes(args)
|
args = ensure_bytes(args)
|
||||||
return conn.escape(args)
|
return conn.escape(args)
|
||||||
|
|
||||||
def mogrify(self, query, args=None):
|
def mogrify(self, query, args=None):
|
||||||
|
@ -137,7 +145,19 @@ class Cursor(object):
|
||||||
return query
|
return query
|
||||||
|
|
||||||
def execute(self, query, args=None):
|
def execute(self, query, args=None):
|
||||||
'''Execute a query'''
|
"""Execute a query
|
||||||
|
|
||||||
|
:param str query: Query to execute.
|
||||||
|
|
||||||
|
:param args: parameters used with query. (optional)
|
||||||
|
:type args: tuple, list or dict
|
||||||
|
|
||||||
|
:return: Number of affected rows
|
||||||
|
:rtype: int
|
||||||
|
|
||||||
|
If args is a list or tuple, %s can be used as a placeholder in the query.
|
||||||
|
If args is a dict, %(name)s can be used as a placeholder in the query.
|
||||||
|
"""
|
||||||
while self.nextset():
|
while self.nextset():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -148,17 +168,23 @@ class Cursor(object):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def executemany(self, query, args):
|
def executemany(self, query, args):
|
||||||
|
# type: (str, list) -> int
|
||||||
"""Run several data against one query
|
"""Run several data against one query
|
||||||
|
|
||||||
PyMySQL can execute bulkinsert for query like 'INSERT ... VALUES (%s)'.
|
:param query: query to execute on server
|
||||||
In other form of queries, just run :meth:`execute` many times.
|
:param args: Sequence of sequences or mappings. It is used as parameter.
|
||||||
|
:return: Number of rows affected, if any.
|
||||||
|
|
||||||
|
This method improves performance on multiple-row INSERT and
|
||||||
|
REPLACE. Otherwise it is equivalent to looping over args with
|
||||||
|
execute().
|
||||||
"""
|
"""
|
||||||
if not args:
|
if not args:
|
||||||
return
|
return
|
||||||
|
|
||||||
m = RE_INSERT_VALUES.match(query)
|
m = RE_INSERT_VALUES.match(query)
|
||||||
if m:
|
if m:
|
||||||
q_prefix = m.group(1)
|
q_prefix = m.group(1) % ()
|
||||||
q_values = m.group(2).rstrip()
|
q_values = m.group(2).rstrip()
|
||||||
q_postfix = m.group(3) or ''
|
q_postfix = m.group(3) or ''
|
||||||
assert q_values[0] == '(' and q_values[-1] == ')'
|
assert q_values[0] == '(' and q_values[-1] == ')'
|
||||||
|
@ -247,7 +273,7 @@ class Cursor(object):
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def fetchone(self):
|
def fetchone(self):
|
||||||
''' Fetch the next row '''
|
"""Fetch the next row"""
|
||||||
self._check_executed()
|
self._check_executed()
|
||||||
if self._rows is None or self.rownumber >= len(self._rows):
|
if self._rows is None or self.rownumber >= len(self._rows):
|
||||||
return None
|
return None
|
||||||
|
@ -256,7 +282,7 @@ class Cursor(object):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def fetchmany(self, size=None):
|
def fetchmany(self, size=None):
|
||||||
''' Fetch several rows '''
|
"""Fetch several rows"""
|
||||||
self._check_executed()
|
self._check_executed()
|
||||||
if self._rows is None:
|
if self._rows is None:
|
||||||
return ()
|
return ()
|
||||||
|
@ -266,7 +292,7 @@ class Cursor(object):
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def fetchall(self):
|
def fetchall(self):
|
||||||
''' Fetch all the rows '''
|
"""Fetch all the rows"""
|
||||||
self._check_executed()
|
self._check_executed()
|
||||||
if self._rows is None:
|
if self._rows is None:
|
||||||
return ()
|
return ()
|
||||||
|
@ -307,14 +333,18 @@ class Cursor(object):
|
||||||
self.description = result.description
|
self.description = result.description
|
||||||
self.lastrowid = result.insert_id
|
self.lastrowid = result.insert_id
|
||||||
self._rows = result.rows
|
self._rows = result.rows
|
||||||
|
self._warnings_handled = False
|
||||||
|
|
||||||
if result.warning_count > 0:
|
if not self._defer_warnings:
|
||||||
self._show_warnings(conn)
|
self._show_warnings()
|
||||||
|
|
||||||
def _show_warnings(self, conn):
|
def _show_warnings(self):
|
||||||
if self._result and self._result.has_next:
|
if self._warnings_handled:
|
||||||
return
|
return
|
||||||
ws = conn.show_warnings()
|
self._warnings_handled = True
|
||||||
|
if self._result and (self._result.has_next or not self._result.warning_count):
|
||||||
|
return
|
||||||
|
ws = self._get_db().show_warnings()
|
||||||
if ws is None:
|
if ws is None:
|
||||||
return
|
return
|
||||||
for w in ws:
|
for w in ws:
|
||||||
|
@ -322,7 +352,7 @@ class Cursor(object):
|
||||||
if PY2:
|
if PY2:
|
||||||
if isinstance(msg, unicode):
|
if isinstance(msg, unicode):
|
||||||
msg = msg.encode('utf-8', 'replace')
|
msg = msg.encode('utf-8', 'replace')
|
||||||
warnings.warn(str(msg), err.Warning, 4)
|
warnings.warn(err.Warning(*w[1:3]), stacklevel=4)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self.fetchone, None)
|
return iter(self.fetchone, None)
|
||||||
|
@ -373,8 +403,8 @@ class SSCursor(Cursor):
|
||||||
or for connections to remote servers over a slow network.
|
or for connections to remote servers over a slow network.
|
||||||
|
|
||||||
Instead of copying every row of data into a buffer, this will fetch
|
Instead of copying every row of data into a buffer, this will fetch
|
||||||
rows as needed. The upside of this, is the client uses much less memory,
|
rows as needed. The upside of this is the client uses much less memory,
|
||||||
and rows are returned much faster when traveling over a slow network,
|
and rows are returned much faster when traveling over a slow network
|
||||||
or if the result set is very big.
|
or if the result set is very big.
|
||||||
|
|
||||||
There are limitations, though. The MySQL protocol doesn't support
|
There are limitations, though. The MySQL protocol doesn't support
|
||||||
|
@ -383,6 +413,8 @@ class SSCursor(Cursor):
|
||||||
possible to scroll backwards, as only the current row is held in memory.
|
possible to scroll backwards, as only the current row is held in memory.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_defer_warnings = True
|
||||||
|
|
||||||
def _conv_row(self, row):
|
def _conv_row(self, row):
|
||||||
return row
|
return row
|
||||||
|
|
||||||
|
@ -411,14 +443,15 @@ class SSCursor(Cursor):
|
||||||
return self._nextset(unbuffered=True)
|
return self._nextset(unbuffered=True)
|
||||||
|
|
||||||
def read_next(self):
|
def read_next(self):
|
||||||
""" Read next row """
|
"""Read next row"""
|
||||||
return self._conv_row(self._result._read_rowdata_packet_unbuffered())
|
return self._conv_row(self._result._read_rowdata_packet_unbuffered())
|
||||||
|
|
||||||
def fetchone(self):
|
def fetchone(self):
|
||||||
""" Fetch next row """
|
"""Fetch next row"""
|
||||||
self._check_executed()
|
self._check_executed()
|
||||||
row = self.read_next()
|
row = self.read_next()
|
||||||
if row is None:
|
if row is None:
|
||||||
|
self._show_warnings()
|
||||||
return None
|
return None
|
||||||
self.rownumber += 1
|
self.rownumber += 1
|
||||||
return row
|
return row
|
||||||
|
@ -443,7 +476,7 @@ class SSCursor(Cursor):
|
||||||
return self.fetchall_unbuffered()
|
return self.fetchall_unbuffered()
|
||||||
|
|
||||||
def fetchmany(self, size=None):
|
def fetchmany(self, size=None):
|
||||||
""" Fetch many """
|
"""Fetch many"""
|
||||||
self._check_executed()
|
self._check_executed()
|
||||||
if size is None:
|
if size is None:
|
||||||
size = self.arraysize
|
size = self.arraysize
|
||||||
|
@ -452,6 +485,7 @@ class SSCursor(Cursor):
|
||||||
for i in range_type(size):
|
for i in range_type(size):
|
||||||
row = self.read_next()
|
row = self.read_next()
|
||||||
if row is None:
|
if row is None:
|
||||||
|
self._show_warnings()
|
||||||
break
|
break
|
||||||
rows.append(row)
|
rows.append(row)
|
||||||
self.rownumber += 1
|
self.rownumber += 1
|
||||||
|
@ -482,4 +516,4 @@ class SSCursor(Cursor):
|
||||||
|
|
||||||
|
|
||||||
class SSDictCursor(DictCursorMixin, SSCursor):
|
class SSDictCursor(DictCursorMixin, SSCursor):
|
||||||
""" An unbuffered cursor, which returns results as a dictionary """
|
"""An unbuffered cursor, which returns results as a dictionary"""
|
||||||
|
|
|
@ -68,10 +68,12 @@ class NotSupportedError(DatabaseError):
|
||||||
|
|
||||||
error_map = {}
|
error_map = {}
|
||||||
|
|
||||||
|
|
||||||
def _map_error(exc, *errors):
|
def _map_error(exc, *errors):
|
||||||
for error in errors:
|
for error in errors:
|
||||||
error_map[error] = exc
|
error_map[error] = exc
|
||||||
|
|
||||||
|
|
||||||
_map_error(ProgrammingError, ER.DB_CREATE_EXISTS, ER.SYNTAX_ERROR,
|
_map_error(ProgrammingError, ER.DB_CREATE_EXISTS, ER.SYNTAX_ERROR,
|
||||||
ER.PARSE_ERROR, ER.NO_SUCH_TABLE, ER.WRONG_DB_NAME,
|
ER.PARSE_ERROR, ER.NO_SUCH_TABLE, ER.WRONG_DB_NAME,
|
||||||
ER.WRONG_TABLE_NAME, ER.FIELD_SPECIFIED_TWICE,
|
ER.WRONG_TABLE_NAME, ER.FIELD_SPECIFIED_TWICE,
|
||||||
|
@ -89,32 +91,17 @@ _map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
|
||||||
ER.CON_COUNT_ERROR, ER.TABLEACCESS_DENIED_ERROR,
|
ER.CON_COUNT_ERROR, ER.TABLEACCESS_DENIED_ERROR,
|
||||||
ER.COLUMNACCESS_DENIED_ERROR)
|
ER.COLUMNACCESS_DENIED_ERROR)
|
||||||
|
|
||||||
|
|
||||||
del _map_error, ER
|
del _map_error, ER
|
||||||
|
|
||||||
|
|
||||||
def _get_error_info(data):
|
def raise_mysql_exception(data):
|
||||||
errno = struct.unpack('<h', data[1:3])[0]
|
errno = struct.unpack('<h', data[1:3])[0]
|
||||||
is_41 = data[3:4] == b"#"
|
is_41 = data[3:4] == b"#"
|
||||||
if is_41:
|
if is_41:
|
||||||
# version 4.1
|
# client protocol 4.1
|
||||||
sqlstate = data[4:9].decode("utf8", 'replace')
|
errval = data[9:].decode('utf-8', 'replace')
|
||||||
errorvalue = data[9:].decode("utf8", 'replace')
|
|
||||||
return (errno, sqlstate, errorvalue)
|
|
||||||
else:
|
else:
|
||||||
# version 4.0
|
errval = data[3:].decode('utf-8', 'replace')
|
||||||
return (errno, None, data[3:].decode("utf8", 'replace'))
|
errorclass = error_map.get(errno, InternalError)
|
||||||
|
raise errorclass(errno, errval)
|
||||||
|
|
||||||
def _check_mysql_exception(errinfo):
|
|
||||||
errno, sqlstate, errorvalue = errinfo
|
|
||||||
errorclass = error_map.get(errno, None)
|
|
||||||
if errorclass:
|
|
||||||
raise errorclass(errno, errorvalue)
|
|
||||||
|
|
||||||
# couldn't find the right error number
|
|
||||||
raise InternalError(errno, errorvalue)
|
|
||||||
|
|
||||||
|
|
||||||
def raise_mysql_exception(data):
|
|
||||||
errinfo = _get_error_info(data)
|
|
||||||
_check_mysql_exception(errinfo)
|
|
||||||
|
|
|
@ -1,16 +1,20 @@
|
||||||
from time import localtime
|
from time import localtime
|
||||||
from datetime import date, datetime, time, timedelta
|
from datetime import date, datetime, time, timedelta
|
||||||
|
|
||||||
|
|
||||||
Date = date
|
Date = date
|
||||||
Time = time
|
Time = time
|
||||||
TimeDelta = timedelta
|
TimeDelta = timedelta
|
||||||
Timestamp = datetime
|
Timestamp = datetime
|
||||||
|
|
||||||
|
|
||||||
def DateFromTicks(ticks):
|
def DateFromTicks(ticks):
|
||||||
return date(*localtime(ticks)[:3])
|
return date(*localtime(ticks)[:3])
|
||||||
|
|
||||||
|
|
||||||
def TimeFromTicks(ticks):
|
def TimeFromTicks(ticks):
|
||||||
return time(*localtime(ticks)[3:6])
|
return time(*localtime(ticks)[3:6])
|
||||||
|
|
||||||
|
|
||||||
def TimestampFromTicks(ticks):
|
def TimestampFromTicks(ticks):
|
||||||
return datetime(*localtime(ticks)[:6])
|
return datetime(*localtime(ticks)[:6])
|
||||||
|
|
|
@ -1,14 +1,17 @@
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
|
||||||
def byte2int(b):
|
def byte2int(b):
|
||||||
if isinstance(b, int):
|
if isinstance(b, int):
|
||||||
return b
|
return b
|
||||||
else:
|
else:
|
||||||
return struct.unpack("!B", b)[0]
|
return struct.unpack("!B", b)[0]
|
||||||
|
|
||||||
|
|
||||||
def int2byte(i):
|
def int2byte(i):
|
||||||
return struct.pack("!B", i)
|
return struct.pack("!B", i)
|
||||||
|
|
||||||
|
|
||||||
def join_bytes(bs):
|
def join_bytes(bs):
|
||||||
if len(bs) == 0:
|
if len(bs) == 0:
|
||||||
return ""
|
return ""
|
||||||
|
|
|
@ -1,45 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
"""RSA module
|
|
||||||
|
|
||||||
Module for calculating large primes, and RSA encryption, decryption, signing
|
|
||||||
and verification. Includes generating public and private keys.
|
|
||||||
|
|
||||||
WARNING: this implementation does not use random padding, compression of the
|
|
||||||
cleartext input to prevent repetitions, or other common security improvements.
|
|
||||||
Use with care.
|
|
||||||
|
|
||||||
If you want to have a more secure implementation, use the functions from the
|
|
||||||
``rsa.pkcs1`` module.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
__author__ = "Sybren Stuvel, Barry Mead and Yesudeep Mangalapilly"
|
|
||||||
__date__ = "2015-07-29"
|
|
||||||
__version__ = '3.2'
|
|
||||||
|
|
||||||
from rsa.key import newkeys, PrivateKey, PublicKey
|
|
||||||
from rsa.pkcs1 import encrypt, decrypt, sign, verify, DecryptionError, \
|
|
||||||
VerificationError
|
|
||||||
|
|
||||||
# Do doctest if we're run directly
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
||||||
|
|
||||||
__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify", 'PublicKey',
|
|
||||||
'PrivateKey', 'DecryptionError', 'VerificationError']
|
|
||||||
|
|
|
@ -1,160 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
"""Python compatibility wrappers."""
|
|
||||||
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from struct import pack
|
|
||||||
|
|
||||||
try:
|
|
||||||
MAX_INT = sys.maxsize
|
|
||||||
except AttributeError:
|
|
||||||
MAX_INT = sys.maxint
|
|
||||||
|
|
||||||
MAX_INT64 = (1 << 63) - 1
|
|
||||||
MAX_INT32 = (1 << 31) - 1
|
|
||||||
MAX_INT16 = (1 << 15) - 1
|
|
||||||
|
|
||||||
# Determine the word size of the processor.
|
|
||||||
if MAX_INT == MAX_INT64:
|
|
||||||
# 64-bit processor.
|
|
||||||
MACHINE_WORD_SIZE = 64
|
|
||||||
elif MAX_INT == MAX_INT32:
|
|
||||||
# 32-bit processor.
|
|
||||||
MACHINE_WORD_SIZE = 32
|
|
||||||
else:
|
|
||||||
# Else we just assume 64-bit processor keeping up with modern times.
|
|
||||||
MACHINE_WORD_SIZE = 64
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
# < Python3
|
|
||||||
unicode_type = unicode
|
|
||||||
have_python3 = False
|
|
||||||
except NameError:
|
|
||||||
# Python3.
|
|
||||||
unicode_type = str
|
|
||||||
have_python3 = True
|
|
||||||
|
|
||||||
# Fake byte literals.
|
|
||||||
if str is unicode_type:
|
|
||||||
def byte_literal(s):
|
|
||||||
return s.encode('latin1')
|
|
||||||
else:
|
|
||||||
def byte_literal(s):
|
|
||||||
return s
|
|
||||||
|
|
||||||
# ``long`` is no more. Do type detection using this instead.
|
|
||||||
try:
|
|
||||||
integer_types = (int, long)
|
|
||||||
except NameError:
|
|
||||||
integer_types = (int,)
|
|
||||||
|
|
||||||
b = byte_literal
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Python 2.6 or higher.
|
|
||||||
bytes_type = bytes
|
|
||||||
except NameError:
|
|
||||||
# Python 2.5
|
|
||||||
bytes_type = str
|
|
||||||
|
|
||||||
|
|
||||||
# To avoid calling b() multiple times in tight loops.
|
|
||||||
ZERO_BYTE = b('\x00')
|
|
||||||
EMPTY_BYTE = b('')
|
|
||||||
|
|
||||||
|
|
||||||
def is_bytes(obj):
|
|
||||||
"""
|
|
||||||
Determines whether the given value is a byte string.
|
|
||||||
|
|
||||||
:param obj:
|
|
||||||
The value to test.
|
|
||||||
:returns:
|
|
||||||
``True`` if ``value`` is a byte string; ``False`` otherwise.
|
|
||||||
"""
|
|
||||||
return isinstance(obj, bytes_type)
|
|
||||||
|
|
||||||
|
|
||||||
def is_integer(obj):
|
|
||||||
"""
|
|
||||||
Determines whether the given value is an integer.
|
|
||||||
|
|
||||||
:param obj:
|
|
||||||
The value to test.
|
|
||||||
:returns:
|
|
||||||
``True`` if ``value`` is an integer; ``False`` otherwise.
|
|
||||||
"""
|
|
||||||
return isinstance(obj, integer_types)
|
|
||||||
|
|
||||||
|
|
||||||
def byte(num):
|
|
||||||
"""
|
|
||||||
Converts a number between 0 and 255 (both inclusive) to a base-256 (byte)
|
|
||||||
representation.
|
|
||||||
|
|
||||||
Use it as a replacement for ``chr`` where you are expecting a byte
|
|
||||||
because this will work on all current versions of Python::
|
|
||||||
|
|
||||||
:param num:
|
|
||||||
An unsigned integer between 0 and 255 (both inclusive).
|
|
||||||
:returns:
|
|
||||||
A single byte.
|
|
||||||
"""
|
|
||||||
return pack("B", num)
|
|
||||||
|
|
||||||
|
|
||||||
def get_word_alignment(num, force_arch=64,
|
|
||||||
_machine_word_size=MACHINE_WORD_SIZE):
|
|
||||||
"""
|
|
||||||
Returns alignment details for the given number based on the platform
|
|
||||||
Python is running on.
|
|
||||||
|
|
||||||
:param num:
|
|
||||||
Unsigned integral number.
|
|
||||||
:param force_arch:
|
|
||||||
If you don't want to use 64-bit unsigned chunks, set this to
|
|
||||||
anything other than 64. 32-bit chunks will be preferred then.
|
|
||||||
Default 64 will be used when on a 64-bit machine.
|
|
||||||
:param _machine_word_size:
|
|
||||||
(Internal) The machine word size used for alignment.
|
|
||||||
:returns:
|
|
||||||
4-tuple::
|
|
||||||
|
|
||||||
(word_bits, word_bytes,
|
|
||||||
max_uint, packing_format_type)
|
|
||||||
"""
|
|
||||||
max_uint64 = 0xffffffffffffffff
|
|
||||||
max_uint32 = 0xffffffff
|
|
||||||
max_uint16 = 0xffff
|
|
||||||
max_uint8 = 0xff
|
|
||||||
|
|
||||||
if force_arch == 64 and _machine_word_size >= 64 and num > max_uint32:
|
|
||||||
# 64-bit unsigned integer.
|
|
||||||
return 64, 8, max_uint64, "Q"
|
|
||||||
elif num > max_uint16:
|
|
||||||
# 32-bit unsigned integer
|
|
||||||
return 32, 4, max_uint32, "L"
|
|
||||||
elif num > max_uint8:
|
|
||||||
# 16-bit unsigned integer.
|
|
||||||
return 16, 2, max_uint16, "H"
|
|
||||||
else:
|
|
||||||
# 8-bit unsigned integer.
|
|
||||||
return 8, 1, max_uint8, "B"
|
|
|
@ -1,442 +0,0 @@
|
||||||
"""RSA module
|
|
||||||
pri = k[1] //Private part of keys d,p,q
|
|
||||||
|
|
||||||
Module for calculating large primes, and RSA encryption, decryption,
|
|
||||||
signing and verification. Includes generating public and private keys.
|
|
||||||
|
|
||||||
WARNING: this code implements the mathematics of RSA. It is not suitable for
|
|
||||||
real-world secure cryptography purposes. It has not been reviewed by a security
|
|
||||||
expert. It does not include padding of data. There are many ways in which the
|
|
||||||
output of this module, when used without any modification, can be sucessfully
|
|
||||||
attacked.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__author__ = "Sybren Stuvel, Marloes de Boer and Ivo Tamboer"
|
|
||||||
__date__ = "2010-02-05"
|
|
||||||
__version__ = '1.3.3'
|
|
||||||
|
|
||||||
# NOTE: Python's modulo can return negative numbers. We compensate for
|
|
||||||
# this behaviour using the abs() function
|
|
||||||
|
|
||||||
from cPickle import dumps, loads
|
|
||||||
import base64
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
import zlib
|
|
||||||
|
|
||||||
from rsa._compat import byte
|
|
||||||
|
|
||||||
# Display a warning that this insecure version is imported.
|
|
||||||
import warnings
|
|
||||||
warnings.warn('Insecure version of the RSA module is imported as %s, be careful'
|
|
||||||
% __name__)
|
|
||||||
|
|
||||||
def gcd(p, q):
|
|
||||||
"""Returns the greatest common divisor of p and q
|
|
||||||
|
|
||||||
|
|
||||||
>>> gcd(42, 6)
|
|
||||||
6
|
|
||||||
"""
|
|
||||||
if p<q: return gcd(q, p)
|
|
||||||
if q == 0: return p
|
|
||||||
return gcd(q, abs(p%q))
|
|
||||||
|
|
||||||
def bytes2int(bytes):
|
|
||||||
"""Converts a list of bytes or a string to an integer
|
|
||||||
|
|
||||||
>>> (128*256 + 64)*256 + + 15
|
|
||||||
8405007
|
|
||||||
>>> l = [128, 64, 15]
|
|
||||||
>>> bytes2int(l)
|
|
||||||
8405007
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (type(bytes) is types.ListType or type(bytes) is types.StringType):
|
|
||||||
raise TypeError("You must pass a string or a list")
|
|
||||||
|
|
||||||
# Convert byte stream to integer
|
|
||||||
integer = 0
|
|
||||||
for byte in bytes:
|
|
||||||
integer *= 256
|
|
||||||
if type(byte) is types.StringType: byte = ord(byte)
|
|
||||||
integer += byte
|
|
||||||
|
|
||||||
return integer
|
|
||||||
|
|
||||||
def int2bytes(number):
|
|
||||||
"""Converts a number to a string of bytes
|
|
||||||
|
|
||||||
>>> bytes2int(int2bytes(123456789))
|
|
||||||
123456789
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (type(number) is types.LongType or type(number) is types.IntType):
|
|
||||||
raise TypeError("You must pass a long or an int")
|
|
||||||
|
|
||||||
string = ""
|
|
||||||
|
|
||||||
while number > 0:
|
|
||||||
string = "%s%s" % (byte(number & 0xFF), string)
|
|
||||||
number /= 256
|
|
||||||
|
|
||||||
return string
|
|
||||||
|
|
||||||
def fast_exponentiation(a, p, n):
|
|
||||||
"""Calculates r = a^p mod n
|
|
||||||
"""
|
|
||||||
result = a % n
|
|
||||||
remainders = []
|
|
||||||
while p != 1:
|
|
||||||
remainders.append(p & 1)
|
|
||||||
p = p >> 1
|
|
||||||
while remainders:
|
|
||||||
rem = remainders.pop()
|
|
||||||
result = ((a ** rem) * result ** 2) % n
|
|
||||||
return result
|
|
||||||
|
|
||||||
def read_random_int(nbits):
|
|
||||||
"""Reads a random integer of approximately nbits bits rounded up
|
|
||||||
to whole bytes"""
|
|
||||||
|
|
||||||
nbytes = ceil(nbits/8.)
|
|
||||||
randomdata = os.urandom(nbytes)
|
|
||||||
return bytes2int(randomdata)
|
|
||||||
|
|
||||||
def ceil(x):
|
|
||||||
"""ceil(x) -> int(math.ceil(x))"""
|
|
||||||
|
|
||||||
return int(math.ceil(x))
|
|
||||||
|
|
||||||
def randint(minvalue, maxvalue):
|
|
||||||
"""Returns a random integer x with minvalue <= x <= maxvalue"""
|
|
||||||
|
|
||||||
# Safety - get a lot of random data even if the range is fairly
|
|
||||||
# small
|
|
||||||
min_nbits = 32
|
|
||||||
|
|
||||||
# The range of the random numbers we need to generate
|
|
||||||
range = maxvalue - minvalue
|
|
||||||
|
|
||||||
# Which is this number of bytes
|
|
||||||
rangebytes = ceil(math.log(range, 2) / 8.)
|
|
||||||
|
|
||||||
# Convert to bits, but make sure it's always at least min_nbits*2
|
|
||||||
rangebits = max(rangebytes * 8, min_nbits * 2)
|
|
||||||
|
|
||||||
# Take a random number of bits between min_nbits and rangebits
|
|
||||||
nbits = random.randint(min_nbits, rangebits)
|
|
||||||
|
|
||||||
return (read_random_int(nbits) % range) + minvalue
|
|
||||||
|
|
||||||
def fermat_little_theorem(p):
|
|
||||||
"""Returns 1 if p may be prime, and something else if p definitely
|
|
||||||
is not prime"""
|
|
||||||
|
|
||||||
a = randint(1, p-1)
|
|
||||||
return fast_exponentiation(a, p-1, p)
|
|
||||||
|
|
||||||
def jacobi(a, b):
|
|
||||||
"""Calculates the value of the Jacobi symbol (a/b)
|
|
||||||
"""
|
|
||||||
|
|
||||||
if a % b == 0:
|
|
||||||
return 0
|
|
||||||
result = 1
|
|
||||||
while a > 1:
|
|
||||||
if a & 1:
|
|
||||||
if ((a-1)*(b-1) >> 2) & 1:
|
|
||||||
result = -result
|
|
||||||
b, a = a, b % a
|
|
||||||
else:
|
|
||||||
if ((b ** 2 - 1) >> 3) & 1:
|
|
||||||
result = -result
|
|
||||||
a = a >> 1
|
|
||||||
return result
|
|
||||||
|
|
||||||
def jacobi_witness(x, n):
|
|
||||||
"""Returns False if n is an Euler pseudo-prime with base x, and
|
|
||||||
True otherwise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
j = jacobi(x, n) % n
|
|
||||||
f = fast_exponentiation(x, (n-1)/2, n)
|
|
||||||
|
|
||||||
if j == f: return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def randomized_primality_testing(n, k):
|
|
||||||
"""Calculates whether n is composite (which is always correct) or
|
|
||||||
prime (which is incorrect with error probability 2**-k)
|
|
||||||
|
|
||||||
Returns False if the number if composite, and True if it's
|
|
||||||
probably prime.
|
|
||||||
"""
|
|
||||||
|
|
||||||
q = 0.5 # Property of the jacobi_witness function
|
|
||||||
|
|
||||||
# t = int(math.ceil(k / math.log(1/q, 2)))
|
|
||||||
t = ceil(k / math.log(1/q, 2))
|
|
||||||
for i in range(t+1):
|
|
||||||
x = randint(1, n-1)
|
|
||||||
if jacobi_witness(x, n): return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def is_prime(number):
|
|
||||||
"""Returns True if the number is prime, and False otherwise.
|
|
||||||
|
|
||||||
>>> is_prime(42)
|
|
||||||
0
|
|
||||||
>>> is_prime(41)
|
|
||||||
1
|
|
||||||
"""
|
|
||||||
|
|
||||||
"""
|
|
||||||
if not fermat_little_theorem(number) == 1:
|
|
||||||
# Not prime, according to Fermat's little theorem
|
|
||||||
return False
|
|
||||||
"""
|
|
||||||
|
|
||||||
if randomized_primality_testing(number, 5):
|
|
||||||
# Prime, according to Jacobi
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Not prime
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def getprime(nbits):
|
|
||||||
"""Returns a prime number of max. 'math.ceil(nbits/8)*8' bits. In
|
|
||||||
other words: nbits is rounded up to whole bytes.
|
|
||||||
|
|
||||||
>>> p = getprime(8)
|
|
||||||
>>> is_prime(p-1)
|
|
||||||
0
|
|
||||||
>>> is_prime(p)
|
|
||||||
1
|
|
||||||
>>> is_prime(p+1)
|
|
||||||
0
|
|
||||||
"""
|
|
||||||
|
|
||||||
nbytes = int(math.ceil(nbits/8.))
|
|
||||||
|
|
||||||
while True:
|
|
||||||
integer = read_random_int(nbits)
|
|
||||||
|
|
||||||
# Make sure it's odd
|
|
||||||
integer |= 1
|
|
||||||
|
|
||||||
# Test for primeness
|
|
||||||
if is_prime(integer): break
|
|
||||||
|
|
||||||
# Retry if not prime
|
|
||||||
|
|
||||||
return integer
|
|
||||||
|
|
||||||
def are_relatively_prime(a, b):
|
|
||||||
"""Returns True if a and b are relatively prime, and False if they
|
|
||||||
are not.
|
|
||||||
|
|
||||||
>>> are_relatively_prime(2, 3)
|
|
||||||
1
|
|
||||||
>>> are_relatively_prime(2, 4)
|
|
||||||
0
|
|
||||||
"""
|
|
||||||
|
|
||||||
d = gcd(a, b)
|
|
||||||
return (d == 1)
|
|
||||||
|
|
||||||
def find_p_q(nbits):
|
|
||||||
"""Returns a tuple of two different primes of nbits bits"""
|
|
||||||
|
|
||||||
p = getprime(nbits)
|
|
||||||
while True:
|
|
||||||
q = getprime(nbits)
|
|
||||||
if not q == p: break
|
|
||||||
|
|
||||||
return (p, q)
|
|
||||||
|
|
||||||
def extended_euclid_gcd(a, b):
|
|
||||||
"""Returns a tuple (d, i, j) such that d = gcd(a, b) = ia + jb
|
|
||||||
"""
|
|
||||||
|
|
||||||
if b == 0:
|
|
||||||
return (a, 1, 0)
|
|
||||||
|
|
||||||
q = abs(a % b)
|
|
||||||
r = long(a / b)
|
|
||||||
(d, k, l) = extended_euclid_gcd(b, q)
|
|
||||||
|
|
||||||
return (d, l, k - l*r)
|
|
||||||
|
|
||||||
# Main function: calculate encryption and decryption keys
|
|
||||||
def calculate_keys(p, q, nbits):
|
|
||||||
"""Calculates an encryption and a decryption key for p and q, and
|
|
||||||
returns them as a tuple (e, d)"""
|
|
||||||
|
|
||||||
n = p * q
|
|
||||||
phi_n = (p-1) * (q-1)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Make sure e has enough bits so we ensure "wrapping" through
|
|
||||||
# modulo n
|
|
||||||
e = getprime(max(8, nbits/2))
|
|
||||||
if are_relatively_prime(e, n) and are_relatively_prime(e, phi_n): break
|
|
||||||
|
|
||||||
(d, i, j) = extended_euclid_gcd(e, phi_n)
|
|
||||||
|
|
||||||
if not d == 1:
|
|
||||||
raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n))
|
|
||||||
|
|
||||||
if not (e * i) % phi_n == 1:
|
|
||||||
raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n))
|
|
||||||
|
|
||||||
return (e, i)
|
|
||||||
|
|
||||||
|
|
||||||
def gen_keys(nbits):
|
|
||||||
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
|
|
||||||
|
|
||||||
Note: this can take a long time, depending on the key size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
while True:
|
|
||||||
(p, q) = find_p_q(nbits)
|
|
||||||
(e, d) = calculate_keys(p, q, nbits)
|
|
||||||
|
|
||||||
# For some reason, d is sometimes negative. We don't know how
|
|
||||||
# to fix it (yet), so we keep trying until everything is shiny
|
|
||||||
if d > 0: break
|
|
||||||
|
|
||||||
return (p, q, e, d)
|
|
||||||
|
|
||||||
def gen_pubpriv_keys(nbits):
|
|
||||||
"""Generates public and private keys, and returns them as (pub,
|
|
||||||
priv).
|
|
||||||
|
|
||||||
The public key consists of a dict {e: ..., , n: ....). The private
|
|
||||||
key consists of a dict {d: ...., p: ...., q: ....).
|
|
||||||
"""
|
|
||||||
|
|
||||||
(p, q, e, d) = gen_keys(nbits)
|
|
||||||
|
|
||||||
return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} )
|
|
||||||
|
|
||||||
def encrypt_int(message, ekey, n):
|
|
||||||
"""Encrypts a message using encryption key 'ekey', working modulo
|
|
||||||
n"""
|
|
||||||
|
|
||||||
if type(message) is types.IntType:
|
|
||||||
return encrypt_int(long(message), ekey, n)
|
|
||||||
|
|
||||||
if not type(message) is types.LongType:
|
|
||||||
raise TypeError("You must pass a long or an int")
|
|
||||||
|
|
||||||
if message > 0 and \
|
|
||||||
math.floor(math.log(message, 2)) > math.floor(math.log(n, 2)):
|
|
||||||
raise OverflowError("The message is too long")
|
|
||||||
|
|
||||||
return fast_exponentiation(message, ekey, n)
|
|
||||||
|
|
||||||
def decrypt_int(cyphertext, dkey, n):
|
|
||||||
"""Decrypts a cypher text using the decryption key 'dkey', working
|
|
||||||
modulo n"""
|
|
||||||
|
|
||||||
return encrypt_int(cyphertext, dkey, n)
|
|
||||||
|
|
||||||
def sign_int(message, dkey, n):
|
|
||||||
"""Signs 'message' using key 'dkey', working modulo n"""
|
|
||||||
|
|
||||||
return decrypt_int(message, dkey, n)
|
|
||||||
|
|
||||||
def verify_int(signed, ekey, n):
|
|
||||||
"""verifies 'signed' using key 'ekey', working modulo n"""
|
|
||||||
|
|
||||||
return encrypt_int(signed, ekey, n)
|
|
||||||
|
|
||||||
def picklechops(chops):
|
|
||||||
"""Pickles and base64encodes it's argument chops"""
|
|
||||||
|
|
||||||
value = zlib.compress(dumps(chops))
|
|
||||||
encoded = base64.encodestring(value)
|
|
||||||
return encoded.strip()
|
|
||||||
|
|
||||||
def unpicklechops(string):
|
|
||||||
"""base64decodes and unpickes it's argument string into chops"""
|
|
||||||
|
|
||||||
return loads(zlib.decompress(base64.decodestring(string)))
|
|
||||||
|
|
||||||
def chopstring(message, key, n, funcref):
|
|
||||||
"""Splits 'message' into chops that are at most as long as n,
|
|
||||||
converts these into integers, and calls funcref(integer, key, n)
|
|
||||||
for each chop.
|
|
||||||
|
|
||||||
Used by 'encrypt' and 'sign'.
|
|
||||||
"""
|
|
||||||
|
|
||||||
msglen = len(message)
|
|
||||||
mbits = msglen * 8
|
|
||||||
nbits = int(math.floor(math.log(n, 2)))
|
|
||||||
nbytes = nbits / 8
|
|
||||||
blocks = msglen / nbytes
|
|
||||||
|
|
||||||
if msglen % nbytes > 0:
|
|
||||||
blocks += 1
|
|
||||||
|
|
||||||
cypher = []
|
|
||||||
|
|
||||||
for bindex in range(blocks):
|
|
||||||
offset = bindex * nbytes
|
|
||||||
block = message[offset:offset+nbytes]
|
|
||||||
value = bytes2int(block)
|
|
||||||
cypher.append(funcref(value, key, n))
|
|
||||||
|
|
||||||
return picklechops(cypher)
|
|
||||||
|
|
||||||
def gluechops(chops, key, n, funcref):
|
|
||||||
"""Glues chops back together into a string. calls
|
|
||||||
funcref(integer, key, n) for each chop.
|
|
||||||
|
|
||||||
Used by 'decrypt' and 'verify'.
|
|
||||||
"""
|
|
||||||
message = ""
|
|
||||||
|
|
||||||
chops = unpicklechops(chops)
|
|
||||||
|
|
||||||
for cpart in chops:
|
|
||||||
mpart = funcref(cpart, key, n)
|
|
||||||
message += int2bytes(mpart)
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
def encrypt(message, key):
|
|
||||||
"""Encrypts a string 'message' with the public key 'key'"""
|
|
||||||
|
|
||||||
return chopstring(message, key['e'], key['n'], encrypt_int)
|
|
||||||
|
|
||||||
def sign(message, key):
|
|
||||||
"""Signs a string 'message' with the private key 'key'"""
|
|
||||||
|
|
||||||
return chopstring(message, key['d'], key['p']*key['q'], decrypt_int)
|
|
||||||
|
|
||||||
def decrypt(cypher, key):
|
|
||||||
"""Decrypts a cypher with the private key 'key'"""
|
|
||||||
|
|
||||||
return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int)
|
|
||||||
|
|
||||||
def verify(cypher, key):
|
|
||||||
"""Verifies a cypher with the public key 'key'"""
|
|
||||||
|
|
||||||
return gluechops(cypher, key['e'], key['n'], encrypt_int)
|
|
||||||
|
|
||||||
# Do doctest if we're not imported
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
||||||
|
|
||||||
__all__ = ["gen_pubpriv_keys", "encrypt", "decrypt", "sign", "verify"]
|
|
||||||
|
|
|
@ -1,529 +0,0 @@
|
||||||
"""RSA module
|
|
||||||
|
|
||||||
Module for calculating large primes, and RSA encryption, decryption,
|
|
||||||
signing and verification. Includes generating public and private keys.
|
|
||||||
|
|
||||||
WARNING: this implementation does not use random padding, compression of the
|
|
||||||
cleartext input to prevent repetitions, or other common security improvements.
|
|
||||||
Use with care.
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
__author__ = "Sybren Stuvel, Marloes de Boer, Ivo Tamboer, and Barry Mead"
|
|
||||||
__date__ = "2010-02-08"
|
|
||||||
__version__ = '2.0'
|
|
||||||
|
|
||||||
import math
|
|
||||||
import os
|
|
||||||
import random
|
|
||||||
import sys
|
|
||||||
import types
|
|
||||||
from rsa._compat import byte
|
|
||||||
|
|
||||||
# Display a warning that this insecure version is imported.
|
|
||||||
import warnings
|
|
||||||
warnings.warn('Insecure version of the RSA module is imported as %s' % __name__)
|
|
||||||
|
|
||||||
|
|
||||||
def bit_size(number):
|
|
||||||
"""Returns the number of bits required to hold a specific long number"""
|
|
||||||
|
|
||||||
return int(math.ceil(math.log(number,2)))
|
|
||||||
|
|
||||||
def gcd(p, q):
|
|
||||||
"""Returns the greatest common divisor of p and q
|
|
||||||
>>> gcd(48, 180)
|
|
||||||
12
|
|
||||||
"""
|
|
||||||
# Iterateive Version is faster and uses much less stack space
|
|
||||||
while q != 0:
|
|
||||||
if p < q: (p,q) = (q,p)
|
|
||||||
(p,q) = (q, p % q)
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def bytes2int(bytes):
|
|
||||||
"""Converts a list of bytes or a string to an integer
|
|
||||||
|
|
||||||
>>> (((128 * 256) + 64) * 256) + 15
|
|
||||||
8405007
|
|
||||||
>>> l = [128, 64, 15]
|
|
||||||
>>> bytes2int(l) #same as bytes2int('\x80@\x0f')
|
|
||||||
8405007
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (type(bytes) is types.ListType or type(bytes) is types.StringType):
|
|
||||||
raise TypeError("You must pass a string or a list")
|
|
||||||
|
|
||||||
# Convert byte stream to integer
|
|
||||||
integer = 0
|
|
||||||
for byte in bytes:
|
|
||||||
integer *= 256
|
|
||||||
if type(byte) is types.StringType: byte = ord(byte)
|
|
||||||
integer += byte
|
|
||||||
|
|
||||||
return integer
|
|
||||||
|
|
||||||
def int2bytes(number):
|
|
||||||
"""
|
|
||||||
Converts a number to a string of bytes
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (type(number) is types.LongType or type(number) is types.IntType):
|
|
||||||
raise TypeError("You must pass a long or an int")
|
|
||||||
|
|
||||||
string = ""
|
|
||||||
|
|
||||||
while number > 0:
|
|
||||||
string = "%s%s" % (byte(number & 0xFF), string)
|
|
||||||
number /= 256
|
|
||||||
|
|
||||||
return string
|
|
||||||
|
|
||||||
def to64(number):
|
|
||||||
"""Converts a number in the range of 0 to 63 into base 64 digit
|
|
||||||
character in the range of '0'-'9', 'A'-'Z', 'a'-'z','-','_'.
|
|
||||||
|
|
||||||
>>> to64(10)
|
|
||||||
'A'
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (type(number) is types.LongType or type(number) is types.IntType):
|
|
||||||
raise TypeError("You must pass a long or an int")
|
|
||||||
|
|
||||||
if 0 <= number <= 9: #00-09 translates to '0' - '9'
|
|
||||||
return byte(number + 48)
|
|
||||||
|
|
||||||
if 10 <= number <= 35:
|
|
||||||
return byte(number + 55) #10-35 translates to 'A' - 'Z'
|
|
||||||
|
|
||||||
if 36 <= number <= 61:
|
|
||||||
return byte(number + 61) #36-61 translates to 'a' - 'z'
|
|
||||||
|
|
||||||
if number == 62: # 62 translates to '-' (minus)
|
|
||||||
return byte(45)
|
|
||||||
|
|
||||||
if number == 63: # 63 translates to '_' (underscore)
|
|
||||||
return byte(95)
|
|
||||||
|
|
||||||
raise ValueError('Invalid Base64 value: %i' % number)
|
|
||||||
|
|
||||||
|
|
||||||
def from64(number):
|
|
||||||
"""Converts an ordinal character value in the range of
|
|
||||||
0-9,A-Z,a-z,-,_ to a number in the range of 0-63.
|
|
||||||
|
|
||||||
>>> from64(49)
|
|
||||||
1
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (type(number) is types.LongType or type(number) is types.IntType):
|
|
||||||
raise TypeError("You must pass a long or an int")
|
|
||||||
|
|
||||||
if 48 <= number <= 57: #ord('0') - ord('9') translates to 0-9
|
|
||||||
return(number - 48)
|
|
||||||
|
|
||||||
if 65 <= number <= 90: #ord('A') - ord('Z') translates to 10-35
|
|
||||||
return(number - 55)
|
|
||||||
|
|
||||||
if 97 <= number <= 122: #ord('a') - ord('z') translates to 36-61
|
|
||||||
return(number - 61)
|
|
||||||
|
|
||||||
if number == 45: #ord('-') translates to 62
|
|
||||||
return(62)
|
|
||||||
|
|
||||||
if number == 95: #ord('_') translates to 63
|
|
||||||
return(63)
|
|
||||||
|
|
||||||
raise ValueError('Invalid Base64 value: %i' % number)
|
|
||||||
|
|
||||||
|
|
||||||
def int2str64(number):
|
|
||||||
"""Converts a number to a string of base64 encoded characters in
|
|
||||||
the range of '0'-'9','A'-'Z,'a'-'z','-','_'.
|
|
||||||
|
|
||||||
>>> int2str64(123456789)
|
|
||||||
'7MyqL'
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (type(number) is types.LongType or type(number) is types.IntType):
|
|
||||||
raise TypeError("You must pass a long or an int")
|
|
||||||
|
|
||||||
string = ""
|
|
||||||
|
|
||||||
while number > 0:
|
|
||||||
string = "%s%s" % (to64(number & 0x3F), string)
|
|
||||||
number /= 64
|
|
||||||
|
|
||||||
return string
|
|
||||||
|
|
||||||
|
|
||||||
def str642int(string):
|
|
||||||
"""Converts a base64 encoded string into an integer.
|
|
||||||
The chars of this string in in the range '0'-'9','A'-'Z','a'-'z','-','_'
|
|
||||||
|
|
||||||
>>> str642int('7MyqL')
|
|
||||||
123456789
|
|
||||||
"""
|
|
||||||
|
|
||||||
if not (type(string) is types.ListType or type(string) is types.StringType):
|
|
||||||
raise TypeError("You must pass a string or a list")
|
|
||||||
|
|
||||||
integer = 0
|
|
||||||
for byte in string:
|
|
||||||
integer *= 64
|
|
||||||
if type(byte) is types.StringType: byte = ord(byte)
|
|
||||||
integer += from64(byte)
|
|
||||||
|
|
||||||
return integer
|
|
||||||
|
|
||||||
def read_random_int(nbits):
|
|
||||||
"""Reads a random integer of approximately nbits bits rounded up
|
|
||||||
to whole bytes"""
|
|
||||||
|
|
||||||
nbytes = int(math.ceil(nbits/8.))
|
|
||||||
randomdata = os.urandom(nbytes)
|
|
||||||
return bytes2int(randomdata)
|
|
||||||
|
|
||||||
def randint(minvalue, maxvalue):
|
|
||||||
"""Returns a random integer x with minvalue <= x <= maxvalue"""
|
|
||||||
|
|
||||||
# Safety - get a lot of random data even if the range is fairly
|
|
||||||
# small
|
|
||||||
min_nbits = 32
|
|
||||||
|
|
||||||
# The range of the random numbers we need to generate
|
|
||||||
range = (maxvalue - minvalue) + 1
|
|
||||||
|
|
||||||
# Which is this number of bytes
|
|
||||||
rangebytes = ((bit_size(range) + 7) / 8)
|
|
||||||
|
|
||||||
# Convert to bits, but make sure it's always at least min_nbits*2
|
|
||||||
rangebits = max(rangebytes * 8, min_nbits * 2)
|
|
||||||
|
|
||||||
# Take a random number of bits between min_nbits and rangebits
|
|
||||||
nbits = random.randint(min_nbits, rangebits)
|
|
||||||
|
|
||||||
return (read_random_int(nbits) % range) + minvalue
|
|
||||||
|
|
||||||
def jacobi(a, b):
|
|
||||||
"""Calculates the value of the Jacobi symbol (a/b)
|
|
||||||
where both a and b are positive integers, and b is odd
|
|
||||||
"""
|
|
||||||
|
|
||||||
if a == 0: return 0
|
|
||||||
result = 1
|
|
||||||
while a > 1:
|
|
||||||
if a & 1:
|
|
||||||
if ((a-1)*(b-1) >> 2) & 1:
|
|
||||||
result = -result
|
|
||||||
a, b = b % a, a
|
|
||||||
else:
|
|
||||||
if (((b * b) - 1) >> 3) & 1:
|
|
||||||
result = -result
|
|
||||||
a >>= 1
|
|
||||||
if a == 0: return 0
|
|
||||||
return result
|
|
||||||
|
|
||||||
def jacobi_witness(x, n):
|
|
||||||
"""Returns False if n is an Euler pseudo-prime with base x, and
|
|
||||||
True otherwise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
j = jacobi(x, n) % n
|
|
||||||
f = pow(x, (n-1)/2, n)
|
|
||||||
|
|
||||||
if j == f: return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def randomized_primality_testing(n, k):
|
|
||||||
"""Calculates whether n is composite (which is always correct) or
|
|
||||||
prime (which is incorrect with error probability 2**-k)
|
|
||||||
|
|
||||||
Returns False if the number is composite, and True if it's
|
|
||||||
probably prime.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# 50% of Jacobi-witnesses can report compositness of non-prime numbers
|
|
||||||
|
|
||||||
for i in range(k):
|
|
||||||
x = randint(1, n-1)
|
|
||||||
if jacobi_witness(x, n): return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def is_prime(number):
|
|
||||||
"""Returns True if the number is prime, and False otherwise.
|
|
||||||
|
|
||||||
>>> is_prime(42)
|
|
||||||
0
|
|
||||||
>>> is_prime(41)
|
|
||||||
1
|
|
||||||
"""
|
|
||||||
|
|
||||||
if randomized_primality_testing(number, 6):
|
|
||||||
# Prime, according to Jacobi
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Not prime
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def getprime(nbits):
|
|
||||||
"""Returns a prime number of max. 'math.ceil(nbits/8)*8' bits. In
|
|
||||||
other words: nbits is rounded up to whole bytes.
|
|
||||||
|
|
||||||
>>> p = getprime(8)
|
|
||||||
>>> is_prime(p-1)
|
|
||||||
0
|
|
||||||
>>> is_prime(p)
|
|
||||||
1
|
|
||||||
>>> is_prime(p+1)
|
|
||||||
0
|
|
||||||
"""
|
|
||||||
|
|
||||||
while True:
|
|
||||||
integer = read_random_int(nbits)
|
|
||||||
|
|
||||||
# Make sure it's odd
|
|
||||||
integer |= 1
|
|
||||||
|
|
||||||
# Test for primeness
|
|
||||||
if is_prime(integer): break
|
|
||||||
|
|
||||||
# Retry if not prime
|
|
||||||
|
|
||||||
return integer
|
|
||||||
|
|
||||||
def are_relatively_prime(a, b):
|
|
||||||
"""Returns True if a and b are relatively prime, and False if they
|
|
||||||
are not.
|
|
||||||
|
|
||||||
>>> are_relatively_prime(2, 3)
|
|
||||||
1
|
|
||||||
>>> are_relatively_prime(2, 4)
|
|
||||||
0
|
|
||||||
"""
|
|
||||||
|
|
||||||
d = gcd(a, b)
|
|
||||||
return (d == 1)
|
|
||||||
|
|
||||||
def find_p_q(nbits):
|
|
||||||
"""Returns a tuple of two different primes of nbits bits"""
|
|
||||||
pbits = nbits + (nbits/16) #Make sure that p and q aren't too close
|
|
||||||
qbits = nbits - (nbits/16) #or the factoring programs can factor n
|
|
||||||
p = getprime(pbits)
|
|
||||||
while True:
|
|
||||||
q = getprime(qbits)
|
|
||||||
#Make sure p and q are different.
|
|
||||||
if not q == p: break
|
|
||||||
return (p, q)
|
|
||||||
|
|
||||||
def extended_gcd(a, b):
|
|
||||||
"""Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
|
|
||||||
"""
|
|
||||||
# r = gcd(a,b) i = multiplicitive inverse of a mod b
|
|
||||||
# or j = multiplicitive inverse of b mod a
|
|
||||||
# Neg return values for i or j are made positive mod b or a respectively
|
|
||||||
# Iterateive Version is faster and uses much less stack space
|
|
||||||
x = 0
|
|
||||||
y = 1
|
|
||||||
lx = 1
|
|
||||||
ly = 0
|
|
||||||
oa = a #Remember original a/b to remove
|
|
||||||
ob = b #negative values from return results
|
|
||||||
while b != 0:
|
|
||||||
q = long(a/b)
|
|
||||||
(a, b) = (b, a % b)
|
|
||||||
(x, lx) = ((lx - (q * x)),x)
|
|
||||||
(y, ly) = ((ly - (q * y)),y)
|
|
||||||
if (lx < 0): lx += ob #If neg wrap modulo orignal b
|
|
||||||
if (ly < 0): ly += oa #If neg wrap modulo orignal a
|
|
||||||
return (a, lx, ly) #Return only positive values
|
|
||||||
|
|
||||||
# Main function: calculate encryption and decryption keys
|
|
||||||
def calculate_keys(p, q, nbits):
|
|
||||||
"""Calculates an encryption and a decryption key for p and q, and
|
|
||||||
returns them as a tuple (e, d)"""
|
|
||||||
|
|
||||||
n = p * q
|
|
||||||
phi_n = (p-1) * (q-1)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
# Make sure e has enough bits so we ensure "wrapping" through
|
|
||||||
# modulo n
|
|
||||||
e = max(65537,getprime(nbits/4))
|
|
||||||
if are_relatively_prime(e, n) and are_relatively_prime(e, phi_n): break
|
|
||||||
|
|
||||||
(d, i, j) = extended_gcd(e, phi_n)
|
|
||||||
|
|
||||||
if not d == 1:
|
|
||||||
raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n))
|
|
||||||
if (i < 0):
|
|
||||||
raise Exception("New extended_gcd shouldn't return negative values")
|
|
||||||
if not (e * i) % phi_n == 1:
|
|
||||||
raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n))
|
|
||||||
|
|
||||||
return (e, i)
|
|
||||||
|
|
||||||
|
|
||||||
def gen_keys(nbits):
|
|
||||||
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
|
|
||||||
|
|
||||||
Note: this can take a long time, depending on the key size.
|
|
||||||
"""
|
|
||||||
|
|
||||||
(p, q) = find_p_q(nbits)
|
|
||||||
(e, d) = calculate_keys(p, q, nbits)
|
|
||||||
|
|
||||||
return (p, q, e, d)
|
|
||||||
|
|
||||||
def newkeys(nbits):
|
|
||||||
"""Generates public and private keys, and returns them as (pub,
|
|
||||||
priv).
|
|
||||||
|
|
||||||
The public key consists of a dict {e: ..., , n: ....). The private
|
|
||||||
key consists of a dict {d: ...., p: ...., q: ....).
|
|
||||||
"""
|
|
||||||
nbits = max(9,nbits) # Don't let nbits go below 9 bits
|
|
||||||
(p, q, e, d) = gen_keys(nbits)
|
|
||||||
|
|
||||||
return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} )
|
|
||||||
|
|
||||||
def encrypt_int(message, ekey, n):
|
|
||||||
"""Encrypts a message using encryption key 'ekey', working modulo n"""
|
|
||||||
|
|
||||||
if type(message) is types.IntType:
|
|
||||||
message = long(message)
|
|
||||||
|
|
||||||
if not type(message) is types.LongType:
|
|
||||||
raise TypeError("You must pass a long or int")
|
|
||||||
|
|
||||||
if message < 0 or message > n:
|
|
||||||
raise OverflowError("The message is too long")
|
|
||||||
|
|
||||||
#Note: Bit exponents start at zero (bit counts start at 1) this is correct
|
|
||||||
safebit = bit_size(n) - 2 #compute safe bit (MSB - 1)
|
|
||||||
message += (1 << safebit) #add safebit to ensure folding
|
|
||||||
|
|
||||||
return pow(message, ekey, n)
|
|
||||||
|
|
||||||
def decrypt_int(cyphertext, dkey, n):
|
|
||||||
"""Decrypts a cypher text using the decryption key 'dkey', working
|
|
||||||
modulo n"""
|
|
||||||
|
|
||||||
message = pow(cyphertext, dkey, n)
|
|
||||||
|
|
||||||
safebit = bit_size(n) - 2 #compute safe bit (MSB - 1)
|
|
||||||
message -= (1 << safebit) #remove safebit before decode
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
def encode64chops(chops):
|
|
||||||
"""base64encodes chops and combines them into a ',' delimited string"""
|
|
||||||
|
|
||||||
chips = [] #chips are character chops
|
|
||||||
|
|
||||||
for value in chops:
|
|
||||||
chips.append(int2str64(value))
|
|
||||||
|
|
||||||
#delimit chops with comma
|
|
||||||
encoded = ','.join(chips)
|
|
||||||
|
|
||||||
return encoded
|
|
||||||
|
|
||||||
def decode64chops(string):
|
|
||||||
"""base64decodes and makes a ',' delimited string into chops"""
|
|
||||||
|
|
||||||
chips = string.split(',') #split chops at commas
|
|
||||||
|
|
||||||
chops = []
|
|
||||||
|
|
||||||
for string in chips: #make char chops (chips) into chops
|
|
||||||
chops.append(str642int(string))
|
|
||||||
|
|
||||||
return chops
|
|
||||||
|
|
||||||
def chopstring(message, key, n, funcref):
|
|
||||||
"""Chops the 'message' into integers that fit into n,
|
|
||||||
leaving room for a safebit to be added to ensure that all
|
|
||||||
messages fold during exponentiation. The MSB of the number n
|
|
||||||
is not independant modulo n (setting it could cause overflow), so
|
|
||||||
use the next lower bit for the safebit. Therefore reserve 2-bits
|
|
||||||
in the number n for non-data bits. Calls specified encryption
|
|
||||||
function for each chop.
|
|
||||||
|
|
||||||
Used by 'encrypt' and 'sign'.
|
|
||||||
"""
|
|
||||||
|
|
||||||
msglen = len(message)
|
|
||||||
mbits = msglen * 8
|
|
||||||
#Set aside 2-bits so setting of safebit won't overflow modulo n.
|
|
||||||
nbits = bit_size(n) - 2 # leave room for safebit
|
|
||||||
nbytes = nbits / 8
|
|
||||||
blocks = msglen / nbytes
|
|
||||||
|
|
||||||
if msglen % nbytes > 0:
|
|
||||||
blocks += 1
|
|
||||||
|
|
||||||
cypher = []
|
|
||||||
|
|
||||||
for bindex in range(blocks):
|
|
||||||
offset = bindex * nbytes
|
|
||||||
block = message[offset:offset+nbytes]
|
|
||||||
value = bytes2int(block)
|
|
||||||
cypher.append(funcref(value, key, n))
|
|
||||||
|
|
||||||
return encode64chops(cypher) #Encode encrypted ints to base64 strings
|
|
||||||
|
|
||||||
def gluechops(string, key, n, funcref):
|
|
||||||
"""Glues chops back together into a string. calls
|
|
||||||
funcref(integer, key, n) for each chop.
|
|
||||||
|
|
||||||
Used by 'decrypt' and 'verify'.
|
|
||||||
"""
|
|
||||||
message = ""
|
|
||||||
|
|
||||||
chops = decode64chops(string) #Decode base64 strings into integer chops
|
|
||||||
|
|
||||||
for cpart in chops:
|
|
||||||
mpart = funcref(cpart, key, n) #Decrypt each chop
|
|
||||||
message += int2bytes(mpart) #Combine decrypted strings into a msg
|
|
||||||
|
|
||||||
return message
|
|
||||||
|
|
||||||
def encrypt(message, key):
|
|
||||||
"""Encrypts a string 'message' with the public key 'key'"""
|
|
||||||
if 'n' not in key:
|
|
||||||
raise Exception("You must use the public key with encrypt")
|
|
||||||
|
|
||||||
return chopstring(message, key['e'], key['n'], encrypt_int)
|
|
||||||
|
|
||||||
def sign(message, key):
|
|
||||||
"""Signs a string 'message' with the private key 'key'"""
|
|
||||||
if 'p' not in key:
|
|
||||||
raise Exception("You must use the private key with sign")
|
|
||||||
|
|
||||||
return chopstring(message, key['d'], key['p']*key['q'], encrypt_int)
|
|
||||||
|
|
||||||
def decrypt(cypher, key):
|
|
||||||
"""Decrypts a string 'cypher' with the private key 'key'"""
|
|
||||||
if 'p' not in key:
|
|
||||||
raise Exception("You must use the private key with decrypt")
|
|
||||||
|
|
||||||
return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int)
|
|
||||||
|
|
||||||
def verify(cypher, key):
|
|
||||||
"""Verifies a string 'cypher' with the public key 'key'"""
|
|
||||||
if 'n' not in key:
|
|
||||||
raise Exception("You must use the public key with verify")
|
|
||||||
|
|
||||||
return gluechops(cypher, key['e'], key['n'], decrypt_int)
|
|
||||||
|
|
||||||
# Do doctest if we're not imported
|
|
||||||
if __name__ == "__main__":
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
||||||
|
|
||||||
__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify"]
|
|
||||||
|
|
|
@ -1,35 +0,0 @@
|
||||||
'''ASN.1 definitions.
|
|
||||||
|
|
||||||
Not all ASN.1-handling code use these definitions, but when it does, they should be here.
|
|
||||||
'''
|
|
||||||
|
|
||||||
from pyasn1.type import univ, namedtype, tag
|
|
||||||
|
|
||||||
class PubKeyHeader(univ.Sequence):
|
|
||||||
componentType = namedtype.NamedTypes(
|
|
||||||
namedtype.NamedType('oid', univ.ObjectIdentifier()),
|
|
||||||
namedtype.NamedType('parameters', univ.Null()),
|
|
||||||
)
|
|
||||||
|
|
||||||
class OpenSSLPubKey(univ.Sequence):
|
|
||||||
componentType = namedtype.NamedTypes(
|
|
||||||
namedtype.NamedType('header', PubKeyHeader()),
|
|
||||||
|
|
||||||
# This little hack (the implicit tag) allows us to get a Bit String as Octet String
|
|
||||||
namedtype.NamedType('key', univ.OctetString().subtype(
|
|
||||||
implicitTag=tag.Tag(tagClass=0, tagFormat=0, tagId=3))),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AsnPubKey(univ.Sequence):
|
|
||||||
'''ASN.1 contents of DER encoded public key:
|
|
||||||
|
|
||||||
RSAPublicKey ::= SEQUENCE {
|
|
||||||
modulus INTEGER, -- n
|
|
||||||
publicExponent INTEGER, -- e
|
|
||||||
'''
|
|
||||||
|
|
||||||
componentType = namedtype.NamedTypes(
|
|
||||||
namedtype.NamedType('modulus', univ.Integer()),
|
|
||||||
namedtype.NamedType('publicExponent', univ.Integer()),
|
|
||||||
)
|
|
|
@ -1,87 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''Large file support
|
|
||||||
|
|
||||||
- break a file into smaller blocks, and encrypt them, and store the
|
|
||||||
encrypted blocks in another file.
|
|
||||||
|
|
||||||
- take such an encrypted files, decrypt its blocks, and reconstruct the
|
|
||||||
original file.
|
|
||||||
|
|
||||||
The encrypted file format is as follows, where || denotes byte concatenation:
|
|
||||||
|
|
||||||
FILE := VERSION || BLOCK || BLOCK ...
|
|
||||||
|
|
||||||
BLOCK := LENGTH || DATA
|
|
||||||
|
|
||||||
LENGTH := varint-encoded length of the subsequent data. Varint comes from
|
|
||||||
Google Protobuf, and encodes an integer into a variable number of bytes.
|
|
||||||
Each byte uses the 7 lowest bits to encode the value. The highest bit set
|
|
||||||
to 1 indicates the next byte is also part of the varint. The last byte will
|
|
||||||
have this bit set to 0.
|
|
||||||
|
|
||||||
This file format is called the VARBLOCK format, in line with the varint format
|
|
||||||
used to denote the block sizes.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
from rsa import key, common, pkcs1, varblock
|
|
||||||
from rsa._compat import byte
|
|
||||||
|
|
||||||
def encrypt_bigfile(infile, outfile, pub_key):
|
|
||||||
'''Encrypts a file, writing it to 'outfile' in VARBLOCK format.
|
|
||||||
|
|
||||||
:param infile: file-like object to read the cleartext from
|
|
||||||
:param outfile: file-like object to write the crypto in VARBLOCK format to
|
|
||||||
:param pub_key: :py:class:`rsa.PublicKey` to encrypt with
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
if not isinstance(pub_key, key.PublicKey):
|
|
||||||
raise TypeError('Public key required, but got %r' % pub_key)
|
|
||||||
|
|
||||||
key_bytes = common.bit_size(pub_key.n) // 8
|
|
||||||
blocksize = key_bytes - 11 # keep space for PKCS#1 padding
|
|
||||||
|
|
||||||
# Write the version number to the VARBLOCK file
|
|
||||||
outfile.write(byte(varblock.VARBLOCK_VERSION))
|
|
||||||
|
|
||||||
# Encrypt and write each block
|
|
||||||
for block in varblock.yield_fixedblocks(infile, blocksize):
|
|
||||||
crypto = pkcs1.encrypt(block, pub_key)
|
|
||||||
|
|
||||||
varblock.write_varint(outfile, len(crypto))
|
|
||||||
outfile.write(crypto)
|
|
||||||
|
|
||||||
def decrypt_bigfile(infile, outfile, priv_key):
|
|
||||||
'''Decrypts an encrypted VARBLOCK file, writing it to 'outfile'
|
|
||||||
|
|
||||||
:param infile: file-like object to read the crypto in VARBLOCK format from
|
|
||||||
:param outfile: file-like object to write the cleartext to
|
|
||||||
:param priv_key: :py:class:`rsa.PrivateKey` to decrypt with
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
if not isinstance(priv_key, key.PrivateKey):
|
|
||||||
raise TypeError('Private key required, but got %r' % priv_key)
|
|
||||||
|
|
||||||
for block in varblock.yield_varblocks(infile):
|
|
||||||
cleartext = pkcs1.decrypt(block, priv_key)
|
|
||||||
outfile.write(cleartext)
|
|
||||||
|
|
||||||
__all__ = ['encrypt_bigfile', 'decrypt_bigfile']
|
|
||||||
|
|
|
@ -1,379 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''Commandline scripts.
|
|
||||||
|
|
||||||
These scripts are called by the executables defined in setup.py.
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import with_statement, print_function
|
|
||||||
|
|
||||||
import abc
|
|
||||||
import sys
|
|
||||||
from optparse import OptionParser
|
|
||||||
|
|
||||||
import rsa
|
|
||||||
import rsa.bigfile
|
|
||||||
import rsa.pkcs1
|
|
||||||
|
|
||||||
HASH_METHODS = sorted(rsa.pkcs1.HASH_METHODS.keys())
|
|
||||||
|
|
||||||
def keygen():
|
|
||||||
'''Key generator.'''
|
|
||||||
|
|
||||||
# Parse the CLI options
|
|
||||||
parser = OptionParser(usage='usage: %prog [options] keysize',
|
|
||||||
description='Generates a new RSA keypair of "keysize" bits.')
|
|
||||||
|
|
||||||
parser.add_option('--pubout', type='string',
|
|
||||||
help='Output filename for the public key. The public key is '
|
|
||||||
'not saved if this option is not present. You can use '
|
|
||||||
'pyrsa-priv2pub to create the public key file later.')
|
|
||||||
|
|
||||||
parser.add_option('-o', '--out', type='string',
|
|
||||||
help='Output filename for the private key. The key is '
|
|
||||||
'written to stdout if this option is not present.')
|
|
||||||
|
|
||||||
parser.add_option('--form',
|
|
||||||
help='key format of the private and public keys - default PEM',
|
|
||||||
choices=('PEM', 'DER'), default='PEM')
|
|
||||||
|
|
||||||
(cli, cli_args) = parser.parse_args(sys.argv[1:])
|
|
||||||
|
|
||||||
if len(cli_args) != 1:
|
|
||||||
parser.print_help()
|
|
||||||
raise SystemExit(1)
|
|
||||||
|
|
||||||
try:
|
|
||||||
keysize = int(cli_args[0])
|
|
||||||
except ValueError:
|
|
||||||
parser.print_help()
|
|
||||||
print('Not a valid number: %s' % cli_args[0], file=sys.stderr)
|
|
||||||
raise SystemExit(1)
|
|
||||||
|
|
||||||
print('Generating %i-bit key' % keysize, file=sys.stderr)
|
|
||||||
(pub_key, priv_key) = rsa.newkeys(keysize)
|
|
||||||
|
|
||||||
|
|
||||||
# Save public key
|
|
||||||
if cli.pubout:
|
|
||||||
print('Writing public key to %s' % cli.pubout, file=sys.stderr)
|
|
||||||
data = pub_key.save_pkcs1(format=cli.form)
|
|
||||||
with open(cli.pubout, 'wb') as outfile:
|
|
||||||
outfile.write(data)
|
|
||||||
|
|
||||||
# Save private key
|
|
||||||
data = priv_key.save_pkcs1(format=cli.form)
|
|
||||||
|
|
||||||
if cli.out:
|
|
||||||
print('Writing private key to %s' % cli.out, file=sys.stderr)
|
|
||||||
with open(cli.out, 'wb') as outfile:
|
|
||||||
outfile.write(data)
|
|
||||||
else:
|
|
||||||
print('Writing private key to stdout', file=sys.stderr)
|
|
||||||
sys.stdout.write(data)
|
|
||||||
|
|
||||||
|
|
||||||
class CryptoOperation(object):
|
|
||||||
'''CLI callable that operates with input, output, and a key.'''
|
|
||||||
|
|
||||||
__metaclass__ = abc.ABCMeta
|
|
||||||
|
|
||||||
keyname = 'public' # or 'private'
|
|
||||||
usage = 'usage: %%prog [options] %(keyname)s_key'
|
|
||||||
description = None
|
|
||||||
operation = 'decrypt'
|
|
||||||
operation_past = 'decrypted'
|
|
||||||
operation_progressive = 'decrypting'
|
|
||||||
input_help = 'Name of the file to %(operation)s. Reads from stdin if ' \
|
|
||||||
'not specified.'
|
|
||||||
output_help = 'Name of the file to write the %(operation_past)s file ' \
|
|
||||||
'to. Written to stdout if this option is not present.'
|
|
||||||
expected_cli_args = 1
|
|
||||||
has_output = True
|
|
||||||
|
|
||||||
key_class = rsa.PublicKey
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
self.usage = self.usage % self.__class__.__dict__
|
|
||||||
self.input_help = self.input_help % self.__class__.__dict__
|
|
||||||
self.output_help = self.output_help % self.__class__.__dict__
|
|
||||||
|
|
||||||
@abc.abstractmethod
|
|
||||||
def perform_operation(self, indata, key, cli_args=None):
|
|
||||||
'''Performs the program's operation.
|
|
||||||
|
|
||||||
Implement in a subclass.
|
|
||||||
|
|
||||||
:returns: the data to write to the output.
|
|
||||||
'''
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
'''Runs the program.'''
|
|
||||||
|
|
||||||
(cli, cli_args) = self.parse_cli()
|
|
||||||
|
|
||||||
key = self.read_key(cli_args[0], cli.keyform)
|
|
||||||
|
|
||||||
indata = self.read_infile(cli.input)
|
|
||||||
|
|
||||||
print(self.operation_progressive.title(), file=sys.stderr)
|
|
||||||
outdata = self.perform_operation(indata, key, cli_args)
|
|
||||||
|
|
||||||
if self.has_output:
|
|
||||||
self.write_outfile(outdata, cli.output)
|
|
||||||
|
|
||||||
def parse_cli(self):
|
|
||||||
'''Parse the CLI options
|
|
||||||
|
|
||||||
:returns: (cli_opts, cli_args)
|
|
||||||
'''
|
|
||||||
|
|
||||||
parser = OptionParser(usage=self.usage, description=self.description)
|
|
||||||
|
|
||||||
parser.add_option('-i', '--input', type='string', help=self.input_help)
|
|
||||||
|
|
||||||
if self.has_output:
|
|
||||||
parser.add_option('-o', '--output', type='string', help=self.output_help)
|
|
||||||
|
|
||||||
parser.add_option('--keyform',
|
|
||||||
help='Key format of the %s key - default PEM' % self.keyname,
|
|
||||||
choices=('PEM', 'DER'), default='PEM')
|
|
||||||
|
|
||||||
(cli, cli_args) = parser.parse_args(sys.argv[1:])
|
|
||||||
|
|
||||||
if len(cli_args) != self.expected_cli_args:
|
|
||||||
parser.print_help()
|
|
||||||
raise SystemExit(1)
|
|
||||||
|
|
||||||
return (cli, cli_args)
|
|
||||||
|
|
||||||
def read_key(self, filename, keyform):
|
|
||||||
'''Reads a public or private key.'''
|
|
||||||
|
|
||||||
print('Reading %s key from %s' % (self.keyname, filename), file=sys.stderr)
|
|
||||||
with open(filename, 'rb') as keyfile:
|
|
||||||
keydata = keyfile.read()
|
|
||||||
|
|
||||||
return self.key_class.load_pkcs1(keydata, keyform)
|
|
||||||
|
|
||||||
def read_infile(self, inname):
|
|
||||||
'''Read the input file'''
|
|
||||||
|
|
||||||
if inname:
|
|
||||||
print('Reading input from %s' % inname, file=sys.stderr)
|
|
||||||
with open(inname, 'rb') as infile:
|
|
||||||
return infile.read()
|
|
||||||
|
|
||||||
print('Reading input from stdin', file=sys.stderr)
|
|
||||||
return sys.stdin.read()
|
|
||||||
|
|
||||||
def write_outfile(self, outdata, outname):
|
|
||||||
'''Write the output file'''
|
|
||||||
|
|
||||||
if outname:
|
|
||||||
print('Writing output to %s' % outname, file=sys.stderr)
|
|
||||||
with open(outname, 'wb') as outfile:
|
|
||||||
outfile.write(outdata)
|
|
||||||
else:
|
|
||||||
print('Writing output to stdout', file=sys.stderr)
|
|
||||||
sys.stdout.write(outdata)
|
|
||||||
|
|
||||||
class EncryptOperation(CryptoOperation):
|
|
||||||
'''Encrypts a file.'''
|
|
||||||
|
|
||||||
keyname = 'public'
|
|
||||||
description = ('Encrypts a file. The file must be shorter than the key '
|
|
||||||
'length in order to be encrypted. For larger files, use the '
|
|
||||||
'pyrsa-encrypt-bigfile command.')
|
|
||||||
operation = 'encrypt'
|
|
||||||
operation_past = 'encrypted'
|
|
||||||
operation_progressive = 'encrypting'
|
|
||||||
|
|
||||||
|
|
||||||
def perform_operation(self, indata, pub_key, cli_args=None):
|
|
||||||
'''Encrypts files.'''
|
|
||||||
|
|
||||||
return rsa.encrypt(indata, pub_key)
|
|
||||||
|
|
||||||
class DecryptOperation(CryptoOperation):
|
|
||||||
'''Decrypts a file.'''
|
|
||||||
|
|
||||||
keyname = 'private'
|
|
||||||
description = ('Decrypts a file. The original file must be shorter than '
|
|
||||||
'the key length in order to have been encrypted. For larger '
|
|
||||||
'files, use the pyrsa-decrypt-bigfile command.')
|
|
||||||
operation = 'decrypt'
|
|
||||||
operation_past = 'decrypted'
|
|
||||||
operation_progressive = 'decrypting'
|
|
||||||
key_class = rsa.PrivateKey
|
|
||||||
|
|
||||||
def perform_operation(self, indata, priv_key, cli_args=None):
|
|
||||||
'''Decrypts files.'''
|
|
||||||
|
|
||||||
return rsa.decrypt(indata, priv_key)
|
|
||||||
|
|
||||||
class SignOperation(CryptoOperation):
|
|
||||||
'''Signs a file.'''
|
|
||||||
|
|
||||||
keyname = 'private'
|
|
||||||
usage = 'usage: %%prog [options] private_key hash_method'
|
|
||||||
description = ('Signs a file, outputs the signature. Choose the hash '
|
|
||||||
'method from %s' % ', '.join(HASH_METHODS))
|
|
||||||
operation = 'sign'
|
|
||||||
operation_past = 'signature'
|
|
||||||
operation_progressive = 'Signing'
|
|
||||||
key_class = rsa.PrivateKey
|
|
||||||
expected_cli_args = 2
|
|
||||||
|
|
||||||
output_help = ('Name of the file to write the signature to. Written '
|
|
||||||
'to stdout if this option is not present.')
|
|
||||||
|
|
||||||
def perform_operation(self, indata, priv_key, cli_args):
|
|
||||||
'''Decrypts files.'''
|
|
||||||
|
|
||||||
hash_method = cli_args[1]
|
|
||||||
if hash_method not in HASH_METHODS:
|
|
||||||
raise SystemExit('Invalid hash method, choose one of %s' %
|
|
||||||
', '.join(HASH_METHODS))
|
|
||||||
|
|
||||||
return rsa.sign(indata, priv_key, hash_method)
|
|
||||||
|
|
||||||
class VerifyOperation(CryptoOperation):
|
|
||||||
'''Verify a signature.'''
|
|
||||||
|
|
||||||
keyname = 'public'
|
|
||||||
usage = 'usage: %%prog [options] public_key signature_file'
|
|
||||||
description = ('Verifies a signature, exits with status 0 upon success, '
|
|
||||||
'prints an error message and exits with status 1 upon error.')
|
|
||||||
operation = 'verify'
|
|
||||||
operation_past = 'verified'
|
|
||||||
operation_progressive = 'Verifying'
|
|
||||||
key_class = rsa.PublicKey
|
|
||||||
expected_cli_args = 2
|
|
||||||
has_output = False
|
|
||||||
|
|
||||||
def perform_operation(self, indata, pub_key, cli_args):
|
|
||||||
'''Decrypts files.'''
|
|
||||||
|
|
||||||
signature_file = cli_args[1]
|
|
||||||
|
|
||||||
with open(signature_file, 'rb') as sigfile:
|
|
||||||
signature = sigfile.read()
|
|
||||||
|
|
||||||
try:
|
|
||||||
rsa.verify(indata, signature, pub_key)
|
|
||||||
except rsa.VerificationError:
|
|
||||||
raise SystemExit('Verification failed.')
|
|
||||||
|
|
||||||
print('Verification OK', file=sys.stderr)
|
|
||||||
|
|
||||||
|
|
||||||
class BigfileOperation(CryptoOperation):
|
|
||||||
'''CryptoOperation that doesn't read the entire file into memory.'''
|
|
||||||
|
|
||||||
def __init__(self):
|
|
||||||
CryptoOperation.__init__(self)
|
|
||||||
|
|
||||||
self.file_objects = []
|
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
'''Closes any open file handles.'''
|
|
||||||
|
|
||||||
for fobj in self.file_objects:
|
|
||||||
fobj.close()
|
|
||||||
|
|
||||||
def __call__(self):
|
|
||||||
'''Runs the program.'''
|
|
||||||
|
|
||||||
(cli, cli_args) = self.parse_cli()
|
|
||||||
|
|
||||||
key = self.read_key(cli_args[0], cli.keyform)
|
|
||||||
|
|
||||||
# Get the file handles
|
|
||||||
infile = self.get_infile(cli.input)
|
|
||||||
outfile = self.get_outfile(cli.output)
|
|
||||||
|
|
||||||
# Call the operation
|
|
||||||
print(self.operation_progressive.title(), file=sys.stderr)
|
|
||||||
self.perform_operation(infile, outfile, key, cli_args)
|
|
||||||
|
|
||||||
def get_infile(self, inname):
|
|
||||||
'''Returns the input file object'''
|
|
||||||
|
|
||||||
if inname:
|
|
||||||
print('Reading input from %s' % inname, file=sys.stderr)
|
|
||||||
fobj = open(inname, 'rb')
|
|
||||||
self.file_objects.append(fobj)
|
|
||||||
else:
|
|
||||||
print('Reading input from stdin', file=sys.stderr)
|
|
||||||
fobj = sys.stdin
|
|
||||||
|
|
||||||
return fobj
|
|
||||||
|
|
||||||
def get_outfile(self, outname):
|
|
||||||
'''Returns the output file object'''
|
|
||||||
|
|
||||||
if outname:
|
|
||||||
print('Will write output to %s' % outname, file=sys.stderr)
|
|
||||||
fobj = open(outname, 'wb')
|
|
||||||
self.file_objects.append(fobj)
|
|
||||||
else:
|
|
||||||
print('Will write output to stdout', file=sys.stderr)
|
|
||||||
fobj = sys.stdout
|
|
||||||
|
|
||||||
return fobj
|
|
||||||
|
|
||||||
class EncryptBigfileOperation(BigfileOperation):
|
|
||||||
'''Encrypts a file to VARBLOCK format.'''
|
|
||||||
|
|
||||||
keyname = 'public'
|
|
||||||
description = ('Encrypts a file to an encrypted VARBLOCK file. The file '
|
|
||||||
'can be larger than the key length, but the output file is only '
|
|
||||||
'compatible with Python-RSA.')
|
|
||||||
operation = 'encrypt'
|
|
||||||
operation_past = 'encrypted'
|
|
||||||
operation_progressive = 'encrypting'
|
|
||||||
|
|
||||||
def perform_operation(self, infile, outfile, pub_key, cli_args=None):
|
|
||||||
'''Encrypts files to VARBLOCK.'''
|
|
||||||
|
|
||||||
return rsa.bigfile.encrypt_bigfile(infile, outfile, pub_key)
|
|
||||||
|
|
||||||
class DecryptBigfileOperation(BigfileOperation):
|
|
||||||
'''Decrypts a file in VARBLOCK format.'''
|
|
||||||
|
|
||||||
keyname = 'private'
|
|
||||||
description = ('Decrypts an encrypted VARBLOCK file that was encrypted '
|
|
||||||
'with pyrsa-encrypt-bigfile')
|
|
||||||
operation = 'decrypt'
|
|
||||||
operation_past = 'decrypted'
|
|
||||||
operation_progressive = 'decrypting'
|
|
||||||
key_class = rsa.PrivateKey
|
|
||||||
|
|
||||||
def perform_operation(self, infile, outfile, priv_key, cli_args=None):
|
|
||||||
'''Decrypts a VARBLOCK file.'''
|
|
||||||
|
|
||||||
return rsa.bigfile.decrypt_bigfile(infile, outfile, priv_key)
|
|
||||||
|
|
||||||
|
|
||||||
encrypt = EncryptOperation()
|
|
||||||
decrypt = DecryptOperation()
|
|
||||||
sign = SignOperation()
|
|
||||||
verify = VerifyOperation()
|
|
||||||
encrypt_bigfile = EncryptBigfileOperation()
|
|
||||||
decrypt_bigfile = DecryptBigfileOperation()
|
|
||||||
|
|
|
@ -1,185 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''Common functionality shared by several modules.'''
|
|
||||||
|
|
||||||
|
|
||||||
def bit_size(num):
|
|
||||||
'''
|
|
||||||
Number of bits needed to represent a integer excluding any prefix
|
|
||||||
0 bits.
|
|
||||||
|
|
||||||
As per definition from http://wiki.python.org/moin/BitManipulation and
|
|
||||||
to match the behavior of the Python 3 API.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
>>> bit_size(1023)
|
|
||||||
10
|
|
||||||
>>> bit_size(1024)
|
|
||||||
11
|
|
||||||
>>> bit_size(1025)
|
|
||||||
11
|
|
||||||
|
|
||||||
:param num:
|
|
||||||
Integer value. If num is 0, returns 0. Only the absolute value of the
|
|
||||||
number is considered. Therefore, signed integers will be abs(num)
|
|
||||||
before the number's bit length is determined.
|
|
||||||
:returns:
|
|
||||||
Returns the number of bits in the integer.
|
|
||||||
'''
|
|
||||||
if num == 0:
|
|
||||||
return 0
|
|
||||||
if num < 0:
|
|
||||||
num = -num
|
|
||||||
|
|
||||||
# Make sure this is an int and not a float.
|
|
||||||
num & 1
|
|
||||||
|
|
||||||
hex_num = "%x" % num
|
|
||||||
return ((len(hex_num) - 1) * 4) + {
|
|
||||||
'0':0, '1':1, '2':2, '3':2,
|
|
||||||
'4':3, '5':3, '6':3, '7':3,
|
|
||||||
'8':4, '9':4, 'a':4, 'b':4,
|
|
||||||
'c':4, 'd':4, 'e':4, 'f':4,
|
|
||||||
}[hex_num[0]]
|
|
||||||
|
|
||||||
|
|
||||||
def _bit_size(number):
|
|
||||||
'''
|
|
||||||
Returns the number of bits required to hold a specific long number.
|
|
||||||
'''
|
|
||||||
if number < 0:
|
|
||||||
raise ValueError('Only nonnegative numbers possible: %s' % number)
|
|
||||||
|
|
||||||
if number == 0:
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# This works, even with very large numbers. When using math.log(number, 2),
|
|
||||||
# you'll get rounding errors and it'll fail.
|
|
||||||
bits = 0
|
|
||||||
while number:
|
|
||||||
bits += 1
|
|
||||||
number >>= 1
|
|
||||||
|
|
||||||
return bits
|
|
||||||
|
|
||||||
|
|
||||||
def byte_size(number):
|
|
||||||
'''
|
|
||||||
Returns the number of bytes required to hold a specific long number.
|
|
||||||
|
|
||||||
The number of bytes is rounded up.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
>>> byte_size(1 << 1023)
|
|
||||||
128
|
|
||||||
>>> byte_size((1 << 1024) - 1)
|
|
||||||
128
|
|
||||||
>>> byte_size(1 << 1024)
|
|
||||||
129
|
|
||||||
|
|
||||||
:param number:
|
|
||||||
An unsigned integer
|
|
||||||
:returns:
|
|
||||||
The number of bytes required to hold a specific long number.
|
|
||||||
'''
|
|
||||||
quanta, mod = divmod(bit_size(number), 8)
|
|
||||||
if mod or number == 0:
|
|
||||||
quanta += 1
|
|
||||||
return quanta
|
|
||||||
#return int(math.ceil(bit_size(number) / 8.0))
|
|
||||||
|
|
||||||
|
|
||||||
def extended_gcd(a, b):
|
|
||||||
'''Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
|
|
||||||
'''
|
|
||||||
# r = gcd(a,b) i = multiplicitive inverse of a mod b
|
|
||||||
# or j = multiplicitive inverse of b mod a
|
|
||||||
# Neg return values for i or j are made positive mod b or a respectively
|
|
||||||
# Iterateive Version is faster and uses much less stack space
|
|
||||||
x = 0
|
|
||||||
y = 1
|
|
||||||
lx = 1
|
|
||||||
ly = 0
|
|
||||||
oa = a #Remember original a/b to remove
|
|
||||||
ob = b #negative values from return results
|
|
||||||
while b != 0:
|
|
||||||
q = a // b
|
|
||||||
(a, b) = (b, a % b)
|
|
||||||
(x, lx) = ((lx - (q * x)),x)
|
|
||||||
(y, ly) = ((ly - (q * y)),y)
|
|
||||||
if (lx < 0): lx += ob #If neg wrap modulo orignal b
|
|
||||||
if (ly < 0): ly += oa #If neg wrap modulo orignal a
|
|
||||||
return (a, lx, ly) #Return only positive values
|
|
||||||
|
|
||||||
|
|
||||||
def inverse(x, n):
|
|
||||||
'''Returns x^-1 (mod n)
|
|
||||||
|
|
||||||
>>> inverse(7, 4)
|
|
||||||
3
|
|
||||||
>>> (inverse(143, 4) * 143) % 4
|
|
||||||
1
|
|
||||||
'''
|
|
||||||
|
|
||||||
(divider, inv, _) = extended_gcd(x, n)
|
|
||||||
|
|
||||||
if divider != 1:
|
|
||||||
raise ValueError("x (%d) and n (%d) are not relatively prime" % (x, n))
|
|
||||||
|
|
||||||
return inv
|
|
||||||
|
|
||||||
|
|
||||||
def crt(a_values, modulo_values):
|
|
||||||
'''Chinese Remainder Theorem.
|
|
||||||
|
|
||||||
Calculates x such that x = a[i] (mod m[i]) for each i.
|
|
||||||
|
|
||||||
:param a_values: the a-values of the above equation
|
|
||||||
:param modulo_values: the m-values of the above equation
|
|
||||||
:returns: x such that x = a[i] (mod m[i]) for each i
|
|
||||||
|
|
||||||
|
|
||||||
>>> crt([2, 3], [3, 5])
|
|
||||||
8
|
|
||||||
|
|
||||||
>>> crt([2, 3, 2], [3, 5, 7])
|
|
||||||
23
|
|
||||||
|
|
||||||
>>> crt([2, 3, 0], [7, 11, 15])
|
|
||||||
135
|
|
||||||
'''
|
|
||||||
|
|
||||||
m = 1
|
|
||||||
x = 0
|
|
||||||
|
|
||||||
for modulo in modulo_values:
|
|
||||||
m *= modulo
|
|
||||||
|
|
||||||
for (m_i, a_i) in zip(modulo_values, a_values):
|
|
||||||
M_i = m // m_i
|
|
||||||
inv = inverse(M_i, m_i)
|
|
||||||
|
|
||||||
x = (x + a_i * M_i * inv) % m
|
|
||||||
|
|
||||||
return x
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
||||||
|
|
|
@ -1,58 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''Core mathematical operations.
|
|
||||||
|
|
||||||
This is the actual core RSA implementation, which is only defined
|
|
||||||
mathematically on integers.
|
|
||||||
'''
|
|
||||||
|
|
||||||
|
|
||||||
from rsa._compat import is_integer
|
|
||||||
|
|
||||||
def assert_int(var, name):
|
|
||||||
|
|
||||||
if is_integer(var):
|
|
||||||
return
|
|
||||||
|
|
||||||
raise TypeError('%s should be an integer, not %s' % (name, var.__class__))
|
|
||||||
|
|
||||||
def encrypt_int(message, ekey, n):
|
|
||||||
'''Encrypts a message using encryption key 'ekey', working modulo n'''
|
|
||||||
|
|
||||||
assert_int(message, 'message')
|
|
||||||
assert_int(ekey, 'ekey')
|
|
||||||
assert_int(n, 'n')
|
|
||||||
|
|
||||||
if message < 0:
|
|
||||||
raise ValueError('Only non-negative numbers are supported')
|
|
||||||
|
|
||||||
if message > n:
|
|
||||||
raise OverflowError("The message %i is too long for n=%i" % (message, n))
|
|
||||||
|
|
||||||
return pow(message, ekey, n)
|
|
||||||
|
|
||||||
def decrypt_int(cyphertext, dkey, n):
|
|
||||||
'''Decrypts a cypher text using the decryption key 'dkey', working
|
|
||||||
modulo n'''
|
|
||||||
|
|
||||||
assert_int(cyphertext, 'cyphertext')
|
|
||||||
assert_int(dkey, 'dkey')
|
|
||||||
assert_int(n, 'n')
|
|
||||||
|
|
||||||
message = pow(cyphertext, dkey, n)
|
|
||||||
return message
|
|
||||||
|
|
|
@ -1,612 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''RSA key generation code.
|
|
||||||
|
|
||||||
Create new keys with the newkeys() function. It will give you a PublicKey and a
|
|
||||||
PrivateKey object.
|
|
||||||
|
|
||||||
Loading and saving keys requires the pyasn1 module. This module is imported as
|
|
||||||
late as possible, such that other functionality will remain working in absence
|
|
||||||
of pyasn1.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from rsa._compat import b, bytes_type
|
|
||||||
|
|
||||||
import rsa.prime
|
|
||||||
import rsa.pem
|
|
||||||
import rsa.common
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class AbstractKey(object):
|
|
||||||
'''Abstract superclass for private and public keys.'''
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_pkcs1(cls, keyfile, format='PEM'):
|
|
||||||
r'''Loads a key in PKCS#1 DER or PEM format.
|
|
||||||
|
|
||||||
:param keyfile: contents of a DER- or PEM-encoded file that contains
|
|
||||||
the public key.
|
|
||||||
:param format: the format of the file to load; 'PEM' or 'DER'
|
|
||||||
|
|
||||||
:return: a PublicKey object
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
methods = {
|
|
||||||
'PEM': cls._load_pkcs1_pem,
|
|
||||||
'DER': cls._load_pkcs1_der,
|
|
||||||
}
|
|
||||||
|
|
||||||
if format not in methods:
|
|
||||||
formats = ', '.join(sorted(methods.keys()))
|
|
||||||
raise ValueError('Unsupported format: %r, try one of %s' % (format,
|
|
||||||
formats))
|
|
||||||
|
|
||||||
method = methods[format]
|
|
||||||
return method(keyfile)
|
|
||||||
|
|
||||||
def save_pkcs1(self, format='PEM'):
|
|
||||||
'''Saves the public key in PKCS#1 DER or PEM format.
|
|
||||||
|
|
||||||
:param format: the format to save; 'PEM' or 'DER'
|
|
||||||
:returns: the DER- or PEM-encoded public key.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
methods = {
|
|
||||||
'PEM': self._save_pkcs1_pem,
|
|
||||||
'DER': self._save_pkcs1_der,
|
|
||||||
}
|
|
||||||
|
|
||||||
if format not in methods:
|
|
||||||
formats = ', '.join(sorted(methods.keys()))
|
|
||||||
raise ValueError('Unsupported format: %r, try one of %s' % (format,
|
|
||||||
formats))
|
|
||||||
|
|
||||||
method = methods[format]
|
|
||||||
return method()
|
|
||||||
|
|
||||||
class PublicKey(AbstractKey):
|
|
||||||
'''Represents a public RSA key.
|
|
||||||
|
|
||||||
This key is also known as the 'encryption key'. It contains the 'n' and 'e'
|
|
||||||
values.
|
|
||||||
|
|
||||||
Supports attributes as well as dictionary-like access. Attribute accesss is
|
|
||||||
faster, though.
|
|
||||||
|
|
||||||
>>> PublicKey(5, 3)
|
|
||||||
PublicKey(5, 3)
|
|
||||||
|
|
||||||
>>> key = PublicKey(5, 3)
|
|
||||||
>>> key.n
|
|
||||||
5
|
|
||||||
>>> key['n']
|
|
||||||
5
|
|
||||||
>>> key.e
|
|
||||||
3
|
|
||||||
>>> key['e']
|
|
||||||
3
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
__slots__ = ('n', 'e')
|
|
||||||
|
|
||||||
def __init__(self, n, e):
|
|
||||||
self.n = n
|
|
||||||
self.e = e
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return getattr(self, key)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return 'PublicKey(%i, %i)' % (self.n, self.e)
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if other is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not isinstance(other, PublicKey):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return self.n == other.n and self.e == other.e
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
return not (self == other)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _load_pkcs1_der(cls, keyfile):
|
|
||||||
r'''Loads a key in PKCS#1 DER format.
|
|
||||||
|
|
||||||
@param keyfile: contents of a DER-encoded file that contains the public
|
|
||||||
key.
|
|
||||||
@return: a PublicKey object
|
|
||||||
|
|
||||||
First let's construct a DER encoded key:
|
|
||||||
|
|
||||||
>>> import base64
|
|
||||||
>>> b64der = 'MAwCBQCNGmYtAgMBAAE='
|
|
||||||
>>> der = base64.decodestring(b64der)
|
|
||||||
|
|
||||||
This loads the file:
|
|
||||||
|
|
||||||
>>> PublicKey._load_pkcs1_der(der)
|
|
||||||
PublicKey(2367317549, 65537)
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
from pyasn1.codec.der import decoder
|
|
||||||
from rsa.asn1 import AsnPubKey
|
|
||||||
|
|
||||||
(priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey())
|
|
||||||
return cls(n=int(priv['modulus']), e=int(priv['publicExponent']))
|
|
||||||
|
|
||||||
def _save_pkcs1_der(self):
|
|
||||||
'''Saves the public key in PKCS#1 DER format.
|
|
||||||
|
|
||||||
@returns: the DER-encoded public key.
|
|
||||||
'''
|
|
||||||
|
|
||||||
from pyasn1.codec.der import encoder
|
|
||||||
from rsa.asn1 import AsnPubKey
|
|
||||||
|
|
||||||
# Create the ASN object
|
|
||||||
asn_key = AsnPubKey()
|
|
||||||
asn_key.setComponentByName('modulus', self.n)
|
|
||||||
asn_key.setComponentByName('publicExponent', self.e)
|
|
||||||
|
|
||||||
return encoder.encode(asn_key)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _load_pkcs1_pem(cls, keyfile):
|
|
||||||
'''Loads a PKCS#1 PEM-encoded public key file.
|
|
||||||
|
|
||||||
The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and
|
|
||||||
after the "-----END RSA PUBLIC KEY-----" lines is ignored.
|
|
||||||
|
|
||||||
@param keyfile: contents of a PEM-encoded file that contains the public
|
|
||||||
key.
|
|
||||||
@return: a PublicKey object
|
|
||||||
'''
|
|
||||||
|
|
||||||
der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY')
|
|
||||||
return cls._load_pkcs1_der(der)
|
|
||||||
|
|
||||||
def _save_pkcs1_pem(self):
|
|
||||||
'''Saves a PKCS#1 PEM-encoded public key file.
|
|
||||||
|
|
||||||
@return: contents of a PEM-encoded file that contains the public key.
|
|
||||||
'''
|
|
||||||
|
|
||||||
der = self._save_pkcs1_der()
|
|
||||||
return rsa.pem.save_pem(der, 'RSA PUBLIC KEY')
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_pkcs1_openssl_pem(cls, keyfile):
|
|
||||||
'''Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL.
|
|
||||||
|
|
||||||
These files can be recognised in that they start with BEGIN PUBLIC KEY
|
|
||||||
rather than BEGIN RSA PUBLIC KEY.
|
|
||||||
|
|
||||||
The contents of the file before the "-----BEGIN PUBLIC KEY-----" and
|
|
||||||
after the "-----END PUBLIC KEY-----" lines is ignored.
|
|
||||||
|
|
||||||
@param keyfile: contents of a PEM-encoded file that contains the public
|
|
||||||
key, from OpenSSL.
|
|
||||||
@return: a PublicKey object
|
|
||||||
'''
|
|
||||||
|
|
||||||
der = rsa.pem.load_pem(keyfile, 'PUBLIC KEY')
|
|
||||||
return cls.load_pkcs1_openssl_der(der)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load_pkcs1_openssl_der(cls, keyfile):
|
|
||||||
'''Loads a PKCS#1 DER-encoded public key file from OpenSSL.
|
|
||||||
|
|
||||||
@param keyfile: contents of a DER-encoded file that contains the public
|
|
||||||
key, from OpenSSL.
|
|
||||||
@return: a PublicKey object
|
|
||||||
'''
|
|
||||||
|
|
||||||
from rsa.asn1 import OpenSSLPubKey
|
|
||||||
from pyasn1.codec.der import decoder
|
|
||||||
from pyasn1.type import univ
|
|
||||||
|
|
||||||
(keyinfo, _) = decoder.decode(keyfile, asn1Spec=OpenSSLPubKey())
|
|
||||||
|
|
||||||
if keyinfo['header']['oid'] != univ.ObjectIdentifier('1.2.840.113549.1.1.1'):
|
|
||||||
raise TypeError("This is not a DER-encoded OpenSSL-compatible public key")
|
|
||||||
|
|
||||||
return cls._load_pkcs1_der(keyinfo['key'][1:])
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class PrivateKey(AbstractKey):
|
|
||||||
'''Represents a private RSA key.
|
|
||||||
|
|
||||||
This key is also known as the 'decryption key'. It contains the 'n', 'e',
|
|
||||||
'd', 'p', 'q' and other values.
|
|
||||||
|
|
||||||
Supports attributes as well as dictionary-like access. Attribute accesss is
|
|
||||||
faster, though.
|
|
||||||
|
|
||||||
>>> PrivateKey(3247, 65537, 833, 191, 17)
|
|
||||||
PrivateKey(3247, 65537, 833, 191, 17)
|
|
||||||
|
|
||||||
exp1, exp2 and coef don't have to be given, they will be calculated:
|
|
||||||
|
|
||||||
>>> pk = PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
|
|
||||||
>>> pk.exp1
|
|
||||||
55063
|
|
||||||
>>> pk.exp2
|
|
||||||
10095
|
|
||||||
>>> pk.coef
|
|
||||||
50797
|
|
||||||
|
|
||||||
If you give exp1, exp2 or coef, they will be used as-is:
|
|
||||||
|
|
||||||
>>> pk = PrivateKey(1, 2, 3, 4, 5, 6, 7, 8)
|
|
||||||
>>> pk.exp1
|
|
||||||
6
|
|
||||||
>>> pk.exp2
|
|
||||||
7
|
|
||||||
>>> pk.coef
|
|
||||||
8
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
__slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef')
|
|
||||||
|
|
||||||
def __init__(self, n, e, d, p, q, exp1=None, exp2=None, coef=None):
|
|
||||||
self.n = n
|
|
||||||
self.e = e
|
|
||||||
self.d = d
|
|
||||||
self.p = p
|
|
||||||
self.q = q
|
|
||||||
|
|
||||||
# Calculate the other values if they aren't supplied
|
|
||||||
if exp1 is None:
|
|
||||||
self.exp1 = int(d % (p - 1))
|
|
||||||
else:
|
|
||||||
self.exp1 = exp1
|
|
||||||
|
|
||||||
if exp1 is None:
|
|
||||||
self.exp2 = int(d % (q - 1))
|
|
||||||
else:
|
|
||||||
self.exp2 = exp2
|
|
||||||
|
|
||||||
if coef is None:
|
|
||||||
self.coef = rsa.common.inverse(q, p)
|
|
||||||
else:
|
|
||||||
self.coef = coef
|
|
||||||
|
|
||||||
def __getitem__(self, key):
|
|
||||||
return getattr(self, key)
|
|
||||||
|
|
||||||
def __repr__(self):
|
|
||||||
return 'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self
|
|
||||||
|
|
||||||
def __eq__(self, other):
|
|
||||||
if other is None:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not isinstance(other, PrivateKey):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return (self.n == other.n and
|
|
||||||
self.e == other.e and
|
|
||||||
self.d == other.d and
|
|
||||||
self.p == other.p and
|
|
||||||
self.q == other.q and
|
|
||||||
self.exp1 == other.exp1 and
|
|
||||||
self.exp2 == other.exp2 and
|
|
||||||
self.coef == other.coef)
|
|
||||||
|
|
||||||
def __ne__(self, other):
|
|
||||||
return not (self == other)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _load_pkcs1_der(cls, keyfile):
|
|
||||||
r'''Loads a key in PKCS#1 DER format.
|
|
||||||
|
|
||||||
@param keyfile: contents of a DER-encoded file that contains the private
|
|
||||||
key.
|
|
||||||
@return: a PrivateKey object
|
|
||||||
|
|
||||||
First let's construct a DER encoded key:
|
|
||||||
|
|
||||||
>>> import base64
|
|
||||||
>>> b64der = 'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt'
|
|
||||||
>>> der = base64.decodestring(b64der)
|
|
||||||
|
|
||||||
This loads the file:
|
|
||||||
|
|
||||||
>>> PrivateKey._load_pkcs1_der(der)
|
|
||||||
PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
from pyasn1.codec.der import decoder
|
|
||||||
(priv, _) = decoder.decode(keyfile)
|
|
||||||
|
|
||||||
# ASN.1 contents of DER encoded private key:
|
|
||||||
#
|
|
||||||
# RSAPrivateKey ::= SEQUENCE {
|
|
||||||
# version Version,
|
|
||||||
# modulus INTEGER, -- n
|
|
||||||
# publicExponent INTEGER, -- e
|
|
||||||
# privateExponent INTEGER, -- d
|
|
||||||
# prime1 INTEGER, -- p
|
|
||||||
# prime2 INTEGER, -- q
|
|
||||||
# exponent1 INTEGER, -- d mod (p-1)
|
|
||||||
# exponent2 INTEGER, -- d mod (q-1)
|
|
||||||
# coefficient INTEGER, -- (inverse of q) mod p
|
|
||||||
# otherPrimeInfos OtherPrimeInfos OPTIONAL
|
|
||||||
# }
|
|
||||||
|
|
||||||
if priv[0] != 0:
|
|
||||||
raise ValueError('Unable to read this file, version %s != 0' % priv[0])
|
|
||||||
|
|
||||||
as_ints = tuple(int(x) for x in priv[1:9])
|
|
||||||
return cls(*as_ints)
|
|
||||||
|
|
||||||
def _save_pkcs1_der(self):
|
|
||||||
'''Saves the private key in PKCS#1 DER format.
|
|
||||||
|
|
||||||
@returns: the DER-encoded private key.
|
|
||||||
'''
|
|
||||||
|
|
||||||
from pyasn1.type import univ, namedtype
|
|
||||||
from pyasn1.codec.der import encoder
|
|
||||||
|
|
||||||
class AsnPrivKey(univ.Sequence):
|
|
||||||
componentType = namedtype.NamedTypes(
|
|
||||||
namedtype.NamedType('version', univ.Integer()),
|
|
||||||
namedtype.NamedType('modulus', univ.Integer()),
|
|
||||||
namedtype.NamedType('publicExponent', univ.Integer()),
|
|
||||||
namedtype.NamedType('privateExponent', univ.Integer()),
|
|
||||||
namedtype.NamedType('prime1', univ.Integer()),
|
|
||||||
namedtype.NamedType('prime2', univ.Integer()),
|
|
||||||
namedtype.NamedType('exponent1', univ.Integer()),
|
|
||||||
namedtype.NamedType('exponent2', univ.Integer()),
|
|
||||||
namedtype.NamedType('coefficient', univ.Integer()),
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create the ASN object
|
|
||||||
asn_key = AsnPrivKey()
|
|
||||||
asn_key.setComponentByName('version', 0)
|
|
||||||
asn_key.setComponentByName('modulus', self.n)
|
|
||||||
asn_key.setComponentByName('publicExponent', self.e)
|
|
||||||
asn_key.setComponentByName('privateExponent', self.d)
|
|
||||||
asn_key.setComponentByName('prime1', self.p)
|
|
||||||
asn_key.setComponentByName('prime2', self.q)
|
|
||||||
asn_key.setComponentByName('exponent1', self.exp1)
|
|
||||||
asn_key.setComponentByName('exponent2', self.exp2)
|
|
||||||
asn_key.setComponentByName('coefficient', self.coef)
|
|
||||||
|
|
||||||
return encoder.encode(asn_key)
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def _load_pkcs1_pem(cls, keyfile):
|
|
||||||
'''Loads a PKCS#1 PEM-encoded private key file.
|
|
||||||
|
|
||||||
The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and
|
|
||||||
after the "-----END RSA PRIVATE KEY-----" lines is ignored.
|
|
||||||
|
|
||||||
@param keyfile: contents of a PEM-encoded file that contains the private
|
|
||||||
key.
|
|
||||||
@return: a PrivateKey object
|
|
||||||
'''
|
|
||||||
|
|
||||||
der = rsa.pem.load_pem(keyfile, b('RSA PRIVATE KEY'))
|
|
||||||
return cls._load_pkcs1_der(der)
|
|
||||||
|
|
||||||
def _save_pkcs1_pem(self):
|
|
||||||
'''Saves a PKCS#1 PEM-encoded private key file.
|
|
||||||
|
|
||||||
@return: contents of a PEM-encoded file that contains the private key.
|
|
||||||
'''
|
|
||||||
|
|
||||||
der = self._save_pkcs1_der()
|
|
||||||
return rsa.pem.save_pem(der, b('RSA PRIVATE KEY'))
|
|
||||||
|
|
||||||
def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
|
|
||||||
''''Returns a tuple of two different primes of nbits bits each.
|
|
||||||
|
|
||||||
The resulting p * q has exacty 2 * nbits bits, and the returned p and q
|
|
||||||
will not be equal.
|
|
||||||
|
|
||||||
:param nbits: the number of bits in each of p and q.
|
|
||||||
:param getprime_func: the getprime function, defaults to
|
|
||||||
:py:func:`rsa.prime.getprime`.
|
|
||||||
|
|
||||||
*Introduced in Python-RSA 3.1*
|
|
||||||
|
|
||||||
:param accurate: whether to enable accurate mode or not.
|
|
||||||
:returns: (p, q), where p > q
|
|
||||||
|
|
||||||
>>> (p, q) = find_p_q(128)
|
|
||||||
>>> from rsa import common
|
|
||||||
>>> common.bit_size(p * q)
|
|
||||||
256
|
|
||||||
|
|
||||||
When not in accurate mode, the number of bits can be slightly less
|
|
||||||
|
|
||||||
>>> (p, q) = find_p_q(128, accurate=False)
|
|
||||||
>>> from rsa import common
|
|
||||||
>>> common.bit_size(p * q) <= 256
|
|
||||||
True
|
|
||||||
>>> common.bit_size(p * q) > 240
|
|
||||||
True
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
total_bits = nbits * 2
|
|
||||||
|
|
||||||
# Make sure that p and q aren't too close or the factoring programs can
|
|
||||||
# factor n.
|
|
||||||
shift = nbits // 16
|
|
||||||
pbits = nbits + shift
|
|
||||||
qbits = nbits - shift
|
|
||||||
|
|
||||||
# Choose the two initial primes
|
|
||||||
log.debug('find_p_q(%i): Finding p', nbits)
|
|
||||||
p = getprime_func(pbits)
|
|
||||||
log.debug('find_p_q(%i): Finding q', nbits)
|
|
||||||
q = getprime_func(qbits)
|
|
||||||
|
|
||||||
def is_acceptable(p, q):
|
|
||||||
'''Returns True iff p and q are acceptable:
|
|
||||||
|
|
||||||
- p and q differ
|
|
||||||
- (p * q) has the right nr of bits (when accurate=True)
|
|
||||||
'''
|
|
||||||
|
|
||||||
if p == q:
|
|
||||||
return False
|
|
||||||
|
|
||||||
if not accurate:
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Make sure we have just the right amount of bits
|
|
||||||
found_size = rsa.common.bit_size(p * q)
|
|
||||||
return total_bits == found_size
|
|
||||||
|
|
||||||
# Keep choosing other primes until they match our requirements.
|
|
||||||
change_p = False
|
|
||||||
while not is_acceptable(p, q):
|
|
||||||
# Change p on one iteration and q on the other
|
|
||||||
if change_p:
|
|
||||||
p = getprime_func(pbits)
|
|
||||||
else:
|
|
||||||
q = getprime_func(qbits)
|
|
||||||
|
|
||||||
change_p = not change_p
|
|
||||||
|
|
||||||
# We want p > q as described on
|
|
||||||
# http://www.di-mgt.com.au/rsa_alg.html#crt
|
|
||||||
return (max(p, q), min(p, q))
|
|
||||||
|
|
||||||
def calculate_keys(p, q, nbits):
|
|
||||||
'''Calculates an encryption and a decryption key given p and q, and
|
|
||||||
returns them as a tuple (e, d)
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
phi_n = (p - 1) * (q - 1)
|
|
||||||
|
|
||||||
# A very common choice for e is 65537
|
|
||||||
e = 65537
|
|
||||||
|
|
||||||
try:
|
|
||||||
d = rsa.common.inverse(e, phi_n)
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError("e (%d) and phi_n (%d) are not relatively prime" %
|
|
||||||
(e, phi_n))
|
|
||||||
|
|
||||||
if (e * d) % phi_n != 1:
|
|
||||||
raise ValueError("e (%d) and d (%d) are not mult. inv. modulo "
|
|
||||||
"phi_n (%d)" % (e, d, phi_n))
|
|
||||||
|
|
||||||
return (e, d)
|
|
||||||
|
|
||||||
def gen_keys(nbits, getprime_func, accurate=True):
|
|
||||||
'''Generate RSA keys of nbits bits. Returns (p, q, e, d).
|
|
||||||
|
|
||||||
Note: this can take a long time, depending on the key size.
|
|
||||||
|
|
||||||
:param nbits: the total number of bits in ``p`` and ``q``. Both ``p`` and
|
|
||||||
``q`` will use ``nbits/2`` bits.
|
|
||||||
:param getprime_func: either :py:func:`rsa.prime.getprime` or a function
|
|
||||||
with similar signature.
|
|
||||||
'''
|
|
||||||
|
|
||||||
(p, q) = find_p_q(nbits // 2, getprime_func, accurate)
|
|
||||||
(e, d) = calculate_keys(p, q, nbits // 2)
|
|
||||||
|
|
||||||
return (p, q, e, d)
|
|
||||||
|
|
||||||
def newkeys(nbits, accurate=True, poolsize=1):
|
|
||||||
'''Generates public and private keys, and returns them as (pub, priv).
|
|
||||||
|
|
||||||
The public key is also known as the 'encryption key', and is a
|
|
||||||
:py:class:`rsa.PublicKey` object. The private key is also known as the
|
|
||||||
'decryption key' and is a :py:class:`rsa.PrivateKey` object.
|
|
||||||
|
|
||||||
:param nbits: the number of bits required to store ``n = p*q``.
|
|
||||||
:param accurate: when True, ``n`` will have exactly the number of bits you
|
|
||||||
asked for. However, this makes key generation much slower. When False,
|
|
||||||
`n`` may have slightly less bits.
|
|
||||||
:param poolsize: the number of processes to use to generate the prime
|
|
||||||
numbers. If set to a number > 1, a parallel algorithm will be used.
|
|
||||||
This requires Python 2.6 or newer.
|
|
||||||
|
|
||||||
:returns: a tuple (:py:class:`rsa.PublicKey`, :py:class:`rsa.PrivateKey`)
|
|
||||||
|
|
||||||
The ``poolsize`` parameter was added in *Python-RSA 3.1* and requires
|
|
||||||
Python 2.6 or newer.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
if nbits < 16:
|
|
||||||
raise ValueError('Key too small')
|
|
||||||
|
|
||||||
if poolsize < 1:
|
|
||||||
raise ValueError('Pool size (%i) should be >= 1' % poolsize)
|
|
||||||
|
|
||||||
# Determine which getprime function to use
|
|
||||||
if poolsize > 1:
|
|
||||||
from rsa import parallel
|
|
||||||
import functools
|
|
||||||
|
|
||||||
getprime_func = functools.partial(parallel.getprime, poolsize=poolsize)
|
|
||||||
else: getprime_func = rsa.prime.getprime
|
|
||||||
|
|
||||||
# Generate the key components
|
|
||||||
(p, q, e, d) = gen_keys(nbits, getprime_func)
|
|
||||||
|
|
||||||
# Create the key objects
|
|
||||||
n = p * q
|
|
||||||
|
|
||||||
return (
|
|
||||||
PublicKey(n, e),
|
|
||||||
PrivateKey(n, e, d, p, q)
|
|
||||||
)
|
|
||||||
|
|
||||||
__all__ = ['PublicKey', 'PrivateKey', 'newkeys']
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import doctest
|
|
||||||
|
|
||||||
try:
|
|
||||||
for count in range(100):
|
|
||||||
(failures, tests) = doctest.testmod()
|
|
||||||
if failures:
|
|
||||||
break
|
|
||||||
|
|
||||||
if (count and count % 10 == 0) or count == 1:
|
|
||||||
print('%i times' % count)
|
|
||||||
except KeyboardInterrupt:
|
|
||||||
print('Aborted')
|
|
||||||
else:
|
|
||||||
print('Doctests done')
|
|
|
@ -1,94 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''Functions for parallel computation on multiple cores.
|
|
||||||
|
|
||||||
Introduced in Python-RSA 3.1.
|
|
||||||
|
|
||||||
.. note::
|
|
||||||
|
|
||||||
Requires Python 2.6 or newer.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import print_function
|
|
||||||
|
|
||||||
import multiprocessing as mp
|
|
||||||
|
|
||||||
import rsa.prime
|
|
||||||
import rsa.randnum
|
|
||||||
|
|
||||||
def _find_prime(nbits, pipe):
|
|
||||||
while True:
|
|
||||||
integer = rsa.randnum.read_random_int(nbits)
|
|
||||||
|
|
||||||
# Make sure it's odd
|
|
||||||
integer |= 1
|
|
||||||
|
|
||||||
# Test for primeness
|
|
||||||
if rsa.prime.is_prime(integer):
|
|
||||||
pipe.send(integer)
|
|
||||||
return
|
|
||||||
|
|
||||||
def getprime(nbits, poolsize):
|
|
||||||
'''Returns a prime number that can be stored in 'nbits' bits.
|
|
||||||
|
|
||||||
Works in multiple threads at the same time.
|
|
||||||
|
|
||||||
>>> p = getprime(128, 3)
|
|
||||||
>>> rsa.prime.is_prime(p-1)
|
|
||||||
False
|
|
||||||
>>> rsa.prime.is_prime(p)
|
|
||||||
True
|
|
||||||
>>> rsa.prime.is_prime(p+1)
|
|
||||||
False
|
|
||||||
|
|
||||||
>>> from rsa import common
|
|
||||||
>>> common.bit_size(p) == 128
|
|
||||||
True
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
(pipe_recv, pipe_send) = mp.Pipe(duplex=False)
|
|
||||||
|
|
||||||
# Create processes
|
|
||||||
procs = [mp.Process(target=_find_prime, args=(nbits, pipe_send))
|
|
||||||
for _ in range(poolsize)]
|
|
||||||
[p.start() for p in procs]
|
|
||||||
|
|
||||||
result = pipe_recv.recv()
|
|
||||||
|
|
||||||
[p.terminate() for p in procs]
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
__all__ = ['getprime']
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print('Running doctests 1000x or until failure')
|
|
||||||
import doctest
|
|
||||||
|
|
||||||
for count in range(100):
|
|
||||||
(failures, tests) = doctest.testmod()
|
|
||||||
if failures:
|
|
||||||
break
|
|
||||||
|
|
||||||
if count and count % 10 == 0:
|
|
||||||
print('%i times' % count)
|
|
||||||
|
|
||||||
print('Doctests done')
|
|
||||||
|
|
|
@ -1,120 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''Functions that load and write PEM-encoded files.'''
|
|
||||||
|
|
||||||
import base64
|
|
||||||
from rsa._compat import b, is_bytes
|
|
||||||
|
|
||||||
def _markers(pem_marker):
|
|
||||||
'''
|
|
||||||
Returns the start and end PEM markers
|
|
||||||
'''
|
|
||||||
|
|
||||||
if is_bytes(pem_marker):
|
|
||||||
pem_marker = pem_marker.decode('utf-8')
|
|
||||||
|
|
||||||
return (b('-----BEGIN %s-----' % pem_marker),
|
|
||||||
b('-----END %s-----' % pem_marker))
|
|
||||||
|
|
||||||
def load_pem(contents, pem_marker):
|
|
||||||
'''Loads a PEM file.
|
|
||||||
|
|
||||||
@param contents: the contents of the file to interpret
|
|
||||||
@param pem_marker: the marker of the PEM content, such as 'RSA PRIVATE KEY'
|
|
||||||
when your file has '-----BEGIN RSA PRIVATE KEY-----' and
|
|
||||||
'-----END RSA PRIVATE KEY-----' markers.
|
|
||||||
|
|
||||||
@return the base64-decoded content between the start and end markers.
|
|
||||||
|
|
||||||
@raise ValueError: when the content is invalid, for example when the start
|
|
||||||
marker cannot be found.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
(pem_start, pem_end) = _markers(pem_marker)
|
|
||||||
|
|
||||||
pem_lines = []
|
|
||||||
in_pem_part = False
|
|
||||||
|
|
||||||
for line in contents.splitlines():
|
|
||||||
line = line.strip()
|
|
||||||
|
|
||||||
# Skip empty lines
|
|
||||||
if not line:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle start marker
|
|
||||||
if line == pem_start:
|
|
||||||
if in_pem_part:
|
|
||||||
raise ValueError('Seen start marker "%s" twice' % pem_start)
|
|
||||||
|
|
||||||
in_pem_part = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip stuff before first marker
|
|
||||||
if not in_pem_part:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle end marker
|
|
||||||
if in_pem_part and line == pem_end:
|
|
||||||
in_pem_part = False
|
|
||||||
break
|
|
||||||
|
|
||||||
# Load fields
|
|
||||||
if b(':') in line:
|
|
||||||
continue
|
|
||||||
|
|
||||||
pem_lines.append(line)
|
|
||||||
|
|
||||||
# Do some sanity checks
|
|
||||||
if not pem_lines:
|
|
||||||
raise ValueError('No PEM start marker "%s" found' % pem_start)
|
|
||||||
|
|
||||||
if in_pem_part:
|
|
||||||
raise ValueError('No PEM end marker "%s" found' % pem_end)
|
|
||||||
|
|
||||||
# Base64-decode the contents
|
|
||||||
pem = b('').join(pem_lines)
|
|
||||||
return base64.decodestring(pem)
|
|
||||||
|
|
||||||
|
|
||||||
def save_pem(contents, pem_marker):
|
|
||||||
'''Saves a PEM file.
|
|
||||||
|
|
||||||
@param contents: the contents to encode in PEM format
|
|
||||||
@param pem_marker: the marker of the PEM content, such as 'RSA PRIVATE KEY'
|
|
||||||
when your file has '-----BEGIN RSA PRIVATE KEY-----' and
|
|
||||||
'-----END RSA PRIVATE KEY-----' markers.
|
|
||||||
|
|
||||||
@return the base64-encoded content between the start and end markers.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
(pem_start, pem_end) = _markers(pem_marker)
|
|
||||||
|
|
||||||
b64 = base64.encodestring(contents).replace(b('\n'), b(''))
|
|
||||||
pem_lines = [pem_start]
|
|
||||||
|
|
||||||
for block_start in range(0, len(b64), 64):
|
|
||||||
block = b64[block_start:block_start + 64]
|
|
||||||
pem_lines.append(block)
|
|
||||||
|
|
||||||
pem_lines.append(pem_end)
|
|
||||||
pem_lines.append(b(''))
|
|
||||||
|
|
||||||
return b('\n').join(pem_lines)
|
|
||||||
|
|
|
@ -1,391 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''Functions for PKCS#1 version 1.5 encryption and signing
|
|
||||||
|
|
||||||
This module implements certain functionality from PKCS#1 version 1.5. For a
|
|
||||||
very clear example, read http://www.di-mgt.com.au/rsa_alg.html#pkcs1schemes
|
|
||||||
|
|
||||||
At least 8 bytes of random padding is used when encrypting a message. This makes
|
|
||||||
these methods much more secure than the ones in the ``rsa`` module.
|
|
||||||
|
|
||||||
WARNING: this module leaks information when decryption or verification fails.
|
|
||||||
The exceptions that are raised contain the Python traceback information, which
|
|
||||||
can be used to deduce where in the process the failure occurred. DO NOT PASS
|
|
||||||
SUCH INFORMATION to your users.
|
|
||||||
'''
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
|
|
||||||
from rsa._compat import b
|
|
||||||
from rsa import common, transform, core, varblock
|
|
||||||
|
|
||||||
# ASN.1 codes that describe the hash algorithm used.
|
|
||||||
HASH_ASN1 = {
|
|
||||||
'MD5': b('\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10'),
|
|
||||||
'SHA-1': b('\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'),
|
|
||||||
'SHA-256': b('\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'),
|
|
||||||
'SHA-384': b('\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30'),
|
|
||||||
'SHA-512': b('\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'),
|
|
||||||
}
|
|
||||||
|
|
||||||
HASH_METHODS = {
|
|
||||||
'MD5': hashlib.md5,
|
|
||||||
'SHA-1': hashlib.sha1,
|
|
||||||
'SHA-256': hashlib.sha256,
|
|
||||||
'SHA-384': hashlib.sha384,
|
|
||||||
'SHA-512': hashlib.sha512,
|
|
||||||
}
|
|
||||||
|
|
||||||
class CryptoError(Exception):
|
|
||||||
'''Base class for all exceptions in this module.'''
|
|
||||||
|
|
||||||
class DecryptionError(CryptoError):
|
|
||||||
'''Raised when decryption fails.'''
|
|
||||||
|
|
||||||
class VerificationError(CryptoError):
|
|
||||||
'''Raised when verification fails.'''
|
|
||||||
|
|
||||||
def _pad_for_encryption(message, target_length):
|
|
||||||
r'''Pads the message for encryption, returning the padded message.
|
|
||||||
|
|
||||||
:return: 00 02 RANDOM_DATA 00 MESSAGE
|
|
||||||
|
|
||||||
>>> block = _pad_for_encryption('hello', 16)
|
|
||||||
>>> len(block)
|
|
||||||
16
|
|
||||||
>>> block[0:2]
|
|
||||||
'\x00\x02'
|
|
||||||
>>> block[-6:]
|
|
||||||
'\x00hello'
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
max_msglength = target_length - 11
|
|
||||||
msglength = len(message)
|
|
||||||
|
|
||||||
if msglength > max_msglength:
|
|
||||||
raise OverflowError('%i bytes needed for message, but there is only'
|
|
||||||
' space for %i' % (msglength, max_msglength))
|
|
||||||
|
|
||||||
# Get random padding
|
|
||||||
padding = b('')
|
|
||||||
padding_length = target_length - msglength - 3
|
|
||||||
|
|
||||||
# We remove 0-bytes, so we'll end up with less padding than we've asked for,
|
|
||||||
# so keep adding data until we're at the correct length.
|
|
||||||
while len(padding) < padding_length:
|
|
||||||
needed_bytes = padding_length - len(padding)
|
|
||||||
|
|
||||||
# Always read at least 8 bytes more than we need, and trim off the rest
|
|
||||||
# after removing the 0-bytes. This increases the chance of getting
|
|
||||||
# enough bytes, especially when needed_bytes is small
|
|
||||||
new_padding = os.urandom(needed_bytes + 5)
|
|
||||||
new_padding = new_padding.replace(b('\x00'), b(''))
|
|
||||||
padding = padding + new_padding[:needed_bytes]
|
|
||||||
|
|
||||||
assert len(padding) == padding_length
|
|
||||||
|
|
||||||
return b('').join([b('\x00\x02'),
|
|
||||||
padding,
|
|
||||||
b('\x00'),
|
|
||||||
message])
|
|
||||||
|
|
||||||
|
|
||||||
def _pad_for_signing(message, target_length):
|
|
||||||
r'''Pads the message for signing, returning the padded message.
|
|
||||||
|
|
||||||
The padding is always a repetition of FF bytes.
|
|
||||||
|
|
||||||
:return: 00 01 PADDING 00 MESSAGE
|
|
||||||
|
|
||||||
>>> block = _pad_for_signing('hello', 16)
|
|
||||||
>>> len(block)
|
|
||||||
16
|
|
||||||
>>> block[0:2]
|
|
||||||
'\x00\x01'
|
|
||||||
>>> block[-6:]
|
|
||||||
'\x00hello'
|
|
||||||
>>> block[2:-6]
|
|
||||||
'\xff\xff\xff\xff\xff\xff\xff\xff'
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
max_msglength = target_length - 11
|
|
||||||
msglength = len(message)
|
|
||||||
|
|
||||||
if msglength > max_msglength:
|
|
||||||
raise OverflowError('%i bytes needed for message, but there is only'
|
|
||||||
' space for %i' % (msglength, max_msglength))
|
|
||||||
|
|
||||||
padding_length = target_length - msglength - 3
|
|
||||||
|
|
||||||
return b('').join([b('\x00\x01'),
|
|
||||||
padding_length * b('\xff'),
|
|
||||||
b('\x00'),
|
|
||||||
message])
|
|
||||||
|
|
||||||
|
|
||||||
def encrypt(message, pub_key):
|
|
||||||
'''Encrypts the given message using PKCS#1 v1.5
|
|
||||||
|
|
||||||
:param message: the message to encrypt. Must be a byte string no longer than
|
|
||||||
``k-11`` bytes, where ``k`` is the number of bytes needed to encode
|
|
||||||
the ``n`` component of the public key.
|
|
||||||
:param pub_key: the :py:class:`rsa.PublicKey` to encrypt with.
|
|
||||||
:raise OverflowError: when the message is too large to fit in the padded
|
|
||||||
block.
|
|
||||||
|
|
||||||
>>> from rsa import key, common
|
|
||||||
>>> (pub_key, priv_key) = key.newkeys(256)
|
|
||||||
>>> message = 'hello'
|
|
||||||
>>> crypto = encrypt(message, pub_key)
|
|
||||||
|
|
||||||
The crypto text should be just as long as the public key 'n' component:
|
|
||||||
|
|
||||||
>>> len(crypto) == common.byte_size(pub_key.n)
|
|
||||||
True
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
keylength = common.byte_size(pub_key.n)
|
|
||||||
padded = _pad_for_encryption(message, keylength)
|
|
||||||
|
|
||||||
payload = transform.bytes2int(padded)
|
|
||||||
encrypted = core.encrypt_int(payload, pub_key.e, pub_key.n)
|
|
||||||
block = transform.int2bytes(encrypted, keylength)
|
|
||||||
|
|
||||||
return block
|
|
||||||
|
|
||||||
def decrypt(crypto, priv_key):
|
|
||||||
r'''Decrypts the given message using PKCS#1 v1.5
|
|
||||||
|
|
||||||
The decryption is considered 'failed' when the resulting cleartext doesn't
|
|
||||||
start with the bytes 00 02, or when the 00 byte between the padding and
|
|
||||||
the message cannot be found.
|
|
||||||
|
|
||||||
:param crypto: the crypto text as returned by :py:func:`rsa.encrypt`
|
|
||||||
:param priv_key: the :py:class:`rsa.PrivateKey` to decrypt with.
|
|
||||||
:raise DecryptionError: when the decryption fails. No details are given as
|
|
||||||
to why the code thinks the decryption fails, as this would leak
|
|
||||||
information about the private key.
|
|
||||||
|
|
||||||
|
|
||||||
>>> import rsa
|
|
||||||
>>> (pub_key, priv_key) = rsa.newkeys(256)
|
|
||||||
|
|
||||||
It works with strings:
|
|
||||||
|
|
||||||
>>> crypto = encrypt('hello', pub_key)
|
|
||||||
>>> decrypt(crypto, priv_key)
|
|
||||||
'hello'
|
|
||||||
|
|
||||||
And with binary data:
|
|
||||||
|
|
||||||
>>> crypto = encrypt('\x00\x00\x00\x00\x01', pub_key)
|
|
||||||
>>> decrypt(crypto, priv_key)
|
|
||||||
'\x00\x00\x00\x00\x01'
|
|
||||||
|
|
||||||
Altering the encrypted information will *likely* cause a
|
|
||||||
:py:class:`rsa.pkcs1.DecryptionError`. If you want to be *sure*, use
|
|
||||||
:py:func:`rsa.sign`.
|
|
||||||
|
|
||||||
|
|
||||||
.. warning::
|
|
||||||
|
|
||||||
Never display the stack trace of a
|
|
||||||
:py:class:`rsa.pkcs1.DecryptionError` exception. It shows where in the
|
|
||||||
code the exception occurred, and thus leaks information about the key.
|
|
||||||
It's only a tiny bit of information, but every bit makes cracking the
|
|
||||||
keys easier.
|
|
||||||
|
|
||||||
>>> crypto = encrypt('hello', pub_key)
|
|
||||||
>>> crypto = crypto[0:5] + 'X' + crypto[6:] # change a byte
|
|
||||||
>>> decrypt(crypto, priv_key)
|
|
||||||
Traceback (most recent call last):
|
|
||||||
...
|
|
||||||
DecryptionError: Decryption failed
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
blocksize = common.byte_size(priv_key.n)
|
|
||||||
encrypted = transform.bytes2int(crypto)
|
|
||||||
decrypted = core.decrypt_int(encrypted, priv_key.d, priv_key.n)
|
|
||||||
cleartext = transform.int2bytes(decrypted, blocksize)
|
|
||||||
|
|
||||||
# If we can't find the cleartext marker, decryption failed.
|
|
||||||
if cleartext[0:2] != b('\x00\x02'):
|
|
||||||
raise DecryptionError('Decryption failed')
|
|
||||||
|
|
||||||
# Find the 00 separator between the padding and the message
|
|
||||||
try:
|
|
||||||
sep_idx = cleartext.index(b('\x00'), 2)
|
|
||||||
except ValueError:
|
|
||||||
raise DecryptionError('Decryption failed')
|
|
||||||
|
|
||||||
return cleartext[sep_idx+1:]
|
|
||||||
|
|
||||||
def sign(message, priv_key, hash):
|
|
||||||
'''Signs the message with the private key.
|
|
||||||
|
|
||||||
Hashes the message, then signs the hash with the given key. This is known
|
|
||||||
as a "detached signature", because the message itself isn't altered.
|
|
||||||
|
|
||||||
:param message: the message to sign. Can be an 8-bit string or a file-like
|
|
||||||
object. If ``message`` has a ``read()`` method, it is assumed to be a
|
|
||||||
file-like object.
|
|
||||||
:param priv_key: the :py:class:`rsa.PrivateKey` to sign with
|
|
||||||
:param hash: the hash method used on the message. Use 'MD5', 'SHA-1',
|
|
||||||
'SHA-256', 'SHA-384' or 'SHA-512'.
|
|
||||||
:return: a message signature block.
|
|
||||||
:raise OverflowError: if the private key is too small to contain the
|
|
||||||
requested hash.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
# Get the ASN1 code for this hash method
|
|
||||||
if hash not in HASH_ASN1:
|
|
||||||
raise ValueError('Invalid hash method: %s' % hash)
|
|
||||||
asn1code = HASH_ASN1[hash]
|
|
||||||
|
|
||||||
# Calculate the hash
|
|
||||||
hash = _hash(message, hash)
|
|
||||||
|
|
||||||
# Encrypt the hash with the private key
|
|
||||||
cleartext = asn1code + hash
|
|
||||||
keylength = common.byte_size(priv_key.n)
|
|
||||||
padded = _pad_for_signing(cleartext, keylength)
|
|
||||||
|
|
||||||
payload = transform.bytes2int(padded)
|
|
||||||
encrypted = core.encrypt_int(payload, priv_key.d, priv_key.n)
|
|
||||||
block = transform.int2bytes(encrypted, keylength)
|
|
||||||
|
|
||||||
return block
|
|
||||||
|
|
||||||
def verify(message, signature, pub_key):
|
|
||||||
'''Verifies that the signature matches the message.
|
|
||||||
|
|
||||||
The hash method is detected automatically from the signature.
|
|
||||||
|
|
||||||
:param message: the signed message. Can be an 8-bit string or a file-like
|
|
||||||
object. If ``message`` has a ``read()`` method, it is assumed to be a
|
|
||||||
file-like object.
|
|
||||||
:param signature: the signature block, as created with :py:func:`rsa.sign`.
|
|
||||||
:param pub_key: the :py:class:`rsa.PublicKey` of the person signing the message.
|
|
||||||
:raise VerificationError: when the signature doesn't match the message.
|
|
||||||
|
|
||||||
.. warning::
|
|
||||||
|
|
||||||
Never display the stack trace of a
|
|
||||||
:py:class:`rsa.pkcs1.VerificationError` exception. It shows where in
|
|
||||||
the code the exception occurred, and thus leaks information about the
|
|
||||||
key. It's only a tiny bit of information, but every bit makes cracking
|
|
||||||
the keys easier.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
blocksize = common.byte_size(pub_key.n)
|
|
||||||
encrypted = transform.bytes2int(signature)
|
|
||||||
decrypted = core.decrypt_int(encrypted, pub_key.e, pub_key.n)
|
|
||||||
clearsig = transform.int2bytes(decrypted, blocksize)
|
|
||||||
|
|
||||||
# If we can't find the signature marker, verification failed.
|
|
||||||
if clearsig[0:2] != b('\x00\x01'):
|
|
||||||
raise VerificationError('Verification failed')
|
|
||||||
|
|
||||||
# Find the 00 separator between the padding and the payload
|
|
||||||
try:
|
|
||||||
sep_idx = clearsig.index(b('\x00'), 2)
|
|
||||||
except ValueError:
|
|
||||||
raise VerificationError('Verification failed')
|
|
||||||
|
|
||||||
# Get the hash and the hash method
|
|
||||||
(method_name, signature_hash) = _find_method_hash(clearsig[sep_idx+1:])
|
|
||||||
message_hash = _hash(message, method_name)
|
|
||||||
|
|
||||||
# Compare the real hash to the hash in the signature
|
|
||||||
if message_hash != signature_hash:
|
|
||||||
raise VerificationError('Verification failed')
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def _hash(message, method_name):
|
|
||||||
'''Returns the message digest.
|
|
||||||
|
|
||||||
:param message: the signed message. Can be an 8-bit string or a file-like
|
|
||||||
object. If ``message`` has a ``read()`` method, it is assumed to be a
|
|
||||||
file-like object.
|
|
||||||
:param method_name: the hash method, must be a key of
|
|
||||||
:py:const:`HASH_METHODS`.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
if method_name not in HASH_METHODS:
|
|
||||||
raise ValueError('Invalid hash method: %s' % method_name)
|
|
||||||
|
|
||||||
method = HASH_METHODS[method_name]
|
|
||||||
hasher = method()
|
|
||||||
|
|
||||||
if hasattr(message, 'read') and hasattr(message.read, '__call__'):
|
|
||||||
# read as 1K blocks
|
|
||||||
for block in varblock.yield_fixedblocks(message, 1024):
|
|
||||||
hasher.update(block)
|
|
||||||
else:
|
|
||||||
# hash the message object itself.
|
|
||||||
hasher.update(message)
|
|
||||||
|
|
||||||
return hasher.digest()
|
|
||||||
|
|
||||||
|
|
||||||
def _find_method_hash(method_hash):
|
|
||||||
'''Finds the hash method and the hash itself.
|
|
||||||
|
|
||||||
:param method_hash: ASN1 code for the hash method concatenated with the
|
|
||||||
hash itself.
|
|
||||||
|
|
||||||
:return: tuple (method, hash) where ``method`` is the used hash method, and
|
|
||||||
``hash`` is the hash itself.
|
|
||||||
|
|
||||||
:raise VerificationFailed: when the hash method cannot be found
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
for (hashname, asn1code) in HASH_ASN1.items():
|
|
||||||
if not method_hash.startswith(asn1code):
|
|
||||||
continue
|
|
||||||
|
|
||||||
return (hashname, method_hash[len(asn1code):])
|
|
||||||
|
|
||||||
raise VerificationError('Verification failed')
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ['encrypt', 'decrypt', 'sign', 'verify',
|
|
||||||
'DecryptionError', 'VerificationError', 'CryptoError']
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print('Running doctests 1000x or until failure')
|
|
||||||
import doctest
|
|
||||||
|
|
||||||
for count in range(1000):
|
|
||||||
(failures, tests) = doctest.testmod()
|
|
||||||
if failures:
|
|
||||||
break
|
|
||||||
|
|
||||||
if count and count % 100 == 0:
|
|
||||||
print('%i times' % count)
|
|
||||||
|
|
||||||
print('Doctests done')
|
|
|
@ -1,166 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''Numerical functions related to primes.
|
|
||||||
|
|
||||||
Implementation based on the book Algorithm Design by Michael T. Goodrich and
|
|
||||||
Roberto Tamassia, 2002.
|
|
||||||
'''
|
|
||||||
|
|
||||||
__all__ = [ 'getprime', 'are_relatively_prime']
|
|
||||||
|
|
||||||
import rsa.randnum
|
|
||||||
|
|
||||||
def gcd(p, q):
|
|
||||||
'''Returns the greatest common divisor of p and q
|
|
||||||
|
|
||||||
>>> gcd(48, 180)
|
|
||||||
12
|
|
||||||
'''
|
|
||||||
|
|
||||||
while q != 0:
|
|
||||||
if p < q: (p,q) = (q,p)
|
|
||||||
(p,q) = (q, p % q)
|
|
||||||
return p
|
|
||||||
|
|
||||||
|
|
||||||
def jacobi(a, b):
|
|
||||||
'''Calculates the value of the Jacobi symbol (a/b) where both a and b are
|
|
||||||
positive integers, and b is odd
|
|
||||||
|
|
||||||
:returns: -1, 0 or 1
|
|
||||||
'''
|
|
||||||
|
|
||||||
assert a > 0
|
|
||||||
assert b > 0
|
|
||||||
|
|
||||||
if a == 0: return 0
|
|
||||||
result = 1
|
|
||||||
while a > 1:
|
|
||||||
if a & 1:
|
|
||||||
if ((a-1)*(b-1) >> 2) & 1:
|
|
||||||
result = -result
|
|
||||||
a, b = b % a, a
|
|
||||||
else:
|
|
||||||
if (((b * b) - 1) >> 3) & 1:
|
|
||||||
result = -result
|
|
||||||
a >>= 1
|
|
||||||
if a == 0: return 0
|
|
||||||
return result
|
|
||||||
|
|
||||||
def jacobi_witness(x, n):
|
|
||||||
'''Returns False if n is an Euler pseudo-prime with base x, and
|
|
||||||
True otherwise.
|
|
||||||
'''
|
|
||||||
|
|
||||||
j = jacobi(x, n) % n
|
|
||||||
|
|
||||||
f = pow(x, n >> 1, n)
|
|
||||||
|
|
||||||
if j == f: return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
def randomized_primality_testing(n, k):
|
|
||||||
'''Calculates whether n is composite (which is always correct) or
|
|
||||||
prime (which is incorrect with error probability 2**-k)
|
|
||||||
|
|
||||||
Returns False if the number is composite, and True if it's
|
|
||||||
probably prime.
|
|
||||||
'''
|
|
||||||
|
|
||||||
# 50% of Jacobi-witnesses can report compositness of non-prime numbers
|
|
||||||
|
|
||||||
# The implemented algorithm using the Jacobi witness function has error
|
|
||||||
# probability q <= 0.5, according to Goodrich et. al
|
|
||||||
#
|
|
||||||
# q = 0.5
|
|
||||||
# t = int(math.ceil(k / log(1 / q, 2)))
|
|
||||||
# So t = k / log(2, 2) = k / 1 = k
|
|
||||||
# this means we can use range(k) rather than range(t)
|
|
||||||
|
|
||||||
for _ in range(k):
|
|
||||||
x = rsa.randnum.randint(n-1)
|
|
||||||
if jacobi_witness(x, n): return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
def is_prime(number):
|
|
||||||
'''Returns True if the number is prime, and False otherwise.
|
|
||||||
|
|
||||||
>>> is_prime(42)
|
|
||||||
False
|
|
||||||
>>> is_prime(41)
|
|
||||||
True
|
|
||||||
'''
|
|
||||||
|
|
||||||
return randomized_primality_testing(number, 6)
|
|
||||||
|
|
||||||
def getprime(nbits):
|
|
||||||
'''Returns a prime number that can be stored in 'nbits' bits.
|
|
||||||
|
|
||||||
>>> p = getprime(128)
|
|
||||||
>>> is_prime(p-1)
|
|
||||||
False
|
|
||||||
>>> is_prime(p)
|
|
||||||
True
|
|
||||||
>>> is_prime(p+1)
|
|
||||||
False
|
|
||||||
|
|
||||||
>>> from rsa import common
|
|
||||||
>>> common.bit_size(p) == 128
|
|
||||||
True
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
while True:
|
|
||||||
integer = rsa.randnum.read_random_int(nbits)
|
|
||||||
|
|
||||||
# Make sure it's odd
|
|
||||||
integer |= 1
|
|
||||||
|
|
||||||
# Test for primeness
|
|
||||||
if is_prime(integer):
|
|
||||||
return integer
|
|
||||||
|
|
||||||
# Retry if not prime
|
|
||||||
|
|
||||||
|
|
||||||
def are_relatively_prime(a, b):
|
|
||||||
'''Returns True if a and b are relatively prime, and False if they
|
|
||||||
are not.
|
|
||||||
|
|
||||||
>>> are_relatively_prime(2, 3)
|
|
||||||
1
|
|
||||||
>>> are_relatively_prime(2, 4)
|
|
||||||
0
|
|
||||||
'''
|
|
||||||
|
|
||||||
d = gcd(a, b)
|
|
||||||
return (d == 1)
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
print('Running doctests 1000x or until failure')
|
|
||||||
import doctest
|
|
||||||
|
|
||||||
for count in range(1000):
|
|
||||||
(failures, tests) = doctest.testmod()
|
|
||||||
if failures:
|
|
||||||
break
|
|
||||||
|
|
||||||
if count and count % 100 == 0:
|
|
||||||
print('%i times' % count)
|
|
||||||
|
|
||||||
print('Doctests done')
|
|
|
@ -1,85 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''Functions for generating random numbers.'''
|
|
||||||
|
|
||||||
# Source inspired by code by Yesudeep Mangalapilly <yesudeep@gmail.com>
|
|
||||||
|
|
||||||
import os
|
|
||||||
|
|
||||||
from rsa import common, transform
|
|
||||||
from rsa._compat import byte
|
|
||||||
|
|
||||||
def read_random_bits(nbits):
|
|
||||||
'''Reads 'nbits' random bits.
|
|
||||||
|
|
||||||
If nbits isn't a whole number of bytes, an extra byte will be appended with
|
|
||||||
only the lower bits set.
|
|
||||||
'''
|
|
||||||
|
|
||||||
nbytes, rbits = divmod(nbits, 8)
|
|
||||||
|
|
||||||
# Get the random bytes
|
|
||||||
randomdata = os.urandom(nbytes)
|
|
||||||
|
|
||||||
# Add the remaining random bits
|
|
||||||
if rbits > 0:
|
|
||||||
randomvalue = ord(os.urandom(1))
|
|
||||||
randomvalue >>= (8 - rbits)
|
|
||||||
randomdata = byte(randomvalue) + randomdata
|
|
||||||
|
|
||||||
return randomdata
|
|
||||||
|
|
||||||
|
|
||||||
def read_random_int(nbits):
|
|
||||||
'''Reads a random integer of approximately nbits bits.
|
|
||||||
'''
|
|
||||||
|
|
||||||
randomdata = read_random_bits(nbits)
|
|
||||||
value = transform.bytes2int(randomdata)
|
|
||||||
|
|
||||||
# Ensure that the number is large enough to just fill out the required
|
|
||||||
# number of bits.
|
|
||||||
value |= 1 << (nbits - 1)
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
def randint(maxvalue):
|
|
||||||
'''Returns a random integer x with 1 <= x <= maxvalue
|
|
||||||
|
|
||||||
May take a very long time in specific situations. If maxvalue needs N bits
|
|
||||||
to store, the closer maxvalue is to (2 ** N) - 1, the faster this function
|
|
||||||
is.
|
|
||||||
'''
|
|
||||||
|
|
||||||
bit_size = common.bit_size(maxvalue)
|
|
||||||
|
|
||||||
tries = 0
|
|
||||||
while True:
|
|
||||||
value = read_random_int(bit_size)
|
|
||||||
if value <= maxvalue:
|
|
||||||
break
|
|
||||||
|
|
||||||
if tries and tries % 10 == 0:
|
|
||||||
# After a lot of tries to get the right number of bits but still
|
|
||||||
# smaller than maxvalue, decrease the number of bits by 1. That'll
|
|
||||||
# dramatically increase the chances to get a large enough number.
|
|
||||||
bit_size -= 1
|
|
||||||
tries += 1
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
|
@ -1,220 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''Data transformation functions.
|
|
||||||
|
|
||||||
From bytes to a number, number to bytes, etc.
|
|
||||||
'''
|
|
||||||
|
|
||||||
from __future__ import absolute_import
|
|
||||||
|
|
||||||
try:
|
|
||||||
# We'll use psyco if available on 32-bit architectures to speed up code.
|
|
||||||
# Using psyco (if available) cuts down the execution time on Python 2.5
|
|
||||||
# at least by half.
|
|
||||||
import psyco
|
|
||||||
psyco.full()
|
|
||||||
except ImportError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
import binascii
|
|
||||||
from struct import pack
|
|
||||||
from rsa import common
|
|
||||||
from rsa._compat import is_integer, b, byte, get_word_alignment, ZERO_BYTE, EMPTY_BYTE
|
|
||||||
|
|
||||||
|
|
||||||
def bytes2int(raw_bytes):
|
|
||||||
r'''Converts a list of bytes or an 8-bit string to an integer.
|
|
||||||
|
|
||||||
When using unicode strings, encode it to some encoding like UTF8 first.
|
|
||||||
|
|
||||||
>>> (((128 * 256) + 64) * 256) + 15
|
|
||||||
8405007
|
|
||||||
>>> bytes2int('\x80@\x0f')
|
|
||||||
8405007
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
return int(binascii.hexlify(raw_bytes), 16)
|
|
||||||
|
|
||||||
|
|
||||||
def _int2bytes(number, block_size=None):
|
|
||||||
r'''Converts a number to a string of bytes.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
>>> _int2bytes(123456789)
|
|
||||||
'\x07[\xcd\x15'
|
|
||||||
>>> bytes2int(_int2bytes(123456789))
|
|
||||||
123456789
|
|
||||||
|
|
||||||
>>> _int2bytes(123456789, 6)
|
|
||||||
'\x00\x00\x07[\xcd\x15'
|
|
||||||
>>> bytes2int(_int2bytes(123456789, 128))
|
|
||||||
123456789
|
|
||||||
|
|
||||||
>>> _int2bytes(123456789, 3)
|
|
||||||
Traceback (most recent call last):
|
|
||||||
...
|
|
||||||
OverflowError: Needed 4 bytes for number, but block size is 3
|
|
||||||
|
|
||||||
@param number: the number to convert
|
|
||||||
@param block_size: the number of bytes to output. If the number encoded to
|
|
||||||
bytes is less than this, the block will be zero-padded. When not given,
|
|
||||||
the returned block is not padded.
|
|
||||||
|
|
||||||
@throws OverflowError when block_size is given and the number takes up more
|
|
||||||
bytes than fit into the block.
|
|
||||||
'''
|
|
||||||
# Type checking
|
|
||||||
if not is_integer(number):
|
|
||||||
raise TypeError("You must pass an integer for 'number', not %s" %
|
|
||||||
number.__class__)
|
|
||||||
|
|
||||||
if number < 0:
|
|
||||||
raise ValueError('Negative numbers cannot be used: %i' % number)
|
|
||||||
|
|
||||||
# Do some bounds checking
|
|
||||||
if number == 0:
|
|
||||||
needed_bytes = 1
|
|
||||||
raw_bytes = [ZERO_BYTE]
|
|
||||||
else:
|
|
||||||
needed_bytes = common.byte_size(number)
|
|
||||||
raw_bytes = []
|
|
||||||
|
|
||||||
# You cannot compare None > 0 in Python 3x. It will fail with a TypeError.
|
|
||||||
if block_size and block_size > 0:
|
|
||||||
if needed_bytes > block_size:
|
|
||||||
raise OverflowError('Needed %i bytes for number, but block size '
|
|
||||||
'is %i' % (needed_bytes, block_size))
|
|
||||||
|
|
||||||
# Convert the number to bytes.
|
|
||||||
while number > 0:
|
|
||||||
raw_bytes.insert(0, byte(number & 0xFF))
|
|
||||||
number >>= 8
|
|
||||||
|
|
||||||
# Pad with zeroes to fill the block
|
|
||||||
if block_size and block_size > 0:
|
|
||||||
padding = (block_size - needed_bytes) * ZERO_BYTE
|
|
||||||
else:
|
|
||||||
padding = EMPTY_BYTE
|
|
||||||
|
|
||||||
return padding + EMPTY_BYTE.join(raw_bytes)
|
|
||||||
|
|
||||||
|
|
||||||
def bytes_leading(raw_bytes, needle=ZERO_BYTE):
|
|
||||||
'''
|
|
||||||
Finds the number of prefixed byte occurrences in the haystack.
|
|
||||||
|
|
||||||
Useful when you want to deal with padding.
|
|
||||||
|
|
||||||
:param raw_bytes:
|
|
||||||
Raw bytes.
|
|
||||||
:param needle:
|
|
||||||
The byte to count. Default \000.
|
|
||||||
:returns:
|
|
||||||
The number of leading needle bytes.
|
|
||||||
'''
|
|
||||||
leading = 0
|
|
||||||
# Indexing keeps compatibility between Python 2.x and Python 3.x
|
|
||||||
_byte = needle[0]
|
|
||||||
for x in raw_bytes:
|
|
||||||
if x == _byte:
|
|
||||||
leading += 1
|
|
||||||
else:
|
|
||||||
break
|
|
||||||
return leading
|
|
||||||
|
|
||||||
|
|
||||||
def int2bytes(number, fill_size=None, chunk_size=None, overflow=False):
|
|
||||||
'''
|
|
||||||
Convert an unsigned integer to bytes (base-256 representation)::
|
|
||||||
|
|
||||||
Does not preserve leading zeros if you don't specify a chunk size or
|
|
||||||
fill size.
|
|
||||||
|
|
||||||
.. NOTE:
|
|
||||||
You must not specify both fill_size and chunk_size. Only one
|
|
||||||
of them is allowed.
|
|
||||||
|
|
||||||
:param number:
|
|
||||||
Integer value
|
|
||||||
:param fill_size:
|
|
||||||
If the optional fill size is given the length of the resulting
|
|
||||||
byte string is expected to be the fill size and will be padded
|
|
||||||
with prefix zero bytes to satisfy that length.
|
|
||||||
:param chunk_size:
|
|
||||||
If optional chunk size is given and greater than zero, pad the front of
|
|
||||||
the byte string with binary zeros so that the length is a multiple of
|
|
||||||
``chunk_size``.
|
|
||||||
:param overflow:
|
|
||||||
``False`` (default). If this is ``True``, no ``OverflowError``
|
|
||||||
will be raised when the fill_size is shorter than the length
|
|
||||||
of the generated byte sequence. Instead the byte sequence will
|
|
||||||
be returned as is.
|
|
||||||
:returns:
|
|
||||||
Raw bytes (base-256 representation).
|
|
||||||
:raises:
|
|
||||||
``OverflowError`` when fill_size is given and the number takes up more
|
|
||||||
bytes than fit into the block. This requires the ``overflow``
|
|
||||||
argument to this function to be set to ``False`` otherwise, no
|
|
||||||
error will be raised.
|
|
||||||
'''
|
|
||||||
if number < 0:
|
|
||||||
raise ValueError("Number must be an unsigned integer: %d" % number)
|
|
||||||
|
|
||||||
if fill_size and chunk_size:
|
|
||||||
raise ValueError("You can either fill or pad chunks, but not both")
|
|
||||||
|
|
||||||
# Ensure these are integers.
|
|
||||||
number & 1
|
|
||||||
|
|
||||||
raw_bytes = b('')
|
|
||||||
|
|
||||||
# Pack the integer one machine word at a time into bytes.
|
|
||||||
num = number
|
|
||||||
word_bits, _, max_uint, pack_type = get_word_alignment(num)
|
|
||||||
pack_format = ">%s" % pack_type
|
|
||||||
while num > 0:
|
|
||||||
raw_bytes = pack(pack_format, num & max_uint) + raw_bytes
|
|
||||||
num >>= word_bits
|
|
||||||
# Obtain the index of the first non-zero byte.
|
|
||||||
zero_leading = bytes_leading(raw_bytes)
|
|
||||||
if number == 0:
|
|
||||||
raw_bytes = ZERO_BYTE
|
|
||||||
# De-padding.
|
|
||||||
raw_bytes = raw_bytes[zero_leading:]
|
|
||||||
|
|
||||||
length = len(raw_bytes)
|
|
||||||
if fill_size and fill_size > 0:
|
|
||||||
if not overflow and length > fill_size:
|
|
||||||
raise OverflowError(
|
|
||||||
"Need %d bytes for number, but fill size is %d" %
|
|
||||||
(length, fill_size)
|
|
||||||
)
|
|
||||||
raw_bytes = raw_bytes.rjust(fill_size, ZERO_BYTE)
|
|
||||||
elif chunk_size and chunk_size > 0:
|
|
||||||
remainder = length % chunk_size
|
|
||||||
if remainder:
|
|
||||||
padding_size = chunk_size - remainder
|
|
||||||
raw_bytes = raw_bytes.rjust(length + padding_size, ZERO_BYTE)
|
|
||||||
return raw_bytes
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
import doctest
|
|
||||||
doctest.testmod()
|
|
||||||
|
|
|
@ -1,81 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''Utility functions.'''
|
|
||||||
|
|
||||||
from __future__ import with_statement, print_function
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from optparse import OptionParser
|
|
||||||
|
|
||||||
import rsa.key
|
|
||||||
|
|
||||||
def private_to_public():
|
|
||||||
'''Reads a private key and outputs the corresponding public key.'''
|
|
||||||
|
|
||||||
# Parse the CLI options
|
|
||||||
parser = OptionParser(usage='usage: %prog [options]',
|
|
||||||
description='Reads a private key and outputs the '
|
|
||||||
'corresponding public key. Both private and public keys use '
|
|
||||||
'the format described in PKCS#1 v1.5')
|
|
||||||
|
|
||||||
parser.add_option('-i', '--input', dest='infilename', type='string',
|
|
||||||
help='Input filename. Reads from stdin if not specified')
|
|
||||||
parser.add_option('-o', '--output', dest='outfilename', type='string',
|
|
||||||
help='Output filename. Writes to stdout of not specified')
|
|
||||||
|
|
||||||
parser.add_option('--inform', dest='inform',
|
|
||||||
help='key format of input - default PEM',
|
|
||||||
choices=('PEM', 'DER'), default='PEM')
|
|
||||||
|
|
||||||
parser.add_option('--outform', dest='outform',
|
|
||||||
help='key format of output - default PEM',
|
|
||||||
choices=('PEM', 'DER'), default='PEM')
|
|
||||||
|
|
||||||
(cli, cli_args) = parser.parse_args(sys.argv)
|
|
||||||
|
|
||||||
# Read the input data
|
|
||||||
if cli.infilename:
|
|
||||||
print('Reading private key from %s in %s format' % \
|
|
||||||
(cli.infilename, cli.inform), file=sys.stderr)
|
|
||||||
with open(cli.infilename, 'rb') as infile:
|
|
||||||
in_data = infile.read()
|
|
||||||
else:
|
|
||||||
print('Reading private key from stdin in %s format' % cli.inform,
|
|
||||||
file=sys.stderr)
|
|
||||||
in_data = sys.stdin.read().encode('ascii')
|
|
||||||
|
|
||||||
assert type(in_data) == bytes, type(in_data)
|
|
||||||
|
|
||||||
|
|
||||||
# Take the public fields and create a public key
|
|
||||||
priv_key = rsa.key.PrivateKey.load_pkcs1(in_data, cli.inform)
|
|
||||||
pub_key = rsa.key.PublicKey(priv_key.n, priv_key.e)
|
|
||||||
|
|
||||||
# Save to the output file
|
|
||||||
out_data = pub_key.save_pkcs1(cli.outform)
|
|
||||||
|
|
||||||
if cli.outfilename:
|
|
||||||
print('Writing public key to %s in %s format' % \
|
|
||||||
(cli.outfilename, cli.outform), file=sys.stderr)
|
|
||||||
with open(cli.outfilename, 'wb') as outfile:
|
|
||||||
outfile.write(out_data)
|
|
||||||
else:
|
|
||||||
print('Writing public key to stdout in %s format' % cli.outform,
|
|
||||||
file=sys.stderr)
|
|
||||||
sys.stdout.write(out_data.decode('ascii'))
|
|
||||||
|
|
||||||
|
|
|
@ -1,155 +0,0 @@
|
||||||
# -*- coding: utf-8 -*-
|
|
||||||
#
|
|
||||||
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
'''VARBLOCK file support
|
|
||||||
|
|
||||||
The VARBLOCK file format is as follows, where || denotes byte concatenation:
|
|
||||||
|
|
||||||
FILE := VERSION || BLOCK || BLOCK ...
|
|
||||||
|
|
||||||
BLOCK := LENGTH || DATA
|
|
||||||
|
|
||||||
LENGTH := varint-encoded length of the subsequent data. Varint comes from
|
|
||||||
Google Protobuf, and encodes an integer into a variable number of bytes.
|
|
||||||
Each byte uses the 7 lowest bits to encode the value. The highest bit set
|
|
||||||
to 1 indicates the next byte is also part of the varint. The last byte will
|
|
||||||
have this bit set to 0.
|
|
||||||
|
|
||||||
This file format is called the VARBLOCK format, in line with the varint format
|
|
||||||
used to denote the block sizes.
|
|
||||||
|
|
||||||
'''
|
|
||||||
|
|
||||||
from rsa._compat import byte, b
|
|
||||||
|
|
||||||
|
|
||||||
ZERO_BYTE = b('\x00')
|
|
||||||
VARBLOCK_VERSION = 1
|
|
||||||
|
|
||||||
def read_varint(infile):
|
|
||||||
'''Reads a varint from the file.
|
|
||||||
|
|
||||||
When the first byte to be read indicates EOF, (0, 0) is returned. When an
|
|
||||||
EOF occurs when at least one byte has been read, an EOFError exception is
|
|
||||||
raised.
|
|
||||||
|
|
||||||
@param infile: the file-like object to read from. It should have a read()
|
|
||||||
method.
|
|
||||||
@returns (varint, length), the read varint and the number of read bytes.
|
|
||||||
'''
|
|
||||||
|
|
||||||
varint = 0
|
|
||||||
read_bytes = 0
|
|
||||||
|
|
||||||
while True:
|
|
||||||
char = infile.read(1)
|
|
||||||
if len(char) == 0:
|
|
||||||
if read_bytes == 0:
|
|
||||||
return (0, 0)
|
|
||||||
raise EOFError('EOF while reading varint, value is %i so far' %
|
|
||||||
varint)
|
|
||||||
|
|
||||||
byte = ord(char)
|
|
||||||
varint += (byte & 0x7F) << (7 * read_bytes)
|
|
||||||
|
|
||||||
read_bytes += 1
|
|
||||||
|
|
||||||
if not byte & 0x80:
|
|
||||||
return (varint, read_bytes)
|
|
||||||
|
|
||||||
|
|
||||||
def write_varint(outfile, value):
|
|
||||||
'''Writes a varint to a file.
|
|
||||||
|
|
||||||
@param outfile: the file-like object to write to. It should have a write()
|
|
||||||
method.
|
|
||||||
@returns the number of written bytes.
|
|
||||||
'''
|
|
||||||
|
|
||||||
# there is a big difference between 'write the value 0' (this case) and
|
|
||||||
# 'there is nothing left to write' (the false-case of the while loop)
|
|
||||||
|
|
||||||
if value == 0:
|
|
||||||
outfile.write(ZERO_BYTE)
|
|
||||||
return 1
|
|
||||||
|
|
||||||
written_bytes = 0
|
|
||||||
while value > 0:
|
|
||||||
to_write = value & 0x7f
|
|
||||||
value = value >> 7
|
|
||||||
|
|
||||||
if value > 0:
|
|
||||||
to_write |= 0x80
|
|
||||||
|
|
||||||
outfile.write(byte(to_write))
|
|
||||||
written_bytes += 1
|
|
||||||
|
|
||||||
return written_bytes
|
|
||||||
|
|
||||||
|
|
||||||
def yield_varblocks(infile):
|
|
||||||
'''Generator, yields each block in the input file.
|
|
||||||
|
|
||||||
@param infile: file to read, is expected to have the VARBLOCK format as
|
|
||||||
described in the module's docstring.
|
|
||||||
@yields the contents of each block.
|
|
||||||
'''
|
|
||||||
|
|
||||||
# Check the version number
|
|
||||||
first_char = infile.read(1)
|
|
||||||
if len(first_char) == 0:
|
|
||||||
raise EOFError('Unable to read VARBLOCK version number')
|
|
||||||
|
|
||||||
version = ord(first_char)
|
|
||||||
if version != VARBLOCK_VERSION:
|
|
||||||
raise ValueError('VARBLOCK version %i not supported' % version)
|
|
||||||
|
|
||||||
while True:
|
|
||||||
(block_size, read_bytes) = read_varint(infile)
|
|
||||||
|
|
||||||
# EOF at block boundary, that's fine.
|
|
||||||
if read_bytes == 0 and block_size == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
block = infile.read(block_size)
|
|
||||||
|
|
||||||
read_size = len(block)
|
|
||||||
if read_size != block_size:
|
|
||||||
raise EOFError('Block size is %i, but could read only %i bytes' %
|
|
||||||
(block_size, read_size))
|
|
||||||
|
|
||||||
yield block
|
|
||||||
|
|
||||||
|
|
||||||
def yield_fixedblocks(infile, blocksize):
|
|
||||||
'''Generator, yields each block of ``blocksize`` bytes in the input file.
|
|
||||||
|
|
||||||
:param infile: file to read and separate in blocks.
|
|
||||||
:returns: a generator that yields the contents of each block
|
|
||||||
'''
|
|
||||||
|
|
||||||
while True:
|
|
||||||
block = infile.read(blocksize)
|
|
||||||
|
|
||||||
read_bytes = len(block)
|
|
||||||
if read_bytes == 0:
|
|
||||||
break
|
|
||||||
|
|
||||||
yield block
|
|
||||||
|
|
||||||
if read_bytes < blocksize:
|
|
||||||
break
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
"""The Tornado web server and tools."""
|
"""The Tornado web server and tools."""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# version is a human-readable version number.
|
# version is a human-readable version number.
|
||||||
|
|
||||||
|
@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement
|
||||||
# is zero for an official release, positive for a development branch,
|
# is zero for an official release, positive for a development branch,
|
||||||
# or negative for a release candidate or beta (after the base version
|
# or negative for a release candidate or beta (after the base version
|
||||||
# number has been incremented)
|
# number has been incremented)
|
||||||
version = "4.3"
|
version = "4.5.1"
|
||||||
version_info = (4, 3, 0, 0)
|
version_info = (4, 5, 1, 0)
|
||||||
|
|
|
@ -17,78 +17,69 @@
|
||||||
|
|
||||||
"""Data used by the tornado.locale module."""
|
"""Data used by the tornado.locale module."""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
# NOTE: This file is supposed to contain unicode strings, which is
|
|
||||||
# exactly what you'd get with e.g. u"Español" in most python versions.
|
|
||||||
# However, Python 3.2 doesn't support the u"" syntax, so we use a u()
|
|
||||||
# function instead. tornado.util.u cannot be used because it doesn't
|
|
||||||
# support non-ascii characters on python 2.
|
|
||||||
# When we drop support for Python 3.2, we can remove the parens
|
|
||||||
# and make these plain unicode strings.
|
|
||||||
from tornado.escape import to_unicode as u
|
|
||||||
|
|
||||||
LOCALE_NAMES = {
|
LOCALE_NAMES = {
|
||||||
"af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")},
|
"af_ZA": {"name_en": u"Afrikaans", "name": u"Afrikaans"},
|
||||||
"am_ET": {"name_en": u("Amharic"), "name": u("አማርኛ")},
|
"am_ET": {"name_en": u"Amharic", "name": u"አማርኛ"},
|
||||||
"ar_AR": {"name_en": u("Arabic"), "name": u("العربية")},
|
"ar_AR": {"name_en": u"Arabic", "name": u"العربية"},
|
||||||
"bg_BG": {"name_en": u("Bulgarian"), "name": u("Български")},
|
"bg_BG": {"name_en": u"Bulgarian", "name": u"Български"},
|
||||||
"bn_IN": {"name_en": u("Bengali"), "name": u("বাংলা")},
|
"bn_IN": {"name_en": u"Bengali", "name": u"বাংলা"},
|
||||||
"bs_BA": {"name_en": u("Bosnian"), "name": u("Bosanski")},
|
"bs_BA": {"name_en": u"Bosnian", "name": u"Bosanski"},
|
||||||
"ca_ES": {"name_en": u("Catalan"), "name": u("Català")},
|
"ca_ES": {"name_en": u"Catalan", "name": u"Català"},
|
||||||
"cs_CZ": {"name_en": u("Czech"), "name": u("Čeština")},
|
"cs_CZ": {"name_en": u"Czech", "name": u"Čeština"},
|
||||||
"cy_GB": {"name_en": u("Welsh"), "name": u("Cymraeg")},
|
"cy_GB": {"name_en": u"Welsh", "name": u"Cymraeg"},
|
||||||
"da_DK": {"name_en": u("Danish"), "name": u("Dansk")},
|
"da_DK": {"name_en": u"Danish", "name": u"Dansk"},
|
||||||
"de_DE": {"name_en": u("German"), "name": u("Deutsch")},
|
"de_DE": {"name_en": u"German", "name": u"Deutsch"},
|
||||||
"el_GR": {"name_en": u("Greek"), "name": u("Ελληνικά")},
|
"el_GR": {"name_en": u"Greek", "name": u"Ελληνικά"},
|
||||||
"en_GB": {"name_en": u("English (UK)"), "name": u("English (UK)")},
|
"en_GB": {"name_en": u"English (UK)", "name": u"English (UK)"},
|
||||||
"en_US": {"name_en": u("English (US)"), "name": u("English (US)")},
|
"en_US": {"name_en": u"English (US)", "name": u"English (US)"},
|
||||||
"es_ES": {"name_en": u("Spanish (Spain)"), "name": u("Español (España)")},
|
"es_ES": {"name_en": u"Spanish (Spain)", "name": u"Español (España)"},
|
||||||
"es_LA": {"name_en": u("Spanish"), "name": u("Español")},
|
"es_LA": {"name_en": u"Spanish", "name": u"Español"},
|
||||||
"et_EE": {"name_en": u("Estonian"), "name": u("Eesti")},
|
"et_EE": {"name_en": u"Estonian", "name": u"Eesti"},
|
||||||
"eu_ES": {"name_en": u("Basque"), "name": u("Euskara")},
|
"eu_ES": {"name_en": u"Basque", "name": u"Euskara"},
|
||||||
"fa_IR": {"name_en": u("Persian"), "name": u("فارسی")},
|
"fa_IR": {"name_en": u"Persian", "name": u"فارسی"},
|
||||||
"fi_FI": {"name_en": u("Finnish"), "name": u("Suomi")},
|
"fi_FI": {"name_en": u"Finnish", "name": u"Suomi"},
|
||||||
"fr_CA": {"name_en": u("French (Canada)"), "name": u("Français (Canada)")},
|
"fr_CA": {"name_en": u"French (Canada)", "name": u"Français (Canada)"},
|
||||||
"fr_FR": {"name_en": u("French"), "name": u("Français")},
|
"fr_FR": {"name_en": u"French", "name": u"Français"},
|
||||||
"ga_IE": {"name_en": u("Irish"), "name": u("Gaeilge")},
|
"ga_IE": {"name_en": u"Irish", "name": u"Gaeilge"},
|
||||||
"gl_ES": {"name_en": u("Galician"), "name": u("Galego")},
|
"gl_ES": {"name_en": u"Galician", "name": u"Galego"},
|
||||||
"he_IL": {"name_en": u("Hebrew"), "name": u("עברית")},
|
"he_IL": {"name_en": u"Hebrew", "name": u"עברית"},
|
||||||
"hi_IN": {"name_en": u("Hindi"), "name": u("हिन्दी")},
|
"hi_IN": {"name_en": u"Hindi", "name": u"हिन्दी"},
|
||||||
"hr_HR": {"name_en": u("Croatian"), "name": u("Hrvatski")},
|
"hr_HR": {"name_en": u"Croatian", "name": u"Hrvatski"},
|
||||||
"hu_HU": {"name_en": u("Hungarian"), "name": u("Magyar")},
|
"hu_HU": {"name_en": u"Hungarian", "name": u"Magyar"},
|
||||||
"id_ID": {"name_en": u("Indonesian"), "name": u("Bahasa Indonesia")},
|
"id_ID": {"name_en": u"Indonesian", "name": u"Bahasa Indonesia"},
|
||||||
"is_IS": {"name_en": u("Icelandic"), "name": u("Íslenska")},
|
"is_IS": {"name_en": u"Icelandic", "name": u"Íslenska"},
|
||||||
"it_IT": {"name_en": u("Italian"), "name": u("Italiano")},
|
"it_IT": {"name_en": u"Italian", "name": u"Italiano"},
|
||||||
"ja_JP": {"name_en": u("Japanese"), "name": u("日本語")},
|
"ja_JP": {"name_en": u"Japanese", "name": u"日本語"},
|
||||||
"ko_KR": {"name_en": u("Korean"), "name": u("한국어")},
|
"ko_KR": {"name_en": u"Korean", "name": u"한국어"},
|
||||||
"lt_LT": {"name_en": u("Lithuanian"), "name": u("Lietuvių")},
|
"lt_LT": {"name_en": u"Lithuanian", "name": u"Lietuvių"},
|
||||||
"lv_LV": {"name_en": u("Latvian"), "name": u("Latviešu")},
|
"lv_LV": {"name_en": u"Latvian", "name": u"Latviešu"},
|
||||||
"mk_MK": {"name_en": u("Macedonian"), "name": u("Македонски")},
|
"mk_MK": {"name_en": u"Macedonian", "name": u"Македонски"},
|
||||||
"ml_IN": {"name_en": u("Malayalam"), "name": u("മലയാളം")},
|
"ml_IN": {"name_en": u"Malayalam", "name": u"മലയാളം"},
|
||||||
"ms_MY": {"name_en": u("Malay"), "name": u("Bahasa Melayu")},
|
"ms_MY": {"name_en": u"Malay", "name": u"Bahasa Melayu"},
|
||||||
"nb_NO": {"name_en": u("Norwegian (bokmal)"), "name": u("Norsk (bokmål)")},
|
"nb_NO": {"name_en": u"Norwegian (bokmal)", "name": u"Norsk (bokmål)"},
|
||||||
"nl_NL": {"name_en": u("Dutch"), "name": u("Nederlands")},
|
"nl_NL": {"name_en": u"Dutch", "name": u"Nederlands"},
|
||||||
"nn_NO": {"name_en": u("Norwegian (nynorsk)"), "name": u("Norsk (nynorsk)")},
|
"nn_NO": {"name_en": u"Norwegian (nynorsk)", "name": u"Norsk (nynorsk)"},
|
||||||
"pa_IN": {"name_en": u("Punjabi"), "name": u("ਪੰਜਾਬੀ")},
|
"pa_IN": {"name_en": u"Punjabi", "name": u"ਪੰਜਾਬੀ"},
|
||||||
"pl_PL": {"name_en": u("Polish"), "name": u("Polski")},
|
"pl_PL": {"name_en": u"Polish", "name": u"Polski"},
|
||||||
"pt_BR": {"name_en": u("Portuguese (Brazil)"), "name": u("Português (Brasil)")},
|
"pt_BR": {"name_en": u"Portuguese (Brazil)", "name": u"Português (Brasil)"},
|
||||||
"pt_PT": {"name_en": u("Portuguese (Portugal)"), "name": u("Português (Portugal)")},
|
"pt_PT": {"name_en": u"Portuguese (Portugal)", "name": u"Português (Portugal)"},
|
||||||
"ro_RO": {"name_en": u("Romanian"), "name": u("Română")},
|
"ro_RO": {"name_en": u"Romanian", "name": u"Română"},
|
||||||
"ru_RU": {"name_en": u("Russian"), "name": u("Русский")},
|
"ru_RU": {"name_en": u"Russian", "name": u"Русский"},
|
||||||
"sk_SK": {"name_en": u("Slovak"), "name": u("Slovenčina")},
|
"sk_SK": {"name_en": u"Slovak", "name": u"Slovenčina"},
|
||||||
"sl_SI": {"name_en": u("Slovenian"), "name": u("Slovenščina")},
|
"sl_SI": {"name_en": u"Slovenian", "name": u"Slovenščina"},
|
||||||
"sq_AL": {"name_en": u("Albanian"), "name": u("Shqip")},
|
"sq_AL": {"name_en": u"Albanian", "name": u"Shqip"},
|
||||||
"sr_RS": {"name_en": u("Serbian"), "name": u("Српски")},
|
"sr_RS": {"name_en": u"Serbian", "name": u"Српски"},
|
||||||
"sv_SE": {"name_en": u("Swedish"), "name": u("Svenska")},
|
"sv_SE": {"name_en": u"Swedish", "name": u"Svenska"},
|
||||||
"sw_KE": {"name_en": u("Swahili"), "name": u("Kiswahili")},
|
"sw_KE": {"name_en": u"Swahili", "name": u"Kiswahili"},
|
||||||
"ta_IN": {"name_en": u("Tamil"), "name": u("தமிழ்")},
|
"ta_IN": {"name_en": u"Tamil", "name": u"தமிழ்"},
|
||||||
"te_IN": {"name_en": u("Telugu"), "name": u("తెలుగు")},
|
"te_IN": {"name_en": u"Telugu", "name": u"తెలుగు"},
|
||||||
"th_TH": {"name_en": u("Thai"), "name": u("ภาษาไทย")},
|
"th_TH": {"name_en": u"Thai", "name": u"ภาษาไทย"},
|
||||||
"tl_PH": {"name_en": u("Filipino"), "name": u("Filipino")},
|
"tl_PH": {"name_en": u"Filipino", "name": u"Filipino"},
|
||||||
"tr_TR": {"name_en": u("Turkish"), "name": u("Türkçe")},
|
"tr_TR": {"name_en": u"Turkish", "name": u"Türkçe"},
|
||||||
"uk_UA": {"name_en": u("Ukraini "), "name": u("Українська")},
|
"uk_UA": {"name_en": u"Ukraini ", "name": u"Українська"},
|
||||||
"vi_VN": {"name_en": u("Vietnamese"), "name": u("Tiếng Việt")},
|
"vi_VN": {"name_en": u"Vietnamese", "name": u"Tiếng Việt"},
|
||||||
"zh_CN": {"name_en": u("Chinese (Simplified)"), "name": u("中文(简体)")},
|
"zh_CN": {"name_en": u"Chinese (Simplified)", "name": u"中文(简体)"},
|
||||||
"zh_TW": {"name_en": u("Chinese (Traditional)"), "name": u("中文(繁體)")},
|
"zh_TW": {"name_en": u"Chinese (Traditional)", "name": u"中文(繁體)"},
|
||||||
}
|
}
|
||||||
|
|
|
@ -65,7 +65,7 @@ Example usage for Google OAuth:
|
||||||
errors are more consistently reported through the ``Future`` interfaces.
|
errors are more consistently reported through the ``Future`` interfaces.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import binascii
|
import binascii
|
||||||
|
@ -82,22 +82,15 @@ from tornado import escape
|
||||||
from tornado.httputil import url_concat
|
from tornado.httputil import url_concat
|
||||||
from tornado.log import gen_log
|
from tornado.log import gen_log
|
||||||
from tornado.stack_context import ExceptionStackContext
|
from tornado.stack_context import ExceptionStackContext
|
||||||
from tornado.util import u, unicode_type, ArgReplacer
|
from tornado.util import unicode_type, ArgReplacer, PY3
|
||||||
|
|
||||||
try:
|
if PY3:
|
||||||
import urlparse # py2
|
import urllib.parse as urlparse
|
||||||
except ImportError:
|
import urllib.parse as urllib_parse
|
||||||
import urllib.parse as urlparse # py3
|
long = int
|
||||||
|
else:
|
||||||
try:
|
import urlparse
|
||||||
import urllib.parse as urllib_parse # py3
|
import urllib as urllib_parse
|
||||||
except ImportError:
|
|
||||||
import urllib as urllib_parse # py2
|
|
||||||
|
|
||||||
try:
|
|
||||||
long # py2
|
|
||||||
except NameError:
|
|
||||||
long = int # py3
|
|
||||||
|
|
||||||
|
|
||||||
class AuthError(Exception):
|
class AuthError(Exception):
|
||||||
|
@ -188,7 +181,7 @@ class OpenIdMixin(object):
|
||||||
"""
|
"""
|
||||||
# Verify the OpenID response via direct request to the OP
|
# Verify the OpenID response via direct request to the OP
|
||||||
args = dict((k, v[-1]) for k, v in self.request.arguments.items())
|
args = dict((k, v[-1]) for k, v in self.request.arguments.items())
|
||||||
args["openid.mode"] = u("check_authentication")
|
args["openid.mode"] = u"check_authentication"
|
||||||
url = self._OPENID_ENDPOINT
|
url = self._OPENID_ENDPOINT
|
||||||
if http_client is None:
|
if http_client is None:
|
||||||
http_client = self.get_auth_http_client()
|
http_client = self.get_auth_http_client()
|
||||||
|
@ -255,13 +248,13 @@ class OpenIdMixin(object):
|
||||||
ax_ns = None
|
ax_ns = None
|
||||||
for name in self.request.arguments:
|
for name in self.request.arguments:
|
||||||
if name.startswith("openid.ns.") and \
|
if name.startswith("openid.ns.") and \
|
||||||
self.get_argument(name) == u("http://openid.net/srv/ax/1.0"):
|
self.get_argument(name) == u"http://openid.net/srv/ax/1.0":
|
||||||
ax_ns = name[10:]
|
ax_ns = name[10:]
|
||||||
break
|
break
|
||||||
|
|
||||||
def get_ax_arg(uri):
|
def get_ax_arg(uri):
|
||||||
if not ax_ns:
|
if not ax_ns:
|
||||||
return u("")
|
return u""
|
||||||
prefix = "openid." + ax_ns + ".type."
|
prefix = "openid." + ax_ns + ".type."
|
||||||
ax_name = None
|
ax_name = None
|
||||||
for name in self.request.arguments.keys():
|
for name in self.request.arguments.keys():
|
||||||
|
@ -270,8 +263,8 @@ class OpenIdMixin(object):
|
||||||
ax_name = "openid." + ax_ns + ".value." + part
|
ax_name = "openid." + ax_ns + ".value." + part
|
||||||
break
|
break
|
||||||
if not ax_name:
|
if not ax_name:
|
||||||
return u("")
|
return u""
|
||||||
return self.get_argument(ax_name, u(""))
|
return self.get_argument(ax_name, u"")
|
||||||
|
|
||||||
email = get_ax_arg("http://axschema.org/contact/email")
|
email = get_ax_arg("http://axschema.org/contact/email")
|
||||||
name = get_ax_arg("http://axschema.org/namePerson")
|
name = get_ax_arg("http://axschema.org/namePerson")
|
||||||
|
@ -290,7 +283,7 @@ class OpenIdMixin(object):
|
||||||
if name:
|
if name:
|
||||||
user["name"] = name
|
user["name"] = name
|
||||||
elif name_parts:
|
elif name_parts:
|
||||||
user["name"] = u(" ").join(name_parts)
|
user["name"] = u" ".join(name_parts)
|
||||||
elif email:
|
elif email:
|
||||||
user["name"] = email.split("@")[0]
|
user["name"] = email.split("@")[0]
|
||||||
if email:
|
if email:
|
||||||
|
@ -961,6 +954,20 @@ class FacebookGraphMixin(OAuth2Mixin):
|
||||||
.. testoutput::
|
.. testoutput::
|
||||||
:hide:
|
:hide:
|
||||||
|
|
||||||
|
This method returns a dictionary which may contain the following fields:
|
||||||
|
|
||||||
|
* ``access_token``, a string which may be passed to `facebook_request`
|
||||||
|
* ``session_expires``, an integer encoded as a string representing
|
||||||
|
the time until the access token expires in seconds. This field should
|
||||||
|
be used like ``int(user['session_expires'])``; in a future version of
|
||||||
|
Tornado it will change from a string to an integer.
|
||||||
|
* ``id``, ``name``, ``first_name``, ``last_name``, ``locale``, ``picture``,
|
||||||
|
``link``, plus any fields named in the ``extra_fields`` argument. These
|
||||||
|
fields are copied from the Facebook graph API `user object <https://developers.facebook.com/docs/graph-api/reference/user>`_
|
||||||
|
|
||||||
|
.. versionchanged:: 4.5
|
||||||
|
The ``session_expires`` field was updated to support changes made to the
|
||||||
|
Facebook API in March 2017.
|
||||||
"""
|
"""
|
||||||
http = self.get_auth_http_client()
|
http = self.get_auth_http_client()
|
||||||
args = {
|
args = {
|
||||||
|
@ -985,10 +992,10 @@ class FacebookGraphMixin(OAuth2Mixin):
|
||||||
future.set_exception(AuthError('Facebook auth error: %s' % str(response)))
|
future.set_exception(AuthError('Facebook auth error: %s' % str(response)))
|
||||||
return
|
return
|
||||||
|
|
||||||
args = urlparse.parse_qs(escape.native_str(response.body))
|
args = escape.json_decode(response.body)
|
||||||
session = {
|
session = {
|
||||||
"access_token": args["access_token"][-1],
|
"access_token": args.get("access_token"),
|
||||||
"expires": args.get("expires")
|
"expires_in": args.get("expires_in")
|
||||||
}
|
}
|
||||||
|
|
||||||
self.facebook_request(
|
self.facebook_request(
|
||||||
|
@ -996,6 +1003,9 @@ class FacebookGraphMixin(OAuth2Mixin):
|
||||||
callback=functools.partial(
|
callback=functools.partial(
|
||||||
self._on_get_user_info, future, session, fields),
|
self._on_get_user_info, future, session, fields),
|
||||||
access_token=session["access_token"],
|
access_token=session["access_token"],
|
||||||
|
appsecret_proof=hmac.new(key=client_secret.encode('utf8'),
|
||||||
|
msg=session["access_token"].encode('utf8'),
|
||||||
|
digestmod=hashlib.sha256).hexdigest(),
|
||||||
fields=",".join(fields)
|
fields=",".join(fields)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1008,7 +1018,12 @@ class FacebookGraphMixin(OAuth2Mixin):
|
||||||
for field in fields:
|
for field in fields:
|
||||||
fieldmap[field] = user.get(field)
|
fieldmap[field] = user.get(field)
|
||||||
|
|
||||||
fieldmap.update({"access_token": session["access_token"], "session_expires": session.get("expires")})
|
# session_expires is converted to str for compatibility with
|
||||||
|
# older versions in which the server used url-encoding and
|
||||||
|
# this code simply returned the string verbatim.
|
||||||
|
# This should change in Tornado 5.0.
|
||||||
|
fieldmap.update({"access_token": session["access_token"],
|
||||||
|
"session_expires": str(session.get("expires_in"))})
|
||||||
future.set_result(fieldmap)
|
future.set_result(fieldmap)
|
||||||
|
|
||||||
@_auth_return_future
|
@_auth_return_future
|
||||||
|
|
|
@ -45,7 +45,7 @@ incorrectly.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
@ -83,7 +83,7 @@ if __name__ == "__main__":
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import pkgutil
|
import pkgutil # type: ignore
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import types
|
import types
|
||||||
|
@ -103,16 +103,12 @@ except ImportError:
|
||||||
# os.execv is broken on Windows and can't properly parse command line
|
# os.execv is broken on Windows and can't properly parse command line
|
||||||
# arguments and executable name if they contain whitespaces. subprocess
|
# arguments and executable name if they contain whitespaces. subprocess
|
||||||
# fixes that behavior.
|
# fixes that behavior.
|
||||||
# This distinction is also important because when we use execv, we want to
|
|
||||||
# close the IOLoop and all its file descriptors, to guard against any
|
|
||||||
# file descriptors that were not set CLOEXEC. When execv is not available,
|
|
||||||
# we must not close the IOLoop because we want the process to exit cleanly.
|
|
||||||
_has_execv = sys.platform != 'win32'
|
_has_execv = sys.platform != 'win32'
|
||||||
|
|
||||||
_watched_files = set()
|
_watched_files = set()
|
||||||
_reload_hooks = []
|
_reload_hooks = []
|
||||||
_reload_attempted = False
|
_reload_attempted = False
|
||||||
_io_loops = weakref.WeakKeyDictionary()
|
_io_loops = weakref.WeakKeyDictionary() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
def start(io_loop=None, check_time=500):
|
def start(io_loop=None, check_time=500):
|
||||||
|
@ -127,8 +123,6 @@ def start(io_loop=None, check_time=500):
|
||||||
_io_loops[io_loop] = True
|
_io_loops[io_loop] = True
|
||||||
if len(_io_loops) > 1:
|
if len(_io_loops) > 1:
|
||||||
gen_log.warning("tornado.autoreload started more than once in the same process")
|
gen_log.warning("tornado.autoreload started more than once in the same process")
|
||||||
if _has_execv:
|
|
||||||
add_reload_hook(functools.partial(io_loop.close, all_fds=True))
|
|
||||||
modify_times = {}
|
modify_times = {}
|
||||||
callback = functools.partial(_reload_on_update, modify_times)
|
callback = functools.partial(_reload_on_update, modify_times)
|
||||||
scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop)
|
scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop)
|
||||||
|
@ -249,6 +243,7 @@ def _reload():
|
||||||
# unwind, so just exit uncleanly.
|
# unwind, so just exit uncleanly.
|
||||||
os._exit(0)
|
os._exit(0)
|
||||||
|
|
||||||
|
|
||||||
_USAGE = """\
|
_USAGE = """\
|
||||||
Usage:
|
Usage:
|
||||||
python -m tornado.autoreload -m module.to.run [args...]
|
python -m tornado.autoreload -m module.to.run [args...]
|
||||||
|
|
|
@ -21,7 +21,7 @@ a mostly-compatible `Future` class designed for use from coroutines,
|
||||||
as well as some utility functions for interacting with the
|
as well as some utility functions for interacting with the
|
||||||
`concurrent.futures` package.
|
`concurrent.futures` package.
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import platform
|
import platform
|
||||||
|
@ -31,13 +31,18 @@ import sys
|
||||||
|
|
||||||
from tornado.log import app_log
|
from tornado.log import app_log
|
||||||
from tornado.stack_context import ExceptionStackContext, wrap
|
from tornado.stack_context import ExceptionStackContext, wrap
|
||||||
from tornado.util import raise_exc_info, ArgReplacer
|
from tornado.util import raise_exc_info, ArgReplacer, is_finalizing
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from concurrent import futures
|
from concurrent import futures
|
||||||
except ImportError:
|
except ImportError:
|
||||||
futures = None
|
futures = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import typing
|
||||||
|
except ImportError:
|
||||||
|
typing = None
|
||||||
|
|
||||||
|
|
||||||
# Can the garbage collector handle cycles that include __del__ methods?
|
# Can the garbage collector handle cycles that include __del__ methods?
|
||||||
# This is true in cpython beginning with version 3.4 (PEP 442).
|
# This is true in cpython beginning with version 3.4 (PEP 442).
|
||||||
|
@ -118,8 +123,8 @@ class _TracebackLogger(object):
|
||||||
self.exc_info = None
|
self.exc_info = None
|
||||||
self.formatted_tb = None
|
self.formatted_tb = None
|
||||||
|
|
||||||
def __del__(self):
|
def __del__(self, is_finalizing=is_finalizing):
|
||||||
if self.formatted_tb:
|
if not is_finalizing() and self.formatted_tb:
|
||||||
app_log.error('Future exception was never retrieved: %s',
|
app_log.error('Future exception was never retrieved: %s',
|
||||||
''.join(self.formatted_tb).rstrip())
|
''.join(self.formatted_tb).rstrip())
|
||||||
|
|
||||||
|
@ -229,7 +234,10 @@ class Future(object):
|
||||||
if self._result is not None:
|
if self._result is not None:
|
||||||
return self._result
|
return self._result
|
||||||
if self._exc_info is not None:
|
if self._exc_info is not None:
|
||||||
raise_exc_info(self._exc_info)
|
try:
|
||||||
|
raise_exc_info(self._exc_info)
|
||||||
|
finally:
|
||||||
|
self = None
|
||||||
self._check_done()
|
self._check_done()
|
||||||
return self._result
|
return self._result
|
||||||
|
|
||||||
|
@ -324,8 +332,8 @@ class Future(object):
|
||||||
# cycle are never destroyed. It's no longer the case on Python 3.4 thanks to
|
# cycle are never destroyed. It's no longer the case on Python 3.4 thanks to
|
||||||
# the PEP 442.
|
# the PEP 442.
|
||||||
if _GC_CYCLE_FINALIZERS:
|
if _GC_CYCLE_FINALIZERS:
|
||||||
def __del__(self):
|
def __del__(self, is_finalizing=is_finalizing):
|
||||||
if not self._log_traceback:
|
if is_finalizing() or not self._log_traceback:
|
||||||
# set_exception() was not called, or result() or exception()
|
# set_exception() was not called, or result() or exception()
|
||||||
# has consumed the exception
|
# has consumed the exception
|
||||||
return
|
return
|
||||||
|
@ -335,10 +343,11 @@ class Future(object):
|
||||||
app_log.error('Future %r exception was never retrieved: %s',
|
app_log.error('Future %r exception was never retrieved: %s',
|
||||||
self, ''.join(tb).rstrip())
|
self, ''.join(tb).rstrip())
|
||||||
|
|
||||||
|
|
||||||
TracebackFuture = Future
|
TracebackFuture = Future
|
||||||
|
|
||||||
if futures is None:
|
if futures is None:
|
||||||
FUTURES = Future
|
FUTURES = Future # type: typing.Union[type, typing.Tuple[type, ...]]
|
||||||
else:
|
else:
|
||||||
FUTURES = (futures.Future, Future)
|
FUTURES = (futures.Future, Future)
|
||||||
|
|
||||||
|
@ -359,6 +368,7 @@ class DummyExecutor(object):
|
||||||
def shutdown(self, wait=True):
|
def shutdown(self, wait=True):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
dummy_executor = DummyExecutor()
|
dummy_executor = DummyExecutor()
|
||||||
|
|
||||||
|
|
||||||
|
@ -500,8 +510,9 @@ def chain_future(a, b):
|
||||||
assert future is a
|
assert future is a
|
||||||
if b.done():
|
if b.done():
|
||||||
return
|
return
|
||||||
if (isinstance(a, TracebackFuture) and isinstance(b, TracebackFuture)
|
if (isinstance(a, TracebackFuture) and
|
||||||
and a.exc_info() is not None):
|
isinstance(b, TracebackFuture) and
|
||||||
|
a.exc_info() is not None):
|
||||||
b.set_exc_info(a.exc_info())
|
b.set_exc_info(a.exc_info())
|
||||||
elif a.exception() is not None:
|
elif a.exception() is not None:
|
||||||
b.set_exception(a.exception())
|
b.set_exception(a.exception())
|
||||||
|
|
|
@ -16,12 +16,12 @@
|
||||||
|
|
||||||
"""Non-blocking HTTP client implementation using pycurl."""
|
"""Non-blocking HTTP client implementation using pycurl."""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import functools
|
import functools
|
||||||
import logging
|
import logging
|
||||||
import pycurl
|
import pycurl # type: ignore
|
||||||
import threading
|
import threading
|
||||||
import time
|
import time
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
@ -221,6 +221,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
||||||
# _process_queue() is called from
|
# _process_queue() is called from
|
||||||
# _finish_pending_requests the exceptions have
|
# _finish_pending_requests the exceptions have
|
||||||
# nowhere to go.
|
# nowhere to go.
|
||||||
|
self._free_list.append(curl)
|
||||||
callback(HTTPResponse(
|
callback(HTTPResponse(
|
||||||
request=request,
|
request=request,
|
||||||
code=599,
|
code=599,
|
||||||
|
@ -277,6 +278,9 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
||||||
if curl_log.isEnabledFor(logging.DEBUG):
|
if curl_log.isEnabledFor(logging.DEBUG):
|
||||||
curl.setopt(pycurl.VERBOSE, 1)
|
curl.setopt(pycurl.VERBOSE, 1)
|
||||||
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
|
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
|
||||||
|
if hasattr(pycurl, 'PROTOCOLS'): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12)
|
||||||
|
curl.setopt(pycurl.PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
|
||||||
|
curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
|
||||||
return curl
|
return curl
|
||||||
|
|
||||||
def _curl_setup_request(self, curl, request, buffer, headers):
|
def _curl_setup_request(self, curl, request, buffer, headers):
|
||||||
|
@ -341,6 +345,15 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
||||||
credentials = '%s:%s' % (request.proxy_username,
|
credentials = '%s:%s' % (request.proxy_username,
|
||||||
request.proxy_password)
|
request.proxy_password)
|
||||||
curl.setopt(pycurl.PROXYUSERPWD, credentials)
|
curl.setopt(pycurl.PROXYUSERPWD, credentials)
|
||||||
|
|
||||||
|
if (request.proxy_auth_mode is None or
|
||||||
|
request.proxy_auth_mode == "basic"):
|
||||||
|
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC)
|
||||||
|
elif request.proxy_auth_mode == "digest":
|
||||||
|
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST)
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Unsupported proxy_auth_mode %s" % request.proxy_auth_mode)
|
||||||
else:
|
else:
|
||||||
curl.setopt(pycurl.PROXY, '')
|
curl.setopt(pycurl.PROXY, '')
|
||||||
curl.unsetopt(pycurl.PROXYUSERPWD)
|
curl.unsetopt(pycurl.PROXYUSERPWD)
|
||||||
|
@ -461,7 +474,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
|
||||||
request.prepare_curl_callback(curl)
|
request.prepare_curl_callback(curl)
|
||||||
|
|
||||||
def _curl_header_callback(self, headers, header_callback, header_line):
|
def _curl_header_callback(self, headers, header_callback, header_line):
|
||||||
header_line = native_str(header_line)
|
header_line = native_str(header_line.decode('latin1'))
|
||||||
if header_callback is not None:
|
if header_callback is not None:
|
||||||
self.io_loop.add_callback(header_callback, header_line)
|
self.io_loop.add_callback(header_callback, header_line)
|
||||||
# header_line as returned by curl includes the end-of-line characters.
|
# header_line as returned by curl includes the end-of-line characters.
|
||||||
|
|
|
@ -20,34 +20,28 @@ Also includes a few other miscellaneous string manipulation functions that
|
||||||
have crept in over time.
|
have crept in over time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import re
|
|
||||||
import sys
|
|
||||||
|
|
||||||
from tornado.util import unicode_type, basestring_type, u
|
|
||||||
|
|
||||||
try:
|
|
||||||
from urllib.parse import parse_qs as _parse_qs # py3
|
|
||||||
except ImportError:
|
|
||||||
from urlparse import parse_qs as _parse_qs # Python 2.6+
|
|
||||||
|
|
||||||
try:
|
|
||||||
import htmlentitydefs # py2
|
|
||||||
except ImportError:
|
|
||||||
import html.entities as htmlentitydefs # py3
|
|
||||||
|
|
||||||
try:
|
|
||||||
import urllib.parse as urllib_parse # py3
|
|
||||||
except ImportError:
|
|
||||||
import urllib as urllib_parse # py2
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
|
|
||||||
|
from tornado.util import PY3, unicode_type, basestring_type
|
||||||
|
|
||||||
|
if PY3:
|
||||||
|
from urllib.parse import parse_qs as _parse_qs
|
||||||
|
import html.entities as htmlentitydefs
|
||||||
|
import urllib.parse as urllib_parse
|
||||||
|
unichr = chr
|
||||||
|
else:
|
||||||
|
from urlparse import parse_qs as _parse_qs
|
||||||
|
import htmlentitydefs
|
||||||
|
import urllib as urllib_parse
|
||||||
|
|
||||||
try:
|
try:
|
||||||
unichr
|
import typing # noqa
|
||||||
except NameError:
|
except ImportError:
|
||||||
unichr = chr
|
pass
|
||||||
|
|
||||||
|
|
||||||
_XHTML_ESCAPE_RE = re.compile('[&<>"\']')
|
_XHTML_ESCAPE_RE = re.compile('[&<>"\']')
|
||||||
_XHTML_ESCAPE_DICT = {'&': '&', '<': '<', '>': '>', '"': '"',
|
_XHTML_ESCAPE_DICT = {'&': '&', '<': '<', '>': '>', '"': '"',
|
||||||
|
@ -116,7 +110,7 @@ def url_escape(value, plus=True):
|
||||||
# python 3 changed things around enough that we need two separate
|
# python 3 changed things around enough that we need two separate
|
||||||
# implementations of url_unescape. We also need our own implementation
|
# implementations of url_unescape. We also need our own implementation
|
||||||
# of parse_qs since python 3's version insists on decoding everything.
|
# of parse_qs since python 3's version insists on decoding everything.
|
||||||
if sys.version_info[0] < 3:
|
if not PY3:
|
||||||
def url_unescape(value, encoding='utf-8', plus=True):
|
def url_unescape(value, encoding='utf-8', plus=True):
|
||||||
"""Decodes the given value from a URL.
|
"""Decodes the given value from a URL.
|
||||||
|
|
||||||
|
@ -191,6 +185,7 @@ _UTF8_TYPES = (bytes, type(None))
|
||||||
|
|
||||||
|
|
||||||
def utf8(value):
|
def utf8(value):
|
||||||
|
# type: (typing.Union[bytes,unicode_type,None])->typing.Union[bytes,None]
|
||||||
"""Converts a string argument to a byte string.
|
"""Converts a string argument to a byte string.
|
||||||
|
|
||||||
If the argument is already a byte string or None, it is returned unchanged.
|
If the argument is already a byte string or None, it is returned unchanged.
|
||||||
|
@ -204,6 +199,7 @@ def utf8(value):
|
||||||
)
|
)
|
||||||
return value.encode("utf-8")
|
return value.encode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
_TO_UNICODE_TYPES = (unicode_type, type(None))
|
_TO_UNICODE_TYPES = (unicode_type, type(None))
|
||||||
|
|
||||||
|
|
||||||
|
@ -221,6 +217,7 @@ def to_unicode(value):
|
||||||
)
|
)
|
||||||
return value.decode("utf-8")
|
return value.decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
# to_unicode was previously named _unicode not because it was private,
|
# to_unicode was previously named _unicode not because it was private,
|
||||||
# but to avoid conflicts with the built-in unicode() function/type
|
# but to avoid conflicts with the built-in unicode() function/type
|
||||||
_unicode = to_unicode
|
_unicode = to_unicode
|
||||||
|
@ -269,6 +266,7 @@ def recursive_unicode(obj):
|
||||||
else:
|
else:
|
||||||
return obj
|
return obj
|
||||||
|
|
||||||
|
|
||||||
# I originally used the regex from
|
# I originally used the regex from
|
||||||
# http://daringfireball.net/2010/07/improved_regex_for_matching_urls
|
# http://daringfireball.net/2010/07/improved_regex_for_matching_urls
|
||||||
# but it gets all exponential on certain patterns (such as too many trailing
|
# but it gets all exponential on certain patterns (such as too many trailing
|
||||||
|
@ -366,7 +364,7 @@ def linkify(text, shorten=False, extra_params="",
|
||||||
# have a status bar, such as Safari by default)
|
# have a status bar, such as Safari by default)
|
||||||
params += ' title="%s"' % href
|
params += ' title="%s"' % href
|
||||||
|
|
||||||
return u('<a href="%s"%s>%s</a>') % (href, params, url)
|
return u'<a href="%s"%s>%s</a>' % (href, params, url)
|
||||||
|
|
||||||
# First HTML-escape so that our strings are all safe.
|
# First HTML-escape so that our strings are all safe.
|
||||||
# The regex is modified to avoid character entites other than & so
|
# The regex is modified to avoid character entites other than & so
|
||||||
|
@ -396,4 +394,5 @@ def _build_unicode_map():
|
||||||
unicode_map[name] = unichr(value)
|
unicode_map[name] = unichr(value)
|
||||||
return unicode_map
|
return unicode_map
|
||||||
|
|
||||||
|
|
||||||
_HTML_UNICODE_MAP = _build_unicode_map()
|
_HTML_UNICODE_MAP = _build_unicode_map()
|
||||||
|
|
|
@ -74,7 +74,7 @@ See the `convert_yielded` function to extend this mechanism.
|
||||||
via ``singledispatch``.
|
via ``singledispatch``.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import functools
|
import functools
|
||||||
|
@ -83,16 +83,18 @@ import os
|
||||||
import sys
|
import sys
|
||||||
import textwrap
|
import textwrap
|
||||||
import types
|
import types
|
||||||
|
import weakref
|
||||||
|
|
||||||
from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
|
from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
|
||||||
from tornado.ioloop import IOLoop
|
from tornado.ioloop import IOLoop
|
||||||
from tornado.log import app_log
|
from tornado.log import app_log
|
||||||
from tornado import stack_context
|
from tornado import stack_context
|
||||||
from tornado.util import raise_exc_info
|
from tornado.util import PY3, raise_exc_info
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
from functools import singledispatch # py34+
|
# py34+
|
||||||
|
from functools import singledispatch # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from singledispatch import singledispatch # backport
|
from singledispatch import singledispatch # backport
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -108,12 +110,14 @@ except ImportError:
|
||||||
|
|
||||||
try:
|
try:
|
||||||
try:
|
try:
|
||||||
from collections.abc import Generator as GeneratorType # py35+
|
# py35+
|
||||||
|
from collections.abc import Generator as GeneratorType # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from backports_abc import Generator as GeneratorType
|
from backports_abc import Generator as GeneratorType # type: ignore
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from inspect import isawaitable # py35+
|
# py35+
|
||||||
|
from inspect import isawaitable # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
from backports_abc import isawaitable
|
from backports_abc import isawaitable
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
@ -121,12 +125,12 @@ except ImportError:
|
||||||
raise
|
raise
|
||||||
from types import GeneratorType
|
from types import GeneratorType
|
||||||
|
|
||||||
def isawaitable(x):
|
def isawaitable(x): # type: ignore
|
||||||
return False
|
return False
|
||||||
|
|
||||||
try:
|
if PY3:
|
||||||
import builtins # py3
|
import builtins
|
||||||
except ImportError:
|
else:
|
||||||
import __builtin__ as builtins
|
import __builtin__ as builtins
|
||||||
|
|
||||||
|
|
||||||
|
@ -242,6 +246,26 @@ def coroutine(func, replace_callback=True):
|
||||||
return _make_coroutine_wrapper(func, replace_callback=True)
|
return _make_coroutine_wrapper(func, replace_callback=True)
|
||||||
|
|
||||||
|
|
||||||
|
# Ties lifetime of runners to their result futures. Github Issue #1769
|
||||||
|
# Generators, like any object in Python, must be strong referenced
|
||||||
|
# in order to not be cleaned up by the garbage collector. When using
|
||||||
|
# coroutines, the Runner object is what strong-refs the inner
|
||||||
|
# generator. However, the only item that strong-reffed the Runner
|
||||||
|
# was the last Future that the inner generator yielded (via the
|
||||||
|
# Future's internal done_callback list). Usually this is enough, but
|
||||||
|
# it is also possible for this Future to not have any strong references
|
||||||
|
# other than other objects referenced by the Runner object (usually
|
||||||
|
# when using other callback patterns and/or weakrefs). In this
|
||||||
|
# situation, if a garbage collection ran, a cycle would be detected and
|
||||||
|
# Runner objects could be destroyed along with their inner generators
|
||||||
|
# and everything in their local scope.
|
||||||
|
# This map provides strong references to Runner objects as long as
|
||||||
|
# their result future objects also have strong references (typically
|
||||||
|
# from the parent coroutine's Runner). This keeps the coroutine's
|
||||||
|
# Runner alive.
|
||||||
|
_futures_to_runners = weakref.WeakKeyDictionary()
|
||||||
|
|
||||||
|
|
||||||
def _make_coroutine_wrapper(func, replace_callback):
|
def _make_coroutine_wrapper(func, replace_callback):
|
||||||
"""The inner workings of ``@gen.coroutine`` and ``@gen.engine``.
|
"""The inner workings of ``@gen.coroutine`` and ``@gen.engine``.
|
||||||
|
|
||||||
|
@ -251,10 +275,11 @@ def _make_coroutine_wrapper(func, replace_callback):
|
||||||
"""
|
"""
|
||||||
# On Python 3.5, set the coroutine flag on our generator, to allow it
|
# On Python 3.5, set the coroutine flag on our generator, to allow it
|
||||||
# to be used with 'await'.
|
# to be used with 'await'.
|
||||||
|
wrapped = func
|
||||||
if hasattr(types, 'coroutine'):
|
if hasattr(types, 'coroutine'):
|
||||||
func = types.coroutine(func)
|
func = types.coroutine(func)
|
||||||
|
|
||||||
@functools.wraps(func)
|
@functools.wraps(wrapped)
|
||||||
def wrapper(*args, **kwargs):
|
def wrapper(*args, **kwargs):
|
||||||
future = TracebackFuture()
|
future = TracebackFuture()
|
||||||
|
|
||||||
|
@ -291,7 +316,8 @@ def _make_coroutine_wrapper(func, replace_callback):
|
||||||
except Exception:
|
except Exception:
|
||||||
future.set_exc_info(sys.exc_info())
|
future.set_exc_info(sys.exc_info())
|
||||||
else:
|
else:
|
||||||
Runner(result, future, yielded)
|
_futures_to_runners[future] = Runner(result, future, yielded)
|
||||||
|
yielded = None
|
||||||
try:
|
try:
|
||||||
return future
|
return future
|
||||||
finally:
|
finally:
|
||||||
|
@ -306,9 +332,21 @@ def _make_coroutine_wrapper(func, replace_callback):
|
||||||
future = None
|
future = None
|
||||||
future.set_result(result)
|
future.set_result(result)
|
||||||
return future
|
return future
|
||||||
|
|
||||||
|
wrapper.__wrapped__ = wrapped
|
||||||
|
wrapper.__tornado_coroutine__ = True
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def is_coroutine_function(func):
|
||||||
|
"""Return whether *func* is a coroutine function, i.e. a function
|
||||||
|
wrapped with `~.gen.coroutine`.
|
||||||
|
|
||||||
|
.. versionadded:: 4.5
|
||||||
|
"""
|
||||||
|
return getattr(func, '__tornado_coroutine__', False)
|
||||||
|
|
||||||
|
|
||||||
class Return(Exception):
|
class Return(Exception):
|
||||||
"""Special exception to return a value from a `coroutine`.
|
"""Special exception to return a value from a `coroutine`.
|
||||||
|
|
||||||
|
@ -682,6 +720,7 @@ def multi(children, quiet_exceptions=()):
|
||||||
else:
|
else:
|
||||||
return multi_future(children, quiet_exceptions=quiet_exceptions)
|
return multi_future(children, quiet_exceptions=quiet_exceptions)
|
||||||
|
|
||||||
|
|
||||||
Multi = multi
|
Multi = multi
|
||||||
|
|
||||||
|
|
||||||
|
@ -830,7 +869,7 @@ def maybe_future(x):
|
||||||
|
|
||||||
|
|
||||||
def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
|
def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
|
||||||
"""Wraps a `.Future` in a timeout.
|
"""Wraps a `.Future` (or other yieldable object) in a timeout.
|
||||||
|
|
||||||
Raises `TimeoutError` if the input future does not complete before
|
Raises `TimeoutError` if the input future does not complete before
|
||||||
``timeout``, which may be specified in any form allowed by
|
``timeout``, which may be specified in any form allowed by
|
||||||
|
@ -841,15 +880,18 @@ def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
|
||||||
will be logged unless it is of a type contained in ``quiet_exceptions``
|
will be logged unless it is of a type contained in ``quiet_exceptions``
|
||||||
(which may be an exception type or a sequence of types).
|
(which may be an exception type or a sequence of types).
|
||||||
|
|
||||||
Currently only supports Futures, not other `YieldPoint` classes.
|
Does not support `YieldPoint` subclasses.
|
||||||
|
|
||||||
.. versionadded:: 4.0
|
.. versionadded:: 4.0
|
||||||
|
|
||||||
.. versionchanged:: 4.1
|
.. versionchanged:: 4.1
|
||||||
Added the ``quiet_exceptions`` argument and the logging of unhandled
|
Added the ``quiet_exceptions`` argument and the logging of unhandled
|
||||||
exceptions.
|
exceptions.
|
||||||
|
|
||||||
|
.. versionchanged:: 4.4
|
||||||
|
Added support for yieldable objects other than `.Future`.
|
||||||
"""
|
"""
|
||||||
# TODO: allow yield points in addition to futures?
|
# TODO: allow YieldPoints in addition to other yieldables?
|
||||||
# Tricky to do with stack_context semantics.
|
# Tricky to do with stack_context semantics.
|
||||||
#
|
#
|
||||||
# It's tempting to optimize this by cancelling the input future on timeout
|
# It's tempting to optimize this by cancelling the input future on timeout
|
||||||
|
@ -857,6 +899,7 @@ def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
|
||||||
# one waiting on the input future, so cancelling it might disrupt other
|
# one waiting on the input future, so cancelling it might disrupt other
|
||||||
# callers and B) concurrent futures can only be cancelled while they are
|
# callers and B) concurrent futures can only be cancelled while they are
|
||||||
# in the queue, so cancellation cannot reliably bound our waiting time.
|
# in the queue, so cancellation cannot reliably bound our waiting time.
|
||||||
|
future = convert_yielded(future)
|
||||||
result = Future()
|
result = Future()
|
||||||
chain_future(future, result)
|
chain_future(future, result)
|
||||||
if io_loop is None:
|
if io_loop is None:
|
||||||
|
@ -923,6 +966,9 @@ coroutines that are likely to yield Futures that are ready instantly.
|
||||||
Usage: ``yield gen.moment``
|
Usage: ``yield gen.moment``
|
||||||
|
|
||||||
.. versionadded:: 4.0
|
.. versionadded:: 4.0
|
||||||
|
|
||||||
|
.. deprecated:: 4.5
|
||||||
|
``yield None`` is now equivalent to ``yield gen.moment``.
|
||||||
"""
|
"""
|
||||||
moment.set_result(None)
|
moment.set_result(None)
|
||||||
|
|
||||||
|
@ -953,6 +999,7 @@ class Runner(object):
|
||||||
# of the coroutine.
|
# of the coroutine.
|
||||||
self.stack_context_deactivate = None
|
self.stack_context_deactivate = None
|
||||||
if self.handle_yield(first_yielded):
|
if self.handle_yield(first_yielded):
|
||||||
|
gen = result_future = first_yielded = None
|
||||||
self.run()
|
self.run()
|
||||||
|
|
||||||
def register_callback(self, key):
|
def register_callback(self, key):
|
||||||
|
@ -1009,10 +1056,15 @@ class Runner(object):
|
||||||
except Exception:
|
except Exception:
|
||||||
self.had_exception = True
|
self.had_exception = True
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
|
future = None
|
||||||
|
|
||||||
if exc_info is not None:
|
if exc_info is not None:
|
||||||
yielded = self.gen.throw(*exc_info)
|
try:
|
||||||
exc_info = None
|
yielded = self.gen.throw(*exc_info)
|
||||||
|
finally:
|
||||||
|
# Break up a reference to itself
|
||||||
|
# for faster GC on CPython.
|
||||||
|
exc_info = None
|
||||||
else:
|
else:
|
||||||
yielded = self.gen.send(value)
|
yielded = self.gen.send(value)
|
||||||
|
|
||||||
|
@ -1045,6 +1097,7 @@ class Runner(object):
|
||||||
return
|
return
|
||||||
if not self.handle_yield(yielded):
|
if not self.handle_yield(yielded):
|
||||||
return
|
return
|
||||||
|
yielded = None
|
||||||
finally:
|
finally:
|
||||||
self.running = False
|
self.running = False
|
||||||
|
|
||||||
|
@ -1093,8 +1146,12 @@ class Runner(object):
|
||||||
self.future.set_exc_info(sys.exc_info())
|
self.future.set_exc_info(sys.exc_info())
|
||||||
|
|
||||||
if not self.future.done() or self.future is moment:
|
if not self.future.done() or self.future is moment:
|
||||||
|
def inner(f):
|
||||||
|
# Break a reference cycle to speed GC.
|
||||||
|
f = None # noqa
|
||||||
|
self.run()
|
||||||
self.io_loop.add_future(
|
self.io_loop.add_future(
|
||||||
self.future, lambda f: self.run())
|
self.future, inner)
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
@ -1116,6 +1173,7 @@ class Runner(object):
|
||||||
self.stack_context_deactivate()
|
self.stack_context_deactivate()
|
||||||
self.stack_context_deactivate = None
|
self.stack_context_deactivate = None
|
||||||
|
|
||||||
|
|
||||||
Arguments = collections.namedtuple('Arguments', ['args', 'kwargs'])
|
Arguments = collections.namedtuple('Arguments', ['args', 'kwargs'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -1135,6 +1193,7 @@ def _argument_adapter(callback):
|
||||||
callback(None)
|
callback(None)
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
# Convert Awaitables into Futures. It is unfortunately possible
|
# Convert Awaitables into Futures. It is unfortunately possible
|
||||||
# to have infinite recursion here if those Awaitables assume that
|
# to have infinite recursion here if those Awaitables assume that
|
||||||
# we're using a different coroutine runner and yield objects
|
# we're using a different coroutine runner and yield objects
|
||||||
|
@ -1212,7 +1271,9 @@ def convert_yielded(yielded):
|
||||||
.. versionadded:: 4.1
|
.. versionadded:: 4.1
|
||||||
"""
|
"""
|
||||||
# Lists and dicts containing YieldPoints were handled earlier.
|
# Lists and dicts containing YieldPoints were handled earlier.
|
||||||
if isinstance(yielded, (list, dict)):
|
if yielded is None:
|
||||||
|
return moment
|
||||||
|
elif isinstance(yielded, (list, dict)):
|
||||||
return multi(yielded)
|
return multi(yielded)
|
||||||
elif is_future(yielded):
|
elif is_future(yielded):
|
||||||
return yielded
|
return yielded
|
||||||
|
@ -1221,6 +1282,7 @@ def convert_yielded(yielded):
|
||||||
else:
|
else:
|
||||||
raise BadYieldError("yielded unknown object %r" % (yielded,))
|
raise BadYieldError("yielded unknown object %r" % (yielded,))
|
||||||
|
|
||||||
|
|
||||||
if singledispatch is not None:
|
if singledispatch is not None:
|
||||||
convert_yielded = singledispatch(convert_yielded)
|
convert_yielded = singledispatch(convert_yielded)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
.. versionadded:: 4.0
|
.. versionadded:: 4.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
|
@ -30,7 +30,7 @@ from tornado import httputil
|
||||||
from tornado import iostream
|
from tornado import iostream
|
||||||
from tornado.log import gen_log, app_log
|
from tornado.log import gen_log, app_log
|
||||||
from tornado import stack_context
|
from tornado import stack_context
|
||||||
from tornado.util import GzipDecompressor
|
from tornado.util import GzipDecompressor, PY3
|
||||||
|
|
||||||
|
|
||||||
class _QuietException(Exception):
|
class _QuietException(Exception):
|
||||||
|
@ -257,6 +257,7 @@ class HTTP1Connection(httputil.HTTPConnection):
|
||||||
if need_delegate_close:
|
if need_delegate_close:
|
||||||
with _ExceptionLoggingContext(app_log):
|
with _ExceptionLoggingContext(app_log):
|
||||||
delegate.on_connection_close()
|
delegate.on_connection_close()
|
||||||
|
header_future = None
|
||||||
self._clear_callbacks()
|
self._clear_callbacks()
|
||||||
raise gen.Return(True)
|
raise gen.Return(True)
|
||||||
|
|
||||||
|
@ -342,7 +343,7 @@ class HTTP1Connection(httputil.HTTPConnection):
|
||||||
'Transfer-Encoding' not in headers)
|
'Transfer-Encoding' not in headers)
|
||||||
else:
|
else:
|
||||||
self._response_start_line = start_line
|
self._response_start_line = start_line
|
||||||
lines.append(utf8('HTTP/1.1 %s %s' % (start_line[1], start_line[2])))
|
lines.append(utf8('HTTP/1.1 %d %s' % (start_line[1], start_line[2])))
|
||||||
self._chunking_output = (
|
self._chunking_output = (
|
||||||
# TODO: should this use
|
# TODO: should this use
|
||||||
# self._request_start_line.version or
|
# self._request_start_line.version or
|
||||||
|
@ -351,7 +352,7 @@ class HTTP1Connection(httputil.HTTPConnection):
|
||||||
# 304 responses have no body (not even a zero-length body), and so
|
# 304 responses have no body (not even a zero-length body), and so
|
||||||
# should not have either Content-Length or Transfer-Encoding.
|
# should not have either Content-Length or Transfer-Encoding.
|
||||||
# headers.
|
# headers.
|
||||||
start_line.code != 304 and
|
start_line.code not in (204, 304) and
|
||||||
# No need to chunk the output if a Content-Length is specified.
|
# No need to chunk the output if a Content-Length is specified.
|
||||||
'Content-Length' not in headers and
|
'Content-Length' not in headers and
|
||||||
# Applications are discouraged from touching Transfer-Encoding,
|
# Applications are discouraged from touching Transfer-Encoding,
|
||||||
|
@ -359,8 +360,8 @@ class HTTP1Connection(httputil.HTTPConnection):
|
||||||
'Transfer-Encoding' not in headers)
|
'Transfer-Encoding' not in headers)
|
||||||
# If a 1.0 client asked for keep-alive, add the header.
|
# If a 1.0 client asked for keep-alive, add the header.
|
||||||
if (self._request_start_line.version == 'HTTP/1.0' and
|
if (self._request_start_line.version == 'HTTP/1.0' and
|
||||||
(self._request_headers.get('Connection', '').lower()
|
(self._request_headers.get('Connection', '').lower() ==
|
||||||
== 'keep-alive')):
|
'keep-alive')):
|
||||||
headers['Connection'] = 'Keep-Alive'
|
headers['Connection'] = 'Keep-Alive'
|
||||||
if self._chunking_output:
|
if self._chunking_output:
|
||||||
headers['Transfer-Encoding'] = 'chunked'
|
headers['Transfer-Encoding'] = 'chunked'
|
||||||
|
@ -372,7 +373,14 @@ class HTTP1Connection(httputil.HTTPConnection):
|
||||||
self._expected_content_remaining = int(headers['Content-Length'])
|
self._expected_content_remaining = int(headers['Content-Length'])
|
||||||
else:
|
else:
|
||||||
self._expected_content_remaining = None
|
self._expected_content_remaining = None
|
||||||
lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()])
|
# TODO: headers are supposed to be of type str, but we still have some
|
||||||
|
# cases that let bytes slip through. Remove these native_str calls when those
|
||||||
|
# are fixed.
|
||||||
|
header_lines = (native_str(n) + ": " + native_str(v) for n, v in headers.get_all())
|
||||||
|
if PY3:
|
||||||
|
lines.extend(l.encode('latin1') for l in header_lines)
|
||||||
|
else:
|
||||||
|
lines.extend(header_lines)
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if b'\n' in line:
|
if b'\n' in line:
|
||||||
raise ValueError('Newline in header: ' + repr(line))
|
raise ValueError('Newline in header: ' + repr(line))
|
||||||
|
@ -479,9 +487,11 @@ class HTTP1Connection(httputil.HTTPConnection):
|
||||||
connection_header = connection_header.lower()
|
connection_header = connection_header.lower()
|
||||||
if start_line.version == "HTTP/1.1":
|
if start_line.version == "HTTP/1.1":
|
||||||
return connection_header != "close"
|
return connection_header != "close"
|
||||||
elif ("Content-Length" in headers
|
elif ("Content-Length" in headers or
|
||||||
or headers.get("Transfer-Encoding", "").lower() == "chunked"
|
headers.get("Transfer-Encoding", "").lower() == "chunked" or
|
||||||
or start_line.method in ("HEAD", "GET")):
|
getattr(start_line, 'method', None) in ("HEAD", "GET")):
|
||||||
|
# start_line may be a request or response start line; only
|
||||||
|
# the former has a method attribute.
|
||||||
return connection_header == "keep-alive"
|
return connection_header == "keep-alive"
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -531,7 +541,13 @@ class HTTP1Connection(httputil.HTTPConnection):
|
||||||
"Multiple unequal Content-Lengths: %r" %
|
"Multiple unequal Content-Lengths: %r" %
|
||||||
headers["Content-Length"])
|
headers["Content-Length"])
|
||||||
headers["Content-Length"] = pieces[0]
|
headers["Content-Length"] = pieces[0]
|
||||||
content_length = int(headers["Content-Length"])
|
|
||||||
|
try:
|
||||||
|
content_length = int(headers["Content-Length"])
|
||||||
|
except ValueError:
|
||||||
|
# Handles non-integer Content-Length value.
|
||||||
|
raise httputil.HTTPInputError(
|
||||||
|
"Only integer Content-Length is allowed: %s" % headers["Content-Length"])
|
||||||
|
|
||||||
if content_length > self._max_body_size:
|
if content_length > self._max_body_size:
|
||||||
raise httputil.HTTPInputError("Content-Length too long")
|
raise httputil.HTTPInputError("Content-Length too long")
|
||||||
|
@ -550,7 +566,7 @@ class HTTP1Connection(httputil.HTTPConnection):
|
||||||
|
|
||||||
if content_length is not None:
|
if content_length is not None:
|
||||||
return self._read_fixed_body(content_length, delegate)
|
return self._read_fixed_body(content_length, delegate)
|
||||||
if headers.get("Transfer-Encoding") == "chunked":
|
if headers.get("Transfer-Encoding", "").lower() == "chunked":
|
||||||
return self._read_chunked_body(delegate)
|
return self._read_chunked_body(delegate)
|
||||||
if self.is_client:
|
if self.is_client:
|
||||||
return self._read_body_until_close(delegate)
|
return self._read_body_until_close(delegate)
|
||||||
|
|
|
@ -25,7 +25,7 @@ to switch to ``curl_httpclient`` for reasons such as the following:
|
||||||
Note that if you are using ``curl_httpclient``, it is highly
|
Note that if you are using ``curl_httpclient``, it is highly
|
||||||
recommended that you use a recent version of ``libcurl`` and
|
recommended that you use a recent version of ``libcurl`` and
|
||||||
``pycurl``. Currently the minimum supported version of libcurl is
|
``pycurl``. Currently the minimum supported version of libcurl is
|
||||||
7.21.1, and the minimum version of pycurl is 7.18.2. It is highly
|
7.22.0, and the minimum version of pycurl is 7.18.2. It is highly
|
||||||
recommended that your ``libcurl`` installation is built with
|
recommended that your ``libcurl`` installation is built with
|
||||||
asynchronous DNS resolver (threaded or c-ares), otherwise you may
|
asynchronous DNS resolver (threaded or c-ares), otherwise you may
|
||||||
encounter various problems with request timeouts (for more
|
encounter various problems with request timeouts (for more
|
||||||
|
@ -38,7 +38,7 @@ To select ``curl_httpclient``, call `AsyncHTTPClient.configure` at startup::
|
||||||
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
|
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import time
|
import time
|
||||||
|
@ -61,7 +61,7 @@ class HTTPClient(object):
|
||||||
http_client = httpclient.HTTPClient()
|
http_client = httpclient.HTTPClient()
|
||||||
try:
|
try:
|
||||||
response = http_client.fetch("http://www.google.com/")
|
response = http_client.fetch("http://www.google.com/")
|
||||||
print response.body
|
print(response.body)
|
||||||
except httpclient.HTTPError as e:
|
except httpclient.HTTPError as e:
|
||||||
# HTTPError is raised for non-200 responses; the response
|
# HTTPError is raised for non-200 responses; the response
|
||||||
# can be found in e.response.
|
# can be found in e.response.
|
||||||
|
@ -108,14 +108,14 @@ class AsyncHTTPClient(Configurable):
|
||||||
|
|
||||||
Example usage::
|
Example usage::
|
||||||
|
|
||||||
def handle_request(response):
|
def handle_response(response):
|
||||||
if response.error:
|
if response.error:
|
||||||
print "Error:", response.error
|
print("Error: %s" % response.error)
|
||||||
else:
|
else:
|
||||||
print response.body
|
print(response.body)
|
||||||
|
|
||||||
http_client = AsyncHTTPClient()
|
http_client = AsyncHTTPClient()
|
||||||
http_client.fetch("http://www.google.com/", handle_request)
|
http_client.fetch("http://www.google.com/", handle_response)
|
||||||
|
|
||||||
The constructor for this class is magic in several respects: It
|
The constructor for this class is magic in several respects: It
|
||||||
actually creates an instance of an implementation-specific
|
actually creates an instance of an implementation-specific
|
||||||
|
@ -211,10 +211,12 @@ class AsyncHTTPClient(Configurable):
|
||||||
kwargs: ``HTTPRequest(request, **kwargs)``
|
kwargs: ``HTTPRequest(request, **kwargs)``
|
||||||
|
|
||||||
This method returns a `.Future` whose result is an
|
This method returns a `.Future` whose result is an
|
||||||
`HTTPResponse`. By default, the ``Future`` will raise an `HTTPError`
|
`HTTPResponse`. By default, the ``Future`` will raise an
|
||||||
if the request returned a non-200 response code. Instead, if
|
`HTTPError` if the request returned a non-200 response code
|
||||||
``raise_error`` is set to False, the response will always be
|
(other errors may also be raised if the server could not be
|
||||||
returned regardless of the response code.
|
contacted). Instead, if ``raise_error`` is set to False, the
|
||||||
|
response will always be returned regardless of the response
|
||||||
|
code.
|
||||||
|
|
||||||
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
|
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
|
||||||
In the callback interface, `HTTPError` is not automatically raised.
|
In the callback interface, `HTTPError` is not automatically raised.
|
||||||
|
@ -225,6 +227,9 @@ class AsyncHTTPClient(Configurable):
|
||||||
raise RuntimeError("fetch() called on closed AsyncHTTPClient")
|
raise RuntimeError("fetch() called on closed AsyncHTTPClient")
|
||||||
if not isinstance(request, HTTPRequest):
|
if not isinstance(request, HTTPRequest):
|
||||||
request = HTTPRequest(url=request, **kwargs)
|
request = HTTPRequest(url=request, **kwargs)
|
||||||
|
else:
|
||||||
|
if kwargs:
|
||||||
|
raise ValueError("kwargs can't be used if request is an HTTPRequest object")
|
||||||
# We may modify this (to add Host, Accept-Encoding, etc),
|
# We may modify this (to add Host, Accept-Encoding, etc),
|
||||||
# so make sure we don't modify the caller's object. This is also
|
# so make sure we don't modify the caller's object. This is also
|
||||||
# where normal dicts get converted to HTTPHeaders objects.
|
# where normal dicts get converted to HTTPHeaders objects.
|
||||||
|
@ -305,10 +310,10 @@ class HTTPRequest(object):
|
||||||
network_interface=None, streaming_callback=None,
|
network_interface=None, streaming_callback=None,
|
||||||
header_callback=None, prepare_curl_callback=None,
|
header_callback=None, prepare_curl_callback=None,
|
||||||
proxy_host=None, proxy_port=None, proxy_username=None,
|
proxy_host=None, proxy_port=None, proxy_username=None,
|
||||||
proxy_password=None, allow_nonstandard_methods=None,
|
proxy_password=None, proxy_auth_mode=None,
|
||||||
validate_cert=None, ca_certs=None,
|
allow_nonstandard_methods=None, validate_cert=None,
|
||||||
allow_ipv6=None,
|
ca_certs=None, allow_ipv6=None, client_key=None,
|
||||||
client_key=None, client_cert=None, body_producer=None,
|
client_cert=None, body_producer=None,
|
||||||
expect_100_continue=False, decompress_response=None,
|
expect_100_continue=False, decompress_response=None,
|
||||||
ssl_options=None):
|
ssl_options=None):
|
||||||
r"""All parameters except ``url`` are optional.
|
r"""All parameters except ``url`` are optional.
|
||||||
|
@ -336,13 +341,15 @@ class HTTPRequest(object):
|
||||||
Allowed values are implementation-defined; ``curl_httpclient``
|
Allowed values are implementation-defined; ``curl_httpclient``
|
||||||
supports "basic" and "digest"; ``simple_httpclient`` only supports
|
supports "basic" and "digest"; ``simple_httpclient`` only supports
|
||||||
"basic"
|
"basic"
|
||||||
:arg float connect_timeout: Timeout for initial connection in seconds
|
:arg float connect_timeout: Timeout for initial connection in seconds,
|
||||||
:arg float request_timeout: Timeout for entire request in seconds
|
default 20 seconds
|
||||||
|
:arg float request_timeout: Timeout for entire request in seconds,
|
||||||
|
default 20 seconds
|
||||||
:arg if_modified_since: Timestamp for ``If-Modified-Since`` header
|
:arg if_modified_since: Timestamp for ``If-Modified-Since`` header
|
||||||
:type if_modified_since: `datetime` or `float`
|
:type if_modified_since: `datetime` or `float`
|
||||||
:arg bool follow_redirects: Should redirects be followed automatically
|
:arg bool follow_redirects: Should redirects be followed automatically
|
||||||
or return the 3xx response?
|
or return the 3xx response? Default True.
|
||||||
:arg int max_redirects: Limit for ``follow_redirects``
|
:arg int max_redirects: Limit for ``follow_redirects``, default 5.
|
||||||
:arg string user_agent: String to send as ``User-Agent`` header
|
:arg string user_agent: String to send as ``User-Agent`` header
|
||||||
:arg bool decompress_response: Request a compressed response from
|
:arg bool decompress_response: Request a compressed response from
|
||||||
the server and decompress it after downloading. Default is True.
|
the server and decompress it after downloading. Default is True.
|
||||||
|
@ -367,16 +374,18 @@ class HTTPRequest(object):
|
||||||
a ``pycurl.Curl`` object to allow the application to make additional
|
a ``pycurl.Curl`` object to allow the application to make additional
|
||||||
``setopt`` calls.
|
``setopt`` calls.
|
||||||
:arg string proxy_host: HTTP proxy hostname. To use proxies,
|
:arg string proxy_host: HTTP proxy hostname. To use proxies,
|
||||||
``proxy_host`` and ``proxy_port`` must be set; ``proxy_username`` and
|
``proxy_host`` and ``proxy_port`` must be set; ``proxy_username``,
|
||||||
``proxy_pass`` are optional. Proxies are currently only supported
|
``proxy_pass`` and ``proxy_auth_mode`` are optional. Proxies are
|
||||||
with ``curl_httpclient``.
|
currently only supported with ``curl_httpclient``.
|
||||||
:arg int proxy_port: HTTP proxy port
|
:arg int proxy_port: HTTP proxy port
|
||||||
:arg string proxy_username: HTTP proxy username
|
:arg string proxy_username: HTTP proxy username
|
||||||
:arg string proxy_password: HTTP proxy password
|
:arg string proxy_password: HTTP proxy password
|
||||||
|
:arg string proxy_auth_mode: HTTP proxy Authentication mode;
|
||||||
|
default is "basic". supports "basic" and "digest"
|
||||||
:arg bool allow_nonstandard_methods: Allow unknown values for ``method``
|
:arg bool allow_nonstandard_methods: Allow unknown values for ``method``
|
||||||
argument?
|
argument? Default is False.
|
||||||
:arg bool validate_cert: For HTTPS requests, validate the server's
|
:arg bool validate_cert: For HTTPS requests, validate the server's
|
||||||
certificate?
|
certificate? Default is True.
|
||||||
:arg string ca_certs: filename of CA certificates in PEM format,
|
:arg string ca_certs: filename of CA certificates in PEM format,
|
||||||
or None to use defaults. See note below when used with
|
or None to use defaults. See note below when used with
|
||||||
``curl_httpclient``.
|
``curl_httpclient``.
|
||||||
|
@ -414,6 +423,9 @@ class HTTPRequest(object):
|
||||||
|
|
||||||
.. versionadded:: 4.2
|
.. versionadded:: 4.2
|
||||||
The ``ssl_options`` argument.
|
The ``ssl_options`` argument.
|
||||||
|
|
||||||
|
.. versionadded:: 4.5
|
||||||
|
The ``proxy_auth_mode`` argument.
|
||||||
"""
|
"""
|
||||||
# Note that some of these attributes go through property setters
|
# Note that some of these attributes go through property setters
|
||||||
# defined below.
|
# defined below.
|
||||||
|
@ -425,6 +437,7 @@ class HTTPRequest(object):
|
||||||
self.proxy_port = proxy_port
|
self.proxy_port = proxy_port
|
||||||
self.proxy_username = proxy_username
|
self.proxy_username = proxy_username
|
||||||
self.proxy_password = proxy_password
|
self.proxy_password = proxy_password
|
||||||
|
self.proxy_auth_mode = proxy_auth_mode
|
||||||
self.url = url
|
self.url = url
|
||||||
self.method = method
|
self.method = method
|
||||||
self.body = body
|
self.body = body
|
||||||
|
@ -525,7 +538,7 @@ class HTTPResponse(object):
|
||||||
|
|
||||||
* buffer: ``cStringIO`` object for response body
|
* buffer: ``cStringIO`` object for response body
|
||||||
|
|
||||||
* body: response body as string (created on demand from ``self.buffer``)
|
* body: response body as bytes (created on demand from ``self.buffer``)
|
||||||
|
|
||||||
* error: Exception object, if any
|
* error: Exception object, if any
|
||||||
|
|
||||||
|
@ -567,7 +580,8 @@ class HTTPResponse(object):
|
||||||
self.request_time = request_time
|
self.request_time = request_time
|
||||||
self.time_info = time_info or {}
|
self.time_info = time_info or {}
|
||||||
|
|
||||||
def _get_body(self):
|
@property
|
||||||
|
def body(self):
|
||||||
if self.buffer is None:
|
if self.buffer is None:
|
||||||
return None
|
return None
|
||||||
elif self._body is None:
|
elif self._body is None:
|
||||||
|
@ -575,8 +589,6 @@ class HTTPResponse(object):
|
||||||
|
|
||||||
return self._body
|
return self._body
|
||||||
|
|
||||||
body = property(_get_body)
|
|
||||||
|
|
||||||
def rethrow(self):
|
def rethrow(self):
|
||||||
"""If there was an error on the request, raise an `HTTPError`."""
|
"""If there was an error on the request, raise an `HTTPError`."""
|
||||||
if self.error:
|
if self.error:
|
||||||
|
@ -610,6 +622,12 @@ class HTTPError(Exception):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return "HTTP %d: %s" % (self.code, self.message)
|
return "HTTP %d: %s" % (self.code, self.message)
|
||||||
|
|
||||||
|
# There is a cyclic reference between self and self.response,
|
||||||
|
# which breaks the default __repr__ implementation.
|
||||||
|
# (especially on pypy, which doesn't have the same recursion
|
||||||
|
# detection as cpython).
|
||||||
|
__repr__ = __str__
|
||||||
|
|
||||||
|
|
||||||
class _RequestProxy(object):
|
class _RequestProxy(object):
|
||||||
"""Combines an object with a dictionary of defaults.
|
"""Combines an object with a dictionary of defaults.
|
||||||
|
@ -655,5 +673,6 @@ def main():
|
||||||
print(native_str(response.body))
|
print(native_str(response.body))
|
||||||
client.close()
|
client.close()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
|
@ -26,7 +26,7 @@ class except to start a server at the beginning of the process
|
||||||
to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias.
|
to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
|
@ -62,6 +62,13 @@ class HTTPServer(TCPServer, Configurable,
|
||||||
if Tornado is run behind an SSL-decoding proxy that does not set one of
|
if Tornado is run behind an SSL-decoding proxy that does not set one of
|
||||||
the supported ``xheaders``.
|
the supported ``xheaders``.
|
||||||
|
|
||||||
|
By default, when parsing the ``X-Forwarded-For`` header, Tornado will
|
||||||
|
select the last (i.e., the closest) address on the list of hosts as the
|
||||||
|
remote host IP address. To select the next server in the chain, a list of
|
||||||
|
trusted downstream hosts may be passed as the ``trusted_downstream``
|
||||||
|
argument. These hosts will be skipped when parsing the ``X-Forwarded-For``
|
||||||
|
header.
|
||||||
|
|
||||||
To make this server serve SSL traffic, send the ``ssl_options`` keyword
|
To make this server serve SSL traffic, send the ``ssl_options`` keyword
|
||||||
argument with an `ssl.SSLContext` object. For compatibility with older
|
argument with an `ssl.SSLContext` object. For compatibility with older
|
||||||
versions of Python ``ssl_options`` may also be a dictionary of keyword
|
versions of Python ``ssl_options`` may also be a dictionary of keyword
|
||||||
|
@ -124,6 +131,9 @@ class HTTPServer(TCPServer, Configurable,
|
||||||
|
|
||||||
.. versionchanged:: 4.2
|
.. versionchanged:: 4.2
|
||||||
`HTTPServer` is now a subclass of `tornado.util.Configurable`.
|
`HTTPServer` is now a subclass of `tornado.util.Configurable`.
|
||||||
|
|
||||||
|
.. versionchanged:: 4.5
|
||||||
|
Added the ``trusted_downstream`` argument.
|
||||||
"""
|
"""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
# Ignore args to __init__; real initialization belongs in
|
# Ignore args to __init__; real initialization belongs in
|
||||||
|
@ -138,7 +148,8 @@ class HTTPServer(TCPServer, Configurable,
|
||||||
decompress_request=False,
|
decompress_request=False,
|
||||||
chunk_size=None, max_header_size=None,
|
chunk_size=None, max_header_size=None,
|
||||||
idle_connection_timeout=None, body_timeout=None,
|
idle_connection_timeout=None, body_timeout=None,
|
||||||
max_body_size=None, max_buffer_size=None):
|
max_body_size=None, max_buffer_size=None,
|
||||||
|
trusted_downstream=None):
|
||||||
self.request_callback = request_callback
|
self.request_callback = request_callback
|
||||||
self.no_keep_alive = no_keep_alive
|
self.no_keep_alive = no_keep_alive
|
||||||
self.xheaders = xheaders
|
self.xheaders = xheaders
|
||||||
|
@ -149,11 +160,13 @@ class HTTPServer(TCPServer, Configurable,
|
||||||
max_header_size=max_header_size,
|
max_header_size=max_header_size,
|
||||||
header_timeout=idle_connection_timeout or 3600,
|
header_timeout=idle_connection_timeout or 3600,
|
||||||
max_body_size=max_body_size,
|
max_body_size=max_body_size,
|
||||||
body_timeout=body_timeout)
|
body_timeout=body_timeout,
|
||||||
|
no_keep_alive=no_keep_alive)
|
||||||
TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
|
TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
|
||||||
max_buffer_size=max_buffer_size,
|
max_buffer_size=max_buffer_size,
|
||||||
read_chunk_size=chunk_size)
|
read_chunk_size=chunk_size)
|
||||||
self._connections = set()
|
self._connections = set()
|
||||||
|
self.trusted_downstream = trusted_downstream
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def configurable_base(cls):
|
def configurable_base(cls):
|
||||||
|
@ -172,21 +185,55 @@ class HTTPServer(TCPServer, Configurable,
|
||||||
|
|
||||||
def handle_stream(self, stream, address):
|
def handle_stream(self, stream, address):
|
||||||
context = _HTTPRequestContext(stream, address,
|
context = _HTTPRequestContext(stream, address,
|
||||||
self.protocol)
|
self.protocol,
|
||||||
|
self.trusted_downstream)
|
||||||
conn = HTTP1ServerConnection(
|
conn = HTTP1ServerConnection(
|
||||||
stream, self.conn_params, context)
|
stream, self.conn_params, context)
|
||||||
self._connections.add(conn)
|
self._connections.add(conn)
|
||||||
conn.start_serving(self)
|
conn.start_serving(self)
|
||||||
|
|
||||||
def start_request(self, server_conn, request_conn):
|
def start_request(self, server_conn, request_conn):
|
||||||
return _ServerRequestAdapter(self, server_conn, request_conn)
|
if isinstance(self.request_callback, httputil.HTTPServerConnectionDelegate):
|
||||||
|
delegate = self.request_callback.start_request(server_conn, request_conn)
|
||||||
|
else:
|
||||||
|
delegate = _CallableAdapter(self.request_callback, request_conn)
|
||||||
|
|
||||||
|
if self.xheaders:
|
||||||
|
delegate = _ProxyAdapter(delegate, request_conn)
|
||||||
|
|
||||||
|
return delegate
|
||||||
|
|
||||||
def on_close(self, server_conn):
|
def on_close(self, server_conn):
|
||||||
self._connections.remove(server_conn)
|
self._connections.remove(server_conn)
|
||||||
|
|
||||||
|
|
||||||
|
class _CallableAdapter(httputil.HTTPMessageDelegate):
|
||||||
|
def __init__(self, request_callback, request_conn):
|
||||||
|
self.connection = request_conn
|
||||||
|
self.request_callback = request_callback
|
||||||
|
self.request = None
|
||||||
|
self.delegate = None
|
||||||
|
self._chunks = []
|
||||||
|
|
||||||
|
def headers_received(self, start_line, headers):
|
||||||
|
self.request = httputil.HTTPServerRequest(
|
||||||
|
connection=self.connection, start_line=start_line,
|
||||||
|
headers=headers)
|
||||||
|
|
||||||
|
def data_received(self, chunk):
|
||||||
|
self._chunks.append(chunk)
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
self.request.body = b''.join(self._chunks)
|
||||||
|
self.request._parse_body()
|
||||||
|
self.request_callback(self.request)
|
||||||
|
|
||||||
|
def on_connection_close(self):
|
||||||
|
self._chunks = None
|
||||||
|
|
||||||
|
|
||||||
class _HTTPRequestContext(object):
|
class _HTTPRequestContext(object):
|
||||||
def __init__(self, stream, address, protocol):
|
def __init__(self, stream, address, protocol, trusted_downstream=None):
|
||||||
self.address = address
|
self.address = address
|
||||||
# Save the socket's address family now so we know how to
|
# Save the socket's address family now so we know how to
|
||||||
# interpret self.address even after the stream is closed
|
# interpret self.address even after the stream is closed
|
||||||
|
@ -210,6 +257,7 @@ class _HTTPRequestContext(object):
|
||||||
self.protocol = "http"
|
self.protocol = "http"
|
||||||
self._orig_remote_ip = self.remote_ip
|
self._orig_remote_ip = self.remote_ip
|
||||||
self._orig_protocol = self.protocol
|
self._orig_protocol = self.protocol
|
||||||
|
self.trusted_downstream = set(trusted_downstream or [])
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
if self.address_family in (socket.AF_INET, socket.AF_INET6):
|
if self.address_family in (socket.AF_INET, socket.AF_INET6):
|
||||||
|
@ -226,7 +274,10 @@ class _HTTPRequestContext(object):
|
||||||
"""Rewrite the ``remote_ip`` and ``protocol`` fields."""
|
"""Rewrite the ``remote_ip`` and ``protocol`` fields."""
|
||||||
# Squid uses X-Forwarded-For, others use X-Real-Ip
|
# Squid uses X-Forwarded-For, others use X-Real-Ip
|
||||||
ip = headers.get("X-Forwarded-For", self.remote_ip)
|
ip = headers.get("X-Forwarded-For", self.remote_ip)
|
||||||
ip = ip.split(',')[-1].strip()
|
# Skip trusted downstream hosts in X-Forwarded-For list
|
||||||
|
for ip in (cand.strip() for cand in reversed(ip.split(','))):
|
||||||
|
if ip not in self.trusted_downstream:
|
||||||
|
break
|
||||||
ip = headers.get("X-Real-Ip", ip)
|
ip = headers.get("X-Real-Ip", ip)
|
||||||
if netutil.is_valid_ip(ip):
|
if netutil.is_valid_ip(ip):
|
||||||
self.remote_ip = ip
|
self.remote_ip = ip
|
||||||
|
@ -247,58 +298,28 @@ class _HTTPRequestContext(object):
|
||||||
self.protocol = self._orig_protocol
|
self.protocol = self._orig_protocol
|
||||||
|
|
||||||
|
|
||||||
class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
|
class _ProxyAdapter(httputil.HTTPMessageDelegate):
|
||||||
"""Adapts the `HTTPMessageDelegate` interface to the interface expected
|
def __init__(self, delegate, request_conn):
|
||||||
by our clients.
|
|
||||||
"""
|
|
||||||
def __init__(self, server, server_conn, request_conn):
|
|
||||||
self.server = server
|
|
||||||
self.connection = request_conn
|
self.connection = request_conn
|
||||||
self.request = None
|
self.delegate = delegate
|
||||||
if isinstance(server.request_callback,
|
|
||||||
httputil.HTTPServerConnectionDelegate):
|
|
||||||
self.delegate = server.request_callback.start_request(
|
|
||||||
server_conn, request_conn)
|
|
||||||
self._chunks = None
|
|
||||||
else:
|
|
||||||
self.delegate = None
|
|
||||||
self._chunks = []
|
|
||||||
|
|
||||||
def headers_received(self, start_line, headers):
|
def headers_received(self, start_line, headers):
|
||||||
if self.server.xheaders:
|
self.connection.context._apply_xheaders(headers)
|
||||||
self.connection.context._apply_xheaders(headers)
|
return self.delegate.headers_received(start_line, headers)
|
||||||
if self.delegate is None:
|
|
||||||
self.request = httputil.HTTPServerRequest(
|
|
||||||
connection=self.connection, start_line=start_line,
|
|
||||||
headers=headers)
|
|
||||||
else:
|
|
||||||
return self.delegate.headers_received(start_line, headers)
|
|
||||||
|
|
||||||
def data_received(self, chunk):
|
def data_received(self, chunk):
|
||||||
if self.delegate is None:
|
return self.delegate.data_received(chunk)
|
||||||
self._chunks.append(chunk)
|
|
||||||
else:
|
|
||||||
return self.delegate.data_received(chunk)
|
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
if self.delegate is None:
|
self.delegate.finish()
|
||||||
self.request.body = b''.join(self._chunks)
|
|
||||||
self.request._parse_body()
|
|
||||||
self.server.request_callback(self.request)
|
|
||||||
else:
|
|
||||||
self.delegate.finish()
|
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
|
|
||||||
def on_connection_close(self):
|
def on_connection_close(self):
|
||||||
if self.delegate is None:
|
self.delegate.on_connection_close()
|
||||||
self._chunks = None
|
|
||||||
else:
|
|
||||||
self.delegate.on_connection_close()
|
|
||||||
self._cleanup()
|
self._cleanup()
|
||||||
|
|
||||||
def _cleanup(self):
|
def _cleanup(self):
|
||||||
if self.server.xheaders:
|
self.connection.context._unapply_xheaders()
|
||||||
self.connection.context._unapply_xheaders()
|
|
||||||
|
|
||||||
|
|
||||||
HTTPRequest = httputil.HTTPServerRequest
|
HTTPRequest = httputil.HTTPServerRequest
|
||||||
|
|
|
@ -20,7 +20,7 @@ This module also defines the `HTTPServerRequest` class which is exposed
|
||||||
via `tornado.web.RequestHandler.request`.
|
via `tornado.web.RequestHandler.request`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import calendar
|
import calendar
|
||||||
import collections
|
import collections
|
||||||
|
@ -33,33 +33,37 @@ import time
|
||||||
|
|
||||||
from tornado.escape import native_str, parse_qs_bytes, utf8
|
from tornado.escape import native_str, parse_qs_bytes, utf8
|
||||||
from tornado.log import gen_log
|
from tornado.log import gen_log
|
||||||
from tornado.util import ObjectDict
|
from tornado.util import ObjectDict, PY3
|
||||||
|
|
||||||
try:
|
if PY3:
|
||||||
import Cookie # py2
|
import http.cookies as Cookie
|
||||||
except ImportError:
|
from http.client import responses
|
||||||
import http.cookies as Cookie # py3
|
from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl
|
||||||
|
else:
|
||||||
|
import Cookie
|
||||||
|
from httplib import responses
|
||||||
|
from urllib import urlencode
|
||||||
|
from urlparse import urlparse, urlunparse, parse_qsl
|
||||||
|
|
||||||
try:
|
|
||||||
from httplib import responses # py2
|
|
||||||
except ImportError:
|
|
||||||
from http.client import responses # py3
|
|
||||||
|
|
||||||
# responses is unused in this file, but we re-export it to other files.
|
# responses is unused in this file, but we re-export it to other files.
|
||||||
# Reference it so pyflakes doesn't complain.
|
# Reference it so pyflakes doesn't complain.
|
||||||
responses
|
responses
|
||||||
|
|
||||||
try:
|
|
||||||
from urllib import urlencode # py2
|
|
||||||
except ImportError:
|
|
||||||
from urllib.parse import urlencode # py3
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
from ssl import SSLError
|
from ssl import SSLError
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# ssl is unavailable on app engine.
|
# ssl is unavailable on app engine.
|
||||||
class SSLError(Exception):
|
class _SSLError(Exception):
|
||||||
pass
|
pass
|
||||||
|
# Hack around a mypy limitation. We can't simply put "type: ignore"
|
||||||
|
# on the class definition itself; must go through an assignment.
|
||||||
|
SSLError = _SSLError # type: ignore
|
||||||
|
|
||||||
|
try:
|
||||||
|
import typing
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
|
# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
|
||||||
|
@ -95,6 +99,7 @@ class _NormalizedHeaderCache(dict):
|
||||||
del self[old_key]
|
del self[old_key]
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
_normalized_headers = _NormalizedHeaderCache(1000)
|
_normalized_headers = _NormalizedHeaderCache(1000)
|
||||||
|
|
||||||
|
|
||||||
|
@ -127,8 +132,8 @@ class HTTPHeaders(collections.MutableMapping):
|
||||||
Set-Cookie: C=D
|
Set-Cookie: C=D
|
||||||
"""
|
"""
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self._dict = {}
|
self._dict = {} # type: typing.Dict[str, str]
|
||||||
self._as_list = {}
|
self._as_list = {} # type: typing.Dict[str, typing.List[str]]
|
||||||
self._last_key = None
|
self._last_key = None
|
||||||
if (len(args) == 1 and len(kwargs) == 0 and
|
if (len(args) == 1 and len(kwargs) == 0 and
|
||||||
isinstance(args[0], HTTPHeaders)):
|
isinstance(args[0], HTTPHeaders)):
|
||||||
|
@ -142,6 +147,7 @@ class HTTPHeaders(collections.MutableMapping):
|
||||||
# new public methods
|
# new public methods
|
||||||
|
|
||||||
def add(self, name, value):
|
def add(self, name, value):
|
||||||
|
# type: (str, str) -> None
|
||||||
"""Adds a new value for the given key."""
|
"""Adds a new value for the given key."""
|
||||||
norm_name = _normalized_headers[name]
|
norm_name = _normalized_headers[name]
|
||||||
self._last_key = norm_name
|
self._last_key = norm_name
|
||||||
|
@ -158,6 +164,7 @@ class HTTPHeaders(collections.MutableMapping):
|
||||||
return self._as_list.get(norm_name, [])
|
return self._as_list.get(norm_name, [])
|
||||||
|
|
||||||
def get_all(self):
|
def get_all(self):
|
||||||
|
# type: () -> typing.Iterable[typing.Tuple[str, str]]
|
||||||
"""Returns an iterable of all (name, value) pairs.
|
"""Returns an iterable of all (name, value) pairs.
|
||||||
|
|
||||||
If a header has multiple values, multiple pairs will be
|
If a header has multiple values, multiple pairs will be
|
||||||
|
@ -206,6 +213,7 @@ class HTTPHeaders(collections.MutableMapping):
|
||||||
self._as_list[norm_name] = [value]
|
self._as_list[norm_name] = [value]
|
||||||
|
|
||||||
def __getitem__(self, name):
|
def __getitem__(self, name):
|
||||||
|
# type: (str) -> str
|
||||||
return self._dict[_normalized_headers[name]]
|
return self._dict[_normalized_headers[name]]
|
||||||
|
|
||||||
def __delitem__(self, name):
|
def __delitem__(self, name):
|
||||||
|
@ -228,6 +236,14 @@ class HTTPHeaders(collections.MutableMapping):
|
||||||
# the appearance that HTTPHeaders is a single container.
|
# the appearance that HTTPHeaders is a single container.
|
||||||
__copy__ = copy
|
__copy__ = copy
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
lines = []
|
||||||
|
for name, value in self.get_all():
|
||||||
|
lines.append("%s: %s\n" % (name, value))
|
||||||
|
return "".join(lines)
|
||||||
|
|
||||||
|
__unicode__ = __str__
|
||||||
|
|
||||||
|
|
||||||
class HTTPServerRequest(object):
|
class HTTPServerRequest(object):
|
||||||
"""A single HTTP request.
|
"""A single HTTP request.
|
||||||
|
@ -323,7 +339,7 @@ class HTTPServerRequest(object):
|
||||||
"""
|
"""
|
||||||
def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
|
def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
|
||||||
body=None, host=None, files=None, connection=None,
|
body=None, host=None, files=None, connection=None,
|
||||||
start_line=None):
|
start_line=None, server_connection=None):
|
||||||
if start_line is not None:
|
if start_line is not None:
|
||||||
method, uri, version = start_line
|
method, uri, version = start_line
|
||||||
self.method = method
|
self.method = method
|
||||||
|
@ -338,8 +354,10 @@ class HTTPServerRequest(object):
|
||||||
self.protocol = getattr(context, 'protocol', "http")
|
self.protocol = getattr(context, 'protocol', "http")
|
||||||
|
|
||||||
self.host = host or self.headers.get("Host") or "127.0.0.1"
|
self.host = host or self.headers.get("Host") or "127.0.0.1"
|
||||||
|
self.host_name = split_host_and_port(self.host.lower())[0]
|
||||||
self.files = files or {}
|
self.files = files or {}
|
||||||
self.connection = connection
|
self.connection = connection
|
||||||
|
self.server_connection = server_connection
|
||||||
self._start_time = time.time()
|
self._start_time = time.time()
|
||||||
self._finish_time = None
|
self._finish_time = None
|
||||||
|
|
||||||
|
@ -365,10 +383,18 @@ class HTTPServerRequest(object):
|
||||||
self._cookies = Cookie.SimpleCookie()
|
self._cookies = Cookie.SimpleCookie()
|
||||||
if "Cookie" in self.headers:
|
if "Cookie" in self.headers:
|
||||||
try:
|
try:
|
||||||
self._cookies.load(
|
parsed = parse_cookie(self.headers["Cookie"])
|
||||||
native_str(self.headers["Cookie"]))
|
|
||||||
except Exception:
|
except Exception:
|
||||||
self._cookies = {}
|
pass
|
||||||
|
else:
|
||||||
|
for k, v in parsed.items():
|
||||||
|
try:
|
||||||
|
self._cookies[k] = v
|
||||||
|
except Exception:
|
||||||
|
# SimpleCookie imposes some restrictions on keys;
|
||||||
|
# parse_cookie does not. Discard any cookies
|
||||||
|
# with disallowed keys.
|
||||||
|
pass
|
||||||
return self._cookies
|
return self._cookies
|
||||||
|
|
||||||
def write(self, chunk, callback=None):
|
def write(self, chunk, callback=None):
|
||||||
|
@ -577,11 +603,28 @@ def url_concat(url, args):
|
||||||
>>> url_concat("http://example.com/foo?a=b", [("c", "d"), ("c", "d2")])
|
>>> url_concat("http://example.com/foo?a=b", [("c", "d"), ("c", "d2")])
|
||||||
'http://example.com/foo?a=b&c=d&c=d2'
|
'http://example.com/foo?a=b&c=d&c=d2'
|
||||||
"""
|
"""
|
||||||
if not args:
|
if args is None:
|
||||||
return url
|
return url
|
||||||
if url[-1] not in ('?', '&'):
|
parsed_url = urlparse(url)
|
||||||
url += '&' if ('?' in url) else '?'
|
if isinstance(args, dict):
|
||||||
return url + urlencode(args)
|
parsed_query = parse_qsl(parsed_url.query, keep_blank_values=True)
|
||||||
|
parsed_query.extend(args.items())
|
||||||
|
elif isinstance(args, list) or isinstance(args, tuple):
|
||||||
|
parsed_query = parse_qsl(parsed_url.query, keep_blank_values=True)
|
||||||
|
parsed_query.extend(args)
|
||||||
|
else:
|
||||||
|
err = "'args' parameter should be dict, list or tuple. Not {0}".format(
|
||||||
|
type(args))
|
||||||
|
raise TypeError(err)
|
||||||
|
final_query = urlencode(parsed_query)
|
||||||
|
url = urlunparse((
|
||||||
|
parsed_url[0],
|
||||||
|
parsed_url[1],
|
||||||
|
parsed_url[2],
|
||||||
|
parsed_url[3],
|
||||||
|
final_query,
|
||||||
|
parsed_url[5]))
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
class HTTPFile(ObjectDict):
|
class HTTPFile(ObjectDict):
|
||||||
|
@ -743,7 +786,7 @@ def parse_multipart_form_data(boundary, data, arguments, files):
|
||||||
name = disp_params["name"]
|
name = disp_params["name"]
|
||||||
if disp_params.get("filename"):
|
if disp_params.get("filename"):
|
||||||
ctype = headers.get("Content-Type", "application/unknown")
|
ctype = headers.get("Content-Type", "application/unknown")
|
||||||
files.setdefault(name, []).append(HTTPFile(
|
files.setdefault(name, []).append(HTTPFile( # type: ignore
|
||||||
filename=disp_params["filename"], body=value,
|
filename=disp_params["filename"], body=value,
|
||||||
content_type=ctype))
|
content_type=ctype))
|
||||||
else:
|
else:
|
||||||
|
@ -895,3 +938,84 @@ def split_host_and_port(netloc):
|
||||||
host = netloc
|
host = netloc
|
||||||
port = None
|
port = None
|
||||||
return (host, port)
|
return (host, port)
|
||||||
|
|
||||||
|
|
||||||
|
_OctalPatt = re.compile(r"\\[0-3][0-7][0-7]")
|
||||||
|
_QuotePatt = re.compile(r"[\\].")
|
||||||
|
_nulljoin = ''.join
|
||||||
|
|
||||||
|
|
||||||
|
def _unquote_cookie(str):
|
||||||
|
"""Handle double quotes and escaping in cookie values.
|
||||||
|
|
||||||
|
This method is copied verbatim from the Python 3.5 standard
|
||||||
|
library (http.cookies._unquote) so we don't have to depend on
|
||||||
|
non-public interfaces.
|
||||||
|
"""
|
||||||
|
# If there aren't any doublequotes,
|
||||||
|
# then there can't be any special characters. See RFC 2109.
|
||||||
|
if str is None or len(str) < 2:
|
||||||
|
return str
|
||||||
|
if str[0] != '"' or str[-1] != '"':
|
||||||
|
return str
|
||||||
|
|
||||||
|
# We have to assume that we must decode this string.
|
||||||
|
# Down to work.
|
||||||
|
|
||||||
|
# Remove the "s
|
||||||
|
str = str[1:-1]
|
||||||
|
|
||||||
|
# Check for special sequences. Examples:
|
||||||
|
# \012 --> \n
|
||||||
|
# \" --> "
|
||||||
|
#
|
||||||
|
i = 0
|
||||||
|
n = len(str)
|
||||||
|
res = []
|
||||||
|
while 0 <= i < n:
|
||||||
|
o_match = _OctalPatt.search(str, i)
|
||||||
|
q_match = _QuotePatt.search(str, i)
|
||||||
|
if not o_match and not q_match: # Neither matched
|
||||||
|
res.append(str[i:])
|
||||||
|
break
|
||||||
|
# else:
|
||||||
|
j = k = -1
|
||||||
|
if o_match:
|
||||||
|
j = o_match.start(0)
|
||||||
|
if q_match:
|
||||||
|
k = q_match.start(0)
|
||||||
|
if q_match and (not o_match or k < j): # QuotePatt matched
|
||||||
|
res.append(str[i:k])
|
||||||
|
res.append(str[k + 1])
|
||||||
|
i = k + 2
|
||||||
|
else: # OctalPatt matched
|
||||||
|
res.append(str[i:j])
|
||||||
|
res.append(chr(int(str[j + 1:j + 4], 8)))
|
||||||
|
i = j + 4
|
||||||
|
return _nulljoin(res)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_cookie(cookie):
|
||||||
|
"""Parse a ``Cookie`` HTTP header into a dict of name/value pairs.
|
||||||
|
|
||||||
|
This function attempts to mimic browser cookie parsing behavior;
|
||||||
|
it specifically does not follow any of the cookie-related RFCs
|
||||||
|
(because browsers don't either).
|
||||||
|
|
||||||
|
The algorithm used is identical to that used by Django version 1.9.10.
|
||||||
|
|
||||||
|
.. versionadded:: 4.4.2
|
||||||
|
"""
|
||||||
|
cookiedict = {}
|
||||||
|
for chunk in cookie.split(str(';')):
|
||||||
|
if str('=') in chunk:
|
||||||
|
key, val = chunk.split(str('='), 1)
|
||||||
|
else:
|
||||||
|
# Assume an empty name per
|
||||||
|
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
|
||||||
|
key, val = str(''), chunk
|
||||||
|
key, val = key.strip(), val.strip()
|
||||||
|
if key or val:
|
||||||
|
# unquote using Python's algorithm.
|
||||||
|
cookiedict[key] = _unquote_cookie(val)
|
||||||
|
return cookiedict
|
||||||
|
|
|
@ -26,8 +26,9 @@ In addition to I/O events, the `IOLoop` can also schedule time-based events.
|
||||||
`IOLoop.add_timeout` is a non-blocking alternative to `time.sleep`.
|
`IOLoop.add_timeout` is a non-blocking alternative to `time.sleep`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import collections
|
||||||
import datetime
|
import datetime
|
||||||
import errno
|
import errno
|
||||||
import functools
|
import functools
|
||||||
|
@ -45,20 +46,20 @@ import math
|
||||||
|
|
||||||
from tornado.concurrent import TracebackFuture, is_future
|
from tornado.concurrent import TracebackFuture, is_future
|
||||||
from tornado.log import app_log, gen_log
|
from tornado.log import app_log, gen_log
|
||||||
|
from tornado.platform.auto import set_close_exec, Waker
|
||||||
from tornado import stack_context
|
from tornado import stack_context
|
||||||
from tornado.util import Configurable, errno_from_exception, timedelta_to_seconds
|
from tornado.util import PY3, Configurable, errno_from_exception, timedelta_to_seconds
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import signal
|
import signal
|
||||||
except ImportError:
|
except ImportError:
|
||||||
signal = None
|
signal = None
|
||||||
|
|
||||||
try:
|
|
||||||
import thread # py2
|
|
||||||
except ImportError:
|
|
||||||
import _thread as thread # py3
|
|
||||||
|
|
||||||
from tornado.platform.auto import set_close_exec, Waker
|
if PY3:
|
||||||
|
import _thread as thread
|
||||||
|
else:
|
||||||
|
import thread
|
||||||
|
|
||||||
|
|
||||||
_POLL_TIMEOUT = 3600.0
|
_POLL_TIMEOUT = 3600.0
|
||||||
|
@ -172,6 +173,10 @@ class IOLoop(Configurable):
|
||||||
This is normally not necessary as `instance()` will create
|
This is normally not necessary as `instance()` will create
|
||||||
an `IOLoop` on demand, but you may want to call `install` to use
|
an `IOLoop` on demand, but you may want to call `install` to use
|
||||||
a custom subclass of `IOLoop`.
|
a custom subclass of `IOLoop`.
|
||||||
|
|
||||||
|
When using an `IOLoop` subclass, `install` must be called prior
|
||||||
|
to creating any objects that implicitly create their own
|
||||||
|
`IOLoop` (e.g., :class:`tornado.httpclient.AsyncHTTPClient`).
|
||||||
"""
|
"""
|
||||||
assert not IOLoop.initialized()
|
assert not IOLoop.initialized()
|
||||||
IOLoop._instance = self
|
IOLoop._instance = self
|
||||||
|
@ -612,10 +617,14 @@ class IOLoop(Configurable):
|
||||||
# result, which should just be ignored.
|
# result, which should just be ignored.
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
self.add_future(ret, lambda f: f.result())
|
self.add_future(ret, self._discard_future_result)
|
||||||
except Exception:
|
except Exception:
|
||||||
self.handle_callback_exception(callback)
|
self.handle_callback_exception(callback)
|
||||||
|
|
||||||
|
def _discard_future_result(self, future):
|
||||||
|
"""Avoid unhandled-exception warnings from spawned coroutines."""
|
||||||
|
future.result()
|
||||||
|
|
||||||
def handle_callback_exception(self, callback):
|
def handle_callback_exception(self, callback):
|
||||||
"""This method is called whenever a callback run by the `IOLoop`
|
"""This method is called whenever a callback run by the `IOLoop`
|
||||||
throws an exception.
|
throws an exception.
|
||||||
|
@ -685,8 +694,7 @@ class PollIOLoop(IOLoop):
|
||||||
self.time_func = time_func or time.time
|
self.time_func = time_func or time.time
|
||||||
self._handlers = {}
|
self._handlers = {}
|
||||||
self._events = {}
|
self._events = {}
|
||||||
self._callbacks = []
|
self._callbacks = collections.deque()
|
||||||
self._callback_lock = threading.Lock()
|
|
||||||
self._timeouts = []
|
self._timeouts = []
|
||||||
self._cancellations = 0
|
self._cancellations = 0
|
||||||
self._running = False
|
self._running = False
|
||||||
|
@ -704,11 +712,10 @@ class PollIOLoop(IOLoop):
|
||||||
self.READ)
|
self.READ)
|
||||||
|
|
||||||
def close(self, all_fds=False):
|
def close(self, all_fds=False):
|
||||||
with self._callback_lock:
|
self._closing = True
|
||||||
self._closing = True
|
|
||||||
self.remove_handler(self._waker.fileno())
|
self.remove_handler(self._waker.fileno())
|
||||||
if all_fds:
|
if all_fds:
|
||||||
for fd, handler in self._handlers.values():
|
for fd, handler in list(self._handlers.values()):
|
||||||
self.close_fd(fd)
|
self.close_fd(fd)
|
||||||
self._waker.close()
|
self._waker.close()
|
||||||
self._impl.close()
|
self._impl.close()
|
||||||
|
@ -792,9 +799,7 @@ class PollIOLoop(IOLoop):
|
||||||
while True:
|
while True:
|
||||||
# Prevent IO event starvation by delaying new callbacks
|
# Prevent IO event starvation by delaying new callbacks
|
||||||
# to the next iteration of the event loop.
|
# to the next iteration of the event loop.
|
||||||
with self._callback_lock:
|
ncallbacks = len(self._callbacks)
|
||||||
callbacks = self._callbacks
|
|
||||||
self._callbacks = []
|
|
||||||
|
|
||||||
# Add any timeouts that have come due to the callback list.
|
# Add any timeouts that have come due to the callback list.
|
||||||
# Do not run anything until we have determined which ones
|
# Do not run anything until we have determined which ones
|
||||||
|
@ -814,8 +819,8 @@ class PollIOLoop(IOLoop):
|
||||||
due_timeouts.append(heapq.heappop(self._timeouts))
|
due_timeouts.append(heapq.heappop(self._timeouts))
|
||||||
else:
|
else:
|
||||||
break
|
break
|
||||||
if (self._cancellations > 512
|
if (self._cancellations > 512 and
|
||||||
and self._cancellations > (len(self._timeouts) >> 1)):
|
self._cancellations > (len(self._timeouts) >> 1)):
|
||||||
# Clean up the timeout queue when it gets large and it's
|
# Clean up the timeout queue when it gets large and it's
|
||||||
# more than half cancellations.
|
# more than half cancellations.
|
||||||
self._cancellations = 0
|
self._cancellations = 0
|
||||||
|
@ -823,14 +828,14 @@ class PollIOLoop(IOLoop):
|
||||||
if x.callback is not None]
|
if x.callback is not None]
|
||||||
heapq.heapify(self._timeouts)
|
heapq.heapify(self._timeouts)
|
||||||
|
|
||||||
for callback in callbacks:
|
for i in range(ncallbacks):
|
||||||
self._run_callback(callback)
|
self._run_callback(self._callbacks.popleft())
|
||||||
for timeout in due_timeouts:
|
for timeout in due_timeouts:
|
||||||
if timeout.callback is not None:
|
if timeout.callback is not None:
|
||||||
self._run_callback(timeout.callback)
|
self._run_callback(timeout.callback)
|
||||||
# Closures may be holding on to a lot of memory, so allow
|
# Closures may be holding on to a lot of memory, so allow
|
||||||
# them to be freed before we go into our poll wait.
|
# them to be freed before we go into our poll wait.
|
||||||
callbacks = callback = due_timeouts = timeout = None
|
due_timeouts = timeout = None
|
||||||
|
|
||||||
if self._callbacks:
|
if self._callbacks:
|
||||||
# If any callbacks or timeouts called add_callback,
|
# If any callbacks or timeouts called add_callback,
|
||||||
|
@ -874,7 +879,7 @@ class PollIOLoop(IOLoop):
|
||||||
# Pop one fd at a time from the set of pending fds and run
|
# Pop one fd at a time from the set of pending fds and run
|
||||||
# its handler. Since that handler may perform actions on
|
# its handler. Since that handler may perform actions on
|
||||||
# other file descriptors, there may be reentrant calls to
|
# other file descriptors, there may be reentrant calls to
|
||||||
# this IOLoop that update self._events
|
# this IOLoop that modify self._events
|
||||||
self._events.update(event_pairs)
|
self._events.update(event_pairs)
|
||||||
while self._events:
|
while self._events:
|
||||||
fd, events = self._events.popitem()
|
fd, events = self._events.popitem()
|
||||||
|
@ -926,36 +931,20 @@ class PollIOLoop(IOLoop):
|
||||||
self._cancellations += 1
|
self._cancellations += 1
|
||||||
|
|
||||||
def add_callback(self, callback, *args, **kwargs):
|
def add_callback(self, callback, *args, **kwargs):
|
||||||
|
if self._closing:
|
||||||
|
return
|
||||||
|
# Blindly insert into self._callbacks. This is safe even
|
||||||
|
# from signal handlers because deque.append is atomic.
|
||||||
|
self._callbacks.append(functools.partial(
|
||||||
|
stack_context.wrap(callback), *args, **kwargs))
|
||||||
if thread.get_ident() != self._thread_ident:
|
if thread.get_ident() != self._thread_ident:
|
||||||
# If we're not on the IOLoop's thread, we need to synchronize
|
# This will write one byte but Waker.consume() reads many
|
||||||
# with other threads, or waking logic will induce a race.
|
# at once, so it's ok to write even when not strictly
|
||||||
with self._callback_lock:
|
# necessary.
|
||||||
if self._closing:
|
self._waker.wake()
|
||||||
return
|
|
||||||
list_empty = not self._callbacks
|
|
||||||
self._callbacks.append(functools.partial(
|
|
||||||
stack_context.wrap(callback), *args, **kwargs))
|
|
||||||
if list_empty:
|
|
||||||
# If we're not in the IOLoop's thread, and we added the
|
|
||||||
# first callback to an empty list, we may need to wake it
|
|
||||||
# up (it may wake up on its own, but an occasional extra
|
|
||||||
# wake is harmless). Waking up a polling IOLoop is
|
|
||||||
# relatively expensive, so we try to avoid it when we can.
|
|
||||||
self._waker.wake()
|
|
||||||
else:
|
else:
|
||||||
if self._closing:
|
# If we're on the IOLoop's thread, we don't need to wake anyone.
|
||||||
return
|
pass
|
||||||
# If we're on the IOLoop's thread, we don't need the lock,
|
|
||||||
# since we don't need to wake anyone, just add the
|
|
||||||
# callback. Blindly insert into self._callbacks. This is
|
|
||||||
# safe even from signal handlers because the GIL makes
|
|
||||||
# list.append atomic. One subtlety is that if the signal
|
|
||||||
# is interrupting another thread holding the
|
|
||||||
# _callback_lock block in IOLoop.start, we may modify
|
|
||||||
# either the old or new version of self._callbacks, but
|
|
||||||
# either way will work.
|
|
||||||
self._callbacks.append(functools.partial(
|
|
||||||
stack_context.wrap(callback), *args, **kwargs))
|
|
||||||
|
|
||||||
def add_callback_from_signal(self, callback, *args, **kwargs):
|
def add_callback_from_signal(self, callback, *args, **kwargs):
|
||||||
with stack_context.NullContext():
|
with stack_context.NullContext():
|
||||||
|
@ -966,26 +955,24 @@ class _Timeout(object):
|
||||||
"""An IOLoop timeout, a UNIX timestamp and a callback"""
|
"""An IOLoop timeout, a UNIX timestamp and a callback"""
|
||||||
|
|
||||||
# Reduce memory overhead when there are lots of pending callbacks
|
# Reduce memory overhead when there are lots of pending callbacks
|
||||||
__slots__ = ['deadline', 'callback', 'tiebreaker']
|
__slots__ = ['deadline', 'callback', 'tdeadline']
|
||||||
|
|
||||||
def __init__(self, deadline, callback, io_loop):
|
def __init__(self, deadline, callback, io_loop):
|
||||||
if not isinstance(deadline, numbers.Real):
|
if not isinstance(deadline, numbers.Real):
|
||||||
raise TypeError("Unsupported deadline %r" % deadline)
|
raise TypeError("Unsupported deadline %r" % deadline)
|
||||||
self.deadline = deadline
|
self.deadline = deadline
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
self.tiebreaker = next(io_loop._timeout_counter)
|
self.tdeadline = (deadline, next(io_loop._timeout_counter))
|
||||||
|
|
||||||
# Comparison methods to sort by deadline, with object id as a tiebreaker
|
# Comparison methods to sort by deadline, with object id as a tiebreaker
|
||||||
# to guarantee a consistent ordering. The heapq module uses __le__
|
# to guarantee a consistent ordering. The heapq module uses __le__
|
||||||
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons
|
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons
|
||||||
# use __lt__).
|
# use __lt__).
|
||||||
def __lt__(self, other):
|
def __lt__(self, other):
|
||||||
return ((self.deadline, self.tiebreaker) <
|
return self.tdeadline < other.tdeadline
|
||||||
(other.deadline, other.tiebreaker))
|
|
||||||
|
|
||||||
def __le__(self, other):
|
def __le__(self, other):
|
||||||
return ((self.deadline, self.tiebreaker) <=
|
return self.tdeadline <= other.tdeadline
|
||||||
(other.deadline, other.tiebreaker))
|
|
||||||
|
|
||||||
|
|
||||||
class PeriodicCallback(object):
|
class PeriodicCallback(object):
|
||||||
|
@ -1048,6 +1035,7 @@ class PeriodicCallback(object):
|
||||||
|
|
||||||
if self._next_timeout <= current_time:
|
if self._next_timeout <= current_time:
|
||||||
callback_time_sec = self.callback_time / 1000.0
|
callback_time_sec = self.callback_time / 1000.0
|
||||||
self._next_timeout += (math.floor((current_time - self._next_timeout) / callback_time_sec) + 1) * callback_time_sec
|
self._next_timeout += (math.floor((current_time - self._next_timeout) /
|
||||||
|
callback_time_sec) + 1) * callback_time_sec
|
||||||
|
|
||||||
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)
|
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)
|
||||||
|
|
|
@ -24,7 +24,7 @@ Contents:
|
||||||
* `PipeIOStream`: Pipe-based IOStream implementation.
|
* `PipeIOStream`: Pipe-based IOStream implementation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import errno
|
import errno
|
||||||
|
@ -58,7 +58,7 @@ except ImportError:
|
||||||
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
|
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
|
||||||
|
|
||||||
if hasattr(errno, "WSAEWOULDBLOCK"):
|
if hasattr(errno, "WSAEWOULDBLOCK"):
|
||||||
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
|
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore
|
||||||
|
|
||||||
# These errnos indicate that a connection has been abruptly terminated.
|
# These errnos indicate that a connection has been abruptly terminated.
|
||||||
# They should be caught and handled less noisily than other errors.
|
# They should be caught and handled less noisily than other errors.
|
||||||
|
@ -66,7 +66,7 @@ _ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE,
|
||||||
errno.ETIMEDOUT)
|
errno.ETIMEDOUT)
|
||||||
|
|
||||||
if hasattr(errno, "WSAECONNRESET"):
|
if hasattr(errno, "WSAECONNRESET"):
|
||||||
_ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT)
|
_ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT) # type: ignore
|
||||||
|
|
||||||
if sys.platform == 'darwin':
|
if sys.platform == 'darwin':
|
||||||
# OSX appears to have a race condition that causes send(2) to return
|
# OSX appears to have a race condition that causes send(2) to return
|
||||||
|
@ -74,13 +74,15 @@ if sys.platform == 'darwin':
|
||||||
# http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
|
# http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
|
||||||
# Since the socket is being closed anyway, treat this as an ECONNRESET
|
# Since the socket is being closed anyway, treat this as an ECONNRESET
|
||||||
# instead of an unexpected error.
|
# instead of an unexpected error.
|
||||||
_ERRNO_CONNRESET += (errno.EPROTOTYPE,)
|
_ERRNO_CONNRESET += (errno.EPROTOTYPE,) # type: ignore
|
||||||
|
|
||||||
# More non-portable errnos:
|
# More non-portable errnos:
|
||||||
_ERRNO_INPROGRESS = (errno.EINPROGRESS,)
|
_ERRNO_INPROGRESS = (errno.EINPROGRESS,)
|
||||||
|
|
||||||
if hasattr(errno, "WSAEINPROGRESS"):
|
if hasattr(errno, "WSAEINPROGRESS"):
|
||||||
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,)
|
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,) # type: ignore
|
||||||
|
|
||||||
|
_WINDOWS = sys.platform.startswith('win')
|
||||||
|
|
||||||
|
|
||||||
class StreamClosedError(IOError):
|
class StreamClosedError(IOError):
|
||||||
|
@ -158,11 +160,16 @@ class BaseIOStream(object):
|
||||||
self.max_buffer_size // 2)
|
self.max_buffer_size // 2)
|
||||||
self.max_write_buffer_size = max_write_buffer_size
|
self.max_write_buffer_size = max_write_buffer_size
|
||||||
self.error = None
|
self.error = None
|
||||||
self._read_buffer = collections.deque()
|
self._read_buffer = bytearray()
|
||||||
self._write_buffer = collections.deque()
|
self._read_buffer_pos = 0
|
||||||
self._read_buffer_size = 0
|
self._read_buffer_size = 0
|
||||||
|
self._write_buffer = bytearray()
|
||||||
|
self._write_buffer_pos = 0
|
||||||
self._write_buffer_size = 0
|
self._write_buffer_size = 0
|
||||||
self._write_buffer_frozen = False
|
self._write_buffer_frozen = False
|
||||||
|
self._total_write_index = 0
|
||||||
|
self._total_write_done_index = 0
|
||||||
|
self._pending_writes_while_frozen = []
|
||||||
self._read_delimiter = None
|
self._read_delimiter = None
|
||||||
self._read_regex = None
|
self._read_regex = None
|
||||||
self._read_max_bytes = None
|
self._read_max_bytes = None
|
||||||
|
@ -173,7 +180,7 @@ class BaseIOStream(object):
|
||||||
self._read_future = None
|
self._read_future = None
|
||||||
self._streaming_callback = None
|
self._streaming_callback = None
|
||||||
self._write_callback = None
|
self._write_callback = None
|
||||||
self._write_future = None
|
self._write_futures = collections.deque()
|
||||||
self._close_callback = None
|
self._close_callback = None
|
||||||
self._connect_callback = None
|
self._connect_callback = None
|
||||||
self._connect_future = None
|
self._connect_future = None
|
||||||
|
@ -367,36 +374,37 @@ class BaseIOStream(object):
|
||||||
|
|
||||||
If no ``callback`` is given, this method returns a `.Future` that
|
If no ``callback`` is given, this method returns a `.Future` that
|
||||||
resolves (with a result of ``None``) when the write has been
|
resolves (with a result of ``None``) when the write has been
|
||||||
completed. If `write` is called again before that `.Future` has
|
completed.
|
||||||
resolved, the previous future will be orphaned and will never resolve.
|
|
||||||
|
The ``data`` argument may be of type `bytes` or `memoryview`.
|
||||||
|
|
||||||
.. versionchanged:: 4.0
|
.. versionchanged:: 4.0
|
||||||
Now returns a `.Future` if no callback is given.
|
Now returns a `.Future` if no callback is given.
|
||||||
|
|
||||||
|
.. versionchanged:: 4.5
|
||||||
|
Added support for `memoryview` arguments.
|
||||||
"""
|
"""
|
||||||
assert isinstance(data, bytes)
|
|
||||||
self._check_closed()
|
self._check_closed()
|
||||||
# We use bool(_write_buffer) as a proxy for write_buffer_size>0,
|
|
||||||
# so never put empty strings in the buffer.
|
|
||||||
if data:
|
if data:
|
||||||
if (self.max_write_buffer_size is not None and
|
if (self.max_write_buffer_size is not None and
|
||||||
self._write_buffer_size + len(data) > self.max_write_buffer_size):
|
self._write_buffer_size + len(data) > self.max_write_buffer_size):
|
||||||
raise StreamBufferFullError("Reached maximum write buffer size")
|
raise StreamBufferFullError("Reached maximum write buffer size")
|
||||||
# Break up large contiguous strings before inserting them in the
|
if self._write_buffer_frozen:
|
||||||
# write buffer, so we don't have to recopy the entire thing
|
self._pending_writes_while_frozen.append(data)
|
||||||
# as we slice off pieces to send to the socket.
|
else:
|
||||||
WRITE_BUFFER_CHUNK_SIZE = 128 * 1024
|
self._write_buffer += data
|
||||||
for i in range(0, len(data), WRITE_BUFFER_CHUNK_SIZE):
|
self._write_buffer_size += len(data)
|
||||||
self._write_buffer.append(data[i:i + WRITE_BUFFER_CHUNK_SIZE])
|
self._total_write_index += len(data)
|
||||||
self._write_buffer_size += len(data)
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
self._write_callback = stack_context.wrap(callback)
|
self._write_callback = stack_context.wrap(callback)
|
||||||
future = None
|
future = None
|
||||||
else:
|
else:
|
||||||
future = self._write_future = TracebackFuture()
|
future = TracebackFuture()
|
||||||
future.add_done_callback(lambda f: f.exception())
|
future.add_done_callback(lambda f: f.exception())
|
||||||
|
self._write_futures.append((self._total_write_index, future))
|
||||||
if not self._connecting:
|
if not self._connecting:
|
||||||
self._handle_write()
|
self._handle_write()
|
||||||
if self._write_buffer:
|
if self._write_buffer_size:
|
||||||
self._add_io_state(self.io_loop.WRITE)
|
self._add_io_state(self.io_loop.WRITE)
|
||||||
self._maybe_add_error_listener()
|
self._maybe_add_error_listener()
|
||||||
return future
|
return future
|
||||||
|
@ -445,9 +453,8 @@ class BaseIOStream(object):
|
||||||
if self._read_future is not None:
|
if self._read_future is not None:
|
||||||
futures.append(self._read_future)
|
futures.append(self._read_future)
|
||||||
self._read_future = None
|
self._read_future = None
|
||||||
if self._write_future is not None:
|
futures += [future for _, future in self._write_futures]
|
||||||
futures.append(self._write_future)
|
self._write_futures.clear()
|
||||||
self._write_future = None
|
|
||||||
if self._connect_future is not None:
|
if self._connect_future is not None:
|
||||||
futures.append(self._connect_future)
|
futures.append(self._connect_future)
|
||||||
self._connect_future = None
|
self._connect_future = None
|
||||||
|
@ -466,6 +473,7 @@ class BaseIOStream(object):
|
||||||
# if the IOStream object is kept alive by a reference cycle.
|
# if the IOStream object is kept alive by a reference cycle.
|
||||||
# TODO: Clear the read buffer too; it currently breaks some tests.
|
# TODO: Clear the read buffer too; it currently breaks some tests.
|
||||||
self._write_buffer = None
|
self._write_buffer = None
|
||||||
|
self._write_buffer_size = 0
|
||||||
|
|
||||||
def reading(self):
|
def reading(self):
|
||||||
"""Returns true if we are currently reading from the stream."""
|
"""Returns true if we are currently reading from the stream."""
|
||||||
|
@ -473,7 +481,7 @@ class BaseIOStream(object):
|
||||||
|
|
||||||
def writing(self):
|
def writing(self):
|
||||||
"""Returns true if we are currently writing to the stream."""
|
"""Returns true if we are currently writing to the stream."""
|
||||||
return bool(self._write_buffer)
|
return self._write_buffer_size > 0
|
||||||
|
|
||||||
def closed(self):
|
def closed(self):
|
||||||
"""Returns true if the stream has been closed."""
|
"""Returns true if the stream has been closed."""
|
||||||
|
@ -743,7 +751,7 @@ class BaseIOStream(object):
|
||||||
break
|
break
|
||||||
if chunk is None:
|
if chunk is None:
|
||||||
return 0
|
return 0
|
||||||
self._read_buffer.append(chunk)
|
self._read_buffer += chunk
|
||||||
self._read_buffer_size += len(chunk)
|
self._read_buffer_size += len(chunk)
|
||||||
if self._read_buffer_size > self.max_buffer_size:
|
if self._read_buffer_size > self.max_buffer_size:
|
||||||
gen_log.error("Reached maximum read buffer size")
|
gen_log.error("Reached maximum read buffer size")
|
||||||
|
@ -791,30 +799,25 @@ class BaseIOStream(object):
|
||||||
# since large merges are relatively expensive and get undone in
|
# since large merges are relatively expensive and get undone in
|
||||||
# _consume().
|
# _consume().
|
||||||
if self._read_buffer:
|
if self._read_buffer:
|
||||||
while True:
|
loc = self._read_buffer.find(self._read_delimiter,
|
||||||
loc = self._read_buffer[0].find(self._read_delimiter)
|
self._read_buffer_pos)
|
||||||
if loc != -1:
|
if loc != -1:
|
||||||
delimiter_len = len(self._read_delimiter)
|
loc -= self._read_buffer_pos
|
||||||
self._check_max_bytes(self._read_delimiter,
|
delimiter_len = len(self._read_delimiter)
|
||||||
loc + delimiter_len)
|
self._check_max_bytes(self._read_delimiter,
|
||||||
return loc + delimiter_len
|
loc + delimiter_len)
|
||||||
if len(self._read_buffer) == 1:
|
return loc + delimiter_len
|
||||||
break
|
|
||||||
_double_prefix(self._read_buffer)
|
|
||||||
self._check_max_bytes(self._read_delimiter,
|
self._check_max_bytes(self._read_delimiter,
|
||||||
len(self._read_buffer[0]))
|
self._read_buffer_size)
|
||||||
elif self._read_regex is not None:
|
elif self._read_regex is not None:
|
||||||
if self._read_buffer:
|
if self._read_buffer:
|
||||||
while True:
|
m = self._read_regex.search(self._read_buffer,
|
||||||
m = self._read_regex.search(self._read_buffer[0])
|
self._read_buffer_pos)
|
||||||
if m is not None:
|
if m is not None:
|
||||||
self._check_max_bytes(self._read_regex, m.end())
|
loc = m.end() - self._read_buffer_pos
|
||||||
return m.end()
|
self._check_max_bytes(self._read_regex, loc)
|
||||||
if len(self._read_buffer) == 1:
|
return loc
|
||||||
break
|
self._check_max_bytes(self._read_regex, self._read_buffer_size)
|
||||||
_double_prefix(self._read_buffer)
|
|
||||||
self._check_max_bytes(self._read_regex,
|
|
||||||
len(self._read_buffer[0]))
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _check_max_bytes(self, delimiter, size):
|
def _check_max_bytes(self, delimiter, size):
|
||||||
|
@ -824,35 +827,56 @@ class BaseIOStream(object):
|
||||||
"delimiter %r not found within %d bytes" % (
|
"delimiter %r not found within %d bytes" % (
|
||||||
delimiter, self._read_max_bytes))
|
delimiter, self._read_max_bytes))
|
||||||
|
|
||||||
|
def _freeze_write_buffer(self, size):
|
||||||
|
self._write_buffer_frozen = size
|
||||||
|
|
||||||
|
def _unfreeze_write_buffer(self):
|
||||||
|
self._write_buffer_frozen = False
|
||||||
|
self._write_buffer += b''.join(self._pending_writes_while_frozen)
|
||||||
|
self._write_buffer_size += sum(map(len, self._pending_writes_while_frozen))
|
||||||
|
self._pending_writes_while_frozen[:] = []
|
||||||
|
|
||||||
|
def _got_empty_write(self, size):
|
||||||
|
"""
|
||||||
|
Called when a non-blocking write() failed writing anything.
|
||||||
|
Can be overridden in subclasses.
|
||||||
|
"""
|
||||||
|
|
||||||
def _handle_write(self):
|
def _handle_write(self):
|
||||||
while self._write_buffer:
|
while self._write_buffer_size:
|
||||||
|
assert self._write_buffer_size >= 0
|
||||||
try:
|
try:
|
||||||
if not self._write_buffer_frozen:
|
start = self._write_buffer_pos
|
||||||
|
if self._write_buffer_frozen:
|
||||||
|
size = self._write_buffer_frozen
|
||||||
|
elif _WINDOWS:
|
||||||
# On windows, socket.send blows up if given a
|
# On windows, socket.send blows up if given a
|
||||||
# write buffer that's too large, instead of just
|
# write buffer that's too large, instead of just
|
||||||
# returning the number of bytes it was able to
|
# returning the number of bytes it was able to
|
||||||
# process. Therefore we must not call socket.send
|
# process. Therefore we must not call socket.send
|
||||||
# with more than 128KB at a time.
|
# with more than 128KB at a time.
|
||||||
_merge_prefix(self._write_buffer, 128 * 1024)
|
size = 128 * 1024
|
||||||
num_bytes = self.write_to_fd(self._write_buffer[0])
|
else:
|
||||||
|
size = self._write_buffer_size
|
||||||
|
num_bytes = self.write_to_fd(
|
||||||
|
memoryview(self._write_buffer)[start:start + size])
|
||||||
if num_bytes == 0:
|
if num_bytes == 0:
|
||||||
# With OpenSSL, if we couldn't write the entire buffer,
|
self._got_empty_write(size)
|
||||||
# the very same string object must be used on the
|
|
||||||
# next call to send. Therefore we suppress
|
|
||||||
# merging the write buffer after an incomplete send.
|
|
||||||
# A cleaner solution would be to set
|
|
||||||
# SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER, but this is
|
|
||||||
# not yet accessible from python
|
|
||||||
# (http://bugs.python.org/issue8240)
|
|
||||||
self._write_buffer_frozen = True
|
|
||||||
break
|
break
|
||||||
self._write_buffer_frozen = False
|
self._write_buffer_pos += num_bytes
|
||||||
_merge_prefix(self._write_buffer, num_bytes)
|
|
||||||
self._write_buffer.popleft()
|
|
||||||
self._write_buffer_size -= num_bytes
|
self._write_buffer_size -= num_bytes
|
||||||
|
# Amortized O(1) shrink
|
||||||
|
# (this heuristic is implemented natively in Python 3.4+
|
||||||
|
# but is replicated here for Python 2)
|
||||||
|
if self._write_buffer_pos > self._write_buffer_size:
|
||||||
|
del self._write_buffer[:self._write_buffer_pos]
|
||||||
|
self._write_buffer_pos = 0
|
||||||
|
if self._write_buffer_frozen:
|
||||||
|
self._unfreeze_write_buffer()
|
||||||
|
self._total_write_done_index += num_bytes
|
||||||
except (socket.error, IOError, OSError) as e:
|
except (socket.error, IOError, OSError) as e:
|
||||||
if e.args[0] in _ERRNO_WOULDBLOCK:
|
if e.args[0] in _ERRNO_WOULDBLOCK:
|
||||||
self._write_buffer_frozen = True
|
self._got_empty_write(size)
|
||||||
break
|
break
|
||||||
else:
|
else:
|
||||||
if not self._is_connreset(e):
|
if not self._is_connreset(e):
|
||||||
|
@ -863,22 +887,38 @@ class BaseIOStream(object):
|
||||||
self.fileno(), e)
|
self.fileno(), e)
|
||||||
self.close(exc_info=True)
|
self.close(exc_info=True)
|
||||||
return
|
return
|
||||||
if not self._write_buffer:
|
|
||||||
|
while self._write_futures:
|
||||||
|
index, future = self._write_futures[0]
|
||||||
|
if index > self._total_write_done_index:
|
||||||
|
break
|
||||||
|
self._write_futures.popleft()
|
||||||
|
future.set_result(None)
|
||||||
|
|
||||||
|
if not self._write_buffer_size:
|
||||||
if self._write_callback:
|
if self._write_callback:
|
||||||
callback = self._write_callback
|
callback = self._write_callback
|
||||||
self._write_callback = None
|
self._write_callback = None
|
||||||
self._run_callback(callback)
|
self._run_callback(callback)
|
||||||
if self._write_future:
|
|
||||||
future = self._write_future
|
|
||||||
self._write_future = None
|
|
||||||
future.set_result(None)
|
|
||||||
|
|
||||||
def _consume(self, loc):
|
def _consume(self, loc):
|
||||||
|
# Consume loc bytes from the read buffer and return them
|
||||||
if loc == 0:
|
if loc == 0:
|
||||||
return b""
|
return b""
|
||||||
_merge_prefix(self._read_buffer, loc)
|
assert loc <= self._read_buffer_size
|
||||||
|
# Slice the bytearray buffer into bytes, without intermediate copying
|
||||||
|
b = (memoryview(self._read_buffer)
|
||||||
|
[self._read_buffer_pos:self._read_buffer_pos + loc]
|
||||||
|
).tobytes()
|
||||||
|
self._read_buffer_pos += loc
|
||||||
self._read_buffer_size -= loc
|
self._read_buffer_size -= loc
|
||||||
return self._read_buffer.popleft()
|
# Amortized O(1) shrink
|
||||||
|
# (this heuristic is implemented natively in Python 3.4+
|
||||||
|
# but is replicated here for Python 2)
|
||||||
|
if self._read_buffer_pos > self._read_buffer_size:
|
||||||
|
del self._read_buffer[:self._read_buffer_pos]
|
||||||
|
self._read_buffer_pos = 0
|
||||||
|
return b
|
||||||
|
|
||||||
def _check_closed(self):
|
def _check_closed(self):
|
||||||
if self.closed():
|
if self.closed():
|
||||||
|
@ -1124,7 +1164,7 @@ class IOStream(BaseIOStream):
|
||||||
suitably-configured `ssl.SSLContext` to disable.
|
suitably-configured `ssl.SSLContext` to disable.
|
||||||
"""
|
"""
|
||||||
if (self._read_callback or self._read_future or
|
if (self._read_callback or self._read_future or
|
||||||
self._write_callback or self._write_future or
|
self._write_callback or self._write_futures or
|
||||||
self._connect_callback or self._connect_future or
|
self._connect_callback or self._connect_future or
|
||||||
self._pending_callbacks or self._closed or
|
self._pending_callbacks or self._closed or
|
||||||
self._read_buffer or self._write_buffer):
|
self._read_buffer or self._write_buffer):
|
||||||
|
@ -1251,6 +1291,17 @@ class SSLIOStream(IOStream):
|
||||||
def writing(self):
|
def writing(self):
|
||||||
return self._handshake_writing or super(SSLIOStream, self).writing()
|
return self._handshake_writing or super(SSLIOStream, self).writing()
|
||||||
|
|
||||||
|
def _got_empty_write(self, size):
|
||||||
|
# With OpenSSL, if we couldn't write the entire buffer,
|
||||||
|
# the very same string object must be used on the
|
||||||
|
# next call to send. Therefore we suppress
|
||||||
|
# merging the write buffer after an incomplete send.
|
||||||
|
# A cleaner solution would be to set
|
||||||
|
# SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER, but this is
|
||||||
|
# not yet accessible from python
|
||||||
|
# (http://bugs.python.org/issue8240)
|
||||||
|
self._freeze_write_buffer(size)
|
||||||
|
|
||||||
def _do_ssl_handshake(self):
|
def _do_ssl_handshake(self):
|
||||||
# Based on code from test_ssl.py in the python stdlib
|
# Based on code from test_ssl.py in the python stdlib
|
||||||
try:
|
try:
|
||||||
|
@ -1498,53 +1549,6 @@ class PipeIOStream(BaseIOStream):
|
||||||
return chunk
|
return chunk
|
||||||
|
|
||||||
|
|
||||||
def _double_prefix(deque):
|
|
||||||
"""Grow by doubling, but don't split the second chunk just because the
|
|
||||||
first one is small.
|
|
||||||
"""
|
|
||||||
new_len = max(len(deque[0]) * 2,
|
|
||||||
(len(deque[0]) + len(deque[1])))
|
|
||||||
_merge_prefix(deque, new_len)
|
|
||||||
|
|
||||||
|
|
||||||
def _merge_prefix(deque, size):
|
|
||||||
"""Replace the first entries in a deque of strings with a single
|
|
||||||
string of up to size bytes.
|
|
||||||
|
|
||||||
>>> d = collections.deque(['abc', 'de', 'fghi', 'j'])
|
|
||||||
>>> _merge_prefix(d, 5); print(d)
|
|
||||||
deque(['abcde', 'fghi', 'j'])
|
|
||||||
|
|
||||||
Strings will be split as necessary to reach the desired size.
|
|
||||||
>>> _merge_prefix(d, 7); print(d)
|
|
||||||
deque(['abcdefg', 'hi', 'j'])
|
|
||||||
|
|
||||||
>>> _merge_prefix(d, 3); print(d)
|
|
||||||
deque(['abc', 'defg', 'hi', 'j'])
|
|
||||||
|
|
||||||
>>> _merge_prefix(d, 100); print(d)
|
|
||||||
deque(['abcdefghij'])
|
|
||||||
"""
|
|
||||||
if len(deque) == 1 and len(deque[0]) <= size:
|
|
||||||
return
|
|
||||||
prefix = []
|
|
||||||
remaining = size
|
|
||||||
while deque and remaining > 0:
|
|
||||||
chunk = deque.popleft()
|
|
||||||
if len(chunk) > remaining:
|
|
||||||
deque.appendleft(chunk[remaining:])
|
|
||||||
chunk = chunk[:remaining]
|
|
||||||
prefix.append(chunk)
|
|
||||||
remaining -= len(chunk)
|
|
||||||
# This data structure normally just contains byte strings, but
|
|
||||||
# the unittest gets messy if it doesn't use the default str() type,
|
|
||||||
# so do the merge based on the type of data that's actually present.
|
|
||||||
if prefix:
|
|
||||||
deque.appendleft(type(prefix[0])().join(prefix))
|
|
||||||
if not deque:
|
|
||||||
deque.appendleft(b"")
|
|
||||||
|
|
||||||
|
|
||||||
def doctests():
|
def doctests():
|
||||||
import doctest
|
import doctest
|
||||||
return doctest.DocTestSuite()
|
return doctest.DocTestSuite()
|
||||||
|
|
|
@ -19,7 +19,7 @@
|
||||||
To load a locale and generate a translated string::
|
To load a locale and generate a translated string::
|
||||||
|
|
||||||
user_locale = tornado.locale.get("es_LA")
|
user_locale = tornado.locale.get("es_LA")
|
||||||
print user_locale.translate("Sign out")
|
print(user_locale.translate("Sign out"))
|
||||||
|
|
||||||
`tornado.locale.get()` returns the closest matching locale, not necessarily the
|
`tornado.locale.get()` returns the closest matching locale, not necessarily the
|
||||||
specific locale you requested. You can support pluralization with
|
specific locale you requested. You can support pluralization with
|
||||||
|
@ -28,7 +28,7 @@ additional arguments to `~Locale.translate()`, e.g.::
|
||||||
people = [...]
|
people = [...]
|
||||||
message = user_locale.translate(
|
message = user_locale.translate(
|
||||||
"%(list)s is online", "%(list)s are online", len(people))
|
"%(list)s is online", "%(list)s are online", len(people))
|
||||||
print message % {"list": user_locale.list(people)}
|
print(message % {"list": user_locale.list(people)})
|
||||||
|
|
||||||
The first string is chosen if ``len(people) == 1``, otherwise the second
|
The first string is chosen if ``len(people) == 1``, otherwise the second
|
||||||
string is chosen.
|
string is chosen.
|
||||||
|
@ -39,7 +39,7 @@ supported by `gettext` and related tools). If neither method is called,
|
||||||
the `Locale.translate` method will simply return the original string.
|
the `Locale.translate` method will simply return the original string.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import codecs
|
import codecs
|
||||||
import csv
|
import csv
|
||||||
|
@ -51,12 +51,12 @@ import re
|
||||||
|
|
||||||
from tornado import escape
|
from tornado import escape
|
||||||
from tornado.log import gen_log
|
from tornado.log import gen_log
|
||||||
from tornado.util import u
|
from tornado.util import PY3
|
||||||
|
|
||||||
from tornado._locale_data import LOCALE_NAMES
|
from tornado._locale_data import LOCALE_NAMES
|
||||||
|
|
||||||
_default_locale = "en_US"
|
_default_locale = "en_US"
|
||||||
_translations = {}
|
_translations = {} # type: dict
|
||||||
_supported_locales = frozenset([_default_locale])
|
_supported_locales = frozenset([_default_locale])
|
||||||
_use_gettext = False
|
_use_gettext = False
|
||||||
CONTEXT_SEPARATOR = "\x04"
|
CONTEXT_SEPARATOR = "\x04"
|
||||||
|
@ -148,11 +148,11 @@ def load_translations(directory, encoding=None):
|
||||||
# in most cases but is common with CSV files because Excel
|
# in most cases but is common with CSV files because Excel
|
||||||
# cannot read utf-8 files without a BOM.
|
# cannot read utf-8 files without a BOM.
|
||||||
encoding = 'utf-8-sig'
|
encoding = 'utf-8-sig'
|
||||||
try:
|
if PY3:
|
||||||
# python 3: csv.reader requires a file open in text mode.
|
# python 3: csv.reader requires a file open in text mode.
|
||||||
# Force utf8 to avoid dependence on $LANG environment variable.
|
# Force utf8 to avoid dependence on $LANG environment variable.
|
||||||
f = open(full_path, "r", encoding=encoding)
|
f = open(full_path, "r", encoding=encoding)
|
||||||
except TypeError:
|
else:
|
||||||
# python 2: csv can only handle byte strings (in ascii-compatible
|
# python 2: csv can only handle byte strings (in ascii-compatible
|
||||||
# encodings), which we decode below. Transcode everything into
|
# encodings), which we decode below. Transcode everything into
|
||||||
# utf8 before passing it to csv.reader.
|
# utf8 before passing it to csv.reader.
|
||||||
|
@ -187,7 +187,7 @@ def load_gettext_translations(directory, domain):
|
||||||
|
|
||||||
{directory}/{lang}/LC_MESSAGES/{domain}.mo
|
{directory}/{lang}/LC_MESSAGES/{domain}.mo
|
||||||
|
|
||||||
Three steps are required to have you app translated:
|
Three steps are required to have your app translated:
|
||||||
|
|
||||||
1. Generate POT translation file::
|
1. Generate POT translation file::
|
||||||
|
|
||||||
|
@ -274,7 +274,7 @@ class Locale(object):
|
||||||
|
|
||||||
def __init__(self, code, translations):
|
def __init__(self, code, translations):
|
||||||
self.code = code
|
self.code = code
|
||||||
self.name = LOCALE_NAMES.get(code, {}).get("name", u("Unknown"))
|
self.name = LOCALE_NAMES.get(code, {}).get("name", u"Unknown")
|
||||||
self.rtl = False
|
self.rtl = False
|
||||||
for prefix in ["fa", "ar", "he"]:
|
for prefix in ["fa", "ar", "he"]:
|
||||||
if self.code.startswith(prefix):
|
if self.code.startswith(prefix):
|
||||||
|
@ -376,7 +376,7 @@ class Locale(object):
|
||||||
str_time = "%d:%02d" % (local_date.hour, local_date.minute)
|
str_time = "%d:%02d" % (local_date.hour, local_date.minute)
|
||||||
elif self.code == "zh_CN":
|
elif self.code == "zh_CN":
|
||||||
str_time = "%s%d:%02d" % (
|
str_time = "%s%d:%02d" % (
|
||||||
(u('\u4e0a\u5348'), u('\u4e0b\u5348'))[local_date.hour >= 12],
|
(u'\u4e0a\u5348', u'\u4e0b\u5348')[local_date.hour >= 12],
|
||||||
local_date.hour % 12 or 12, local_date.minute)
|
local_date.hour % 12 or 12, local_date.minute)
|
||||||
else:
|
else:
|
||||||
str_time = "%d:%02d %s" % (
|
str_time = "%d:%02d %s" % (
|
||||||
|
@ -422,7 +422,7 @@ class Locale(object):
|
||||||
return ""
|
return ""
|
||||||
if len(parts) == 1:
|
if len(parts) == 1:
|
||||||
return parts[0]
|
return parts[0]
|
||||||
comma = u(' \u0648 ') if self.code.startswith("fa") else u(", ")
|
comma = u' \u0648 ' if self.code.startswith("fa") else u", "
|
||||||
return _("%(commas)s and %(last)s") % {
|
return _("%(commas)s and %(last)s") % {
|
||||||
"commas": comma.join(parts[:-1]),
|
"commas": comma.join(parts[:-1]),
|
||||||
"last": parts[len(parts) - 1],
|
"last": parts[len(parts) - 1],
|
||||||
|
|
|
@ -12,15 +12,15 @@
|
||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
|
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
|
|
||||||
from tornado import gen, ioloop
|
from tornado import gen, ioloop
|
||||||
from tornado.concurrent import Future
|
from tornado.concurrent import Future
|
||||||
|
|
||||||
|
__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
|
||||||
|
|
||||||
|
|
||||||
class _TimeoutGarbageCollector(object):
|
class _TimeoutGarbageCollector(object):
|
||||||
"""Base class for objects that periodically clean up timed-out waiters.
|
"""Base class for objects that periodically clean up timed-out waiters.
|
||||||
|
@ -465,7 +465,7 @@ class Lock(object):
|
||||||
...
|
...
|
||||||
... # Now the lock is released.
|
... # Now the lock is released.
|
||||||
|
|
||||||
.. versionchanged:: 3.5
|
.. versionchanged:: 4.3
|
||||||
Added ``async with`` support in Python 3.5.
|
Added ``async with`` support in Python 3.5.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -28,7 +28,7 @@ These streams may be configured independently using the standard library's
|
||||||
`logging` module. For example, you may wish to send ``tornado.access`` logs
|
`logging` module. For example, you may wish to send ``tornado.access`` logs
|
||||||
to a separate file for analysis.
|
to a separate file for analysis.
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import logging.handlers
|
import logging.handlers
|
||||||
|
@ -38,7 +38,12 @@ from tornado.escape import _unicode
|
||||||
from tornado.util import unicode_type, basestring_type
|
from tornado.util import unicode_type, basestring_type
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import curses
|
import colorama
|
||||||
|
except ImportError:
|
||||||
|
colorama = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
import curses # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
curses = None
|
curses = None
|
||||||
|
|
||||||
|
@ -49,15 +54,21 @@ gen_log = logging.getLogger("tornado.general")
|
||||||
|
|
||||||
|
|
||||||
def _stderr_supports_color():
|
def _stderr_supports_color():
|
||||||
color = False
|
try:
|
||||||
if curses and hasattr(sys.stderr, 'isatty') and sys.stderr.isatty():
|
if hasattr(sys.stderr, 'isatty') and sys.stderr.isatty():
|
||||||
try:
|
if curses:
|
||||||
curses.setupterm()
|
curses.setupterm()
|
||||||
if curses.tigetnum("colors") > 0:
|
if curses.tigetnum("colors") > 0:
|
||||||
color = True
|
return True
|
||||||
except Exception:
|
elif colorama:
|
||||||
pass
|
if sys.stderr is getattr(colorama.initialise, 'wrapped_stderr',
|
||||||
return color
|
object()):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
# Very broad exception handling because it's always better to
|
||||||
|
# fall back to non-colored logs than to break at startup.
|
||||||
|
pass
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def _safe_unicode(s):
|
def _safe_unicode(s):
|
||||||
|
@ -77,8 +88,19 @@ class LogFormatter(logging.Formatter):
|
||||||
* Robust against str/bytes encoding problems.
|
* Robust against str/bytes encoding problems.
|
||||||
|
|
||||||
This formatter is enabled automatically by
|
This formatter is enabled automatically by
|
||||||
`tornado.options.parse_command_line` (unless ``--logging=none`` is
|
`tornado.options.parse_command_line` or `tornado.options.parse_config_file`
|
||||||
used).
|
(unless ``--logging=none`` is used).
|
||||||
|
|
||||||
|
Color support on Windows versions that do not support ANSI color codes is
|
||||||
|
enabled by use of the colorama__ library. Applications that wish to use
|
||||||
|
this must first initialize colorama with a call to ``colorama.init``.
|
||||||
|
See the colorama documentation for details.
|
||||||
|
|
||||||
|
__ https://pypi.python.org/pypi/colorama
|
||||||
|
|
||||||
|
.. versionchanged:: 4.5
|
||||||
|
Added support for ``colorama``. Changed the constructor
|
||||||
|
signature to be compatible with `logging.config.dictConfig`.
|
||||||
"""
|
"""
|
||||||
DEFAULT_FORMAT = '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s'
|
DEFAULT_FORMAT = '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s'
|
||||||
DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S'
|
DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S'
|
||||||
|
@ -89,8 +111,8 @@ class LogFormatter(logging.Formatter):
|
||||||
logging.ERROR: 1, # Red
|
logging.ERROR: 1, # Red
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, color=True, fmt=DEFAULT_FORMAT,
|
def __init__(self, fmt=DEFAULT_FORMAT, datefmt=DEFAULT_DATE_FORMAT,
|
||||||
datefmt=DEFAULT_DATE_FORMAT, colors=DEFAULT_COLORS):
|
style='%', color=True, colors=DEFAULT_COLORS):
|
||||||
r"""
|
r"""
|
||||||
:arg bool color: Enables color support.
|
:arg bool color: Enables color support.
|
||||||
:arg string fmt: Log message format.
|
:arg string fmt: Log message format.
|
||||||
|
@ -111,21 +133,28 @@ class LogFormatter(logging.Formatter):
|
||||||
|
|
||||||
self._colors = {}
|
self._colors = {}
|
||||||
if color and _stderr_supports_color():
|
if color and _stderr_supports_color():
|
||||||
# The curses module has some str/bytes confusion in
|
if curses is not None:
|
||||||
# python3. Until version 3.2.3, most methods return
|
# The curses module has some str/bytes confusion in
|
||||||
# bytes, but only accept strings. In addition, we want to
|
# python3. Until version 3.2.3, most methods return
|
||||||
# output these strings with the logging module, which
|
# bytes, but only accept strings. In addition, we want to
|
||||||
# works with unicode strings. The explicit calls to
|
# output these strings with the logging module, which
|
||||||
# unicode() below are harmless in python2 but will do the
|
# works with unicode strings. The explicit calls to
|
||||||
# right conversion in python 3.
|
# unicode() below are harmless in python2 but will do the
|
||||||
fg_color = (curses.tigetstr("setaf") or
|
# right conversion in python 3.
|
||||||
curses.tigetstr("setf") or "")
|
fg_color = (curses.tigetstr("setaf") or
|
||||||
if (3, 0) < sys.version_info < (3, 2, 3):
|
curses.tigetstr("setf") or "")
|
||||||
fg_color = unicode_type(fg_color, "ascii")
|
if (3, 0) < sys.version_info < (3, 2, 3):
|
||||||
|
fg_color = unicode_type(fg_color, "ascii")
|
||||||
|
|
||||||
for levelno, code in colors.items():
|
for levelno, code in colors.items():
|
||||||
self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii")
|
self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii")
|
||||||
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
|
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
|
||||||
|
else:
|
||||||
|
# If curses is not present (currently we'll only get here for
|
||||||
|
# colorama on windows), assume hard-coded ANSI color codes.
|
||||||
|
for levelno, code in colors.items():
|
||||||
|
self._colors[levelno] = '\033[2;3%dm' % code
|
||||||
|
self._normal = '\033[0m'
|
||||||
else:
|
else:
|
||||||
self._normal = ''
|
self._normal = ''
|
||||||
|
|
||||||
|
@ -183,7 +212,8 @@ def enable_pretty_logging(options=None, logger=None):
|
||||||
and `tornado.options.parse_config_file`.
|
and `tornado.options.parse_config_file`.
|
||||||
"""
|
"""
|
||||||
if options is None:
|
if options is None:
|
||||||
from tornado.options import options
|
import tornado.options
|
||||||
|
options = tornado.options.options
|
||||||
if options.logging is None or options.logging.lower() == 'none':
|
if options.logging is None or options.logging.lower() == 'none':
|
||||||
return
|
return
|
||||||
if logger is None:
|
if logger is None:
|
||||||
|
@ -228,7 +258,8 @@ def define_logging_options(options=None):
|
||||||
"""
|
"""
|
||||||
if options is None:
|
if options is None:
|
||||||
# late import to prevent cycle
|
# late import to prevent cycle
|
||||||
from tornado.options import options
|
import tornado.options
|
||||||
|
options = tornado.options.options
|
||||||
options.define("logging", default="info",
|
options.define("logging", default="info",
|
||||||
help=("Set the Python log level. If 'none', tornado won't touch the "
|
help=("Set the Python log level. If 'none', tornado won't touch the "
|
||||||
"logging configuration."),
|
"logging configuration."),
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
"""Miscellaneous network utility code."""
|
"""Miscellaneous network utility code."""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import errno
|
import errno
|
||||||
import os
|
import os
|
||||||
|
@ -27,7 +27,7 @@ import stat
|
||||||
from tornado.concurrent import dummy_executor, run_on_executor
|
from tornado.concurrent import dummy_executor, run_on_executor
|
||||||
from tornado.ioloop import IOLoop
|
from tornado.ioloop import IOLoop
|
||||||
from tornado.platform.auto import set_close_exec
|
from tornado.platform.auto import set_close_exec
|
||||||
from tornado.util import u, Configurable, errno_from_exception
|
from tornado.util import PY3, Configurable, errno_from_exception
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ssl
|
import ssl
|
||||||
|
@ -44,20 +44,18 @@ except ImportError:
|
||||||
else:
|
else:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
try:
|
if PY3:
|
||||||
xrange # py2
|
xrange = range
|
||||||
except NameError:
|
|
||||||
xrange = range # py3
|
|
||||||
|
|
||||||
if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+
|
if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+
|
||||||
ssl_match_hostname = ssl.match_hostname
|
ssl_match_hostname = ssl.match_hostname
|
||||||
SSLCertificateError = ssl.CertificateError
|
SSLCertificateError = ssl.CertificateError
|
||||||
elif ssl is None:
|
elif ssl is None:
|
||||||
ssl_match_hostname = SSLCertificateError = None
|
ssl_match_hostname = SSLCertificateError = None # type: ignore
|
||||||
else:
|
else:
|
||||||
import backports.ssl_match_hostname
|
import backports.ssl_match_hostname
|
||||||
ssl_match_hostname = backports.ssl_match_hostname.match_hostname
|
ssl_match_hostname = backports.ssl_match_hostname.match_hostname
|
||||||
SSLCertificateError = backports.ssl_match_hostname.CertificateError
|
SSLCertificateError = backports.ssl_match_hostname.CertificateError # type: ignore
|
||||||
|
|
||||||
if hasattr(ssl, 'SSLContext'):
|
if hasattr(ssl, 'SSLContext'):
|
||||||
if hasattr(ssl, 'create_default_context'):
|
if hasattr(ssl, 'create_default_context'):
|
||||||
|
@ -96,7 +94,10 @@ else:
|
||||||
# module-import time, the import lock is already held by the main thread,
|
# module-import time, the import lock is already held by the main thread,
|
||||||
# leading to deadlock. Avoid it by caching the idna encoder on the main
|
# leading to deadlock. Avoid it by caching the idna encoder on the main
|
||||||
# thread now.
|
# thread now.
|
||||||
u('foo').encode('idna')
|
u'foo'.encode('idna')
|
||||||
|
|
||||||
|
# For undiagnosed reasons, 'latin1' codec may also need to be preloaded.
|
||||||
|
u'foo'.encode('latin1')
|
||||||
|
|
||||||
# These errnos indicate that a non-blocking operation must be retried
|
# These errnos indicate that a non-blocking operation must be retried
|
||||||
# at a later time. On most platforms they're the same value, but on
|
# at a later time. On most platforms they're the same value, but on
|
||||||
|
@ -104,7 +105,7 @@ u('foo').encode('idna')
|
||||||
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
|
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
|
||||||
|
|
||||||
if hasattr(errno, "WSAEWOULDBLOCK"):
|
if hasattr(errno, "WSAEWOULDBLOCK"):
|
||||||
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
|
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore
|
||||||
|
|
||||||
# Default backlog used when calling sock.listen()
|
# Default backlog used when calling sock.listen()
|
||||||
_DEFAULT_BACKLOG = 128
|
_DEFAULT_BACKLOG = 128
|
||||||
|
@ -131,7 +132,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
|
||||||
``flags`` is a bitmask of AI_* flags to `~socket.getaddrinfo`, like
|
``flags`` is a bitmask of AI_* flags to `~socket.getaddrinfo`, like
|
||||||
``socket.AI_PASSIVE | socket.AI_NUMERICHOST``.
|
``socket.AI_PASSIVE | socket.AI_NUMERICHOST``.
|
||||||
|
|
||||||
``resuse_port`` option sets ``SO_REUSEPORT`` option for every socket
|
``reuse_port`` option sets ``SO_REUSEPORT`` option for every socket
|
||||||
in the list. If your platform doesn't support this option ValueError will
|
in the list. If your platform doesn't support this option ValueError will
|
||||||
be raised.
|
be raised.
|
||||||
"""
|
"""
|
||||||
|
@ -199,6 +200,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
|
||||||
sockets.append(sock)
|
sockets.append(sock)
|
||||||
return sockets
|
return sockets
|
||||||
|
|
||||||
|
|
||||||
if hasattr(socket, 'AF_UNIX'):
|
if hasattr(socket, 'AF_UNIX'):
|
||||||
def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
|
def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
|
||||||
"""Creates a listening unix socket.
|
"""Creates a listening unix socket.
|
||||||
|
@ -334,6 +336,11 @@ class Resolver(Configurable):
|
||||||
port)`` pair for IPv4; additional fields may be present for
|
port)`` pair for IPv4; additional fields may be present for
|
||||||
IPv6). If a ``callback`` is passed, it will be run with the
|
IPv6). If a ``callback`` is passed, it will be run with the
|
||||||
result as an argument when it is complete.
|
result as an argument when it is complete.
|
||||||
|
|
||||||
|
:raises IOError: if the address cannot be resolved.
|
||||||
|
|
||||||
|
.. versionchanged:: 4.4
|
||||||
|
Standardized all implementations to raise `IOError`.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@ -413,8 +420,8 @@ class ThreadedResolver(ExecutorResolver):
|
||||||
All ``ThreadedResolvers`` share a single thread pool, whose
|
All ``ThreadedResolvers`` share a single thread pool, whose
|
||||||
size is set by the first one to be created.
|
size is set by the first one to be created.
|
||||||
"""
|
"""
|
||||||
_threadpool = None
|
_threadpool = None # type: ignore
|
||||||
_threadpool_pid = None
|
_threadpool_pid = None # type: int
|
||||||
|
|
||||||
def initialize(self, io_loop=None, num_threads=10):
|
def initialize(self, io_loop=None, num_threads=10):
|
||||||
threadpool = ThreadedResolver._create_threadpool(num_threads)
|
threadpool = ThreadedResolver._create_threadpool(num_threads)
|
||||||
|
@ -518,4 +525,4 @@ def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
|
||||||
else:
|
else:
|
||||||
return context.wrap_socket(socket, **kwargs)
|
return context.wrap_socket(socket, **kwargs)
|
||||||
else:
|
else:
|
||||||
return ssl.wrap_socket(socket, **dict(context, **kwargs))
|
return ssl.wrap_socket(socket, **dict(context, **kwargs)) # type: ignore
|
||||||
|
|
|
@ -41,6 +41,12 @@ either::
|
||||||
# or
|
# or
|
||||||
tornado.options.parse_config_file("/etc/server.conf")
|
tornado.options.parse_config_file("/etc/server.conf")
|
||||||
|
|
||||||
|
.. note:
|
||||||
|
|
||||||
|
When using tornado.options.parse_command_line or
|
||||||
|
tornado.options.parse_config_file, the only options that are set are
|
||||||
|
ones that were previously defined with tornado.options.define.
|
||||||
|
|
||||||
Command line formats are what you would expect (``--myoption=myvalue``).
|
Command line formats are what you would expect (``--myoption=myvalue``).
|
||||||
Config files are just Python files. Global names become options, e.g.::
|
Config files are just Python files. Global names become options, e.g.::
|
||||||
|
|
||||||
|
@ -76,7 +82,7 @@ instances to define isolated sets of options, such as for subcommands.
|
||||||
underscores.
|
underscores.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import numbers
|
import numbers
|
||||||
|
@ -132,8 +138,10 @@ class OptionParser(object):
|
||||||
return name in self._options
|
return name in self._options
|
||||||
|
|
||||||
def __getitem__(self, name):
|
def __getitem__(self, name):
|
||||||
name = self._normalize_name(name)
|
return self.__getattr__(name)
|
||||||
return self._options[name].value()
|
|
||||||
|
def __setitem__(self, name, value):
|
||||||
|
return self.__setattr__(name, value)
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
"""A sequence of (name, value) pairs.
|
"""A sequence of (name, value) pairs.
|
||||||
|
@ -300,8 +308,12 @@ class OptionParser(object):
|
||||||
.. versionchanged:: 4.1
|
.. versionchanged:: 4.1
|
||||||
Config files are now always interpreted as utf-8 instead of
|
Config files are now always interpreted as utf-8 instead of
|
||||||
the system default encoding.
|
the system default encoding.
|
||||||
|
|
||||||
|
.. versionchanged:: 4.4
|
||||||
|
The special variable ``__file__`` is available inside config
|
||||||
|
files, specifying the absolute path to the config file itself.
|
||||||
"""
|
"""
|
||||||
config = {}
|
config = {'__file__': os.path.abspath(path)}
|
||||||
with open(path, 'rb') as f:
|
with open(path, 'rb') as f:
|
||||||
exec_in(native_str(f.read()), config, config)
|
exec_in(native_str(f.read()), config, config)
|
||||||
for name in config:
|
for name in config:
|
||||||
|
|
|
@ -14,12 +14,12 @@ loops.
|
||||||
|
|
||||||
.. note::
|
.. note::
|
||||||
|
|
||||||
Tornado requires the `~asyncio.BaseEventLoop.add_reader` family of methods,
|
Tornado requires the `~asyncio.AbstractEventLoop.add_reader` family of
|
||||||
so it is not compatible with the `~asyncio.ProactorEventLoop` on Windows.
|
methods, so it is not compatible with the `~asyncio.ProactorEventLoop` on
|
||||||
Use the `~asyncio.SelectorEventLoop` instead.
|
Windows. Use the `~asyncio.SelectorEventLoop` instead.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
import tornado.concurrent
|
import tornado.concurrent
|
||||||
|
@ -30,11 +30,11 @@ from tornado import stack_context
|
||||||
try:
|
try:
|
||||||
# Import the real asyncio module for py33+ first. Older versions of the
|
# Import the real asyncio module for py33+ first. Older versions of the
|
||||||
# trollius backport also use this name.
|
# trollius backport also use this name.
|
||||||
import asyncio
|
import asyncio # type: ignore
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
# Asyncio itself isn't available; see if trollius is (backport to py26+).
|
# Asyncio itself isn't available; see if trollius is (backport to py26+).
|
||||||
try:
|
try:
|
||||||
import trollius as asyncio
|
import trollius as asyncio # type: ignore
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Re-raise the original asyncio error, not the trollius one.
|
# Re-raise the original asyncio error, not the trollius one.
|
||||||
raise e
|
raise e
|
||||||
|
@ -141,6 +141,8 @@ class BaseAsyncIOLoop(IOLoop):
|
||||||
|
|
||||||
def add_callback(self, callback, *args, **kwargs):
|
def add_callback(self, callback, *args, **kwargs):
|
||||||
if self.closing:
|
if self.closing:
|
||||||
|
# TODO: this is racy; we need a lock to ensure that the
|
||||||
|
# loop isn't closed during call_soon_threadsafe.
|
||||||
raise RuntimeError("IOLoop is closing")
|
raise RuntimeError("IOLoop is closing")
|
||||||
self.asyncio_loop.call_soon_threadsafe(
|
self.asyncio_loop.call_soon_threadsafe(
|
||||||
self._run_callback,
|
self._run_callback,
|
||||||
|
@ -158,6 +160,9 @@ class AsyncIOMainLoop(BaseAsyncIOLoop):
|
||||||
import asyncio
|
import asyncio
|
||||||
AsyncIOMainLoop().install()
|
AsyncIOMainLoop().install()
|
||||||
asyncio.get_event_loop().run_forever()
|
asyncio.get_event_loop().run_forever()
|
||||||
|
|
||||||
|
See also :meth:`tornado.ioloop.IOLoop.install` for general notes on
|
||||||
|
installing alternative IOLoops.
|
||||||
"""
|
"""
|
||||||
def initialize(self, **kwargs):
|
def initialize(self, **kwargs):
|
||||||
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(),
|
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(),
|
||||||
|
@ -212,5 +217,6 @@ def to_asyncio_future(tornado_future):
|
||||||
tornado.concurrent.chain_future(tornado_future, af)
|
tornado.concurrent.chain_future(tornado_future, af)
|
||||||
return af
|
return af
|
||||||
|
|
||||||
|
|
||||||
if hasattr(convert_yielded, 'register'):
|
if hasattr(convert_yielded, 'register'):
|
||||||
convert_yielded.register(asyncio.Future, to_tornado_future)
|
convert_yielded.register(asyncio.Future, to_tornado_future) # type: ignore
|
||||||
|
|
|
@ -23,7 +23,7 @@ Most code that needs access to this functionality should do e.g.::
|
||||||
from tornado.platform.auto import set_close_exec
|
from tornado.platform.auto import set_close_exec
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
|
||||||
|
@ -47,8 +47,13 @@ try:
|
||||||
except ImportError:
|
except ImportError:
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
from time import monotonic as monotonic_time
|
# monotonic can provide a monotonic function in versions of python before
|
||||||
|
# 3.3, too.
|
||||||
|
from monotonic import monotonic as monotonic_time
|
||||||
except ImportError:
|
except ImportError:
|
||||||
monotonic_time = None
|
try:
|
||||||
|
from time import monotonic as monotonic_time
|
||||||
|
except ImportError:
|
||||||
|
monotonic_time = None
|
||||||
|
|
||||||
__all__ = ['Waker', 'set_close_exec', 'monotonic_time']
|
__all__ = ['Waker', 'set_close_exec', 'monotonic_time']
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
import pycares
|
import pycares # type: ignore
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
from tornado import gen
|
from tornado import gen
|
||||||
|
@ -61,8 +61,8 @@ class CaresResolver(Resolver):
|
||||||
assert not callback_args.kwargs
|
assert not callback_args.kwargs
|
||||||
result, error = callback_args.args
|
result, error = callback_args.args
|
||||||
if error:
|
if error:
|
||||||
raise Exception('C-Ares returned error %s: %s while resolving %s' %
|
raise IOError('C-Ares returned error %s: %s while resolving %s' %
|
||||||
(error, pycares.errno.strerror(error), host))
|
(error, pycares.errno.strerror(error), host))
|
||||||
addresses = result.addresses
|
addresses = result.addresses
|
||||||
addrinfo = []
|
addrinfo = []
|
||||||
for address in addresses:
|
for address in addresses:
|
||||||
|
@ -73,7 +73,7 @@ class CaresResolver(Resolver):
|
||||||
else:
|
else:
|
||||||
address_family = socket.AF_UNSPEC
|
address_family = socket.AF_UNSPEC
|
||||||
if family != socket.AF_UNSPEC and family != address_family:
|
if family != socket.AF_UNSPEC and family != address_family:
|
||||||
raise Exception('Requested socket family %d but got %d' %
|
raise IOError('Requested socket family %d but got %d' %
|
||||||
(family, address_family))
|
(family, address_family))
|
||||||
addrinfo.append((address_family, (address, port)))
|
addrinfo.append((address_family, (address, port)))
|
||||||
raise gen.Return(addrinfo)
|
raise gen.Return(addrinfo)
|
||||||
|
|
|
@ -1,10 +1,27 @@
|
||||||
"""Lowest-common-denominator implementations of platform functionality."""
|
"""Lowest-common-denominator implementations of platform functionality."""
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import errno
|
import errno
|
||||||
import socket
|
import socket
|
||||||
|
import time
|
||||||
|
|
||||||
from tornado.platform import interface
|
from tornado.platform import interface
|
||||||
|
from tornado.util import errno_from_exception
|
||||||
|
|
||||||
|
|
||||||
|
def try_close(f):
|
||||||
|
# Avoid issue #875 (race condition when using the file in another
|
||||||
|
# thread).
|
||||||
|
for i in range(10):
|
||||||
|
try:
|
||||||
|
f.close()
|
||||||
|
except IOError:
|
||||||
|
# Yield to another thread
|
||||||
|
time.sleep(1e-3)
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
# Try a last time and let raise
|
||||||
|
f.close()
|
||||||
|
|
||||||
|
|
||||||
class Waker(interface.Waker):
|
class Waker(interface.Waker):
|
||||||
|
@ -45,7 +62,7 @@ class Waker(interface.Waker):
|
||||||
break # success
|
break # success
|
||||||
except socket.error as detail:
|
except socket.error as detail:
|
||||||
if (not hasattr(errno, 'WSAEADDRINUSE') or
|
if (not hasattr(errno, 'WSAEADDRINUSE') or
|
||||||
detail[0] != errno.WSAEADDRINUSE):
|
errno_from_exception(detail) != errno.WSAEADDRINUSE):
|
||||||
# "Address already in use" is the only error
|
# "Address already in use" is the only error
|
||||||
# I've seen on two WinXP Pro SP2 boxes, under
|
# I've seen on two WinXP Pro SP2 boxes, under
|
||||||
# Pythons 2.3.5 and 2.4.1.
|
# Pythons 2.3.5 and 2.4.1.
|
||||||
|
@ -75,7 +92,7 @@ class Waker(interface.Waker):
|
||||||
def wake(self):
|
def wake(self):
|
||||||
try:
|
try:
|
||||||
self.writer.send(b"x")
|
self.writer.send(b"x")
|
||||||
except (IOError, socket.error):
|
except (IOError, socket.error, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def consume(self):
|
def consume(self):
|
||||||
|
@ -89,4 +106,4 @@ class Waker(interface.Waker):
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.reader.close()
|
self.reader.close()
|
||||||
self.writer.close()
|
try_close(self.writer)
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
"""EPoll-based IOLoop implementation for Linux systems."""
|
"""EPoll-based IOLoop implementation for Linux systems."""
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import select
|
import select
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ for other tornado.platform modules. Most code should import the appropriate
|
||||||
implementation from `tornado.platform.auto`.
|
implementation from `tornado.platform.auto`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
|
||||||
def set_close_exec(fd):
|
def set_close_exec(fd):
|
||||||
|
@ -61,3 +61,7 @@ class Waker(object):
|
||||||
def close(self):
|
def close(self):
|
||||||
"""Closes the waker's file descriptor(s)."""
|
"""Closes the waker's file descriptor(s)."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
def monotonic_time():
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
|
@ -14,7 +14,7 @@
|
||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
"""KQueue-based IOLoop implementation for BSD/Mac systems."""
|
"""KQueue-based IOLoop implementation for BSD/Mac systems."""
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import select
|
import select
|
||||||
|
|
||||||
|
|
|
@ -16,12 +16,12 @@
|
||||||
|
|
||||||
"""Posix implementations of platform-specific functionality."""
|
"""Posix implementations of platform-specific functionality."""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import fcntl
|
import fcntl
|
||||||
import os
|
import os
|
||||||
|
|
||||||
from tornado.platform import interface
|
from tornado.platform import common, interface
|
||||||
|
|
||||||
|
|
||||||
def set_close_exec(fd):
|
def set_close_exec(fd):
|
||||||
|
@ -53,7 +53,7 @@ class Waker(interface.Waker):
|
||||||
def wake(self):
|
def wake(self):
|
||||||
try:
|
try:
|
||||||
self.writer.write(b"x")
|
self.writer.write(b"x")
|
||||||
except IOError:
|
except (IOError, ValueError):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def consume(self):
|
def consume(self):
|
||||||
|
@ -67,4 +67,4 @@ class Waker(interface.Waker):
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.reader.close()
|
self.reader.close()
|
||||||
self.writer.close()
|
common.try_close(self.writer)
|
||||||
|
|
|
@ -17,7 +17,7 @@
|
||||||
|
|
||||||
Used as a fallback for systems that don't support epoll or kqueue.
|
Used as a fallback for systems that don't support epoll or kqueue.
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import select
|
import select
|
||||||
|
|
||||||
|
|
|
@ -21,7 +21,7 @@ depending on which library's underlying event loop you want to use.
|
||||||
This module has been tested with Twisted versions 11.0.0 and newer.
|
This module has been tested with Twisted versions 11.0.0 and newer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import functools
|
import functools
|
||||||
|
@ -29,19 +29,18 @@ import numbers
|
||||||
import socket
|
import socket
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
import twisted.internet.abstract
|
import twisted.internet.abstract # type: ignore
|
||||||
from twisted.internet.defer import Deferred
|
from twisted.internet.defer import Deferred # type: ignore
|
||||||
from twisted.internet.posixbase import PosixReactorBase
|
from twisted.internet.posixbase import PosixReactorBase # type: ignore
|
||||||
from twisted.internet.interfaces import \
|
from twisted.internet.interfaces import IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor # type: ignore
|
||||||
IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor
|
from twisted.python import failure, log # type: ignore
|
||||||
from twisted.python import failure, log
|
from twisted.internet import error # type: ignore
|
||||||
from twisted.internet import error
|
import twisted.names.cache # type: ignore
|
||||||
import twisted.names.cache
|
import twisted.names.client # type: ignore
|
||||||
import twisted.names.client
|
import twisted.names.hosts # type: ignore
|
||||||
import twisted.names.hosts
|
import twisted.names.resolve # type: ignore
|
||||||
import twisted.names.resolve
|
|
||||||
|
|
||||||
from zope.interface import implementer
|
from zope.interface import implementer # type: ignore
|
||||||
|
|
||||||
from tornado.concurrent import Future
|
from tornado.concurrent import Future
|
||||||
from tornado.escape import utf8
|
from tornado.escape import utf8
|
||||||
|
@ -354,7 +353,7 @@ def install(io_loop=None):
|
||||||
if not io_loop:
|
if not io_loop:
|
||||||
io_loop = tornado.ioloop.IOLoop.current()
|
io_loop = tornado.ioloop.IOLoop.current()
|
||||||
reactor = TornadoReactor(io_loop)
|
reactor = TornadoReactor(io_loop)
|
||||||
from twisted.internet.main import installReactor
|
from twisted.internet.main import installReactor # type: ignore
|
||||||
installReactor(reactor)
|
installReactor(reactor)
|
||||||
return reactor
|
return reactor
|
||||||
|
|
||||||
|
@ -408,11 +407,14 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
|
||||||
Not compatible with `tornado.process.Subprocess.set_exit_callback`
|
Not compatible with `tornado.process.Subprocess.set_exit_callback`
|
||||||
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
|
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
|
||||||
with each other.
|
with each other.
|
||||||
|
|
||||||
|
See also :meth:`tornado.ioloop.IOLoop.install` for general notes on
|
||||||
|
installing alternative IOLoops.
|
||||||
"""
|
"""
|
||||||
def initialize(self, reactor=None, **kwargs):
|
def initialize(self, reactor=None, **kwargs):
|
||||||
super(TwistedIOLoop, self).initialize(**kwargs)
|
super(TwistedIOLoop, self).initialize(**kwargs)
|
||||||
if reactor is None:
|
if reactor is None:
|
||||||
import twisted.internet.reactor
|
import twisted.internet.reactor # type: ignore
|
||||||
reactor = twisted.internet.reactor
|
reactor = twisted.internet.reactor
|
||||||
self.reactor = reactor
|
self.reactor = reactor
|
||||||
self.fds = {}
|
self.fds = {}
|
||||||
|
@ -554,7 +556,10 @@ class TwistedResolver(Resolver):
|
||||||
deferred = self.resolver.getHostByName(utf8(host))
|
deferred = self.resolver.getHostByName(utf8(host))
|
||||||
resolved = yield gen.Task(deferred.addBoth)
|
resolved = yield gen.Task(deferred.addBoth)
|
||||||
if isinstance(resolved, failure.Failure):
|
if isinstance(resolved, failure.Failure):
|
||||||
resolved.raiseException()
|
try:
|
||||||
|
resolved.raiseException()
|
||||||
|
except twisted.names.error.DomainError as e:
|
||||||
|
raise IOError(e)
|
||||||
elif twisted.internet.abstract.isIPAddress(resolved):
|
elif twisted.internet.abstract.isIPAddress(resolved):
|
||||||
resolved_family = socket.AF_INET
|
resolved_family = socket.AF_INET
|
||||||
elif twisted.internet.abstract.isIPv6Address(resolved):
|
elif twisted.internet.abstract.isIPv6Address(resolved):
|
||||||
|
@ -569,8 +574,9 @@ class TwistedResolver(Resolver):
|
||||||
]
|
]
|
||||||
raise gen.Return(result)
|
raise gen.Return(result)
|
||||||
|
|
||||||
|
|
||||||
if hasattr(gen.convert_yielded, 'register'):
|
if hasattr(gen.convert_yielded, 'register'):
|
||||||
@gen.convert_yielded.register(Deferred)
|
@gen.convert_yielded.register(Deferred) # type: ignore
|
||||||
def _(d):
|
def _(d):
|
||||||
f = Future()
|
f = Future()
|
||||||
|
|
||||||
|
|
|
@ -2,9 +2,9 @@
|
||||||
# for production use.
|
# for production use.
|
||||||
|
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
import ctypes
|
import ctypes # type: ignore
|
||||||
import ctypes.wintypes
|
import ctypes.wintypes # type: ignore
|
||||||
|
|
||||||
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
|
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
|
||||||
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
|
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
|
||||||
|
@ -17,4 +17,4 @@ HANDLE_FLAG_INHERIT = 0x00000001
|
||||||
def set_close_exec(fd):
|
def set_close_exec(fd):
|
||||||
success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
|
success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
|
||||||
if not success:
|
if not success:
|
||||||
raise ctypes.GetLastError()
|
raise ctypes.WinError()
|
||||||
|
|
|
@ -18,7 +18,7 @@
|
||||||
the server into multiple processes and managing subprocesses.
|
the server into multiple processes and managing subprocesses.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import errno
|
import errno
|
||||||
import os
|
import os
|
||||||
|
@ -35,7 +35,7 @@ from tornado.iostream import PipeIOStream
|
||||||
from tornado.log import gen_log
|
from tornado.log import gen_log
|
||||||
from tornado.platform.auto import set_close_exec
|
from tornado.platform.auto import set_close_exec
|
||||||
from tornado import stack_context
|
from tornado import stack_context
|
||||||
from tornado.util import errno_from_exception
|
from tornado.util import errno_from_exception, PY3
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
|
@ -43,11 +43,8 @@ except ImportError:
|
||||||
# Multiprocessing is not available on Google App Engine.
|
# Multiprocessing is not available on Google App Engine.
|
||||||
multiprocessing = None
|
multiprocessing = None
|
||||||
|
|
||||||
try:
|
if PY3:
|
||||||
long # py2
|
long = int
|
||||||
except NameError:
|
|
||||||
long = int # py3
|
|
||||||
|
|
||||||
|
|
||||||
# Re-export this exception for convenience.
|
# Re-export this exception for convenience.
|
||||||
try:
|
try:
|
||||||
|
@ -70,7 +67,7 @@ def cpu_count():
|
||||||
pass
|
pass
|
||||||
try:
|
try:
|
||||||
return os.sysconf("SC_NPROCESSORS_CONF")
|
return os.sysconf("SC_NPROCESSORS_CONF")
|
||||||
except ValueError:
|
except (AttributeError, ValueError):
|
||||||
pass
|
pass
|
||||||
gen_log.error("Could not detect number of processors; assuming 1")
|
gen_log.error("Could not detect number of processors; assuming 1")
|
||||||
return 1
|
return 1
|
||||||
|
@ -147,6 +144,7 @@ def fork_processes(num_processes, max_restarts=100):
|
||||||
else:
|
else:
|
||||||
children[pid] = i
|
children[pid] = i
|
||||||
return None
|
return None
|
||||||
|
|
||||||
for i in range(num_processes):
|
for i in range(num_processes):
|
||||||
id = start_child(i)
|
id = start_child(i)
|
||||||
if id is not None:
|
if id is not None:
|
||||||
|
@ -204,13 +202,19 @@ class Subprocess(object):
|
||||||
attribute of the resulting Subprocess a `.PipeIOStream`.
|
attribute of the resulting Subprocess a `.PipeIOStream`.
|
||||||
* A new keyword argument ``io_loop`` may be used to pass in an IOLoop.
|
* A new keyword argument ``io_loop`` may be used to pass in an IOLoop.
|
||||||
|
|
||||||
|
The ``Subprocess.STREAM`` option and the ``set_exit_callback`` and
|
||||||
|
``wait_for_exit`` methods do not work on Windows. There is
|
||||||
|
therefore no reason to use this class instead of
|
||||||
|
``subprocess.Popen`` on that platform.
|
||||||
|
|
||||||
.. versionchanged:: 4.1
|
.. versionchanged:: 4.1
|
||||||
The ``io_loop`` argument is deprecated.
|
The ``io_loop`` argument is deprecated.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
STREAM = object()
|
STREAM = object()
|
||||||
|
|
||||||
_initialized = False
|
_initialized = False
|
||||||
_waiting = {}
|
_waiting = {} # type: ignore
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.io_loop = kwargs.pop('io_loop', None) or ioloop.IOLoop.current()
|
self.io_loop = kwargs.pop('io_loop', None) or ioloop.IOLoop.current()
|
||||||
|
@ -351,6 +355,10 @@ class Subprocess(object):
|
||||||
else:
|
else:
|
||||||
assert os.WIFEXITED(status)
|
assert os.WIFEXITED(status)
|
||||||
self.returncode = os.WEXITSTATUS(status)
|
self.returncode = os.WEXITSTATUS(status)
|
||||||
|
# We've taken over wait() duty from the subprocess.Popen
|
||||||
|
# object. If we don't inform it of the process's return code,
|
||||||
|
# it will log a warning at destruction in python 3.6+.
|
||||||
|
self.proc.returncode = self.returncode
|
||||||
if self._exit_callback:
|
if self._exit_callback:
|
||||||
callback = self._exit_callback
|
callback = self._exit_callback
|
||||||
self._exit_callback = None
|
self._exit_callback = None
|
||||||
|
|
|
@ -12,9 +12,17 @@
|
||||||
# License for the specific language governing permissions and limitations
|
# License for the specific language governing permissions and limitations
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
"""Asynchronous queues for coroutines.
|
||||||
|
|
||||||
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
|
.. warning::
|
||||||
|
|
||||||
|
Unlike the standard library's `queue` module, the classes defined here
|
||||||
|
are *not* thread-safe. To use these queues from another thread,
|
||||||
|
use `.IOLoop.add_callback` to transfer control to the `.IOLoop` thread
|
||||||
|
before calling any queue methods.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import collections
|
import collections
|
||||||
import heapq
|
import heapq
|
||||||
|
@ -23,6 +31,8 @@ from tornado import gen, ioloop
|
||||||
from tornado.concurrent import Future
|
from tornado.concurrent import Future
|
||||||
from tornado.locks import Event
|
from tornado.locks import Event
|
||||||
|
|
||||||
|
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
|
||||||
|
|
||||||
|
|
||||||
class QueueEmpty(Exception):
|
class QueueEmpty(Exception):
|
||||||
"""Raised by `.Queue.get_nowait` when the queue has no items."""
|
"""Raised by `.Queue.get_nowait` when the queue has no items."""
|
||||||
|
|
|
@ -0,0 +1,625 @@
|
||||||
|
# Copyright 2015 The Tornado Authors
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License"); you may
|
||||||
|
# not use this file except in compliance with the License. You may obtain
|
||||||
|
# a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
|
||||||
|
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
|
||||||
|
# License for the specific language governing permissions and limitations
|
||||||
|
# under the License.
|
||||||
|
|
||||||
|
"""Flexible routing implementation.
|
||||||
|
|
||||||
|
Tornado routes HTTP requests to appropriate handlers using `Router`
|
||||||
|
class implementations. The `tornado.web.Application` class is a
|
||||||
|
`Router` implementation and may be used directly, or the classes in
|
||||||
|
this module may be used for additional flexibility. The `RuleRouter`
|
||||||
|
class can match on more criteria than `.Application`, or the `Router`
|
||||||
|
interface can be subclassed for maximum customization.
|
||||||
|
|
||||||
|
`Router` interface extends `~.httputil.HTTPServerConnectionDelegate`
|
||||||
|
to provide additional routing capabilities. This also means that any
|
||||||
|
`Router` implementation can be used directly as a ``request_callback``
|
||||||
|
for `~.httpserver.HTTPServer` constructor.
|
||||||
|
|
||||||
|
`Router` subclass must implement a ``find_handler`` method to provide
|
||||||
|
a suitable `~.httputil.HTTPMessageDelegate` instance to handle the
|
||||||
|
request:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class CustomRouter(Router):
|
||||||
|
def find_handler(self, request, **kwargs):
|
||||||
|
# some routing logic providing a suitable HTTPMessageDelegate instance
|
||||||
|
return MessageDelegate(request.connection)
|
||||||
|
|
||||||
|
class MessageDelegate(HTTPMessageDelegate):
|
||||||
|
def __init__(self, connection):
|
||||||
|
self.connection = connection
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
self.connection.write_headers(
|
||||||
|
ResponseStartLine("HTTP/1.1", 200, "OK"),
|
||||||
|
HTTPHeaders({"Content-Length": "2"}),
|
||||||
|
b"OK")
|
||||||
|
self.connection.finish()
|
||||||
|
|
||||||
|
router = CustomRouter()
|
||||||
|
server = HTTPServer(router)
|
||||||
|
|
||||||
|
The main responsibility of `Router` implementation is to provide a
|
||||||
|
mapping from a request to `~.httputil.HTTPMessageDelegate` instance
|
||||||
|
that will handle this request. In the example above we can see that
|
||||||
|
routing is possible even without instantiating an `~.web.Application`.
|
||||||
|
|
||||||
|
For routing to `~.web.RequestHandler` implementations we need an
|
||||||
|
`~.web.Application` instance. `~.web.Application.get_handler_delegate`
|
||||||
|
provides a convenient way to create `~.httputil.HTTPMessageDelegate`
|
||||||
|
for a given request and `~.web.RequestHandler`.
|
||||||
|
|
||||||
|
Here is a simple example of how we can we route to
|
||||||
|
`~.web.RequestHandler` subclasses by HTTP method:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
resources = {}
|
||||||
|
|
||||||
|
class GetResource(RequestHandler):
|
||||||
|
def get(self, path):
|
||||||
|
if path not in resources:
|
||||||
|
raise HTTPError(404)
|
||||||
|
|
||||||
|
self.finish(resources[path])
|
||||||
|
|
||||||
|
class PostResource(RequestHandler):
|
||||||
|
def post(self, path):
|
||||||
|
resources[path] = self.request.body
|
||||||
|
|
||||||
|
class HTTPMethodRouter(Router):
|
||||||
|
def __init__(self, app):
|
||||||
|
self.app = app
|
||||||
|
|
||||||
|
def find_handler(self, request, **kwargs):
|
||||||
|
handler = GetResource if request.method == "GET" else PostResource
|
||||||
|
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
|
||||||
|
|
||||||
|
router = HTTPMethodRouter(Application())
|
||||||
|
server = HTTPServer(router)
|
||||||
|
|
||||||
|
`ReversibleRouter` interface adds the ability to distinguish between
|
||||||
|
the routes and reverse them to the original urls using route's name
|
||||||
|
and additional arguments. `~.web.Application` is itself an
|
||||||
|
implementation of `ReversibleRouter` class.
|
||||||
|
|
||||||
|
`RuleRouter` and `ReversibleRuleRouter` are implementations of
|
||||||
|
`Router` and `ReversibleRouter` interfaces and can be used for
|
||||||
|
creating rule-based routing configurations.
|
||||||
|
|
||||||
|
Rules are instances of `Rule` class. They contain a `Matcher`, which
|
||||||
|
provides the logic for determining whether the rule is a match for a
|
||||||
|
particular request and a target, which can be one of the following.
|
||||||
|
|
||||||
|
1) An instance of `~.httputil.HTTPServerConnectionDelegate`:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
router = RuleRouter([
|
||||||
|
Rule(PathMatches("/handler"), ConnectionDelegate()),
|
||||||
|
# ... more rules
|
||||||
|
])
|
||||||
|
|
||||||
|
class ConnectionDelegate(HTTPServerConnectionDelegate):
|
||||||
|
def start_request(self, server_conn, request_conn):
|
||||||
|
return MessageDelegate(request_conn)
|
||||||
|
|
||||||
|
2) A callable accepting a single argument of `~.httputil.HTTPServerRequest` type:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
router = RuleRouter([
|
||||||
|
Rule(PathMatches("/callable"), request_callable)
|
||||||
|
])
|
||||||
|
|
||||||
|
def request_callable(request):
|
||||||
|
request.write(b"HTTP/1.1 200 OK\\r\\nContent-Length: 2\\r\\n\\r\\nOK")
|
||||||
|
request.finish()
|
||||||
|
|
||||||
|
3) Another `Router` instance:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
router = RuleRouter([
|
||||||
|
Rule(PathMatches("/router.*"), CustomRouter())
|
||||||
|
])
|
||||||
|
|
||||||
|
Of course a nested `RuleRouter` or a `~.web.Application` is allowed:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
router = RuleRouter([
|
||||||
|
Rule(HostMatches("example.com"), RuleRouter([
|
||||||
|
Rule(PathMatches("/app1/.*"), Application([(r"/app1/handler", Handler)]))),
|
||||||
|
]))
|
||||||
|
])
|
||||||
|
|
||||||
|
server = HTTPServer(router)
|
||||||
|
|
||||||
|
In the example below `RuleRouter` is used to route between applications:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
app1 = Application([
|
||||||
|
(r"/app1/handler", Handler1),
|
||||||
|
# other handlers ...
|
||||||
|
])
|
||||||
|
|
||||||
|
app2 = Application([
|
||||||
|
(r"/app2/handler", Handler2),
|
||||||
|
# other handlers ...
|
||||||
|
])
|
||||||
|
|
||||||
|
router = RuleRouter([
|
||||||
|
Rule(PathMatches("/app1.*"), app1),
|
||||||
|
Rule(PathMatches("/app2.*"), app2)
|
||||||
|
])
|
||||||
|
|
||||||
|
server = HTTPServer(router)
|
||||||
|
|
||||||
|
For more information on application-level routing see docs for `~.web.Application`.
|
||||||
|
|
||||||
|
.. versionadded:: 4.5
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
|
import re
|
||||||
|
from functools import partial
|
||||||
|
|
||||||
|
from tornado import httputil
|
||||||
|
from tornado.httpserver import _CallableAdapter
|
||||||
|
from tornado.escape import url_escape, url_unescape, utf8
|
||||||
|
from tornado.log import app_log
|
||||||
|
from tornado.util import basestring_type, import_object, re_unescape, unicode_type
|
||||||
|
|
||||||
|
try:
|
||||||
|
import typing # noqa
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class Router(httputil.HTTPServerConnectionDelegate):
|
||||||
|
"""Abstract router interface."""
|
||||||
|
|
||||||
|
def find_handler(self, request, **kwargs):
|
||||||
|
# type: (httputil.HTTPServerRequest, typing.Any)->httputil.HTTPMessageDelegate
|
||||||
|
"""Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate`
|
||||||
|
that can serve the request.
|
||||||
|
Routing implementations may pass additional kwargs to extend the routing logic.
|
||||||
|
|
||||||
|
:arg httputil.HTTPServerRequest request: current HTTP request.
|
||||||
|
:arg kwargs: additional keyword arguments passed by routing implementation.
|
||||||
|
:returns: an instance of `~.httputil.HTTPMessageDelegate` that will be used to
|
||||||
|
process the request.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def start_request(self, server_conn, request_conn):
|
||||||
|
return _RoutingDelegate(self, server_conn, request_conn)
|
||||||
|
|
||||||
|
|
||||||
|
class ReversibleRouter(Router):
|
||||||
|
"""Abstract router interface for routers that can handle named routes
|
||||||
|
and support reversing them to original urls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def reverse_url(self, name, *args):
|
||||||
|
"""Returns url string for a given route name and arguments
|
||||||
|
or ``None`` if no match is found.
|
||||||
|
|
||||||
|
:arg str name: route name.
|
||||||
|
:arg args: url parameters.
|
||||||
|
:returns: parametrized url string for a given route name (or ``None``).
|
||||||
|
"""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
|
class _RoutingDelegate(httputil.HTTPMessageDelegate):
|
||||||
|
def __init__(self, router, server_conn, request_conn):
|
||||||
|
self.server_conn = server_conn
|
||||||
|
self.request_conn = request_conn
|
||||||
|
self.delegate = None
|
||||||
|
self.router = router # type: Router
|
||||||
|
|
||||||
|
def headers_received(self, start_line, headers):
|
||||||
|
request = httputil.HTTPServerRequest(
|
||||||
|
connection=self.request_conn,
|
||||||
|
server_connection=self.server_conn,
|
||||||
|
start_line=start_line, headers=headers)
|
||||||
|
|
||||||
|
self.delegate = self.router.find_handler(request)
|
||||||
|
return self.delegate.headers_received(start_line, headers)
|
||||||
|
|
||||||
|
def data_received(self, chunk):
|
||||||
|
return self.delegate.data_received(chunk)
|
||||||
|
|
||||||
|
def finish(self):
|
||||||
|
self.delegate.finish()
|
||||||
|
|
||||||
|
def on_connection_close(self):
|
||||||
|
self.delegate.on_connection_close()
|
||||||
|
|
||||||
|
|
||||||
|
class RuleRouter(Router):
|
||||||
|
"""Rule-based router implementation."""
|
||||||
|
|
||||||
|
def __init__(self, rules=None):
|
||||||
|
"""Constructs a router from an ordered list of rules::
|
||||||
|
|
||||||
|
RuleRouter([
|
||||||
|
Rule(PathMatches("/handler"), Target),
|
||||||
|
# ... more rules
|
||||||
|
])
|
||||||
|
|
||||||
|
You can also omit explicit `Rule` constructor and use tuples of arguments::
|
||||||
|
|
||||||
|
RuleRouter([
|
||||||
|
(PathMatches("/handler"), Target),
|
||||||
|
])
|
||||||
|
|
||||||
|
`PathMatches` is a default matcher, so the example above can be simplified::
|
||||||
|
|
||||||
|
RuleRouter([
|
||||||
|
("/handler", Target),
|
||||||
|
])
|
||||||
|
|
||||||
|
In the examples above, ``Target`` can be a nested `Router` instance, an instance of
|
||||||
|
`~.httputil.HTTPServerConnectionDelegate` or an old-style callable, accepting a request argument.
|
||||||
|
|
||||||
|
:arg rules: a list of `Rule` instances or tuples of `Rule`
|
||||||
|
constructor arguments.
|
||||||
|
"""
|
||||||
|
self.rules = [] # type: typing.List[Rule]
|
||||||
|
if rules:
|
||||||
|
self.add_rules(rules)
|
||||||
|
|
||||||
|
def add_rules(self, rules):
|
||||||
|
"""Appends new rules to the router.
|
||||||
|
|
||||||
|
:arg rules: a list of Rule instances (or tuples of arguments, which are
|
||||||
|
passed to Rule constructor).
|
||||||
|
"""
|
||||||
|
for rule in rules:
|
||||||
|
if isinstance(rule, (tuple, list)):
|
||||||
|
assert len(rule) in (2, 3, 4)
|
||||||
|
if isinstance(rule[0], basestring_type):
|
||||||
|
rule = Rule(PathMatches(rule[0]), *rule[1:])
|
||||||
|
else:
|
||||||
|
rule = Rule(*rule)
|
||||||
|
|
||||||
|
self.rules.append(self.process_rule(rule))
|
||||||
|
|
||||||
|
def process_rule(self, rule):
|
||||||
|
"""Override this method for additional preprocessing of each rule.
|
||||||
|
|
||||||
|
:arg Rule rule: a rule to be processed.
|
||||||
|
:returns: the same or modified Rule instance.
|
||||||
|
"""
|
||||||
|
return rule
|
||||||
|
|
||||||
|
def find_handler(self, request, **kwargs):
|
||||||
|
for rule in self.rules:
|
||||||
|
target_params = rule.matcher.match(request)
|
||||||
|
if target_params is not None:
|
||||||
|
if rule.target_kwargs:
|
||||||
|
target_params['target_kwargs'] = rule.target_kwargs
|
||||||
|
|
||||||
|
delegate = self.get_target_delegate(
|
||||||
|
rule.target, request, **target_params)
|
||||||
|
|
||||||
|
if delegate is not None:
|
||||||
|
return delegate
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_target_delegate(self, target, request, **target_params):
|
||||||
|
"""Returns an instance of `~.httputil.HTTPMessageDelegate` for a
|
||||||
|
Rule's target. This method is called by `~.find_handler` and can be
|
||||||
|
extended to provide additional target types.
|
||||||
|
|
||||||
|
:arg target: a Rule's target.
|
||||||
|
:arg httputil.HTTPServerRequest request: current request.
|
||||||
|
:arg target_params: additional parameters that can be useful
|
||||||
|
for `~.httputil.HTTPMessageDelegate` creation.
|
||||||
|
"""
|
||||||
|
if isinstance(target, Router):
|
||||||
|
return target.find_handler(request, **target_params)
|
||||||
|
|
||||||
|
elif isinstance(target, httputil.HTTPServerConnectionDelegate):
|
||||||
|
return target.start_request(request.server_connection, request.connection)
|
||||||
|
|
||||||
|
elif callable(target):
|
||||||
|
return _CallableAdapter(
|
||||||
|
partial(target, **target_params), request.connection
|
||||||
|
)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class ReversibleRuleRouter(ReversibleRouter, RuleRouter):
|
||||||
|
"""A rule-based router that implements ``reverse_url`` method.
|
||||||
|
|
||||||
|
Each rule added to this router may have a ``name`` attribute that can be
|
||||||
|
used to reconstruct an original uri. The actual reconstruction takes place
|
||||||
|
in a rule's matcher (see `Matcher.reverse`).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, rules=None):
|
||||||
|
self.named_rules = {} # type: typing.Dict[str]
|
||||||
|
super(ReversibleRuleRouter, self).__init__(rules)
|
||||||
|
|
||||||
|
def process_rule(self, rule):
|
||||||
|
rule = super(ReversibleRuleRouter, self).process_rule(rule)
|
||||||
|
|
||||||
|
if rule.name:
|
||||||
|
if rule.name in self.named_rules:
|
||||||
|
app_log.warning(
|
||||||
|
"Multiple handlers named %s; replacing previous value",
|
||||||
|
rule.name)
|
||||||
|
self.named_rules[rule.name] = rule
|
||||||
|
|
||||||
|
return rule
|
||||||
|
|
||||||
|
def reverse_url(self, name, *args):
|
||||||
|
if name in self.named_rules:
|
||||||
|
return self.named_rules[name].matcher.reverse(*args)
|
||||||
|
|
||||||
|
for rule in self.rules:
|
||||||
|
if isinstance(rule.target, ReversibleRouter):
|
||||||
|
reversed_url = rule.target.reverse_url(name, *args)
|
||||||
|
if reversed_url is not None:
|
||||||
|
return reversed_url
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class Rule(object):
|
||||||
|
"""A routing rule."""
|
||||||
|
|
||||||
|
def __init__(self, matcher, target, target_kwargs=None, name=None):
|
||||||
|
"""Constructs a Rule instance.
|
||||||
|
|
||||||
|
:arg Matcher matcher: a `Matcher` instance used for determining
|
||||||
|
whether the rule should be considered a match for a specific
|
||||||
|
request.
|
||||||
|
:arg target: a Rule's target (typically a ``RequestHandler`` or
|
||||||
|
`~.httputil.HTTPServerConnectionDelegate` subclass or even a nested `Router`,
|
||||||
|
depending on routing implementation).
|
||||||
|
:arg dict target_kwargs: a dict of parameters that can be useful
|
||||||
|
at the moment of target instantiation (for example, ``status_code``
|
||||||
|
for a ``RequestHandler`` subclass). They end up in
|
||||||
|
``target_params['target_kwargs']`` of `RuleRouter.get_target_delegate`
|
||||||
|
method.
|
||||||
|
:arg str name: the name of the rule that can be used to find it
|
||||||
|
in `ReversibleRouter.reverse_url` implementation.
|
||||||
|
"""
|
||||||
|
if isinstance(target, str):
|
||||||
|
# import the Module and instantiate the class
|
||||||
|
# Must be a fully qualified name (module.ClassName)
|
||||||
|
target = import_object(target)
|
||||||
|
|
||||||
|
self.matcher = matcher # type: Matcher
|
||||||
|
self.target = target
|
||||||
|
self.target_kwargs = target_kwargs if target_kwargs else {}
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def reverse(self, *args):
|
||||||
|
return self.matcher.reverse(*args)
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '%s(%r, %s, kwargs=%r, name=%r)' % \
|
||||||
|
(self.__class__.__name__, self.matcher,
|
||||||
|
self.target, self.target_kwargs, self.name)
|
||||||
|
|
||||||
|
|
||||||
|
class Matcher(object):
|
||||||
|
"""Represents a matcher for request features."""
|
||||||
|
|
||||||
|
def match(self, request):
|
||||||
|
"""Matches current instance against the request.
|
||||||
|
|
||||||
|
:arg httputil.HTTPServerRequest request: current HTTP request
|
||||||
|
:returns: a dict of parameters to be passed to the target handler
|
||||||
|
(for example, ``handler_kwargs``, ``path_args``, ``path_kwargs``
|
||||||
|
can be passed for proper `~.web.RequestHandler` instantiation).
|
||||||
|
An empty dict is a valid (and common) return value to indicate a match
|
||||||
|
when the argument-passing features are not used.
|
||||||
|
``None`` must be returned to indicate that there is no match."""
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def reverse(self, *args):
|
||||||
|
"""Reconstructs full url from matcher instance and additional arguments."""
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class AnyMatches(Matcher):
|
||||||
|
"""Matches any request."""
|
||||||
|
|
||||||
|
def match(self, request):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
class HostMatches(Matcher):
|
||||||
|
"""Matches requests from hosts specified by ``host_pattern`` regex."""
|
||||||
|
|
||||||
|
def __init__(self, host_pattern):
|
||||||
|
if isinstance(host_pattern, basestring_type):
|
||||||
|
if not host_pattern.endswith("$"):
|
||||||
|
host_pattern += "$"
|
||||||
|
self.host_pattern = re.compile(host_pattern)
|
||||||
|
else:
|
||||||
|
self.host_pattern = host_pattern
|
||||||
|
|
||||||
|
def match(self, request):
|
||||||
|
if self.host_pattern.match(request.host_name):
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class DefaultHostMatches(Matcher):
|
||||||
|
"""Matches requests from host that is equal to application's default_host.
|
||||||
|
Always returns no match if ``X-Real-Ip`` header is present.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, application, host_pattern):
|
||||||
|
self.application = application
|
||||||
|
self.host_pattern = host_pattern
|
||||||
|
|
||||||
|
def match(self, request):
|
||||||
|
# Look for default host if not behind load balancer (for debugging)
|
||||||
|
if "X-Real-Ip" not in request.headers:
|
||||||
|
if self.host_pattern.match(self.application.default_host):
|
||||||
|
return {}
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class PathMatches(Matcher):
|
||||||
|
"""Matches requests with paths specified by ``path_pattern`` regex."""
|
||||||
|
|
||||||
|
def __init__(self, path_pattern):
|
||||||
|
if isinstance(path_pattern, basestring_type):
|
||||||
|
if not path_pattern.endswith('$'):
|
||||||
|
path_pattern += '$'
|
||||||
|
self.regex = re.compile(path_pattern)
|
||||||
|
else:
|
||||||
|
self.regex = path_pattern
|
||||||
|
|
||||||
|
assert len(self.regex.groupindex) in (0, self.regex.groups), \
|
||||||
|
("groups in url regexes must either be all named or all "
|
||||||
|
"positional: %r" % self.regex.pattern)
|
||||||
|
|
||||||
|
self._path, self._group_count = self._find_groups()
|
||||||
|
|
||||||
|
def match(self, request):
|
||||||
|
match = self.regex.match(request.path)
|
||||||
|
if match is None:
|
||||||
|
return None
|
||||||
|
if not self.regex.groups:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
path_args, path_kwargs = [], {}
|
||||||
|
|
||||||
|
# Pass matched groups to the handler. Since
|
||||||
|
# match.groups() includes both named and
|
||||||
|
# unnamed groups, we want to use either groups
|
||||||
|
# or groupdict but not both.
|
||||||
|
if self.regex.groupindex:
|
||||||
|
path_kwargs = dict(
|
||||||
|
(str(k), _unquote_or_none(v))
|
||||||
|
for (k, v) in match.groupdict().items())
|
||||||
|
else:
|
||||||
|
path_args = [_unquote_or_none(s) for s in match.groups()]
|
||||||
|
|
||||||
|
return dict(path_args=path_args, path_kwargs=path_kwargs)
|
||||||
|
|
||||||
|
def reverse(self, *args):
|
||||||
|
if self._path is None:
|
||||||
|
raise ValueError("Cannot reverse url regex " + self.regex.pattern)
|
||||||
|
assert len(args) == self._group_count, "required number of arguments " \
|
||||||
|
"not found"
|
||||||
|
if not len(args):
|
||||||
|
return self._path
|
||||||
|
converted_args = []
|
||||||
|
for a in args:
|
||||||
|
if not isinstance(a, (unicode_type, bytes)):
|
||||||
|
a = str(a)
|
||||||
|
converted_args.append(url_escape(utf8(a), plus=False))
|
||||||
|
return self._path % tuple(converted_args)
|
||||||
|
|
||||||
|
def _find_groups(self):
|
||||||
|
"""Returns a tuple (reverse string, group count) for a url.
|
||||||
|
|
||||||
|
For example: Given the url pattern /([0-9]{4})/([a-z-]+)/, this method
|
||||||
|
would return ('/%s/%s/', 2).
|
||||||
|
"""
|
||||||
|
pattern = self.regex.pattern
|
||||||
|
if pattern.startswith('^'):
|
||||||
|
pattern = pattern[1:]
|
||||||
|
if pattern.endswith('$'):
|
||||||
|
pattern = pattern[:-1]
|
||||||
|
|
||||||
|
if self.regex.groups != pattern.count('('):
|
||||||
|
# The pattern is too complicated for our simplistic matching,
|
||||||
|
# so we can't support reversing it.
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
pieces = []
|
||||||
|
for fragment in pattern.split('('):
|
||||||
|
if ')' in fragment:
|
||||||
|
paren_loc = fragment.index(')')
|
||||||
|
if paren_loc >= 0:
|
||||||
|
pieces.append('%s' + fragment[paren_loc + 1:])
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
unescaped_fragment = re_unescape(fragment)
|
||||||
|
except ValueError as exc:
|
||||||
|
# If we can't unescape part of it, we can't
|
||||||
|
# reverse this url.
|
||||||
|
return (None, None)
|
||||||
|
pieces.append(unescaped_fragment)
|
||||||
|
|
||||||
|
return ''.join(pieces), self.regex.groups
|
||||||
|
|
||||||
|
|
||||||
|
class URLSpec(Rule):
|
||||||
|
"""Specifies mappings between URLs and handlers.
|
||||||
|
|
||||||
|
.. versionchanged: 4.5
|
||||||
|
`URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for
|
||||||
|
backwards compatibility.
|
||||||
|
"""
|
||||||
|
def __init__(self, pattern, handler, kwargs=None, name=None):
|
||||||
|
"""Parameters:
|
||||||
|
|
||||||
|
* ``pattern``: Regular expression to be matched. Any capturing
|
||||||
|
groups in the regex will be passed in to the handler's
|
||||||
|
get/post/etc methods as arguments (by keyword if named, by
|
||||||
|
position if unnamed. Named and unnamed capturing groups may
|
||||||
|
may not be mixed in the same rule).
|
||||||
|
|
||||||
|
* ``handler``: `~.web.RequestHandler` subclass to be invoked.
|
||||||
|
|
||||||
|
* ``kwargs`` (optional): A dictionary of additional arguments
|
||||||
|
to be passed to the handler's constructor.
|
||||||
|
|
||||||
|
* ``name`` (optional): A name for this handler. Used by
|
||||||
|
`~.web.Application.reverse_url`.
|
||||||
|
|
||||||
|
"""
|
||||||
|
super(URLSpec, self).__init__(PathMatches(pattern), handler, kwargs, name)
|
||||||
|
|
||||||
|
self.regex = self.matcher.regex
|
||||||
|
self.handler_class = self.target
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return '%s(%r, %s, kwargs=%r, name=%r)' % \
|
||||||
|
(self.__class__.__name__, self.regex.pattern,
|
||||||
|
self.handler_class, self.kwargs, self.name)
|
||||||
|
|
||||||
|
|
||||||
|
def _unquote_or_none(s):
|
||||||
|
"""None-safe wrapper around url_unescape to handle unmatched optional
|
||||||
|
groups correctly.
|
||||||
|
|
||||||
|
Note that args are passed as bytes so the handler can decide what
|
||||||
|
encoding to use.
|
||||||
|
"""
|
||||||
|
if s is None:
|
||||||
|
return s
|
||||||
|
return url_unescape(s, encoding=None, plus=False)
|
|
@ -1,5 +1,5 @@
|
||||||
#!/usr/bin/env python
|
#!/usr/bin/env python
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
from tornado.escape import utf8, _unicode
|
from tornado.escape import utf8, _unicode
|
||||||
from tornado import gen
|
from tornado import gen
|
||||||
|
@ -11,6 +11,7 @@ from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
|
||||||
from tornado.log import gen_log
|
from tornado.log import gen_log
|
||||||
from tornado import stack_context
|
from tornado import stack_context
|
||||||
from tornado.tcpclient import TCPClient
|
from tornado.tcpclient import TCPClient
|
||||||
|
from tornado.util import PY3
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import collections
|
import collections
|
||||||
|
@ -22,10 +23,10 @@ import sys
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
|
||||||
|
|
||||||
try:
|
if PY3:
|
||||||
import urlparse # py2
|
import urllib.parse as urlparse
|
||||||
except ImportError:
|
else:
|
||||||
import urllib.parse as urlparse # py3
|
import urlparse
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import ssl
|
import ssl
|
||||||
|
@ -126,7 +127,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
|
||||||
timeout_handle = self.io_loop.add_timeout(
|
timeout_handle = self.io_loop.add_timeout(
|
||||||
self.io_loop.time() + min(request.connect_timeout,
|
self.io_loop.time() + min(request.connect_timeout,
|
||||||
request.request_timeout),
|
request.request_timeout),
|
||||||
functools.partial(self._on_timeout, key))
|
functools.partial(self._on_timeout, key, "in request queue"))
|
||||||
else:
|
else:
|
||||||
timeout_handle = None
|
timeout_handle = None
|
||||||
self.waiting[key] = (request, callback, timeout_handle)
|
self.waiting[key] = (request, callback, timeout_handle)
|
||||||
|
@ -167,11 +168,20 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
|
||||||
self.io_loop.remove_timeout(timeout_handle)
|
self.io_loop.remove_timeout(timeout_handle)
|
||||||
del self.waiting[key]
|
del self.waiting[key]
|
||||||
|
|
||||||
def _on_timeout(self, key):
|
def _on_timeout(self, key, info=None):
|
||||||
|
"""Timeout callback of request.
|
||||||
|
|
||||||
|
Construct a timeout HTTPResponse when a timeout occurs.
|
||||||
|
|
||||||
|
:arg object key: A simple object to mark the request.
|
||||||
|
:info string key: More detailed timeout information.
|
||||||
|
"""
|
||||||
request, callback, timeout_handle = self.waiting[key]
|
request, callback, timeout_handle = self.waiting[key]
|
||||||
self.queue.remove((key, request, callback))
|
self.queue.remove((key, request, callback))
|
||||||
|
|
||||||
|
error_message = "Timeout {0}".format(info) if info else "Timeout"
|
||||||
timeout_response = HTTPResponse(
|
timeout_response = HTTPResponse(
|
||||||
request, 599, error=HTTPError(599, "Timeout"),
|
request, 599, error=HTTPError(599, error_message),
|
||||||
request_time=self.io_loop.time() - request.start_time)
|
request_time=self.io_loop.time() - request.start_time)
|
||||||
self.io_loop.add_callback(callback, timeout_response)
|
self.io_loop.add_callback(callback, timeout_response)
|
||||||
del self.waiting[key]
|
del self.waiting[key]
|
||||||
|
@ -229,7 +239,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
||||||
if timeout:
|
if timeout:
|
||||||
self._timeout = self.io_loop.add_timeout(
|
self._timeout = self.io_loop.add_timeout(
|
||||||
self.start_time + timeout,
|
self.start_time + timeout,
|
||||||
stack_context.wrap(self._on_timeout))
|
stack_context.wrap(functools.partial(self._on_timeout, "while connecting")))
|
||||||
self.tcp_client.connect(host, port, af=af,
|
self.tcp_client.connect(host, port, af=af,
|
||||||
ssl_options=ssl_options,
|
ssl_options=ssl_options,
|
||||||
max_buffer_size=self.max_buffer_size,
|
max_buffer_size=self.max_buffer_size,
|
||||||
|
@ -284,10 +294,17 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
||||||
return ssl_options
|
return ssl_options
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _on_timeout(self):
|
def _on_timeout(self, info=None):
|
||||||
|
"""Timeout callback of _HTTPConnection instance.
|
||||||
|
|
||||||
|
Raise a timeout HTTPError when a timeout occurs.
|
||||||
|
|
||||||
|
:info string key: More detailed timeout information.
|
||||||
|
"""
|
||||||
self._timeout = None
|
self._timeout = None
|
||||||
|
error_message = "Timeout {0}".format(info) if info else "Timeout"
|
||||||
if self.final_callback is not None:
|
if self.final_callback is not None:
|
||||||
raise HTTPError(599, "Timeout")
|
raise HTTPError(599, error_message)
|
||||||
|
|
||||||
def _remove_timeout(self):
|
def _remove_timeout(self):
|
||||||
if self._timeout is not None:
|
if self._timeout is not None:
|
||||||
|
@ -307,13 +324,14 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
||||||
if self.request.request_timeout:
|
if self.request.request_timeout:
|
||||||
self._timeout = self.io_loop.add_timeout(
|
self._timeout = self.io_loop.add_timeout(
|
||||||
self.start_time + self.request.request_timeout,
|
self.start_time + self.request.request_timeout,
|
||||||
stack_context.wrap(self._on_timeout))
|
stack_context.wrap(functools.partial(self._on_timeout, "during request")))
|
||||||
if (self.request.method not in self._SUPPORTED_METHODS and
|
if (self.request.method not in self._SUPPORTED_METHODS and
|
||||||
not self.request.allow_nonstandard_methods):
|
not self.request.allow_nonstandard_methods):
|
||||||
raise KeyError("unknown method %s" % self.request.method)
|
raise KeyError("unknown method %s" % self.request.method)
|
||||||
for key in ('network_interface',
|
for key in ('network_interface',
|
||||||
'proxy_host', 'proxy_port',
|
'proxy_host', 'proxy_port',
|
||||||
'proxy_username', 'proxy_password'):
|
'proxy_username', 'proxy_password',
|
||||||
|
'proxy_auth_mode'):
|
||||||
if getattr(self.request, key, None):
|
if getattr(self.request, key, None):
|
||||||
raise NotImplementedError('%s not supported' % key)
|
raise NotImplementedError('%s not supported' % key)
|
||||||
if "Connection" not in self.request.headers:
|
if "Connection" not in self.request.headers:
|
||||||
|
@ -481,7 +499,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
|
||||||
def _should_follow_redirect(self):
|
def _should_follow_redirect(self):
|
||||||
return (self.request.follow_redirects and
|
return (self.request.follow_redirects and
|
||||||
self.request.max_redirects > 0 and
|
self.request.max_redirects > 0 and
|
||||||
self.code in (301, 302, 303, 307))
|
self.code in (301, 302, 303, 307, 308))
|
||||||
|
|
||||||
def finish(self):
|
def finish(self):
|
||||||
data = b''.join(self.chunks)
|
data = b''.join(self.chunks)
|
||||||
|
|
|
@ -67,7 +67,7 @@ Here are a few rules of thumb for when it's necessary:
|
||||||
block that references your `StackContext`.
|
block that references your `StackContext`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
|
@ -82,6 +82,8 @@ class StackContextInconsistentError(Exception):
|
||||||
class _State(threading.local):
|
class _State(threading.local):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.contexts = (tuple(), None)
|
self.contexts = (tuple(), None)
|
||||||
|
|
||||||
|
|
||||||
_state = _State()
|
_state = _State()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -16,7 +16,7 @@
|
||||||
|
|
||||||
"""A non-blocking TCP connection factory.
|
"""A non-blocking TCP connection factory.
|
||||||
"""
|
"""
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
import socket
|
import socket
|
||||||
|
@ -155,16 +155,30 @@ class TCPClient(object):
|
||||||
|
|
||||||
@gen.coroutine
|
@gen.coroutine
|
||||||
def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
|
def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
|
||||||
max_buffer_size=None):
|
max_buffer_size=None, source_ip=None, source_port=None):
|
||||||
"""Connect to the given host and port.
|
"""Connect to the given host and port.
|
||||||
|
|
||||||
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
|
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
|
||||||
``ssl_options`` is not None).
|
``ssl_options`` is not None).
|
||||||
|
|
||||||
|
Using the ``source_ip`` kwarg, one can specify the source
|
||||||
|
IP address to use when establishing the connection.
|
||||||
|
In case the user needs to resolve and
|
||||||
|
use a specific interface, it has to be handled outside
|
||||||
|
of Tornado as this depends very much on the platform.
|
||||||
|
|
||||||
|
Similarly, when the user requires a certain source port, it can
|
||||||
|
be specified using the ``source_port`` arg.
|
||||||
|
|
||||||
|
.. versionchanged:: 4.5
|
||||||
|
Added the ``source_ip`` and ``source_port`` arguments.
|
||||||
"""
|
"""
|
||||||
addrinfo = yield self.resolver.resolve(host, port, af)
|
addrinfo = yield self.resolver.resolve(host, port, af)
|
||||||
connector = _Connector(
|
connector = _Connector(
|
||||||
addrinfo, self.io_loop,
|
addrinfo, self.io_loop,
|
||||||
functools.partial(self._create_stream, max_buffer_size))
|
functools.partial(self._create_stream, max_buffer_size,
|
||||||
|
source_ip=source_ip, source_port=source_port)
|
||||||
|
)
|
||||||
af, addr, stream = yield connector.start()
|
af, addr, stream = yield connector.start()
|
||||||
# TODO: For better performance we could cache the (af, addr)
|
# TODO: For better performance we could cache the (af, addr)
|
||||||
# information here and re-use it on subsequent connections to
|
# information here and re-use it on subsequent connections to
|
||||||
|
@ -174,10 +188,35 @@ class TCPClient(object):
|
||||||
server_hostname=host)
|
server_hostname=host)
|
||||||
raise gen.Return(stream)
|
raise gen.Return(stream)
|
||||||
|
|
||||||
def _create_stream(self, max_buffer_size, af, addr):
|
def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
|
||||||
|
source_port=None):
|
||||||
# Always connect in plaintext; we'll convert to ssl if necessary
|
# Always connect in plaintext; we'll convert to ssl if necessary
|
||||||
# after one connection has completed.
|
# after one connection has completed.
|
||||||
stream = IOStream(socket.socket(af),
|
source_port_bind = source_port if isinstance(source_port, int) else 0
|
||||||
io_loop=self.io_loop,
|
source_ip_bind = source_ip
|
||||||
max_buffer_size=max_buffer_size)
|
if source_port_bind and not source_ip:
|
||||||
return stream.connect(addr)
|
# User required a specific port, but did not specify
|
||||||
|
# a certain source IP, will bind to the default loopback.
|
||||||
|
source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1'
|
||||||
|
# Trying to use the same address family as the requested af socket:
|
||||||
|
# - 127.0.0.1 for IPv4
|
||||||
|
# - ::1 for IPv6
|
||||||
|
socket_obj = socket.socket(af)
|
||||||
|
if source_port_bind or source_ip_bind:
|
||||||
|
# If the user requires binding also to a specific IP/port.
|
||||||
|
try:
|
||||||
|
socket_obj.bind((source_ip_bind, source_port_bind))
|
||||||
|
except socket.error:
|
||||||
|
socket_obj.close()
|
||||||
|
# Fail loudly if unable to use the IP/port.
|
||||||
|
raise
|
||||||
|
try:
|
||||||
|
stream = IOStream(socket_obj,
|
||||||
|
io_loop=self.io_loop,
|
||||||
|
max_buffer_size=max_buffer_size)
|
||||||
|
except socket.error as e:
|
||||||
|
fu = Future()
|
||||||
|
fu.set_exception(e)
|
||||||
|
return fu
|
||||||
|
else:
|
||||||
|
return stream.connect(addr)
|
||||||
|
|
|
@ -15,12 +15,13 @@
|
||||||
# under the License.
|
# under the License.
|
||||||
|
|
||||||
"""A non-blocking, single-threaded TCP server."""
|
"""A non-blocking, single-threaded TCP server."""
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import errno
|
import errno
|
||||||
import os
|
import os
|
||||||
import socket
|
import socket
|
||||||
|
|
||||||
|
from tornado import gen
|
||||||
from tornado.log import app_log
|
from tornado.log import app_log
|
||||||
from tornado.ioloop import IOLoop
|
from tornado.ioloop import IOLoop
|
||||||
from tornado.iostream import IOStream, SSLIOStream
|
from tornado.iostream import IOStream, SSLIOStream
|
||||||
|
@ -39,7 +40,21 @@ class TCPServer(object):
|
||||||
r"""A non-blocking, single-threaded TCP server.
|
r"""A non-blocking, single-threaded TCP server.
|
||||||
|
|
||||||
To use `TCPServer`, define a subclass which overrides the `handle_stream`
|
To use `TCPServer`, define a subclass which overrides the `handle_stream`
|
||||||
method.
|
method. For example, a simple echo server could be defined like this::
|
||||||
|
|
||||||
|
from tornado.tcpserver import TCPServer
|
||||||
|
from tornado.iostream import StreamClosedError
|
||||||
|
from tornado import gen
|
||||||
|
|
||||||
|
class EchoServer(TCPServer):
|
||||||
|
@gen.coroutine
|
||||||
|
def handle_stream(self, stream, address):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
data = yield stream.read_until(b"\n")
|
||||||
|
yield stream.write(data)
|
||||||
|
except StreamClosedError:
|
||||||
|
break
|
||||||
|
|
||||||
To make this server serve SSL traffic, send the ``ssl_options`` keyword
|
To make this server serve SSL traffic, send the ``ssl_options`` keyword
|
||||||
argument with an `ssl.SSLContext` object. For compatibility with older
|
argument with an `ssl.SSLContext` object. For compatibility with older
|
||||||
|
@ -95,6 +110,7 @@ class TCPServer(object):
|
||||||
self._sockets = {} # fd -> socket object
|
self._sockets = {} # fd -> socket object
|
||||||
self._pending_sockets = []
|
self._pending_sockets = []
|
||||||
self._started = False
|
self._started = False
|
||||||
|
self._stopped = False
|
||||||
self.max_buffer_size = max_buffer_size
|
self.max_buffer_size = max_buffer_size
|
||||||
self.read_chunk_size = read_chunk_size
|
self.read_chunk_size = read_chunk_size
|
||||||
|
|
||||||
|
@ -147,7 +163,8 @@ class TCPServer(object):
|
||||||
"""Singular version of `add_sockets`. Takes a single socket object."""
|
"""Singular version of `add_sockets`. Takes a single socket object."""
|
||||||
self.add_sockets([socket])
|
self.add_sockets([socket])
|
||||||
|
|
||||||
def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128):
|
def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128,
|
||||||
|
reuse_port=False):
|
||||||
"""Binds this server to the given port on the given address.
|
"""Binds this server to the given port on the given address.
|
||||||
|
|
||||||
To start the server, call `start`. If you want to run this server
|
To start the server, call `start`. If you want to run this server
|
||||||
|
@ -162,13 +179,17 @@ class TCPServer(object):
|
||||||
both will be used if available.
|
both will be used if available.
|
||||||
|
|
||||||
The ``backlog`` argument has the same meaning as for
|
The ``backlog`` argument has the same meaning as for
|
||||||
`socket.listen <socket.socket.listen>`.
|
`socket.listen <socket.socket.listen>`. The ``reuse_port`` argument
|
||||||
|
has the same meaning as for `.bind_sockets`.
|
||||||
|
|
||||||
This method may be called multiple times prior to `start` to listen
|
This method may be called multiple times prior to `start` to listen
|
||||||
on multiple ports or interfaces.
|
on multiple ports or interfaces.
|
||||||
|
|
||||||
|
.. versionchanged:: 4.4
|
||||||
|
Added the ``reuse_port`` argument.
|
||||||
"""
|
"""
|
||||||
sockets = bind_sockets(port, address=address, family=family,
|
sockets = bind_sockets(port, address=address, family=family,
|
||||||
backlog=backlog)
|
backlog=backlog, reuse_port=reuse_port)
|
||||||
if self._started:
|
if self._started:
|
||||||
self.add_sockets(sockets)
|
self.add_sockets(sockets)
|
||||||
else:
|
else:
|
||||||
|
@ -208,7 +229,11 @@ class TCPServer(object):
|
||||||
Requests currently in progress may still continue after the
|
Requests currently in progress may still continue after the
|
||||||
server is stopped.
|
server is stopped.
|
||||||
"""
|
"""
|
||||||
|
if self._stopped:
|
||||||
|
return
|
||||||
|
self._stopped = True
|
||||||
for fd, sock in self._sockets.items():
|
for fd, sock in self._sockets.items():
|
||||||
|
assert sock.fileno() == fd
|
||||||
self.io_loop.remove_handler(fd)
|
self.io_loop.remove_handler(fd)
|
||||||
sock.close()
|
sock.close()
|
||||||
|
|
||||||
|
@ -266,8 +291,10 @@ class TCPServer(object):
|
||||||
stream = IOStream(connection, io_loop=self.io_loop,
|
stream = IOStream(connection, io_loop=self.io_loop,
|
||||||
max_buffer_size=self.max_buffer_size,
|
max_buffer_size=self.max_buffer_size,
|
||||||
read_chunk_size=self.read_chunk_size)
|
read_chunk_size=self.read_chunk_size)
|
||||||
|
|
||||||
future = self.handle_stream(stream, address)
|
future = self.handle_stream(stream, address)
|
||||||
if future is not None:
|
if future is not None:
|
||||||
self.io_loop.add_future(future, lambda f: f.result())
|
self.io_loop.add_future(gen.convert_yielded(future),
|
||||||
|
lambda f: f.result())
|
||||||
except Exception:
|
except Exception:
|
||||||
app_log.error("Error in connection callback", exc_info=True)
|
app_log.error("Error in connection callback", exc_info=True)
|
||||||
|
|
|
@ -19,13 +19,13 @@
|
||||||
Basic usage looks like::
|
Basic usage looks like::
|
||||||
|
|
||||||
t = template.Template("<html>{{ myvalue }}</html>")
|
t = template.Template("<html>{{ myvalue }}</html>")
|
||||||
print t.generate(myvalue="XXX")
|
print(t.generate(myvalue="XXX"))
|
||||||
|
|
||||||
`Loader` is a class that loads templates from a root directory and caches
|
`Loader` is a class that loads templates from a root directory and caches
|
||||||
the compiled templates::
|
the compiled templates::
|
||||||
|
|
||||||
loader = template.Loader("/home/btaylor")
|
loader = template.Loader("/home/btaylor")
|
||||||
print loader.load("test.html").generate(myvalue="XXX")
|
print(loader.load("test.html").generate(myvalue="XXX"))
|
||||||
|
|
||||||
We compile all templates to raw Python. Error-reporting is currently... uh,
|
We compile all templates to raw Python. Error-reporting is currently... uh,
|
||||||
interesting. Syntax for the templates::
|
interesting. Syntax for the templates::
|
||||||
|
@ -94,12 +94,15 @@ Syntax Reference
|
||||||
Template expressions are surrounded by double curly braces: ``{{ ... }}``.
|
Template expressions are surrounded by double curly braces: ``{{ ... }}``.
|
||||||
The contents may be any python expression, which will be escaped according
|
The contents may be any python expression, which will be escaped according
|
||||||
to the current autoescape setting and inserted into the output. Other
|
to the current autoescape setting and inserted into the output. Other
|
||||||
template directives use ``{% %}``. These tags may be escaped as ``{{!``
|
template directives use ``{% %}``.
|
||||||
and ``{%!`` if you need to include a literal ``{{`` or ``{%`` in the output.
|
|
||||||
|
|
||||||
To comment out a section so that it is omitted from the output, surround it
|
To comment out a section so that it is omitted from the output, surround it
|
||||||
with ``{# ... #}``.
|
with ``{# ... #}``.
|
||||||
|
|
||||||
|
These tags may be escaped as ``{{!``, ``{%!``, and ``{#!``
|
||||||
|
if you need to include a literal ``{{``, ``{%``, or ``{#`` in the output.
|
||||||
|
|
||||||
|
|
||||||
``{% apply *function* %}...{% end %}``
|
``{% apply *function* %}...{% end %}``
|
||||||
Applies a function to the output of all template code between ``apply``
|
Applies a function to the output of all template code between ``apply``
|
||||||
and ``end``::
|
and ``end``::
|
||||||
|
@ -193,7 +196,7 @@ with ``{# ... #}``.
|
||||||
`filter_whitespace` for available options. New in Tornado 4.3.
|
`filter_whitespace` for available options. New in Tornado 4.3.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import absolute_import, division, print_function, with_statement
|
from __future__ import absolute_import, division, print_function
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import linecache
|
import linecache
|
||||||
|
@ -204,12 +207,12 @@ import threading
|
||||||
|
|
||||||
from tornado import escape
|
from tornado import escape
|
||||||
from tornado.log import app_log
|
from tornado.log import app_log
|
||||||
from tornado.util import ObjectDict, exec_in, unicode_type
|
from tornado.util import ObjectDict, exec_in, unicode_type, PY3
|
||||||
|
|
||||||
try:
|
if PY3:
|
||||||
from cStringIO import StringIO # py2
|
from io import StringIO
|
||||||
except ImportError:
|
else:
|
||||||
from io import StringIO # py3
|
from cStringIO import StringIO
|
||||||
|
|
||||||
_DEFAULT_AUTOESCAPE = "xhtml_escape"
|
_DEFAULT_AUTOESCAPE = "xhtml_escape"
|
||||||
_UNSET = object()
|
_UNSET = object()
|
||||||
|
@ -665,7 +668,7 @@ class ParseError(Exception):
|
||||||
.. versionchanged:: 4.3
|
.. versionchanged:: 4.3
|
||||||
Added ``filename`` and ``lineno`` attributes.
|
Added ``filename`` and ``lineno`` attributes.
|
||||||
"""
|
"""
|
||||||
def __init__(self, message, filename, lineno):
|
def __init__(self, message, filename=None, lineno=0):
|
||||||
self.message = message
|
self.message = message
|
||||||
# The names "filename" and "lineno" are chosen for consistency
|
# The names "filename" and "lineno" are chosen for consistency
|
||||||
# with python SyntaxError.
|
# with python SyntaxError.
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue