web端做了很大改动,尚未完成。

pull/105/head
Apex Liu 2017-10-31 14:52:03 +08:00
parent 2d0ce5da20
commit 51b143c828
494 changed files with 87356 additions and 64150 deletions

View File

@ -1,8 +1,8 @@
# mako/__init__.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
__version__ = '1.0.3'
__version__ = '1.0.6'

View File

@ -1,5 +1,5 @@
# mako/_ast_util.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# mako/ast.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# mako/cache.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# mako/cmd.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# mako/codegen.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -5,6 +5,7 @@ py3k = sys.version_info >= (3, 0)
py33 = sys.version_info >= (3, 3)
py2k = sys.version_info < (3,)
py26 = sys.version_info >= (2, 6)
py27 = sys.version_info >= (2, 7)
jython = sys.platform.startswith('java')
win32 = sys.platform.startswith('win')
pypy = hasattr(sys, 'pypy_version_info')

View File

@ -1,5 +1,5 @@
# mako/exceptions.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# ext/autohandler.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# ext/babelplugin.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# ext/preprocessors.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# ext/pygmentplugin.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# ext/turbogears.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# mako/filters.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# mako/lexer.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -95,31 +95,37 @@ class Lexer(object):
# (match and "TRUE" or "FALSE")
return match
def parse_until_text(self, *text):
def parse_until_text(self, watch_nesting, *text):
startpos = self.match_position
text_re = r'|'.join(text)
brace_level = 0
paren_level = 0
bracket_level = 0
while True:
match = self.match(r'#.*\n')
if match:
continue
match = self.match(r'(\"\"\"|\'\'\'|\"|\')((?<!\\)\\\1|.)*?\1',
match = self.match(r'(\"\"\"|\'\'\'|\"|\')[^\\]*?(\\.[^\\]*?)*\1',
re.S)
if match:
continue
match = self.match(r'(%s)' % text_re)
if match:
if match.group(1) == '}' and brace_level > 0:
brace_level -= 1
continue
if match and not (watch_nesting
and (brace_level > 0 or paren_level > 0
or bracket_level > 0)):
return \
self.text[startpos:
self.match_position - len(match.group(1))],\
match.group(1)
match = self.match(r"(.*?)(?=\"|\'|#|%s)" % text_re, re.S)
elif not match:
match = self.match(r"(.*?)(?=\"|\'|#|%s)" % text_re, re.S)
if match:
brace_level += match.group(1).count('{')
brace_level -= match.group(1).count('}')
paren_level += match.group(1).count('(')
paren_level -= match.group(1).count(')')
bracket_level += match.group(1).count('[')
bracket_level -= match.group(1).count(']')
continue
raise exceptions.SyntaxException(
"Expected: %s" %
@ -368,7 +374,7 @@ class Lexer(object):
match = self.match(r"<%(!)?")
if match:
line, pos = self.matched_lineno, self.matched_charpos
text, end = self.parse_until_text(r'%>')
text, end = self.parse_until_text(False, r'%>')
# the trailing newline helps
# compiler.parse() not complain about indentation
text = adjust_whitespace(text) + "\n"
@ -384,9 +390,9 @@ class Lexer(object):
match = self.match(r"\${")
if match:
line, pos = self.matched_lineno, self.matched_charpos
text, end = self.parse_until_text(r'\|', r'}')
text, end = self.parse_until_text(True, r'\|', r'}')
if end == '|':
escapes, end = self.parse_until_text(r'}')
escapes, end = self.parse_until_text(True, r'}')
else:
escapes = ""
text = text.replace('\r\n', '\n')

View File

@ -1,5 +1,5 @@
# mako/lookup.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -96,7 +96,7 @@ class TemplateLookup(TemplateCollection):
.. sourcecode:: python
lookup = TemplateLookup(["/path/to/templates"])
some_template = lookup.get_template("/admin_index.mako")
some_template = lookup.get_template("/index.html")
The :class:`.TemplateLookup` can also be given :class:`.Template` objects
programatically using :meth:`.put_string` or :meth:`.put_template`:
@ -180,7 +180,8 @@ class TemplateLookup(TemplateCollection):
enable_loop=True,
input_encoding=None,
preprocessor=None,
lexer_cls=None):
lexer_cls=None,
include_error_handler=None):
self.directories = [posixpath.normpath(d) for d in
util.to_list(directories, ())
@ -203,6 +204,7 @@ class TemplateLookup(TemplateCollection):
self.template_args = {
'format_exceptions': format_exceptions,
'error_handler': error_handler,
'include_error_handler': include_error_handler,
'disable_unicode': disable_unicode,
'bytestring_passthrough': bytestring_passthrough,
'output_encoding': output_encoding,

View File

@ -1,5 +1,5 @@
# mako/parsetree.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# mako/pygen.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# mako/pyparser.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1,5 +1,5 @@
# mako/runtime.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -749,7 +749,16 @@ def _include_file(context, uri, calling_uri, **kwargs):
(callable_, ctx) = _populate_self_namespace(
context._clean_inheritance_tokens(),
template)
callable_(ctx, **_kwargs_for_include(callable_, context._data, **kwargs))
kwargs = _kwargs_for_include(callable_, context._data, **kwargs)
if template.include_error_handler:
try:
callable_(ctx, **kwargs)
except Exception:
result = template.include_error_handler(ctx, compat.exception_as())
if not result:
compat.reraise(*sys.exc_info())
else:
callable_(ctx, **kwargs)
def _inherit_from(context, uri, calling_uri):

View File

@ -1,5 +1,5 @@
# mako/template.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php
@ -109,6 +109,11 @@ class Template(object):
completes. Is used to provide custom error-rendering
functions.
.. seealso::
:paramref:`.Template.include_error_handler` - include-specific
error handler function
:param format_exceptions: if ``True``, exceptions which occur during
the render phase of this template will be caught and
formatted into an HTML error page, which then becomes the
@ -129,6 +134,16 @@ class Template(object):
import will not appear as the first executed statement in the generated
code and will therefore not have the desired effect.
:param include_error_handler: An error handler that runs when this template
is included within another one via the ``<%include>`` tag, and raises an
error. Compare to the :paramref:`.Template.error_handler` option.
.. versionadded:: 1.0.6
.. seealso::
:paramref:`.Template.error_handler` - top-level error handler function
:param input_encoding: Encoding of the template's source code. Can
be used in lieu of the coding comment. See
:ref:`usage_unicode` as well as :ref:`unicode_toplevel` for
@ -171,7 +186,7 @@ class Template(object):
from mako.template import Template
mytemplate = Template(
filename="admin_index.mako",
filename="index.html",
module_directory="/path/to/modules",
module_writer=module_writer
)
@ -243,7 +258,8 @@ class Template(object):
future_imports=None,
enable_loop=True,
preprocessor=None,
lexer_cls=None):
lexer_cls=None,
include_error_handler=None):
if uri:
self.module_id = re.sub(r'\W', "_", uri)
self.uri = uri
@ -329,6 +345,7 @@ class Template(object):
self.callable_ = self.module.render_body
self.format_exceptions = format_exceptions
self.error_handler = error_handler
self.include_error_handler = include_error_handler
self.lookup = lookup
self.module_directory = module_directory
@ -475,6 +492,14 @@ class Template(object):
return DefTemplate(self, getattr(self.module, "render_%s" % name))
def list_defs(self):
"""return a list of defs in the template.
.. versionadded:: 1.0.4
"""
return [i[7:] for i in dir(self.module) if i[:7] == 'render_']
def _get_def_callable(self, name):
return getattr(self.module, "render_%s" % name)
@ -520,6 +545,7 @@ class ModuleTemplate(Template):
cache_type=None,
cache_dir=None,
cache_url=None,
include_error_handler=None,
):
self.module_id = re.sub(r'\W', "_", module._template_uri)
self.uri = module._template_uri
@ -551,6 +577,7 @@ class ModuleTemplate(Template):
self.callable_ = self.module.render_body
self.format_exceptions = format_exceptions
self.error_handler = error_handler
self.include_error_handler = include_error_handler
self.lookup = lookup
self._setup_cache_args(
cache_impl, cache_enabled, cache_args,
@ -571,6 +598,7 @@ class DefTemplate(Template):
self.encoding_errors = parent.encoding_errors
self.format_exceptions = parent.format_exceptions
self.error_handler = parent.error_handler
self.include_error_handler = parent.include_error_handler
self.enable_loop = parent.enable_loop
self.lookup = parent.lookup
self.bytestring_passthrough = parent.bytestring_passthrough

View File

@ -1,5 +1,5 @@
# mako/util.py
# Copyright (C) 2006-2015 the Mako authors and contributors <see AUTHORS file>
# Copyright (C) 2006-2016 the Mako authors and contributors <see AUTHORS file>
#
# This module is part of Mako and is released under
# the MIT License: http://www.opensource.org/licenses/mit-license.php

View File

@ -1 +0,0 @@
__version__ = '1.3.5'

View File

@ -1,12 +0,0 @@
# API Backwards compatibility
from pymemcache.client.base import Client # noqa
from pymemcache.client.base import PooledClient # noqa
from pymemcache.exceptions import MemcacheError # noqa
from pymemcache.exceptions import MemcacheClientError # noqa
from pymemcache.exceptions import MemcacheUnknownCommandError # noqa
from pymemcache.exceptions import MemcacheIllegalInputError # noqa
from pymemcache.exceptions import MemcacheServerError # noqa
from pymemcache.exceptions import MemcacheUnknownError # noqa
from pymemcache.exceptions import MemcacheUnexpectedCloseError # noqa

File diff suppressed because it is too large Load Diff

View File

@ -1,333 +0,0 @@
import socket
import time
import logging
from pymemcache.client.base import Client, PooledClient, _check_key
from pymemcache.client.rendezvous import RendezvousHash
logger = logging.getLogger(__name__)
class HashClient(object):
"""
A client for communicating with a cluster of memcached servers
"""
def __init__(
self,
servers,
hasher=RendezvousHash,
serializer=None,
deserializer=None,
connect_timeout=None,
timeout=None,
no_delay=False,
socket_module=socket,
key_prefix=b'',
max_pool_size=None,
lock_generator=None,
retry_attempts=2,
retry_timeout=1,
dead_timeout=60,
use_pooling=False,
ignore_exc=False,
):
"""
Constructor.
Args:
servers: list(tuple(hostname, port))
hasher: optional class three functions ``get_node``, ``add_node``,
and ``remove_node``
defaults to Rendezvous (HRW) hash.
use_pooling: use py:class:`.PooledClient` as the default underlying
class. ``max_pool_size`` and ``lock_generator`` can
be used with this. default: False
retry_attempts: Amount of times a client should be tried before it
is marked dead and removed from the pool.
retry_timeout (float): Time in seconds that should pass between retry
attempts.
dead_timeout (float): Time in seconds before attempting to add a node
back in the pool.
Further arguments are interpreted as for :py:class:`.Client`
constructor.
The default ``hasher`` is using a pure python implementation that can
be significantly improved performance wise by switching to a C based
version. We recommend using ``python-clandestined`` if having a C
dependency is acceptable.
"""
self.clients = {}
self.retry_attempts = retry_attempts
self.retry_timeout = retry_timeout
self.dead_timeout = dead_timeout
self.use_pooling = use_pooling
self.key_prefix = key_prefix
self.ignore_exc = ignore_exc
self._failed_clients = {}
self._dead_clients = {}
self._last_dead_check_time = time.time()
self.hasher = hasher()
self.default_kwargs = {
'connect_timeout': connect_timeout,
'timeout': timeout,
'no_delay': no_delay,
'socket_module': socket_module,
'key_prefix': key_prefix,
'serializer': serializer,
'deserializer': deserializer,
}
if use_pooling is True:
self.default_kwargs.update({
'max_pool_size': max_pool_size,
'lock_generator': lock_generator
})
for server, port in servers:
self.add_server(server, port)
def add_server(self, server, port):
key = '%s:%s' % (server, port)
if self.use_pooling:
client = PooledClient(
(server, port),
**self.default_kwargs
)
else:
client = Client((server, port), **self.default_kwargs)
self.clients[key] = client
self.hasher.add_node(key)
def remove_server(self, server, port):
dead_time = time.time()
self._failed_clients.pop((server, port))
self._dead_clients[(server, port)] = dead_time
key = '%s:%s' % (server, port)
self.hasher.remove_node(key)
def _get_client(self, key):
_check_key(key, self.key_prefix)
if len(self._dead_clients) > 0:
current_time = time.time()
ldc = self._last_dead_check_time
# we have dead clients and we have reached the
# timeout retry
if current_time - ldc > self.dead_timeout:
for server, dead_time in self._dead_clients.items():
if current_time - dead_time > self.dead_timeout:
logger.debug(
'bringing server back into rotation %s',
server
)
self.add_server(*server)
self._last_dead_check_time = current_time
server = self.hasher.get_node(key)
# We've ran out of servers to try
if server is None:
if self.ignore_exc is True:
return
raise Exception('All servers seem to be down right now')
client = self.clients[server]
return client
def _safely_run_func(self, client, func, default_val, *args, **kwargs):
try:
if client.server in self._failed_clients:
# This server is currently failing, lets check if it is in
# retry or marked as dead
failed_metadata = self._failed_clients[client.server]
# we haven't tried our max amount yet, if it has been enough
# time lets just retry using it
if failed_metadata['attempts'] < self.retry_attempts:
failed_time = failed_metadata['failed_time']
if time.time() - failed_time > self.retry_timeout:
logger.debug(
'retrying failed server: %s', client.server
)
result = func(*args, **kwargs)
# we were successful, lets remove it from the failed
# clients
self._failed_clients.pop(client.server)
return result
return default_val
else:
# We've reached our max retry attempts, we need to mark
# the sever as dead
logger.debug('marking server as dead: %s', client.server)
self.remove_server(*client.server)
result = func(*args, **kwargs)
return result
# Connecting to the server fail, we should enter
# retry mode
except socket.error:
# This client has never failed, lets mark it for failure
if (
client.server not in self._failed_clients and
self.retry_attempts > 0
):
self._failed_clients[client.server] = {
'failed_time': time.time(),
'attempts': 0,
}
# We aren't allowing any retries, we should mark the server as
# dead immediately
elif (
client.server not in self._failed_clients and
self.retry_attempts <= 0
):
self._failed_clients[client.server] = {
'failed_time': time.time(),
'attempts': 0,
}
logger.debug("marking server as dead %s", client.server)
self.remove_server(*client.server)
# This client has failed previously, we need to update the metadata
# to reflect that we have attempted it again
else:
failed_metadata = self._failed_clients[client.server]
failed_metadata['attempts'] += 1
failed_metadata['failed_time'] = time.time()
self._failed_clients[client.server] = failed_metadata
# if we haven't enabled ignore_exc, don't move on gracefully, just
# raise the exception
if not self.ignore_exc:
raise
return default_val
except:
# any exceptions that aren't socket.error we need to handle
# gracefully as well
if not self.ignore_exc:
raise
return default_val
def _run_cmd(self, cmd, key, default_val, *args, **kwargs):
client = self._get_client(key)
if client is None:
return False
func = getattr(client, cmd)
args = list(args)
args.insert(0, key)
return self._safely_run_func(
client, func, default_val, *args, **kwargs
)
def set(self, key, *args, **kwargs):
return self._run_cmd('set', key, False, *args, **kwargs)
def get(self, key, *args, **kwargs):
return self._run_cmd('get', key, None, *args, **kwargs)
def incr(self, key, *args, **kwargs):
return self._run_cmd('incr', key, False, *args, **kwargs)
def decr(self, key, *args, **kwargs):
return self._run_cmd('decr', key, False, *args, **kwargs)
def set_many(self, values, *args, **kwargs):
client_batches = {}
end = []
for key, value in values.items():
client = self._get_client(key)
if client is None:
end.append(False)
continue
if client.server not in client_batches:
client_batches[client.server] = {}
client_batches[client.server][key] = value
for server, values in client_batches.items():
client = self.clients['%s:%s' % server]
new_args = list(args)
new_args.insert(0, values)
result = self._safely_run_func(
client,
client.set_many, False, *new_args, **kwargs
)
end.append(result)
return all(end)
set_multi = set_many
def get_many(self, keys, *args, **kwargs):
client_batches = {}
end = {}
for key in keys:
client = self._get_client(key)
if client is None:
end[key] = False
continue
if client.server not in client_batches:
client_batches[client.server] = []
client_batches[client.server].append(key)
for server, keys in client_batches.items():
client = self.clients['%s:%s' % server]
new_args = list(args)
new_args.insert(0, keys)
result = self._safely_run_func(
client,
client.get_many, {}, *new_args, **kwargs
)
end.update(result)
return end
get_multi = get_many
def gets(self, key, *args, **kwargs):
return self._run_cmd('gets', key, None, *args, **kwargs)
def add(self, key, *args, **kwargs):
return self._run_cmd('add', key, False, *args, **kwargs)
def prepend(self, key, *args, **kwargs):
return self._run_cmd('prepend', key, False, *args, **kwargs)
def append(self, key, *args, **kwargs):
return self._run_cmd('append', key, False, *args, **kwargs)
def delete(self, key, *args, **kwargs):
return self._run_cmd('delete', key, False, *args, **kwargs)
def delete_many(self, keys, *args, **kwargs):
for key in keys:
self._run_cmd('delete', key, False, *args, **kwargs)
return True
delete_multi = delete_many
def cas(self, key, *args, **kwargs):
return self._run_cmd('cas', key, False, *args, **kwargs)
def replace(self, key, *args, **kwargs):
return self._run_cmd('replace', key, False, *args, **kwargs)
def flush_all(self):
for _, client in self.clients.items():
self._safely_run_func(client, client.flush_all, False)

View File

@ -1,51 +0,0 @@
def murmur3_32(data, seed=0):
"""MurmurHash3 was written by Austin Appleby, and is placed in the
public domain. The author hereby disclaims copyright to this source
code."""
c1 = 0xcc9e2d51
c2 = 0x1b873593
length = len(data)
h1 = seed
roundedEnd = (length & 0xfffffffc) # round down to 4 byte block
for i in range(0, roundedEnd, 4):
# little endian load order
k1 = (ord(data[i]) & 0xff) | ((ord(data[i + 1]) & 0xff) << 8) | \
((ord(data[i + 2]) & 0xff) << 16) | (ord(data[i + 3]) << 24)
k1 *= c1
k1 = (k1 << 15) | ((k1 & 0xffffffff) >> 17) # ROTL32(k1,15)
k1 *= c2
h1 ^= k1
h1 = (h1 << 13) | ((h1 & 0xffffffff) >> 19) # ROTL32(h1,13)
h1 = h1 * 5 + 0xe6546b64
# tail
k1 = 0
val = length & 0x03
if val == 3:
k1 = (ord(data[roundedEnd + 2]) & 0xff) << 16
# fallthrough
if val in [2, 3]:
k1 |= (ord(data[roundedEnd + 1]) & 0xff) << 8
# fallthrough
if val in [1, 2, 3]:
k1 |= ord(data[roundedEnd]) & 0xff
k1 *= c1
k1 = (k1 << 15) | ((k1 & 0xffffffff) >> 17) # ROTL32(k1,15)
k1 *= c2
h1 ^= k1
# finalization
h1 ^= length
# fmix(h1)
h1 ^= ((h1 & 0xffffffff) >> 16)
h1 *= 0x85ebca6b
h1 ^= ((h1 & 0xffffffff) >> 13)
h1 *= 0xc2b2ae35
h1 ^= ((h1 & 0xffffffff) >> 16)
return h1 & 0xffffffff

View File

@ -1,46 +0,0 @@
from pymemcache.client.murmur3 import murmur3_32
class RendezvousHash(object):
"""
Implements the Highest Random Weight (HRW) hashing algorithm most
commonly referred to as rendezvous hashing.
Originally developed as part of python-clandestined.
Copyright (c) 2014 Ernest W. Durbin III
"""
def __init__(self, nodes=None, seed=0, hash_function=murmur3_32):
"""
Constructor.
"""
self.nodes = []
self.seed = seed
if nodes is not None:
self.nodes = nodes
self.hash_function = lambda x: hash_function(x, seed)
def add_node(self, node):
if node not in self.nodes:
self.nodes.append(node)
def remove_node(self, node):
if node in self.nodes:
self.nodes.remove(node)
else:
raise ValueError("No such node %s to remove" % (node))
def get_node(self, key):
high_score = -1
winner = None
for node in self.nodes:
score = self.hash_function(
"%s-%s" % (str(node), str(key)))
if score > high_score:
(high_score, winner) = (score, node)
elif score == high_score:
(high_score, winner) = (score, max(str(node), str(winner)))
return winner

View File

@ -1,40 +0,0 @@
class MemcacheError(Exception):
"Base exception class"
pass
class MemcacheClientError(MemcacheError):
"""Raised when memcached fails to parse the arguments to a request, likely
due to a malformed key and/or value, a bug in this library, or a version
mismatch with memcached."""
pass
class MemcacheUnknownCommandError(MemcacheClientError):
"""Raised when memcached fails to parse a request, likely due to a bug in
this library or a version mismatch with memcached."""
pass
class MemcacheIllegalInputError(MemcacheClientError):
"""Raised when a key or value is not legal for Memcache (see the class docs
for Client for more details)."""
pass
class MemcacheServerError(MemcacheError):
"""Raised when memcached reports a failure while processing a request,
likely due to a bug or transient issue in memcached."""
pass
class MemcacheUnknownError(MemcacheError):
"""Raised when this library receives a response from memcached that it
cannot parse, likely due to a bug in this library or a version mismatch
with memcached."""
pass
class MemcacheUnexpectedCloseError(MemcacheServerError):
"Raised when the connection with memcached closes unexpectedly."
pass

View File

@ -1,123 +0,0 @@
# Copyright 2012 Pinterest.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
A client for falling back to older memcached servers when performing reads.
It is sometimes necessary to deploy memcached on new servers, or with a
different configuration. In theses cases, it is undesirable to start up an
empty memcached server and point traffic to it, since the cache will be cold,
and the backing store will have a large increase in traffic.
This class attempts to solve that problem by providing an interface identical
to the Client interface, but which can fall back to older memcached servers
when reads to the primary server fail. The approach for upgrading memcached
servers or configuration then becomes:
1. Deploy a new host (or fleet) with memcached, possibly with a new
configuration.
2. From your application servers, use FallbackClient to write and read from
the new cluster, and to read from the old cluster when there is a miss in
the new cluster.
3. Wait until the new cache is warm enough to support the load.
4. Switch from FallbackClient to a regular Client library for doing all
reads and writes to the new cluster.
5. Take down the old cluster.
Best Practices:
---------------
- Make sure that the old client has "ignore_exc" set to True, so that it
treats failures like cache misses. That will allow you to take down the
old cluster before you switch away from FallbackClient.
"""
class FallbackClient(object):
def __init__(self, caches):
assert len(caches) > 0
self.caches = caches
def close(self):
"Close each of the memcached clients"
for cache in self.caches:
cache.close()
def set(self, key, value, expire=0, noreply=True):
self.caches[0].set(key, value, expire, noreply)
def add(self, key, value, expire=0, noreply=True):
self.caches[0].add(key, value, expire, noreply)
def replace(self, key, value, expire=0, noreply=True):
self.caches[0].replace(key, value, expire, noreply)
def append(self, key, value, expire=0, noreply=True):
self.caches[0].append(key, value, expire, noreply)
def prepend(self, key, value, expire=0, noreply=True):
self.caches[0].prepend(key, value, expire, noreply)
def cas(self, key, value, cas, expire=0, noreply=True):
self.caches[0].cas(key, value, cas, expire, noreply)
def get(self, key):
for cache in self.caches:
result = cache.get(key)
if result is not None:
return result
return None
def get_many(self, keys):
for cache in self.caches:
result = cache.get_many(keys)
if result:
return result
return []
def gets(self, key):
for cache in self.caches:
result = cache.gets(key)
if result is not None:
return result
return None
def gets_many(self, keys):
for cache in self.caches:
result = cache.gets_many(keys)
if result:
return result
return []
def delete(self, key, noreply=True):
self.caches[0].delete(key, noreply)
def incr(self, key, value, noreply=True):
self.caches[0].incr(key, value, noreply)
def decr(self, key, value, noreply=True):
self.caches[0].decr(key, value, noreply)
def touch(self, key, expire=0, noreply=True):
self.caches[0].touch(key, expire, noreply)
def stats(self):
# TODO: ??
pass
def flush_all(self, delay=0, noreply=True):
self.caches[0].flush_all(delay, noreply)
def quit(self):
# TODO: ??
pass

View File

@ -1,114 +0,0 @@
# Copyright 2015 Yahoo.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import contextlib
import sys
import threading
import six
class ObjectPool(object):
"""A pool of objects that release/creates/destroys as needed."""
def __init__(self, obj_creator,
after_remove=None, max_size=None,
lock_generator=None):
self._used_objs = collections.deque()
self._free_objs = collections.deque()
self._obj_creator = obj_creator
if lock_generator is None:
self._lock = threading.Lock()
else:
self._lock = lock_generator()
self._after_remove = after_remove
max_size = max_size or 2 ** 31
if not isinstance(max_size, six.integer_types) or max_size < 0:
raise ValueError('"max_size" must be a positive integer')
self.max_size = max_size
@property
def used(self):
return tuple(self._used_objs)
@property
def free(self):
return tuple(self._free_objs)
@contextlib.contextmanager
def get_and_release(self, destroy_on_fail=False):
obj = self.get()
try:
yield obj
except Exception:
exc_info = sys.exc_info()
if not destroy_on_fail:
self.release(obj)
else:
self.destroy(obj)
six.reraise(exc_info[0], exc_info[1], exc_info[2])
self.release(obj)
def get(self):
with self._lock:
if not self._free_objs:
curr_count = len(self._used_objs)
if curr_count >= self.max_size:
raise RuntimeError("Too many objects,"
" %s >= %s" % (curr_count,
self.max_size))
obj = self._obj_creator()
self._used_objs.append(obj)
return obj
else:
obj = self._free_objs.pop()
self._used_objs.append(obj)
return obj
def destroy(self, obj, silent=True):
was_dropped = False
with self._lock:
try:
self._used_objs.remove(obj)
was_dropped = True
except ValueError:
if not silent:
raise
if was_dropped and self._after_remove is not None:
self._after_remove(obj)
def release(self, obj, silent=True):
with self._lock:
try:
self._used_objs.remove(obj)
self._free_objs.append(obj)
except ValueError:
if not silent:
raise
def clear(self):
if self._after_remove is not None:
needs_destroy = []
with self._lock:
needs_destroy.extend(self._used_objs)
needs_destroy.extend(self._free_objs)
self._free_objs.clear()
self._used_objs.clear()
for obj in needs_destroy:
self._after_remove(obj)
else:
with self._lock:
self._free_objs.clear()
self._used_objs.clear()

View File

@ -1,69 +0,0 @@
# Copyright 2012 Pinterest.com
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import pickle
try:
from cStringIO import StringIO
except ImportError:
from StringIO import StringIO
FLAG_PICKLE = 1 << 0
FLAG_INTEGER = 1 << 1
FLAG_LONG = 1 << 2
def python_memcache_serializer(key, value):
flags = 0
if isinstance(value, str):
pass
elif isinstance(value, int):
flags |= FLAG_INTEGER
value = "%d" % value
elif isinstance(value, long):
flags |= FLAG_LONG
value = "%d" % value
else:
flags |= FLAG_PICKLE
output = StringIO()
pickler = pickle.Pickler(output, 0)
pickler.dump(value)
value = output.getvalue()
return value, flags
def python_memcache_deserializer(key, value, flags):
if flags == 0:
return value
if flags & FLAG_INTEGER:
return int(value)
if flags & FLAG_LONG:
return long(value)
if flags & FLAG_PICKLE:
try:
buf = StringIO(value)
unpickler = pickle.Unpickler(buf)
return unpickler.load()
except Exception:
logging.info('Pickle error', exc_info=True)
return None
return value

View File

@ -1,7 +1,7 @@
'''
"""
PyMySQL: A pure-Python MySQL client library.
Copyright (c) 2010, 2013 PyMySQL contributors
Copyright (c) 2010-2016 PyMySQL contributors
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
@ -20,30 +20,29 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE.
'''
VERSION = (0, 6, 7, None)
from ._compat import text_type, JYTHON, IRONPYTHON
from .constants import FIELD_TYPE
from .converters import escape_dict, escape_sequence, escape_string
from .err import Warning, Error, InterfaceError, DataError, \
DatabaseError, OperationalError, IntegrityError, InternalError, \
NotSupportedError, ProgrammingError, MySQLError
from .times import Date, Time, Timestamp, \
DateFromTicks, TimeFromTicks, TimestampFromTicks
"""
import sys
from ._compat import PY2
from .constants import FIELD_TYPE
from .converters import escape_dict, escape_sequence, escape_string
from .err import (
Warning, Error, InterfaceError, DataError,
DatabaseError, OperationalError, IntegrityError, InternalError,
NotSupportedError, ProgrammingError, MySQLError)
from .times import (
Date, Time, Timestamp,
DateFromTicks, TimeFromTicks, TimestampFromTicks)
VERSION = (0, 7, 11, None)
threadsafety = 1
apilevel = "2.0"
paramstyle = "format"
paramstyle = "pyformat"
class DBAPISet(frozenset):
def __ne__(self, other):
if isinstance(other, set):
return frozenset.__ne__(self, other)
@ -73,11 +72,14 @@ TIMESTAMP = DBAPISet([FIELD_TYPE.TIMESTAMP, FIELD_TYPE.DATETIME])
DATETIME = TIMESTAMP
ROWID = DBAPISet()
def Binary(x):
"""Return x as a binary type."""
if isinstance(x, text_type) and not (JYTHON or IRONPYTHON):
return x.encode()
return bytes(x)
if PY2:
return bytearray(x)
else:
return bytes(x)
def Connect(*args, **kwargs):
"""
@ -87,27 +89,26 @@ def Connect(*args, **kwargs):
from .connections import Connection
return Connection(*args, **kwargs)
from pymysql import connections as _orig_conn
from . import connections as _orig_conn
if _orig_conn.Connection.__init__.__doc__ is not None:
Connect.__doc__ = _orig_conn.Connection.__init__.__doc__ + ("""
See connections.Connection.__init__() for information about defaults.
""")
Connect.__doc__ = _orig_conn.Connection.__init__.__doc__
del _orig_conn
def get_client_info(): # for MySQLdb compatibility
return '.'.join(map(str, VERSION))
connect = Connection = Connect
# we include a doctored version_info here for MySQLdb compatibility
version_info = (1,2,2,"final",0)
version_info = (1,2,6,"final",0)
NULL = "NULL"
__version__ = get_client_info()
def thread_safe():
return True # match MySQLdb.thread_safe()
return True # match MySQLdb.thread_safe()
def install_as_MySQLdb():
"""
@ -116,6 +117,7 @@ def install_as_MySQLdb():
"""
sys.modules["MySQLdb"] = sys.modules["_mysql"] = sys.modules["pymysql"]
__all__ = [
'BINARY', 'Binary', 'Connect', 'Connection', 'DATE', 'Date',
'Time', 'Timestamp', 'DateFromTicks', 'TimeFromTicks', 'TimestampFromTicks',
@ -128,6 +130,5 @@ __all__ = [
'paramstyle', 'threadsafety', 'version_info',
"install_as_MySQLdb",
"NULL","__version__",
]
"NULL", "__version__",
]

View File

@ -7,12 +7,15 @@ IRONPYTHON = sys.platform == 'cli'
CPYTHON = not PYPY and not JYTHON and not IRONPYTHON
if PY2:
import __builtin__
range_type = xrange
text_type = unicode
long_type = long
str_type = basestring
unichr = __builtin__.unichr
else:
range_type = range
text_type = str
long_type = int
str_type = str
unichr = chr

View File

@ -11,6 +11,10 @@ class Charset(object):
self.id, self.name, self.collation = id, name, collation
self.is_default = is_default == 'Yes'
def __repr__(self):
return "Charset(id=%s, name=%r, collation=%r)" % (
self.id, self.name, self.collation)
@property
def encoding(self):
name = self.name
@ -249,6 +253,10 @@ _charsets.add(Charset(240, 'utf8mb4', 'utf8mb4_persian_ci', ''))
_charsets.add(Charset(241, 'utf8mb4', 'utf8mb4_esperanto_ci', ''))
_charsets.add(Charset(242, 'utf8mb4', 'utf8mb4_hungarian_ci', ''))
_charsets.add(Charset(243, 'utf8mb4', 'utf8mb4_sinhala_ci', ''))
_charsets.add(Charset(244, 'utf8mb4', 'utf8mb4_german2_ci', ''))
_charsets.add(Charset(245, 'utf8mb4', 'utf8mb4_croatian_ci', ''))
_charsets.add(Charset(246, 'utf8mb4', 'utf8mb4_unicode_520_ci', ''))
_charsets.add(Charset(247, 'utf8mb4', 'utf8mb4_vietnamese_ci', ''))
charset_by_name = _charsets.by_name

View File

@ -17,9 +17,8 @@ import traceback
import warnings
from .charset import MBLENGTH, charset_by_name, charset_by_id
from .constants import CLIENT, COMMAND, FIELD_TYPE, SERVER_STATUS
from .converters import (
escape_item, encoders, decoders, escape_string, through)
from .constants import CLIENT, COMMAND, CR, FIELD_TYPE, SERVER_STATUS
from .converters import escape_item, escape_string, through, conversions as _conv
from .cursors import Cursor
from .optionfile import Parser
from .util import byte2int, int2byte
@ -36,7 +35,8 @@ try:
import getpass
DEFAULT_USER = getpass.getuser()
del getpass
except ImportError:
except (ImportError, KeyError):
# KeyError occurs when there's no entry in OS database for a current user.
DEFAULT_USER = None
@ -117,26 +117,24 @@ def dump_packet(data): # pragma: no cover
try:
print("packet length:", len(data))
print("method call[1]:", sys._getframe(1).f_code.co_name)
print("method call[2]:", sys._getframe(2).f_code.co_name)
print("method call[3]:", sys._getframe(3).f_code.co_name)
print("method call[4]:", sys._getframe(4).f_code.co_name)
print("method call[5]:", sys._getframe(5).f_code.co_name)
print("-" * 88)
for i in range(1, 6):
f = sys._getframe(i)
print("call[%d]: %s (line %d)" % (i, f.f_code.co_name, f.f_lineno))
print("-" * 66)
except ValueError:
pass
dump_data = [data[i:i+16] for i in range_type(0, min(len(data), 256), 16)]
for d in dump_data:
print(' '.join(map(lambda x: "{:02X}".format(byte2int(x)), d)) +
' ' * (16 - len(d)) + ' ' * 2 +
' '.join(map(lambda x: "{}".format(is_ascii(x)), d)))
print("-" * 88)
''.join(map(lambda x: "{}".format(is_ascii(x)), d)))
print("-" * 66)
print()
def _scramble(password, message):
if not password:
return b'\0'
return b''
if DEBUG: print('password=' + str(password))
stage1 = sha_new(password).digest()
stage2 = sha_new(stage1).digest()
@ -149,7 +147,7 @@ def _scramble(password, message):
def _my_crypt(message1, message2):
length = len(message1)
result = struct.pack('B', length)
result = b''
for i in range_type(length):
x = (struct.unpack('B', message1[i:i+1])[0] ^
struct.unpack('B', message2[i:i+1])[0])
@ -196,7 +194,8 @@ def _hash_password_323(password):
add = 7
nr2 = 0x12345671
for c in [byte2int(x) for x in password if x not in (' ', '\t')]:
# x in py3 is numbers, p27 is chars
for c in [byte2int(x) for x in password if x not in (' ', '\t', 32, 9)]:
nr ^= (((nr & 63) + add) * c) + (nr << 8) & 0xFFFFFFFF
nr2 = (nr2 + ((nr2 << 8) ^ nr)) & 0xFFFFFFFF
add = (add + c) & 0xFFFFFFFF
@ -209,6 +208,20 @@ def _hash_password_323(password):
def pack_int24(n):
return struct.pack('<I', n)[:3]
# https://dev.mysql.com/doc/internals/en/integer.html#packet-Protocol::LengthEncodedInteger
def lenenc_int(i):
if (i < 0):
raise ValueError("Encoding %d is less than 0 - no representation in LengthEncodedInteger" % i)
elif (i < 0xfb):
return int2byte(i)
elif (i < (1 << 16)):
return b'\xfc' + struct.pack('<H', i)
elif (i < (1 << 24)):
return b'\xfd' + struct.pack('<I', i)[:3]
elif (i < (1 << 64)):
return b'\xfe' + struct.pack('<Q', i)
else:
raise ValueError("Encoding %x is larger than %x - no representation in LengthEncodedInteger" % (i, (1 << 64)))
class MysqlPacket(object):
"""Representation of a MySQL response packet.
@ -303,6 +316,14 @@ class MysqlPacket(object):
self._position += 8
return result
def read_string(self):
end_pos = self._data.find(b'\0', self._position)
if end_pos < 0:
return None
result = self._data[self._position:end_pos]
self._position = end_pos + 1
return result
def read_length_encoded_integer(self):
"""Read a 'Length Coded Binary' number from the data buffer.
@ -340,13 +361,18 @@ class MysqlPacket(object):
return result
def is_ok_packet(self):
return self._data[0:1] == b'\0'
# https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html
return self._data[0:1] == b'\0' and len(self._data) >= 7
def is_eof_packet(self):
# http://dev.mysql.com/doc/internals/en/generic-response-packets.html#packet-EOF_Packet
# Caution: \xFE may be LengthEncodedInteger.
# If \xFE is LengthEncodedInteger header, 8bytes followed.
return len(self._data) < 9 and self._data[0:1] == b'\xfe'
return self._data[0:1] == b'\xfe' and len(self._data) < 9
def is_auth_switch_request(self):
# http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
return self._data[0:1] == b'\xfe'
def is_resultset_packet(self):
field_count = ord(self._data[0:1])
@ -379,9 +405,9 @@ class FieldDescriptorPacket(MysqlPacket):
def __init__(self, data, encoding):
MysqlPacket.__init__(self, data, encoding)
self.__parse_field_descriptor(encoding)
self._parse_field_descriptor(encoding)
def __parse_field_descriptor(self, encoding):
def _parse_field_descriptor(self, encoding):
"""Parse the 'Field Descriptor' (Metadata) packet.
This is compatible with MySQL 4.1+ (not compatible with MySQL 4.0).
@ -494,20 +520,23 @@ class Connection(object):
The proper way to get an instance of this class is to call
connect().
"""
socket = None
_sock = None
_auth_plugin_name = ''
_closed = False
def __init__(self, host="localhost", user=None, password="",
database=None, port=3306, unix_socket=None,
def __init__(self, host=None, user=None, password="",
database=None, port=0, unix_socket=None,
charset='', sql_mode=None,
read_default_file=None, conv=decoders, use_unicode=None,
read_default_file=None, conv=None, use_unicode=None,
client_flag=0, cursorclass=Cursor, init_command=None,
connect_timeout=None, ssl=None, read_default_group=None,
connect_timeout=10, ssl=None, read_default_group=None,
compress=None, named_pipe=None, no_delay=None,
autocommit=False, db=None, passwd=None, local_infile=False,
max_allowed_packet=16*1024*1024, defer_connect=False):
max_allowed_packet=16*1024*1024, defer_connect=False,
auth_plugin_map={}, read_timeout=None, write_timeout=None,
bind_address=None):
"""
Establish a connection to the MySQL database. Accepts several
arguments:
@ -516,15 +545,19 @@ class Connection(object):
user: Username to log in as
password: Password to use.
database: Database to use, None to not use a particular one.
port: MySQL port to use, default is usually OK.
port: MySQL port to use, default is usually OK. (default: 3306)
bind_address: When the client has multiple network interfaces, specify
the interface from which to connect to the host. Argument can be
a hostname or an IP address.
unix_socket: Optionally, you can use a unix socket rather than TCP/IP.
charset: Charset you want to use.
sql_mode: Default SQL_MODE to use.
read_default_file:
Specifies my.cnf file to read these parameters from under the [client] section.
conv:
Decoders dictionary to use instead of the default one.
This is used to provide custom marshalling of types. See converters.
Conversion dictionary to use instead of the default one.
This is used to provide custom marshalling and unmarshaling of types.
See converters.
use_unicode:
Whether or not to default to unicode strings.
This option defaults to true for Py3k.
@ -532,27 +565,29 @@ class Connection(object):
cursorclass: Custom cursor class to use.
init_command: Initial SQL statement to run when connection is established.
connect_timeout: Timeout before throwing an exception when connecting.
(default: 10, min: 1, max: 31536000)
ssl:
A dict of arguments similar to mysql_ssl_set()'s parameters.
For now the capath and cipher arguments are not supported.
read_default_group: Group to read from in the configuration file.
compress; Not supported
named_pipe: Not supported
no_delay: Disable Nagle's algorithm on the socket. (deprecated, default: True)
autocommit: Autocommit mode. None means use server default. (default: False)
local_infile: Boolean to enable the use of LOAD DATA LOCAL command. (default: False)
max_allowed_packet: Max size of packet sent to server in bytes. (default: 16MB)
Only used to limit size of "LOAD LOCAL INFILE" data packet smaller than default (16KB).
defer_connect: Don't explicitly connect on contruction - wait for connect call.
(default: False)
auth_plugin_map: A dict of plugin names to a class that processes that plugin.
The class will take the Connection object as the argument to the constructor.
The class needs an authenticate method taking an authentication packet as
an argument. For the dialog plugin, a prompt(echo, prompt) method can be used
(if no authenticate method) for returning a string from the user. (experimental)
db: Alias for database. (for compatibility to MySQLdb)
passwd: Alias for password. (for compatibility to MySQLdb)
"""
if no_delay is not None:
warnings.warn("no_delay option is deprecated", DeprecationWarning)
no_delay = bool(no_delay)
else:
no_delay = True
if use_unicode is None and sys.version_info[0] > 2:
use_unicode = True
@ -565,24 +600,10 @@ class Connection(object):
if compress or named_pipe:
raise NotImplementedError("compress and named_pipe arguments are not supported")
if local_infile:
self._local_infile = bool(local_infile)
if self._local_infile:
client_flag |= CLIENT.LOCAL_FILES
if ssl and ('capath' in ssl or 'cipher' in ssl):
raise NotImplementedError('ssl options capath and cipher are not supported')
self.ssl = False
if ssl:
if not SSL_ENABLED:
raise NotImplementedError("ssl module not found")
self.ssl = True
client_flag |= CLIENT.SSL
for k in ('key', 'cert', 'ca'):
v = None
if k in ssl:
v = ssl[k]
setattr(self, k, v)
if read_default_group and not read_default_file:
if sys.platform.startswith("win"):
read_default_file = "c:\\my.ini"
@ -610,15 +631,40 @@ class Connection(object):
database = _config("database", database)
unix_socket = _config("socket", unix_socket)
port = int(_config("port", port))
bind_address = _config("bind-address", bind_address)
charset = _config("default-character-set", charset)
if not ssl:
ssl = {}
if isinstance(ssl, dict):
for key in ["ca", "capath", "cert", "key", "cipher"]:
value = _config("ssl-" + key, ssl.get(key))
if value:
ssl[key] = value
self.host = host
self.port = port
self.ssl = False
if ssl:
if not SSL_ENABLED:
raise NotImplementedError("ssl module not found")
self.ssl = True
client_flag |= CLIENT.SSL
self.ctx = self._create_ssl_ctx(ssl)
self.host = host or "localhost"
self.port = port or 3306
self.user = user or DEFAULT_USER
self.password = password or ""
self.db = database
self.no_delay = no_delay
self.unix_socket = unix_socket
self.bind_address = bind_address
if not (0 < connect_timeout <= 31536000):
raise ValueError("connect_timeout should be >0 and <=31536000")
self.connect_timeout = connect_timeout or None
if read_timeout is not None and read_timeout <= 0:
raise ValueError("read_timeout should be >= 0")
self._read_timeout = read_timeout
if write_timeout is not None and write_timeout <= 0:
raise ValueError("write_timeout should be >= 0")
self._write_timeout = write_timeout
if charset:
self.charset = charset
self.use_unicode = True
@ -631,13 +677,12 @@ class Connection(object):
self.encoding = charset_by_name(self.charset).encoding
client_flag |= CLIENT.CAPABILITIES | CLIENT.MULTI_STATEMENTS
client_flag |= CLIENT.CAPABILITIES
if self.db:
client_flag |= CLIENT.CONNECT_WITH_DB
self.client_flag = client_flag
self.cursorclass = cursorclass
self.connect_timeout = connect_timeout
self._result = None
self._affected_rows = 0
@ -646,44 +691,68 @@ class Connection(object):
#: specified autocommit mode. None means use server default.
self.autocommit_mode = autocommit
self.encoders = encoders # Need for MySQLdb compatibility.
self.decoders = conv
if conv is None:
conv = _conv
# Need for MySQLdb compatibility.
self.encoders = dict([(k, v) for (k, v) in conv.items() if type(k) is not int])
self.decoders = dict([(k, v) for (k, v) in conv.items() if type(k) is int])
self.sql_mode = sql_mode
self.init_command = init_command
self.max_allowed_packet = max_allowed_packet
self._auth_plugin_map = auth_plugin_map
if defer_connect:
self.socket = None
self._sock = None
else:
self.connect()
def _create_ssl_ctx(self, sslp):
if isinstance(sslp, ssl.SSLContext):
return sslp
ca = sslp.get('ca')
capath = sslp.get('capath')
hasnoca = ca is None and capath is None
ctx = ssl.create_default_context(cafile=ca, capath=capath)
ctx.check_hostname = not hasnoca and sslp.get('check_hostname', True)
ctx.verify_mode = ssl.CERT_NONE if hasnoca else ssl.CERT_REQUIRED
if 'cert' in sslp:
ctx.load_cert_chain(sslp['cert'], keyfile=sslp.get('key'))
if 'cipher' in sslp:
ctx.set_ciphers(sslp['cipher'])
ctx.options |= ssl.OP_NO_SSLv2
ctx.options |= ssl.OP_NO_SSLv3
return ctx
def close(self):
"""Send the quit message and close the socket"""
if self.socket is None:
if self._closed:
raise err.Error("Already closed")
self._closed = True
if self._sock is None:
return
send_data = struct.pack('<iB', 1, COMMAND.COM_QUIT)
try:
self._write_bytes(send_data)
except Exception:
pass
finally:
sock = self.socket
self.socket = None
self._rfile = None
sock.close()
self._force_close()
@property
def open(self):
return self.socket is not None
return self._sock is not None
def __del__(self):
if self.socket:
def _force_close(self):
"""Close connection without QUIT message"""
if self._sock:
try:
self.socket.close()
self._sock.close()
except:
pass
self.socket = None
self._sock = None
self._rfile = None
__del__ = _force_close
def autocommit(self, value):
self.autocommit_mode = bool(value)
current = self.get_autocommit()
@ -731,19 +800,25 @@ class Connection(object):
return result.rows
def select_db(self, db):
'''Set current db'''
"""Set current db"""
self._execute_command(COMMAND.COM_INIT_DB, db)
self._read_ok_packet()
def escape(self, obj, mapping=None):
"""Escape whatever value you pass to it"""
"""Escape whatever value you pass to it.
Non-standard, for internal use; do not use this in your applications.
"""
if isinstance(obj, str_type):
return "'" + self.escape_string(obj) + "'"
return escape_item(obj, self.charset, mapping=mapping)
def literal(self, obj):
'''Alias for escape()'''
return self.escape(obj)
"""Alias for escape()
Non-standard, for internal use; do not use this in your applications.
"""
return self.escape(obj, self.encoders)
def escape_string(self, s):
if (self.server_status &
@ -795,7 +870,7 @@ class Connection(object):
def ping(self, reconnect=True):
"""Check if the server is alive"""
if self.socket is None:
if self._sock is None:
if reconnect:
self.connect()
reconnect = False
@ -821,6 +896,7 @@ class Connection(object):
self.encoding = encoding
def connect(self, sock=None):
self._closed = False
try:
if sock is None:
if self.unix_socket and self.host in ('localhost', '127.0.0.1'):
@ -830,10 +906,14 @@ class Connection(object):
self.host_info = "Localhost via UNIX socket"
if DEBUG: print('connected using unix_socket')
else:
kwargs = {}
if self.bind_address is not None:
kwargs['source_address'] = (self.bind_address, 0)
while True:
try:
sock = socket.create_connection(
(self.host, self.port), self.connect_timeout)
(self.host, self.port), self.connect_timeout,
**kwargs)
break
except (OSError, IOError) as e:
if e.errno == errno.EINTR:
@ -841,12 +921,13 @@ class Connection(object):
raise
self.host_info = "socket %s:%d" % (self.host, self.port)
if DEBUG: print('connected using socket')
if self.no_delay:
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
# sock.settimeout(None)
sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
sock.settimeout(None)
sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1)
self.socket = sock
self._sock = sock
self._rfile = _makefile(sock, 'rb')
self._next_seq_id = 0
self._get_server_information()
self._request_authentication()
@ -886,6 +967,17 @@ class Connection(object):
# So just reraise it.
raise
def write_packet(self, payload):
"""Writes an entire "mysql packet" in its entirety to the network
addings its length and sequence number.
"""
# Internal note: when you build packet manualy and calls _write_bytes()
# directly, you should set self._next_seq_id properly.
data = pack_int24(len(payload)) + int2byte(self._next_seq_id) + payload
if DEBUG: dump_packet(data)
self._write_bytes(data)
self._next_seq_id = (self._next_seq_id + 1) % 256
def _read_packet(self, packet_type=MysqlPacket):
"""Read an entire "mysql packet" in its entirety from the network
and return a MysqlPacket type that represents the results.
@ -894,19 +986,36 @@ class Connection(object):
while True:
packet_header = self._read_bytes(4)
if DEBUG: dump_packet(packet_header)
btrl, btrh, packet_number = struct.unpack('<HBB', packet_header)
bytes_to_read = btrl + (btrh << 16)
#TODO: check sequence id
if packet_number != self._next_seq_id:
self._force_close()
if packet_number == 0:
# MariaDB sends error packet with seqno==0 when shutdown
raise err.OperationalError(
CR.CR_SERVER_LOST,
"Lost connection to MySQL server during query")
raise err.InternalError(
"Packet sequence number wrong - got %d expected %d"
% (packet_number, self._next_seq_id))
self._next_seq_id = (self._next_seq_id + 1) % 256
recv_data = self._read_bytes(bytes_to_read)
if DEBUG: dump_packet(recv_data)
buff += recv_data
# https://dev.mysql.com/doc/internals/en/sending-more-than-16mbyte.html
if bytes_to_read == 0xffffff:
continue
if bytes_to_read < MAX_PACKET_LEN:
break
packet = packet_type(buff, self.encoding)
packet.check_error()
return packet
def _read_bytes(self, num_bytes):
self._sock.settimeout(self._read_timeout)
while True:
try:
data = self._rfile.read(num_bytes)
@ -914,19 +1023,25 @@ class Connection(object):
except (IOError, OSError) as e:
if e.errno == errno.EINTR:
continue
self._force_close()
raise err.OperationalError(
2013,
CR.CR_SERVER_LOST,
"Lost connection to MySQL server during query (%s)" % (e,))
if len(data) < num_bytes:
self._force_close()
raise err.OperationalError(
2013, "Lost connection to MySQL server during query")
CR.CR_SERVER_LOST, "Lost connection to MySQL server during query")
return data
def _write_bytes(self, data):
self._sock.settimeout(self._write_timeout)
try:
self.socket.sendall(data)
self._sock.sendall(data)
except IOError as e:
raise err.OperationalError(2006, "MySQL server has gone away (%r)" % (e,))
self._force_close()
raise err.OperationalError(
CR.CR_SERVER_GONE_ERROR,
"MySQL server has gone away (%r)" % (e,))
def _read_query_result(self, unbuffered=False):
if unbuffered:
@ -952,42 +1067,45 @@ class Connection(object):
return 0
def _execute_command(self, command, sql):
if not self.socket:
if not self._sock:
raise err.InterfaceError("(0, '')")
# If the last query was unbuffered, make sure it finishes before
# sending new commands
if self._result is not None and self._result.unbuffered_active:
warnings.warn("Previous unbuffered result was left incomplete")
self._result._finish_unbuffered_query()
if self._result is not None:
if self._result.unbuffered_active:
warnings.warn("Previous unbuffered result was left incomplete")
self._result._finish_unbuffered_query()
while self._result.has_next:
self.next_result()
self._result = None
if isinstance(sql, text_type):
sql = sql.encode(self.encoding)
chunk_size = min(self.max_allowed_packet, len(sql) + 1) # +1 is for command
packet_size = min(MAX_PACKET_LEN, len(sql) + 1) # +1 is for command
prelude = struct.pack('<iB', chunk_size, command)
self._write_bytes(prelude + sql[:chunk_size-1])
if DEBUG: dump_packet(prelude + sql)
# tiny optimization: build first packet manually instead of
# calling self..write_packet()
prelude = struct.pack('<iB', packet_size, command)
packet = prelude + sql[:packet_size-1]
self._write_bytes(packet)
if DEBUG: dump_packet(packet)
self._next_seq_id = 1
if chunk_size < self.max_allowed_packet:
if packet_size < MAX_PACKET_LEN:
return
seq_id = 1
sql = sql[chunk_size-1:]
sql = sql[packet_size-1:]
while True:
chunk_size = min(self.max_allowed_packet, len(sql))
prelude = struct.pack('<i', chunk_size)[:3]
data = prelude + int2byte(seq_id%256) + sql[:chunk_size]
self._write_bytes(data)
if DEBUG: dump_packet(data)
sql = sql[chunk_size:]
if not sql and chunk_size < self.max_allowed_packet:
packet_size = min(MAX_PACKET_LEN, len(sql))
self.write_packet(sql[:packet_size])
sql = sql[packet_size:]
if not sql and packet_size < MAX_PACKET_LEN:
break
seq_id += 1
def _request_authentication(self):
self.client_flag |= CLIENT.CAPABILITIES
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse
if int(self.server_version.split('.', 1)[0]) >= 5:
self.client_flag |= CLIENT.MULTI_RESULTS
@ -1000,48 +1118,114 @@ class Connection(object):
data_init = struct.pack('<iIB23s', self.client_flag, 1, charset_id, b'')
next_packet = 1
if self.ssl and self.server_capabilities & CLIENT.SSL:
self.write_packet(data_init)
if self.ssl:
data = pack_int24(len(data_init)) + int2byte(next_packet) + data_init
next_packet += 1
self._sock = self.ctx.wrap_socket(self._sock, server_hostname=self.host)
self._rfile = _makefile(self._sock, 'rb')
if DEBUG: dump_packet(data)
self._write_bytes(data)
data = data_init + self.user + b'\0'
cert_reqs = ssl.CERT_NONE if self.ca is None else ssl.CERT_REQUIRED
self.socket = ssl.wrap_socket(self.socket, keyfile=self.key,
certfile=self.cert,
ssl_version=ssl.PROTOCOL_TLSv1,
cert_reqs=cert_reqs,
ca_certs=self.ca)
self._rfile = _makefile(self.socket, 'rb')
authresp = b''
if self._auth_plugin_name in ('', 'mysql_native_password'):
authresp = _scramble(self.password.encode('latin1'), self.salt)
data = data_init + self.user + b'\0' + \
_scramble(self.password.encode('latin1'), self.salt)
if self.server_capabilities & CLIENT.PLUGIN_AUTH_LENENC_CLIENT_DATA:
data += lenenc_int(len(authresp)) + authresp
elif self.server_capabilities & CLIENT.SECURE_CONNECTION:
data += struct.pack('B', len(authresp)) + authresp
else: # pragma: no cover - not testing against servers without secure auth (>=5.0)
data += authresp + b'\0'
if self.db:
if self.db and self.server_capabilities & CLIENT.CONNECT_WITH_DB:
if isinstance(self.db, text_type):
self.db = self.db.encode(self.encoding)
data += self.db + int2byte(0)
data += self.db + b'\0'
data = pack_int24(len(data)) + int2byte(next_packet) + data
next_packet += 2
if DEBUG: dump_packet(data)
self._write_bytes(data)
if self.server_capabilities & CLIENT.PLUGIN_AUTH:
name = self._auth_plugin_name
if isinstance(name, text_type):
name = name.encode('ascii')
data += name + b'\0'
self.write_packet(data)
auth_packet = self._read_packet()
# if old_passwords is enabled the packet will be 1 byte long and
# have the octet 254
# if authentication method isn't accepted the first byte
# will have the octet 254
if auth_packet.is_auth_switch_request():
# https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchRequest
auth_packet.read_uint8() # 0xfe packet identifier
plugin_name = auth_packet.read_string()
if self.server_capabilities & CLIENT.PLUGIN_AUTH and plugin_name is not None:
auth_packet = self._process_auth(plugin_name, auth_packet)
else:
# send legacy handshake
data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
self.write_packet(data)
auth_packet = self._read_packet()
if auth_packet.is_eof_packet():
# send legacy handshake
data = _scramble_323(self.password.encode('latin1'), self.salt) + b'\0'
data = pack_int24(len(data)) + int2byte(next_packet) + data
self._write_bytes(data)
auth_packet = self._read_packet()
def _process_auth(self, plugin_name, auth_packet):
plugin_class = self._auth_plugin_map.get(plugin_name)
if not plugin_class:
plugin_class = self._auth_plugin_map.get(plugin_name.decode('ascii'))
if plugin_class:
try:
handler = plugin_class(self)
return handler.authenticate(auth_packet)
except AttributeError:
if plugin_name != b'dialog':
raise err.OperationalError(2059, "Authentication plugin '%s'" \
" not loaded: - %r missing authenticate method" % (plugin_name, plugin_class))
except TypeError:
raise err.OperationalError(2059, "Authentication plugin '%s'" \
" not loaded: - %r cannot be constructed with connection object" % (plugin_name, plugin_class))
else:
handler = None
if plugin_name == b"mysql_native_password":
# https://dev.mysql.com/doc/internals/en/secure-password-authentication.html#packet-Authentication::Native41
data = _scramble(self.password.encode('latin1'), auth_packet.read_all()) + b'\0'
elif plugin_name == b"mysql_old_password":
# https://dev.mysql.com/doc/internals/en/old-password-authentication.html
data = _scramble_323(self.password.encode('latin1'), auth_packet.read_all()) + b'\0'
elif plugin_name == b"mysql_clear_password":
# https://dev.mysql.com/doc/internals/en/clear-text-authentication.html
data = self.password.encode('latin1') + b'\0'
elif plugin_name == b"dialog":
pkt = auth_packet
while True:
flag = pkt.read_uint8()
echo = (flag & 0x06) == 0x02
last = (flag & 0x01) == 0x01
prompt = pkt.read_all()
if prompt == b"Password: ":
self.write_packet(self.password.encode('latin1') + b'\0')
elif handler:
resp = 'no response - TypeError within plugin.prompt method'
try:
resp = handler.prompt(echo, prompt)
self.write_packet(resp + b'\0')
except AttributeError:
raise err.OperationalError(2059, "Authentication plugin '%s'" \
" not loaded: - %r missing prompt method" % (plugin_name, handler))
except TypeError:
raise err.OperationalError(2061, "Authentication plugin '%s'" \
" %r didn't respond with string. Returned '%r' to prompt %r" % (plugin_name, handler, resp, prompt))
else:
raise err.OperationalError(2059, "Authentication plugin '%s' (%r) not configured" % (plugin_name, handler))
pkt = self._read_packet()
pkt.check_error()
if pkt.is_ok_packet() or last:
break
return pkt
else:
raise err.OperationalError(2059, "Authentication plugin '%s' not configured" % plugin_name)
self.write_packet(data)
pkt = self._read_packet()
pkt.check_error()
return pkt
# _mysql support
def thread_id(self):
@ -1065,7 +1249,7 @@ class Connection(object):
self.protocol_version = byte2int(data[i:i+1])
i += 1
server_end = data.find(int2byte(0), i)
server_end = data.find(b'\0', i)
self.server_version = data[i:server_end].decode('latin1')
i = server_end + 1
@ -1097,7 +1281,22 @@ class Connection(object):
if len(data) >= i + salt_len:
# salt_len includes auth_plugin_data_part_1 and filler
self.salt += data[i:i+salt_len]
# TODO: AUTH PLUGIN NAME may appeare here.
i += salt_len
i+=1
# AUTH PLUGIN NAME may appear here.
if self.server_capabilities & CLIENT.PLUGIN_AUTH and len(data) >= i:
# Due to Bug#59453 the auth-plugin-name is missing the terminating
# NUL-char in versions prior to 5.5.10 and 5.6.2.
# ref: https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake
# didn't use version checks as mariadb is corrected and reports
# earlier than those two.
server_end = data.find(b'\0', i)
if server_end < 0: # pragma: no cover - very specific upstream bug
# not found \0 and last field so take it all
self._auth_plugin_name = data[i:].decode('latin1')
else:
self._auth_plugin_name = data[i:server_end].decode('latin1')
def get_server_info(self):
return self.server_version
@ -1117,6 +1316,9 @@ class Connection(object):
class MySQLResult(object):
def __init__(self, connection):
"""
:type connection: Connection
"""
self.connection = connection
self.affected_rows = None
self.insert_id = None
@ -1144,7 +1346,7 @@ class MySQLResult(object):
else:
self._read_result_packet(first_packet)
finally:
self.connection = False
self.connection = None
def init_unbuffered_query(self):
self.unbuffered_active = True
@ -1154,6 +1356,10 @@ class MySQLResult(object):
self._read_ok_packet(first_packet)
self.unbuffered_active = False
self.connection = None
elif first_packet.is_load_local_packet():
self._read_load_local_packet(first_packet)
self.unbuffered_active = False
self.connection = None
else:
self.field_count = first_packet.read_length_encoded_integer()
self._get_descriptions()
@ -1173,22 +1379,33 @@ class MySQLResult(object):
self.has_next = ok_packet.has_next
def _read_load_local_packet(self, first_packet):
if not self.connection._local_infile:
raise RuntimeError(
"**WARN**: Received LOAD_LOCAL packet but local_infile option is false.")
load_packet = LoadLocalPacketWrapper(first_packet)
sender = LoadLocalFile(load_packet.filename, self.connection)
sender.send_data()
try:
sender.send_data()
except:
self.connection._read_packet() # skip ok packet
raise
ok_packet = self.connection._read_packet()
if not ok_packet.is_ok_packet():
if not ok_packet.is_ok_packet(): # pragma: no cover - upstream induced protocol error
raise err.OperationalError(2014, "Commands Out of Sync")
self._read_ok_packet(ok_packet)
def _check_packet_is_eof(self, packet):
if packet.is_eof_packet():
eof_packet = EOFPacketWrapper(packet)
self.warning_count = eof_packet.warning_count
self.has_next = eof_packet.has_next
return True
return False
if not packet.is_eof_packet():
return False
#TODO: Support CLIENT.DEPRECATE_EOF
# 1) Add DEPRECATE_EOF to CAPABILITIES
# 2) Mask CAPABILITIES with server_capabilities
# 3) if server_capabilities & CLIENT.DEPRECATE_EOF: use OKPacketWrapper instead of EOFPacketWrapper
wp = EOFPacketWrapper(packet)
self.warning_count = wp.warning_count
self.has_next = wp.has_next
return True
def _read_result_packet(self, first_packet):
self.field_count = first_packet.read_length_encoded_integer()
@ -1239,7 +1456,12 @@ class MySQLResult(object):
def _read_row_from_packet(self, packet):
row = []
for encoding, converter in self.converters:
data = packet.read_length_coded_string()
try:
data = packet.read_length_coded_string()
except IndexError:
# No more columns in this row
# See https://github.com/PyMySQL/PyMySQL/pull/434
break
if data is not None:
if encoding is not None:
data = data.decode(encoding)
@ -1254,21 +1476,30 @@ class MySQLResult(object):
self.fields = []
self.converters = []
use_unicode = self.connection.use_unicode
conn_encoding = self.connection.encoding
description = []
for i in range_type(self.field_count):
field = self.connection._read_packet(FieldDescriptorPacket)
self.fields.append(field)
description.append(field.description())
field_type = field.type_code
if use_unicode:
if field_type in TEXT_TYPES:
charset = charset_by_id(field.charsetnr)
if charset.is_binary:
if field_type == FIELD_TYPE.JSON:
# When SELECT from JSON column: charset = binary
# When SELECT CAST(... AS JSON): charset = connection encoding
# This behavior is different from TEXT / BLOB.
# We should decode result by connection encoding regardless charsetnr.
# See https://github.com/PyMySQL/PyMySQL/issues/488
encoding = conn_encoding # SELECT CAST(... AS JSON)
elif field_type in TEXT_TYPES:
if field.charsetnr == 63: # binary
# TEXTs with charset=binary means BINARY types.
encoding = None
else:
encoding = charset.encoding
encoding = conn_encoding
else:
# Integers, Dates and Times, and other basic data is encoded in ascii
encoding = 'ascii'
else:
encoding = None
@ -1290,28 +1521,20 @@ class LoadLocalFile(object):
def send_data(self):
"""Send data packets from the local file to the server"""
if not self.connection.socket:
if not self.connection._sock:
raise err.InterfaceError("(0, '')")
conn = self.connection
# sequence id is 2 as we already sent a query packet
seq_id = 2
try:
with open(self.filename, 'rb') as open_file:
chunk_size = self.connection.max_allowed_packet
packet = b""
packet_size = min(conn.max_allowed_packet, 16*1024) # 16KB is efficient enough
while True:
chunk = open_file.read(chunk_size)
chunk = open_file.read(packet_size)
if not chunk:
break
packet = struct.pack('<i', len(chunk))[:3] + int2byte(seq_id)
format_str = '!{0}s'.format(len(chunk))
packet += struct.pack(format_str, chunk)
self.connection._write_bytes(packet)
seq_id += 1
conn.write_packet(chunk)
except IOError:
raise err.OperationalError(1017, "Can't find file '{0}'".format(self.filename))
finally:
# send the empty packet to signify we are done sending data
packet = struct.pack('<i', 0)[:3] + int2byte(seq_id)
self.connection._write_bytes(packet)
conn.write_packet(b'')

View File

@ -1,3 +1,4 @@
# https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags
LONG_PASSWORD = 1
FOUND_ROWS = 1 << 1
LONG_FLAG = 1 << 2
@ -15,5 +16,16 @@ TRANSACTIONS = 1 << 13
SECURE_CONNECTION = 1 << 15
MULTI_STATEMENTS = 1 << 16
MULTI_RESULTS = 1 << 17
CAPABILITIES = (LONG_PASSWORD | LONG_FLAG | TRANSACTIONS |
PROTOCOL_41 | SECURE_CONNECTION)
PS_MULTI_RESULTS = 1 << 18
PLUGIN_AUTH = 1 << 19
PLUGIN_AUTH_LENENC_CLIENT_DATA = 1 << 21
CAPABILITIES = (
LONG_PASSWORD | LONG_FLAG | PROTOCOL_41 | TRANSACTIONS
| SECURE_CONNECTION | MULTI_STATEMENTS | MULTI_RESULTS
| PLUGIN_AUTH | PLUGIN_AUTH_LENENC_CLIENT_DATA)
# Not done yet
CONNECT_ATTRS = 1 << 20
HANDLE_EXPIRED_PASSWORDS = 1 << 22
SESSION_TRACK = 1 << 23
DEPRECATE_EOF = 1 << 24

View File

@ -1,3 +1,4 @@
# flake8: noqa
# errmsg.h
CR_ERROR_FIRST = 2000
CR_UNKNOWN_ERROR = 2000

View File

@ -17,6 +17,7 @@ YEAR = 13
NEWDATE = 14
VARCHAR = 15
BIT = 16
JSON = 245
NEWDECIMAL = 246
ENUM = 247
SET = 248

View File

@ -1,7 +1,5 @@
from ._compat import PY2, text_type, long_type, JYTHON, IRONPYTHON
from ._compat import PY2, text_type, long_type, JYTHON, IRONPYTHON, unichr
import sys
import binascii
import datetime
from decimal import Decimal
import re
@ -11,10 +9,6 @@ from .constants import FIELD_TYPE, FLAG
from .charset import charset_by_id, charset_to_encoding
ESCAPE_REGEX = re.compile(r"[\0\n\r\032\'\"\\]")
ESCAPE_MAP = {'\0': '\\0', '\n': '\\n', '\r': '\\r', '\032': '\\Z',
'\'': '\\\'', '"': '\\"', '\\': '\\\\'}
def escape_item(val, charset, mapping=None):
if mapping is None:
mapping = encoders
@ -48,8 +42,7 @@ def escape_sequence(val, charset, mapping=None):
return "(" + ",".join(n) + ")"
def escape_set(val, charset, mapping=None):
val = map(lambda x: escape_item(x, charset, mapping), val)
return ','.join(val)
return ','.join([escape_item(x, charset, mapping) for x in val])
def escape_bool(value, mapping=None):
return str(int(value))
@ -63,19 +56,61 @@ def escape_int(value, mapping=None):
def escape_float(value, mapping=None):
return ('%.15g' % value)
def escape_string(value, mapping=None):
return ("%s" % (ESCAPE_REGEX.sub(
lambda match: ESCAPE_MAP.get(match.group(0)), value),))
_escape_table = [unichr(x) for x in range(128)]
_escape_table[0] = u'\\0'
_escape_table[ord('\\')] = u'\\\\'
_escape_table[ord('\n')] = u'\\n'
_escape_table[ord('\r')] = u'\\r'
_escape_table[ord('\032')] = u'\\Z'
_escape_table[ord('"')] = u'\\"'
_escape_table[ord("'")] = u"\\'"
def _escape_unicode(value, mapping=None):
"""escapes *value* without adding quote.
Value should be unicode
"""
return value.translate(_escape_table)
if PY2:
def escape_string(value, mapping=None):
"""escape_string escapes *value* but not surround it with quotes.
Value should be bytes or unicode.
"""
if isinstance(value, unicode):
return _escape_unicode(value)
assert isinstance(value, (bytes, bytearray))
value = value.replace('\\', '\\\\')
value = value.replace('\0', '\\0')
value = value.replace('\n', '\\n')
value = value.replace('\r', '\\r')
value = value.replace('\032', '\\Z')
value = value.replace("'", "\\'")
value = value.replace('"', '\\"')
return value
def escape_bytes(value, mapping=None):
assert isinstance(value, (bytes, bytearray))
return b"_binary'%s'" % escape_string(value)
else:
escape_string = _escape_unicode
# On Python ~3.5, str.decode('ascii', 'surrogateescape') is slow.
# (fixed in Python 3.6, http://bugs.python.org/issue24870)
# Workaround is str.decode('latin1') then translate 0x80-0xff into 0udc80-0udcff.
# We can escape special chars and surrogateescape at once.
_escape_bytes_table = _escape_table + [chr(i) for i in range(0xdc80, 0xdd00)]
def escape_bytes(value, mapping=None):
return "_binary'%s'" % value.decode('latin1').translate(_escape_bytes_table)
def escape_str(value, mapping=None):
return "'%s'" % escape_string(value, mapping)
def escape_unicode(value, mapping=None):
return escape_str(value, mapping)
return u"'%s'" % _escape_unicode(value)
def escape_bytes(value, mapping=None):
# escape_bytes is calld only on Python 3.
return escape_str(value.decode('ascii', 'surrogateescape'), mapping)
def escape_str(value, mapping=None):
return "'%s'" % escape_string(str(value), mapping)
def escape_None(value, mapping=None):
return 'NULL'
@ -111,6 +146,16 @@ def escape_date(obj, mapping=None):
def escape_struct_time(obj, mapping=None):
return escape_datetime(datetime.datetime(*obj[:6]))
def _convert_second_fraction(s):
if not s:
return 0
# Pad zeros to ensure the fraction length in microseconds
s = s.ljust(6, '0')
return int(s[:6])
DATETIME_RE = re.compile(r"(\d{1,4})-(\d{1,2})-(\d{1,2})[T ](\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
def convert_datetime(obj):
"""Returns a DATETIME or TIMESTAMP column value as a datetime object:
@ -127,23 +172,22 @@ def convert_datetime(obj):
True
"""
if ' ' in obj:
sep = ' '
elif 'T' in obj:
sep = 'T'
else:
if not PY2 and isinstance(obj, (bytes, bytearray)):
obj = obj.decode('ascii')
m = DATETIME_RE.match(obj)
if not m:
return convert_date(obj)
try:
ymd, hms = obj.split(sep, 1)
usecs = '0'
if '.' in hms:
hms, usecs = hms.split('.')
usecs = float('0.' + usecs) * 1e6
return datetime.datetime(*[ int(x) for x in ymd.split('-')+hms.split(':')+[usecs] ])
groups = list(m.groups())
groups[-1] = _convert_second_fraction(groups[-1])
return datetime.datetime(*[ int(x) for x in groups ])
except ValueError:
return convert_date(obj)
TIMEDELTA_RE = re.compile(r"(-)?(\d{1,3}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
def convert_timedelta(obj):
"""Returns a TIME column as a timedelta object:
@ -162,16 +206,19 @@ def convert_timedelta(obj):
can accept values as (+|-)DD HH:MM:SS. The latter format will not
be parsed correctly by this function.
"""
if not PY2 and isinstance(obj, (bytes, bytearray)):
obj = obj.decode('ascii')
m = TIMEDELTA_RE.match(obj)
if not m:
return None
try:
microseconds = 0
if "." in obj:
(obj, tail) = obj.split('.')
microseconds = float('0.' + tail) * 1e6
hours, minutes, seconds = obj.split(':')
negate = 1
if hours.startswith("-"):
hours = hours[1:]
negate = -1
groups = list(m.groups())
groups[-1] = _convert_second_fraction(groups[-1])
negate = -1 if groups[0] else 1
hours, minutes, seconds, microseconds = groups[1:]
tdelta = datetime.timedelta(
hours = int(hours),
minutes = int(minutes),
@ -182,6 +229,9 @@ def convert_timedelta(obj):
except ValueError:
return None
TIME_RE = re.compile(r"(\d{1,2}):(\d{1,2}):(\d{1,2})(?:.(\d{1,6}))?")
def convert_time(obj):
"""Returns a TIME column as a time object:
@ -204,17 +254,23 @@ def convert_time(obj):
to be treated as time-of-day and not a time offset, then you can
use set this function as the converter for FIELD_TYPE.TIME.
"""
if not PY2 and isinstance(obj, (bytes, bytearray)):
obj = obj.decode('ascii')
m = TIME_RE.match(obj)
if not m:
return None
try:
microseconds = 0
if "." in obj:
(obj, tail) = obj.split('.')
microseconds = float('0.' + tail) * 1e6
hours, minutes, seconds = obj.split(':')
groups = list(m.groups())
groups[-1] = _convert_second_fraction(groups[-1])
hours, minutes, seconds, microseconds = groups
return datetime.time(hour=int(hours), minute=int(minutes),
second=int(seconds), microsecond=int(microseconds))
except ValueError:
return None
def convert_date(obj):
"""Returns a DATE column as a date object:
@ -229,6 +285,8 @@ def convert_date(obj):
True
"""
if not PY2 and isinstance(obj, (bytes, bytearray)):
obj = obj.decode('ascii')
try:
return datetime.date(*[ int(x) for x in obj.split('-', 2) ])
except ValueError:
@ -256,6 +314,8 @@ def convert_mysql_timestamp(timestamp):
True
"""
if not PY2 and isinstance(timestamp, (bytes, bytearray)):
timestamp = timestamp.decode('ascii')
if timestamp[4] == '-':
return convert_datetime(timestamp)
timestamp += "0"*(14-len(timestamp)) # padding
@ -268,6 +328,8 @@ def convert_mysql_timestamp(timestamp):
return None
def convert_set(s):
if isinstance(s, (bytes, bytearray)):
return set(s.split(b","))
return set(s.split(","))
@ -278,7 +340,7 @@ def through(x):
#def convert_bit(b):
# b = "\x00" * (8 - len(b)) + b # pad w/ zeroes
# return struct.unpack(">Q", b)[0]
#
#
# the snippet above is right, but MySQLdb doesn't process bits,
# so we shouldn't either
convert_bit = through
@ -309,7 +371,9 @@ encoders = {
tuple: escape_sequence,
list: escape_sequence,
set: escape_sequence,
frozenset: escape_sequence,
dict: escape_dict,
bytearray: escape_bytes,
type(None): escape_None,
datetime.date: escape_date,
datetime.datetime: escape_datetime,
@ -350,7 +414,6 @@ decoders = {
# for MySQLdb compatibility
conversions = decoders
def Thing2Literal(obj):
return escape_str(str(obj))
conversions = encoders.copy()
conversions.update(decoders)
Thing2Literal = escape_str

View File

@ -5,33 +5,37 @@ import re
import warnings
from ._compat import range_type, text_type, PY2
from . import err
#: Regular expression for :meth:`Cursor.executemany`.
#: executemany only suports simple bulk insert.
#: You can use it to load large dataset.
RE_INSERT_VALUES = re.compile(r"""(INSERT\s.+\sVALUES\s+)(\(\s*%s\s*(?:,\s*%s\s*)*\))(\s*(?:ON DUPLICATE.*)?)\Z""",
re.IGNORECASE | re.DOTALL)
RE_INSERT_VALUES = re.compile(
r"\s*((?:INSERT|REPLACE)\s.+\sVALUES?\s+)" +
r"(\(\s*(?:%s|%\(.+\)s)\s*(?:,\s*(?:%s|%\(.+\)s)\s*)*\))" +
r"(\s*(?:ON DUPLICATE.*)?);?\s*\Z",
re.IGNORECASE | re.DOTALL)
class Cursor(object):
'''
"""
This is the object you use to interact with the database.
'''
"""
#: Max stetement size which :meth:`executemany` generates.
#: Max statement size which :meth:`executemany` generates.
#:
#: Max size of allowed statement is max_allowed_packet - packet_header_size.
#: Default value of max_allowed_packet is 1048576.
max_stmt_length = 1024000
_defer_warnings = False
def __init__(self, connection):
'''
"""
Do not create an instance of a Cursor yourself. Call
connections.Connection.cursor().
'''
"""
self.connection = connection
self.description = None
self.rownumber = 0
@ -40,11 +44,12 @@ class Cursor(object):
self._executed = None
self._result = None
self._rows = None
self._warnings_handled = False
def close(self):
'''
"""
Closing a cursor just exhausts all remaining data.
'''
"""
conn = self.connection
if conn is None:
return
@ -83,6 +88,9 @@ class Cursor(object):
"""Get the next query set"""
conn = self._get_db()
current_result = self._result
# for unbuffered queries warnings are only available once whole result has been read
if unbuffered:
self._show_warnings()
if current_result is None or current_result is not conn._result:
return None
if not current_result.has_next:
@ -107,17 +115,17 @@ class Cursor(object):
if isinstance(args, (tuple, list)):
if PY2:
args = tuple(map(ensure_bytes, args))
return tuple(conn.escape(arg) for arg in args)
return tuple(conn.literal(arg) for arg in args)
elif isinstance(args, dict):
if PY2:
args = dict((ensure_bytes(key), ensure_bytes(val)) for
(key, val) in args.items())
return dict((key, conn.escape(val)) for (key, val) in args.items())
return dict((key, conn.literal(val)) for (key, val) in args.items())
else:
# If it's not a dictionary let's try escaping it anyways.
# Worst case it will throw a Value error
if PY2:
ensure_bytes(args)
args = ensure_bytes(args)
return conn.escape(args)
def mogrify(self, query, args=None):
@ -137,7 +145,19 @@ class Cursor(object):
return query
def execute(self, query, args=None):
'''Execute a query'''
"""Execute a query
:param str query: Query to execute.
:param args: parameters used with query. (optional)
:type args: tuple, list or dict
:return: Number of affected rows
:rtype: int
If args is a list or tuple, %s can be used as a placeholder in the query.
If args is a dict, %(name)s can be used as a placeholder in the query.
"""
while self.nextset():
pass
@ -148,17 +168,23 @@ class Cursor(object):
return result
def executemany(self, query, args):
# type: (str, list) -> int
"""Run several data against one query
PyMySQL can execute bulkinsert for query like 'INSERT ... VALUES (%s)'.
In other form of queries, just run :meth:`execute` many times.
:param query: query to execute on server
:param args: Sequence of sequences or mappings. It is used as parameter.
:return: Number of rows affected, if any.
This method improves performance on multiple-row INSERT and
REPLACE. Otherwise it is equivalent to looping over args with
execute().
"""
if not args:
return
m = RE_INSERT_VALUES.match(query)
if m:
q_prefix = m.group(1)
q_prefix = m.group(1) % ()
q_values = m.group(2).rstrip()
q_postfix = m.group(3) or ''
assert q_values[0] == '(' and q_values[-1] == ')'
@ -247,7 +273,7 @@ class Cursor(object):
return args
def fetchone(self):
''' Fetch the next row '''
"""Fetch the next row"""
self._check_executed()
if self._rows is None or self.rownumber >= len(self._rows):
return None
@ -256,7 +282,7 @@ class Cursor(object):
return result
def fetchmany(self, size=None):
''' Fetch several rows '''
"""Fetch several rows"""
self._check_executed()
if self._rows is None:
return ()
@ -266,7 +292,7 @@ class Cursor(object):
return result
def fetchall(self):
''' Fetch all the rows '''
"""Fetch all the rows"""
self._check_executed()
if self._rows is None:
return ()
@ -307,14 +333,18 @@ class Cursor(object):
self.description = result.description
self.lastrowid = result.insert_id
self._rows = result.rows
self._warnings_handled = False
if result.warning_count > 0:
self._show_warnings(conn)
if not self._defer_warnings:
self._show_warnings()
def _show_warnings(self, conn):
if self._result and self._result.has_next:
def _show_warnings(self):
if self._warnings_handled:
return
ws = conn.show_warnings()
self._warnings_handled = True
if self._result and (self._result.has_next or not self._result.warning_count):
return
ws = self._get_db().show_warnings()
if ws is None:
return
for w in ws:
@ -322,7 +352,7 @@ class Cursor(object):
if PY2:
if isinstance(msg, unicode):
msg = msg.encode('utf-8', 'replace')
warnings.warn(str(msg), err.Warning, 4)
warnings.warn(err.Warning(*w[1:3]), stacklevel=4)
def __iter__(self):
return iter(self.fetchone, None)
@ -373,8 +403,8 @@ class SSCursor(Cursor):
or for connections to remote servers over a slow network.
Instead of copying every row of data into a buffer, this will fetch
rows as needed. The upside of this, is the client uses much less memory,
and rows are returned much faster when traveling over a slow network,
rows as needed. The upside of this is the client uses much less memory,
and rows are returned much faster when traveling over a slow network
or if the result set is very big.
There are limitations, though. The MySQL protocol doesn't support
@ -383,6 +413,8 @@ class SSCursor(Cursor):
possible to scroll backwards, as only the current row is held in memory.
"""
_defer_warnings = True
def _conv_row(self, row):
return row
@ -411,14 +443,15 @@ class SSCursor(Cursor):
return self._nextset(unbuffered=True)
def read_next(self):
""" Read next row """
"""Read next row"""
return self._conv_row(self._result._read_rowdata_packet_unbuffered())
def fetchone(self):
""" Fetch next row """
"""Fetch next row"""
self._check_executed()
row = self.read_next()
if row is None:
self._show_warnings()
return None
self.rownumber += 1
return row
@ -443,7 +476,7 @@ class SSCursor(Cursor):
return self.fetchall_unbuffered()
def fetchmany(self, size=None):
""" Fetch many """
"""Fetch many"""
self._check_executed()
if size is None:
size = self.arraysize
@ -452,6 +485,7 @@ class SSCursor(Cursor):
for i in range_type(size):
row = self.read_next()
if row is None:
self._show_warnings()
break
rows.append(row)
self.rownumber += 1
@ -482,4 +516,4 @@ class SSCursor(Cursor):
class SSDictCursor(DictCursorMixin, SSCursor):
""" An unbuffered cursor, which returns results as a dictionary """
"""An unbuffered cursor, which returns results as a dictionary"""

View File

@ -68,10 +68,12 @@ class NotSupportedError(DatabaseError):
error_map = {}
def _map_error(exc, *errors):
for error in errors:
error_map[error] = exc
_map_error(ProgrammingError, ER.DB_CREATE_EXISTS, ER.SYNTAX_ERROR,
ER.PARSE_ERROR, ER.NO_SUCH_TABLE, ER.WRONG_DB_NAME,
ER.WRONG_TABLE_NAME, ER.FIELD_SPECIFIED_TWICE,
@ -89,32 +91,17 @@ _map_error(OperationalError, ER.DBACCESS_DENIED_ERROR, ER.ACCESS_DENIED_ERROR,
ER.CON_COUNT_ERROR, ER.TABLEACCESS_DENIED_ERROR,
ER.COLUMNACCESS_DENIED_ERROR)
del _map_error, ER
def _get_error_info(data):
def raise_mysql_exception(data):
errno = struct.unpack('<h', data[1:3])[0]
is_41 = data[3:4] == b"#"
if is_41:
# version 4.1
sqlstate = data[4:9].decode("utf8", 'replace')
errorvalue = data[9:].decode("utf8", 'replace')
return (errno, sqlstate, errorvalue)
# client protocol 4.1
errval = data[9:].decode('utf-8', 'replace')
else:
# version 4.0
return (errno, None, data[3:].decode("utf8", 'replace'))
def _check_mysql_exception(errinfo):
errno, sqlstate, errorvalue = errinfo
errorclass = error_map.get(errno, None)
if errorclass:
raise errorclass(errno, errorvalue)
# couldn't find the right error number
raise InternalError(errno, errorvalue)
def raise_mysql_exception(data):
errinfo = _get_error_info(data)
_check_mysql_exception(errinfo)
errval = data[3:].decode('utf-8', 'replace')
errorclass = error_map.get(errno, InternalError)
raise errorclass(errno, errval)

View File

@ -1,16 +1,20 @@
from time import localtime
from datetime import date, datetime, time, timedelta
Date = date
Time = time
TimeDelta = timedelta
Timestamp = datetime
def DateFromTicks(ticks):
return date(*localtime(ticks)[:3])
def TimeFromTicks(ticks):
return time(*localtime(ticks)[3:6])
def TimestampFromTicks(ticks):
return datetime(*localtime(ticks)[:6])

View File

@ -1,14 +1,17 @@
import struct
def byte2int(b):
if isinstance(b, int):
return b
else:
return struct.unpack("!B", b)[0]
def int2byte(i):
return struct.pack("!B", i)
def join_bytes(bs):
if len(bs) == 0:
return ""

View File

@ -1,45 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RSA module
Module for calculating large primes, and RSA encryption, decryption, signing
and verification. Includes generating public and private keys.
WARNING: this implementation does not use random padding, compression of the
cleartext input to prevent repetitions, or other common security improvements.
Use with care.
If you want to have a more secure implementation, use the functions from the
``rsa.pkcs1`` module.
"""
__author__ = "Sybren Stuvel, Barry Mead and Yesudeep Mangalapilly"
__date__ = "2015-07-29"
__version__ = '3.2'
from rsa.key import newkeys, PrivateKey, PublicKey
from rsa.pkcs1 import encrypt, decrypt, sign, verify, DecryptionError, \
VerificationError
# Do doctest if we're run directly
if __name__ == "__main__":
import doctest
doctest.testmod()
__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify", 'PublicKey',
'PrivateKey', 'DecryptionError', 'VerificationError']

View File

@ -1,160 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Python compatibility wrappers."""
from __future__ import absolute_import
import sys
from struct import pack
try:
MAX_INT = sys.maxsize
except AttributeError:
MAX_INT = sys.maxint
MAX_INT64 = (1 << 63) - 1
MAX_INT32 = (1 << 31) - 1
MAX_INT16 = (1 << 15) - 1
# Determine the word size of the processor.
if MAX_INT == MAX_INT64:
# 64-bit processor.
MACHINE_WORD_SIZE = 64
elif MAX_INT == MAX_INT32:
# 32-bit processor.
MACHINE_WORD_SIZE = 32
else:
# Else we just assume 64-bit processor keeping up with modern times.
MACHINE_WORD_SIZE = 64
try:
# < Python3
unicode_type = unicode
have_python3 = False
except NameError:
# Python3.
unicode_type = str
have_python3 = True
# Fake byte literals.
if str is unicode_type:
def byte_literal(s):
return s.encode('latin1')
else:
def byte_literal(s):
return s
# ``long`` is no more. Do type detection using this instead.
try:
integer_types = (int, long)
except NameError:
integer_types = (int,)
b = byte_literal
try:
# Python 2.6 or higher.
bytes_type = bytes
except NameError:
# Python 2.5
bytes_type = str
# To avoid calling b() multiple times in tight loops.
ZERO_BYTE = b('\x00')
EMPTY_BYTE = b('')
def is_bytes(obj):
"""
Determines whether the given value is a byte string.
:param obj:
The value to test.
:returns:
``True`` if ``value`` is a byte string; ``False`` otherwise.
"""
return isinstance(obj, bytes_type)
def is_integer(obj):
"""
Determines whether the given value is an integer.
:param obj:
The value to test.
:returns:
``True`` if ``value`` is an integer; ``False`` otherwise.
"""
return isinstance(obj, integer_types)
def byte(num):
"""
Converts a number between 0 and 255 (both inclusive) to a base-256 (byte)
representation.
Use it as a replacement for ``chr`` where you are expecting a byte
because this will work on all current versions of Python::
:param num:
An unsigned integer between 0 and 255 (both inclusive).
:returns:
A single byte.
"""
return pack("B", num)
def get_word_alignment(num, force_arch=64,
_machine_word_size=MACHINE_WORD_SIZE):
"""
Returns alignment details for the given number based on the platform
Python is running on.
:param num:
Unsigned integral number.
:param force_arch:
If you don't want to use 64-bit unsigned chunks, set this to
anything other than 64. 32-bit chunks will be preferred then.
Default 64 will be used when on a 64-bit machine.
:param _machine_word_size:
(Internal) The machine word size used for alignment.
:returns:
4-tuple::
(word_bits, word_bytes,
max_uint, packing_format_type)
"""
max_uint64 = 0xffffffffffffffff
max_uint32 = 0xffffffff
max_uint16 = 0xffff
max_uint8 = 0xff
if force_arch == 64 and _machine_word_size >= 64 and num > max_uint32:
# 64-bit unsigned integer.
return 64, 8, max_uint64, "Q"
elif num > max_uint16:
# 32-bit unsigned integer
return 32, 4, max_uint32, "L"
elif num > max_uint8:
# 16-bit unsigned integer.
return 16, 2, max_uint16, "H"
else:
# 8-bit unsigned integer.
return 8, 1, max_uint8, "B"

View File

@ -1,442 +0,0 @@
"""RSA module
pri = k[1] //Private part of keys d,p,q
Module for calculating large primes, and RSA encryption, decryption,
signing and verification. Includes generating public and private keys.
WARNING: this code implements the mathematics of RSA. It is not suitable for
real-world secure cryptography purposes. It has not been reviewed by a security
expert. It does not include padding of data. There are many ways in which the
output of this module, when used without any modification, can be sucessfully
attacked.
"""
__author__ = "Sybren Stuvel, Marloes de Boer and Ivo Tamboer"
__date__ = "2010-02-05"
__version__ = '1.3.3'
# NOTE: Python's modulo can return negative numbers. We compensate for
# this behaviour using the abs() function
from cPickle import dumps, loads
import base64
import math
import os
import random
import sys
import types
import zlib
from rsa._compat import byte
# Display a warning that this insecure version is imported.
import warnings
warnings.warn('Insecure version of the RSA module is imported as %s, be careful'
% __name__)
def gcd(p, q):
"""Returns the greatest common divisor of p and q
>>> gcd(42, 6)
6
"""
if p<q: return gcd(q, p)
if q == 0: return p
return gcd(q, abs(p%q))
def bytes2int(bytes):
"""Converts a list of bytes or a string to an integer
>>> (128*256 + 64)*256 + + 15
8405007
>>> l = [128, 64, 15]
>>> bytes2int(l)
8405007
"""
if not (type(bytes) is types.ListType or type(bytes) is types.StringType):
raise TypeError("You must pass a string or a list")
# Convert byte stream to integer
integer = 0
for byte in bytes:
integer *= 256
if type(byte) is types.StringType: byte = ord(byte)
integer += byte
return integer
def int2bytes(number):
"""Converts a number to a string of bytes
>>> bytes2int(int2bytes(123456789))
123456789
"""
if not (type(number) is types.LongType or type(number) is types.IntType):
raise TypeError("You must pass a long or an int")
string = ""
while number > 0:
string = "%s%s" % (byte(number & 0xFF), string)
number /= 256
return string
def fast_exponentiation(a, p, n):
"""Calculates r = a^p mod n
"""
result = a % n
remainders = []
while p != 1:
remainders.append(p & 1)
p = p >> 1
while remainders:
rem = remainders.pop()
result = ((a ** rem) * result ** 2) % n
return result
def read_random_int(nbits):
"""Reads a random integer of approximately nbits bits rounded up
to whole bytes"""
nbytes = ceil(nbits/8.)
randomdata = os.urandom(nbytes)
return bytes2int(randomdata)
def ceil(x):
"""ceil(x) -> int(math.ceil(x))"""
return int(math.ceil(x))
def randint(minvalue, maxvalue):
"""Returns a random integer x with minvalue <= x <= maxvalue"""
# Safety - get a lot of random data even if the range is fairly
# small
min_nbits = 32
# The range of the random numbers we need to generate
range = maxvalue - minvalue
# Which is this number of bytes
rangebytes = ceil(math.log(range, 2) / 8.)
# Convert to bits, but make sure it's always at least min_nbits*2
rangebits = max(rangebytes * 8, min_nbits * 2)
# Take a random number of bits between min_nbits and rangebits
nbits = random.randint(min_nbits, rangebits)
return (read_random_int(nbits) % range) + minvalue
def fermat_little_theorem(p):
"""Returns 1 if p may be prime, and something else if p definitely
is not prime"""
a = randint(1, p-1)
return fast_exponentiation(a, p-1, p)
def jacobi(a, b):
"""Calculates the value of the Jacobi symbol (a/b)
"""
if a % b == 0:
return 0
result = 1
while a > 1:
if a & 1:
if ((a-1)*(b-1) >> 2) & 1:
result = -result
b, a = a, b % a
else:
if ((b ** 2 - 1) >> 3) & 1:
result = -result
a = a >> 1
return result
def jacobi_witness(x, n):
"""Returns False if n is an Euler pseudo-prime with base x, and
True otherwise.
"""
j = jacobi(x, n) % n
f = fast_exponentiation(x, (n-1)/2, n)
if j == f: return False
return True
def randomized_primality_testing(n, k):
"""Calculates whether n is composite (which is always correct) or
prime (which is incorrect with error probability 2**-k)
Returns False if the number if composite, and True if it's
probably prime.
"""
q = 0.5 # Property of the jacobi_witness function
# t = int(math.ceil(k / math.log(1/q, 2)))
t = ceil(k / math.log(1/q, 2))
for i in range(t+1):
x = randint(1, n-1)
if jacobi_witness(x, n): return False
return True
def is_prime(number):
"""Returns True if the number is prime, and False otherwise.
>>> is_prime(42)
0
>>> is_prime(41)
1
"""
"""
if not fermat_little_theorem(number) == 1:
# Not prime, according to Fermat's little theorem
return False
"""
if randomized_primality_testing(number, 5):
# Prime, according to Jacobi
return True
# Not prime
return False
def getprime(nbits):
"""Returns a prime number of max. 'math.ceil(nbits/8)*8' bits. In
other words: nbits is rounded up to whole bytes.
>>> p = getprime(8)
>>> is_prime(p-1)
0
>>> is_prime(p)
1
>>> is_prime(p+1)
0
"""
nbytes = int(math.ceil(nbits/8.))
while True:
integer = read_random_int(nbits)
# Make sure it's odd
integer |= 1
# Test for primeness
if is_prime(integer): break
# Retry if not prime
return integer
def are_relatively_prime(a, b):
"""Returns True if a and b are relatively prime, and False if they
are not.
>>> are_relatively_prime(2, 3)
1
>>> are_relatively_prime(2, 4)
0
"""
d = gcd(a, b)
return (d == 1)
def find_p_q(nbits):
"""Returns a tuple of two different primes of nbits bits"""
p = getprime(nbits)
while True:
q = getprime(nbits)
if not q == p: break
return (p, q)
def extended_euclid_gcd(a, b):
"""Returns a tuple (d, i, j) such that d = gcd(a, b) = ia + jb
"""
if b == 0:
return (a, 1, 0)
q = abs(a % b)
r = long(a / b)
(d, k, l) = extended_euclid_gcd(b, q)
return (d, l, k - l*r)
# Main function: calculate encryption and decryption keys
def calculate_keys(p, q, nbits):
"""Calculates an encryption and a decryption key for p and q, and
returns them as a tuple (e, d)"""
n = p * q
phi_n = (p-1) * (q-1)
while True:
# Make sure e has enough bits so we ensure "wrapping" through
# modulo n
e = getprime(max(8, nbits/2))
if are_relatively_prime(e, n) and are_relatively_prime(e, phi_n): break
(d, i, j) = extended_euclid_gcd(e, phi_n)
if not d == 1:
raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n))
if not (e * i) % phi_n == 1:
raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n))
return (e, i)
def gen_keys(nbits):
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
Note: this can take a long time, depending on the key size.
"""
while True:
(p, q) = find_p_q(nbits)
(e, d) = calculate_keys(p, q, nbits)
# For some reason, d is sometimes negative. We don't know how
# to fix it (yet), so we keep trying until everything is shiny
if d > 0: break
return (p, q, e, d)
def gen_pubpriv_keys(nbits):
"""Generates public and private keys, and returns them as (pub,
priv).
The public key consists of a dict {e: ..., , n: ....). The private
key consists of a dict {d: ...., p: ...., q: ....).
"""
(p, q, e, d) = gen_keys(nbits)
return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} )
def encrypt_int(message, ekey, n):
"""Encrypts a message using encryption key 'ekey', working modulo
n"""
if type(message) is types.IntType:
return encrypt_int(long(message), ekey, n)
if not type(message) is types.LongType:
raise TypeError("You must pass a long or an int")
if message > 0 and \
math.floor(math.log(message, 2)) > math.floor(math.log(n, 2)):
raise OverflowError("The message is too long")
return fast_exponentiation(message, ekey, n)
def decrypt_int(cyphertext, dkey, n):
"""Decrypts a cypher text using the decryption key 'dkey', working
modulo n"""
return encrypt_int(cyphertext, dkey, n)
def sign_int(message, dkey, n):
"""Signs 'message' using key 'dkey', working modulo n"""
return decrypt_int(message, dkey, n)
def verify_int(signed, ekey, n):
"""verifies 'signed' using key 'ekey', working modulo n"""
return encrypt_int(signed, ekey, n)
def picklechops(chops):
"""Pickles and base64encodes it's argument chops"""
value = zlib.compress(dumps(chops))
encoded = base64.encodestring(value)
return encoded.strip()
def unpicklechops(string):
"""base64decodes and unpickes it's argument string into chops"""
return loads(zlib.decompress(base64.decodestring(string)))
def chopstring(message, key, n, funcref):
"""Splits 'message' into chops that are at most as long as n,
converts these into integers, and calls funcref(integer, key, n)
for each chop.
Used by 'encrypt' and 'sign'.
"""
msglen = len(message)
mbits = msglen * 8
nbits = int(math.floor(math.log(n, 2)))
nbytes = nbits / 8
blocks = msglen / nbytes
if msglen % nbytes > 0:
blocks += 1
cypher = []
for bindex in range(blocks):
offset = bindex * nbytes
block = message[offset:offset+nbytes]
value = bytes2int(block)
cypher.append(funcref(value, key, n))
return picklechops(cypher)
def gluechops(chops, key, n, funcref):
"""Glues chops back together into a string. calls
funcref(integer, key, n) for each chop.
Used by 'decrypt' and 'verify'.
"""
message = ""
chops = unpicklechops(chops)
for cpart in chops:
mpart = funcref(cpart, key, n)
message += int2bytes(mpart)
return message
def encrypt(message, key):
"""Encrypts a string 'message' with the public key 'key'"""
return chopstring(message, key['e'], key['n'], encrypt_int)
def sign(message, key):
"""Signs a string 'message' with the private key 'key'"""
return chopstring(message, key['d'], key['p']*key['q'], decrypt_int)
def decrypt(cypher, key):
"""Decrypts a cypher with the private key 'key'"""
return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int)
def verify(cypher, key):
"""Verifies a cypher with the public key 'key'"""
return gluechops(cypher, key['e'], key['n'], encrypt_int)
# Do doctest if we're not imported
if __name__ == "__main__":
import doctest
doctest.testmod()
__all__ = ["gen_pubpriv_keys", "encrypt", "decrypt", "sign", "verify"]

View File

@ -1,529 +0,0 @@
"""RSA module
Module for calculating large primes, and RSA encryption, decryption,
signing and verification. Includes generating public and private keys.
WARNING: this implementation does not use random padding, compression of the
cleartext input to prevent repetitions, or other common security improvements.
Use with care.
"""
__author__ = "Sybren Stuvel, Marloes de Boer, Ivo Tamboer, and Barry Mead"
__date__ = "2010-02-08"
__version__ = '2.0'
import math
import os
import random
import sys
import types
from rsa._compat import byte
# Display a warning that this insecure version is imported.
import warnings
warnings.warn('Insecure version of the RSA module is imported as %s' % __name__)
def bit_size(number):
"""Returns the number of bits required to hold a specific long number"""
return int(math.ceil(math.log(number,2)))
def gcd(p, q):
"""Returns the greatest common divisor of p and q
>>> gcd(48, 180)
12
"""
# Iterateive Version is faster and uses much less stack space
while q != 0:
if p < q: (p,q) = (q,p)
(p,q) = (q, p % q)
return p
def bytes2int(bytes):
"""Converts a list of bytes or a string to an integer
>>> (((128 * 256) + 64) * 256) + 15
8405007
>>> l = [128, 64, 15]
>>> bytes2int(l) #same as bytes2int('\x80@\x0f')
8405007
"""
if not (type(bytes) is types.ListType or type(bytes) is types.StringType):
raise TypeError("You must pass a string or a list")
# Convert byte stream to integer
integer = 0
for byte in bytes:
integer *= 256
if type(byte) is types.StringType: byte = ord(byte)
integer += byte
return integer
def int2bytes(number):
"""
Converts a number to a string of bytes
"""
if not (type(number) is types.LongType or type(number) is types.IntType):
raise TypeError("You must pass a long or an int")
string = ""
while number > 0:
string = "%s%s" % (byte(number & 0xFF), string)
number /= 256
return string
def to64(number):
"""Converts a number in the range of 0 to 63 into base 64 digit
character in the range of '0'-'9', 'A'-'Z', 'a'-'z','-','_'.
>>> to64(10)
'A'
"""
if not (type(number) is types.LongType or type(number) is types.IntType):
raise TypeError("You must pass a long or an int")
if 0 <= number <= 9: #00-09 translates to '0' - '9'
return byte(number + 48)
if 10 <= number <= 35:
return byte(number + 55) #10-35 translates to 'A' - 'Z'
if 36 <= number <= 61:
return byte(number + 61) #36-61 translates to 'a' - 'z'
if number == 62: # 62 translates to '-' (minus)
return byte(45)
if number == 63: # 63 translates to '_' (underscore)
return byte(95)
raise ValueError('Invalid Base64 value: %i' % number)
def from64(number):
"""Converts an ordinal character value in the range of
0-9,A-Z,a-z,-,_ to a number in the range of 0-63.
>>> from64(49)
1
"""
if not (type(number) is types.LongType or type(number) is types.IntType):
raise TypeError("You must pass a long or an int")
if 48 <= number <= 57: #ord('0') - ord('9') translates to 0-9
return(number - 48)
if 65 <= number <= 90: #ord('A') - ord('Z') translates to 10-35
return(number - 55)
if 97 <= number <= 122: #ord('a') - ord('z') translates to 36-61
return(number - 61)
if number == 45: #ord('-') translates to 62
return(62)
if number == 95: #ord('_') translates to 63
return(63)
raise ValueError('Invalid Base64 value: %i' % number)
def int2str64(number):
"""Converts a number to a string of base64 encoded characters in
the range of '0'-'9','A'-'Z,'a'-'z','-','_'.
>>> int2str64(123456789)
'7MyqL'
"""
if not (type(number) is types.LongType or type(number) is types.IntType):
raise TypeError("You must pass a long or an int")
string = ""
while number > 0:
string = "%s%s" % (to64(number & 0x3F), string)
number /= 64
return string
def str642int(string):
"""Converts a base64 encoded string into an integer.
The chars of this string in in the range '0'-'9','A'-'Z','a'-'z','-','_'
>>> str642int('7MyqL')
123456789
"""
if not (type(string) is types.ListType or type(string) is types.StringType):
raise TypeError("You must pass a string or a list")
integer = 0
for byte in string:
integer *= 64
if type(byte) is types.StringType: byte = ord(byte)
integer += from64(byte)
return integer
def read_random_int(nbits):
"""Reads a random integer of approximately nbits bits rounded up
to whole bytes"""
nbytes = int(math.ceil(nbits/8.))
randomdata = os.urandom(nbytes)
return bytes2int(randomdata)
def randint(minvalue, maxvalue):
"""Returns a random integer x with minvalue <= x <= maxvalue"""
# Safety - get a lot of random data even if the range is fairly
# small
min_nbits = 32
# The range of the random numbers we need to generate
range = (maxvalue - minvalue) + 1
# Which is this number of bytes
rangebytes = ((bit_size(range) + 7) / 8)
# Convert to bits, but make sure it's always at least min_nbits*2
rangebits = max(rangebytes * 8, min_nbits * 2)
# Take a random number of bits between min_nbits and rangebits
nbits = random.randint(min_nbits, rangebits)
return (read_random_int(nbits) % range) + minvalue
def jacobi(a, b):
"""Calculates the value of the Jacobi symbol (a/b)
where both a and b are positive integers, and b is odd
"""
if a == 0: return 0
result = 1
while a > 1:
if a & 1:
if ((a-1)*(b-1) >> 2) & 1:
result = -result
a, b = b % a, a
else:
if (((b * b) - 1) >> 3) & 1:
result = -result
a >>= 1
if a == 0: return 0
return result
def jacobi_witness(x, n):
"""Returns False if n is an Euler pseudo-prime with base x, and
True otherwise.
"""
j = jacobi(x, n) % n
f = pow(x, (n-1)/2, n)
if j == f: return False
return True
def randomized_primality_testing(n, k):
"""Calculates whether n is composite (which is always correct) or
prime (which is incorrect with error probability 2**-k)
Returns False if the number is composite, and True if it's
probably prime.
"""
# 50% of Jacobi-witnesses can report compositness of non-prime numbers
for i in range(k):
x = randint(1, n-1)
if jacobi_witness(x, n): return False
return True
def is_prime(number):
"""Returns True if the number is prime, and False otherwise.
>>> is_prime(42)
0
>>> is_prime(41)
1
"""
if randomized_primality_testing(number, 6):
# Prime, according to Jacobi
return True
# Not prime
return False
def getprime(nbits):
"""Returns a prime number of max. 'math.ceil(nbits/8)*8' bits. In
other words: nbits is rounded up to whole bytes.
>>> p = getprime(8)
>>> is_prime(p-1)
0
>>> is_prime(p)
1
>>> is_prime(p+1)
0
"""
while True:
integer = read_random_int(nbits)
# Make sure it's odd
integer |= 1
# Test for primeness
if is_prime(integer): break
# Retry if not prime
return integer
def are_relatively_prime(a, b):
"""Returns True if a and b are relatively prime, and False if they
are not.
>>> are_relatively_prime(2, 3)
1
>>> are_relatively_prime(2, 4)
0
"""
d = gcd(a, b)
return (d == 1)
def find_p_q(nbits):
"""Returns a tuple of two different primes of nbits bits"""
pbits = nbits + (nbits/16) #Make sure that p and q aren't too close
qbits = nbits - (nbits/16) #or the factoring programs can factor n
p = getprime(pbits)
while True:
q = getprime(qbits)
#Make sure p and q are different.
if not q == p: break
return (p, q)
def extended_gcd(a, b):
"""Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
"""
# r = gcd(a,b) i = multiplicitive inverse of a mod b
# or j = multiplicitive inverse of b mod a
# Neg return values for i or j are made positive mod b or a respectively
# Iterateive Version is faster and uses much less stack space
x = 0
y = 1
lx = 1
ly = 0
oa = a #Remember original a/b to remove
ob = b #negative values from return results
while b != 0:
q = long(a/b)
(a, b) = (b, a % b)
(x, lx) = ((lx - (q * x)),x)
(y, ly) = ((ly - (q * y)),y)
if (lx < 0): lx += ob #If neg wrap modulo orignal b
if (ly < 0): ly += oa #If neg wrap modulo orignal a
return (a, lx, ly) #Return only positive values
# Main function: calculate encryption and decryption keys
def calculate_keys(p, q, nbits):
"""Calculates an encryption and a decryption key for p and q, and
returns them as a tuple (e, d)"""
n = p * q
phi_n = (p-1) * (q-1)
while True:
# Make sure e has enough bits so we ensure "wrapping" through
# modulo n
e = max(65537,getprime(nbits/4))
if are_relatively_prime(e, n) and are_relatively_prime(e, phi_n): break
(d, i, j) = extended_gcd(e, phi_n)
if not d == 1:
raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n))
if (i < 0):
raise Exception("New extended_gcd shouldn't return negative values")
if not (e * i) % phi_n == 1:
raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n))
return (e, i)
def gen_keys(nbits):
"""Generate RSA keys of nbits bits. Returns (p, q, e, d).
Note: this can take a long time, depending on the key size.
"""
(p, q) = find_p_q(nbits)
(e, d) = calculate_keys(p, q, nbits)
return (p, q, e, d)
def newkeys(nbits):
"""Generates public and private keys, and returns them as (pub,
priv).
The public key consists of a dict {e: ..., , n: ....). The private
key consists of a dict {d: ...., p: ...., q: ....).
"""
nbits = max(9,nbits) # Don't let nbits go below 9 bits
(p, q, e, d) = gen_keys(nbits)
return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} )
def encrypt_int(message, ekey, n):
"""Encrypts a message using encryption key 'ekey', working modulo n"""
if type(message) is types.IntType:
message = long(message)
if not type(message) is types.LongType:
raise TypeError("You must pass a long or int")
if message < 0 or message > n:
raise OverflowError("The message is too long")
#Note: Bit exponents start at zero (bit counts start at 1) this is correct
safebit = bit_size(n) - 2 #compute safe bit (MSB - 1)
message += (1 << safebit) #add safebit to ensure folding
return pow(message, ekey, n)
def decrypt_int(cyphertext, dkey, n):
"""Decrypts a cypher text using the decryption key 'dkey', working
modulo n"""
message = pow(cyphertext, dkey, n)
safebit = bit_size(n) - 2 #compute safe bit (MSB - 1)
message -= (1 << safebit) #remove safebit before decode
return message
def encode64chops(chops):
"""base64encodes chops and combines them into a ',' delimited string"""
chips = [] #chips are character chops
for value in chops:
chips.append(int2str64(value))
#delimit chops with comma
encoded = ','.join(chips)
return encoded
def decode64chops(string):
"""base64decodes and makes a ',' delimited string into chops"""
chips = string.split(',') #split chops at commas
chops = []
for string in chips: #make char chops (chips) into chops
chops.append(str642int(string))
return chops
def chopstring(message, key, n, funcref):
"""Chops the 'message' into integers that fit into n,
leaving room for a safebit to be added to ensure that all
messages fold during exponentiation. The MSB of the number n
is not independant modulo n (setting it could cause overflow), so
use the next lower bit for the safebit. Therefore reserve 2-bits
in the number n for non-data bits. Calls specified encryption
function for each chop.
Used by 'encrypt' and 'sign'.
"""
msglen = len(message)
mbits = msglen * 8
#Set aside 2-bits so setting of safebit won't overflow modulo n.
nbits = bit_size(n) - 2 # leave room for safebit
nbytes = nbits / 8
blocks = msglen / nbytes
if msglen % nbytes > 0:
blocks += 1
cypher = []
for bindex in range(blocks):
offset = bindex * nbytes
block = message[offset:offset+nbytes]
value = bytes2int(block)
cypher.append(funcref(value, key, n))
return encode64chops(cypher) #Encode encrypted ints to base64 strings
def gluechops(string, key, n, funcref):
"""Glues chops back together into a string. calls
funcref(integer, key, n) for each chop.
Used by 'decrypt' and 'verify'.
"""
message = ""
chops = decode64chops(string) #Decode base64 strings into integer chops
for cpart in chops:
mpart = funcref(cpart, key, n) #Decrypt each chop
message += int2bytes(mpart) #Combine decrypted strings into a msg
return message
def encrypt(message, key):
"""Encrypts a string 'message' with the public key 'key'"""
if 'n' not in key:
raise Exception("You must use the public key with encrypt")
return chopstring(message, key['e'], key['n'], encrypt_int)
def sign(message, key):
"""Signs a string 'message' with the private key 'key'"""
if 'p' not in key:
raise Exception("You must use the private key with sign")
return chopstring(message, key['d'], key['p']*key['q'], encrypt_int)
def decrypt(cypher, key):
"""Decrypts a string 'cypher' with the private key 'key'"""
if 'p' not in key:
raise Exception("You must use the private key with decrypt")
return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int)
def verify(cypher, key):
"""Verifies a string 'cypher' with the public key 'key'"""
if 'n' not in key:
raise Exception("You must use the public key with verify")
return gluechops(cypher, key['e'], key['n'], decrypt_int)
# Do doctest if we're not imported
if __name__ == "__main__":
import doctest
doctest.testmod()
__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify"]

View File

@ -1,35 +0,0 @@
'''ASN.1 definitions.
Not all ASN.1-handling code use these definitions, but when it does, they should be here.
'''
from pyasn1.type import univ, namedtype, tag
class PubKeyHeader(univ.Sequence):
componentType = namedtype.NamedTypes(
namedtype.NamedType('oid', univ.ObjectIdentifier()),
namedtype.NamedType('parameters', univ.Null()),
)
class OpenSSLPubKey(univ.Sequence):
componentType = namedtype.NamedTypes(
namedtype.NamedType('header', PubKeyHeader()),
# This little hack (the implicit tag) allows us to get a Bit String as Octet String
namedtype.NamedType('key', univ.OctetString().subtype(
implicitTag=tag.Tag(tagClass=0, tagFormat=0, tagId=3))),
)
class AsnPubKey(univ.Sequence):
'''ASN.1 contents of DER encoded public key:
RSAPublicKey ::= SEQUENCE {
modulus INTEGER, -- n
publicExponent INTEGER, -- e
'''
componentType = namedtype.NamedTypes(
namedtype.NamedType('modulus', univ.Integer()),
namedtype.NamedType('publicExponent', univ.Integer()),
)

View File

@ -1,87 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Large file support
- break a file into smaller blocks, and encrypt them, and store the
encrypted blocks in another file.
- take such an encrypted files, decrypt its blocks, and reconstruct the
original file.
The encrypted file format is as follows, where || denotes byte concatenation:
FILE := VERSION || BLOCK || BLOCK ...
BLOCK := LENGTH || DATA
LENGTH := varint-encoded length of the subsequent data. Varint comes from
Google Protobuf, and encodes an integer into a variable number of bytes.
Each byte uses the 7 lowest bits to encode the value. The highest bit set
to 1 indicates the next byte is also part of the varint. The last byte will
have this bit set to 0.
This file format is called the VARBLOCK format, in line with the varint format
used to denote the block sizes.
'''
from rsa import key, common, pkcs1, varblock
from rsa._compat import byte
def encrypt_bigfile(infile, outfile, pub_key):
'''Encrypts a file, writing it to 'outfile' in VARBLOCK format.
:param infile: file-like object to read the cleartext from
:param outfile: file-like object to write the crypto in VARBLOCK format to
:param pub_key: :py:class:`rsa.PublicKey` to encrypt with
'''
if not isinstance(pub_key, key.PublicKey):
raise TypeError('Public key required, but got %r' % pub_key)
key_bytes = common.bit_size(pub_key.n) // 8
blocksize = key_bytes - 11 # keep space for PKCS#1 padding
# Write the version number to the VARBLOCK file
outfile.write(byte(varblock.VARBLOCK_VERSION))
# Encrypt and write each block
for block in varblock.yield_fixedblocks(infile, blocksize):
crypto = pkcs1.encrypt(block, pub_key)
varblock.write_varint(outfile, len(crypto))
outfile.write(crypto)
def decrypt_bigfile(infile, outfile, priv_key):
'''Decrypts an encrypted VARBLOCK file, writing it to 'outfile'
:param infile: file-like object to read the crypto in VARBLOCK format from
:param outfile: file-like object to write the cleartext to
:param priv_key: :py:class:`rsa.PrivateKey` to decrypt with
'''
if not isinstance(priv_key, key.PrivateKey):
raise TypeError('Private key required, but got %r' % priv_key)
for block in varblock.yield_varblocks(infile):
cleartext = pkcs1.decrypt(block, priv_key)
outfile.write(cleartext)
__all__ = ['encrypt_bigfile', 'decrypt_bigfile']

View File

@ -1,379 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Commandline scripts.
These scripts are called by the executables defined in setup.py.
'''
from __future__ import with_statement, print_function
import abc
import sys
from optparse import OptionParser
import rsa
import rsa.bigfile
import rsa.pkcs1
HASH_METHODS = sorted(rsa.pkcs1.HASH_METHODS.keys())
def keygen():
'''Key generator.'''
# Parse the CLI options
parser = OptionParser(usage='usage: %prog [options] keysize',
description='Generates a new RSA keypair of "keysize" bits.')
parser.add_option('--pubout', type='string',
help='Output filename for the public key. The public key is '
'not saved if this option is not present. You can use '
'pyrsa-priv2pub to create the public key file later.')
parser.add_option('-o', '--out', type='string',
help='Output filename for the private key. The key is '
'written to stdout if this option is not present.')
parser.add_option('--form',
help='key format of the private and public keys - default PEM',
choices=('PEM', 'DER'), default='PEM')
(cli, cli_args) = parser.parse_args(sys.argv[1:])
if len(cli_args) != 1:
parser.print_help()
raise SystemExit(1)
try:
keysize = int(cli_args[0])
except ValueError:
parser.print_help()
print('Not a valid number: %s' % cli_args[0], file=sys.stderr)
raise SystemExit(1)
print('Generating %i-bit key' % keysize, file=sys.stderr)
(pub_key, priv_key) = rsa.newkeys(keysize)
# Save public key
if cli.pubout:
print('Writing public key to %s' % cli.pubout, file=sys.stderr)
data = pub_key.save_pkcs1(format=cli.form)
with open(cli.pubout, 'wb') as outfile:
outfile.write(data)
# Save private key
data = priv_key.save_pkcs1(format=cli.form)
if cli.out:
print('Writing private key to %s' % cli.out, file=sys.stderr)
with open(cli.out, 'wb') as outfile:
outfile.write(data)
else:
print('Writing private key to stdout', file=sys.stderr)
sys.stdout.write(data)
class CryptoOperation(object):
'''CLI callable that operates with input, output, and a key.'''
__metaclass__ = abc.ABCMeta
keyname = 'public' # or 'private'
usage = 'usage: %%prog [options] %(keyname)s_key'
description = None
operation = 'decrypt'
operation_past = 'decrypted'
operation_progressive = 'decrypting'
input_help = 'Name of the file to %(operation)s. Reads from stdin if ' \
'not specified.'
output_help = 'Name of the file to write the %(operation_past)s file ' \
'to. Written to stdout if this option is not present.'
expected_cli_args = 1
has_output = True
key_class = rsa.PublicKey
def __init__(self):
self.usage = self.usage % self.__class__.__dict__
self.input_help = self.input_help % self.__class__.__dict__
self.output_help = self.output_help % self.__class__.__dict__
@abc.abstractmethod
def perform_operation(self, indata, key, cli_args=None):
'''Performs the program's operation.
Implement in a subclass.
:returns: the data to write to the output.
'''
def __call__(self):
'''Runs the program.'''
(cli, cli_args) = self.parse_cli()
key = self.read_key(cli_args[0], cli.keyform)
indata = self.read_infile(cli.input)
print(self.operation_progressive.title(), file=sys.stderr)
outdata = self.perform_operation(indata, key, cli_args)
if self.has_output:
self.write_outfile(outdata, cli.output)
def parse_cli(self):
'''Parse the CLI options
:returns: (cli_opts, cli_args)
'''
parser = OptionParser(usage=self.usage, description=self.description)
parser.add_option('-i', '--input', type='string', help=self.input_help)
if self.has_output:
parser.add_option('-o', '--output', type='string', help=self.output_help)
parser.add_option('--keyform',
help='Key format of the %s key - default PEM' % self.keyname,
choices=('PEM', 'DER'), default='PEM')
(cli, cli_args) = parser.parse_args(sys.argv[1:])
if len(cli_args) != self.expected_cli_args:
parser.print_help()
raise SystemExit(1)
return (cli, cli_args)
def read_key(self, filename, keyform):
'''Reads a public or private key.'''
print('Reading %s key from %s' % (self.keyname, filename), file=sys.stderr)
with open(filename, 'rb') as keyfile:
keydata = keyfile.read()
return self.key_class.load_pkcs1(keydata, keyform)
def read_infile(self, inname):
'''Read the input file'''
if inname:
print('Reading input from %s' % inname, file=sys.stderr)
with open(inname, 'rb') as infile:
return infile.read()
print('Reading input from stdin', file=sys.stderr)
return sys.stdin.read()
def write_outfile(self, outdata, outname):
'''Write the output file'''
if outname:
print('Writing output to %s' % outname, file=sys.stderr)
with open(outname, 'wb') as outfile:
outfile.write(outdata)
else:
print('Writing output to stdout', file=sys.stderr)
sys.stdout.write(outdata)
class EncryptOperation(CryptoOperation):
'''Encrypts a file.'''
keyname = 'public'
description = ('Encrypts a file. The file must be shorter than the key '
'length in order to be encrypted. For larger files, use the '
'pyrsa-encrypt-bigfile command.')
operation = 'encrypt'
operation_past = 'encrypted'
operation_progressive = 'encrypting'
def perform_operation(self, indata, pub_key, cli_args=None):
'''Encrypts files.'''
return rsa.encrypt(indata, pub_key)
class DecryptOperation(CryptoOperation):
'''Decrypts a file.'''
keyname = 'private'
description = ('Decrypts a file. The original file must be shorter than '
'the key length in order to have been encrypted. For larger '
'files, use the pyrsa-decrypt-bigfile command.')
operation = 'decrypt'
operation_past = 'decrypted'
operation_progressive = 'decrypting'
key_class = rsa.PrivateKey
def perform_operation(self, indata, priv_key, cli_args=None):
'''Decrypts files.'''
return rsa.decrypt(indata, priv_key)
class SignOperation(CryptoOperation):
'''Signs a file.'''
keyname = 'private'
usage = 'usage: %%prog [options] private_key hash_method'
description = ('Signs a file, outputs the signature. Choose the hash '
'method from %s' % ', '.join(HASH_METHODS))
operation = 'sign'
operation_past = 'signature'
operation_progressive = 'Signing'
key_class = rsa.PrivateKey
expected_cli_args = 2
output_help = ('Name of the file to write the signature to. Written '
'to stdout if this option is not present.')
def perform_operation(self, indata, priv_key, cli_args):
'''Decrypts files.'''
hash_method = cli_args[1]
if hash_method not in HASH_METHODS:
raise SystemExit('Invalid hash method, choose one of %s' %
', '.join(HASH_METHODS))
return rsa.sign(indata, priv_key, hash_method)
class VerifyOperation(CryptoOperation):
'''Verify a signature.'''
keyname = 'public'
usage = 'usage: %%prog [options] public_key signature_file'
description = ('Verifies a signature, exits with status 0 upon success, '
'prints an error message and exits with status 1 upon error.')
operation = 'verify'
operation_past = 'verified'
operation_progressive = 'Verifying'
key_class = rsa.PublicKey
expected_cli_args = 2
has_output = False
def perform_operation(self, indata, pub_key, cli_args):
'''Decrypts files.'''
signature_file = cli_args[1]
with open(signature_file, 'rb') as sigfile:
signature = sigfile.read()
try:
rsa.verify(indata, signature, pub_key)
except rsa.VerificationError:
raise SystemExit('Verification failed.')
print('Verification OK', file=sys.stderr)
class BigfileOperation(CryptoOperation):
'''CryptoOperation that doesn't read the entire file into memory.'''
def __init__(self):
CryptoOperation.__init__(self)
self.file_objects = []
def __del__(self):
'''Closes any open file handles.'''
for fobj in self.file_objects:
fobj.close()
def __call__(self):
'''Runs the program.'''
(cli, cli_args) = self.parse_cli()
key = self.read_key(cli_args[0], cli.keyform)
# Get the file handles
infile = self.get_infile(cli.input)
outfile = self.get_outfile(cli.output)
# Call the operation
print(self.operation_progressive.title(), file=sys.stderr)
self.perform_operation(infile, outfile, key, cli_args)
def get_infile(self, inname):
'''Returns the input file object'''
if inname:
print('Reading input from %s' % inname, file=sys.stderr)
fobj = open(inname, 'rb')
self.file_objects.append(fobj)
else:
print('Reading input from stdin', file=sys.stderr)
fobj = sys.stdin
return fobj
def get_outfile(self, outname):
'''Returns the output file object'''
if outname:
print('Will write output to %s' % outname, file=sys.stderr)
fobj = open(outname, 'wb')
self.file_objects.append(fobj)
else:
print('Will write output to stdout', file=sys.stderr)
fobj = sys.stdout
return fobj
class EncryptBigfileOperation(BigfileOperation):
'''Encrypts a file to VARBLOCK format.'''
keyname = 'public'
description = ('Encrypts a file to an encrypted VARBLOCK file. The file '
'can be larger than the key length, but the output file is only '
'compatible with Python-RSA.')
operation = 'encrypt'
operation_past = 'encrypted'
operation_progressive = 'encrypting'
def perform_operation(self, infile, outfile, pub_key, cli_args=None):
'''Encrypts files to VARBLOCK.'''
return rsa.bigfile.encrypt_bigfile(infile, outfile, pub_key)
class DecryptBigfileOperation(BigfileOperation):
'''Decrypts a file in VARBLOCK format.'''
keyname = 'private'
description = ('Decrypts an encrypted VARBLOCK file that was encrypted '
'with pyrsa-encrypt-bigfile')
operation = 'decrypt'
operation_past = 'decrypted'
operation_progressive = 'decrypting'
key_class = rsa.PrivateKey
def perform_operation(self, infile, outfile, priv_key, cli_args=None):
'''Decrypts a VARBLOCK file.'''
return rsa.bigfile.decrypt_bigfile(infile, outfile, priv_key)
encrypt = EncryptOperation()
decrypt = DecryptOperation()
sign = SignOperation()
verify = VerifyOperation()
encrypt_bigfile = EncryptBigfileOperation()
decrypt_bigfile = DecryptBigfileOperation()

View File

@ -1,185 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Common functionality shared by several modules.'''
def bit_size(num):
'''
Number of bits needed to represent a integer excluding any prefix
0 bits.
As per definition from http://wiki.python.org/moin/BitManipulation and
to match the behavior of the Python 3 API.
Usage::
>>> bit_size(1023)
10
>>> bit_size(1024)
11
>>> bit_size(1025)
11
:param num:
Integer value. If num is 0, returns 0. Only the absolute value of the
number is considered. Therefore, signed integers will be abs(num)
before the number's bit length is determined.
:returns:
Returns the number of bits in the integer.
'''
if num == 0:
return 0
if num < 0:
num = -num
# Make sure this is an int and not a float.
num & 1
hex_num = "%x" % num
return ((len(hex_num) - 1) * 4) + {
'0':0, '1':1, '2':2, '3':2,
'4':3, '5':3, '6':3, '7':3,
'8':4, '9':4, 'a':4, 'b':4,
'c':4, 'd':4, 'e':4, 'f':4,
}[hex_num[0]]
def _bit_size(number):
'''
Returns the number of bits required to hold a specific long number.
'''
if number < 0:
raise ValueError('Only nonnegative numbers possible: %s' % number)
if number == 0:
return 0
# This works, even with very large numbers. When using math.log(number, 2),
# you'll get rounding errors and it'll fail.
bits = 0
while number:
bits += 1
number >>= 1
return bits
def byte_size(number):
'''
Returns the number of bytes required to hold a specific long number.
The number of bytes is rounded up.
Usage::
>>> byte_size(1 << 1023)
128
>>> byte_size((1 << 1024) - 1)
128
>>> byte_size(1 << 1024)
129
:param number:
An unsigned integer
:returns:
The number of bytes required to hold a specific long number.
'''
quanta, mod = divmod(bit_size(number), 8)
if mod or number == 0:
quanta += 1
return quanta
#return int(math.ceil(bit_size(number) / 8.0))
def extended_gcd(a, b):
'''Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb
'''
# r = gcd(a,b) i = multiplicitive inverse of a mod b
# or j = multiplicitive inverse of b mod a
# Neg return values for i or j are made positive mod b or a respectively
# Iterateive Version is faster and uses much less stack space
x = 0
y = 1
lx = 1
ly = 0
oa = a #Remember original a/b to remove
ob = b #negative values from return results
while b != 0:
q = a // b
(a, b) = (b, a % b)
(x, lx) = ((lx - (q * x)),x)
(y, ly) = ((ly - (q * y)),y)
if (lx < 0): lx += ob #If neg wrap modulo orignal b
if (ly < 0): ly += oa #If neg wrap modulo orignal a
return (a, lx, ly) #Return only positive values
def inverse(x, n):
'''Returns x^-1 (mod n)
>>> inverse(7, 4)
3
>>> (inverse(143, 4) * 143) % 4
1
'''
(divider, inv, _) = extended_gcd(x, n)
if divider != 1:
raise ValueError("x (%d) and n (%d) are not relatively prime" % (x, n))
return inv
def crt(a_values, modulo_values):
'''Chinese Remainder Theorem.
Calculates x such that x = a[i] (mod m[i]) for each i.
:param a_values: the a-values of the above equation
:param modulo_values: the m-values of the above equation
:returns: x such that x = a[i] (mod m[i]) for each i
>>> crt([2, 3], [3, 5])
8
>>> crt([2, 3, 2], [3, 5, 7])
23
>>> crt([2, 3, 0], [7, 11, 15])
135
'''
m = 1
x = 0
for modulo in modulo_values:
m *= modulo
for (m_i, a_i) in zip(modulo_values, a_values):
M_i = m // m_i
inv = inverse(M_i, m_i)
x = (x + a_i * M_i * inv) % m
return x
if __name__ == '__main__':
import doctest
doctest.testmod()

View File

@ -1,58 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Core mathematical operations.
This is the actual core RSA implementation, which is only defined
mathematically on integers.
'''
from rsa._compat import is_integer
def assert_int(var, name):
if is_integer(var):
return
raise TypeError('%s should be an integer, not %s' % (name, var.__class__))
def encrypt_int(message, ekey, n):
'''Encrypts a message using encryption key 'ekey', working modulo n'''
assert_int(message, 'message')
assert_int(ekey, 'ekey')
assert_int(n, 'n')
if message < 0:
raise ValueError('Only non-negative numbers are supported')
if message > n:
raise OverflowError("The message %i is too long for n=%i" % (message, n))
return pow(message, ekey, n)
def decrypt_int(cyphertext, dkey, n):
'''Decrypts a cypher text using the decryption key 'dkey', working
modulo n'''
assert_int(cyphertext, 'cyphertext')
assert_int(dkey, 'dkey')
assert_int(n, 'n')
message = pow(cyphertext, dkey, n)
return message

View File

@ -1,612 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''RSA key generation code.
Create new keys with the newkeys() function. It will give you a PublicKey and a
PrivateKey object.
Loading and saving keys requires the pyasn1 module. This module is imported as
late as possible, such that other functionality will remain working in absence
of pyasn1.
'''
import logging
from rsa._compat import b, bytes_type
import rsa.prime
import rsa.pem
import rsa.common
log = logging.getLogger(__name__)
class AbstractKey(object):
'''Abstract superclass for private and public keys.'''
@classmethod
def load_pkcs1(cls, keyfile, format='PEM'):
r'''Loads a key in PKCS#1 DER or PEM format.
:param keyfile: contents of a DER- or PEM-encoded file that contains
the public key.
:param format: the format of the file to load; 'PEM' or 'DER'
:return: a PublicKey object
'''
methods = {
'PEM': cls._load_pkcs1_pem,
'DER': cls._load_pkcs1_der,
}
if format not in methods:
formats = ', '.join(sorted(methods.keys()))
raise ValueError('Unsupported format: %r, try one of %s' % (format,
formats))
method = methods[format]
return method(keyfile)
def save_pkcs1(self, format='PEM'):
'''Saves the public key in PKCS#1 DER or PEM format.
:param format: the format to save; 'PEM' or 'DER'
:returns: the DER- or PEM-encoded public key.
'''
methods = {
'PEM': self._save_pkcs1_pem,
'DER': self._save_pkcs1_der,
}
if format not in methods:
formats = ', '.join(sorted(methods.keys()))
raise ValueError('Unsupported format: %r, try one of %s' % (format,
formats))
method = methods[format]
return method()
class PublicKey(AbstractKey):
'''Represents a public RSA key.
This key is also known as the 'encryption key'. It contains the 'n' and 'e'
values.
Supports attributes as well as dictionary-like access. Attribute accesss is
faster, though.
>>> PublicKey(5, 3)
PublicKey(5, 3)
>>> key = PublicKey(5, 3)
>>> key.n
5
>>> key['n']
5
>>> key.e
3
>>> key['e']
3
'''
__slots__ = ('n', 'e')
def __init__(self, n, e):
self.n = n
self.e = e
def __getitem__(self, key):
return getattr(self, key)
def __repr__(self):
return 'PublicKey(%i, %i)' % (self.n, self.e)
def __eq__(self, other):
if other is None:
return False
if not isinstance(other, PublicKey):
return False
return self.n == other.n and self.e == other.e
def __ne__(self, other):
return not (self == other)
@classmethod
def _load_pkcs1_der(cls, keyfile):
r'''Loads a key in PKCS#1 DER format.
@param keyfile: contents of a DER-encoded file that contains the public
key.
@return: a PublicKey object
First let's construct a DER encoded key:
>>> import base64
>>> b64der = 'MAwCBQCNGmYtAgMBAAE='
>>> der = base64.decodestring(b64der)
This loads the file:
>>> PublicKey._load_pkcs1_der(der)
PublicKey(2367317549, 65537)
'''
from pyasn1.codec.der import decoder
from rsa.asn1 import AsnPubKey
(priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey())
return cls(n=int(priv['modulus']), e=int(priv['publicExponent']))
def _save_pkcs1_der(self):
'''Saves the public key in PKCS#1 DER format.
@returns: the DER-encoded public key.
'''
from pyasn1.codec.der import encoder
from rsa.asn1 import AsnPubKey
# Create the ASN object
asn_key = AsnPubKey()
asn_key.setComponentByName('modulus', self.n)
asn_key.setComponentByName('publicExponent', self.e)
return encoder.encode(asn_key)
@classmethod
def _load_pkcs1_pem(cls, keyfile):
'''Loads a PKCS#1 PEM-encoded public key file.
The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and
after the "-----END RSA PUBLIC KEY-----" lines is ignored.
@param keyfile: contents of a PEM-encoded file that contains the public
key.
@return: a PublicKey object
'''
der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY')
return cls._load_pkcs1_der(der)
def _save_pkcs1_pem(self):
'''Saves a PKCS#1 PEM-encoded public key file.
@return: contents of a PEM-encoded file that contains the public key.
'''
der = self._save_pkcs1_der()
return rsa.pem.save_pem(der, 'RSA PUBLIC KEY')
@classmethod
def load_pkcs1_openssl_pem(cls, keyfile):
'''Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL.
These files can be recognised in that they start with BEGIN PUBLIC KEY
rather than BEGIN RSA PUBLIC KEY.
The contents of the file before the "-----BEGIN PUBLIC KEY-----" and
after the "-----END PUBLIC KEY-----" lines is ignored.
@param keyfile: contents of a PEM-encoded file that contains the public
key, from OpenSSL.
@return: a PublicKey object
'''
der = rsa.pem.load_pem(keyfile, 'PUBLIC KEY')
return cls.load_pkcs1_openssl_der(der)
@classmethod
def load_pkcs1_openssl_der(cls, keyfile):
'''Loads a PKCS#1 DER-encoded public key file from OpenSSL.
@param keyfile: contents of a DER-encoded file that contains the public
key, from OpenSSL.
@return: a PublicKey object
'''
from rsa.asn1 import OpenSSLPubKey
from pyasn1.codec.der import decoder
from pyasn1.type import univ
(keyinfo, _) = decoder.decode(keyfile, asn1Spec=OpenSSLPubKey())
if keyinfo['header']['oid'] != univ.ObjectIdentifier('1.2.840.113549.1.1.1'):
raise TypeError("This is not a DER-encoded OpenSSL-compatible public key")
return cls._load_pkcs1_der(keyinfo['key'][1:])
class PrivateKey(AbstractKey):
'''Represents a private RSA key.
This key is also known as the 'decryption key'. It contains the 'n', 'e',
'd', 'p', 'q' and other values.
Supports attributes as well as dictionary-like access. Attribute accesss is
faster, though.
>>> PrivateKey(3247, 65537, 833, 191, 17)
PrivateKey(3247, 65537, 833, 191, 17)
exp1, exp2 and coef don't have to be given, they will be calculated:
>>> pk = PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
>>> pk.exp1
55063
>>> pk.exp2
10095
>>> pk.coef
50797
If you give exp1, exp2 or coef, they will be used as-is:
>>> pk = PrivateKey(1, 2, 3, 4, 5, 6, 7, 8)
>>> pk.exp1
6
>>> pk.exp2
7
>>> pk.coef
8
'''
__slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef')
def __init__(self, n, e, d, p, q, exp1=None, exp2=None, coef=None):
self.n = n
self.e = e
self.d = d
self.p = p
self.q = q
# Calculate the other values if they aren't supplied
if exp1 is None:
self.exp1 = int(d % (p - 1))
else:
self.exp1 = exp1
if exp1 is None:
self.exp2 = int(d % (q - 1))
else:
self.exp2 = exp2
if coef is None:
self.coef = rsa.common.inverse(q, p)
else:
self.coef = coef
def __getitem__(self, key):
return getattr(self, key)
def __repr__(self):
return 'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self
def __eq__(self, other):
if other is None:
return False
if not isinstance(other, PrivateKey):
return False
return (self.n == other.n and
self.e == other.e and
self.d == other.d and
self.p == other.p and
self.q == other.q and
self.exp1 == other.exp1 and
self.exp2 == other.exp2 and
self.coef == other.coef)
def __ne__(self, other):
return not (self == other)
@classmethod
def _load_pkcs1_der(cls, keyfile):
r'''Loads a key in PKCS#1 DER format.
@param keyfile: contents of a DER-encoded file that contains the private
key.
@return: a PrivateKey object
First let's construct a DER encoded key:
>>> import base64
>>> b64der = 'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt'
>>> der = base64.decodestring(b64der)
This loads the file:
>>> PrivateKey._load_pkcs1_der(der)
PrivateKey(3727264081, 65537, 3349121513, 65063, 57287)
'''
from pyasn1.codec.der import decoder
(priv, _) = decoder.decode(keyfile)
# ASN.1 contents of DER encoded private key:
#
# RSAPrivateKey ::= SEQUENCE {
# version Version,
# modulus INTEGER, -- n
# publicExponent INTEGER, -- e
# privateExponent INTEGER, -- d
# prime1 INTEGER, -- p
# prime2 INTEGER, -- q
# exponent1 INTEGER, -- d mod (p-1)
# exponent2 INTEGER, -- d mod (q-1)
# coefficient INTEGER, -- (inverse of q) mod p
# otherPrimeInfos OtherPrimeInfos OPTIONAL
# }
if priv[0] != 0:
raise ValueError('Unable to read this file, version %s != 0' % priv[0])
as_ints = tuple(int(x) for x in priv[1:9])
return cls(*as_ints)
def _save_pkcs1_der(self):
'''Saves the private key in PKCS#1 DER format.
@returns: the DER-encoded private key.
'''
from pyasn1.type import univ, namedtype
from pyasn1.codec.der import encoder
class AsnPrivKey(univ.Sequence):
componentType = namedtype.NamedTypes(
namedtype.NamedType('version', univ.Integer()),
namedtype.NamedType('modulus', univ.Integer()),
namedtype.NamedType('publicExponent', univ.Integer()),
namedtype.NamedType('privateExponent', univ.Integer()),
namedtype.NamedType('prime1', univ.Integer()),
namedtype.NamedType('prime2', univ.Integer()),
namedtype.NamedType('exponent1', univ.Integer()),
namedtype.NamedType('exponent2', univ.Integer()),
namedtype.NamedType('coefficient', univ.Integer()),
)
# Create the ASN object
asn_key = AsnPrivKey()
asn_key.setComponentByName('version', 0)
asn_key.setComponentByName('modulus', self.n)
asn_key.setComponentByName('publicExponent', self.e)
asn_key.setComponentByName('privateExponent', self.d)
asn_key.setComponentByName('prime1', self.p)
asn_key.setComponentByName('prime2', self.q)
asn_key.setComponentByName('exponent1', self.exp1)
asn_key.setComponentByName('exponent2', self.exp2)
asn_key.setComponentByName('coefficient', self.coef)
return encoder.encode(asn_key)
@classmethod
def _load_pkcs1_pem(cls, keyfile):
'''Loads a PKCS#1 PEM-encoded private key file.
The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and
after the "-----END RSA PRIVATE KEY-----" lines is ignored.
@param keyfile: contents of a PEM-encoded file that contains the private
key.
@return: a PrivateKey object
'''
der = rsa.pem.load_pem(keyfile, b('RSA PRIVATE KEY'))
return cls._load_pkcs1_der(der)
def _save_pkcs1_pem(self):
'''Saves a PKCS#1 PEM-encoded private key file.
@return: contents of a PEM-encoded file that contains the private key.
'''
der = self._save_pkcs1_der()
return rsa.pem.save_pem(der, b('RSA PRIVATE KEY'))
def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True):
''''Returns a tuple of two different primes of nbits bits each.
The resulting p * q has exacty 2 * nbits bits, and the returned p and q
will not be equal.
:param nbits: the number of bits in each of p and q.
:param getprime_func: the getprime function, defaults to
:py:func:`rsa.prime.getprime`.
*Introduced in Python-RSA 3.1*
:param accurate: whether to enable accurate mode or not.
:returns: (p, q), where p > q
>>> (p, q) = find_p_q(128)
>>> from rsa import common
>>> common.bit_size(p * q)
256
When not in accurate mode, the number of bits can be slightly less
>>> (p, q) = find_p_q(128, accurate=False)
>>> from rsa import common
>>> common.bit_size(p * q) <= 256
True
>>> common.bit_size(p * q) > 240
True
'''
total_bits = nbits * 2
# Make sure that p and q aren't too close or the factoring programs can
# factor n.
shift = nbits // 16
pbits = nbits + shift
qbits = nbits - shift
# Choose the two initial primes
log.debug('find_p_q(%i): Finding p', nbits)
p = getprime_func(pbits)
log.debug('find_p_q(%i): Finding q', nbits)
q = getprime_func(qbits)
def is_acceptable(p, q):
'''Returns True iff p and q are acceptable:
- p and q differ
- (p * q) has the right nr of bits (when accurate=True)
'''
if p == q:
return False
if not accurate:
return True
# Make sure we have just the right amount of bits
found_size = rsa.common.bit_size(p * q)
return total_bits == found_size
# Keep choosing other primes until they match our requirements.
change_p = False
while not is_acceptable(p, q):
# Change p on one iteration and q on the other
if change_p:
p = getprime_func(pbits)
else:
q = getprime_func(qbits)
change_p = not change_p
# We want p > q as described on
# http://www.di-mgt.com.au/rsa_alg.html#crt
return (max(p, q), min(p, q))
def calculate_keys(p, q, nbits):
'''Calculates an encryption and a decryption key given p and q, and
returns them as a tuple (e, d)
'''
phi_n = (p - 1) * (q - 1)
# A very common choice for e is 65537
e = 65537
try:
d = rsa.common.inverse(e, phi_n)
except ValueError:
raise ValueError("e (%d) and phi_n (%d) are not relatively prime" %
(e, phi_n))
if (e * d) % phi_n != 1:
raise ValueError("e (%d) and d (%d) are not mult. inv. modulo "
"phi_n (%d)" % (e, d, phi_n))
return (e, d)
def gen_keys(nbits, getprime_func, accurate=True):
'''Generate RSA keys of nbits bits. Returns (p, q, e, d).
Note: this can take a long time, depending on the key size.
:param nbits: the total number of bits in ``p`` and ``q``. Both ``p`` and
``q`` will use ``nbits/2`` bits.
:param getprime_func: either :py:func:`rsa.prime.getprime` or a function
with similar signature.
'''
(p, q) = find_p_q(nbits // 2, getprime_func, accurate)
(e, d) = calculate_keys(p, q, nbits // 2)
return (p, q, e, d)
def newkeys(nbits, accurate=True, poolsize=1):
'''Generates public and private keys, and returns them as (pub, priv).
The public key is also known as the 'encryption key', and is a
:py:class:`rsa.PublicKey` object. The private key is also known as the
'decryption key' and is a :py:class:`rsa.PrivateKey` object.
:param nbits: the number of bits required to store ``n = p*q``.
:param accurate: when True, ``n`` will have exactly the number of bits you
asked for. However, this makes key generation much slower. When False,
`n`` may have slightly less bits.
:param poolsize: the number of processes to use to generate the prime
numbers. If set to a number > 1, a parallel algorithm will be used.
This requires Python 2.6 or newer.
:returns: a tuple (:py:class:`rsa.PublicKey`, :py:class:`rsa.PrivateKey`)
The ``poolsize`` parameter was added in *Python-RSA 3.1* and requires
Python 2.6 or newer.
'''
if nbits < 16:
raise ValueError('Key too small')
if poolsize < 1:
raise ValueError('Pool size (%i) should be >= 1' % poolsize)
# Determine which getprime function to use
if poolsize > 1:
from rsa import parallel
import functools
getprime_func = functools.partial(parallel.getprime, poolsize=poolsize)
else: getprime_func = rsa.prime.getprime
# Generate the key components
(p, q, e, d) = gen_keys(nbits, getprime_func)
# Create the key objects
n = p * q
return (
PublicKey(n, e),
PrivateKey(n, e, d, p, q)
)
__all__ = ['PublicKey', 'PrivateKey', 'newkeys']
if __name__ == '__main__':
import doctest
try:
for count in range(100):
(failures, tests) = doctest.testmod()
if failures:
break
if (count and count % 10 == 0) or count == 1:
print('%i times' % count)
except KeyboardInterrupt:
print('Aborted')
else:
print('Doctests done')

View File

@ -1,94 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Functions for parallel computation on multiple cores.
Introduced in Python-RSA 3.1.
.. note::
Requires Python 2.6 or newer.
'''
from __future__ import print_function
import multiprocessing as mp
import rsa.prime
import rsa.randnum
def _find_prime(nbits, pipe):
while True:
integer = rsa.randnum.read_random_int(nbits)
# Make sure it's odd
integer |= 1
# Test for primeness
if rsa.prime.is_prime(integer):
pipe.send(integer)
return
def getprime(nbits, poolsize):
'''Returns a prime number that can be stored in 'nbits' bits.
Works in multiple threads at the same time.
>>> p = getprime(128, 3)
>>> rsa.prime.is_prime(p-1)
False
>>> rsa.prime.is_prime(p)
True
>>> rsa.prime.is_prime(p+1)
False
>>> from rsa import common
>>> common.bit_size(p) == 128
True
'''
(pipe_recv, pipe_send) = mp.Pipe(duplex=False)
# Create processes
procs = [mp.Process(target=_find_prime, args=(nbits, pipe_send))
for _ in range(poolsize)]
[p.start() for p in procs]
result = pipe_recv.recv()
[p.terminate() for p in procs]
return result
__all__ = ['getprime']
if __name__ == '__main__':
print('Running doctests 1000x or until failure')
import doctest
for count in range(100):
(failures, tests) = doctest.testmod()
if failures:
break
if count and count % 10 == 0:
print('%i times' % count)
print('Doctests done')

View File

@ -1,120 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Functions that load and write PEM-encoded files.'''
import base64
from rsa._compat import b, is_bytes
def _markers(pem_marker):
'''
Returns the start and end PEM markers
'''
if is_bytes(pem_marker):
pem_marker = pem_marker.decode('utf-8')
return (b('-----BEGIN %s-----' % pem_marker),
b('-----END %s-----' % pem_marker))
def load_pem(contents, pem_marker):
'''Loads a PEM file.
@param contents: the contents of the file to interpret
@param pem_marker: the marker of the PEM content, such as 'RSA PRIVATE KEY'
when your file has '-----BEGIN RSA PRIVATE KEY-----' and
'-----END RSA PRIVATE KEY-----' markers.
@return the base64-decoded content between the start and end markers.
@raise ValueError: when the content is invalid, for example when the start
marker cannot be found.
'''
(pem_start, pem_end) = _markers(pem_marker)
pem_lines = []
in_pem_part = False
for line in contents.splitlines():
line = line.strip()
# Skip empty lines
if not line:
continue
# Handle start marker
if line == pem_start:
if in_pem_part:
raise ValueError('Seen start marker "%s" twice' % pem_start)
in_pem_part = True
continue
# Skip stuff before first marker
if not in_pem_part:
continue
# Handle end marker
if in_pem_part and line == pem_end:
in_pem_part = False
break
# Load fields
if b(':') in line:
continue
pem_lines.append(line)
# Do some sanity checks
if not pem_lines:
raise ValueError('No PEM start marker "%s" found' % pem_start)
if in_pem_part:
raise ValueError('No PEM end marker "%s" found' % pem_end)
# Base64-decode the contents
pem = b('').join(pem_lines)
return base64.decodestring(pem)
def save_pem(contents, pem_marker):
'''Saves a PEM file.
@param contents: the contents to encode in PEM format
@param pem_marker: the marker of the PEM content, such as 'RSA PRIVATE KEY'
when your file has '-----BEGIN RSA PRIVATE KEY-----' and
'-----END RSA PRIVATE KEY-----' markers.
@return the base64-encoded content between the start and end markers.
'''
(pem_start, pem_end) = _markers(pem_marker)
b64 = base64.encodestring(contents).replace(b('\n'), b(''))
pem_lines = [pem_start]
for block_start in range(0, len(b64), 64):
block = b64[block_start:block_start + 64]
pem_lines.append(block)
pem_lines.append(pem_end)
pem_lines.append(b(''))
return b('\n').join(pem_lines)

View File

@ -1,391 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Functions for PKCS#1 version 1.5 encryption and signing
This module implements certain functionality from PKCS#1 version 1.5. For a
very clear example, read http://www.di-mgt.com.au/rsa_alg.html#pkcs1schemes
At least 8 bytes of random padding is used when encrypting a message. This makes
these methods much more secure than the ones in the ``rsa`` module.
WARNING: this module leaks information when decryption or verification fails.
The exceptions that are raised contain the Python traceback information, which
can be used to deduce where in the process the failure occurred. DO NOT PASS
SUCH INFORMATION to your users.
'''
import hashlib
import os
from rsa._compat import b
from rsa import common, transform, core, varblock
# ASN.1 codes that describe the hash algorithm used.
HASH_ASN1 = {
'MD5': b('\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10'),
'SHA-1': b('\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'),
'SHA-256': b('\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'),
'SHA-384': b('\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30'),
'SHA-512': b('\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'),
}
HASH_METHODS = {
'MD5': hashlib.md5,
'SHA-1': hashlib.sha1,
'SHA-256': hashlib.sha256,
'SHA-384': hashlib.sha384,
'SHA-512': hashlib.sha512,
}
class CryptoError(Exception):
'''Base class for all exceptions in this module.'''
class DecryptionError(CryptoError):
'''Raised when decryption fails.'''
class VerificationError(CryptoError):
'''Raised when verification fails.'''
def _pad_for_encryption(message, target_length):
r'''Pads the message for encryption, returning the padded message.
:return: 00 02 RANDOM_DATA 00 MESSAGE
>>> block = _pad_for_encryption('hello', 16)
>>> len(block)
16
>>> block[0:2]
'\x00\x02'
>>> block[-6:]
'\x00hello'
'''
max_msglength = target_length - 11
msglength = len(message)
if msglength > max_msglength:
raise OverflowError('%i bytes needed for message, but there is only'
' space for %i' % (msglength, max_msglength))
# Get random padding
padding = b('')
padding_length = target_length - msglength - 3
# We remove 0-bytes, so we'll end up with less padding than we've asked for,
# so keep adding data until we're at the correct length.
while len(padding) < padding_length:
needed_bytes = padding_length - len(padding)
# Always read at least 8 bytes more than we need, and trim off the rest
# after removing the 0-bytes. This increases the chance of getting
# enough bytes, especially when needed_bytes is small
new_padding = os.urandom(needed_bytes + 5)
new_padding = new_padding.replace(b('\x00'), b(''))
padding = padding + new_padding[:needed_bytes]
assert len(padding) == padding_length
return b('').join([b('\x00\x02'),
padding,
b('\x00'),
message])
def _pad_for_signing(message, target_length):
r'''Pads the message for signing, returning the padded message.
The padding is always a repetition of FF bytes.
:return: 00 01 PADDING 00 MESSAGE
>>> block = _pad_for_signing('hello', 16)
>>> len(block)
16
>>> block[0:2]
'\x00\x01'
>>> block[-6:]
'\x00hello'
>>> block[2:-6]
'\xff\xff\xff\xff\xff\xff\xff\xff'
'''
max_msglength = target_length - 11
msglength = len(message)
if msglength > max_msglength:
raise OverflowError('%i bytes needed for message, but there is only'
' space for %i' % (msglength, max_msglength))
padding_length = target_length - msglength - 3
return b('').join([b('\x00\x01'),
padding_length * b('\xff'),
b('\x00'),
message])
def encrypt(message, pub_key):
'''Encrypts the given message using PKCS#1 v1.5
:param message: the message to encrypt. Must be a byte string no longer than
``k-11`` bytes, where ``k`` is the number of bytes needed to encode
the ``n`` component of the public key.
:param pub_key: the :py:class:`rsa.PublicKey` to encrypt with.
:raise OverflowError: when the message is too large to fit in the padded
block.
>>> from rsa import key, common
>>> (pub_key, priv_key) = key.newkeys(256)
>>> message = 'hello'
>>> crypto = encrypt(message, pub_key)
The crypto text should be just as long as the public key 'n' component:
>>> len(crypto) == common.byte_size(pub_key.n)
True
'''
keylength = common.byte_size(pub_key.n)
padded = _pad_for_encryption(message, keylength)
payload = transform.bytes2int(padded)
encrypted = core.encrypt_int(payload, pub_key.e, pub_key.n)
block = transform.int2bytes(encrypted, keylength)
return block
def decrypt(crypto, priv_key):
r'''Decrypts the given message using PKCS#1 v1.5
The decryption is considered 'failed' when the resulting cleartext doesn't
start with the bytes 00 02, or when the 00 byte between the padding and
the message cannot be found.
:param crypto: the crypto text as returned by :py:func:`rsa.encrypt`
:param priv_key: the :py:class:`rsa.PrivateKey` to decrypt with.
:raise DecryptionError: when the decryption fails. No details are given as
to why the code thinks the decryption fails, as this would leak
information about the private key.
>>> import rsa
>>> (pub_key, priv_key) = rsa.newkeys(256)
It works with strings:
>>> crypto = encrypt('hello', pub_key)
>>> decrypt(crypto, priv_key)
'hello'
And with binary data:
>>> crypto = encrypt('\x00\x00\x00\x00\x01', pub_key)
>>> decrypt(crypto, priv_key)
'\x00\x00\x00\x00\x01'
Altering the encrypted information will *likely* cause a
:py:class:`rsa.pkcs1.DecryptionError`. If you want to be *sure*, use
:py:func:`rsa.sign`.
.. warning::
Never display the stack trace of a
:py:class:`rsa.pkcs1.DecryptionError` exception. It shows where in the
code the exception occurred, and thus leaks information about the key.
It's only a tiny bit of information, but every bit makes cracking the
keys easier.
>>> crypto = encrypt('hello', pub_key)
>>> crypto = crypto[0:5] + 'X' + crypto[6:] # change a byte
>>> decrypt(crypto, priv_key)
Traceback (most recent call last):
...
DecryptionError: Decryption failed
'''
blocksize = common.byte_size(priv_key.n)
encrypted = transform.bytes2int(crypto)
decrypted = core.decrypt_int(encrypted, priv_key.d, priv_key.n)
cleartext = transform.int2bytes(decrypted, blocksize)
# If we can't find the cleartext marker, decryption failed.
if cleartext[0:2] != b('\x00\x02'):
raise DecryptionError('Decryption failed')
# Find the 00 separator between the padding and the message
try:
sep_idx = cleartext.index(b('\x00'), 2)
except ValueError:
raise DecryptionError('Decryption failed')
return cleartext[sep_idx+1:]
def sign(message, priv_key, hash):
'''Signs the message with the private key.
Hashes the message, then signs the hash with the given key. This is known
as a "detached signature", because the message itself isn't altered.
:param message: the message to sign. Can be an 8-bit string or a file-like
object. If ``message`` has a ``read()`` method, it is assumed to be a
file-like object.
:param priv_key: the :py:class:`rsa.PrivateKey` to sign with
:param hash: the hash method used on the message. Use 'MD5', 'SHA-1',
'SHA-256', 'SHA-384' or 'SHA-512'.
:return: a message signature block.
:raise OverflowError: if the private key is too small to contain the
requested hash.
'''
# Get the ASN1 code for this hash method
if hash not in HASH_ASN1:
raise ValueError('Invalid hash method: %s' % hash)
asn1code = HASH_ASN1[hash]
# Calculate the hash
hash = _hash(message, hash)
# Encrypt the hash with the private key
cleartext = asn1code + hash
keylength = common.byte_size(priv_key.n)
padded = _pad_for_signing(cleartext, keylength)
payload = transform.bytes2int(padded)
encrypted = core.encrypt_int(payload, priv_key.d, priv_key.n)
block = transform.int2bytes(encrypted, keylength)
return block
def verify(message, signature, pub_key):
'''Verifies that the signature matches the message.
The hash method is detected automatically from the signature.
:param message: the signed message. Can be an 8-bit string or a file-like
object. If ``message`` has a ``read()`` method, it is assumed to be a
file-like object.
:param signature: the signature block, as created with :py:func:`rsa.sign`.
:param pub_key: the :py:class:`rsa.PublicKey` of the person signing the message.
:raise VerificationError: when the signature doesn't match the message.
.. warning::
Never display the stack trace of a
:py:class:`rsa.pkcs1.VerificationError` exception. It shows where in
the code the exception occurred, and thus leaks information about the
key. It's only a tiny bit of information, but every bit makes cracking
the keys easier.
'''
blocksize = common.byte_size(pub_key.n)
encrypted = transform.bytes2int(signature)
decrypted = core.decrypt_int(encrypted, pub_key.e, pub_key.n)
clearsig = transform.int2bytes(decrypted, blocksize)
# If we can't find the signature marker, verification failed.
if clearsig[0:2] != b('\x00\x01'):
raise VerificationError('Verification failed')
# Find the 00 separator between the padding and the payload
try:
sep_idx = clearsig.index(b('\x00'), 2)
except ValueError:
raise VerificationError('Verification failed')
# Get the hash and the hash method
(method_name, signature_hash) = _find_method_hash(clearsig[sep_idx+1:])
message_hash = _hash(message, method_name)
# Compare the real hash to the hash in the signature
if message_hash != signature_hash:
raise VerificationError('Verification failed')
return True
def _hash(message, method_name):
'''Returns the message digest.
:param message: the signed message. Can be an 8-bit string or a file-like
object. If ``message`` has a ``read()`` method, it is assumed to be a
file-like object.
:param method_name: the hash method, must be a key of
:py:const:`HASH_METHODS`.
'''
if method_name not in HASH_METHODS:
raise ValueError('Invalid hash method: %s' % method_name)
method = HASH_METHODS[method_name]
hasher = method()
if hasattr(message, 'read') and hasattr(message.read, '__call__'):
# read as 1K blocks
for block in varblock.yield_fixedblocks(message, 1024):
hasher.update(block)
else:
# hash the message object itself.
hasher.update(message)
return hasher.digest()
def _find_method_hash(method_hash):
'''Finds the hash method and the hash itself.
:param method_hash: ASN1 code for the hash method concatenated with the
hash itself.
:return: tuple (method, hash) where ``method`` is the used hash method, and
``hash`` is the hash itself.
:raise VerificationFailed: when the hash method cannot be found
'''
for (hashname, asn1code) in HASH_ASN1.items():
if not method_hash.startswith(asn1code):
continue
return (hashname, method_hash[len(asn1code):])
raise VerificationError('Verification failed')
__all__ = ['encrypt', 'decrypt', 'sign', 'verify',
'DecryptionError', 'VerificationError', 'CryptoError']
if __name__ == '__main__':
print('Running doctests 1000x or until failure')
import doctest
for count in range(1000):
(failures, tests) = doctest.testmod()
if failures:
break
if count and count % 100 == 0:
print('%i times' % count)
print('Doctests done')

View File

@ -1,166 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Numerical functions related to primes.
Implementation based on the book Algorithm Design by Michael T. Goodrich and
Roberto Tamassia, 2002.
'''
__all__ = [ 'getprime', 'are_relatively_prime']
import rsa.randnum
def gcd(p, q):
'''Returns the greatest common divisor of p and q
>>> gcd(48, 180)
12
'''
while q != 0:
if p < q: (p,q) = (q,p)
(p,q) = (q, p % q)
return p
def jacobi(a, b):
'''Calculates the value of the Jacobi symbol (a/b) where both a and b are
positive integers, and b is odd
:returns: -1, 0 or 1
'''
assert a > 0
assert b > 0
if a == 0: return 0
result = 1
while a > 1:
if a & 1:
if ((a-1)*(b-1) >> 2) & 1:
result = -result
a, b = b % a, a
else:
if (((b * b) - 1) >> 3) & 1:
result = -result
a >>= 1
if a == 0: return 0
return result
def jacobi_witness(x, n):
'''Returns False if n is an Euler pseudo-prime with base x, and
True otherwise.
'''
j = jacobi(x, n) % n
f = pow(x, n >> 1, n)
if j == f: return False
return True
def randomized_primality_testing(n, k):
'''Calculates whether n is composite (which is always correct) or
prime (which is incorrect with error probability 2**-k)
Returns False if the number is composite, and True if it's
probably prime.
'''
# 50% of Jacobi-witnesses can report compositness of non-prime numbers
# The implemented algorithm using the Jacobi witness function has error
# probability q <= 0.5, according to Goodrich et. al
#
# q = 0.5
# t = int(math.ceil(k / log(1 / q, 2)))
# So t = k / log(2, 2) = k / 1 = k
# this means we can use range(k) rather than range(t)
for _ in range(k):
x = rsa.randnum.randint(n-1)
if jacobi_witness(x, n): return False
return True
def is_prime(number):
'''Returns True if the number is prime, and False otherwise.
>>> is_prime(42)
False
>>> is_prime(41)
True
'''
return randomized_primality_testing(number, 6)
def getprime(nbits):
'''Returns a prime number that can be stored in 'nbits' bits.
>>> p = getprime(128)
>>> is_prime(p-1)
False
>>> is_prime(p)
True
>>> is_prime(p+1)
False
>>> from rsa import common
>>> common.bit_size(p) == 128
True
'''
while True:
integer = rsa.randnum.read_random_int(nbits)
# Make sure it's odd
integer |= 1
# Test for primeness
if is_prime(integer):
return integer
# Retry if not prime
def are_relatively_prime(a, b):
'''Returns True if a and b are relatively prime, and False if they
are not.
>>> are_relatively_prime(2, 3)
1
>>> are_relatively_prime(2, 4)
0
'''
d = gcd(a, b)
return (d == 1)
if __name__ == '__main__':
print('Running doctests 1000x or until failure')
import doctest
for count in range(1000):
(failures, tests) = doctest.testmod()
if failures:
break
if count and count % 100 == 0:
print('%i times' % count)
print('Doctests done')

View File

@ -1,85 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Functions for generating random numbers.'''
# Source inspired by code by Yesudeep Mangalapilly <yesudeep@gmail.com>
import os
from rsa import common, transform
from rsa._compat import byte
def read_random_bits(nbits):
'''Reads 'nbits' random bits.
If nbits isn't a whole number of bytes, an extra byte will be appended with
only the lower bits set.
'''
nbytes, rbits = divmod(nbits, 8)
# Get the random bytes
randomdata = os.urandom(nbytes)
# Add the remaining random bits
if rbits > 0:
randomvalue = ord(os.urandom(1))
randomvalue >>= (8 - rbits)
randomdata = byte(randomvalue) + randomdata
return randomdata
def read_random_int(nbits):
'''Reads a random integer of approximately nbits bits.
'''
randomdata = read_random_bits(nbits)
value = transform.bytes2int(randomdata)
# Ensure that the number is large enough to just fill out the required
# number of bits.
value |= 1 << (nbits - 1)
return value
def randint(maxvalue):
'''Returns a random integer x with 1 <= x <= maxvalue
May take a very long time in specific situations. If maxvalue needs N bits
to store, the closer maxvalue is to (2 ** N) - 1, the faster this function
is.
'''
bit_size = common.bit_size(maxvalue)
tries = 0
while True:
value = read_random_int(bit_size)
if value <= maxvalue:
break
if tries and tries % 10 == 0:
# After a lot of tries to get the right number of bits but still
# smaller than maxvalue, decrease the number of bits by 1. That'll
# dramatically increase the chances to get a large enough number.
bit_size -= 1
tries += 1
return value

View File

@ -1,220 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Data transformation functions.
From bytes to a number, number to bytes, etc.
'''
from __future__ import absolute_import
try:
# We'll use psyco if available on 32-bit architectures to speed up code.
# Using psyco (if available) cuts down the execution time on Python 2.5
# at least by half.
import psyco
psyco.full()
except ImportError:
pass
import binascii
from struct import pack
from rsa import common
from rsa._compat import is_integer, b, byte, get_word_alignment, ZERO_BYTE, EMPTY_BYTE
def bytes2int(raw_bytes):
r'''Converts a list of bytes or an 8-bit string to an integer.
When using unicode strings, encode it to some encoding like UTF8 first.
>>> (((128 * 256) + 64) * 256) + 15
8405007
>>> bytes2int('\x80@\x0f')
8405007
'''
return int(binascii.hexlify(raw_bytes), 16)
def _int2bytes(number, block_size=None):
r'''Converts a number to a string of bytes.
Usage::
>>> _int2bytes(123456789)
'\x07[\xcd\x15'
>>> bytes2int(_int2bytes(123456789))
123456789
>>> _int2bytes(123456789, 6)
'\x00\x00\x07[\xcd\x15'
>>> bytes2int(_int2bytes(123456789, 128))
123456789
>>> _int2bytes(123456789, 3)
Traceback (most recent call last):
...
OverflowError: Needed 4 bytes for number, but block size is 3
@param number: the number to convert
@param block_size: the number of bytes to output. If the number encoded to
bytes is less than this, the block will be zero-padded. When not given,
the returned block is not padded.
@throws OverflowError when block_size is given and the number takes up more
bytes than fit into the block.
'''
# Type checking
if not is_integer(number):
raise TypeError("You must pass an integer for 'number', not %s" %
number.__class__)
if number < 0:
raise ValueError('Negative numbers cannot be used: %i' % number)
# Do some bounds checking
if number == 0:
needed_bytes = 1
raw_bytes = [ZERO_BYTE]
else:
needed_bytes = common.byte_size(number)
raw_bytes = []
# You cannot compare None > 0 in Python 3x. It will fail with a TypeError.
if block_size and block_size > 0:
if needed_bytes > block_size:
raise OverflowError('Needed %i bytes for number, but block size '
'is %i' % (needed_bytes, block_size))
# Convert the number to bytes.
while number > 0:
raw_bytes.insert(0, byte(number & 0xFF))
number >>= 8
# Pad with zeroes to fill the block
if block_size and block_size > 0:
padding = (block_size - needed_bytes) * ZERO_BYTE
else:
padding = EMPTY_BYTE
return padding + EMPTY_BYTE.join(raw_bytes)
def bytes_leading(raw_bytes, needle=ZERO_BYTE):
'''
Finds the number of prefixed byte occurrences in the haystack.
Useful when you want to deal with padding.
:param raw_bytes:
Raw bytes.
:param needle:
The byte to count. Default \000.
:returns:
The number of leading needle bytes.
'''
leading = 0
# Indexing keeps compatibility between Python 2.x and Python 3.x
_byte = needle[0]
for x in raw_bytes:
if x == _byte:
leading += 1
else:
break
return leading
def int2bytes(number, fill_size=None, chunk_size=None, overflow=False):
'''
Convert an unsigned integer to bytes (base-256 representation)::
Does not preserve leading zeros if you don't specify a chunk size or
fill size.
.. NOTE:
You must not specify both fill_size and chunk_size. Only one
of them is allowed.
:param number:
Integer value
:param fill_size:
If the optional fill size is given the length of the resulting
byte string is expected to be the fill size and will be padded
with prefix zero bytes to satisfy that length.
:param chunk_size:
If optional chunk size is given and greater than zero, pad the front of
the byte string with binary zeros so that the length is a multiple of
``chunk_size``.
:param overflow:
``False`` (default). If this is ``True``, no ``OverflowError``
will be raised when the fill_size is shorter than the length
of the generated byte sequence. Instead the byte sequence will
be returned as is.
:returns:
Raw bytes (base-256 representation).
:raises:
``OverflowError`` when fill_size is given and the number takes up more
bytes than fit into the block. This requires the ``overflow``
argument to this function to be set to ``False`` otherwise, no
error will be raised.
'''
if number < 0:
raise ValueError("Number must be an unsigned integer: %d" % number)
if fill_size and chunk_size:
raise ValueError("You can either fill or pad chunks, but not both")
# Ensure these are integers.
number & 1
raw_bytes = b('')
# Pack the integer one machine word at a time into bytes.
num = number
word_bits, _, max_uint, pack_type = get_word_alignment(num)
pack_format = ">%s" % pack_type
while num > 0:
raw_bytes = pack(pack_format, num & max_uint) + raw_bytes
num >>= word_bits
# Obtain the index of the first non-zero byte.
zero_leading = bytes_leading(raw_bytes)
if number == 0:
raw_bytes = ZERO_BYTE
# De-padding.
raw_bytes = raw_bytes[zero_leading:]
length = len(raw_bytes)
if fill_size and fill_size > 0:
if not overflow and length > fill_size:
raise OverflowError(
"Need %d bytes for number, but fill size is %d" %
(length, fill_size)
)
raw_bytes = raw_bytes.rjust(fill_size, ZERO_BYTE)
elif chunk_size and chunk_size > 0:
remainder = length % chunk_size
if remainder:
padding_size = chunk_size - remainder
raw_bytes = raw_bytes.rjust(length + padding_size, ZERO_BYTE)
return raw_bytes
if __name__ == '__main__':
import doctest
doctest.testmod()

View File

@ -1,81 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''Utility functions.'''
from __future__ import with_statement, print_function
import sys
from optparse import OptionParser
import rsa.key
def private_to_public():
'''Reads a private key and outputs the corresponding public key.'''
# Parse the CLI options
parser = OptionParser(usage='usage: %prog [options]',
description='Reads a private key and outputs the '
'corresponding public key. Both private and public keys use '
'the format described in PKCS#1 v1.5')
parser.add_option('-i', '--input', dest='infilename', type='string',
help='Input filename. Reads from stdin if not specified')
parser.add_option('-o', '--output', dest='outfilename', type='string',
help='Output filename. Writes to stdout of not specified')
parser.add_option('--inform', dest='inform',
help='key format of input - default PEM',
choices=('PEM', 'DER'), default='PEM')
parser.add_option('--outform', dest='outform',
help='key format of output - default PEM',
choices=('PEM', 'DER'), default='PEM')
(cli, cli_args) = parser.parse_args(sys.argv)
# Read the input data
if cli.infilename:
print('Reading private key from %s in %s format' % \
(cli.infilename, cli.inform), file=sys.stderr)
with open(cli.infilename, 'rb') as infile:
in_data = infile.read()
else:
print('Reading private key from stdin in %s format' % cli.inform,
file=sys.stderr)
in_data = sys.stdin.read().encode('ascii')
assert type(in_data) == bytes, type(in_data)
# Take the public fields and create a public key
priv_key = rsa.key.PrivateKey.load_pkcs1(in_data, cli.inform)
pub_key = rsa.key.PublicKey(priv_key.n, priv_key.e)
# Save to the output file
out_data = pub_key.save_pkcs1(cli.outform)
if cli.outfilename:
print('Writing public key to %s in %s format' % \
(cli.outfilename, cli.outform), file=sys.stderr)
with open(cli.outfilename, 'wb') as outfile:
outfile.write(out_data)
else:
print('Writing public key to stdout in %s format' % cli.outform,
file=sys.stderr)
sys.stdout.write(out_data.decode('ascii'))

View File

@ -1,155 +0,0 @@
# -*- coding: utf-8 -*-
#
# Copyright 2011 Sybren A. Stüvel <sybren@stuvel.eu>
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
'''VARBLOCK file support
The VARBLOCK file format is as follows, where || denotes byte concatenation:
FILE := VERSION || BLOCK || BLOCK ...
BLOCK := LENGTH || DATA
LENGTH := varint-encoded length of the subsequent data. Varint comes from
Google Protobuf, and encodes an integer into a variable number of bytes.
Each byte uses the 7 lowest bits to encode the value. The highest bit set
to 1 indicates the next byte is also part of the varint. The last byte will
have this bit set to 0.
This file format is called the VARBLOCK format, in line with the varint format
used to denote the block sizes.
'''
from rsa._compat import byte, b
ZERO_BYTE = b('\x00')
VARBLOCK_VERSION = 1
def read_varint(infile):
'''Reads a varint from the file.
When the first byte to be read indicates EOF, (0, 0) is returned. When an
EOF occurs when at least one byte has been read, an EOFError exception is
raised.
@param infile: the file-like object to read from. It should have a read()
method.
@returns (varint, length), the read varint and the number of read bytes.
'''
varint = 0
read_bytes = 0
while True:
char = infile.read(1)
if len(char) == 0:
if read_bytes == 0:
return (0, 0)
raise EOFError('EOF while reading varint, value is %i so far' %
varint)
byte = ord(char)
varint += (byte & 0x7F) << (7 * read_bytes)
read_bytes += 1
if not byte & 0x80:
return (varint, read_bytes)
def write_varint(outfile, value):
'''Writes a varint to a file.
@param outfile: the file-like object to write to. It should have a write()
method.
@returns the number of written bytes.
'''
# there is a big difference between 'write the value 0' (this case) and
# 'there is nothing left to write' (the false-case of the while loop)
if value == 0:
outfile.write(ZERO_BYTE)
return 1
written_bytes = 0
while value > 0:
to_write = value & 0x7f
value = value >> 7
if value > 0:
to_write |= 0x80
outfile.write(byte(to_write))
written_bytes += 1
return written_bytes
def yield_varblocks(infile):
'''Generator, yields each block in the input file.
@param infile: file to read, is expected to have the VARBLOCK format as
described in the module's docstring.
@yields the contents of each block.
'''
# Check the version number
first_char = infile.read(1)
if len(first_char) == 0:
raise EOFError('Unable to read VARBLOCK version number')
version = ord(first_char)
if version != VARBLOCK_VERSION:
raise ValueError('VARBLOCK version %i not supported' % version)
while True:
(block_size, read_bytes) = read_varint(infile)
# EOF at block boundary, that's fine.
if read_bytes == 0 and block_size == 0:
break
block = infile.read(block_size)
read_size = len(block)
if read_size != block_size:
raise EOFError('Block size is %i, but could read only %i bytes' %
(block_size, read_size))
yield block
def yield_fixedblocks(infile, blocksize):
'''Generator, yields each block of ``blocksize`` bytes in the input file.
:param infile: file to read and separate in blocks.
:returns: a generator that yields the contents of each block
'''
while True:
block = infile.read(blocksize)
read_bytes = len(block)
if read_bytes == 0:
break
yield block
if read_bytes < blocksize:
break

View File

@ -16,7 +16,7 @@
"""The Tornado web server and tools."""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
# version is a human-readable version number.
@ -25,5 +25,5 @@ from __future__ import absolute_import, division, print_function, with_statement
# is zero for an official release, positive for a development branch,
# or negative for a release candidate or beta (after the base version
# number has been incremented)
version = "4.3"
version_info = (4, 3, 0, 0)
version = "4.5.1"
version_info = (4, 5, 1, 0)

View File

@ -17,78 +17,69 @@
"""Data used by the tornado.locale module."""
from __future__ import absolute_import, division, print_function, with_statement
# NOTE: This file is supposed to contain unicode strings, which is
# exactly what you'd get with e.g. u"Español" in most python versions.
# However, Python 3.2 doesn't support the u"" syntax, so we use a u()
# function instead. tornado.util.u cannot be used because it doesn't
# support non-ascii characters on python 2.
# When we drop support for Python 3.2, we can remove the parens
# and make these plain unicode strings.
from tornado.escape import to_unicode as u
from __future__ import absolute_import, division, print_function
LOCALE_NAMES = {
"af_ZA": {"name_en": u("Afrikaans"), "name": u("Afrikaans")},
"am_ET": {"name_en": u("Amharic"), "name": u("አማርኛ")},
"ar_AR": {"name_en": u("Arabic"), "name": u("العربية")},
"bg_BG": {"name_en": u("Bulgarian"), "name": u("Български")},
"bn_IN": {"name_en": u("Bengali"), "name": u("বাংলা")},
"bs_BA": {"name_en": u("Bosnian"), "name": u("Bosanski")},
"ca_ES": {"name_en": u("Catalan"), "name": u("Català")},
"cs_CZ": {"name_en": u("Czech"), "name": u("Čeština")},
"cy_GB": {"name_en": u("Welsh"), "name": u("Cymraeg")},
"da_DK": {"name_en": u("Danish"), "name": u("Dansk")},
"de_DE": {"name_en": u("German"), "name": u("Deutsch")},
"el_GR": {"name_en": u("Greek"), "name": u("Ελληνικά")},
"en_GB": {"name_en": u("English (UK)"), "name": u("English (UK)")},
"en_US": {"name_en": u("English (US)"), "name": u("English (US)")},
"es_ES": {"name_en": u("Spanish (Spain)"), "name": u("Español (España)")},
"es_LA": {"name_en": u("Spanish"), "name": u("Español")},
"et_EE": {"name_en": u("Estonian"), "name": u("Eesti")},
"eu_ES": {"name_en": u("Basque"), "name": u("Euskara")},
"fa_IR": {"name_en": u("Persian"), "name": u("فارسی")},
"fi_FI": {"name_en": u("Finnish"), "name": u("Suomi")},
"fr_CA": {"name_en": u("French (Canada)"), "name": u("Français (Canada)")},
"fr_FR": {"name_en": u("French"), "name": u("Français")},
"ga_IE": {"name_en": u("Irish"), "name": u("Gaeilge")},
"gl_ES": {"name_en": u("Galician"), "name": u("Galego")},
"he_IL": {"name_en": u("Hebrew"), "name": u("עברית")},
"hi_IN": {"name_en": u("Hindi"), "name": u("हिन्दी")},
"hr_HR": {"name_en": u("Croatian"), "name": u("Hrvatski")},
"hu_HU": {"name_en": u("Hungarian"), "name": u("Magyar")},
"id_ID": {"name_en": u("Indonesian"), "name": u("Bahasa Indonesia")},
"is_IS": {"name_en": u("Icelandic"), "name": u("Íslenska")},
"it_IT": {"name_en": u("Italian"), "name": u("Italiano")},
"ja_JP": {"name_en": u("Japanese"), "name": u("日本語")},
"ko_KR": {"name_en": u("Korean"), "name": u("한국어")},
"lt_LT": {"name_en": u("Lithuanian"), "name": u("Lietuvių")},
"lv_LV": {"name_en": u("Latvian"), "name": u("Latviešu")},
"mk_MK": {"name_en": u("Macedonian"), "name": u("Македонски")},
"ml_IN": {"name_en": u("Malayalam"), "name": u("മലയാളം")},
"ms_MY": {"name_en": u("Malay"), "name": u("Bahasa Melayu")},
"nb_NO": {"name_en": u("Norwegian (bokmal)"), "name": u("Norsk (bokmål)")},
"nl_NL": {"name_en": u("Dutch"), "name": u("Nederlands")},
"nn_NO": {"name_en": u("Norwegian (nynorsk)"), "name": u("Norsk (nynorsk)")},
"pa_IN": {"name_en": u("Punjabi"), "name": u("ਪੰਜਾਬੀ")},
"pl_PL": {"name_en": u("Polish"), "name": u("Polski")},
"pt_BR": {"name_en": u("Portuguese (Brazil)"), "name": u("Português (Brasil)")},
"pt_PT": {"name_en": u("Portuguese (Portugal)"), "name": u("Português (Portugal)")},
"ro_RO": {"name_en": u("Romanian"), "name": u("Română")},
"ru_RU": {"name_en": u("Russian"), "name": u("Русский")},
"sk_SK": {"name_en": u("Slovak"), "name": u("Slovenčina")},
"sl_SI": {"name_en": u("Slovenian"), "name": u("Slovenščina")},
"sq_AL": {"name_en": u("Albanian"), "name": u("Shqip")},
"sr_RS": {"name_en": u("Serbian"), "name": u("Српски")},
"sv_SE": {"name_en": u("Swedish"), "name": u("Svenska")},
"sw_KE": {"name_en": u("Swahili"), "name": u("Kiswahili")},
"ta_IN": {"name_en": u("Tamil"), "name": u("தமிழ்")},
"te_IN": {"name_en": u("Telugu"), "name": u("తెలుగు")},
"th_TH": {"name_en": u("Thai"), "name": u("ภาษาไทย")},
"tl_PH": {"name_en": u("Filipino"), "name": u("Filipino")},
"tr_TR": {"name_en": u("Turkish"), "name": u("Türkçe")},
"uk_UA": {"name_en": u("Ukraini "), "name": u("Українська")},
"vi_VN": {"name_en": u("Vietnamese"), "name": u("Tiếng Việt")},
"zh_CN": {"name_en": u("Chinese (Simplified)"), "name": u("中文(简体)")},
"zh_TW": {"name_en": u("Chinese (Traditional)"), "name": u("中文(繁體)")},
"af_ZA": {"name_en": u"Afrikaans", "name": u"Afrikaans"},
"am_ET": {"name_en": u"Amharic", "name": u"አማርኛ"},
"ar_AR": {"name_en": u"Arabic", "name": u"العربية"},
"bg_BG": {"name_en": u"Bulgarian", "name": u"Български"},
"bn_IN": {"name_en": u"Bengali", "name": u"বাংলা"},
"bs_BA": {"name_en": u"Bosnian", "name": u"Bosanski"},
"ca_ES": {"name_en": u"Catalan", "name": u"Català"},
"cs_CZ": {"name_en": u"Czech", "name": u"Čeština"},
"cy_GB": {"name_en": u"Welsh", "name": u"Cymraeg"},
"da_DK": {"name_en": u"Danish", "name": u"Dansk"},
"de_DE": {"name_en": u"German", "name": u"Deutsch"},
"el_GR": {"name_en": u"Greek", "name": u"Ελληνικά"},
"en_GB": {"name_en": u"English (UK)", "name": u"English (UK)"},
"en_US": {"name_en": u"English (US)", "name": u"English (US)"},
"es_ES": {"name_en": u"Spanish (Spain)", "name": u"Español (España)"},
"es_LA": {"name_en": u"Spanish", "name": u"Español"},
"et_EE": {"name_en": u"Estonian", "name": u"Eesti"},
"eu_ES": {"name_en": u"Basque", "name": u"Euskara"},
"fa_IR": {"name_en": u"Persian", "name": u"فارسی"},
"fi_FI": {"name_en": u"Finnish", "name": u"Suomi"},
"fr_CA": {"name_en": u"French (Canada)", "name": u"Français (Canada)"},
"fr_FR": {"name_en": u"French", "name": u"Français"},
"ga_IE": {"name_en": u"Irish", "name": u"Gaeilge"},
"gl_ES": {"name_en": u"Galician", "name": u"Galego"},
"he_IL": {"name_en": u"Hebrew", "name": u"עברית"},
"hi_IN": {"name_en": u"Hindi", "name": u"हिन्दी"},
"hr_HR": {"name_en": u"Croatian", "name": u"Hrvatski"},
"hu_HU": {"name_en": u"Hungarian", "name": u"Magyar"},
"id_ID": {"name_en": u"Indonesian", "name": u"Bahasa Indonesia"},
"is_IS": {"name_en": u"Icelandic", "name": u"Íslenska"},
"it_IT": {"name_en": u"Italian", "name": u"Italiano"},
"ja_JP": {"name_en": u"Japanese", "name": u"日本語"},
"ko_KR": {"name_en": u"Korean", "name": u"한국어"},
"lt_LT": {"name_en": u"Lithuanian", "name": u"Lietuvių"},
"lv_LV": {"name_en": u"Latvian", "name": u"Latviešu"},
"mk_MK": {"name_en": u"Macedonian", "name": u"Македонски"},
"ml_IN": {"name_en": u"Malayalam", "name": u"മലയാളം"},
"ms_MY": {"name_en": u"Malay", "name": u"Bahasa Melayu"},
"nb_NO": {"name_en": u"Norwegian (bokmal)", "name": u"Norsk (bokmål)"},
"nl_NL": {"name_en": u"Dutch", "name": u"Nederlands"},
"nn_NO": {"name_en": u"Norwegian (nynorsk)", "name": u"Norsk (nynorsk)"},
"pa_IN": {"name_en": u"Punjabi", "name": u"ਪੰਜਾਬੀ"},
"pl_PL": {"name_en": u"Polish", "name": u"Polski"},
"pt_BR": {"name_en": u"Portuguese (Brazil)", "name": u"Português (Brasil)"},
"pt_PT": {"name_en": u"Portuguese (Portugal)", "name": u"Português (Portugal)"},
"ro_RO": {"name_en": u"Romanian", "name": u"Română"},
"ru_RU": {"name_en": u"Russian", "name": u"Русский"},
"sk_SK": {"name_en": u"Slovak", "name": u"Slovenčina"},
"sl_SI": {"name_en": u"Slovenian", "name": u"Slovenščina"},
"sq_AL": {"name_en": u"Albanian", "name": u"Shqip"},
"sr_RS": {"name_en": u"Serbian", "name": u"Српски"},
"sv_SE": {"name_en": u"Swedish", "name": u"Svenska"},
"sw_KE": {"name_en": u"Swahili", "name": u"Kiswahili"},
"ta_IN": {"name_en": u"Tamil", "name": u"தமிழ்"},
"te_IN": {"name_en": u"Telugu", "name": u"తెలుగు"},
"th_TH": {"name_en": u"Thai", "name": u"ภาษาไทย"},
"tl_PH": {"name_en": u"Filipino", "name": u"Filipino"},
"tr_TR": {"name_en": u"Turkish", "name": u"Türkçe"},
"uk_UA": {"name_en": u"Ukraini ", "name": u"Українська"},
"vi_VN": {"name_en": u"Vietnamese", "name": u"Tiếng Việt"},
"zh_CN": {"name_en": u"Chinese (Simplified)", "name": u"中文(简体)"},
"zh_TW": {"name_en": u"Chinese (Traditional)", "name": u"中文(繁體)"},
}

View File

@ -65,7 +65,7 @@ Example usage for Google OAuth:
errors are more consistently reported through the ``Future`` interfaces.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import base64
import binascii
@ -82,22 +82,15 @@ from tornado import escape
from tornado.httputil import url_concat
from tornado.log import gen_log
from tornado.stack_context import ExceptionStackContext
from tornado.util import u, unicode_type, ArgReplacer
from tornado.util import unicode_type, ArgReplacer, PY3
try:
import urlparse # py2
except ImportError:
import urllib.parse as urlparse # py3
try:
import urllib.parse as urllib_parse # py3
except ImportError:
import urllib as urllib_parse # py2
try:
long # py2
except NameError:
long = int # py3
if PY3:
import urllib.parse as urlparse
import urllib.parse as urllib_parse
long = int
else:
import urlparse
import urllib as urllib_parse
class AuthError(Exception):
@ -188,7 +181,7 @@ class OpenIdMixin(object):
"""
# Verify the OpenID response via direct request to the OP
args = dict((k, v[-1]) for k, v in self.request.arguments.items())
args["openid.mode"] = u("check_authentication")
args["openid.mode"] = u"check_authentication"
url = self._OPENID_ENDPOINT
if http_client is None:
http_client = self.get_auth_http_client()
@ -255,13 +248,13 @@ class OpenIdMixin(object):
ax_ns = None
for name in self.request.arguments:
if name.startswith("openid.ns.") and \
self.get_argument(name) == u("http://openid.net/srv/ax/1.0"):
self.get_argument(name) == u"http://openid.net/srv/ax/1.0":
ax_ns = name[10:]
break
def get_ax_arg(uri):
if not ax_ns:
return u("")
return u""
prefix = "openid." + ax_ns + ".type."
ax_name = None
for name in self.request.arguments.keys():
@ -270,8 +263,8 @@ class OpenIdMixin(object):
ax_name = "openid." + ax_ns + ".value." + part
break
if not ax_name:
return u("")
return self.get_argument(ax_name, u(""))
return u""
return self.get_argument(ax_name, u"")
email = get_ax_arg("http://axschema.org/contact/email")
name = get_ax_arg("http://axschema.org/namePerson")
@ -290,7 +283,7 @@ class OpenIdMixin(object):
if name:
user["name"] = name
elif name_parts:
user["name"] = u(" ").join(name_parts)
user["name"] = u" ".join(name_parts)
elif email:
user["name"] = email.split("@")[0]
if email:
@ -961,6 +954,20 @@ class FacebookGraphMixin(OAuth2Mixin):
.. testoutput::
:hide:
This method returns a dictionary which may contain the following fields:
* ``access_token``, a string which may be passed to `facebook_request`
* ``session_expires``, an integer encoded as a string representing
the time until the access token expires in seconds. This field should
be used like ``int(user['session_expires'])``; in a future version of
Tornado it will change from a string to an integer.
* ``id``, ``name``, ``first_name``, ``last_name``, ``locale``, ``picture``,
``link``, plus any fields named in the ``extra_fields`` argument. These
fields are copied from the Facebook graph API `user object <https://developers.facebook.com/docs/graph-api/reference/user>`_
.. versionchanged:: 4.5
The ``session_expires`` field was updated to support changes made to the
Facebook API in March 2017.
"""
http = self.get_auth_http_client()
args = {
@ -985,10 +992,10 @@ class FacebookGraphMixin(OAuth2Mixin):
future.set_exception(AuthError('Facebook auth error: %s' % str(response)))
return
args = urlparse.parse_qs(escape.native_str(response.body))
args = escape.json_decode(response.body)
session = {
"access_token": args["access_token"][-1],
"expires": args.get("expires")
"access_token": args.get("access_token"),
"expires_in": args.get("expires_in")
}
self.facebook_request(
@ -996,6 +1003,9 @@ class FacebookGraphMixin(OAuth2Mixin):
callback=functools.partial(
self._on_get_user_info, future, session, fields),
access_token=session["access_token"],
appsecret_proof=hmac.new(key=client_secret.encode('utf8'),
msg=session["access_token"].encode('utf8'),
digestmod=hashlib.sha256).hexdigest(),
fields=",".join(fields)
)
@ -1008,7 +1018,12 @@ class FacebookGraphMixin(OAuth2Mixin):
for field in fields:
fieldmap[field] = user.get(field)
fieldmap.update({"access_token": session["access_token"], "session_expires": session.get("expires")})
# session_expires is converted to str for compatibility with
# older versions in which the server used url-encoding and
# this code simply returned the string verbatim.
# This should change in Tornado 5.0.
fieldmap.update({"access_token": session["access_token"],
"session_expires": str(session.get("expires_in"))})
future.set_result(fieldmap)
@_auth_return_future

View File

@ -45,7 +45,7 @@ incorrectly.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import os
import sys
@ -83,7 +83,7 @@ if __name__ == "__main__":
import functools
import logging
import os
import pkgutil
import pkgutil # type: ignore
import sys
import traceback
import types
@ -103,16 +103,12 @@ except ImportError:
# os.execv is broken on Windows and can't properly parse command line
# arguments and executable name if they contain whitespaces. subprocess
# fixes that behavior.
# This distinction is also important because when we use execv, we want to
# close the IOLoop and all its file descriptors, to guard against any
# file descriptors that were not set CLOEXEC. When execv is not available,
# we must not close the IOLoop because we want the process to exit cleanly.
_has_execv = sys.platform != 'win32'
_watched_files = set()
_reload_hooks = []
_reload_attempted = False
_io_loops = weakref.WeakKeyDictionary()
_io_loops = weakref.WeakKeyDictionary() # type: ignore
def start(io_loop=None, check_time=500):
@ -127,8 +123,6 @@ def start(io_loop=None, check_time=500):
_io_loops[io_loop] = True
if len(_io_loops) > 1:
gen_log.warning("tornado.autoreload started more than once in the same process")
if _has_execv:
add_reload_hook(functools.partial(io_loop.close, all_fds=True))
modify_times = {}
callback = functools.partial(_reload_on_update, modify_times)
scheduler = ioloop.PeriodicCallback(callback, check_time, io_loop=io_loop)
@ -249,6 +243,7 @@ def _reload():
# unwind, so just exit uncleanly.
os._exit(0)
_USAGE = """\
Usage:
python -m tornado.autoreload -m module.to.run [args...]

View File

@ -21,7 +21,7 @@ a mostly-compatible `Future` class designed for use from coroutines,
as well as some utility functions for interacting with the
`concurrent.futures` package.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import functools
import platform
@ -31,13 +31,18 @@ import sys
from tornado.log import app_log
from tornado.stack_context import ExceptionStackContext, wrap
from tornado.util import raise_exc_info, ArgReplacer
from tornado.util import raise_exc_info, ArgReplacer, is_finalizing
try:
from concurrent import futures
except ImportError:
futures = None
try:
import typing
except ImportError:
typing = None
# Can the garbage collector handle cycles that include __del__ methods?
# This is true in cpython beginning with version 3.4 (PEP 442).
@ -118,8 +123,8 @@ class _TracebackLogger(object):
self.exc_info = None
self.formatted_tb = None
def __del__(self):
if self.formatted_tb:
def __del__(self, is_finalizing=is_finalizing):
if not is_finalizing() and self.formatted_tb:
app_log.error('Future exception was never retrieved: %s',
''.join(self.formatted_tb).rstrip())
@ -229,7 +234,10 @@ class Future(object):
if self._result is not None:
return self._result
if self._exc_info is not None:
raise_exc_info(self._exc_info)
try:
raise_exc_info(self._exc_info)
finally:
self = None
self._check_done()
return self._result
@ -324,8 +332,8 @@ class Future(object):
# cycle are never destroyed. It's no longer the case on Python 3.4 thanks to
# the PEP 442.
if _GC_CYCLE_FINALIZERS:
def __del__(self):
if not self._log_traceback:
def __del__(self, is_finalizing=is_finalizing):
if is_finalizing() or not self._log_traceback:
# set_exception() was not called, or result() or exception()
# has consumed the exception
return
@ -335,10 +343,11 @@ class Future(object):
app_log.error('Future %r exception was never retrieved: %s',
self, ''.join(tb).rstrip())
TracebackFuture = Future
if futures is None:
FUTURES = Future
FUTURES = Future # type: typing.Union[type, typing.Tuple[type, ...]]
else:
FUTURES = (futures.Future, Future)
@ -359,6 +368,7 @@ class DummyExecutor(object):
def shutdown(self, wait=True):
pass
dummy_executor = DummyExecutor()
@ -500,8 +510,9 @@ def chain_future(a, b):
assert future is a
if b.done():
return
if (isinstance(a, TracebackFuture) and isinstance(b, TracebackFuture)
and a.exc_info() is not None):
if (isinstance(a, TracebackFuture) and
isinstance(b, TracebackFuture) and
a.exc_info() is not None):
b.set_exc_info(a.exc_info())
elif a.exception() is not None:
b.set_exception(a.exception())

View File

@ -16,12 +16,12 @@
"""Non-blocking HTTP client implementation using pycurl."""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import collections
import functools
import logging
import pycurl
import pycurl # type: ignore
import threading
import time
from io import BytesIO
@ -221,6 +221,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
# _process_queue() is called from
# _finish_pending_requests the exceptions have
# nowhere to go.
self._free_list.append(curl)
callback(HTTPResponse(
request=request,
code=599,
@ -277,6 +278,9 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
if curl_log.isEnabledFor(logging.DEBUG):
curl.setopt(pycurl.VERBOSE, 1)
curl.setopt(pycurl.DEBUGFUNCTION, self._curl_debug)
if hasattr(pycurl, 'PROTOCOLS'): # PROTOCOLS first appeared in pycurl 7.19.5 (2014-07-12)
curl.setopt(pycurl.PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
curl.setopt(pycurl.REDIR_PROTOCOLS, pycurl.PROTO_HTTP | pycurl.PROTO_HTTPS)
return curl
def _curl_setup_request(self, curl, request, buffer, headers):
@ -341,6 +345,15 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
credentials = '%s:%s' % (request.proxy_username,
request.proxy_password)
curl.setopt(pycurl.PROXYUSERPWD, credentials)
if (request.proxy_auth_mode is None or
request.proxy_auth_mode == "basic"):
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_BASIC)
elif request.proxy_auth_mode == "digest":
curl.setopt(pycurl.PROXYAUTH, pycurl.HTTPAUTH_DIGEST)
else:
raise ValueError(
"Unsupported proxy_auth_mode %s" % request.proxy_auth_mode)
else:
curl.setopt(pycurl.PROXY, '')
curl.unsetopt(pycurl.PROXYUSERPWD)
@ -461,7 +474,7 @@ class CurlAsyncHTTPClient(AsyncHTTPClient):
request.prepare_curl_callback(curl)
def _curl_header_callback(self, headers, header_callback, header_line):
header_line = native_str(header_line)
header_line = native_str(header_line.decode('latin1'))
if header_callback is not None:
self.io_loop.add_callback(header_callback, header_line)
# header_line as returned by curl includes the end-of-line characters.

View File

@ -20,34 +20,28 @@ Also includes a few other miscellaneous string manipulation functions that
have crept in over time.
"""
from __future__ import absolute_import, division, print_function, with_statement
import re
import sys
from tornado.util import unicode_type, basestring_type, u
try:
from urllib.parse import parse_qs as _parse_qs # py3
except ImportError:
from urlparse import parse_qs as _parse_qs # Python 2.6+
try:
import htmlentitydefs # py2
except ImportError:
import html.entities as htmlentitydefs # py3
try:
import urllib.parse as urllib_parse # py3
except ImportError:
import urllib as urllib_parse # py2
from __future__ import absolute_import, division, print_function
import json
import re
from tornado.util import PY3, unicode_type, basestring_type
if PY3:
from urllib.parse import parse_qs as _parse_qs
import html.entities as htmlentitydefs
import urllib.parse as urllib_parse
unichr = chr
else:
from urlparse import parse_qs as _parse_qs
import htmlentitydefs
import urllib as urllib_parse
try:
unichr
except NameError:
unichr = chr
import typing # noqa
except ImportError:
pass
_XHTML_ESCAPE_RE = re.compile('[&<>"\']')
_XHTML_ESCAPE_DICT = {'&': '&amp;', '<': '&lt;', '>': '&gt;', '"': '&quot;',
@ -116,7 +110,7 @@ def url_escape(value, plus=True):
# python 3 changed things around enough that we need two separate
# implementations of url_unescape. We also need our own implementation
# of parse_qs since python 3's version insists on decoding everything.
if sys.version_info[0] < 3:
if not PY3:
def url_unescape(value, encoding='utf-8', plus=True):
"""Decodes the given value from a URL.
@ -191,6 +185,7 @@ _UTF8_TYPES = (bytes, type(None))
def utf8(value):
# type: (typing.Union[bytes,unicode_type,None])->typing.Union[bytes,None]
"""Converts a string argument to a byte string.
If the argument is already a byte string or None, it is returned unchanged.
@ -204,6 +199,7 @@ def utf8(value):
)
return value.encode("utf-8")
_TO_UNICODE_TYPES = (unicode_type, type(None))
@ -221,6 +217,7 @@ def to_unicode(value):
)
return value.decode("utf-8")
# to_unicode was previously named _unicode not because it was private,
# but to avoid conflicts with the built-in unicode() function/type
_unicode = to_unicode
@ -269,6 +266,7 @@ def recursive_unicode(obj):
else:
return obj
# I originally used the regex from
# http://daringfireball.net/2010/07/improved_regex_for_matching_urls
# but it gets all exponential on certain patterns (such as too many trailing
@ -366,7 +364,7 @@ def linkify(text, shorten=False, extra_params="",
# have a status bar, such as Safari by default)
params += ' title="%s"' % href
return u('<a href="%s"%s>%s</a>') % (href, params, url)
return u'<a href="%s"%s>%s</a>' % (href, params, url)
# First HTML-escape so that our strings are all safe.
# The regex is modified to avoid character entites other than &amp; so
@ -396,4 +394,5 @@ def _build_unicode_map():
unicode_map[name] = unichr(value)
return unicode_map
_HTML_UNICODE_MAP = _build_unicode_map()

View File

@ -74,7 +74,7 @@ See the `convert_yielded` function to extend this mechanism.
via ``singledispatch``.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import collections
import functools
@ -83,16 +83,18 @@ import os
import sys
import textwrap
import types
import weakref
from tornado.concurrent import Future, TracebackFuture, is_future, chain_future
from tornado.ioloop import IOLoop
from tornado.log import app_log
from tornado import stack_context
from tornado.util import raise_exc_info
from tornado.util import PY3, raise_exc_info
try:
try:
from functools import singledispatch # py34+
# py34+
from functools import singledispatch # type: ignore
except ImportError:
from singledispatch import singledispatch # backport
except ImportError:
@ -108,12 +110,14 @@ except ImportError:
try:
try:
from collections.abc import Generator as GeneratorType # py35+
# py35+
from collections.abc import Generator as GeneratorType # type: ignore
except ImportError:
from backports_abc import Generator as GeneratorType
from backports_abc import Generator as GeneratorType # type: ignore
try:
from inspect import isawaitable # py35+
# py35+
from inspect import isawaitable # type: ignore
except ImportError:
from backports_abc import isawaitable
except ImportError:
@ -121,12 +125,12 @@ except ImportError:
raise
from types import GeneratorType
def isawaitable(x):
def isawaitable(x): # type: ignore
return False
try:
import builtins # py3
except ImportError:
if PY3:
import builtins
else:
import __builtin__ as builtins
@ -242,6 +246,26 @@ def coroutine(func, replace_callback=True):
return _make_coroutine_wrapper(func, replace_callback=True)
# Ties lifetime of runners to their result futures. Github Issue #1769
# Generators, like any object in Python, must be strong referenced
# in order to not be cleaned up by the garbage collector. When using
# coroutines, the Runner object is what strong-refs the inner
# generator. However, the only item that strong-reffed the Runner
# was the last Future that the inner generator yielded (via the
# Future's internal done_callback list). Usually this is enough, but
# it is also possible for this Future to not have any strong references
# other than other objects referenced by the Runner object (usually
# when using other callback patterns and/or weakrefs). In this
# situation, if a garbage collection ran, a cycle would be detected and
# Runner objects could be destroyed along with their inner generators
# and everything in their local scope.
# This map provides strong references to Runner objects as long as
# their result future objects also have strong references (typically
# from the parent coroutine's Runner). This keeps the coroutine's
# Runner alive.
_futures_to_runners = weakref.WeakKeyDictionary()
def _make_coroutine_wrapper(func, replace_callback):
"""The inner workings of ``@gen.coroutine`` and ``@gen.engine``.
@ -251,10 +275,11 @@ def _make_coroutine_wrapper(func, replace_callback):
"""
# On Python 3.5, set the coroutine flag on our generator, to allow it
# to be used with 'await'.
wrapped = func
if hasattr(types, 'coroutine'):
func = types.coroutine(func)
@functools.wraps(func)
@functools.wraps(wrapped)
def wrapper(*args, **kwargs):
future = TracebackFuture()
@ -291,7 +316,8 @@ def _make_coroutine_wrapper(func, replace_callback):
except Exception:
future.set_exc_info(sys.exc_info())
else:
Runner(result, future, yielded)
_futures_to_runners[future] = Runner(result, future, yielded)
yielded = None
try:
return future
finally:
@ -306,9 +332,21 @@ def _make_coroutine_wrapper(func, replace_callback):
future = None
future.set_result(result)
return future
wrapper.__wrapped__ = wrapped
wrapper.__tornado_coroutine__ = True
return wrapper
def is_coroutine_function(func):
"""Return whether *func* is a coroutine function, i.e. a function
wrapped with `~.gen.coroutine`.
.. versionadded:: 4.5
"""
return getattr(func, '__tornado_coroutine__', False)
class Return(Exception):
"""Special exception to return a value from a `coroutine`.
@ -682,6 +720,7 @@ def multi(children, quiet_exceptions=()):
else:
return multi_future(children, quiet_exceptions=quiet_exceptions)
Multi = multi
@ -830,7 +869,7 @@ def maybe_future(x):
def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
"""Wraps a `.Future` in a timeout.
"""Wraps a `.Future` (or other yieldable object) in a timeout.
Raises `TimeoutError` if the input future does not complete before
``timeout``, which may be specified in any form allowed by
@ -841,15 +880,18 @@ def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
will be logged unless it is of a type contained in ``quiet_exceptions``
(which may be an exception type or a sequence of types).
Currently only supports Futures, not other `YieldPoint` classes.
Does not support `YieldPoint` subclasses.
.. versionadded:: 4.0
.. versionchanged:: 4.1
Added the ``quiet_exceptions`` argument and the logging of unhandled
exceptions.
.. versionchanged:: 4.4
Added support for yieldable objects other than `.Future`.
"""
# TODO: allow yield points in addition to futures?
# TODO: allow YieldPoints in addition to other yieldables?
# Tricky to do with stack_context semantics.
#
# It's tempting to optimize this by cancelling the input future on timeout
@ -857,6 +899,7 @@ def with_timeout(timeout, future, io_loop=None, quiet_exceptions=()):
# one waiting on the input future, so cancelling it might disrupt other
# callers and B) concurrent futures can only be cancelled while they are
# in the queue, so cancellation cannot reliably bound our waiting time.
future = convert_yielded(future)
result = Future()
chain_future(future, result)
if io_loop is None:
@ -923,6 +966,9 @@ coroutines that are likely to yield Futures that are ready instantly.
Usage: ``yield gen.moment``
.. versionadded:: 4.0
.. deprecated:: 4.5
``yield None`` is now equivalent to ``yield gen.moment``.
"""
moment.set_result(None)
@ -953,6 +999,7 @@ class Runner(object):
# of the coroutine.
self.stack_context_deactivate = None
if self.handle_yield(first_yielded):
gen = result_future = first_yielded = None
self.run()
def register_callback(self, key):
@ -1009,10 +1056,15 @@ class Runner(object):
except Exception:
self.had_exception = True
exc_info = sys.exc_info()
future = None
if exc_info is not None:
yielded = self.gen.throw(*exc_info)
exc_info = None
try:
yielded = self.gen.throw(*exc_info)
finally:
# Break up a reference to itself
# for faster GC on CPython.
exc_info = None
else:
yielded = self.gen.send(value)
@ -1045,6 +1097,7 @@ class Runner(object):
return
if not self.handle_yield(yielded):
return
yielded = None
finally:
self.running = False
@ -1093,8 +1146,12 @@ class Runner(object):
self.future.set_exc_info(sys.exc_info())
if not self.future.done() or self.future is moment:
def inner(f):
# Break a reference cycle to speed GC.
f = None # noqa
self.run()
self.io_loop.add_future(
self.future, lambda f: self.run())
self.future, inner)
return False
return True
@ -1116,6 +1173,7 @@ class Runner(object):
self.stack_context_deactivate()
self.stack_context_deactivate = None
Arguments = collections.namedtuple('Arguments', ['args', 'kwargs'])
@ -1135,6 +1193,7 @@ def _argument_adapter(callback):
callback(None)
return wrapper
# Convert Awaitables into Futures. It is unfortunately possible
# to have infinite recursion here if those Awaitables assume that
# we're using a different coroutine runner and yield objects
@ -1212,7 +1271,9 @@ def convert_yielded(yielded):
.. versionadded:: 4.1
"""
# Lists and dicts containing YieldPoints were handled earlier.
if isinstance(yielded, (list, dict)):
if yielded is None:
return moment
elif isinstance(yielded, (list, dict)):
return multi(yielded)
elif is_future(yielded):
return yielded
@ -1221,6 +1282,7 @@ def convert_yielded(yielded):
else:
raise BadYieldError("yielded unknown object %r" % (yielded,))
if singledispatch is not None:
convert_yielded = singledispatch(convert_yielded)

View File

@ -19,7 +19,7 @@
.. versionadded:: 4.0
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import re
@ -30,7 +30,7 @@ from tornado import httputil
from tornado import iostream
from tornado.log import gen_log, app_log
from tornado import stack_context
from tornado.util import GzipDecompressor
from tornado.util import GzipDecompressor, PY3
class _QuietException(Exception):
@ -257,6 +257,7 @@ class HTTP1Connection(httputil.HTTPConnection):
if need_delegate_close:
with _ExceptionLoggingContext(app_log):
delegate.on_connection_close()
header_future = None
self._clear_callbacks()
raise gen.Return(True)
@ -342,7 +343,7 @@ class HTTP1Connection(httputil.HTTPConnection):
'Transfer-Encoding' not in headers)
else:
self._response_start_line = start_line
lines.append(utf8('HTTP/1.1 %s %s' % (start_line[1], start_line[2])))
lines.append(utf8('HTTP/1.1 %d %s' % (start_line[1], start_line[2])))
self._chunking_output = (
# TODO: should this use
# self._request_start_line.version or
@ -351,7 +352,7 @@ class HTTP1Connection(httputil.HTTPConnection):
# 304 responses have no body (not even a zero-length body), and so
# should not have either Content-Length or Transfer-Encoding.
# headers.
start_line.code != 304 and
start_line.code not in (204, 304) and
# No need to chunk the output if a Content-Length is specified.
'Content-Length' not in headers and
# Applications are discouraged from touching Transfer-Encoding,
@ -359,8 +360,8 @@ class HTTP1Connection(httputil.HTTPConnection):
'Transfer-Encoding' not in headers)
# If a 1.0 client asked for keep-alive, add the header.
if (self._request_start_line.version == 'HTTP/1.0' and
(self._request_headers.get('Connection', '').lower()
== 'keep-alive')):
(self._request_headers.get('Connection', '').lower() ==
'keep-alive')):
headers['Connection'] = 'Keep-Alive'
if self._chunking_output:
headers['Transfer-Encoding'] = 'chunked'
@ -372,7 +373,14 @@ class HTTP1Connection(httputil.HTTPConnection):
self._expected_content_remaining = int(headers['Content-Length'])
else:
self._expected_content_remaining = None
lines.extend([utf8(n) + b": " + utf8(v) for n, v in headers.get_all()])
# TODO: headers are supposed to be of type str, but we still have some
# cases that let bytes slip through. Remove these native_str calls when those
# are fixed.
header_lines = (native_str(n) + ": " + native_str(v) for n, v in headers.get_all())
if PY3:
lines.extend(l.encode('latin1') for l in header_lines)
else:
lines.extend(header_lines)
for line in lines:
if b'\n' in line:
raise ValueError('Newline in header: ' + repr(line))
@ -479,9 +487,11 @@ class HTTP1Connection(httputil.HTTPConnection):
connection_header = connection_header.lower()
if start_line.version == "HTTP/1.1":
return connection_header != "close"
elif ("Content-Length" in headers
or headers.get("Transfer-Encoding", "").lower() == "chunked"
or start_line.method in ("HEAD", "GET")):
elif ("Content-Length" in headers or
headers.get("Transfer-Encoding", "").lower() == "chunked" or
getattr(start_line, 'method', None) in ("HEAD", "GET")):
# start_line may be a request or response start line; only
# the former has a method attribute.
return connection_header == "keep-alive"
return False
@ -531,7 +541,13 @@ class HTTP1Connection(httputil.HTTPConnection):
"Multiple unequal Content-Lengths: %r" %
headers["Content-Length"])
headers["Content-Length"] = pieces[0]
content_length = int(headers["Content-Length"])
try:
content_length = int(headers["Content-Length"])
except ValueError:
# Handles non-integer Content-Length value.
raise httputil.HTTPInputError(
"Only integer Content-Length is allowed: %s" % headers["Content-Length"])
if content_length > self._max_body_size:
raise httputil.HTTPInputError("Content-Length too long")
@ -550,7 +566,7 @@ class HTTP1Connection(httputil.HTTPConnection):
if content_length is not None:
return self._read_fixed_body(content_length, delegate)
if headers.get("Transfer-Encoding") == "chunked":
if headers.get("Transfer-Encoding", "").lower() == "chunked":
return self._read_chunked_body(delegate)
if self.is_client:
return self._read_body_until_close(delegate)

View File

@ -25,7 +25,7 @@ to switch to ``curl_httpclient`` for reasons such as the following:
Note that if you are using ``curl_httpclient``, it is highly
recommended that you use a recent version of ``libcurl`` and
``pycurl``. Currently the minimum supported version of libcurl is
7.21.1, and the minimum version of pycurl is 7.18.2. It is highly
7.22.0, and the minimum version of pycurl is 7.18.2. It is highly
recommended that your ``libcurl`` installation is built with
asynchronous DNS resolver (threaded or c-ares), otherwise you may
encounter various problems with request timeouts (for more
@ -38,7 +38,7 @@ To select ``curl_httpclient``, call `AsyncHTTPClient.configure` at startup::
AsyncHTTPClient.configure("tornado.curl_httpclient.CurlAsyncHTTPClient")
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import functools
import time
@ -61,7 +61,7 @@ class HTTPClient(object):
http_client = httpclient.HTTPClient()
try:
response = http_client.fetch("http://www.google.com/")
print response.body
print(response.body)
except httpclient.HTTPError as e:
# HTTPError is raised for non-200 responses; the response
# can be found in e.response.
@ -108,14 +108,14 @@ class AsyncHTTPClient(Configurable):
Example usage::
def handle_request(response):
def handle_response(response):
if response.error:
print "Error:", response.error
print("Error: %s" % response.error)
else:
print response.body
print(response.body)
http_client = AsyncHTTPClient()
http_client.fetch("http://www.google.com/", handle_request)
http_client.fetch("http://www.google.com/", handle_response)
The constructor for this class is magic in several respects: It
actually creates an instance of an implementation-specific
@ -211,10 +211,12 @@ class AsyncHTTPClient(Configurable):
kwargs: ``HTTPRequest(request, **kwargs)``
This method returns a `.Future` whose result is an
`HTTPResponse`. By default, the ``Future`` will raise an `HTTPError`
if the request returned a non-200 response code. Instead, if
``raise_error`` is set to False, the response will always be
returned regardless of the response code.
`HTTPResponse`. By default, the ``Future`` will raise an
`HTTPError` if the request returned a non-200 response code
(other errors may also be raised if the server could not be
contacted). Instead, if ``raise_error`` is set to False, the
response will always be returned regardless of the response
code.
If a ``callback`` is given, it will be invoked with the `HTTPResponse`.
In the callback interface, `HTTPError` is not automatically raised.
@ -225,6 +227,9 @@ class AsyncHTTPClient(Configurable):
raise RuntimeError("fetch() called on closed AsyncHTTPClient")
if not isinstance(request, HTTPRequest):
request = HTTPRequest(url=request, **kwargs)
else:
if kwargs:
raise ValueError("kwargs can't be used if request is an HTTPRequest object")
# We may modify this (to add Host, Accept-Encoding, etc),
# so make sure we don't modify the caller's object. This is also
# where normal dicts get converted to HTTPHeaders objects.
@ -305,10 +310,10 @@ class HTTPRequest(object):
network_interface=None, streaming_callback=None,
header_callback=None, prepare_curl_callback=None,
proxy_host=None, proxy_port=None, proxy_username=None,
proxy_password=None, allow_nonstandard_methods=None,
validate_cert=None, ca_certs=None,
allow_ipv6=None,
client_key=None, client_cert=None, body_producer=None,
proxy_password=None, proxy_auth_mode=None,
allow_nonstandard_methods=None, validate_cert=None,
ca_certs=None, allow_ipv6=None, client_key=None,
client_cert=None, body_producer=None,
expect_100_continue=False, decompress_response=None,
ssl_options=None):
r"""All parameters except ``url`` are optional.
@ -336,13 +341,15 @@ class HTTPRequest(object):
Allowed values are implementation-defined; ``curl_httpclient``
supports "basic" and "digest"; ``simple_httpclient`` only supports
"basic"
:arg float connect_timeout: Timeout for initial connection in seconds
:arg float request_timeout: Timeout for entire request in seconds
:arg float connect_timeout: Timeout for initial connection in seconds,
default 20 seconds
:arg float request_timeout: Timeout for entire request in seconds,
default 20 seconds
:arg if_modified_since: Timestamp for ``If-Modified-Since`` header
:type if_modified_since: `datetime` or `float`
:arg bool follow_redirects: Should redirects be followed automatically
or return the 3xx response?
:arg int max_redirects: Limit for ``follow_redirects``
or return the 3xx response? Default True.
:arg int max_redirects: Limit for ``follow_redirects``, default 5.
:arg string user_agent: String to send as ``User-Agent`` header
:arg bool decompress_response: Request a compressed response from
the server and decompress it after downloading. Default is True.
@ -367,16 +374,18 @@ class HTTPRequest(object):
a ``pycurl.Curl`` object to allow the application to make additional
``setopt`` calls.
:arg string proxy_host: HTTP proxy hostname. To use proxies,
``proxy_host`` and ``proxy_port`` must be set; ``proxy_username`` and
``proxy_pass`` are optional. Proxies are currently only supported
with ``curl_httpclient``.
``proxy_host`` and ``proxy_port`` must be set; ``proxy_username``,
``proxy_pass`` and ``proxy_auth_mode`` are optional. Proxies are
currently only supported with ``curl_httpclient``.
:arg int proxy_port: HTTP proxy port
:arg string proxy_username: HTTP proxy username
:arg string proxy_password: HTTP proxy password
:arg string proxy_auth_mode: HTTP proxy Authentication mode;
default is "basic". supports "basic" and "digest"
:arg bool allow_nonstandard_methods: Allow unknown values for ``method``
argument?
argument? Default is False.
:arg bool validate_cert: For HTTPS requests, validate the server's
certificate?
certificate? Default is True.
:arg string ca_certs: filename of CA certificates in PEM format,
or None to use defaults. See note below when used with
``curl_httpclient``.
@ -414,6 +423,9 @@ class HTTPRequest(object):
.. versionadded:: 4.2
The ``ssl_options`` argument.
.. versionadded:: 4.5
The ``proxy_auth_mode`` argument.
"""
# Note that some of these attributes go through property setters
# defined below.
@ -425,6 +437,7 @@ class HTTPRequest(object):
self.proxy_port = proxy_port
self.proxy_username = proxy_username
self.proxy_password = proxy_password
self.proxy_auth_mode = proxy_auth_mode
self.url = url
self.method = method
self.body = body
@ -525,7 +538,7 @@ class HTTPResponse(object):
* buffer: ``cStringIO`` object for response body
* body: response body as string (created on demand from ``self.buffer``)
* body: response body as bytes (created on demand from ``self.buffer``)
* error: Exception object, if any
@ -567,7 +580,8 @@ class HTTPResponse(object):
self.request_time = request_time
self.time_info = time_info or {}
def _get_body(self):
@property
def body(self):
if self.buffer is None:
return None
elif self._body is None:
@ -575,8 +589,6 @@ class HTTPResponse(object):
return self._body
body = property(_get_body)
def rethrow(self):
"""If there was an error on the request, raise an `HTTPError`."""
if self.error:
@ -610,6 +622,12 @@ class HTTPError(Exception):
def __str__(self):
return "HTTP %d: %s" % (self.code, self.message)
# There is a cyclic reference between self and self.response,
# which breaks the default __repr__ implementation.
# (especially on pypy, which doesn't have the same recursion
# detection as cpython).
__repr__ = __str__
class _RequestProxy(object):
"""Combines an object with a dictionary of defaults.
@ -655,5 +673,6 @@ def main():
print(native_str(response.body))
client.close()
if __name__ == "__main__":
main()

View File

@ -26,7 +26,7 @@ class except to start a server at the beginning of the process
to `tornado.httputil.HTTPServerRequest`. The old name remains as an alias.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import socket
@ -62,6 +62,13 @@ class HTTPServer(TCPServer, Configurable,
if Tornado is run behind an SSL-decoding proxy that does not set one of
the supported ``xheaders``.
By default, when parsing the ``X-Forwarded-For`` header, Tornado will
select the last (i.e., the closest) address on the list of hosts as the
remote host IP address. To select the next server in the chain, a list of
trusted downstream hosts may be passed as the ``trusted_downstream``
argument. These hosts will be skipped when parsing the ``X-Forwarded-For``
header.
To make this server serve SSL traffic, send the ``ssl_options`` keyword
argument with an `ssl.SSLContext` object. For compatibility with older
versions of Python ``ssl_options`` may also be a dictionary of keyword
@ -124,6 +131,9 @@ class HTTPServer(TCPServer, Configurable,
.. versionchanged:: 4.2
`HTTPServer` is now a subclass of `tornado.util.Configurable`.
.. versionchanged:: 4.5
Added the ``trusted_downstream`` argument.
"""
def __init__(self, *args, **kwargs):
# Ignore args to __init__; real initialization belongs in
@ -138,7 +148,8 @@ class HTTPServer(TCPServer, Configurable,
decompress_request=False,
chunk_size=None, max_header_size=None,
idle_connection_timeout=None, body_timeout=None,
max_body_size=None, max_buffer_size=None):
max_body_size=None, max_buffer_size=None,
trusted_downstream=None):
self.request_callback = request_callback
self.no_keep_alive = no_keep_alive
self.xheaders = xheaders
@ -149,11 +160,13 @@ class HTTPServer(TCPServer, Configurable,
max_header_size=max_header_size,
header_timeout=idle_connection_timeout or 3600,
max_body_size=max_body_size,
body_timeout=body_timeout)
body_timeout=body_timeout,
no_keep_alive=no_keep_alive)
TCPServer.__init__(self, io_loop=io_loop, ssl_options=ssl_options,
max_buffer_size=max_buffer_size,
read_chunk_size=chunk_size)
self._connections = set()
self.trusted_downstream = trusted_downstream
@classmethod
def configurable_base(cls):
@ -172,21 +185,55 @@ class HTTPServer(TCPServer, Configurable,
def handle_stream(self, stream, address):
context = _HTTPRequestContext(stream, address,
self.protocol)
self.protocol,
self.trusted_downstream)
conn = HTTP1ServerConnection(
stream, self.conn_params, context)
self._connections.add(conn)
conn.start_serving(self)
def start_request(self, server_conn, request_conn):
return _ServerRequestAdapter(self, server_conn, request_conn)
if isinstance(self.request_callback, httputil.HTTPServerConnectionDelegate):
delegate = self.request_callback.start_request(server_conn, request_conn)
else:
delegate = _CallableAdapter(self.request_callback, request_conn)
if self.xheaders:
delegate = _ProxyAdapter(delegate, request_conn)
return delegate
def on_close(self, server_conn):
self._connections.remove(server_conn)
class _CallableAdapter(httputil.HTTPMessageDelegate):
def __init__(self, request_callback, request_conn):
self.connection = request_conn
self.request_callback = request_callback
self.request = None
self.delegate = None
self._chunks = []
def headers_received(self, start_line, headers):
self.request = httputil.HTTPServerRequest(
connection=self.connection, start_line=start_line,
headers=headers)
def data_received(self, chunk):
self._chunks.append(chunk)
def finish(self):
self.request.body = b''.join(self._chunks)
self.request._parse_body()
self.request_callback(self.request)
def on_connection_close(self):
self._chunks = None
class _HTTPRequestContext(object):
def __init__(self, stream, address, protocol):
def __init__(self, stream, address, protocol, trusted_downstream=None):
self.address = address
# Save the socket's address family now so we know how to
# interpret self.address even after the stream is closed
@ -210,6 +257,7 @@ class _HTTPRequestContext(object):
self.protocol = "http"
self._orig_remote_ip = self.remote_ip
self._orig_protocol = self.protocol
self.trusted_downstream = set(trusted_downstream or [])
def __str__(self):
if self.address_family in (socket.AF_INET, socket.AF_INET6):
@ -226,7 +274,10 @@ class _HTTPRequestContext(object):
"""Rewrite the ``remote_ip`` and ``protocol`` fields."""
# Squid uses X-Forwarded-For, others use X-Real-Ip
ip = headers.get("X-Forwarded-For", self.remote_ip)
ip = ip.split(',')[-1].strip()
# Skip trusted downstream hosts in X-Forwarded-For list
for ip in (cand.strip() for cand in reversed(ip.split(','))):
if ip not in self.trusted_downstream:
break
ip = headers.get("X-Real-Ip", ip)
if netutil.is_valid_ip(ip):
self.remote_ip = ip
@ -247,58 +298,28 @@ class _HTTPRequestContext(object):
self.protocol = self._orig_protocol
class _ServerRequestAdapter(httputil.HTTPMessageDelegate):
"""Adapts the `HTTPMessageDelegate` interface to the interface expected
by our clients.
"""
def __init__(self, server, server_conn, request_conn):
self.server = server
class _ProxyAdapter(httputil.HTTPMessageDelegate):
def __init__(self, delegate, request_conn):
self.connection = request_conn
self.request = None
if isinstance(server.request_callback,
httputil.HTTPServerConnectionDelegate):
self.delegate = server.request_callback.start_request(
server_conn, request_conn)
self._chunks = None
else:
self.delegate = None
self._chunks = []
self.delegate = delegate
def headers_received(self, start_line, headers):
if self.server.xheaders:
self.connection.context._apply_xheaders(headers)
if self.delegate is None:
self.request = httputil.HTTPServerRequest(
connection=self.connection, start_line=start_line,
headers=headers)
else:
return self.delegate.headers_received(start_line, headers)
self.connection.context._apply_xheaders(headers)
return self.delegate.headers_received(start_line, headers)
def data_received(self, chunk):
if self.delegate is None:
self._chunks.append(chunk)
else:
return self.delegate.data_received(chunk)
return self.delegate.data_received(chunk)
def finish(self):
if self.delegate is None:
self.request.body = b''.join(self._chunks)
self.request._parse_body()
self.server.request_callback(self.request)
else:
self.delegate.finish()
self.delegate.finish()
self._cleanup()
def on_connection_close(self):
if self.delegate is None:
self._chunks = None
else:
self.delegate.on_connection_close()
self.delegate.on_connection_close()
self._cleanup()
def _cleanup(self):
if self.server.xheaders:
self.connection.context._unapply_xheaders()
self.connection.context._unapply_xheaders()
HTTPRequest = httputil.HTTPServerRequest

View File

@ -20,7 +20,7 @@ This module also defines the `HTTPServerRequest` class which is exposed
via `tornado.web.RequestHandler.request`.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import calendar
import collections
@ -33,33 +33,37 @@ import time
from tornado.escape import native_str, parse_qs_bytes, utf8
from tornado.log import gen_log
from tornado.util import ObjectDict
from tornado.util import ObjectDict, PY3
try:
import Cookie # py2
except ImportError:
import http.cookies as Cookie # py3
if PY3:
import http.cookies as Cookie
from http.client import responses
from urllib.parse import urlencode, urlparse, urlunparse, parse_qsl
else:
import Cookie
from httplib import responses
from urllib import urlencode
from urlparse import urlparse, urlunparse, parse_qsl
try:
from httplib import responses # py2
except ImportError:
from http.client import responses # py3
# responses is unused in this file, but we re-export it to other files.
# Reference it so pyflakes doesn't complain.
responses
try:
from urllib import urlencode # py2
except ImportError:
from urllib.parse import urlencode # py3
try:
from ssl import SSLError
except ImportError:
# ssl is unavailable on app engine.
class SSLError(Exception):
class _SSLError(Exception):
pass
# Hack around a mypy limitation. We can't simply put "type: ignore"
# on the class definition itself; must go through an assignment.
SSLError = _SSLError # type: ignore
try:
import typing
except ImportError:
pass
# RFC 7230 section 3.5: a recipient MAY recognize a single LF as a line
@ -95,6 +99,7 @@ class _NormalizedHeaderCache(dict):
del self[old_key]
return normalized
_normalized_headers = _NormalizedHeaderCache(1000)
@ -127,8 +132,8 @@ class HTTPHeaders(collections.MutableMapping):
Set-Cookie: C=D
"""
def __init__(self, *args, **kwargs):
self._dict = {}
self._as_list = {}
self._dict = {} # type: typing.Dict[str, str]
self._as_list = {} # type: typing.Dict[str, typing.List[str]]
self._last_key = None
if (len(args) == 1 and len(kwargs) == 0 and
isinstance(args[0], HTTPHeaders)):
@ -142,6 +147,7 @@ class HTTPHeaders(collections.MutableMapping):
# new public methods
def add(self, name, value):
# type: (str, str) -> None
"""Adds a new value for the given key."""
norm_name = _normalized_headers[name]
self._last_key = norm_name
@ -158,6 +164,7 @@ class HTTPHeaders(collections.MutableMapping):
return self._as_list.get(norm_name, [])
def get_all(self):
# type: () -> typing.Iterable[typing.Tuple[str, str]]
"""Returns an iterable of all (name, value) pairs.
If a header has multiple values, multiple pairs will be
@ -206,6 +213,7 @@ class HTTPHeaders(collections.MutableMapping):
self._as_list[norm_name] = [value]
def __getitem__(self, name):
# type: (str) -> str
return self._dict[_normalized_headers[name]]
def __delitem__(self, name):
@ -228,6 +236,14 @@ class HTTPHeaders(collections.MutableMapping):
# the appearance that HTTPHeaders is a single container.
__copy__ = copy
def __str__(self):
lines = []
for name, value in self.get_all():
lines.append("%s: %s\n" % (name, value))
return "".join(lines)
__unicode__ = __str__
class HTTPServerRequest(object):
"""A single HTTP request.
@ -323,7 +339,7 @@ class HTTPServerRequest(object):
"""
def __init__(self, method=None, uri=None, version="HTTP/1.0", headers=None,
body=None, host=None, files=None, connection=None,
start_line=None):
start_line=None, server_connection=None):
if start_line is not None:
method, uri, version = start_line
self.method = method
@ -338,8 +354,10 @@ class HTTPServerRequest(object):
self.protocol = getattr(context, 'protocol', "http")
self.host = host or self.headers.get("Host") or "127.0.0.1"
self.host_name = split_host_and_port(self.host.lower())[0]
self.files = files or {}
self.connection = connection
self.server_connection = server_connection
self._start_time = time.time()
self._finish_time = None
@ -365,10 +383,18 @@ class HTTPServerRequest(object):
self._cookies = Cookie.SimpleCookie()
if "Cookie" in self.headers:
try:
self._cookies.load(
native_str(self.headers["Cookie"]))
parsed = parse_cookie(self.headers["Cookie"])
except Exception:
self._cookies = {}
pass
else:
for k, v in parsed.items():
try:
self._cookies[k] = v
except Exception:
# SimpleCookie imposes some restrictions on keys;
# parse_cookie does not. Discard any cookies
# with disallowed keys.
pass
return self._cookies
def write(self, chunk, callback=None):
@ -577,11 +603,28 @@ def url_concat(url, args):
>>> url_concat("http://example.com/foo?a=b", [("c", "d"), ("c", "d2")])
'http://example.com/foo?a=b&c=d&c=d2'
"""
if not args:
if args is None:
return url
if url[-1] not in ('?', '&'):
url += '&' if ('?' in url) else '?'
return url + urlencode(args)
parsed_url = urlparse(url)
if isinstance(args, dict):
parsed_query = parse_qsl(parsed_url.query, keep_blank_values=True)
parsed_query.extend(args.items())
elif isinstance(args, list) or isinstance(args, tuple):
parsed_query = parse_qsl(parsed_url.query, keep_blank_values=True)
parsed_query.extend(args)
else:
err = "'args' parameter should be dict, list or tuple. Not {0}".format(
type(args))
raise TypeError(err)
final_query = urlencode(parsed_query)
url = urlunparse((
parsed_url[0],
parsed_url[1],
parsed_url[2],
parsed_url[3],
final_query,
parsed_url[5]))
return url
class HTTPFile(ObjectDict):
@ -743,7 +786,7 @@ def parse_multipart_form_data(boundary, data, arguments, files):
name = disp_params["name"]
if disp_params.get("filename"):
ctype = headers.get("Content-Type", "application/unknown")
files.setdefault(name, []).append(HTTPFile(
files.setdefault(name, []).append(HTTPFile( # type: ignore
filename=disp_params["filename"], body=value,
content_type=ctype))
else:
@ -895,3 +938,84 @@ def split_host_and_port(netloc):
host = netloc
port = None
return (host, port)
_OctalPatt = re.compile(r"\\[0-3][0-7][0-7]")
_QuotePatt = re.compile(r"[\\].")
_nulljoin = ''.join
def _unquote_cookie(str):
"""Handle double quotes and escaping in cookie values.
This method is copied verbatim from the Python 3.5 standard
library (http.cookies._unquote) so we don't have to depend on
non-public interfaces.
"""
# If there aren't any doublequotes,
# then there can't be any special characters. See RFC 2109.
if str is None or len(str) < 2:
return str
if str[0] != '"' or str[-1] != '"':
return str
# We have to assume that we must decode this string.
# Down to work.
# Remove the "s
str = str[1:-1]
# Check for special sequences. Examples:
# \012 --> \n
# \" --> "
#
i = 0
n = len(str)
res = []
while 0 <= i < n:
o_match = _OctalPatt.search(str, i)
q_match = _QuotePatt.search(str, i)
if not o_match and not q_match: # Neither matched
res.append(str[i:])
break
# else:
j = k = -1
if o_match:
j = o_match.start(0)
if q_match:
k = q_match.start(0)
if q_match and (not o_match or k < j): # QuotePatt matched
res.append(str[i:k])
res.append(str[k + 1])
i = k + 2
else: # OctalPatt matched
res.append(str[i:j])
res.append(chr(int(str[j + 1:j + 4], 8)))
i = j + 4
return _nulljoin(res)
def parse_cookie(cookie):
"""Parse a ``Cookie`` HTTP header into a dict of name/value pairs.
This function attempts to mimic browser cookie parsing behavior;
it specifically does not follow any of the cookie-related RFCs
(because browsers don't either).
The algorithm used is identical to that used by Django version 1.9.10.
.. versionadded:: 4.4.2
"""
cookiedict = {}
for chunk in cookie.split(str(';')):
if str('=') in chunk:
key, val = chunk.split(str('='), 1)
else:
# Assume an empty name per
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
key, val = str(''), chunk
key, val = key.strip(), val.strip()
if key or val:
# unquote using Python's algorithm.
cookiedict[key] = _unquote_cookie(val)
return cookiedict

View File

@ -26,8 +26,9 @@ In addition to I/O events, the `IOLoop` can also schedule time-based events.
`IOLoop.add_timeout` is a non-blocking alternative to `time.sleep`.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import collections
import datetime
import errno
import functools
@ -45,20 +46,20 @@ import math
from tornado.concurrent import TracebackFuture, is_future
from tornado.log import app_log, gen_log
from tornado.platform.auto import set_close_exec, Waker
from tornado import stack_context
from tornado.util import Configurable, errno_from_exception, timedelta_to_seconds
from tornado.util import PY3, Configurable, errno_from_exception, timedelta_to_seconds
try:
import signal
except ImportError:
signal = None
try:
import thread # py2
except ImportError:
import _thread as thread # py3
from tornado.platform.auto import set_close_exec, Waker
if PY3:
import _thread as thread
else:
import thread
_POLL_TIMEOUT = 3600.0
@ -172,6 +173,10 @@ class IOLoop(Configurable):
This is normally not necessary as `instance()` will create
an `IOLoop` on demand, but you may want to call `install` to use
a custom subclass of `IOLoop`.
When using an `IOLoop` subclass, `install` must be called prior
to creating any objects that implicitly create their own
`IOLoop` (e.g., :class:`tornado.httpclient.AsyncHTTPClient`).
"""
assert not IOLoop.initialized()
IOLoop._instance = self
@ -612,10 +617,14 @@ class IOLoop(Configurable):
# result, which should just be ignored.
pass
else:
self.add_future(ret, lambda f: f.result())
self.add_future(ret, self._discard_future_result)
except Exception:
self.handle_callback_exception(callback)
def _discard_future_result(self, future):
"""Avoid unhandled-exception warnings from spawned coroutines."""
future.result()
def handle_callback_exception(self, callback):
"""This method is called whenever a callback run by the `IOLoop`
throws an exception.
@ -685,8 +694,7 @@ class PollIOLoop(IOLoop):
self.time_func = time_func or time.time
self._handlers = {}
self._events = {}
self._callbacks = []
self._callback_lock = threading.Lock()
self._callbacks = collections.deque()
self._timeouts = []
self._cancellations = 0
self._running = False
@ -704,11 +712,10 @@ class PollIOLoop(IOLoop):
self.READ)
def close(self, all_fds=False):
with self._callback_lock:
self._closing = True
self._closing = True
self.remove_handler(self._waker.fileno())
if all_fds:
for fd, handler in self._handlers.values():
for fd, handler in list(self._handlers.values()):
self.close_fd(fd)
self._waker.close()
self._impl.close()
@ -792,9 +799,7 @@ class PollIOLoop(IOLoop):
while True:
# Prevent IO event starvation by delaying new callbacks
# to the next iteration of the event loop.
with self._callback_lock:
callbacks = self._callbacks
self._callbacks = []
ncallbacks = len(self._callbacks)
# Add any timeouts that have come due to the callback list.
# Do not run anything until we have determined which ones
@ -814,8 +819,8 @@ class PollIOLoop(IOLoop):
due_timeouts.append(heapq.heappop(self._timeouts))
else:
break
if (self._cancellations > 512
and self._cancellations > (len(self._timeouts) >> 1)):
if (self._cancellations > 512 and
self._cancellations > (len(self._timeouts) >> 1)):
# Clean up the timeout queue when it gets large and it's
# more than half cancellations.
self._cancellations = 0
@ -823,14 +828,14 @@ class PollIOLoop(IOLoop):
if x.callback is not None]
heapq.heapify(self._timeouts)
for callback in callbacks:
self._run_callback(callback)
for i in range(ncallbacks):
self._run_callback(self._callbacks.popleft())
for timeout in due_timeouts:
if timeout.callback is not None:
self._run_callback(timeout.callback)
# Closures may be holding on to a lot of memory, so allow
# them to be freed before we go into our poll wait.
callbacks = callback = due_timeouts = timeout = None
due_timeouts = timeout = None
if self._callbacks:
# If any callbacks or timeouts called add_callback,
@ -874,7 +879,7 @@ class PollIOLoop(IOLoop):
# Pop one fd at a time from the set of pending fds and run
# its handler. Since that handler may perform actions on
# other file descriptors, there may be reentrant calls to
# this IOLoop that update self._events
# this IOLoop that modify self._events
self._events.update(event_pairs)
while self._events:
fd, events = self._events.popitem()
@ -926,36 +931,20 @@ class PollIOLoop(IOLoop):
self._cancellations += 1
def add_callback(self, callback, *args, **kwargs):
if self._closing:
return
# Blindly insert into self._callbacks. This is safe even
# from signal handlers because deque.append is atomic.
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))
if thread.get_ident() != self._thread_ident:
# If we're not on the IOLoop's thread, we need to synchronize
# with other threads, or waking logic will induce a race.
with self._callback_lock:
if self._closing:
return
list_empty = not self._callbacks
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))
if list_empty:
# If we're not in the IOLoop's thread, and we added the
# first callback to an empty list, we may need to wake it
# up (it may wake up on its own, but an occasional extra
# wake is harmless). Waking up a polling IOLoop is
# relatively expensive, so we try to avoid it when we can.
self._waker.wake()
# This will write one byte but Waker.consume() reads many
# at once, so it's ok to write even when not strictly
# necessary.
self._waker.wake()
else:
if self._closing:
return
# If we're on the IOLoop's thread, we don't need the lock,
# since we don't need to wake anyone, just add the
# callback. Blindly insert into self._callbacks. This is
# safe even from signal handlers because the GIL makes
# list.append atomic. One subtlety is that if the signal
# is interrupting another thread holding the
# _callback_lock block in IOLoop.start, we may modify
# either the old or new version of self._callbacks, but
# either way will work.
self._callbacks.append(functools.partial(
stack_context.wrap(callback), *args, **kwargs))
# If we're on the IOLoop's thread, we don't need to wake anyone.
pass
def add_callback_from_signal(self, callback, *args, **kwargs):
with stack_context.NullContext():
@ -966,26 +955,24 @@ class _Timeout(object):
"""An IOLoop timeout, a UNIX timestamp and a callback"""
# Reduce memory overhead when there are lots of pending callbacks
__slots__ = ['deadline', 'callback', 'tiebreaker']
__slots__ = ['deadline', 'callback', 'tdeadline']
def __init__(self, deadline, callback, io_loop):
if not isinstance(deadline, numbers.Real):
raise TypeError("Unsupported deadline %r" % deadline)
self.deadline = deadline
self.callback = callback
self.tiebreaker = next(io_loop._timeout_counter)
self.tdeadline = (deadline, next(io_loop._timeout_counter))
# Comparison methods to sort by deadline, with object id as a tiebreaker
# to guarantee a consistent ordering. The heapq module uses __le__
# in python2.5, and __lt__ in 2.6+ (sort() and most other comparisons
# use __lt__).
def __lt__(self, other):
return ((self.deadline, self.tiebreaker) <
(other.deadline, other.tiebreaker))
return self.tdeadline < other.tdeadline
def __le__(self, other):
return ((self.deadline, self.tiebreaker) <=
(other.deadline, other.tiebreaker))
return self.tdeadline <= other.tdeadline
class PeriodicCallback(object):
@ -1048,6 +1035,7 @@ class PeriodicCallback(object):
if self._next_timeout <= current_time:
callback_time_sec = self.callback_time / 1000.0
self._next_timeout += (math.floor((current_time - self._next_timeout) / callback_time_sec) + 1) * callback_time_sec
self._next_timeout += (math.floor((current_time - self._next_timeout) /
callback_time_sec) + 1) * callback_time_sec
self._timeout = self.io_loop.add_timeout(self._next_timeout, self._run)

View File

@ -24,7 +24,7 @@ Contents:
* `PipeIOStream`: Pipe-based IOStream implementation.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import collections
import errno
@ -58,7 +58,7 @@ except ImportError:
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
if hasattr(errno, "WSAEWOULDBLOCK"):
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore
# These errnos indicate that a connection has been abruptly terminated.
# They should be caught and handled less noisily than other errors.
@ -66,7 +66,7 @@ _ERRNO_CONNRESET = (errno.ECONNRESET, errno.ECONNABORTED, errno.EPIPE,
errno.ETIMEDOUT)
if hasattr(errno, "WSAECONNRESET"):
_ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT)
_ERRNO_CONNRESET += (errno.WSAECONNRESET, errno.WSAECONNABORTED, errno.WSAETIMEDOUT) # type: ignore
if sys.platform == 'darwin':
# OSX appears to have a race condition that causes send(2) to return
@ -74,13 +74,15 @@ if sys.platform == 'darwin':
# http://erickt.github.io/blog/2014/11/19/adventures-in-debugging-a-potential-osx-kernel-bug/
# Since the socket is being closed anyway, treat this as an ECONNRESET
# instead of an unexpected error.
_ERRNO_CONNRESET += (errno.EPROTOTYPE,)
_ERRNO_CONNRESET += (errno.EPROTOTYPE,) # type: ignore
# More non-portable errnos:
_ERRNO_INPROGRESS = (errno.EINPROGRESS,)
if hasattr(errno, "WSAEINPROGRESS"):
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,)
_ERRNO_INPROGRESS += (errno.WSAEINPROGRESS,) # type: ignore
_WINDOWS = sys.platform.startswith('win')
class StreamClosedError(IOError):
@ -158,11 +160,16 @@ class BaseIOStream(object):
self.max_buffer_size // 2)
self.max_write_buffer_size = max_write_buffer_size
self.error = None
self._read_buffer = collections.deque()
self._write_buffer = collections.deque()
self._read_buffer = bytearray()
self._read_buffer_pos = 0
self._read_buffer_size = 0
self._write_buffer = bytearray()
self._write_buffer_pos = 0
self._write_buffer_size = 0
self._write_buffer_frozen = False
self._total_write_index = 0
self._total_write_done_index = 0
self._pending_writes_while_frozen = []
self._read_delimiter = None
self._read_regex = None
self._read_max_bytes = None
@ -173,7 +180,7 @@ class BaseIOStream(object):
self._read_future = None
self._streaming_callback = None
self._write_callback = None
self._write_future = None
self._write_futures = collections.deque()
self._close_callback = None
self._connect_callback = None
self._connect_future = None
@ -367,36 +374,37 @@ class BaseIOStream(object):
If no ``callback`` is given, this method returns a `.Future` that
resolves (with a result of ``None``) when the write has been
completed. If `write` is called again before that `.Future` has
resolved, the previous future will be orphaned and will never resolve.
completed.
The ``data`` argument may be of type `bytes` or `memoryview`.
.. versionchanged:: 4.0
Now returns a `.Future` if no callback is given.
.. versionchanged:: 4.5
Added support for `memoryview` arguments.
"""
assert isinstance(data, bytes)
self._check_closed()
# We use bool(_write_buffer) as a proxy for write_buffer_size>0,
# so never put empty strings in the buffer.
if data:
if (self.max_write_buffer_size is not None and
self._write_buffer_size + len(data) > self.max_write_buffer_size):
raise StreamBufferFullError("Reached maximum write buffer size")
# Break up large contiguous strings before inserting them in the
# write buffer, so we don't have to recopy the entire thing
# as we slice off pieces to send to the socket.
WRITE_BUFFER_CHUNK_SIZE = 128 * 1024
for i in range(0, len(data), WRITE_BUFFER_CHUNK_SIZE):
self._write_buffer.append(data[i:i + WRITE_BUFFER_CHUNK_SIZE])
self._write_buffer_size += len(data)
if self._write_buffer_frozen:
self._pending_writes_while_frozen.append(data)
else:
self._write_buffer += data
self._write_buffer_size += len(data)
self._total_write_index += len(data)
if callback is not None:
self._write_callback = stack_context.wrap(callback)
future = None
else:
future = self._write_future = TracebackFuture()
future = TracebackFuture()
future.add_done_callback(lambda f: f.exception())
self._write_futures.append((self._total_write_index, future))
if not self._connecting:
self._handle_write()
if self._write_buffer:
if self._write_buffer_size:
self._add_io_state(self.io_loop.WRITE)
self._maybe_add_error_listener()
return future
@ -445,9 +453,8 @@ class BaseIOStream(object):
if self._read_future is not None:
futures.append(self._read_future)
self._read_future = None
if self._write_future is not None:
futures.append(self._write_future)
self._write_future = None
futures += [future for _, future in self._write_futures]
self._write_futures.clear()
if self._connect_future is not None:
futures.append(self._connect_future)
self._connect_future = None
@ -466,6 +473,7 @@ class BaseIOStream(object):
# if the IOStream object is kept alive by a reference cycle.
# TODO: Clear the read buffer too; it currently breaks some tests.
self._write_buffer = None
self._write_buffer_size = 0
def reading(self):
"""Returns true if we are currently reading from the stream."""
@ -473,7 +481,7 @@ class BaseIOStream(object):
def writing(self):
"""Returns true if we are currently writing to the stream."""
return bool(self._write_buffer)
return self._write_buffer_size > 0
def closed(self):
"""Returns true if the stream has been closed."""
@ -743,7 +751,7 @@ class BaseIOStream(object):
break
if chunk is None:
return 0
self._read_buffer.append(chunk)
self._read_buffer += chunk
self._read_buffer_size += len(chunk)
if self._read_buffer_size > self.max_buffer_size:
gen_log.error("Reached maximum read buffer size")
@ -791,30 +799,25 @@ class BaseIOStream(object):
# since large merges are relatively expensive and get undone in
# _consume().
if self._read_buffer:
while True:
loc = self._read_buffer[0].find(self._read_delimiter)
if loc != -1:
delimiter_len = len(self._read_delimiter)
self._check_max_bytes(self._read_delimiter,
loc + delimiter_len)
return loc + delimiter_len
if len(self._read_buffer) == 1:
break
_double_prefix(self._read_buffer)
loc = self._read_buffer.find(self._read_delimiter,
self._read_buffer_pos)
if loc != -1:
loc -= self._read_buffer_pos
delimiter_len = len(self._read_delimiter)
self._check_max_bytes(self._read_delimiter,
loc + delimiter_len)
return loc + delimiter_len
self._check_max_bytes(self._read_delimiter,
len(self._read_buffer[0]))
self._read_buffer_size)
elif self._read_regex is not None:
if self._read_buffer:
while True:
m = self._read_regex.search(self._read_buffer[0])
if m is not None:
self._check_max_bytes(self._read_regex, m.end())
return m.end()
if len(self._read_buffer) == 1:
break
_double_prefix(self._read_buffer)
self._check_max_bytes(self._read_regex,
len(self._read_buffer[0]))
m = self._read_regex.search(self._read_buffer,
self._read_buffer_pos)
if m is not None:
loc = m.end() - self._read_buffer_pos
self._check_max_bytes(self._read_regex, loc)
return loc
self._check_max_bytes(self._read_regex, self._read_buffer_size)
return None
def _check_max_bytes(self, delimiter, size):
@ -824,35 +827,56 @@ class BaseIOStream(object):
"delimiter %r not found within %d bytes" % (
delimiter, self._read_max_bytes))
def _freeze_write_buffer(self, size):
self._write_buffer_frozen = size
def _unfreeze_write_buffer(self):
self._write_buffer_frozen = False
self._write_buffer += b''.join(self._pending_writes_while_frozen)
self._write_buffer_size += sum(map(len, self._pending_writes_while_frozen))
self._pending_writes_while_frozen[:] = []
def _got_empty_write(self, size):
"""
Called when a non-blocking write() failed writing anything.
Can be overridden in subclasses.
"""
def _handle_write(self):
while self._write_buffer:
while self._write_buffer_size:
assert self._write_buffer_size >= 0
try:
if not self._write_buffer_frozen:
start = self._write_buffer_pos
if self._write_buffer_frozen:
size = self._write_buffer_frozen
elif _WINDOWS:
# On windows, socket.send blows up if given a
# write buffer that's too large, instead of just
# returning the number of bytes it was able to
# process. Therefore we must not call socket.send
# with more than 128KB at a time.
_merge_prefix(self._write_buffer, 128 * 1024)
num_bytes = self.write_to_fd(self._write_buffer[0])
size = 128 * 1024
else:
size = self._write_buffer_size
num_bytes = self.write_to_fd(
memoryview(self._write_buffer)[start:start + size])
if num_bytes == 0:
# With OpenSSL, if we couldn't write the entire buffer,
# the very same string object must be used on the
# next call to send. Therefore we suppress
# merging the write buffer after an incomplete send.
# A cleaner solution would be to set
# SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER, but this is
# not yet accessible from python
# (http://bugs.python.org/issue8240)
self._write_buffer_frozen = True
self._got_empty_write(size)
break
self._write_buffer_frozen = False
_merge_prefix(self._write_buffer, num_bytes)
self._write_buffer.popleft()
self._write_buffer_pos += num_bytes
self._write_buffer_size -= num_bytes
# Amortized O(1) shrink
# (this heuristic is implemented natively in Python 3.4+
# but is replicated here for Python 2)
if self._write_buffer_pos > self._write_buffer_size:
del self._write_buffer[:self._write_buffer_pos]
self._write_buffer_pos = 0
if self._write_buffer_frozen:
self._unfreeze_write_buffer()
self._total_write_done_index += num_bytes
except (socket.error, IOError, OSError) as e:
if e.args[0] in _ERRNO_WOULDBLOCK:
self._write_buffer_frozen = True
self._got_empty_write(size)
break
else:
if not self._is_connreset(e):
@ -863,22 +887,38 @@ class BaseIOStream(object):
self.fileno(), e)
self.close(exc_info=True)
return
if not self._write_buffer:
while self._write_futures:
index, future = self._write_futures[0]
if index > self._total_write_done_index:
break
self._write_futures.popleft()
future.set_result(None)
if not self._write_buffer_size:
if self._write_callback:
callback = self._write_callback
self._write_callback = None
self._run_callback(callback)
if self._write_future:
future = self._write_future
self._write_future = None
future.set_result(None)
def _consume(self, loc):
# Consume loc bytes from the read buffer and return them
if loc == 0:
return b""
_merge_prefix(self._read_buffer, loc)
assert loc <= self._read_buffer_size
# Slice the bytearray buffer into bytes, without intermediate copying
b = (memoryview(self._read_buffer)
[self._read_buffer_pos:self._read_buffer_pos + loc]
).tobytes()
self._read_buffer_pos += loc
self._read_buffer_size -= loc
return self._read_buffer.popleft()
# Amortized O(1) shrink
# (this heuristic is implemented natively in Python 3.4+
# but is replicated here for Python 2)
if self._read_buffer_pos > self._read_buffer_size:
del self._read_buffer[:self._read_buffer_pos]
self._read_buffer_pos = 0
return b
def _check_closed(self):
if self.closed():
@ -1124,7 +1164,7 @@ class IOStream(BaseIOStream):
suitably-configured `ssl.SSLContext` to disable.
"""
if (self._read_callback or self._read_future or
self._write_callback or self._write_future or
self._write_callback or self._write_futures or
self._connect_callback or self._connect_future or
self._pending_callbacks or self._closed or
self._read_buffer or self._write_buffer):
@ -1251,6 +1291,17 @@ class SSLIOStream(IOStream):
def writing(self):
return self._handshake_writing or super(SSLIOStream, self).writing()
def _got_empty_write(self, size):
# With OpenSSL, if we couldn't write the entire buffer,
# the very same string object must be used on the
# next call to send. Therefore we suppress
# merging the write buffer after an incomplete send.
# A cleaner solution would be to set
# SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER, but this is
# not yet accessible from python
# (http://bugs.python.org/issue8240)
self._freeze_write_buffer(size)
def _do_ssl_handshake(self):
# Based on code from test_ssl.py in the python stdlib
try:
@ -1498,53 +1549,6 @@ class PipeIOStream(BaseIOStream):
return chunk
def _double_prefix(deque):
"""Grow by doubling, but don't split the second chunk just because the
first one is small.
"""
new_len = max(len(deque[0]) * 2,
(len(deque[0]) + len(deque[1])))
_merge_prefix(deque, new_len)
def _merge_prefix(deque, size):
"""Replace the first entries in a deque of strings with a single
string of up to size bytes.
>>> d = collections.deque(['abc', 'de', 'fghi', 'j'])
>>> _merge_prefix(d, 5); print(d)
deque(['abcde', 'fghi', 'j'])
Strings will be split as necessary to reach the desired size.
>>> _merge_prefix(d, 7); print(d)
deque(['abcdefg', 'hi', 'j'])
>>> _merge_prefix(d, 3); print(d)
deque(['abc', 'defg', 'hi', 'j'])
>>> _merge_prefix(d, 100); print(d)
deque(['abcdefghij'])
"""
if len(deque) == 1 and len(deque[0]) <= size:
return
prefix = []
remaining = size
while deque and remaining > 0:
chunk = deque.popleft()
if len(chunk) > remaining:
deque.appendleft(chunk[remaining:])
chunk = chunk[:remaining]
prefix.append(chunk)
remaining -= len(chunk)
# This data structure normally just contains byte strings, but
# the unittest gets messy if it doesn't use the default str() type,
# so do the merge based on the type of data that's actually present.
if prefix:
deque.appendleft(type(prefix[0])().join(prefix))
if not deque:
deque.appendleft(b"")
def doctests():
import doctest
return doctest.DocTestSuite()

View File

@ -19,7 +19,7 @@
To load a locale and generate a translated string::
user_locale = tornado.locale.get("es_LA")
print user_locale.translate("Sign out")
print(user_locale.translate("Sign out"))
`tornado.locale.get()` returns the closest matching locale, not necessarily the
specific locale you requested. You can support pluralization with
@ -28,7 +28,7 @@ additional arguments to `~Locale.translate()`, e.g.::
people = [...]
message = user_locale.translate(
"%(list)s is online", "%(list)s are online", len(people))
print message % {"list": user_locale.list(people)}
print(message % {"list": user_locale.list(people)})
The first string is chosen if ``len(people) == 1``, otherwise the second
string is chosen.
@ -39,7 +39,7 @@ supported by `gettext` and related tools). If neither method is called,
the `Locale.translate` method will simply return the original string.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import codecs
import csv
@ -51,12 +51,12 @@ import re
from tornado import escape
from tornado.log import gen_log
from tornado.util import u
from tornado.util import PY3
from tornado._locale_data import LOCALE_NAMES
_default_locale = "en_US"
_translations = {}
_translations = {} # type: dict
_supported_locales = frozenset([_default_locale])
_use_gettext = False
CONTEXT_SEPARATOR = "\x04"
@ -148,11 +148,11 @@ def load_translations(directory, encoding=None):
# in most cases but is common with CSV files because Excel
# cannot read utf-8 files without a BOM.
encoding = 'utf-8-sig'
try:
if PY3:
# python 3: csv.reader requires a file open in text mode.
# Force utf8 to avoid dependence on $LANG environment variable.
f = open(full_path, "r", encoding=encoding)
except TypeError:
else:
# python 2: csv can only handle byte strings (in ascii-compatible
# encodings), which we decode below. Transcode everything into
# utf8 before passing it to csv.reader.
@ -187,7 +187,7 @@ def load_gettext_translations(directory, domain):
{directory}/{lang}/LC_MESSAGES/{domain}.mo
Three steps are required to have you app translated:
Three steps are required to have your app translated:
1. Generate POT translation file::
@ -274,7 +274,7 @@ class Locale(object):
def __init__(self, code, translations):
self.code = code
self.name = LOCALE_NAMES.get(code, {}).get("name", u("Unknown"))
self.name = LOCALE_NAMES.get(code, {}).get("name", u"Unknown")
self.rtl = False
for prefix in ["fa", "ar", "he"]:
if self.code.startswith(prefix):
@ -376,7 +376,7 @@ class Locale(object):
str_time = "%d:%02d" % (local_date.hour, local_date.minute)
elif self.code == "zh_CN":
str_time = "%s%d:%02d" % (
(u('\u4e0a\u5348'), u('\u4e0b\u5348'))[local_date.hour >= 12],
(u'\u4e0a\u5348', u'\u4e0b\u5348')[local_date.hour >= 12],
local_date.hour % 12 or 12, local_date.minute)
else:
str_time = "%d:%02d %s" % (
@ -422,7 +422,7 @@ class Locale(object):
return ""
if len(parts) == 1:
return parts[0]
comma = u(' \u0648 ') if self.code.startswith("fa") else u(", ")
comma = u' \u0648 ' if self.code.startswith("fa") else u", "
return _("%(commas)s and %(last)s") % {
"commas": comma.join(parts[:-1]),
"last": parts[len(parts) - 1],

View File

@ -12,15 +12,15 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
from __future__ import absolute_import, division, print_function
import collections
from tornado import gen, ioloop
from tornado.concurrent import Future
__all__ = ['Condition', 'Event', 'Semaphore', 'BoundedSemaphore', 'Lock']
class _TimeoutGarbageCollector(object):
"""Base class for objects that periodically clean up timed-out waiters.
@ -465,7 +465,7 @@ class Lock(object):
...
... # Now the lock is released.
.. versionchanged:: 3.5
.. versionchanged:: 4.3
Added ``async with`` support in Python 3.5.
"""

View File

@ -28,7 +28,7 @@ These streams may be configured independently using the standard library's
`logging` module. For example, you may wish to send ``tornado.access`` logs
to a separate file for analysis.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import logging
import logging.handlers
@ -38,7 +38,12 @@ from tornado.escape import _unicode
from tornado.util import unicode_type, basestring_type
try:
import curses
import colorama
except ImportError:
colorama = None
try:
import curses # type: ignore
except ImportError:
curses = None
@ -49,15 +54,21 @@ gen_log = logging.getLogger("tornado.general")
def _stderr_supports_color():
color = False
if curses and hasattr(sys.stderr, 'isatty') and sys.stderr.isatty():
try:
curses.setupterm()
if curses.tigetnum("colors") > 0:
color = True
except Exception:
pass
return color
try:
if hasattr(sys.stderr, 'isatty') and sys.stderr.isatty():
if curses:
curses.setupterm()
if curses.tigetnum("colors") > 0:
return True
elif colorama:
if sys.stderr is getattr(colorama.initialise, 'wrapped_stderr',
object()):
return True
except Exception:
# Very broad exception handling because it's always better to
# fall back to non-colored logs than to break at startup.
pass
return False
def _safe_unicode(s):
@ -77,8 +88,19 @@ class LogFormatter(logging.Formatter):
* Robust against str/bytes encoding problems.
This formatter is enabled automatically by
`tornado.options.parse_command_line` (unless ``--logging=none`` is
used).
`tornado.options.parse_command_line` or `tornado.options.parse_config_file`
(unless ``--logging=none`` is used).
Color support on Windows versions that do not support ANSI color codes is
enabled by use of the colorama__ library. Applications that wish to use
this must first initialize colorama with a call to ``colorama.init``.
See the colorama documentation for details.
__ https://pypi.python.org/pypi/colorama
.. versionchanged:: 4.5
Added support for ``colorama``. Changed the constructor
signature to be compatible with `logging.config.dictConfig`.
"""
DEFAULT_FORMAT = '%(color)s[%(levelname)1.1s %(asctime)s %(module)s:%(lineno)d]%(end_color)s %(message)s'
DEFAULT_DATE_FORMAT = '%y%m%d %H:%M:%S'
@ -89,8 +111,8 @@ class LogFormatter(logging.Formatter):
logging.ERROR: 1, # Red
}
def __init__(self, color=True, fmt=DEFAULT_FORMAT,
datefmt=DEFAULT_DATE_FORMAT, colors=DEFAULT_COLORS):
def __init__(self, fmt=DEFAULT_FORMAT, datefmt=DEFAULT_DATE_FORMAT,
style='%', color=True, colors=DEFAULT_COLORS):
r"""
:arg bool color: Enables color support.
:arg string fmt: Log message format.
@ -111,21 +133,28 @@ class LogFormatter(logging.Formatter):
self._colors = {}
if color and _stderr_supports_color():
# The curses module has some str/bytes confusion in
# python3. Until version 3.2.3, most methods return
# bytes, but only accept strings. In addition, we want to
# output these strings with the logging module, which
# works with unicode strings. The explicit calls to
# unicode() below are harmless in python2 but will do the
# right conversion in python 3.
fg_color = (curses.tigetstr("setaf") or
curses.tigetstr("setf") or "")
if (3, 0) < sys.version_info < (3, 2, 3):
fg_color = unicode_type(fg_color, "ascii")
if curses is not None:
# The curses module has some str/bytes confusion in
# python3. Until version 3.2.3, most methods return
# bytes, but only accept strings. In addition, we want to
# output these strings with the logging module, which
# works with unicode strings. The explicit calls to
# unicode() below are harmless in python2 but will do the
# right conversion in python 3.
fg_color = (curses.tigetstr("setaf") or
curses.tigetstr("setf") or "")
if (3, 0) < sys.version_info < (3, 2, 3):
fg_color = unicode_type(fg_color, "ascii")
for levelno, code in colors.items():
self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii")
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
for levelno, code in colors.items():
self._colors[levelno] = unicode_type(curses.tparm(fg_color, code), "ascii")
self._normal = unicode_type(curses.tigetstr("sgr0"), "ascii")
else:
# If curses is not present (currently we'll only get here for
# colorama on windows), assume hard-coded ANSI color codes.
for levelno, code in colors.items():
self._colors[levelno] = '\033[2;3%dm' % code
self._normal = '\033[0m'
else:
self._normal = ''
@ -183,7 +212,8 @@ def enable_pretty_logging(options=None, logger=None):
and `tornado.options.parse_config_file`.
"""
if options is None:
from tornado.options import options
import tornado.options
options = tornado.options.options
if options.logging is None or options.logging.lower() == 'none':
return
if logger is None:
@ -228,7 +258,8 @@ def define_logging_options(options=None):
"""
if options is None:
# late import to prevent cycle
from tornado.options import options
import tornado.options
options = tornado.options.options
options.define("logging", default="info",
help=("Set the Python log level. If 'none', tornado won't touch the "
"logging configuration."),

View File

@ -16,7 +16,7 @@
"""Miscellaneous network utility code."""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import errno
import os
@ -27,7 +27,7 @@ import stat
from tornado.concurrent import dummy_executor, run_on_executor
from tornado.ioloop import IOLoop
from tornado.platform.auto import set_close_exec
from tornado.util import u, Configurable, errno_from_exception
from tornado.util import PY3, Configurable, errno_from_exception
try:
import ssl
@ -44,20 +44,18 @@ except ImportError:
else:
raise
try:
xrange # py2
except NameError:
xrange = range # py3
if PY3:
xrange = range
if hasattr(ssl, 'match_hostname') and hasattr(ssl, 'CertificateError'): # python 3.2+
ssl_match_hostname = ssl.match_hostname
SSLCertificateError = ssl.CertificateError
elif ssl is None:
ssl_match_hostname = SSLCertificateError = None
ssl_match_hostname = SSLCertificateError = None # type: ignore
else:
import backports.ssl_match_hostname
ssl_match_hostname = backports.ssl_match_hostname.match_hostname
SSLCertificateError = backports.ssl_match_hostname.CertificateError
SSLCertificateError = backports.ssl_match_hostname.CertificateError # type: ignore
if hasattr(ssl, 'SSLContext'):
if hasattr(ssl, 'create_default_context'):
@ -96,7 +94,10 @@ else:
# module-import time, the import lock is already held by the main thread,
# leading to deadlock. Avoid it by caching the idna encoder on the main
# thread now.
u('foo').encode('idna')
u'foo'.encode('idna')
# For undiagnosed reasons, 'latin1' codec may also need to be preloaded.
u'foo'.encode('latin1')
# These errnos indicate that a non-blocking operation must be retried
# at a later time. On most platforms they're the same value, but on
@ -104,7 +105,7 @@ u('foo').encode('idna')
_ERRNO_WOULDBLOCK = (errno.EWOULDBLOCK, errno.EAGAIN)
if hasattr(errno, "WSAEWOULDBLOCK"):
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,)
_ERRNO_WOULDBLOCK += (errno.WSAEWOULDBLOCK,) # type: ignore
# Default backlog used when calling sock.listen()
_DEFAULT_BACKLOG = 128
@ -131,7 +132,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
``flags`` is a bitmask of AI_* flags to `~socket.getaddrinfo`, like
``socket.AI_PASSIVE | socket.AI_NUMERICHOST``.
``resuse_port`` option sets ``SO_REUSEPORT`` option for every socket
``reuse_port`` option sets ``SO_REUSEPORT`` option for every socket
in the list. If your platform doesn't support this option ValueError will
be raised.
"""
@ -199,6 +200,7 @@ def bind_sockets(port, address=None, family=socket.AF_UNSPEC,
sockets.append(sock)
return sockets
if hasattr(socket, 'AF_UNIX'):
def bind_unix_socket(file, mode=0o600, backlog=_DEFAULT_BACKLOG):
"""Creates a listening unix socket.
@ -334,6 +336,11 @@ class Resolver(Configurable):
port)`` pair for IPv4; additional fields may be present for
IPv6). If a ``callback`` is passed, it will be run with the
result as an argument when it is complete.
:raises IOError: if the address cannot be resolved.
.. versionchanged:: 4.4
Standardized all implementations to raise `IOError`.
"""
raise NotImplementedError()
@ -413,8 +420,8 @@ class ThreadedResolver(ExecutorResolver):
All ``ThreadedResolvers`` share a single thread pool, whose
size is set by the first one to be created.
"""
_threadpool = None
_threadpool_pid = None
_threadpool = None # type: ignore
_threadpool_pid = None # type: int
def initialize(self, io_loop=None, num_threads=10):
threadpool = ThreadedResolver._create_threadpool(num_threads)
@ -518,4 +525,4 @@ def ssl_wrap_socket(socket, ssl_options, server_hostname=None, **kwargs):
else:
return context.wrap_socket(socket, **kwargs)
else:
return ssl.wrap_socket(socket, **dict(context, **kwargs))
return ssl.wrap_socket(socket, **dict(context, **kwargs)) # type: ignore

View File

@ -41,6 +41,12 @@ either::
# or
tornado.options.parse_config_file("/etc/server.conf")
.. note:
When using tornado.options.parse_command_line or
tornado.options.parse_config_file, the only options that are set are
ones that were previously defined with tornado.options.define.
Command line formats are what you would expect (``--myoption=myvalue``).
Config files are just Python files. Global names become options, e.g.::
@ -76,7 +82,7 @@ instances to define isolated sets of options, such as for subcommands.
underscores.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import datetime
import numbers
@ -132,8 +138,10 @@ class OptionParser(object):
return name in self._options
def __getitem__(self, name):
name = self._normalize_name(name)
return self._options[name].value()
return self.__getattr__(name)
def __setitem__(self, name, value):
return self.__setattr__(name, value)
def items(self):
"""A sequence of (name, value) pairs.
@ -300,8 +308,12 @@ class OptionParser(object):
.. versionchanged:: 4.1
Config files are now always interpreted as utf-8 instead of
the system default encoding.
.. versionchanged:: 4.4
The special variable ``__file__`` is available inside config
files, specifying the absolute path to the config file itself.
"""
config = {}
config = {'__file__': os.path.abspath(path)}
with open(path, 'rb') as f:
exec_in(native_str(f.read()), config, config)
for name in config:

View File

@ -14,12 +14,12 @@ loops.
.. note::
Tornado requires the `~asyncio.BaseEventLoop.add_reader` family of methods,
so it is not compatible with the `~asyncio.ProactorEventLoop` on Windows.
Use the `~asyncio.SelectorEventLoop` instead.
Tornado requires the `~asyncio.AbstractEventLoop.add_reader` family of
methods, so it is not compatible with the `~asyncio.ProactorEventLoop` on
Windows. Use the `~asyncio.SelectorEventLoop` instead.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import functools
import tornado.concurrent
@ -30,11 +30,11 @@ from tornado import stack_context
try:
# Import the real asyncio module for py33+ first. Older versions of the
# trollius backport also use this name.
import asyncio
import asyncio # type: ignore
except ImportError as e:
# Asyncio itself isn't available; see if trollius is (backport to py26+).
try:
import trollius as asyncio
import trollius as asyncio # type: ignore
except ImportError:
# Re-raise the original asyncio error, not the trollius one.
raise e
@ -141,6 +141,8 @@ class BaseAsyncIOLoop(IOLoop):
def add_callback(self, callback, *args, **kwargs):
if self.closing:
# TODO: this is racy; we need a lock to ensure that the
# loop isn't closed during call_soon_threadsafe.
raise RuntimeError("IOLoop is closing")
self.asyncio_loop.call_soon_threadsafe(
self._run_callback,
@ -158,6 +160,9 @@ class AsyncIOMainLoop(BaseAsyncIOLoop):
import asyncio
AsyncIOMainLoop().install()
asyncio.get_event_loop().run_forever()
See also :meth:`tornado.ioloop.IOLoop.install` for general notes on
installing alternative IOLoops.
"""
def initialize(self, **kwargs):
super(AsyncIOMainLoop, self).initialize(asyncio.get_event_loop(),
@ -212,5 +217,6 @@ def to_asyncio_future(tornado_future):
tornado.concurrent.chain_future(tornado_future, af)
return af
if hasattr(convert_yielded, 'register'):
convert_yielded.register(asyncio.Future, to_tornado_future)
convert_yielded.register(asyncio.Future, to_tornado_future) # type: ignore

View File

@ -23,7 +23,7 @@ Most code that needs access to this functionality should do e.g.::
from tornado.platform.auto import set_close_exec
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import os
@ -47,8 +47,13 @@ try:
except ImportError:
pass
try:
from time import monotonic as monotonic_time
# monotonic can provide a monotonic function in versions of python before
# 3.3, too.
from monotonic import monotonic as monotonic_time
except ImportError:
monotonic_time = None
try:
from time import monotonic as monotonic_time
except ImportError:
monotonic_time = None
__all__ = ['Waker', 'set_close_exec', 'monotonic_time']

View File

@ -1,5 +1,5 @@
from __future__ import absolute_import, division, print_function, with_statement
import pycares
from __future__ import absolute_import, division, print_function
import pycares # type: ignore
import socket
from tornado import gen
@ -61,8 +61,8 @@ class CaresResolver(Resolver):
assert not callback_args.kwargs
result, error = callback_args.args
if error:
raise Exception('C-Ares returned error %s: %s while resolving %s' %
(error, pycares.errno.strerror(error), host))
raise IOError('C-Ares returned error %s: %s while resolving %s' %
(error, pycares.errno.strerror(error), host))
addresses = result.addresses
addrinfo = []
for address in addresses:
@ -73,7 +73,7 @@ class CaresResolver(Resolver):
else:
address_family = socket.AF_UNSPEC
if family != socket.AF_UNSPEC and family != address_family:
raise Exception('Requested socket family %d but got %d' %
(family, address_family))
raise IOError('Requested socket family %d but got %d' %
(family, address_family))
addrinfo.append((address_family, (address, port)))
raise gen.Return(addrinfo)

View File

@ -1,10 +1,27 @@
"""Lowest-common-denominator implementations of platform functionality."""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import errno
import socket
import time
from tornado.platform import interface
from tornado.util import errno_from_exception
def try_close(f):
# Avoid issue #875 (race condition when using the file in another
# thread).
for i in range(10):
try:
f.close()
except IOError:
# Yield to another thread
time.sleep(1e-3)
else:
break
# Try a last time and let raise
f.close()
class Waker(interface.Waker):
@ -45,7 +62,7 @@ class Waker(interface.Waker):
break # success
except socket.error as detail:
if (not hasattr(errno, 'WSAEADDRINUSE') or
detail[0] != errno.WSAEADDRINUSE):
errno_from_exception(detail) != errno.WSAEADDRINUSE):
# "Address already in use" is the only error
# I've seen on two WinXP Pro SP2 boxes, under
# Pythons 2.3.5 and 2.4.1.
@ -75,7 +92,7 @@ class Waker(interface.Waker):
def wake(self):
try:
self.writer.send(b"x")
except (IOError, socket.error):
except (IOError, socket.error, ValueError):
pass
def consume(self):
@ -89,4 +106,4 @@ class Waker(interface.Waker):
def close(self):
self.reader.close()
self.writer.close()
try_close(self.writer)

View File

@ -14,7 +14,7 @@
# License for the specific language governing permissions and limitations
# under the License.
"""EPoll-based IOLoop implementation for Linux systems."""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import select

View File

@ -21,7 +21,7 @@ for other tornado.platform modules. Most code should import the appropriate
implementation from `tornado.platform.auto`.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
def set_close_exec(fd):
@ -61,3 +61,7 @@ class Waker(object):
def close(self):
"""Closes the waker's file descriptor(s)."""
raise NotImplementedError()
def monotonic_time():
raise NotImplementedError()

View File

@ -14,7 +14,7 @@
# License for the specific language governing permissions and limitations
# under the License.
"""KQueue-based IOLoop implementation for BSD/Mac systems."""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import select

View File

@ -16,12 +16,12 @@
"""Posix implementations of platform-specific functionality."""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import fcntl
import os
from tornado.platform import interface
from tornado.platform import common, interface
def set_close_exec(fd):
@ -53,7 +53,7 @@ class Waker(interface.Waker):
def wake(self):
try:
self.writer.write(b"x")
except IOError:
except (IOError, ValueError):
pass
def consume(self):
@ -67,4 +67,4 @@ class Waker(interface.Waker):
def close(self):
self.reader.close()
self.writer.close()
common.try_close(self.writer)

View File

@ -17,7 +17,7 @@
Used as a fallback for systems that don't support epoll or kqueue.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import select

View File

@ -21,7 +21,7 @@ depending on which library's underlying event loop you want to use.
This module has been tested with Twisted versions 11.0.0 and newer.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import datetime
import functools
@ -29,19 +29,18 @@ import numbers
import socket
import sys
import twisted.internet.abstract
from twisted.internet.defer import Deferred
from twisted.internet.posixbase import PosixReactorBase
from twisted.internet.interfaces import \
IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor
from twisted.python import failure, log
from twisted.internet import error
import twisted.names.cache
import twisted.names.client
import twisted.names.hosts
import twisted.names.resolve
import twisted.internet.abstract # type: ignore
from twisted.internet.defer import Deferred # type: ignore
from twisted.internet.posixbase import PosixReactorBase # type: ignore
from twisted.internet.interfaces import IReactorFDSet, IDelayedCall, IReactorTime, IReadDescriptor, IWriteDescriptor # type: ignore
from twisted.python import failure, log # type: ignore
from twisted.internet import error # type: ignore
import twisted.names.cache # type: ignore
import twisted.names.client # type: ignore
import twisted.names.hosts # type: ignore
import twisted.names.resolve # type: ignore
from zope.interface import implementer
from zope.interface import implementer # type: ignore
from tornado.concurrent import Future
from tornado.escape import utf8
@ -354,7 +353,7 @@ def install(io_loop=None):
if not io_loop:
io_loop = tornado.ioloop.IOLoop.current()
reactor = TornadoReactor(io_loop)
from twisted.internet.main import installReactor
from twisted.internet.main import installReactor # type: ignore
installReactor(reactor)
return reactor
@ -408,11 +407,14 @@ class TwistedIOLoop(tornado.ioloop.IOLoop):
Not compatible with `tornado.process.Subprocess.set_exit_callback`
because the ``SIGCHLD`` handlers used by Tornado and Twisted conflict
with each other.
See also :meth:`tornado.ioloop.IOLoop.install` for general notes on
installing alternative IOLoops.
"""
def initialize(self, reactor=None, **kwargs):
super(TwistedIOLoop, self).initialize(**kwargs)
if reactor is None:
import twisted.internet.reactor
import twisted.internet.reactor # type: ignore
reactor = twisted.internet.reactor
self.reactor = reactor
self.fds = {}
@ -554,7 +556,10 @@ class TwistedResolver(Resolver):
deferred = self.resolver.getHostByName(utf8(host))
resolved = yield gen.Task(deferred.addBoth)
if isinstance(resolved, failure.Failure):
resolved.raiseException()
try:
resolved.raiseException()
except twisted.names.error.DomainError as e:
raise IOError(e)
elif twisted.internet.abstract.isIPAddress(resolved):
resolved_family = socket.AF_INET
elif twisted.internet.abstract.isIPv6Address(resolved):
@ -569,8 +574,9 @@ class TwistedResolver(Resolver):
]
raise gen.Return(result)
if hasattr(gen.convert_yielded, 'register'):
@gen.convert_yielded.register(Deferred)
@gen.convert_yielded.register(Deferred) # type: ignore
def _(d):
f = Future()

View File

@ -2,9 +2,9 @@
# for production use.
from __future__ import absolute_import, division, print_function, with_statement
import ctypes
import ctypes.wintypes
from __future__ import absolute_import, division, print_function
import ctypes # type: ignore
import ctypes.wintypes # type: ignore
# See: http://msdn.microsoft.com/en-us/library/ms724935(VS.85).aspx
SetHandleInformation = ctypes.windll.kernel32.SetHandleInformation
@ -17,4 +17,4 @@ HANDLE_FLAG_INHERIT = 0x00000001
def set_close_exec(fd):
success = SetHandleInformation(fd, HANDLE_FLAG_INHERIT, 0)
if not success:
raise ctypes.GetLastError()
raise ctypes.WinError()

View File

@ -18,7 +18,7 @@
the server into multiple processes and managing subprocesses.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import errno
import os
@ -35,7 +35,7 @@ from tornado.iostream import PipeIOStream
from tornado.log import gen_log
from tornado.platform.auto import set_close_exec
from tornado import stack_context
from tornado.util import errno_from_exception
from tornado.util import errno_from_exception, PY3
try:
import multiprocessing
@ -43,11 +43,8 @@ except ImportError:
# Multiprocessing is not available on Google App Engine.
multiprocessing = None
try:
long # py2
except NameError:
long = int # py3
if PY3:
long = int
# Re-export this exception for convenience.
try:
@ -70,7 +67,7 @@ def cpu_count():
pass
try:
return os.sysconf("SC_NPROCESSORS_CONF")
except ValueError:
except (AttributeError, ValueError):
pass
gen_log.error("Could not detect number of processors; assuming 1")
return 1
@ -147,6 +144,7 @@ def fork_processes(num_processes, max_restarts=100):
else:
children[pid] = i
return None
for i in range(num_processes):
id = start_child(i)
if id is not None:
@ -204,13 +202,19 @@ class Subprocess(object):
attribute of the resulting Subprocess a `.PipeIOStream`.
* A new keyword argument ``io_loop`` may be used to pass in an IOLoop.
The ``Subprocess.STREAM`` option and the ``set_exit_callback`` and
``wait_for_exit`` methods do not work on Windows. There is
therefore no reason to use this class instead of
``subprocess.Popen`` on that platform.
.. versionchanged:: 4.1
The ``io_loop`` argument is deprecated.
"""
STREAM = object()
_initialized = False
_waiting = {}
_waiting = {} # type: ignore
def __init__(self, *args, **kwargs):
self.io_loop = kwargs.pop('io_loop', None) or ioloop.IOLoop.current()
@ -351,6 +355,10 @@ class Subprocess(object):
else:
assert os.WIFEXITED(status)
self.returncode = os.WEXITSTATUS(status)
# We've taken over wait() duty from the subprocess.Popen
# object. If we don't inform it of the process's return code,
# it will log a warning at destruction in python 3.6+.
self.proc.returncode = self.returncode
if self._exit_callback:
callback = self._exit_callback
self._exit_callback = None

View File

@ -12,9 +12,17 @@
# License for the specific language governing permissions and limitations
# under the License.
from __future__ import absolute_import, division, print_function, with_statement
"""Asynchronous queues for coroutines.
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
.. warning::
Unlike the standard library's `queue` module, the classes defined here
are *not* thread-safe. To use these queues from another thread,
use `.IOLoop.add_callback` to transfer control to the `.IOLoop` thread
before calling any queue methods.
"""
from __future__ import absolute_import, division, print_function
import collections
import heapq
@ -23,6 +31,8 @@ from tornado import gen, ioloop
from tornado.concurrent import Future
from tornado.locks import Event
__all__ = ['Queue', 'PriorityQueue', 'LifoQueue', 'QueueFull', 'QueueEmpty']
class QueueEmpty(Exception):
"""Raised by `.Queue.get_nowait` when the queue has no items."""

View File

@ -0,0 +1,625 @@
# Copyright 2015 The Tornado Authors
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
"""Flexible routing implementation.
Tornado routes HTTP requests to appropriate handlers using `Router`
class implementations. The `tornado.web.Application` class is a
`Router` implementation and may be used directly, or the classes in
this module may be used for additional flexibility. The `RuleRouter`
class can match on more criteria than `.Application`, or the `Router`
interface can be subclassed for maximum customization.
`Router` interface extends `~.httputil.HTTPServerConnectionDelegate`
to provide additional routing capabilities. This also means that any
`Router` implementation can be used directly as a ``request_callback``
for `~.httpserver.HTTPServer` constructor.
`Router` subclass must implement a ``find_handler`` method to provide
a suitable `~.httputil.HTTPMessageDelegate` instance to handle the
request:
.. code-block:: python
class CustomRouter(Router):
def find_handler(self, request, **kwargs):
# some routing logic providing a suitable HTTPMessageDelegate instance
return MessageDelegate(request.connection)
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}),
b"OK")
self.connection.finish()
router = CustomRouter()
server = HTTPServer(router)
The main responsibility of `Router` implementation is to provide a
mapping from a request to `~.httputil.HTTPMessageDelegate` instance
that will handle this request. In the example above we can see that
routing is possible even without instantiating an `~.web.Application`.
For routing to `~.web.RequestHandler` implementations we need an
`~.web.Application` instance. `~.web.Application.get_handler_delegate`
provides a convenient way to create `~.httputil.HTTPMessageDelegate`
for a given request and `~.web.RequestHandler`.
Here is a simple example of how we can we route to
`~.web.RequestHandler` subclasses by HTTP method:
.. code-block:: python
resources = {}
class GetResource(RequestHandler):
def get(self, path):
if path not in resources:
raise HTTPError(404)
self.finish(resources[path])
class PostResource(RequestHandler):
def post(self, path):
resources[path] = self.request.body
class HTTPMethodRouter(Router):
def __init__(self, app):
self.app = app
def find_handler(self, request, **kwargs):
handler = GetResource if request.method == "GET" else PostResource
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
router = HTTPMethodRouter(Application())
server = HTTPServer(router)
`ReversibleRouter` interface adds the ability to distinguish between
the routes and reverse them to the original urls using route's name
and additional arguments. `~.web.Application` is itself an
implementation of `ReversibleRouter` class.
`RuleRouter` and `ReversibleRuleRouter` are implementations of
`Router` and `ReversibleRouter` interfaces and can be used for
creating rule-based routing configurations.
Rules are instances of `Rule` class. They contain a `Matcher`, which
provides the logic for determining whether the rule is a match for a
particular request and a target, which can be one of the following.
1) An instance of `~.httputil.HTTPServerConnectionDelegate`:
.. code-block:: python
router = RuleRouter([
Rule(PathMatches("/handler"), ConnectionDelegate()),
# ... more rules
])
class ConnectionDelegate(HTTPServerConnectionDelegate):
def start_request(self, server_conn, request_conn):
return MessageDelegate(request_conn)
2) A callable accepting a single argument of `~.httputil.HTTPServerRequest` type:
.. code-block:: python
router = RuleRouter([
Rule(PathMatches("/callable"), request_callable)
])
def request_callable(request):
request.write(b"HTTP/1.1 200 OK\\r\\nContent-Length: 2\\r\\n\\r\\nOK")
request.finish()
3) Another `Router` instance:
.. code-block:: python
router = RuleRouter([
Rule(PathMatches("/router.*"), CustomRouter())
])
Of course a nested `RuleRouter` or a `~.web.Application` is allowed:
.. code-block:: python
router = RuleRouter([
Rule(HostMatches("example.com"), RuleRouter([
Rule(PathMatches("/app1/.*"), Application([(r"/app1/handler", Handler)]))),
]))
])
server = HTTPServer(router)
In the example below `RuleRouter` is used to route between applications:
.. code-block:: python
app1 = Application([
(r"/app1/handler", Handler1),
# other handlers ...
])
app2 = Application([
(r"/app2/handler", Handler2),
# other handlers ...
])
router = RuleRouter([
Rule(PathMatches("/app1.*"), app1),
Rule(PathMatches("/app2.*"), app2)
])
server = HTTPServer(router)
For more information on application-level routing see docs for `~.web.Application`.
.. versionadded:: 4.5
"""
from __future__ import absolute_import, division, print_function
import re
from functools import partial
from tornado import httputil
from tornado.httpserver import _CallableAdapter
from tornado.escape import url_escape, url_unescape, utf8
from tornado.log import app_log
from tornado.util import basestring_type, import_object, re_unescape, unicode_type
try:
import typing # noqa
except ImportError:
pass
class Router(httputil.HTTPServerConnectionDelegate):
"""Abstract router interface."""
def find_handler(self, request, **kwargs):
# type: (httputil.HTTPServerRequest, typing.Any)->httputil.HTTPMessageDelegate
"""Must be implemented to return an appropriate instance of `~.httputil.HTTPMessageDelegate`
that can serve the request.
Routing implementations may pass additional kwargs to extend the routing logic.
:arg httputil.HTTPServerRequest request: current HTTP request.
:arg kwargs: additional keyword arguments passed by routing implementation.
:returns: an instance of `~.httputil.HTTPMessageDelegate` that will be used to
process the request.
"""
raise NotImplementedError()
def start_request(self, server_conn, request_conn):
return _RoutingDelegate(self, server_conn, request_conn)
class ReversibleRouter(Router):
"""Abstract router interface for routers that can handle named routes
and support reversing them to original urls.
"""
def reverse_url(self, name, *args):
"""Returns url string for a given route name and arguments
or ``None`` if no match is found.
:arg str name: route name.
:arg args: url parameters.
:returns: parametrized url string for a given route name (or ``None``).
"""
raise NotImplementedError()
class _RoutingDelegate(httputil.HTTPMessageDelegate):
def __init__(self, router, server_conn, request_conn):
self.server_conn = server_conn
self.request_conn = request_conn
self.delegate = None
self.router = router # type: Router
def headers_received(self, start_line, headers):
request = httputil.HTTPServerRequest(
connection=self.request_conn,
server_connection=self.server_conn,
start_line=start_line, headers=headers)
self.delegate = self.router.find_handler(request)
return self.delegate.headers_received(start_line, headers)
def data_received(self, chunk):
return self.delegate.data_received(chunk)
def finish(self):
self.delegate.finish()
def on_connection_close(self):
self.delegate.on_connection_close()
class RuleRouter(Router):
"""Rule-based router implementation."""
def __init__(self, rules=None):
"""Constructs a router from an ordered list of rules::
RuleRouter([
Rule(PathMatches("/handler"), Target),
# ... more rules
])
You can also omit explicit `Rule` constructor and use tuples of arguments::
RuleRouter([
(PathMatches("/handler"), Target),
])
`PathMatches` is a default matcher, so the example above can be simplified::
RuleRouter([
("/handler", Target),
])
In the examples above, ``Target`` can be a nested `Router` instance, an instance of
`~.httputil.HTTPServerConnectionDelegate` or an old-style callable, accepting a request argument.
:arg rules: a list of `Rule` instances or tuples of `Rule`
constructor arguments.
"""
self.rules = [] # type: typing.List[Rule]
if rules:
self.add_rules(rules)
def add_rules(self, rules):
"""Appends new rules to the router.
:arg rules: a list of Rule instances (or tuples of arguments, which are
passed to Rule constructor).
"""
for rule in rules:
if isinstance(rule, (tuple, list)):
assert len(rule) in (2, 3, 4)
if isinstance(rule[0], basestring_type):
rule = Rule(PathMatches(rule[0]), *rule[1:])
else:
rule = Rule(*rule)
self.rules.append(self.process_rule(rule))
def process_rule(self, rule):
"""Override this method for additional preprocessing of each rule.
:arg Rule rule: a rule to be processed.
:returns: the same or modified Rule instance.
"""
return rule
def find_handler(self, request, **kwargs):
for rule in self.rules:
target_params = rule.matcher.match(request)
if target_params is not None:
if rule.target_kwargs:
target_params['target_kwargs'] = rule.target_kwargs
delegate = self.get_target_delegate(
rule.target, request, **target_params)
if delegate is not None:
return delegate
return None
def get_target_delegate(self, target, request, **target_params):
"""Returns an instance of `~.httputil.HTTPMessageDelegate` for a
Rule's target. This method is called by `~.find_handler` and can be
extended to provide additional target types.
:arg target: a Rule's target.
:arg httputil.HTTPServerRequest request: current request.
:arg target_params: additional parameters that can be useful
for `~.httputil.HTTPMessageDelegate` creation.
"""
if isinstance(target, Router):
return target.find_handler(request, **target_params)
elif isinstance(target, httputil.HTTPServerConnectionDelegate):
return target.start_request(request.server_connection, request.connection)
elif callable(target):
return _CallableAdapter(
partial(target, **target_params), request.connection
)
return None
class ReversibleRuleRouter(ReversibleRouter, RuleRouter):
"""A rule-based router that implements ``reverse_url`` method.
Each rule added to this router may have a ``name`` attribute that can be
used to reconstruct an original uri. The actual reconstruction takes place
in a rule's matcher (see `Matcher.reverse`).
"""
def __init__(self, rules=None):
self.named_rules = {} # type: typing.Dict[str]
super(ReversibleRuleRouter, self).__init__(rules)
def process_rule(self, rule):
rule = super(ReversibleRuleRouter, self).process_rule(rule)
if rule.name:
if rule.name in self.named_rules:
app_log.warning(
"Multiple handlers named %s; replacing previous value",
rule.name)
self.named_rules[rule.name] = rule
return rule
def reverse_url(self, name, *args):
if name in self.named_rules:
return self.named_rules[name].matcher.reverse(*args)
for rule in self.rules:
if isinstance(rule.target, ReversibleRouter):
reversed_url = rule.target.reverse_url(name, *args)
if reversed_url is not None:
return reversed_url
return None
class Rule(object):
"""A routing rule."""
def __init__(self, matcher, target, target_kwargs=None, name=None):
"""Constructs a Rule instance.
:arg Matcher matcher: a `Matcher` instance used for determining
whether the rule should be considered a match for a specific
request.
:arg target: a Rule's target (typically a ``RequestHandler`` or
`~.httputil.HTTPServerConnectionDelegate` subclass or even a nested `Router`,
depending on routing implementation).
:arg dict target_kwargs: a dict of parameters that can be useful
at the moment of target instantiation (for example, ``status_code``
for a ``RequestHandler`` subclass). They end up in
``target_params['target_kwargs']`` of `RuleRouter.get_target_delegate`
method.
:arg str name: the name of the rule that can be used to find it
in `ReversibleRouter.reverse_url` implementation.
"""
if isinstance(target, str):
# import the Module and instantiate the class
# Must be a fully qualified name (module.ClassName)
target = import_object(target)
self.matcher = matcher # type: Matcher
self.target = target
self.target_kwargs = target_kwargs if target_kwargs else {}
self.name = name
def reverse(self, *args):
return self.matcher.reverse(*args)
def __repr__(self):
return '%s(%r, %s, kwargs=%r, name=%r)' % \
(self.__class__.__name__, self.matcher,
self.target, self.target_kwargs, self.name)
class Matcher(object):
"""Represents a matcher for request features."""
def match(self, request):
"""Matches current instance against the request.
:arg httputil.HTTPServerRequest request: current HTTP request
:returns: a dict of parameters to be passed to the target handler
(for example, ``handler_kwargs``, ``path_args``, ``path_kwargs``
can be passed for proper `~.web.RequestHandler` instantiation).
An empty dict is a valid (and common) return value to indicate a match
when the argument-passing features are not used.
``None`` must be returned to indicate that there is no match."""
raise NotImplementedError()
def reverse(self, *args):
"""Reconstructs full url from matcher instance and additional arguments."""
return None
class AnyMatches(Matcher):
"""Matches any request."""
def match(self, request):
return {}
class HostMatches(Matcher):
"""Matches requests from hosts specified by ``host_pattern`` regex."""
def __init__(self, host_pattern):
if isinstance(host_pattern, basestring_type):
if not host_pattern.endswith("$"):
host_pattern += "$"
self.host_pattern = re.compile(host_pattern)
else:
self.host_pattern = host_pattern
def match(self, request):
if self.host_pattern.match(request.host_name):
return {}
return None
class DefaultHostMatches(Matcher):
"""Matches requests from host that is equal to application's default_host.
Always returns no match if ``X-Real-Ip`` header is present.
"""
def __init__(self, application, host_pattern):
self.application = application
self.host_pattern = host_pattern
def match(self, request):
# Look for default host if not behind load balancer (for debugging)
if "X-Real-Ip" not in request.headers:
if self.host_pattern.match(self.application.default_host):
return {}
return None
class PathMatches(Matcher):
"""Matches requests with paths specified by ``path_pattern`` regex."""
def __init__(self, path_pattern):
if isinstance(path_pattern, basestring_type):
if not path_pattern.endswith('$'):
path_pattern += '$'
self.regex = re.compile(path_pattern)
else:
self.regex = path_pattern
assert len(self.regex.groupindex) in (0, self.regex.groups), \
("groups in url regexes must either be all named or all "
"positional: %r" % self.regex.pattern)
self._path, self._group_count = self._find_groups()
def match(self, request):
match = self.regex.match(request.path)
if match is None:
return None
if not self.regex.groups:
return {}
path_args, path_kwargs = [], {}
# Pass matched groups to the handler. Since
# match.groups() includes both named and
# unnamed groups, we want to use either groups
# or groupdict but not both.
if self.regex.groupindex:
path_kwargs = dict(
(str(k), _unquote_or_none(v))
for (k, v) in match.groupdict().items())
else:
path_args = [_unquote_or_none(s) for s in match.groups()]
return dict(path_args=path_args, path_kwargs=path_kwargs)
def reverse(self, *args):
if self._path is None:
raise ValueError("Cannot reverse url regex " + self.regex.pattern)
assert len(args) == self._group_count, "required number of arguments " \
"not found"
if not len(args):
return self._path
converted_args = []
for a in args:
if not isinstance(a, (unicode_type, bytes)):
a = str(a)
converted_args.append(url_escape(utf8(a), plus=False))
return self._path % tuple(converted_args)
def _find_groups(self):
"""Returns a tuple (reverse string, group count) for a url.
For example: Given the url pattern /([0-9]{4})/([a-z-]+)/, this method
would return ('/%s/%s/', 2).
"""
pattern = self.regex.pattern
if pattern.startswith('^'):
pattern = pattern[1:]
if pattern.endswith('$'):
pattern = pattern[:-1]
if self.regex.groups != pattern.count('('):
# The pattern is too complicated for our simplistic matching,
# so we can't support reversing it.
return None, None
pieces = []
for fragment in pattern.split('('):
if ')' in fragment:
paren_loc = fragment.index(')')
if paren_loc >= 0:
pieces.append('%s' + fragment[paren_loc + 1:])
else:
try:
unescaped_fragment = re_unescape(fragment)
except ValueError as exc:
# If we can't unescape part of it, we can't
# reverse this url.
return (None, None)
pieces.append(unescaped_fragment)
return ''.join(pieces), self.regex.groups
class URLSpec(Rule):
"""Specifies mappings between URLs and handlers.
.. versionchanged: 4.5
`URLSpec` is now a subclass of a `Rule` with `PathMatches` matcher and is preserved for
backwards compatibility.
"""
def __init__(self, pattern, handler, kwargs=None, name=None):
"""Parameters:
* ``pattern``: Regular expression to be matched. Any capturing
groups in the regex will be passed in to the handler's
get/post/etc methods as arguments (by keyword if named, by
position if unnamed. Named and unnamed capturing groups may
may not be mixed in the same rule).
* ``handler``: `~.web.RequestHandler` subclass to be invoked.
* ``kwargs`` (optional): A dictionary of additional arguments
to be passed to the handler's constructor.
* ``name`` (optional): A name for this handler. Used by
`~.web.Application.reverse_url`.
"""
super(URLSpec, self).__init__(PathMatches(pattern), handler, kwargs, name)
self.regex = self.matcher.regex
self.handler_class = self.target
self.kwargs = kwargs
def __repr__(self):
return '%s(%r, %s, kwargs=%r, name=%r)' % \
(self.__class__.__name__, self.regex.pattern,
self.handler_class, self.kwargs, self.name)
def _unquote_or_none(s):
"""None-safe wrapper around url_unescape to handle unmatched optional
groups correctly.
Note that args are passed as bytes so the handler can decide what
encoding to use.
"""
if s is None:
return s
return url_unescape(s, encoding=None, plus=False)

View File

@ -1,5 +1,5 @@
#!/usr/bin/env python
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
from tornado.escape import utf8, _unicode
from tornado import gen
@ -11,6 +11,7 @@ from tornado.netutil import Resolver, OverrideResolver, _client_ssl_defaults
from tornado.log import gen_log
from tornado import stack_context
from tornado.tcpclient import TCPClient
from tornado.util import PY3
import base64
import collections
@ -22,10 +23,10 @@ import sys
from io import BytesIO
try:
import urlparse # py2
except ImportError:
import urllib.parse as urlparse # py3
if PY3:
import urllib.parse as urlparse
else:
import urlparse
try:
import ssl
@ -126,7 +127,7 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
timeout_handle = self.io_loop.add_timeout(
self.io_loop.time() + min(request.connect_timeout,
request.request_timeout),
functools.partial(self._on_timeout, key))
functools.partial(self._on_timeout, key, "in request queue"))
else:
timeout_handle = None
self.waiting[key] = (request, callback, timeout_handle)
@ -167,11 +168,20 @@ class SimpleAsyncHTTPClient(AsyncHTTPClient):
self.io_loop.remove_timeout(timeout_handle)
del self.waiting[key]
def _on_timeout(self, key):
def _on_timeout(self, key, info=None):
"""Timeout callback of request.
Construct a timeout HTTPResponse when a timeout occurs.
:arg object key: A simple object to mark the request.
:info string key: More detailed timeout information.
"""
request, callback, timeout_handle = self.waiting[key]
self.queue.remove((key, request, callback))
error_message = "Timeout {0}".format(info) if info else "Timeout"
timeout_response = HTTPResponse(
request, 599, error=HTTPError(599, "Timeout"),
request, 599, error=HTTPError(599, error_message),
request_time=self.io_loop.time() - request.start_time)
self.io_loop.add_callback(callback, timeout_response)
del self.waiting[key]
@ -229,7 +239,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
if timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + timeout,
stack_context.wrap(self._on_timeout))
stack_context.wrap(functools.partial(self._on_timeout, "while connecting")))
self.tcp_client.connect(host, port, af=af,
ssl_options=ssl_options,
max_buffer_size=self.max_buffer_size,
@ -284,10 +294,17 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
return ssl_options
return None
def _on_timeout(self):
def _on_timeout(self, info=None):
"""Timeout callback of _HTTPConnection instance.
Raise a timeout HTTPError when a timeout occurs.
:info string key: More detailed timeout information.
"""
self._timeout = None
error_message = "Timeout {0}".format(info) if info else "Timeout"
if self.final_callback is not None:
raise HTTPError(599, "Timeout")
raise HTTPError(599, error_message)
def _remove_timeout(self):
if self._timeout is not None:
@ -307,13 +324,14 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
if self.request.request_timeout:
self._timeout = self.io_loop.add_timeout(
self.start_time + self.request.request_timeout,
stack_context.wrap(self._on_timeout))
stack_context.wrap(functools.partial(self._on_timeout, "during request")))
if (self.request.method not in self._SUPPORTED_METHODS and
not self.request.allow_nonstandard_methods):
raise KeyError("unknown method %s" % self.request.method)
for key in ('network_interface',
'proxy_host', 'proxy_port',
'proxy_username', 'proxy_password'):
'proxy_username', 'proxy_password',
'proxy_auth_mode'):
if getattr(self.request, key, None):
raise NotImplementedError('%s not supported' % key)
if "Connection" not in self.request.headers:
@ -481,7 +499,7 @@ class _HTTPConnection(httputil.HTTPMessageDelegate):
def _should_follow_redirect(self):
return (self.request.follow_redirects and
self.request.max_redirects > 0 and
self.code in (301, 302, 303, 307))
self.code in (301, 302, 303, 307, 308))
def finish(self):
data = b''.join(self.chunks)

View File

@ -67,7 +67,7 @@ Here are a few rules of thumb for when it's necessary:
block that references your `StackContext`.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import sys
import threading
@ -82,6 +82,8 @@ class StackContextInconsistentError(Exception):
class _State(threading.local):
def __init__(self):
self.contexts = (tuple(), None)
_state = _State()

View File

@ -16,7 +16,7 @@
"""A non-blocking TCP connection factory.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import functools
import socket
@ -155,16 +155,30 @@ class TCPClient(object):
@gen.coroutine
def connect(self, host, port, af=socket.AF_UNSPEC, ssl_options=None,
max_buffer_size=None):
max_buffer_size=None, source_ip=None, source_port=None):
"""Connect to the given host and port.
Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
``ssl_options`` is not None).
Using the ``source_ip`` kwarg, one can specify the source
IP address to use when establishing the connection.
In case the user needs to resolve and
use a specific interface, it has to be handled outside
of Tornado as this depends very much on the platform.
Similarly, when the user requires a certain source port, it can
be specified using the ``source_port`` arg.
.. versionchanged:: 4.5
Added the ``source_ip`` and ``source_port`` arguments.
"""
addrinfo = yield self.resolver.resolve(host, port, af)
connector = _Connector(
addrinfo, self.io_loop,
functools.partial(self._create_stream, max_buffer_size))
functools.partial(self._create_stream, max_buffer_size,
source_ip=source_ip, source_port=source_port)
)
af, addr, stream = yield connector.start()
# TODO: For better performance we could cache the (af, addr)
# information here and re-use it on subsequent connections to
@ -174,10 +188,35 @@ class TCPClient(object):
server_hostname=host)
raise gen.Return(stream)
def _create_stream(self, max_buffer_size, af, addr):
def _create_stream(self, max_buffer_size, af, addr, source_ip=None,
source_port=None):
# Always connect in plaintext; we'll convert to ssl if necessary
# after one connection has completed.
stream = IOStream(socket.socket(af),
io_loop=self.io_loop,
max_buffer_size=max_buffer_size)
return stream.connect(addr)
source_port_bind = source_port if isinstance(source_port, int) else 0
source_ip_bind = source_ip
if source_port_bind and not source_ip:
# User required a specific port, but did not specify
# a certain source IP, will bind to the default loopback.
source_ip_bind = '::1' if af == socket.AF_INET6 else '127.0.0.1'
# Trying to use the same address family as the requested af socket:
# - 127.0.0.1 for IPv4
# - ::1 for IPv6
socket_obj = socket.socket(af)
if source_port_bind or source_ip_bind:
# If the user requires binding also to a specific IP/port.
try:
socket_obj.bind((source_ip_bind, source_port_bind))
except socket.error:
socket_obj.close()
# Fail loudly if unable to use the IP/port.
raise
try:
stream = IOStream(socket_obj,
io_loop=self.io_loop,
max_buffer_size=max_buffer_size)
except socket.error as e:
fu = Future()
fu.set_exception(e)
return fu
else:
return stream.connect(addr)

View File

@ -15,12 +15,13 @@
# under the License.
"""A non-blocking, single-threaded TCP server."""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import errno
import os
import socket
from tornado import gen
from tornado.log import app_log
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream, SSLIOStream
@ -39,7 +40,21 @@ class TCPServer(object):
r"""A non-blocking, single-threaded TCP server.
To use `TCPServer`, define a subclass which overrides the `handle_stream`
method.
method. For example, a simple echo server could be defined like this::
from tornado.tcpserver import TCPServer
from tornado.iostream import StreamClosedError
from tornado import gen
class EchoServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
while True:
try:
data = yield stream.read_until(b"\n")
yield stream.write(data)
except StreamClosedError:
break
To make this server serve SSL traffic, send the ``ssl_options`` keyword
argument with an `ssl.SSLContext` object. For compatibility with older
@ -95,6 +110,7 @@ class TCPServer(object):
self._sockets = {} # fd -> socket object
self._pending_sockets = []
self._started = False
self._stopped = False
self.max_buffer_size = max_buffer_size
self.read_chunk_size = read_chunk_size
@ -147,7 +163,8 @@ class TCPServer(object):
"""Singular version of `add_sockets`. Takes a single socket object."""
self.add_sockets([socket])
def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128):
def bind(self, port, address=None, family=socket.AF_UNSPEC, backlog=128,
reuse_port=False):
"""Binds this server to the given port on the given address.
To start the server, call `start`. If you want to run this server
@ -162,13 +179,17 @@ class TCPServer(object):
both will be used if available.
The ``backlog`` argument has the same meaning as for
`socket.listen <socket.socket.listen>`.
`socket.listen <socket.socket.listen>`. The ``reuse_port`` argument
has the same meaning as for `.bind_sockets`.
This method may be called multiple times prior to `start` to listen
on multiple ports or interfaces.
.. versionchanged:: 4.4
Added the ``reuse_port`` argument.
"""
sockets = bind_sockets(port, address=address, family=family,
backlog=backlog)
backlog=backlog, reuse_port=reuse_port)
if self._started:
self.add_sockets(sockets)
else:
@ -208,7 +229,11 @@ class TCPServer(object):
Requests currently in progress may still continue after the
server is stopped.
"""
if self._stopped:
return
self._stopped = True
for fd, sock in self._sockets.items():
assert sock.fileno() == fd
self.io_loop.remove_handler(fd)
sock.close()
@ -266,8 +291,10 @@ class TCPServer(object):
stream = IOStream(connection, io_loop=self.io_loop,
max_buffer_size=self.max_buffer_size,
read_chunk_size=self.read_chunk_size)
future = self.handle_stream(stream, address)
if future is not None:
self.io_loop.add_future(future, lambda f: f.result())
self.io_loop.add_future(gen.convert_yielded(future),
lambda f: f.result())
except Exception:
app_log.error("Error in connection callback", exc_info=True)

View File

@ -19,13 +19,13 @@
Basic usage looks like::
t = template.Template("<html>{{ myvalue }}</html>")
print t.generate(myvalue="XXX")
print(t.generate(myvalue="XXX"))
`Loader` is a class that loads templates from a root directory and caches
the compiled templates::
loader = template.Loader("/home/btaylor")
print loader.load("test.html").generate(myvalue="XXX")
print(loader.load("test.html").generate(myvalue="XXX"))
We compile all templates to raw Python. Error-reporting is currently... uh,
interesting. Syntax for the templates::
@ -94,12 +94,15 @@ Syntax Reference
Template expressions are surrounded by double curly braces: ``{{ ... }}``.
The contents may be any python expression, which will be escaped according
to the current autoescape setting and inserted into the output. Other
template directives use ``{% %}``. These tags may be escaped as ``{{!``
and ``{%!`` if you need to include a literal ``{{`` or ``{%`` in the output.
template directives use ``{% %}``.
To comment out a section so that it is omitted from the output, surround it
with ``{# ... #}``.
These tags may be escaped as ``{{!``, ``{%!``, and ``{#!``
if you need to include a literal ``{{``, ``{%``, or ``{#`` in the output.
``{% apply *function* %}...{% end %}``
Applies a function to the output of all template code between ``apply``
and ``end``::
@ -193,7 +196,7 @@ with ``{# ... #}``.
`filter_whitespace` for available options. New in Tornado 4.3.
"""
from __future__ import absolute_import, division, print_function, with_statement
from __future__ import absolute_import, division, print_function
import datetime
import linecache
@ -204,12 +207,12 @@ import threading
from tornado import escape
from tornado.log import app_log
from tornado.util import ObjectDict, exec_in, unicode_type
from tornado.util import ObjectDict, exec_in, unicode_type, PY3
try:
from cStringIO import StringIO # py2
except ImportError:
from io import StringIO # py3
if PY3:
from io import StringIO
else:
from cStringIO import StringIO
_DEFAULT_AUTOESCAPE = "xhtml_escape"
_UNSET = object()
@ -665,7 +668,7 @@ class ParseError(Exception):
.. versionchanged:: 4.3
Added ``filename`` and ``lineno`` attributes.
"""
def __init__(self, message, filename, lineno):
def __init__(self, message, filename=None, lineno=0):
self.message = message
# The names "filename" and "lineno" are chosen for consistency
# with python SyntaxError.

Some files were not shown because too many files have changed in this diff Show More