Uploaded Test files

This commit is contained in:
Batuhan Berk Başoğlu 2020-11-12 11:05:57 -05:00
parent f584ad9d97
commit 2e81cb7d99
16627 changed files with 2065359 additions and 102444 deletions

View file

@ -0,0 +1,12 @@
"""Shim to allow python -m tornado.test.
This only works in python 2.7+.
"""
from tornado.test.runtests import all, main
# tornado.testing.main autodiscovery relies on 'all' being present in
# the main module, so import it here even though it is not used directly.
# The following line prevents a pyflakes warning.
all = all
main()

View file

@ -0,0 +1,190 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import asyncio
import unittest
from concurrent.futures import ThreadPoolExecutor
from tornado import gen
from tornado.ioloop import IOLoop
from tornado.platform.asyncio import (
AsyncIOLoop,
to_asyncio_future,
AnyThreadEventLoopPolicy,
)
from tornado.testing import AsyncTestCase, gen_test
class AsyncIOLoopTest(AsyncTestCase):
def get_new_ioloop(self):
io_loop = AsyncIOLoop()
return io_loop
def test_asyncio_callback(self):
# Basic test that the asyncio loop is set up correctly.
asyncio.get_event_loop().call_soon(self.stop)
self.wait()
@gen_test
def test_asyncio_future(self):
# Test that we can yield an asyncio future from a tornado coroutine.
# Without 'yield from', we must wrap coroutines in ensure_future,
# which was introduced during Python 3.4, deprecating the prior "async".
if hasattr(asyncio, "ensure_future"):
ensure_future = asyncio.ensure_future
else:
# async is a reserved word in Python 3.7
ensure_future = getattr(asyncio, "async")
x = yield ensure_future(
asyncio.get_event_loop().run_in_executor(None, lambda: 42)
)
self.assertEqual(x, 42)
@gen_test
def test_asyncio_yield_from(self):
@gen.coroutine
def f():
event_loop = asyncio.get_event_loop()
x = yield from event_loop.run_in_executor(None, lambda: 42)
return x
result = yield f()
self.assertEqual(result, 42)
def test_asyncio_adapter(self):
# This test demonstrates that when using the asyncio coroutine
# runner (i.e. run_until_complete), the to_asyncio_future
# adapter is needed. No adapter is needed in the other direction,
# as demonstrated by other tests in the package.
@gen.coroutine
def tornado_coroutine():
yield gen.moment
raise gen.Return(42)
async def native_coroutine_without_adapter():
return await tornado_coroutine()
async def native_coroutine_with_adapter():
return await to_asyncio_future(tornado_coroutine())
# Use the adapter, but two degrees from the tornado coroutine.
async def native_coroutine_with_adapter2():
return await to_asyncio_future(native_coroutine_without_adapter())
# Tornado supports native coroutines both with and without adapters
self.assertEqual(self.io_loop.run_sync(native_coroutine_without_adapter), 42)
self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter), 42)
self.assertEqual(self.io_loop.run_sync(native_coroutine_with_adapter2), 42)
# Asyncio only supports coroutines that yield asyncio-compatible
# Futures (which our Future is since 5.0).
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_without_adapter()
),
42,
)
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_with_adapter()
),
42,
)
self.assertEqual(
asyncio.get_event_loop().run_until_complete(
native_coroutine_with_adapter2()
),
42,
)
class LeakTest(unittest.TestCase):
def setUp(self):
# Trigger a cleanup of the mapping so we start with a clean slate.
AsyncIOLoop().close()
# If we don't clean up after ourselves other tests may fail on
# py34.
self.orig_policy = asyncio.get_event_loop_policy()
asyncio.set_event_loop_policy(asyncio.DefaultEventLoopPolicy())
def tearDown(self):
asyncio.get_event_loop().close()
asyncio.set_event_loop_policy(self.orig_policy)
def test_ioloop_close_leak(self):
orig_count = len(IOLoop._ioloop_for_asyncio)
for i in range(10):
# Create and close an AsyncIOLoop using Tornado interfaces.
loop = AsyncIOLoop()
loop.close()
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
self.assertEqual(new_count, 0)
def test_asyncio_close_leak(self):
orig_count = len(IOLoop._ioloop_for_asyncio)
for i in range(10):
# Create and close an AsyncIOMainLoop using asyncio interfaces.
loop = asyncio.new_event_loop()
loop.call_soon(IOLoop.current)
loop.call_soon(loop.stop)
loop.run_forever()
loop.close()
new_count = len(IOLoop._ioloop_for_asyncio) - orig_count
# Because the cleanup is run on new loop creation, we have one
# dangling entry in the map (but only one).
self.assertEqual(new_count, 1)
class AnyThreadEventLoopPolicyTest(unittest.TestCase):
def setUp(self):
self.orig_policy = asyncio.get_event_loop_policy()
self.executor = ThreadPoolExecutor(1)
def tearDown(self):
asyncio.set_event_loop_policy(self.orig_policy)
self.executor.shutdown()
def get_event_loop_on_thread(self):
def get_and_close_event_loop():
"""Get the event loop. Close it if one is returned.
Returns the (closed) event loop. This is a silly thing
to do and leaves the thread in a broken state, but it's
enough for this test. Closing the loop avoids resource
leak warnings.
"""
loop = asyncio.get_event_loop()
loop.close()
return loop
future = self.executor.submit(get_and_close_event_loop)
return future.result()
def run_policy_test(self, accessor, expected_type):
# With the default policy, non-main threads don't get an event
# loop.
self.assertRaises(
(RuntimeError, AssertionError), self.executor.submit(accessor).result
)
# Set the policy and we can get a loop.
asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
self.assertIsInstance(self.executor.submit(accessor).result(), expected_type)
# Clean up to silence leak warnings. Always use asyncio since
# IOLoop doesn't (currently) close the underlying loop.
self.executor.submit(lambda: asyncio.get_event_loop().close()).result() # type: ignore
def test_asyncio_accessor(self):
self.run_policy_test(asyncio.get_event_loop, asyncio.AbstractEventLoop)
def test_tornado_accessor(self):
self.run_policy_test(IOLoop.current, IOLoop)

View file

@ -0,0 +1,609 @@
# These tests do not currently do much to verify the correct implementation
# of the openid/oauth protocols, they just exercise the major code paths
# and ensure that it doesn't blow up (e.g. with unicode/bytes issues in
# python 3)
import unittest
from tornado.auth import (
OpenIdMixin,
OAuthMixin,
OAuth2Mixin,
GoogleOAuth2Mixin,
FacebookGraphMixin,
TwitterMixin,
)
from tornado.escape import json_decode
from tornado import gen
from tornado.httpclient import HTTPClientError
from tornado.httputil import url_concat
from tornado.log import app_log
from tornado.testing import AsyncHTTPTestCase, ExpectLog
from tornado.web import RequestHandler, Application, HTTPError
try:
from unittest import mock
except ImportError:
mock = None # type: ignore
class OpenIdClientLoginHandler(RequestHandler, OpenIdMixin):
def initialize(self, test):
self._OPENID_ENDPOINT = test.get_url("/openid/server/authenticate")
@gen.coroutine
def get(self):
if self.get_argument("openid.mode", None):
user = yield self.get_authenticated_user(
http_client=self.settings["http_client"]
)
if user is None:
raise Exception("user is None")
self.finish(user)
return
res = self.authenticate_redirect() # type: ignore
assert res is None
class OpenIdServerAuthenticateHandler(RequestHandler):
def post(self):
if self.get_argument("openid.mode") != "check_authentication":
raise Exception("incorrect openid.mode %r")
self.write("is_valid:true")
class OAuth1ClientLoginHandler(RequestHandler, OAuthMixin):
def initialize(self, test, version):
self._OAUTH_VERSION = version
self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token")
self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize")
self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/oauth1/server/access_token")
def _oauth_consumer_token(self):
return dict(key="asdf", secret="qwer")
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user(
http_client=self.settings["http_client"]
)
if user is None:
raise Exception("user is None")
self.finish(user)
return
yield self.authorize_redirect(http_client=self.settings["http_client"])
@gen.coroutine
def _oauth_get_user_future(self, access_token):
if self.get_argument("fail_in_get_user", None):
raise Exception("failing in get_user")
if access_token != dict(key="uiop", secret="5678"):
raise Exception("incorrect access token %r" % access_token)
return dict(email="foo@example.com")
class OAuth1ClientLoginCoroutineHandler(OAuth1ClientLoginHandler):
"""Replaces OAuth1ClientLoginCoroutineHandler's get() with a coroutine."""
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
# Ensure that any exceptions are set on the returned Future,
# not simply thrown into the surrounding StackContext.
try:
yield self.get_authenticated_user()
except Exception as e:
self.set_status(503)
self.write("got exception: %s" % e)
else:
yield self.authorize_redirect()
class OAuth1ClientRequestParametersHandler(RequestHandler, OAuthMixin):
def initialize(self, version):
self._OAUTH_VERSION = version
def _oauth_consumer_token(self):
return dict(key="asdf", secret="qwer")
def get(self):
params = self._oauth_request_parameters(
"http://www.example.com/api/asdf",
dict(key="uiop", secret="5678"),
parameters=dict(foo="bar"),
)
self.write(params)
class OAuth1ServerRequestTokenHandler(RequestHandler):
def get(self):
self.write("oauth_token=zxcv&oauth_token_secret=1234")
class OAuth1ServerAccessTokenHandler(RequestHandler):
def get(self):
self.write("oauth_token=uiop&oauth_token_secret=5678")
class OAuth2ClientLoginHandler(RequestHandler, OAuth2Mixin):
def initialize(self, test):
self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth2/server/authorize")
def get(self):
res = self.authorize_redirect() # type: ignore
assert res is None
class FacebookClientLoginHandler(RequestHandler, FacebookGraphMixin):
def initialize(self, test):
self._OAUTH_AUTHORIZE_URL = test.get_url("/facebook/server/authorize")
self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/facebook/server/access_token")
self._FACEBOOK_BASE_URL = test.get_url("/facebook/server")
@gen.coroutine
def get(self):
if self.get_argument("code", None):
user = yield self.get_authenticated_user(
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
client_secret=self.settings["facebook_secret"],
code=self.get_argument("code"),
)
self.write(user)
else:
self.authorize_redirect(
redirect_uri=self.request.full_url(),
client_id=self.settings["facebook_api_key"],
extra_params={"scope": "read_stream,offline_access"},
)
class FacebookServerAccessTokenHandler(RequestHandler):
def get(self):
self.write(dict(access_token="asdf", expires_in=3600))
class FacebookServerMeHandler(RequestHandler):
def get(self):
self.write("{}")
class TwitterClientHandler(RequestHandler, TwitterMixin):
def initialize(self, test):
self._OAUTH_REQUEST_TOKEN_URL = test.get_url("/oauth1/server/request_token")
self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/twitter/server/access_token")
self._OAUTH_AUTHORIZE_URL = test.get_url("/oauth1/server/authorize")
self._OAUTH_AUTHENTICATE_URL = test.get_url("/twitter/server/authenticate")
self._TWITTER_BASE_URL = test.get_url("/twitter/api")
def get_auth_http_client(self):
return self.settings["http_client"]
class TwitterClientLoginHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
if user is None:
raise Exception("user is None")
self.finish(user)
return
yield self.authorize_redirect()
class TwitterClientAuthenticateHandler(TwitterClientHandler):
# Like TwitterClientLoginHandler, but uses authenticate_redirect
# instead of authorize_redirect.
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
if user is None:
raise Exception("user is None")
self.finish(user)
return
yield self.authenticate_redirect()
class TwitterClientLoginGenCoroutineHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
if self.get_argument("oauth_token", None):
user = yield self.get_authenticated_user()
self.finish(user)
else:
# New style: with @gen.coroutine the result must be yielded
# or else the request will be auto-finished too soon.
yield self.authorize_redirect()
class TwitterClientShowUserHandler(TwitterClientHandler):
@gen.coroutine
def get(self):
# TODO: would be nice to go through the login flow instead of
# cheating with a hard-coded access token.
try:
response = yield self.twitter_request(
"/users/show/%s" % self.get_argument("name"),
access_token=dict(key="hjkl", secret="vbnm"),
)
except HTTPClientError:
# TODO(bdarnell): Should we catch HTTP errors and
# transform some of them (like 403s) into AuthError?
self.set_status(500)
self.finish("error from twitter request")
else:
self.finish(response)
class TwitterServerAccessTokenHandler(RequestHandler):
def get(self):
self.write("oauth_token=hjkl&oauth_token_secret=vbnm&screen_name=foo")
class TwitterServerShowUserHandler(RequestHandler):
def get(self, screen_name):
if screen_name == "error":
raise HTTPError(500)
assert "oauth_nonce" in self.request.arguments
assert "oauth_timestamp" in self.request.arguments
assert "oauth_signature" in self.request.arguments
assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key"
assert self.get_argument("oauth_signature_method") == "HMAC-SHA1"
assert self.get_argument("oauth_version") == "1.0"
assert self.get_argument("oauth_token") == "hjkl"
self.write(dict(screen_name=screen_name, name=screen_name.capitalize()))
class TwitterServerVerifyCredentialsHandler(RequestHandler):
def get(self):
assert "oauth_nonce" in self.request.arguments
assert "oauth_timestamp" in self.request.arguments
assert "oauth_signature" in self.request.arguments
assert self.get_argument("oauth_consumer_key") == "test_twitter_consumer_key"
assert self.get_argument("oauth_signature_method") == "HMAC-SHA1"
assert self.get_argument("oauth_version") == "1.0"
assert self.get_argument("oauth_token") == "hjkl"
self.write(dict(screen_name="foo", name="Foo"))
class AuthTest(AsyncHTTPTestCase):
def get_app(self):
return Application(
[
# test endpoints
("/openid/client/login", OpenIdClientLoginHandler, dict(test=self)),
(
"/oauth10/client/login",
OAuth1ClientLoginHandler,
dict(test=self, version="1.0"),
),
(
"/oauth10/client/request_params",
OAuth1ClientRequestParametersHandler,
dict(version="1.0"),
),
(
"/oauth10a/client/login",
OAuth1ClientLoginHandler,
dict(test=self, version="1.0a"),
),
(
"/oauth10a/client/login_coroutine",
OAuth1ClientLoginCoroutineHandler,
dict(test=self, version="1.0a"),
),
(
"/oauth10a/client/request_params",
OAuth1ClientRequestParametersHandler,
dict(version="1.0a"),
),
("/oauth2/client/login", OAuth2ClientLoginHandler, dict(test=self)),
("/facebook/client/login", FacebookClientLoginHandler, dict(test=self)),
("/twitter/client/login", TwitterClientLoginHandler, dict(test=self)),
(
"/twitter/client/authenticate",
TwitterClientAuthenticateHandler,
dict(test=self),
),
(
"/twitter/client/login_gen_coroutine",
TwitterClientLoginGenCoroutineHandler,
dict(test=self),
),
(
"/twitter/client/show_user",
TwitterClientShowUserHandler,
dict(test=self),
),
# simulated servers
("/openid/server/authenticate", OpenIdServerAuthenticateHandler),
("/oauth1/server/request_token", OAuth1ServerRequestTokenHandler),
("/oauth1/server/access_token", OAuth1ServerAccessTokenHandler),
("/facebook/server/access_token", FacebookServerAccessTokenHandler),
("/facebook/server/me", FacebookServerMeHandler),
("/twitter/server/access_token", TwitterServerAccessTokenHandler),
(r"/twitter/api/users/show/(.*)\.json", TwitterServerShowUserHandler),
(
r"/twitter/api/account/verify_credentials\.json",
TwitterServerVerifyCredentialsHandler,
),
],
http_client=self.http_client,
twitter_consumer_key="test_twitter_consumer_key",
twitter_consumer_secret="test_twitter_consumer_secret",
facebook_api_key="test_facebook_api_key",
facebook_secret="test_facebook_secret",
)
def test_openid_redirect(self):
response = self.fetch("/openid/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue("/openid/server/authenticate?" in response.headers["Location"])
def test_openid_get_user(self):
response = self.fetch(
"/openid/client/login?openid.mode=blah"
"&openid.ns.ax=http://openid.net/srv/ax/1.0"
"&openid.ax.type.email=http://axschema.org/contact/email"
"&openid.ax.value.email=foo@example.com"
)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
def test_oauth10_redirect(self):
response = self.fetch("/oauth10/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
response.headers["Location"].endswith(
"/oauth1/server/authorize?oauth_token=zxcv"
)
)
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="'
in response.headers["Set-Cookie"],
response.headers["Set-Cookie"],
)
def test_oauth10_get_user(self):
response = self.fetch(
"/oauth10/client/login?oauth_token=zxcv",
headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678"))
def test_oauth10_request_parameters(self):
response = self.fetch("/oauth10/client/request_params")
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["oauth_consumer_key"], "asdf")
self.assertEqual(parsed["oauth_token"], "uiop")
self.assertTrue("oauth_nonce" in parsed)
self.assertTrue("oauth_signature" in parsed)
def test_oauth10a_redirect(self):
response = self.fetch("/oauth10a/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
response.headers["Location"].endswith(
"/oauth1/server/authorize?oauth_token=zxcv"
)
)
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="'
in response.headers["Set-Cookie"],
response.headers["Set-Cookie"],
)
@unittest.skipIf(mock is None, "mock package not present")
def test_oauth10a_redirect_error(self):
with mock.patch.object(OAuth1ServerRequestTokenHandler, "get") as get:
get.side_effect = Exception("boom")
with ExpectLog(app_log, "Uncaught exception"):
response = self.fetch("/oauth10a/client/login", follow_redirects=False)
self.assertEqual(response.code, 500)
def test_oauth10a_get_user(self):
response = self.fetch(
"/oauth10a/client/login?oauth_token=zxcv",
headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["email"], "foo@example.com")
self.assertEqual(parsed["access_token"], dict(key="uiop", secret="5678"))
def test_oauth10a_request_parameters(self):
response = self.fetch("/oauth10a/client/request_params")
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(parsed["oauth_consumer_key"], "asdf")
self.assertEqual(parsed["oauth_token"], "uiop")
self.assertTrue("oauth_nonce" in parsed)
self.assertTrue("oauth_signature" in parsed)
def test_oauth10a_get_user_coroutine_exception(self):
response = self.fetch(
"/oauth10a/client/login_coroutine?oauth_token=zxcv&fail_in_get_user=true",
headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
)
self.assertEqual(response.code, 503)
def test_oauth2_redirect(self):
response = self.fetch("/oauth2/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue("/oauth2/server/authorize?" in response.headers["Location"])
def test_facebook_login(self):
response = self.fetch("/facebook/client/login", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue("/facebook/server/authorize?" in response.headers["Location"])
response = self.fetch(
"/facebook/client/login?code=1234", follow_redirects=False
)
self.assertEqual(response.code, 200)
user = json_decode(response.body)
self.assertEqual(user["access_token"], "asdf")
self.assertEqual(user["session_expires"], "3600")
def base_twitter_redirect(self, url):
# Same as test_oauth10a_redirect
response = self.fetch(url, follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
response.headers["Location"].endswith(
"/oauth1/server/authorize?oauth_token=zxcv"
)
)
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="'
in response.headers["Set-Cookie"],
response.headers["Set-Cookie"],
)
def test_twitter_redirect(self):
self.base_twitter_redirect("/twitter/client/login")
def test_twitter_redirect_gen_coroutine(self):
self.base_twitter_redirect("/twitter/client/login_gen_coroutine")
def test_twitter_authenticate_redirect(self):
response = self.fetch("/twitter/client/authenticate", follow_redirects=False)
self.assertEqual(response.code, 302)
self.assertTrue(
response.headers["Location"].endswith(
"/twitter/server/authenticate?oauth_token=zxcv"
),
response.headers["Location"],
)
# the cookie is base64('zxcv')|base64('1234')
self.assertTrue(
'_oauth_request_token="enhjdg==|MTIzNA=="'
in response.headers["Set-Cookie"],
response.headers["Set-Cookie"],
)
def test_twitter_get_user(self):
response = self.fetch(
"/twitter/client/login?oauth_token=zxcv",
headers={"Cookie": "_oauth_request_token=enhjdg==|MTIzNA=="},
)
response.rethrow()
parsed = json_decode(response.body)
self.assertEqual(
parsed,
{
u"access_token": {
u"key": u"hjkl",
u"screen_name": u"foo",
u"secret": u"vbnm",
},
u"name": u"Foo",
u"screen_name": u"foo",
u"username": u"foo",
},
)
def test_twitter_show_user(self):
response = self.fetch("/twitter/client/show_user?name=somebody")
response.rethrow()
self.assertEqual(
json_decode(response.body), {"name": "Somebody", "screen_name": "somebody"}
)
def test_twitter_show_user_error(self):
response = self.fetch("/twitter/client/show_user?name=error")
self.assertEqual(response.code, 500)
self.assertEqual(response.body, b"error from twitter request")
class GoogleLoginHandler(RequestHandler, GoogleOAuth2Mixin):
def initialize(self, test):
self.test = test
self._OAUTH_REDIRECT_URI = test.get_url("/client/login")
self._OAUTH_AUTHORIZE_URL = test.get_url("/google/oauth2/authorize")
self._OAUTH_ACCESS_TOKEN_URL = test.get_url("/google/oauth2/token")
@gen.coroutine
def get(self):
code = self.get_argument("code", None)
if code is not None:
# retrieve authenticate google user
access = yield self.get_authenticated_user(self._OAUTH_REDIRECT_URI, code)
user = yield self.oauth2_request(
self.test.get_url("/google/oauth2/userinfo"),
access_token=access["access_token"],
)
# return the user and access token as json
user["access_token"] = access["access_token"]
self.write(user)
else:
self.authorize_redirect(
redirect_uri=self._OAUTH_REDIRECT_URI,
client_id=self.settings["google_oauth"]["key"],
client_secret=self.settings["google_oauth"]["secret"],
scope=["profile", "email"],
response_type="code",
extra_params={"prompt": "select_account"},
)
class GoogleOAuth2AuthorizeHandler(RequestHandler):
def get(self):
# issue a fake auth code and redirect to redirect_uri
code = "fake-authorization-code"
self.redirect(url_concat(self.get_argument("redirect_uri"), dict(code=code)))
class GoogleOAuth2TokenHandler(RequestHandler):
def post(self):
assert self.get_argument("code") == "fake-authorization-code"
# issue a fake token
self.finish(
{"access_token": "fake-access-token", "expires_in": "never-expires"}
)
class GoogleOAuth2UserinfoHandler(RequestHandler):
def get(self):
assert self.get_argument("access_token") == "fake-access-token"
# return a fake user
self.finish({"name": "Foo", "email": "foo@example.com"})
class GoogleOAuth2Test(AsyncHTTPTestCase):
def get_app(self):
return Application(
[
# test endpoints
("/client/login", GoogleLoginHandler, dict(test=self)),
# simulated google authorization server endpoints
("/google/oauth2/authorize", GoogleOAuth2AuthorizeHandler),
("/google/oauth2/token", GoogleOAuth2TokenHandler),
("/google/oauth2/userinfo", GoogleOAuth2UserinfoHandler),
],
google_oauth={
"key": "fake_google_client_id",
"secret": "fake_google_client_secret",
},
)
def test_google_login(self):
response = self.fetch("/client/login")
self.assertDictEqual(
{
u"name": u"Foo",
u"email": u"foo@example.com",
u"access_token": u"fake-access-token",
},
json_decode(response.body),
)

View file

@ -0,0 +1,127 @@
import os
import shutil
import subprocess
from subprocess import Popen
import sys
from tempfile import mkdtemp
import time
import unittest
class AutoreloadTest(unittest.TestCase):
def setUp(self):
self.path = mkdtemp()
def tearDown(self):
try:
shutil.rmtree(self.path)
except OSError:
# Windows disallows deleting files that are in use by
# another process, and even though we've waited for our
# child process below, it appears that its lock on these
# files is not guaranteed to be released by this point.
# Sleep and try again (once).
time.sleep(1)
shutil.rmtree(self.path)
def test_reload_module(self):
main = """\
import os
import sys
from tornado import autoreload
# This import will fail if path is not set up correctly
import testapp
print('Starting')
if 'TESTAPP_STARTED' not in os.environ:
os.environ['TESTAPP_STARTED'] = '1'
sys.stdout.flush()
autoreload._reload()
"""
# Create temporary test application
os.mkdir(os.path.join(self.path, "testapp"))
open(os.path.join(self.path, "testapp/__init__.py"), "w").close()
with open(os.path.join(self.path, "testapp/__main__.py"), "w") as f:
f.write(main)
# Make sure the tornado module under test is available to the test
# application
pythonpath = os.getcwd()
if "PYTHONPATH" in os.environ:
pythonpath += os.pathsep + os.environ["PYTHONPATH"]
p = Popen(
[sys.executable, "-m", "testapp"],
stdout=subprocess.PIPE,
cwd=self.path,
env=dict(os.environ, PYTHONPATH=pythonpath),
universal_newlines=True,
)
out = p.communicate()[0]
self.assertEqual(out, "Starting\nStarting\n")
def test_reload_wrapper_preservation(self):
# This test verifies that when `python -m tornado.autoreload`
# is used on an application that also has an internal
# autoreload, the reload wrapper is preserved on restart.
main = """\
import os
import sys
# This import will fail if path is not set up correctly
import testapp
if 'tornado.autoreload' not in sys.modules:
raise Exception('started without autoreload wrapper')
import tornado.autoreload
print('Starting')
sys.stdout.flush()
if 'TESTAPP_STARTED' not in os.environ:
os.environ['TESTAPP_STARTED'] = '1'
# Simulate an internal autoreload (one not caused
# by the wrapper).
tornado.autoreload._reload()
else:
# Exit directly so autoreload doesn't catch it.
os._exit(0)
"""
# Create temporary test application
os.mkdir(os.path.join(self.path, "testapp"))
init_file = os.path.join(self.path, "testapp", "__init__.py")
open(init_file, "w").close()
main_file = os.path.join(self.path, "testapp", "__main__.py")
with open(main_file, "w") as f:
f.write(main)
# Make sure the tornado module under test is available to the test
# application
pythonpath = os.getcwd()
if "PYTHONPATH" in os.environ:
pythonpath += os.pathsep + os.environ["PYTHONPATH"]
autoreload_proc = Popen(
[sys.executable, "-m", "tornado.autoreload", "-m", "testapp"],
stdout=subprocess.PIPE,
cwd=self.path,
env=dict(os.environ, PYTHONPATH=pythonpath),
universal_newlines=True,
)
# This timeout needs to be fairly generous for pypy due to jit
# warmup costs.
for i in range(40):
if autoreload_proc.poll() is not None:
break
time.sleep(0.1)
else:
autoreload_proc.kill()
raise Exception("subprocess failed to terminate")
out = autoreload_proc.communicate()[0]
self.assertEqual(out, "Starting\n" * 2)

View file

@ -0,0 +1,212 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from concurrent import futures
import logging
import re
import socket
import typing
import unittest
from tornado.concurrent import (
Future,
run_on_executor,
future_set_result_unless_cancelled,
)
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
class MiscFutureTest(AsyncTestCase):
def test_future_set_result_unless_cancelled(self):
fut = Future() # type: Future[int]
future_set_result_unless_cancelled(fut, 42)
self.assertEqual(fut.result(), 42)
self.assertFalse(fut.cancelled())
fut = Future()
fut.cancel()
is_cancelled = fut.cancelled()
future_set_result_unless_cancelled(fut, 42)
self.assertEqual(fut.cancelled(), is_cancelled)
if not is_cancelled:
self.assertEqual(fut.result(), 42)
# The following series of classes demonstrate and test various styles
# of use, with and without generators and futures.
class CapServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
data = yield stream.read_until(b"\n")
data = to_unicode(data)
if data == data.upper():
stream.write(b"error\talready capitalized\n")
else:
# data already has \n
stream.write(utf8("ok\t%s" % data.upper()))
stream.close()
class CapError(Exception):
pass
class BaseCapClient(object):
def __init__(self, port):
self.port = port
def process_response(self, data):
m = re.match("(.*)\t(.*)\n", to_unicode(data))
if m is None:
raise Exception("did not match")
status, message = m.groups()
if status == "ok":
return message
else:
raise CapError(message)
class GeneratorCapClient(BaseCapClient):
@gen.coroutine
def capitalize(self, request_data):
logging.debug("capitalize")
stream = IOStream(socket.socket())
logging.debug("connecting")
yield stream.connect(("127.0.0.1", self.port))
stream.write(utf8(request_data + "\n"))
logging.debug("reading")
data = yield stream.read_until(b"\n")
logging.debug("returning")
stream.close()
raise gen.Return(self.process_response(data))
class ClientTestMixin(object):
client_class = None # type: typing.Callable
def setUp(self):
super().setUp() # type: ignore
self.server = CapServer()
sock, port = bind_unused_port()
self.server.add_sockets([sock])
self.client = self.client_class(port=port)
def tearDown(self):
self.server.stop()
super().tearDown() # type: ignore
def test_future(self: typing.Any):
future = self.client.capitalize("hello")
self.io_loop.add_future(future, self.stop)
self.wait()
self.assertEqual(future.result(), "HELLO")
def test_future_error(self: typing.Any):
future = self.client.capitalize("HELLO")
self.io_loop.add_future(future, self.stop)
self.wait()
self.assertRaisesRegexp(CapError, "already capitalized", future.result) # type: ignore
def test_generator(self: typing.Any):
@gen.coroutine
def f():
result = yield self.client.capitalize("hello")
self.assertEqual(result, "HELLO")
self.io_loop.run_sync(f)
def test_generator_error(self: typing.Any):
@gen.coroutine
def f():
with self.assertRaisesRegexp(CapError, "already capitalized"):
yield self.client.capitalize("HELLO")
self.io_loop.run_sync(f)
class GeneratorClientTest(ClientTestMixin, AsyncTestCase):
client_class = GeneratorCapClient
class RunOnExecutorTest(AsyncTestCase):
@gen_test
def test_no_calling(self):
class Object(object):
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor
def f(self):
return 42
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_no_args(self):
class Object(object):
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor()
def f(self):
return 42
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_call_with_executor(self):
class Object(object):
def __init__(self):
self.__executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor(executor="_Object__executor")
def f(self):
return 42
o = Object()
answer = yield o.f()
self.assertEqual(answer, 42)
@gen_test
def test_async_await(self):
class Object(object):
def __init__(self):
self.executor = futures.thread.ThreadPoolExecutor(1)
@run_on_executor()
def f(self):
return 42
o = Object()
async def f():
answer = await o.f()
return answer
result = yield f()
self.assertEqual(result, 42)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1 @@
"school","école"
1 school école

View file

@ -0,0 +1,129 @@
from hashlib import md5
import unittest
from tornado.escape import utf8
from tornado.testing import AsyncHTTPTestCase
from tornado.test import httpclient_test
from tornado.web import Application, RequestHandler
try:
import pycurl
except ImportError:
pycurl = None # type: ignore
if pycurl is not None:
from tornado.curl_httpclient import CurlAsyncHTTPClient
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = CurlAsyncHTTPClient(defaults=dict(allow_ipv6=False))
# make sure AsyncHTTPClient magic doesn't give us the wrong class
self.assertTrue(isinstance(client, CurlAsyncHTTPClient))
return client
class DigestAuthHandler(RequestHandler):
def initialize(self, username, password):
self.username = username
self.password = password
def get(self):
realm = "test"
opaque = "asdf"
# Real implementations would use a random nonce.
nonce = "1234"
auth_header = self.request.headers.get("Authorization", None)
if auth_header is not None:
auth_mode, params = auth_header.split(" ", 1)
assert auth_mode == "Digest"
param_dict = {}
for pair in params.split(","):
k, v = pair.strip().split("=", 1)
if v[0] == '"' and v[-1] == '"':
v = v[1:-1]
param_dict[k] = v
assert param_dict["realm"] == realm
assert param_dict["opaque"] == opaque
assert param_dict["nonce"] == nonce
assert param_dict["username"] == self.username
assert param_dict["uri"] == self.request.path
h1 = md5(
utf8("%s:%s:%s" % (self.username, realm, self.password))
).hexdigest()
h2 = md5(
utf8("%s:%s" % (self.request.method, self.request.path))
).hexdigest()
digest = md5(utf8("%s:%s:%s" % (h1, nonce, h2))).hexdigest()
if digest == param_dict["response"]:
self.write("ok")
else:
self.write("fail")
else:
self.set_status(401)
self.set_header(
"WWW-Authenticate",
'Digest realm="%s", nonce="%s", opaque="%s"' % (realm, nonce, opaque),
)
class CustomReasonHandler(RequestHandler):
def get(self):
self.set_status(200, "Custom reason")
class CustomFailReasonHandler(RequestHandler):
def get(self):
self.set_status(400, "Custom reason")
@unittest.skipIf(pycurl is None, "pycurl module not present")
class CurlHTTPClientTestCase(AsyncHTTPTestCase):
def setUp(self):
super().setUp()
self.http_client = self.create_client()
def get_app(self):
return Application(
[
("/digest", DigestAuthHandler, {"username": "foo", "password": "bar"}),
(
"/digest_non_ascii",
DigestAuthHandler,
{"username": "foo", "password": "barユ£"},
),
("/custom_reason", CustomReasonHandler),
("/custom_fail_reason", CustomFailReasonHandler),
]
)
def create_client(self, **kwargs):
return CurlAsyncHTTPClient(
force_instance=True, defaults=dict(allow_ipv6=False), **kwargs
)
def test_digest_auth(self):
response = self.fetch(
"/digest", auth_mode="digest", auth_username="foo", auth_password="bar"
)
self.assertEqual(response.body, b"ok")
def test_custom_reason(self):
response = self.fetch("/custom_reason")
self.assertEqual(response.reason, "Custom reason")
def test_fail_custom_reason(self):
response = self.fetch("/custom_fail_reason")
self.assertEqual(str(response.error), "HTTP 400: Custom reason")
def test_digest_auth_non_ascii(self):
response = self.fetch(
"/digest_non_ascii",
auth_mode="digest",
auth_username="foo",
auth_password="barユ£",
)
self.assertEqual(response.body, b"ok")

View file

@ -0,0 +1,322 @@
import unittest
import tornado.escape
from tornado.escape import (
utf8,
xhtml_escape,
xhtml_unescape,
url_escape,
url_unescape,
to_unicode,
json_decode,
json_encode,
squeeze,
recursive_unicode,
)
from tornado.util import unicode_type
from typing import List, Tuple, Union, Dict, Any # noqa: F401
linkify_tests = [
# (input, linkify_kwargs, expected_output)
(
"hello http://world.com/!",
{},
u'hello <a href="http://world.com/">http://world.com/</a>!',
),
(
"hello http://world.com/with?param=true&stuff=yes",
{},
u'hello <a href="http://world.com/with?param=true&amp;stuff=yes">http://world.com/with?param=true&amp;stuff=yes</a>', # noqa: E501
),
# an opened paren followed by many chars killed Gruber's regex
(
"http://url.com/w(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa",
{},
u'<a href="http://url.com/w">http://url.com/w</a>(aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa', # noqa: E501
),
# as did too many dots at the end
(
"http://url.com/withmany.......................................",
{},
u'<a href="http://url.com/withmany">http://url.com/withmany</a>.......................................', # noqa: E501
),
(
"http://url.com/withmany((((((((((((((((((((((((((((((((((a)",
{},
u'<a href="http://url.com/withmany">http://url.com/withmany</a>((((((((((((((((((((((((((((((((((a)', # noqa: E501
),
# some examples from http://daringfireball.net/2009/11/liberal_regex_for_matching_urls
# plus a fex extras (such as multiple parentheses).
(
"http://foo.com/blah_blah",
{},
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>',
),
(
"http://foo.com/blah_blah/",
{},
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>',
),
(
"(Something like http://foo.com/blah_blah)",
{},
u'(Something like <a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>)',
),
(
"http://foo.com/blah_blah_(wikipedia)",
{},
u'<a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>',
),
(
"http://foo.com/blah_(blah)_(wikipedia)_blah",
{},
u'<a href="http://foo.com/blah_(blah)_(wikipedia)_blah">http://foo.com/blah_(blah)_(wikipedia)_blah</a>', # noqa: E501
),
(
"(Something like http://foo.com/blah_blah_(wikipedia))",
{},
u'(Something like <a href="http://foo.com/blah_blah_(wikipedia)">http://foo.com/blah_blah_(wikipedia)</a>)', # noqa: E501
),
(
"http://foo.com/blah_blah.",
{},
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>.',
),
(
"http://foo.com/blah_blah/.",
{},
u'<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>.',
),
(
"<http://foo.com/blah_blah>",
{},
u'&lt;<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>&gt;',
),
(
"<http://foo.com/blah_blah/>",
{},
u'&lt;<a href="http://foo.com/blah_blah/">http://foo.com/blah_blah/</a>&gt;',
),
(
"http://foo.com/blah_blah,",
{},
u'<a href="http://foo.com/blah_blah">http://foo.com/blah_blah</a>,',
),
(
"http://www.example.com/wpstyle/?p=364.",
{},
u'<a href="http://www.example.com/wpstyle/?p=364">http://www.example.com/wpstyle/?p=364</a>.', # noqa: E501
),
(
"rdar://1234",
{"permitted_protocols": ["http", "rdar"]},
u'<a href="rdar://1234">rdar://1234</a>',
),
(
"rdar:/1234",
{"permitted_protocols": ["rdar"]},
u'<a href="rdar:/1234">rdar:/1234</a>',
),
(
"http://userid:password@example.com:8080",
{},
u'<a href="http://userid:password@example.com:8080">http://userid:password@example.com:8080</a>', # noqa: E501
),
(
"http://userid@example.com",
{},
u'<a href="http://userid@example.com">http://userid@example.com</a>',
),
(
"http://userid@example.com:8080",
{},
u'<a href="http://userid@example.com:8080">http://userid@example.com:8080</a>',
),
(
"http://userid:password@example.com",
{},
u'<a href="http://userid:password@example.com">http://userid:password@example.com</a>',
),
(
"message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e",
{"permitted_protocols": ["http", "message"]},
u'<a href="message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e">'
u"message://%3c330e7f8409726r6a4ba78dkf1fd71420c1bf6ff@mail.gmail.com%3e</a>",
),
(
u"http://\u27a1.ws/\u4a39",
{},
u'<a href="http://\u27a1.ws/\u4a39">http://\u27a1.ws/\u4a39</a>',
),
(
"<tag>http://example.com</tag>",
{},
u'&lt;tag&gt;<a href="http://example.com">http://example.com</a>&lt;/tag&gt;',
),
(
"Just a www.example.com link.",
{},
u'Just a <a href="http://www.example.com">www.example.com</a> link.',
),
(
"Just a www.example.com link.",
{"require_protocol": True},
u"Just a www.example.com link.",
),
(
"A http://reallylong.com/link/that/exceedsthelenglimit.html",
{"require_protocol": True, "shorten": True},
u'A <a href="http://reallylong.com/link/that/exceedsthelenglimit.html"'
u' title="http://reallylong.com/link/that/exceedsthelenglimit.html">http://reallylong.com/link...</a>', # noqa: E501
),
(
"A http://reallylongdomainnamethatwillbetoolong.com/hi!",
{"shorten": True},
u'A <a href="http://reallylongdomainnamethatwillbetoolong.com/hi"'
u' title="http://reallylongdomainnamethatwillbetoolong.com/hi">http://reallylongdomainnametha...</a>!', # noqa: E501
),
(
"A file:///passwords.txt and http://web.com link",
{},
u'A file:///passwords.txt and <a href="http://web.com">http://web.com</a> link',
),
(
"A file:///passwords.txt and http://web.com link",
{"permitted_protocols": ["file"]},
u'A <a href="file:///passwords.txt">file:///passwords.txt</a> and http://web.com link',
),
(
"www.external-link.com",
{"extra_params": 'rel="nofollow" class="external"'},
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>', # noqa: E501
),
(
"www.external-link.com and www.internal-link.com/blogs extra",
{
"extra_params": lambda href: 'class="internal"'
if href.startswith("http://www.internal-link.com")
else 'rel="nofollow" class="external"'
},
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>' # noqa: E501
u' and <a href="http://www.internal-link.com/blogs" class="internal">www.internal-link.com/blogs</a> extra', # noqa: E501
),
(
"www.external-link.com",
{"extra_params": lambda href: ' rel="nofollow" class="external" '},
u'<a href="http://www.external-link.com" rel="nofollow" class="external">www.external-link.com</a>', # noqa: E501
),
] # type: List[Tuple[Union[str, bytes], Dict[str, Any], str]]
class EscapeTestCase(unittest.TestCase):
def test_linkify(self):
for text, kwargs, html in linkify_tests:
linked = tornado.escape.linkify(text, **kwargs)
self.assertEqual(linked, html)
def test_xhtml_escape(self):
tests = [
("<foo>", "&lt;foo&gt;"),
(u"<foo>", u"&lt;foo&gt;"),
(b"<foo>", b"&lt;foo&gt;"),
("<>&\"'", "&lt;&gt;&amp;&quot;&#39;"),
("&amp;", "&amp;amp;"),
(u"<\u00e9>", u"&lt;\u00e9&gt;"),
(b"<\xc3\xa9>", b"&lt;\xc3\xa9&gt;"),
] # type: List[Tuple[Union[str, bytes], Union[str, bytes]]]
for unescaped, escaped in tests:
self.assertEqual(utf8(xhtml_escape(unescaped)), utf8(escaped))
self.assertEqual(utf8(unescaped), utf8(xhtml_unescape(escaped)))
def test_xhtml_unescape_numeric(self):
tests = [
("foo&#32;bar", "foo bar"),
("foo&#x20;bar", "foo bar"),
("foo&#X20;bar", "foo bar"),
("foo&#xabc;bar", u"foo\u0abcbar"),
("foo&#xyz;bar", "foo&#xyz;bar"), # invalid encoding
("foo&#;bar", "foo&#;bar"), # invalid encoding
("foo&#x;bar", "foo&#x;bar"), # invalid encoding
]
for escaped, unescaped in tests:
self.assertEqual(unescaped, xhtml_unescape(escaped))
def test_url_escape_unicode(self):
tests = [
# byte strings are passed through as-is
(u"\u00e9".encode("utf8"), "%C3%A9"),
(u"\u00e9".encode("latin1"), "%E9"),
# unicode strings become utf8
(u"\u00e9", "%C3%A9"),
] # type: List[Tuple[Union[str, bytes], str]]
for unescaped, escaped in tests:
self.assertEqual(url_escape(unescaped), escaped)
def test_url_unescape_unicode(self):
tests = [
("%C3%A9", u"\u00e9", "utf8"),
("%C3%A9", u"\u00c3\u00a9", "latin1"),
("%C3%A9", utf8(u"\u00e9"), None),
]
for escaped, unescaped, encoding in tests:
# input strings to url_unescape should only contain ascii
# characters, but make sure the function accepts both byte
# and unicode strings.
self.assertEqual(url_unescape(to_unicode(escaped), encoding), unescaped)
self.assertEqual(url_unescape(utf8(escaped), encoding), unescaped)
def test_url_escape_quote_plus(self):
unescaped = "+ #%"
plus_escaped = "%2B+%23%25"
escaped = "%2B%20%23%25"
self.assertEqual(url_escape(unescaped), plus_escaped)
self.assertEqual(url_escape(unescaped, plus=False), escaped)
self.assertEqual(url_unescape(plus_escaped), unescaped)
self.assertEqual(url_unescape(escaped, plus=False), unescaped)
self.assertEqual(url_unescape(plus_escaped, encoding=None), utf8(unescaped))
self.assertEqual(
url_unescape(escaped, encoding=None, plus=False), utf8(unescaped)
)
def test_escape_return_types(self):
# On python2 the escape methods should generally return the same
# type as their argument
self.assertEqual(type(xhtml_escape("foo")), str)
self.assertEqual(type(xhtml_escape(u"foo")), unicode_type)
def test_json_decode(self):
# json_decode accepts both bytes and unicode, but strings it returns
# are always unicode.
self.assertEqual(json_decode(b'"foo"'), u"foo")
self.assertEqual(json_decode(u'"foo"'), u"foo")
# Non-ascii bytes are interpreted as utf8
self.assertEqual(json_decode(utf8(u'"\u00e9"')), u"\u00e9")
def test_json_encode(self):
# json deals with strings, not bytes. On python 2 byte strings will
# convert automatically if they are utf8; on python 3 byte strings
# are not allowed.
self.assertEqual(json_decode(json_encode(u"\u00e9")), u"\u00e9")
if bytes is str:
self.assertEqual(json_decode(json_encode(utf8(u"\u00e9"))), u"\u00e9")
self.assertRaises(UnicodeDecodeError, json_encode, b"\xe9")
def test_squeeze(self):
self.assertEqual(
squeeze(u"sequences of whitespace chars"),
u"sequences of whitespace chars",
)
def test_recursive_unicode(self):
tests = {
"dict": {b"foo": b"bar"},
"list": [b"foo", b"bar"],
"tuple": (b"foo", b"bar"),
"bytes": b"foo",
}
self.assertEqual(recursive_unicode(tests["dict"]), {u"foo": u"bar"})
self.assertEqual(recursive_unicode(tests["list"]), [u"foo", u"bar"])
self.assertEqual(recursive_unicode(tests["tuple"]), (u"foo", u"bar"))
self.assertEqual(recursive_unicode(tests["bytes"]), u"foo")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,47 @@
# SOME DESCRIPTIVE TITLE.
# Copyright (C) YEAR THE PACKAGE'S COPYRIGHT HOLDER
# This file is distributed under the same license as the PACKAGE package.
# FIRST AUTHOR <EMAIL@ADDRESS>, YEAR.
#
#, fuzzy
msgid ""
msgstr ""
"Project-Id-Version: PACKAGE VERSION\n"
"Report-Msgid-Bugs-To: \n"
"POT-Creation-Date: 2015-01-27 11:05+0300\n"
"PO-Revision-Date: YEAR-MO-DA HO:MI+ZONE\n"
"Last-Translator: FULL NAME <EMAIL@ADDRESS>\n"
"Language-Team: LANGUAGE <LL@li.org>\n"
"Language: \n"
"MIME-Version: 1.0\n"
"Content-Type: text/plain; charset=utf-8\n"
"Content-Transfer-Encoding: 8bit\n"
"Plural-Forms: nplurals=2; plural=(n > 1);\n"
#: extract_me.py:11
msgid "school"
msgstr "école"
#: extract_me.py:12
msgctxt "law"
msgid "right"
msgstr "le droit"
#: extract_me.py:13
msgctxt "good"
msgid "right"
msgstr "le bien"
#: extract_me.py:14
msgctxt "organization"
msgid "club"
msgid_plural "clubs"
msgstr[0] "le club"
msgstr[1] "les clubs"
#: extract_me.py:15
msgctxt "stick"
msgid "club"
msgid_plural "clubs"
msgstr[0] "le bâton"
msgstr[1] "les bâtons"

View file

@ -0,0 +1,61 @@
import socket
import typing
from tornado.http1connection import HTTP1Connection
from tornado.httputil import HTTPMessageDelegate
from tornado.iostream import IOStream
from tornado.locks import Event
from tornado.netutil import add_accept_handler
from tornado.testing import AsyncTestCase, bind_unused_port, gen_test
class HTTP1ConnectionTest(AsyncTestCase):
code = None # type: typing.Optional[int]
def setUp(self):
super().setUp()
self.asyncSetUp()
@gen_test
def asyncSetUp(self):
listener, port = bind_unused_port()
event = Event()
def accept_callback(conn, addr):
self.server_stream = IOStream(conn)
self.addCleanup(self.server_stream.close)
event.set()
add_accept_handler(listener, accept_callback)
self.client_stream = IOStream(socket.socket())
self.addCleanup(self.client_stream.close)
yield [self.client_stream.connect(("127.0.0.1", port)), event.wait()]
self.io_loop.remove_handler(listener)
listener.close()
@gen_test
def test_http10_no_content_length(self):
# Regression test for a bug in which can_keep_alive would crash
# for an HTTP/1.0 (not 1.1) response with no content-length.
conn = HTTP1Connection(self.client_stream, True)
self.server_stream.write(b"HTTP/1.0 200 Not Modified\r\n\r\nhello")
self.server_stream.close()
event = Event()
test = self
body = []
class Delegate(HTTPMessageDelegate):
def headers_received(self, start_line, headers):
test.code = start_line.code
def data_received(self, data):
body.append(data)
def finish(self):
event.set()
yield conn.read_response(Delegate())
yield event.wait()
self.assertEqual(self.code, 200)
self.assertEqual(b"".join(body), b"hello")

View file

@ -0,0 +1,898 @@
import base64
import binascii
from contextlib import closing
import copy
import gzip
import threading
import datetime
from io import BytesIO
import subprocess
import sys
import time
import typing # noqa: F401
import unicodedata
import unittest
from tornado.escape import utf8, native_str, to_unicode
from tornado import gen
from tornado.httpclient import (
HTTPRequest,
HTTPResponse,
_RequestProxy,
HTTPError,
HTTPClient,
)
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.iostream import IOStream
from tornado.log import gen_log, app_log
from tornado import netutil
from tornado.testing import AsyncHTTPTestCase, bind_unused_port, gen_test, ExpectLog
from tornado.test.util import skipOnTravis
from tornado.web import Application, RequestHandler, url
from tornado.httputil import format_timestamp, HTTPHeaders
class HelloWorldHandler(RequestHandler):
def get(self):
name = self.get_argument("name", "world")
self.set_header("Content-Type", "text/plain")
self.finish("Hello %s!" % name)
class PostHandler(RequestHandler):
def post(self):
self.finish(
"Post arg1: %s, arg2: %s"
% (self.get_argument("arg1"), self.get_argument("arg2"))
)
class PutHandler(RequestHandler):
def put(self):
self.write("Put body: ")
self.write(self.request.body)
class RedirectHandler(RequestHandler):
def prepare(self):
self.write("redirects can have bodies too")
self.redirect(
self.get_argument("url"), status=int(self.get_argument("status", "302"))
)
class RedirectWithoutLocationHandler(RequestHandler):
def prepare(self):
# For testing error handling of a redirect with no location header.
self.set_status(301)
self.finish()
class ChunkHandler(RequestHandler):
@gen.coroutine
def get(self):
self.write("asdf")
self.flush()
# Wait a bit to ensure the chunks are sent and received separately.
yield gen.sleep(0.01)
self.write("qwer")
class AuthHandler(RequestHandler):
def get(self):
self.finish(self.request.headers["Authorization"])
class CountdownHandler(RequestHandler):
def get(self, count):
count = int(count)
if count > 0:
self.redirect(self.reverse_url("countdown", count - 1))
else:
self.write("Zero")
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
class UserAgentHandler(RequestHandler):
def get(self):
self.write(self.request.headers.get("User-Agent", "User agent not set"))
class ContentLength304Handler(RequestHandler):
def get(self):
self.set_status(304)
self.set_header("Content-Length", 42)
def _clear_representation_headers(self):
# Tornado strips content-length from 304 responses, but here we
# want to simulate servers that include the headers anyway.
pass
class PatchHandler(RequestHandler):
def patch(self):
"Return the request payload - so we can check it is being kept"
self.write(self.request.body)
class AllMethodsHandler(RequestHandler):
SUPPORTED_METHODS = RequestHandler.SUPPORTED_METHODS + ("OTHER",) # type: ignore
def method(self):
assert self.request.method is not None
self.write(self.request.method)
get = head = post = put = delete = options = patch = other = method # type: ignore
class SetHeaderHandler(RequestHandler):
def get(self):
# Use get_arguments for keys to get strings, but
# request.arguments for values to get bytes.
for k, v in zip(self.get_arguments("k"), self.request.arguments["v"]):
self.set_header(k, v)
class InvalidGzipHandler(RequestHandler):
def get(self):
# set Content-Encoding manually to avoid automatic gzip encoding
self.set_header("Content-Type", "text/plain")
self.set_header("Content-Encoding", "gzip")
# Triggering the potential bug seems to depend on input length.
# This length is taken from the bad-response example reported in
# https://github.com/tornadoweb/tornado/pull/2875 (uncompressed).
body = "".join("Hello World {}\n".format(i) for i in range(9000))[:149051]
body = gzip.compress(body.encode(), compresslevel=6) + b"\00"
self.write(body)
# These tests end up getting run redundantly: once here with the default
# HTTPClient implementation, and then again in each implementation's own
# test suite.
class HTTPClientCommonTestCase(AsyncHTTPTestCase):
def get_app(self):
return Application(
[
url("/hello", HelloWorldHandler),
url("/post", PostHandler),
url("/put", PutHandler),
url("/redirect", RedirectHandler),
url("/redirect_without_location", RedirectWithoutLocationHandler),
url("/chunk", ChunkHandler),
url("/auth", AuthHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/echopost", EchoPostHandler),
url("/user_agent", UserAgentHandler),
url("/304_with_content_length", ContentLength304Handler),
url("/all_methods", AllMethodsHandler),
url("/patch", PatchHandler),
url("/set_header", SetHeaderHandler),
url("/invalid_gzip", InvalidGzipHandler),
],
gzip=True,
)
def test_patch_receives_payload(self):
body = b"some patch data"
response = self.fetch("/patch", method="PATCH", body=body)
self.assertEqual(response.code, 200)
self.assertEqual(response.body, body)
@skipOnTravis
def test_hello_world(self):
response = self.fetch("/hello")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["Content-Type"], "text/plain")
self.assertEqual(response.body, b"Hello world!")
assert response.request_time is not None
self.assertEqual(int(response.request_time), 0)
response = self.fetch("/hello?name=Ben")
self.assertEqual(response.body, b"Hello Ben!")
def test_streaming_callback(self):
# streaming_callback is also tested in test_chunked
chunks = [] # type: typing.List[bytes]
response = self.fetch("/hello", streaming_callback=chunks.append)
# with streaming_callback, data goes to the callback and not response.body
self.assertEqual(chunks, [b"Hello world!"])
self.assertFalse(response.body)
def test_post(self):
response = self.fetch("/post", method="POST", body="arg1=foo&arg2=bar")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_chunked(self):
response = self.fetch("/chunk")
self.assertEqual(response.body, b"asdfqwer")
chunks = [] # type: typing.List[bytes]
response = self.fetch("/chunk", streaming_callback=chunks.append)
self.assertEqual(chunks, [b"asdf", b"qwer"])
self.assertFalse(response.body)
def test_chunked_close(self):
# test case in which chunks spread read-callback processing
# over several ioloop iterations, but the connection is already closed.
sock, port = bind_unused_port()
with closing(sock):
@gen.coroutine
def accept_callback(conn, address):
# fake an HTTP server using chunked encoding where the final chunks
# and connection close all happen at once
stream = IOStream(conn)
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
yield stream.write(
b"""\
HTTP/1.1 200 OK
Transfer-Encoding: chunked
1
1
1
2
0
""".replace(
b"\n", b"\r\n"
)
)
stream.close()
netutil.add_accept_handler(sock, accept_callback) # type: ignore
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
self.assertEqual(resp.body, b"12")
self.io_loop.remove_handler(sock.fileno())
def test_basic_auth(self):
# This test data appears in section 2 of RFC 7617.
self.assertEqual(
self.fetch(
"/auth", auth_username="Aladdin", auth_password="open sesame"
).body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
)
def test_basic_auth_explicit_mode(self):
self.assertEqual(
self.fetch(
"/auth",
auth_username="Aladdin",
auth_password="open sesame",
auth_mode="basic",
).body,
b"Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
)
def test_basic_auth_unicode(self):
# This test data appears in section 2.1 of RFC 7617.
self.assertEqual(
self.fetch("/auth", auth_username="test", auth_password="123£").body,
b"Basic dGVzdDoxMjPCow==",
)
# The standard mandates NFC. Give it a decomposed username
# and ensure it is normalized to composed form.
username = unicodedata.normalize("NFD", u"josé")
self.assertEqual(
self.fetch("/auth", auth_username=username, auth_password="səcrət").body,
b"Basic am9zw6k6c8mZY3LJmXQ=",
)
def test_unsupported_auth_mode(self):
# curl and simple clients handle errors a bit differently; the
# important thing is that they don't fall back to basic auth
# on an unknown mode.
with ExpectLog(gen_log, "uncaught exception", required=False):
with self.assertRaises((ValueError, HTTPError)): # type: ignore
self.fetch(
"/auth",
auth_username="Aladdin",
auth_password="open sesame",
auth_mode="asdf",
raise_error=True,
)
def test_follow_redirect(self):
response = self.fetch("/countdown/2", follow_redirects=False)
self.assertEqual(302, response.code)
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
response = self.fetch("/countdown/2")
self.assertEqual(200, response.code)
self.assertTrue(response.effective_url.endswith("/countdown/0"))
self.assertEqual(b"Zero", response.body)
def test_redirect_without_location(self):
response = self.fetch("/redirect_without_location", follow_redirects=True)
# If there is no location header, the redirect response should
# just be returned as-is. (This should arguably raise an
# error, but libcurl doesn't treat this as an error, so we
# don't either).
self.assertEqual(301, response.code)
def test_redirect_put_with_body(self):
response = self.fetch(
"/redirect?url=/put&status=307", method="PUT", body="hello"
)
self.assertEqual(response.body, b"Put body: hello")
def test_redirect_put_without_body(self):
# This "without body" edge case is similar to what happens with body_producer.
response = self.fetch(
"/redirect?url=/put&status=307",
method="PUT",
allow_nonstandard_methods=True,
)
self.assertEqual(response.body, b"Put body: ")
def test_method_after_redirect(self):
# Legacy redirect codes (301, 302) convert POST requests to GET.
for status in [301, 302, 303]:
url = "/redirect?url=/all_methods&status=%d" % status
resp = self.fetch(url, method="POST", body=b"")
self.assertEqual(b"GET", resp.body)
# Other methods are left alone, except for 303 redirect, depending on client
for method in ["GET", "OPTIONS", "PUT", "DELETE"]:
resp = self.fetch(url, method=method, allow_nonstandard_methods=True)
if status in [301, 302]:
self.assertEqual(utf8(method), resp.body)
else:
self.assertIn(resp.body, [utf8(method), b"GET"])
# HEAD is different so check it separately.
resp = self.fetch(url, method="HEAD")
self.assertEqual(200, resp.code)
self.assertEqual(b"", resp.body)
# Newer redirects always preserve the original method.
for status in [307, 308]:
url = "/redirect?url=/all_methods&status=307"
for method in ["GET", "OPTIONS", "POST", "PUT", "DELETE"]:
resp = self.fetch(url, method=method, allow_nonstandard_methods=True)
self.assertEqual(method, to_unicode(resp.body))
resp = self.fetch(url, method="HEAD")
self.assertEqual(200, resp.code)
self.assertEqual(b"", resp.body)
def test_credentials_in_url(self):
url = self.get_url("/auth").replace("http://", "http://me:secret@")
response = self.fetch(url)
self.assertEqual(b"Basic " + base64.b64encode(b"me:secret"), response.body)
def test_body_encoding(self):
unicode_body = u"\xe9"
byte_body = binascii.a2b_hex(b"e9")
# unicode string in body gets converted to utf8
response = self.fetch(
"/echopost",
method="POST",
body=unicode_body,
headers={"Content-Type": "application/blah"},
)
self.assertEqual(response.headers["Content-Length"], "2")
self.assertEqual(response.body, utf8(unicode_body))
# byte strings pass through directly
response = self.fetch(
"/echopost",
method="POST",
body=byte_body,
headers={"Content-Type": "application/blah"},
)
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
# Mixing unicode in headers and byte string bodies shouldn't
# break anything
response = self.fetch(
"/echopost",
method="POST",
body=byte_body,
headers={"Content-Type": "application/blah"},
user_agent=u"foo",
)
self.assertEqual(response.headers["Content-Length"], "1")
self.assertEqual(response.body, byte_body)
def test_types(self):
response = self.fetch("/hello")
self.assertEqual(type(response.body), bytes)
self.assertEqual(type(response.headers["Content-Type"]), str)
self.assertEqual(type(response.code), int)
self.assertEqual(type(response.effective_url), str)
def test_gzip(self):
# All the tests in this file should be using gzip, but this test
# ensures that it is in fact getting compressed, and also tests
# the httpclient's decompress=False option.
# Setting Accept-Encoding manually bypasses the client's
# decompression so we can see the raw data.
response = self.fetch(
"/chunk", decompress_response=False, headers={"Accept-Encoding": "gzip"}
)
self.assertEqual(response.headers["Content-Encoding"], "gzip")
self.assertNotEqual(response.body, b"asdfqwer")
# Our test data gets bigger when gzipped. Oops. :)
# Chunked encoding bypasses the MIN_LENGTH check.
self.assertEqual(len(response.body), 34)
f = gzip.GzipFile(mode="r", fileobj=response.buffer)
self.assertEqual(f.read(), b"asdfqwer")
def test_invalid_gzip(self):
# test if client hangs on tricky invalid gzip
# curl/simple httpclient have different behavior (exception, logging)
with ExpectLog(
app_log, "(Uncaught exception|Exception in callback)", required=False
):
try:
response = self.fetch("/invalid_gzip")
self.assertEqual(response.code, 200)
self.assertEqual(response.body[:14], b"Hello World 0\n")
except HTTPError:
pass # acceptable
def test_header_callback(self):
first_line = []
headers = {}
chunks = []
def header_callback(header_line):
if header_line.startswith("HTTP/1.1 101"):
# Upgrading to HTTP/2
pass
elif header_line.startswith("HTTP/"):
first_line.append(header_line)
elif header_line != "\r\n":
k, v = header_line.split(":", 1)
headers[k.lower()] = v.strip()
def streaming_callback(chunk):
# All header callbacks are run before any streaming callbacks,
# so the header data is available to process the data as it
# comes in.
self.assertEqual(headers["content-type"], "text/html; charset=UTF-8")
chunks.append(chunk)
self.fetch(
"/chunk",
header_callback=header_callback,
streaming_callback=streaming_callback,
)
self.assertEqual(len(first_line), 1, first_line)
self.assertRegexpMatches(first_line[0], "HTTP/[0-9]\\.[0-9] 200.*\r\n")
self.assertEqual(chunks, [b"asdf", b"qwer"])
@gen_test
def test_configure_defaults(self):
defaults = dict(user_agent="TestDefaultUserAgent", allow_ipv6=False)
# Construct a new instance of the configured client class
client = self.http_client.__class__(force_instance=True, defaults=defaults)
try:
response = yield client.fetch(self.get_url("/user_agent"))
self.assertEqual(response.body, b"TestDefaultUserAgent")
finally:
client.close()
def test_header_types(self):
# Header values may be passed as character or utf8 byte strings,
# in a plain dictionary or an HTTPHeaders object.
# Keys must always be the native str type.
# All combinations should have the same results on the wire.
for value in [u"MyUserAgent", b"MyUserAgent"]:
for container in [dict, HTTPHeaders]:
headers = container()
headers["User-Agent"] = value
resp = self.fetch("/user_agent", headers=headers)
self.assertEqual(
resp.body,
b"MyUserAgent",
"response=%r, value=%r, container=%r"
% (resp.body, value, container),
)
def test_multi_line_headers(self):
# Multi-line http headers are rare but rfc-allowed
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2
sock, port = bind_unused_port()
with closing(sock):
@gen.coroutine
def accept_callback(conn, address):
stream = IOStream(conn)
request_data = yield stream.read_until(b"\r\n\r\n")
if b"HTTP/1." not in request_data:
self.skipTest("requires HTTP/1.x")
yield stream.write(
b"""\
HTTP/1.1 200 OK
X-XSS-Protection: 1;
\tmode=block
""".replace(
b"\n", b"\r\n"
)
)
stream.close()
netutil.add_accept_handler(sock, accept_callback) # type: ignore
try:
resp = self.fetch("http://127.0.0.1:%d/" % port)
resp.rethrow()
self.assertEqual(resp.headers["X-XSS-Protection"], "1; mode=block")
finally:
self.io_loop.remove_handler(sock.fileno())
def test_304_with_content_length(self):
# According to the spec 304 responses SHOULD NOT include
# Content-Length or other entity headers, but some servers do it
# anyway.
# http://www.w3.org/Protocols/rfc2616/rfc2616-sec10.html#sec10.3.5
response = self.fetch("/304_with_content_length")
self.assertEqual(response.code, 304)
self.assertEqual(response.headers["Content-Length"], "42")
@gen_test
def test_future_interface(self):
response = yield self.http_client.fetch(self.get_url("/hello"))
self.assertEqual(response.body, b"Hello world!")
@gen_test
def test_future_http_error(self):
with self.assertRaises(HTTPError) as context:
yield self.http_client.fetch(self.get_url("/notfound"))
assert context.exception is not None
assert context.exception.response is not None
self.assertEqual(context.exception.code, 404)
self.assertEqual(context.exception.response.code, 404)
@gen_test
def test_future_http_error_no_raise(self):
response = yield self.http_client.fetch(
self.get_url("/notfound"), raise_error=False
)
self.assertEqual(response.code, 404)
@gen_test
def test_reuse_request_from_response(self):
# The response.request attribute should be an HTTPRequest, not
# a _RequestProxy.
# This test uses self.http_client.fetch because self.fetch calls
# self.get_url on the input unconditionally.
url = self.get_url("/hello")
response = yield self.http_client.fetch(url)
self.assertEqual(response.request.url, url)
self.assertTrue(isinstance(response.request, HTTPRequest))
response2 = yield self.http_client.fetch(response.request)
self.assertEqual(response2.body, b"Hello world!")
@gen_test
def test_bind_source_ip(self):
url = self.get_url("/hello")
request = HTTPRequest(url, network_interface="127.0.0.1")
response = yield self.http_client.fetch(request)
self.assertEqual(response.code, 200)
with self.assertRaises((ValueError, HTTPError)) as context: # type: ignore
request = HTTPRequest(url, network_interface="not-interface-or-ip")
yield self.http_client.fetch(request)
self.assertIn("not-interface-or-ip", str(context.exception))
def test_all_methods(self):
for method in ["GET", "DELETE", "OPTIONS"]:
response = self.fetch("/all_methods", method=method)
self.assertEqual(response.body, utf8(method))
for method in ["POST", "PUT", "PATCH"]:
response = self.fetch("/all_methods", method=method, body=b"")
self.assertEqual(response.body, utf8(method))
response = self.fetch("/all_methods", method="HEAD")
self.assertEqual(response.body, b"")
response = self.fetch(
"/all_methods", method="OTHER", allow_nonstandard_methods=True
)
self.assertEqual(response.body, b"OTHER")
def test_body_sanity_checks(self):
# These methods require a body.
for method in ("POST", "PUT", "PATCH"):
with self.assertRaises(ValueError) as context:
self.fetch("/all_methods", method=method, raise_error=True)
self.assertIn("must not be None", str(context.exception))
resp = self.fetch(
"/all_methods", method=method, allow_nonstandard_methods=True
)
self.assertEqual(resp.code, 200)
# These methods don't allow a body.
for method in ("GET", "DELETE", "OPTIONS"):
with self.assertRaises(ValueError) as context:
self.fetch(
"/all_methods", method=method, body=b"asdf", raise_error=True
)
self.assertIn("must be None", str(context.exception))
# In most cases this can be overridden, but curl_httpclient
# does not allow body with a GET at all.
if method != "GET":
self.fetch(
"/all_methods",
method=method,
body=b"asdf",
allow_nonstandard_methods=True,
raise_error=True,
)
self.assertEqual(resp.code, 200)
# This test causes odd failures with the combination of
# curl_httpclient (at least with the version of libcurl available
# on ubuntu 12.04), TwistedIOLoop, and epoll. For POST (but not PUT),
# curl decides the response came back too soon and closes the connection
# to start again. It does this *before* telling the socket callback to
# unregister the FD. Some IOLoop implementations have special kernel
# integration to discover this immediately. Tornado's IOLoops
# ignore errors on remove_handler to accommodate this behavior, but
# Twisted's reactor does not. The removeReader call fails and so
# do all future removeAll calls (which our tests do at cleanup).
#
# def test_post_307(self):
# response = self.fetch("/redirect?status=307&url=/post",
# method="POST", body=b"arg1=foo&arg2=bar")
# self.assertEqual(response.body, b"Post arg1: foo, arg2: bar")
def test_put_307(self):
response = self.fetch(
"/redirect?status=307&url=/put", method="PUT", body=b"hello"
)
response.rethrow()
self.assertEqual(response.body, b"Put body: hello")
def test_non_ascii_header(self):
# Non-ascii headers are sent as latin1.
response = self.fetch("/set_header?k=foo&v=%E9")
response.rethrow()
self.assertEqual(response.headers["Foo"], native_str(u"\u00e9"))
def test_response_times(self):
# A few simple sanity checks of the response time fields to
# make sure they're using the right basis (between the
# wall-time and monotonic clocks).
start_time = time.time()
response = self.fetch("/hello")
response.rethrow()
self.assertGreaterEqual(response.request_time, 0)
self.assertLess(response.request_time, 1.0)
# A very crude check to make sure that start_time is based on
# wall time and not the monotonic clock.
assert response.start_time is not None
self.assertLess(abs(response.start_time - start_time), 1.0)
for k, v in response.time_info.items():
self.assertTrue(0 <= v < 1.0, "time_info[%s] out of bounds: %s" % (k, v))
def test_zero_timeout(self):
response = self.fetch("/hello", connect_timeout=0)
self.assertEqual(response.code, 200)
response = self.fetch("/hello", request_timeout=0)
self.assertEqual(response.code, 200)
response = self.fetch("/hello", connect_timeout=0, request_timeout=0)
self.assertEqual(response.code, 200)
@gen_test
def test_error_after_cancel(self):
fut = self.http_client.fetch(self.get_url("/404"))
self.assertTrue(fut.cancel())
with ExpectLog(app_log, "Exception after Future was cancelled") as el:
# We can't wait on the cancelled Future any more, so just
# let the IOLoop run until the exception gets logged (or
# not, in which case we exit the loop and ExpectLog will
# raise).
for i in range(100):
yield gen.sleep(0.01)
if el.logged_stack:
break
class RequestProxyTest(unittest.TestCase):
def test_request_set(self):
proxy = _RequestProxy(
HTTPRequest("http://example.com/", user_agent="foo"), dict()
)
self.assertEqual(proxy.user_agent, "foo")
def test_default_set(self):
proxy = _RequestProxy(
HTTPRequest("http://example.com/"), dict(network_interface="foo")
)
self.assertEqual(proxy.network_interface, "foo")
def test_both_set(self):
proxy = _RequestProxy(
HTTPRequest("http://example.com/", proxy_host="foo"), dict(proxy_host="bar")
)
self.assertEqual(proxy.proxy_host, "foo")
def test_neither_set(self):
proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict())
self.assertIs(proxy.auth_username, None)
def test_bad_attribute(self):
proxy = _RequestProxy(HTTPRequest("http://example.com/"), dict())
with self.assertRaises(AttributeError):
proxy.foo
def test_defaults_none(self):
proxy = _RequestProxy(HTTPRequest("http://example.com/"), None)
self.assertIs(proxy.auth_username, None)
class HTTPResponseTestCase(unittest.TestCase):
def test_str(self):
response = HTTPResponse( # type: ignore
HTTPRequest("http://example.com"), 200, buffer=BytesIO()
)
s = str(response)
self.assertTrue(s.startswith("HTTPResponse("))
self.assertIn("code=200", s)
class SyncHTTPClientTest(unittest.TestCase):
def setUp(self):
self.server_ioloop = IOLoop()
event = threading.Event()
@gen.coroutine
def init_server():
sock, self.port = bind_unused_port()
app = Application([("/", HelloWorldHandler)])
self.server = HTTPServer(app)
self.server.add_socket(sock)
event.set()
def start():
self.server_ioloop.run_sync(init_server)
self.server_ioloop.start()
self.server_thread = threading.Thread(target=start)
self.server_thread.start()
event.wait()
self.http_client = HTTPClient()
def tearDown(self):
def stop_server():
self.server.stop()
# Delay the shutdown of the IOLoop by several iterations because
# the server may still have some cleanup work left when
# the client finishes with the response (this is noticeable
# with http/2, which leaves a Future with an unexamined
# StreamClosedError on the loop).
@gen.coroutine
def slow_stop():
yield self.server.close_all_connections()
# The number of iterations is difficult to predict. Typically,
# one is sufficient, although sometimes it needs more.
for i in range(5):
yield
self.server_ioloop.stop()
self.server_ioloop.add_callback(slow_stop)
self.server_ioloop.add_callback(stop_server)
self.server_thread.join()
self.http_client.close()
self.server_ioloop.close(all_fds=True)
def get_url(self, path):
return "http://127.0.0.1:%d%s" % (self.port, path)
def test_sync_client(self):
response = self.http_client.fetch(self.get_url("/"))
self.assertEqual(b"Hello world!", response.body)
def test_sync_client_error(self):
# Synchronous HTTPClient raises errors directly; no need for
# response.rethrow()
with self.assertRaises(HTTPError) as assertion:
self.http_client.fetch(self.get_url("/notfound"))
self.assertEqual(assertion.exception.code, 404)
class SyncHTTPClientSubprocessTest(unittest.TestCase):
def test_destructor_log(self):
# Regression test for
# https://github.com/tornadoweb/tornado/issues/2539
#
# In the past, the following program would log an
# "inconsistent AsyncHTTPClient cache" error from a destructor
# when the process is shutting down. The shutdown process is
# subtle and I don't fully understand it; the failure does not
# manifest if that lambda isn't there or is a simpler object
# like an int (nor does it manifest in the tornado test suite
# as a whole, which is why we use this subprocess).
proc = subprocess.run(
[
sys.executable,
"-c",
"from tornado.httpclient import HTTPClient; f = lambda: None; c = HTTPClient()",
],
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
check=True,
timeout=5,
)
if proc.stdout:
print("STDOUT:")
print(to_unicode(proc.stdout))
if proc.stdout:
self.fail("subprocess produced unexpected output")
class HTTPRequestTestCase(unittest.TestCase):
def test_headers(self):
request = HTTPRequest("http://example.com", headers={"foo": "bar"})
self.assertEqual(request.headers, {"foo": "bar"})
def test_headers_setter(self):
request = HTTPRequest("http://example.com")
request.headers = {"bar": "baz"} # type: ignore
self.assertEqual(request.headers, {"bar": "baz"})
def test_null_headers_setter(self):
request = HTTPRequest("http://example.com")
request.headers = None # type: ignore
self.assertEqual(request.headers, {})
def test_body(self):
request = HTTPRequest("http://example.com", body="foo")
self.assertEqual(request.body, utf8("foo"))
def test_body_setter(self):
request = HTTPRequest("http://example.com")
request.body = "foo" # type: ignore
self.assertEqual(request.body, utf8("foo"))
def test_if_modified_since(self):
http_date = datetime.datetime.utcnow()
request = HTTPRequest("http://example.com", if_modified_since=http_date)
self.assertEqual(
request.headers, {"If-Modified-Since": format_timestamp(http_date)}
)
class HTTPErrorTestCase(unittest.TestCase):
def test_copy(self):
e = HTTPError(403)
e2 = copy.copy(e)
self.assertIsNot(e, e2)
self.assertEqual(e.code, e2.code)
def test_plain_error(self):
e = HTTPError(403)
self.assertEqual(str(e), "HTTP 403: Forbidden")
self.assertEqual(repr(e), "HTTP 403: Forbidden")
def test_error_with_response(self):
resp = HTTPResponse(HTTPRequest("http://example.com/"), 403)
with self.assertRaises(HTTPError) as cm:
resp.rethrow()
e = cm.exception
self.assertEqual(str(e), "HTTP 403: Forbidden")
self.assertEqual(repr(e), "HTTP 403: Forbidden")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,521 @@
from tornado.httputil import (
url_concat,
parse_multipart_form_data,
HTTPHeaders,
format_timestamp,
HTTPServerRequest,
parse_request_start_line,
parse_cookie,
qs_to_qsl,
HTTPInputError,
HTTPFile,
)
from tornado.escape import utf8, native_str
from tornado.log import gen_log
from tornado.testing import ExpectLog
import copy
import datetime
import logging
import pickle
import time
import urllib.parse
import unittest
from typing import Tuple, Dict, List
def form_data_args() -> Tuple[Dict[str, List[bytes]], Dict[str, List[HTTPFile]]]:
"""Return two empty dicts suitable for use with parse_multipart_form_data.
mypy insists on type annotations for dict literals, so this lets us avoid
the verbose types throughout this test.
"""
return {}, {}
class TestUrlConcat(unittest.TestCase):
def test_url_concat_no_query_params(self):
url = url_concat("https://localhost/path", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_encode_args(self):
url = url_concat("https://localhost/path", [("y", "/y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?y=%2Fy&z=z")
def test_url_concat_trailing_q(self):
url = url_concat("https://localhost/path?", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?y=y&z=z")
def test_url_concat_q_with_no_trailing_amp(self):
url = url_concat("https://localhost/path?x", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
def test_url_concat_trailing_amp(self):
url = url_concat("https://localhost/path?x&", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?x=&y=y&z=z")
def test_url_concat_mult_params(self):
url = url_concat("https://localhost/path?a=1&b=2", [("y", "y"), ("z", "z")])
self.assertEqual(url, "https://localhost/path?a=1&b=2&y=y&z=z")
def test_url_concat_no_params(self):
url = url_concat("https://localhost/path?r=1&t=2", [])
self.assertEqual(url, "https://localhost/path?r=1&t=2")
def test_url_concat_none_params(self):
url = url_concat("https://localhost/path?r=1&t=2", None)
self.assertEqual(url, "https://localhost/path?r=1&t=2")
def test_url_concat_with_frag(self):
url = url_concat("https://localhost/path#tab", [("y", "y")])
self.assertEqual(url, "https://localhost/path?y=y#tab")
def test_url_concat_multi_same_params(self):
url = url_concat("https://localhost/path", [("y", "y1"), ("y", "y2")])
self.assertEqual(url, "https://localhost/path?y=y1&y=y2")
def test_url_concat_multi_same_query_params(self):
url = url_concat("https://localhost/path?r=1&r=2", [("y", "y")])
self.assertEqual(url, "https://localhost/path?r=1&r=2&y=y")
def test_url_concat_dict_params(self):
url = url_concat("https://localhost/path", dict(y="y"))
self.assertEqual(url, "https://localhost/path?y=y")
class QsParseTest(unittest.TestCase):
def test_parsing(self):
qsstring = "a=1&b=2&a=3"
qs = urllib.parse.parse_qs(qsstring)
qsl = list(qs_to_qsl(qs))
self.assertIn(("a", "1"), qsl)
self.assertIn(("a", "3"), qsl)
self.assertIn(("b", "2"), qsl)
class MultipartFormDataTest(unittest.TestCase):
def test_file_upload(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_unquoted_names(self):
# quotes are optional unless special characters are present
data = b"""\
--1234
Content-Disposition: form-data; name=files; filename=ab.txt
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_special_filenames(self):
filenames = [
"a;b.txt",
'a"b.txt',
'a";b.txt',
'a;"b.txt',
'a";";.txt',
'a\\"b.txt',
"a\\b.txt",
]
for filename in filenames:
logging.debug("trying filename %r", filename)
str_data = """\
--1234
Content-Disposition: form-data; name="files"; filename="%s"
Foo
--1234--""" % filename.replace(
"\\", "\\\\"
).replace(
'"', '\\"'
)
data = utf8(str_data.replace("\n", "\r\n"))
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], filename)
self.assertEqual(file["body"], b"Foo")
def test_non_ascii_filename(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"; filename*=UTF-8''%C3%A1b.txt
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], u"áb.txt")
self.assertEqual(file["body"], b"Foo")
def test_boundary_starts_and_ends_with_quotes(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
parse_multipart_form_data(b'"1234"', data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
def test_missing_headers(self):
data = b"""\
--1234
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
with ExpectLog(gen_log, "multipart/form-data missing headers"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_invalid_content_disposition(self):
data = b"""\
--1234
Content-Disposition: invalid; name="files"; filename="ab.txt"
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_line_does_not_end_with_correct_line_break(self):
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
with ExpectLog(gen_log, "Invalid multipart/form-data"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_content_disposition_header_without_name_parameter(self):
data = b"""\
--1234
Content-Disposition: form-data; filename="ab.txt"
Foo
--1234--""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
with ExpectLog(gen_log, "multipart/form-data value missing name"):
parse_multipart_form_data(b"1234", data, args, files)
self.assertEqual(files, {})
def test_data_after_final_boundary(self):
# The spec requires that data after the final boundary be ignored.
# http://www.w3.org/Protocols/rfc1341/7_2_Multipart.html
# In practice, some libraries include an extra CRLF after the boundary.
data = b"""\
--1234
Content-Disposition: form-data; name="files"; filename="ab.txt"
Foo
--1234--
""".replace(
b"\n", b"\r\n"
)
args, files = form_data_args()
parse_multipart_form_data(b"1234", data, args, files)
file = files["files"][0]
self.assertEqual(file["filename"], "ab.txt")
self.assertEqual(file["body"], b"Foo")
class HTTPHeadersTest(unittest.TestCase):
def test_multi_line(self):
# Lines beginning with whitespace are appended to the previous line
# with any leading whitespace replaced by a single space.
# Note that while multi-line headers are a part of the HTTP spec,
# their use is strongly discouraged.
data = """\
Foo: bar
baz
Asdf: qwer
\tzxcv
Foo: even
more
lines
""".replace(
"\n", "\r\n"
)
headers = HTTPHeaders.parse(data)
self.assertEqual(headers["asdf"], "qwer zxcv")
self.assertEqual(headers.get_list("asdf"), ["qwer zxcv"])
self.assertEqual(headers["Foo"], "bar baz,even more lines")
self.assertEqual(headers.get_list("foo"), ["bar baz", "even more lines"])
self.assertEqual(
sorted(list(headers.get_all())),
[("Asdf", "qwer zxcv"), ("Foo", "bar baz"), ("Foo", "even more lines")],
)
def test_malformed_continuation(self):
# If the first line starts with whitespace, it's a
# continuation line with nothing to continue, so reject it
# (with a proper error).
data = " Foo: bar"
self.assertRaises(HTTPInputError, HTTPHeaders.parse, data)
def test_unicode_newlines(self):
# Ensure that only \r\n is recognized as a header separator, and not
# the other newline-like unicode characters.
# Characters that are likely to be problematic can be found in
# http://unicode.org/standard/reports/tr13/tr13-5.html
# and cpython's unicodeobject.c (which defines the implementation
# of unicode_type.splitlines(), and uses a different list than TR13).
newlines = [
u"\u001b", # VERTICAL TAB
u"\u001c", # FILE SEPARATOR
u"\u001d", # GROUP SEPARATOR
u"\u001e", # RECORD SEPARATOR
u"\u0085", # NEXT LINE
u"\u2028", # LINE SEPARATOR
u"\u2029", # PARAGRAPH SEPARATOR
]
for newline in newlines:
# Try the utf8 and latin1 representations of each newline
for encoding in ["utf8", "latin1"]:
try:
try:
encoded = newline.encode(encoding)
except UnicodeEncodeError:
# Some chars cannot be represented in latin1
continue
data = b"Cookie: foo=" + encoded + b"bar"
# parse() wants a native_str, so decode through latin1
# in the same way the real parser does.
headers = HTTPHeaders.parse(native_str(data.decode("latin1")))
expected = [
(
"Cookie",
"foo=" + native_str(encoded.decode("latin1")) + "bar",
)
]
self.assertEqual(expected, list(headers.get_all()))
except Exception:
gen_log.warning("failed while trying %r in %s", newline, encoding)
raise
def test_optional_cr(self):
# Both CRLF and LF should be accepted as separators. CR should not be
# part of the data when followed by LF, but it is a normal char
# otherwise (or should bare CR be an error?)
headers = HTTPHeaders.parse("CRLF: crlf\r\nLF: lf\nCR: cr\rMore: more\r\n")
self.assertEqual(
sorted(headers.get_all()),
[("Cr", "cr\rMore: more"), ("Crlf", "crlf"), ("Lf", "lf")],
)
def test_copy(self):
all_pairs = [("A", "1"), ("A", "2"), ("B", "c")]
h1 = HTTPHeaders()
for k, v in all_pairs:
h1.add(k, v)
h2 = h1.copy()
h3 = copy.copy(h1)
h4 = copy.deepcopy(h1)
for headers in [h1, h2, h3, h4]:
# All the copies are identical, no matter how they were
# constructed.
self.assertEqual(list(sorted(headers.get_all())), all_pairs)
for headers in [h2, h3, h4]:
# Neither the dict or its member lists are reused.
self.assertIsNot(headers, h1)
self.assertIsNot(headers.get_list("A"), h1.get_list("A"))
def test_pickle_roundtrip(self):
headers = HTTPHeaders()
headers.add("Set-Cookie", "a=b")
headers.add("Set-Cookie", "c=d")
headers.add("Content-Type", "text/html")
pickled = pickle.dumps(headers)
unpickled = pickle.loads(pickled)
self.assertEqual(sorted(headers.get_all()), sorted(unpickled.get_all()))
self.assertEqual(sorted(headers.items()), sorted(unpickled.items()))
def test_setdefault(self):
headers = HTTPHeaders()
headers["foo"] = "bar"
# If a value is present, setdefault returns it without changes.
self.assertEqual(headers.setdefault("foo", "baz"), "bar")
self.assertEqual(headers["foo"], "bar")
# If a value is not present, setdefault sets it for future use.
self.assertEqual(headers.setdefault("quux", "xyzzy"), "xyzzy")
self.assertEqual(headers["quux"], "xyzzy")
self.assertEqual(sorted(headers.get_all()), [("Foo", "bar"), ("Quux", "xyzzy")])
def test_string(self):
headers = HTTPHeaders()
headers.add("Foo", "1")
headers.add("Foo", "2")
headers.add("Foo", "3")
headers2 = HTTPHeaders.parse(str(headers))
self.assertEqual(headers, headers2)
class FormatTimestampTest(unittest.TestCase):
# Make sure that all the input types are supported.
TIMESTAMP = 1359312200.503611
EXPECTED = "Sun, 27 Jan 2013 18:43:20 GMT"
def check(self, value):
self.assertEqual(format_timestamp(value), self.EXPECTED)
def test_unix_time_float(self):
self.check(self.TIMESTAMP)
def test_unix_time_int(self):
self.check(int(self.TIMESTAMP))
def test_struct_time(self):
self.check(time.gmtime(self.TIMESTAMP))
def test_time_tuple(self):
tup = tuple(time.gmtime(self.TIMESTAMP))
self.assertEqual(9, len(tup))
self.check(tup)
def test_datetime(self):
self.check(datetime.datetime.utcfromtimestamp(self.TIMESTAMP))
# HTTPServerRequest is mainly tested incidentally to the server itself,
# but this tests the parts of the class that can be tested in isolation.
class HTTPServerRequestTest(unittest.TestCase):
def test_default_constructor(self):
# All parameters are formally optional, but uri is required
# (and has been for some time). This test ensures that no
# more required parameters slip in.
HTTPServerRequest(uri="/")
def test_body_is_a_byte_string(self):
requets = HTTPServerRequest(uri="/")
self.assertIsInstance(requets.body, bytes)
def test_repr_does_not_contain_headers(self):
request = HTTPServerRequest(
uri="/", headers=HTTPHeaders({"Canary": ["Coal Mine"]})
)
self.assertTrue("Canary" not in repr(request))
class ParseRequestStartLineTest(unittest.TestCase):
METHOD = "GET"
PATH = "/foo"
VERSION = "HTTP/1.1"
def test_parse_request_start_line(self):
start_line = " ".join([self.METHOD, self.PATH, self.VERSION])
parsed_start_line = parse_request_start_line(start_line)
self.assertEqual(parsed_start_line.method, self.METHOD)
self.assertEqual(parsed_start_line.path, self.PATH)
self.assertEqual(parsed_start_line.version, self.VERSION)
class ParseCookieTest(unittest.TestCase):
# These tests copied from Django:
# https://github.com/django/django/pull/6277/commits/da810901ada1cae9fc1f018f879f11a7fb467b28
def test_python_cookies(self):
"""
Test cases copied from Python's Lib/test/test_http_cookies.py
"""
self.assertEqual(
parse_cookie("chips=ahoy; vienna=finger"),
{"chips": "ahoy", "vienna": "finger"},
)
# Here parse_cookie() differs from Python's cookie parsing in that it
# treats all semicolons as delimiters, even within quotes.
self.assertEqual(
parse_cookie('keebler="E=mc2; L=\\"Loves\\"; fudge=\\012;"'),
{"keebler": '"E=mc2', "L": '\\"Loves\\"', "fudge": "\\012", "": '"'},
)
# Illegal cookies that have an '=' char in an unquoted value.
self.assertEqual(parse_cookie("keebler=E=mc2"), {"keebler": "E=mc2"})
# Cookies with ':' character in their name.
self.assertEqual(
parse_cookie("key:term=value:term"), {"key:term": "value:term"}
)
# Cookies with '[' and ']'.
self.assertEqual(
parse_cookie("a=b; c=[; d=r; f=h"), {"a": "b", "c": "[", "d": "r", "f": "h"}
)
def test_cookie_edgecases(self):
# Cookies that RFC6265 allows.
self.assertEqual(
parse_cookie("a=b; Domain=example.com"), {"a": "b", "Domain": "example.com"}
)
# parse_cookie() has historically kept only the last cookie with the
# same name.
self.assertEqual(parse_cookie("a=b; h=i; a=c"), {"a": "c", "h": "i"})
def test_invalid_cookies(self):
"""
Cookie strings that go against RFC6265 but browsers will send if set
via document.cookie.
"""
# Chunks without an equals sign appear as unnamed values per
# https://bugzilla.mozilla.org/show_bug.cgi?id=169091
self.assertIn(
"django_language",
parse_cookie("abc=def; unnamed; django_language=en").keys(),
)
# Even a double quote may be an unamed value.
self.assertEqual(parse_cookie('a=b; "; c=d'), {"a": "b", "": '"', "c": "d"})
# Spaces in names and values, and an equals sign in values.
self.assertEqual(
parse_cookie("a b c=d e = f; gh=i"), {"a b c": "d e = f", "gh": "i"}
)
# More characters the spec forbids.
self.assertEqual(
parse_cookie('a b,c<>@:/[]?{}=d " =e,f g'),
{"a b,c<>@:/[]?{}": 'd " =e,f g'},
)
# Unicode characters. The spec only allows ASCII.
self.assertEqual(
parse_cookie("saint=André Bessette"),
{"saint": native_str("André Bessette")},
)
# Browsers don't send extra whitespace or semicolons in Cookie headers,
# but parse_cookie() should parse whitespace the same way
# document.cookie parses whitespace.
self.assertEqual(
parse_cookie(" = b ; ; = ; c = ; "), {"": "b", "c": ""}
)

View file

@ -0,0 +1,66 @@
# flake8: noqa
import subprocess
import sys
import unittest
_import_everything = b"""
# The event loop is not fork-safe, and it's easy to initialize an asyncio.Future
# at startup, which in turn creates the default event loop and prevents forking.
# Explicitly disallow the default event loop so that an error will be raised
# if something tries to touch it.
import asyncio
asyncio.set_event_loop(None)
import tornado.auth
import tornado.autoreload
import tornado.concurrent
import tornado.escape
import tornado.gen
import tornado.http1connection
import tornado.httpclient
import tornado.httpserver
import tornado.httputil
import tornado.ioloop
import tornado.iostream
import tornado.locale
import tornado.log
import tornado.netutil
import tornado.options
import tornado.process
import tornado.simple_httpclient
import tornado.tcpserver
import tornado.tcpclient
import tornado.template
import tornado.testing
import tornado.util
import tornado.web
import tornado.websocket
import tornado.wsgi
try:
import pycurl
except ImportError:
pass
else:
import tornado.curl_httpclient
"""
class ImportTest(unittest.TestCase):
def test_import_everything(self):
# Test that all Tornado modules can be imported without side effects,
# specifically without initializing the default asyncio event loop.
# Since we can't tell which modules may have already beein imported
# in our process, do it in a subprocess for a clean slate.
proc = subprocess.Popen([sys.executable], stdin=subprocess.PIPE)
proc.communicate(_import_everything)
self.assertEqual(proc.returncode, 0)
def test_import_aliases(self):
# Ensure we don't delete formerly-documented aliases accidentally.
import tornado.ioloop
import tornado.gen
import tornado.util
self.assertIs(tornado.ioloop.TimeoutError, tornado.util.TimeoutError)
self.assertIs(tornado.gen.TimeoutError, tornado.util.TimeoutError)

View file

@ -0,0 +1,725 @@
from concurrent.futures import ThreadPoolExecutor
from concurrent import futures
import contextlib
import datetime
import functools
import socket
import subprocess
import sys
import threading
import time
import types
from unittest import mock
import unittest
from tornado.escape import native_str
from tornado import gen
from tornado.ioloop import IOLoop, TimeoutError, PeriodicCallback
from tornado.log import app_log
from tornado.testing import AsyncTestCase, bind_unused_port, ExpectLog, gen_test
from tornado.test.util import skipIfNonUnix, skipOnTravis
import typing
if typing.TYPE_CHECKING:
from typing import List # noqa: F401
class TestIOLoop(AsyncTestCase):
def test_add_callback_return_sequence(self):
# A callback returning {} or [] shouldn't spin the CPU, see Issue #1803.
self.calls = 0
loop = self.io_loop
test = self
old_add_callback = loop.add_callback
def add_callback(self, callback, *args, **kwargs):
test.calls += 1
old_add_callback(callback, *args, **kwargs)
loop.add_callback = types.MethodType(add_callback, loop) # type: ignore
loop.add_callback(lambda: {}) # type: ignore
loop.add_callback(lambda: []) # type: ignore
loop.add_timeout(datetime.timedelta(milliseconds=50), loop.stop)
loop.start()
self.assertLess(self.calls, 10)
@skipOnTravis
def test_add_callback_wakeup(self):
# Make sure that add_callback from inside a running IOLoop
# wakes up the IOLoop immediately instead of waiting for a timeout.
def callback():
self.called = True
self.stop()
def schedule_callback():
self.called = False
self.io_loop.add_callback(callback)
# Store away the time so we can check if we woke up immediately
self.start_time = time.time()
self.io_loop.add_timeout(self.io_loop.time(), schedule_callback)
self.wait()
self.assertAlmostEqual(time.time(), self.start_time, places=2)
self.assertTrue(self.called)
@skipOnTravis
def test_add_callback_wakeup_other_thread(self):
def target():
# sleep a bit to let the ioloop go into its poll loop
time.sleep(0.01)
self.stop_time = time.time()
self.io_loop.add_callback(self.stop)
thread = threading.Thread(target=target)
self.io_loop.add_callback(thread.start)
self.wait()
delta = time.time() - self.stop_time
self.assertLess(delta, 0.1)
thread.join()
def test_add_timeout_timedelta(self):
self.io_loop.add_timeout(datetime.timedelta(microseconds=1), self.stop)
self.wait()
def test_multiple_add(self):
sock, port = bind_unused_port()
try:
self.io_loop.add_handler(
sock.fileno(), lambda fd, events: None, IOLoop.READ
)
# Attempting to add the same handler twice fails
# (with a platform-dependent exception)
self.assertRaises(
Exception,
self.io_loop.add_handler,
sock.fileno(),
lambda fd, events: None,
IOLoop.READ,
)
finally:
self.io_loop.remove_handler(sock.fileno())
sock.close()
def test_remove_without_add(self):
# remove_handler should not throw an exception if called on an fd
# was never added.
sock, port = bind_unused_port()
try:
self.io_loop.remove_handler(sock.fileno())
finally:
sock.close()
def test_add_callback_from_signal(self):
# cheat a little bit and just run this normally, since we can't
# easily simulate the races that happen with real signal handlers
self.io_loop.add_callback_from_signal(self.stop)
self.wait()
def test_add_callback_from_signal_other_thread(self):
# Very crude test, just to make sure that we cover this case.
# This also happens to be the first test where we run an IOLoop in
# a non-main thread.
other_ioloop = IOLoop()
thread = threading.Thread(target=other_ioloop.start)
thread.start()
other_ioloop.add_callback_from_signal(other_ioloop.stop)
thread.join()
other_ioloop.close()
def test_add_callback_while_closing(self):
# add_callback should not fail if it races with another thread
# closing the IOLoop. The callbacks are dropped silently
# without executing.
closing = threading.Event()
def target():
other_ioloop.add_callback(other_ioloop.stop)
other_ioloop.start()
closing.set()
other_ioloop.close(all_fds=True)
other_ioloop = IOLoop()
thread = threading.Thread(target=target)
thread.start()
closing.wait()
for i in range(1000):
other_ioloop.add_callback(lambda: None)
@skipIfNonUnix # just because socketpair is so convenient
def test_read_while_writeable(self):
# Ensure that write events don't come in while we're waiting for
# a read and haven't asked for writeability. (the reverse is
# difficult to test for)
client, server = socket.socketpair()
try:
def handler(fd, events):
self.assertEqual(events, IOLoop.READ)
self.stop()
self.io_loop.add_handler(client.fileno(), handler, IOLoop.READ)
self.io_loop.add_timeout(
self.io_loop.time() + 0.01, functools.partial(server.send, b"asdf") # type: ignore
)
self.wait()
self.io_loop.remove_handler(client.fileno())
finally:
client.close()
server.close()
def test_remove_timeout_after_fire(self):
# It is not an error to call remove_timeout after it has run.
handle = self.io_loop.add_timeout(self.io_loop.time(), self.stop)
self.wait()
self.io_loop.remove_timeout(handle)
def test_remove_timeout_cleanup(self):
# Add and remove enough callbacks to trigger cleanup.
# Not a very thorough test, but it ensures that the cleanup code
# gets executed and doesn't blow up. This test is only really useful
# on PollIOLoop subclasses, but it should run silently on any
# implementation.
for i in range(2000):
timeout = self.io_loop.add_timeout(self.io_loop.time() + 3600, lambda: None)
self.io_loop.remove_timeout(timeout)
# HACK: wait two IOLoop iterations for the GC to happen.
self.io_loop.add_callback(lambda: self.io_loop.add_callback(self.stop))
self.wait()
def test_remove_timeout_from_timeout(self):
calls = [False, False]
# Schedule several callbacks and wait for them all to come due at once.
# t2 should be cancelled by t1, even though it is already scheduled to
# be run before the ioloop even looks at it.
now = self.io_loop.time()
def t1():
calls[0] = True
self.io_loop.remove_timeout(t2_handle)
self.io_loop.add_timeout(now + 0.01, t1)
def t2():
calls[1] = True
t2_handle = self.io_loop.add_timeout(now + 0.02, t2)
self.io_loop.add_timeout(now + 0.03, self.stop)
time.sleep(0.03)
self.wait()
self.assertEqual(calls, [True, False])
def test_timeout_with_arguments(self):
# This tests that all the timeout methods pass through *args correctly.
results = [] # type: List[int]
self.io_loop.add_timeout(self.io_loop.time(), results.append, 1)
self.io_loop.add_timeout(datetime.timedelta(seconds=0), results.append, 2)
self.io_loop.call_at(self.io_loop.time(), results.append, 3)
self.io_loop.call_later(0, results.append, 4)
self.io_loop.call_later(0, self.stop)
self.wait()
# The asyncio event loop does not guarantee the order of these
# callbacks.
self.assertEqual(sorted(results), [1, 2, 3, 4])
def test_add_timeout_return(self):
# All the timeout methods return non-None handles that can be
# passed to remove_timeout.
handle = self.io_loop.add_timeout(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_at_return(self):
handle = self.io_loop.call_at(self.io_loop.time(), lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_call_later_return(self):
handle = self.io_loop.call_later(0, lambda: None)
self.assertFalse(handle is None)
self.io_loop.remove_timeout(handle)
def test_close_file_object(self):
"""When a file object is used instead of a numeric file descriptor,
the object should be closed (by IOLoop.close(all_fds=True),
not just the fd.
"""
# Use a socket since they are supported by IOLoop on all platforms.
# Unfortunately, sockets don't support the .closed attribute for
# inspecting their close status, so we must use a wrapper.
class SocketWrapper(object):
def __init__(self, sockobj):
self.sockobj = sockobj
self.closed = False
def fileno(self):
return self.sockobj.fileno()
def close(self):
self.closed = True
self.sockobj.close()
sockobj, port = bind_unused_port()
socket_wrapper = SocketWrapper(sockobj)
io_loop = IOLoop()
io_loop.add_handler(socket_wrapper, lambda fd, events: None, IOLoop.READ)
io_loop.close(all_fds=True)
self.assertTrue(socket_wrapper.closed)
def test_handler_callback_file_object(self):
"""The handler callback receives the same fd object it passed in."""
server_sock, port = bind_unused_port()
fds = []
def handle_connection(fd, events):
fds.append(fd)
conn, addr = server_sock.accept()
conn.close()
self.stop()
self.io_loop.add_handler(server_sock, handle_connection, IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(("127.0.0.1", port))
self.wait()
self.io_loop.remove_handler(server_sock)
self.io_loop.add_handler(server_sock.fileno(), handle_connection, IOLoop.READ)
with contextlib.closing(socket.socket()) as client_sock:
client_sock.connect(("127.0.0.1", port))
self.wait()
self.assertIs(fds[0], server_sock)
self.assertEqual(fds[1], server_sock.fileno())
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_mixed_fd_fileobj(self):
server_sock, port = bind_unused_port()
def f(fd, events):
pass
self.io_loop.add_handler(server_sock, f, IOLoop.READ)
with self.assertRaises(Exception):
# The exact error is unspecified - some implementations use
# IOError, others use ValueError.
self.io_loop.add_handler(server_sock.fileno(), f, IOLoop.READ)
self.io_loop.remove_handler(server_sock.fileno())
server_sock.close()
def test_reentrant(self):
"""Calling start() twice should raise an error, not deadlock."""
returned_from_start = [False]
got_exception = [False]
def callback():
try:
self.io_loop.start()
returned_from_start[0] = True
except Exception:
got_exception[0] = True
self.stop()
self.io_loop.add_callback(callback)
self.wait()
self.assertTrue(got_exception[0])
self.assertFalse(returned_from_start[0])
def test_exception_logging(self):
"""Uncaught exceptions get logged by the IOLoop."""
self.io_loop.add_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_exception_logging_future(self):
"""The IOLoop examines exceptions from Futures and logs them."""
@gen.coroutine
def callback():
self.io_loop.add_callback(self.stop)
1 / 0
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_exception_logging_native_coro(self):
"""The IOLoop examines exceptions from awaitables and logs them."""
async def callback():
# Stop the IOLoop two iterations after raising an exception
# to give the exception time to be logged.
self.io_loop.add_callback(self.io_loop.add_callback, self.stop)
1 / 0
self.io_loop.add_callback(callback)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
def test_spawn_callback(self):
# Both add_callback and spawn_callback run directly on the IOLoop,
# so their errors are logged without stopping the test.
self.io_loop.add_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
# A spawned callback is run directly on the IOLoop, so it will be
# logged without stopping the test.
self.io_loop.spawn_callback(lambda: 1 / 0)
self.io_loop.add_callback(self.stop)
with ExpectLog(app_log, "Exception in callback"):
self.wait()
@skipIfNonUnix
def test_remove_handler_from_handler(self):
# Create two sockets with simultaneous read events.
client, server = socket.socketpair()
try:
client.send(b"abc")
server.send(b"abc")
# After reading from one fd, remove the other from the IOLoop.
chunks = []
def handle_read(fd, events):
chunks.append(fd.recv(1024))
if fd is client:
self.io_loop.remove_handler(server)
else:
self.io_loop.remove_handler(client)
self.io_loop.add_handler(client, handle_read, self.io_loop.READ)
self.io_loop.add_handler(server, handle_read, self.io_loop.READ)
self.io_loop.call_later(0.1, self.stop)
self.wait()
# Only one fd was read; the other was cleanly removed.
self.assertEqual(chunks, [b"abc"])
finally:
client.close()
server.close()
@skipIfNonUnix
@gen_test
def test_init_close_race(self):
# Regression test for #2367
#
# Skipped on windows because of what looks like a bug in the
# proactor event loop when started and stopped on non-main
# threads.
def f():
for i in range(10):
loop = IOLoop()
loop.close()
yield gen.multi([self.io_loop.run_in_executor(None, f) for i in range(2)])
# Deliberately not a subclass of AsyncTestCase so the IOLoop isn't
# automatically set as current.
class TestIOLoopCurrent(unittest.TestCase):
def setUp(self):
self.io_loop = None # type: typing.Optional[IOLoop]
IOLoop.clear_current()
def tearDown(self):
if self.io_loop is not None:
self.io_loop.close()
def test_default_current(self):
self.io_loop = IOLoop()
# The first IOLoop with default arguments is made current.
self.assertIs(self.io_loop, IOLoop.current())
# A second IOLoop can be created but is not made current.
io_loop2 = IOLoop()
self.assertIs(self.io_loop, IOLoop.current())
io_loop2.close()
def test_non_current(self):
self.io_loop = IOLoop(make_current=False)
# The new IOLoop is not initially made current.
self.assertIsNone(IOLoop.current(instance=False))
# Starting the IOLoop makes it current, and stopping the loop
# makes it non-current. This process is repeatable.
for i in range(3):
def f():
self.current_io_loop = IOLoop.current()
assert self.io_loop is not None
self.io_loop.stop()
self.io_loop.add_callback(f)
self.io_loop.start()
self.assertIs(self.current_io_loop, self.io_loop)
# Now that the loop is stopped, it is no longer current.
self.assertIsNone(IOLoop.current(instance=False))
def test_force_current(self):
self.io_loop = IOLoop(make_current=True)
self.assertIs(self.io_loop, IOLoop.current())
with self.assertRaises(RuntimeError):
# A second make_current=True construction cannot succeed.
IOLoop(make_current=True)
# current() was not affected by the failed construction.
self.assertIs(self.io_loop, IOLoop.current())
class TestIOLoopCurrentAsync(AsyncTestCase):
@gen_test
def test_clear_without_current(self):
# If there is no current IOLoop, clear_current is a no-op (but
# should not fail). Use a thread so we see the threading.Local
# in a pristine state.
with ThreadPoolExecutor(1) as e:
yield e.submit(IOLoop.clear_current)
class TestIOLoopFutures(AsyncTestCase):
def test_add_future_threads(self):
with futures.ThreadPoolExecutor(1) as pool:
def dummy():
pass
self.io_loop.add_future(
pool.submit(dummy), lambda future: self.stop(future)
)
future = self.wait()
self.assertTrue(future.done())
self.assertTrue(future.result() is None)
@gen_test
def test_run_in_executor_gen(self):
event1 = threading.Event()
event2 = threading.Event()
def sync_func(self_event, other_event):
self_event.set()
other_event.wait()
# Note that return value doesn't actually do anything,
# it is just passed through to our final assertion to
# make sure it is passed through properly.
return self_event
# Run two synchronous functions, which would deadlock if not
# run in parallel.
res = yield [
IOLoop.current().run_in_executor(None, sync_func, event1, event2),
IOLoop.current().run_in_executor(None, sync_func, event2, event1),
]
self.assertEqual([event1, event2], res)
@gen_test
def test_run_in_executor_native(self):
event1 = threading.Event()
event2 = threading.Event()
def sync_func(self_event, other_event):
self_event.set()
other_event.wait()
return self_event
# Go through an async wrapper to ensure that the result of
# run_in_executor works with await and not just gen.coroutine
# (simply passing the underlying concurrent future would do that).
async def async_wrapper(self_event, other_event):
return await IOLoop.current().run_in_executor(
None, sync_func, self_event, other_event
)
res = yield [async_wrapper(event1, event2), async_wrapper(event2, event1)]
self.assertEqual([event1, event2], res)
@gen_test
def test_set_default_executor(self):
count = [0]
class MyExecutor(futures.ThreadPoolExecutor):
def submit(self, func, *args):
count[0] += 1
return super().submit(func, *args)
event = threading.Event()
def sync_func():
event.set()
executor = MyExecutor(1)
loop = IOLoop.current()
loop.set_default_executor(executor)
yield loop.run_in_executor(None, sync_func)
self.assertEqual(1, count[0])
self.assertTrue(event.is_set())
class TestIOLoopRunSync(unittest.TestCase):
def setUp(self):
self.io_loop = IOLoop()
def tearDown(self):
self.io_loop.close()
def test_sync_result(self):
with self.assertRaises(gen.BadYieldError):
self.io_loop.run_sync(lambda: 42)
def test_sync_exception(self):
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(lambda: 1 / 0)
def test_async_result(self):
@gen.coroutine
def f():
yield gen.moment
raise gen.Return(42)
self.assertEqual(self.io_loop.run_sync(f), 42)
def test_async_exception(self):
@gen.coroutine
def f():
yield gen.moment
1 / 0
with self.assertRaises(ZeroDivisionError):
self.io_loop.run_sync(f)
def test_current(self):
def f():
self.assertIs(IOLoop.current(), self.io_loop)
self.io_loop.run_sync(f)
def test_timeout(self):
@gen.coroutine
def f():
yield gen.sleep(1)
self.assertRaises(TimeoutError, self.io_loop.run_sync, f, timeout=0.01)
def test_native_coroutine(self):
@gen.coroutine
def f1():
yield gen.moment
async def f2():
await f1()
self.io_loop.run_sync(f2)
class TestPeriodicCallbackMath(unittest.TestCase):
def simulate_calls(self, pc, durations):
"""Simulate a series of calls to the PeriodicCallback.
Pass a list of call durations in seconds (negative values
work to simulate clock adjustments during the call, or more or
less equivalently, between calls). This method returns the
times at which each call would be made.
"""
calls = []
now = 1000
pc._next_timeout = now
for d in durations:
pc._update_next(now)
calls.append(pc._next_timeout)
now = pc._next_timeout + d
return calls
def dummy(self):
pass
def test_basic(self):
pc = PeriodicCallback(self.dummy, 10000)
self.assertEqual(
self.simulate_calls(pc, [0] * 5), [1010, 1020, 1030, 1040, 1050]
)
def test_overrun(self):
# If a call runs for too long, we skip entire cycles to get
# back on schedule.
call_durations = [9, 9, 10, 11, 20, 20, 35, 35, 0, 0, 0]
expected = [
1010,
1020,
1030, # first 3 calls on schedule
1050,
1070, # next 2 delayed one cycle
1100,
1130, # next 2 delayed 2 cycles
1170,
1210, # next 2 delayed 3 cycles
1220,
1230, # then back on schedule.
]
pc = PeriodicCallback(self.dummy, 10000)
self.assertEqual(self.simulate_calls(pc, call_durations), expected)
def test_clock_backwards(self):
pc = PeriodicCallback(self.dummy, 10000)
# Backwards jumps are ignored, potentially resulting in a
# slightly slow schedule (although we assume that when
# time.time() and time.monotonic() are different, time.time()
# is getting adjusted by NTP and is therefore more accurate)
self.assertEqual(
self.simulate_calls(pc, [-2, -1, -3, -2, 0]), [1010, 1020, 1030, 1040, 1050]
)
# For big jumps, we should perhaps alter the schedule, but we
# don't currently. This trace shows that we run callbacks
# every 10s of time.time(), but the first and second calls are
# 110s of real time apart because the backwards jump is
# ignored.
self.assertEqual(self.simulate_calls(pc, [-100, 0, 0]), [1010, 1020, 1030])
def test_jitter(self):
random_times = [0.5, 1, 0, 0.75]
expected = [1010, 1022.5, 1030, 1041.25]
call_durations = [0] * len(random_times)
pc = PeriodicCallback(self.dummy, 10000, jitter=0.5)
def mock_random():
return random_times.pop(0)
with mock.patch("random.random", mock_random):
self.assertEqual(self.simulate_calls(pc, call_durations), expected)
class TestIOLoopConfiguration(unittest.TestCase):
def run_python(self, *statements):
stmt_list = [
"from tornado.ioloop import IOLoop",
"classname = lambda x: x.__class__.__name__",
] + list(statements)
args = [sys.executable, "-c", "; ".join(stmt_list)]
return native_str(subprocess.check_output(args)).strip()
def test_default(self):
# When asyncio is available, it is used by default.
cls = self.run_python("print(classname(IOLoop.current()))")
self.assertEqual(cls, "AsyncIOMainLoop")
cls = self.run_python("print(classname(IOLoop()))")
self.assertEqual(cls, "AsyncIOLoop")
def test_asyncio(self):
cls = self.run_python(
'IOLoop.configure("tornado.platform.asyncio.AsyncIOLoop")',
"print(classname(IOLoop.current()))",
)
self.assertEqual(cls, "AsyncIOMainLoop")
def test_asyncio_main(self):
cls = self.run_python(
"from tornado.platform.asyncio import AsyncIOMainLoop",
"AsyncIOMainLoop().install()",
"print(classname(IOLoop.current()))",
)
self.assertEqual(cls, "AsyncIOMainLoop")
if __name__ == "__main__":
unittest.main()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,151 @@
import datetime
import os
import shutil
import tempfile
import unittest
import tornado.locale
from tornado.escape import utf8, to_unicode
from tornado.util import unicode_type
class TranslationLoaderTest(unittest.TestCase):
# TODO: less hacky way to get isolated tests
SAVE_VARS = ["_translations", "_supported_locales", "_use_gettext"]
def clear_locale_cache(self):
tornado.locale.Locale._cache = {}
def setUp(self):
self.saved = {} # type: dict
for var in TranslationLoaderTest.SAVE_VARS:
self.saved[var] = getattr(tornado.locale, var)
self.clear_locale_cache()
def tearDown(self):
for k, v in self.saved.items():
setattr(tornado.locale, k, v)
self.clear_locale_cache()
def test_csv(self):
tornado.locale.load_translations(
os.path.join(os.path.dirname(__file__), "csv_translations")
)
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.CSVLocale))
self.assertEqual(locale.translate("school"), u"\u00e9cole")
def test_csv_bom(self):
with open(
os.path.join(os.path.dirname(__file__), "csv_translations", "fr_FR.csv"),
"rb",
) as f:
char_data = to_unicode(f.read())
# Re-encode our input data (which is utf-8 without BOM) in
# encodings that use the BOM and ensure that we can still load
# it. Note that utf-16-le and utf-16-be do not write a BOM,
# so we only test whichver variant is native to our platform.
for encoding in ["utf-8-sig", "utf-16"]:
tmpdir = tempfile.mkdtemp()
try:
with open(os.path.join(tmpdir, "fr_FR.csv"), "wb") as f:
f.write(char_data.encode(encoding))
tornado.locale.load_translations(tmpdir)
locale = tornado.locale.get("fr_FR")
self.assertIsInstance(locale, tornado.locale.CSVLocale)
self.assertEqual(locale.translate("school"), u"\u00e9cole")
finally:
shutil.rmtree(tmpdir)
def test_gettext(self):
tornado.locale.load_gettext_translations(
os.path.join(os.path.dirname(__file__), "gettext_translations"),
"tornado_test",
)
locale = tornado.locale.get("fr_FR")
self.assertTrue(isinstance(locale, tornado.locale.GettextLocale))
self.assertEqual(locale.translate("school"), u"\u00e9cole")
self.assertEqual(locale.pgettext("law", "right"), u"le droit")
self.assertEqual(locale.pgettext("good", "right"), u"le bien")
self.assertEqual(
locale.pgettext("organization", "club", "clubs", 1), u"le club"
)
self.assertEqual(
locale.pgettext("organization", "club", "clubs", 2), u"les clubs"
)
self.assertEqual(locale.pgettext("stick", "club", "clubs", 1), u"le b\xe2ton")
self.assertEqual(locale.pgettext("stick", "club", "clubs", 2), u"les b\xe2tons")
class LocaleDataTest(unittest.TestCase):
def test_non_ascii_name(self):
name = tornado.locale.LOCALE_NAMES["es_LA"]["name"]
self.assertTrue(isinstance(name, unicode_type))
self.assertEqual(name, u"Espa\u00f1ol")
self.assertEqual(utf8(name), b"Espa\xc3\xb1ol")
class EnglishTest(unittest.TestCase):
def test_format_date(self):
locale = tornado.locale.get("en_US")
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(
locale.format_date(date, full_format=True), "April 28, 2013 at 6:35 pm"
)
now = datetime.datetime.utcnow()
self.assertEqual(
locale.format_date(now - datetime.timedelta(seconds=2), full_format=False),
"2 seconds ago",
)
self.assertEqual(
locale.format_date(now - datetime.timedelta(minutes=2), full_format=False),
"2 minutes ago",
)
self.assertEqual(
locale.format_date(now - datetime.timedelta(hours=2), full_format=False),
"2 hours ago",
)
self.assertEqual(
locale.format_date(
now - datetime.timedelta(days=1), full_format=False, shorter=True
),
"yesterday",
)
date = now - datetime.timedelta(days=2)
self.assertEqual(
locale.format_date(date, full_format=False, shorter=True),
locale._weekdays[date.weekday()],
)
date = now - datetime.timedelta(days=300)
self.assertEqual(
locale.format_date(date, full_format=False, shorter=True),
"%s %d" % (locale._months[date.month - 1], date.day),
)
date = now - datetime.timedelta(days=500)
self.assertEqual(
locale.format_date(date, full_format=False, shorter=True),
"%s %d, %d" % (locale._months[date.month - 1], date.day, date.year),
)
def test_friendly_number(self):
locale = tornado.locale.get("en_US")
self.assertEqual(locale.friendly_number(1000000), "1,000,000")
def test_list(self):
locale = tornado.locale.get("en_US")
self.assertEqual(locale.list([]), "")
self.assertEqual(locale.list(["A"]), "A")
self.assertEqual(locale.list(["A", "B"]), "A and B")
self.assertEqual(locale.list(["A", "B", "C"]), "A, B and C")
def test_format_day(self):
locale = tornado.locale.get("en_US")
date = datetime.datetime(2013, 4, 28, 18, 35)
self.assertEqual(locale.format_day(date=date, dow=True), "Sunday, April 28")
self.assertEqual(locale.format_day(date=date, dow=False), "April 28")

View file

@ -0,0 +1,535 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import asyncio
from datetime import timedelta
import typing # noqa: F401
import unittest
from tornado import gen, locks
from tornado.gen import TimeoutError
from tornado.testing import gen_test, AsyncTestCase
class ConditionTest(AsyncTestCase):
def setUp(self):
super().setUp()
self.history = [] # type: typing.List[typing.Union[int, str]]
def record_done(self, future, key):
"""Record the resolution of a Future returned by Condition.wait."""
def callback(_):
if not future.result():
# wait() resolved to False, meaning it timed out.
self.history.append("timeout")
else:
self.history.append(key)
future.add_done_callback(callback)
def loop_briefly(self):
"""Run all queued callbacks on the IOLoop.
In these tests, this method is used after calling notify() to
preserve the pre-5.0 behavior in which callbacks ran
synchronously.
"""
self.io_loop.add_callback(self.stop)
self.wait()
def test_repr(self):
c = locks.Condition()
self.assertIn("Condition", repr(c))
self.assertNotIn("waiters", repr(c))
c.wait()
self.assertIn("waiters", repr(c))
@gen_test
def test_notify(self):
c = locks.Condition()
self.io_loop.call_later(0.01, c.notify)
yield c.wait()
def test_notify_1(self):
c = locks.Condition()
self.record_done(c.wait(), "wait1")
self.record_done(c.wait(), "wait2")
c.notify(1)
self.loop_briefly()
self.history.append("notify1")
c.notify(1)
self.loop_briefly()
self.history.append("notify2")
self.assertEqual(["wait1", "notify1", "wait2", "notify2"], self.history)
def test_notify_n(self):
c = locks.Condition()
for i in range(6):
self.record_done(c.wait(), i)
c.notify(3)
self.loop_briefly()
# Callbacks execute in the order they were registered.
self.assertEqual(list(range(3)), self.history)
c.notify(1)
self.loop_briefly()
self.assertEqual(list(range(4)), self.history)
c.notify(2)
self.loop_briefly()
self.assertEqual(list(range(6)), self.history)
def test_notify_all(self):
c = locks.Condition()
for i in range(4):
self.record_done(c.wait(), i)
c.notify_all()
self.loop_briefly()
self.history.append("notify_all")
# Callbacks execute in the order they were registered.
self.assertEqual(list(range(4)) + ["notify_all"], self.history) # type: ignore
@gen_test
def test_wait_timeout(self):
c = locks.Condition()
wait = c.wait(timedelta(seconds=0.01))
self.io_loop.call_later(0.02, c.notify) # Too late.
yield gen.sleep(0.03)
self.assertFalse((yield wait))
@gen_test
def test_wait_timeout_preempted(self):
c = locks.Condition()
# This fires before the wait times out.
self.io_loop.call_later(0.01, c.notify)
wait = c.wait(timedelta(seconds=0.02))
yield gen.sleep(0.03)
yield wait # No TimeoutError.
@gen_test
def test_notify_n_with_timeout(self):
# Register callbacks 0, 1, 2, and 3. Callback 1 has a timeout.
# Wait for that timeout to expire, then do notify(2) and make
# sure everyone runs. Verifies that a timed-out callback does
# not count against the 'n' argument to notify().
c = locks.Condition()
self.record_done(c.wait(), 0)
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
self.record_done(c.wait(), 2)
self.record_done(c.wait(), 3)
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
self.assertEqual(["timeout"], self.history)
c.notify(2)
yield gen.sleep(0.01)
self.assertEqual(["timeout", 0, 2], self.history)
self.assertEqual(["timeout", 0, 2], self.history)
c.notify()
yield
self.assertEqual(["timeout", 0, 2, 3], self.history)
@gen_test
def test_notify_all_with_timeout(self):
c = locks.Condition()
self.record_done(c.wait(), 0)
self.record_done(c.wait(timedelta(seconds=0.01)), 1)
self.record_done(c.wait(), 2)
# Wait for callback 1 to time out.
yield gen.sleep(0.02)
self.assertEqual(["timeout"], self.history)
c.notify_all()
yield
self.assertEqual(["timeout", 0, 2], self.history)
@gen_test
def test_nested_notify(self):
# Ensure no notifications lost, even if notify() is reentered by a
# waiter calling notify().
c = locks.Condition()
# Three waiters.
futures = [asyncio.ensure_future(c.wait()) for _ in range(3)]
# First and second futures resolved. Second future reenters notify(),
# resolving third future.
futures[1].add_done_callback(lambda _: c.notify())
c.notify(2)
yield
self.assertTrue(all(f.done() for f in futures))
@gen_test
def test_garbage_collection(self):
# Test that timed-out waiters are occasionally cleaned from the queue.
c = locks.Condition()
for _ in range(101):
c.wait(timedelta(seconds=0.01))
future = asyncio.ensure_future(c.wait())
self.assertEqual(102, len(c._waiters))
# Let first 101 waiters time out, triggering a collection.
yield gen.sleep(0.02)
self.assertEqual(1, len(c._waiters))
# Final waiter is still active.
self.assertFalse(future.done())
c.notify()
self.assertTrue(future.done())
class EventTest(AsyncTestCase):
def test_repr(self):
event = locks.Event()
self.assertTrue("clear" in str(event))
self.assertFalse("set" in str(event))
event.set()
self.assertFalse("clear" in str(event))
self.assertTrue("set" in str(event))
def test_event(self):
e = locks.Event()
future_0 = asyncio.ensure_future(e.wait())
e.set()
future_1 = asyncio.ensure_future(e.wait())
e.clear()
future_2 = asyncio.ensure_future(e.wait())
self.assertTrue(future_0.done())
self.assertTrue(future_1.done())
self.assertFalse(future_2.done())
@gen_test
def test_event_timeout(self):
e = locks.Event()
with self.assertRaises(TimeoutError):
yield e.wait(timedelta(seconds=0.01))
# After a timed-out waiter, normal operation works.
self.io_loop.add_timeout(timedelta(seconds=0.01), e.set)
yield e.wait(timedelta(seconds=1))
def test_event_set_multiple(self):
e = locks.Event()
e.set()
e.set()
self.assertTrue(e.is_set())
def test_event_wait_clear(self):
e = locks.Event()
f0 = asyncio.ensure_future(e.wait())
e.clear()
f1 = asyncio.ensure_future(e.wait())
e.set()
self.assertTrue(f0.done())
self.assertTrue(f1.done())
class SemaphoreTest(AsyncTestCase):
def test_negative_value(self):
self.assertRaises(ValueError, locks.Semaphore, value=-1)
def test_repr(self):
sem = locks.Semaphore()
self.assertIn("Semaphore", repr(sem))
self.assertIn("unlocked,value:1", repr(sem))
sem.acquire()
self.assertIn("locked", repr(sem))
self.assertNotIn("waiters", repr(sem))
sem.acquire()
self.assertIn("waiters", repr(sem))
def test_acquire(self):
sem = locks.Semaphore()
f0 = asyncio.ensure_future(sem.acquire())
self.assertTrue(f0.done())
# Wait for release().
f1 = asyncio.ensure_future(sem.acquire())
self.assertFalse(f1.done())
f2 = asyncio.ensure_future(sem.acquire())
sem.release()
self.assertTrue(f1.done())
self.assertFalse(f2.done())
sem.release()
self.assertTrue(f2.done())
sem.release()
# Now acquire() is instant.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
self.assertEqual(0, len(sem._waiters))
@gen_test
def test_acquire_timeout(self):
sem = locks.Semaphore(2)
yield sem.acquire()
yield sem.acquire()
acquire = sem.acquire(timedelta(seconds=0.01))
self.io_loop.call_later(0.02, sem.release) # Too late.
yield gen.sleep(0.3)
with self.assertRaises(gen.TimeoutError):
yield acquire
sem.acquire()
f = asyncio.ensure_future(sem.acquire())
self.assertFalse(f.done())
sem.release()
self.assertTrue(f.done())
@gen_test
def test_acquire_timeout_preempted(self):
sem = locks.Semaphore(1)
yield sem.acquire()
# This fires before the wait times out.
self.io_loop.call_later(0.01, sem.release)
acquire = sem.acquire(timedelta(seconds=0.02))
yield gen.sleep(0.03)
yield acquire # No TimeoutError.
def test_release_unacquired(self):
# Unbounded releases are allowed, and increment the semaphore's value.
sem = locks.Semaphore()
sem.release()
sem.release()
# Now the counter is 3. We can acquire three times before blocking.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
self.assertFalse(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_garbage_collection(self):
# Test that timed-out waiters are occasionally cleaned from the queue.
sem = locks.Semaphore(value=0)
futures = [
asyncio.ensure_future(sem.acquire(timedelta(seconds=0.01)))
for _ in range(101)
]
future = asyncio.ensure_future(sem.acquire())
self.assertEqual(102, len(sem._waiters))
# Let first 101 waiters time out, triggering a collection.
yield gen.sleep(0.02)
self.assertEqual(1, len(sem._waiters))
# Final waiter is still active.
self.assertFalse(future.done())
sem.release()
self.assertTrue(future.done())
# Prevent "Future exception was never retrieved" messages.
for future in futures:
self.assertRaises(TimeoutError, future.result)
class SemaphoreContextManagerTest(AsyncTestCase):
@gen_test
def test_context_manager(self):
sem = locks.Semaphore()
with (yield sem.acquire()) as yielded:
self.assertTrue(yielded is None)
# Semaphore was released and can be acquired again.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_context_manager_async_await(self):
# Repeat the above test using 'async with'.
sem = locks.Semaphore()
async def f():
async with sem as yielded:
self.assertTrue(yielded is None)
yield f()
# Semaphore was released and can be acquired again.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_context_manager_exception(self):
sem = locks.Semaphore()
with self.assertRaises(ZeroDivisionError):
with (yield sem.acquire()):
1 / 0
# Semaphore was released and can be acquired again.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_context_manager_timeout(self):
sem = locks.Semaphore()
with (yield sem.acquire(timedelta(seconds=0.01))):
pass
# Semaphore was released and can be acquired again.
self.assertTrue(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_context_manager_timeout_error(self):
sem = locks.Semaphore(value=0)
with self.assertRaises(gen.TimeoutError):
with (yield sem.acquire(timedelta(seconds=0.01))):
pass
# Counter is still 0.
self.assertFalse(asyncio.ensure_future(sem.acquire()).done())
@gen_test
def test_context_manager_contended(self):
sem = locks.Semaphore()
history = []
@gen.coroutine
def f(index):
with (yield sem.acquire()):
history.append("acquired %d" % index)
yield gen.sleep(0.01)
history.append("release %d" % index)
yield [f(i) for i in range(2)]
expected_history = []
for i in range(2):
expected_history.extend(["acquired %d" % i, "release %d" % i])
self.assertEqual(expected_history, history)
@gen_test
def test_yield_sem(self):
# Ensure we catch a "with (yield sem)", which should be
# "with (yield sem.acquire())".
with self.assertRaises(gen.BadYieldError):
with (yield locks.Semaphore()):
pass
def test_context_manager_misuse(self):
# Ensure we catch a "with sem", which should be
# "with (yield sem.acquire())".
with self.assertRaises(RuntimeError):
with locks.Semaphore():
pass
class BoundedSemaphoreTest(AsyncTestCase):
def test_release_unacquired(self):
sem = locks.BoundedSemaphore()
self.assertRaises(ValueError, sem.release)
# Value is 0.
sem.acquire()
# Block on acquire().
future = asyncio.ensure_future(sem.acquire())
self.assertFalse(future.done())
sem.release()
self.assertTrue(future.done())
# Value is 1.
sem.release()
self.assertRaises(ValueError, sem.release)
class LockTests(AsyncTestCase):
def test_repr(self):
lock = locks.Lock()
# No errors.
repr(lock)
lock.acquire()
repr(lock)
def test_acquire_release(self):
lock = locks.Lock()
self.assertTrue(asyncio.ensure_future(lock.acquire()).done())
future = asyncio.ensure_future(lock.acquire())
self.assertFalse(future.done())
lock.release()
self.assertTrue(future.done())
@gen_test
def test_acquire_fifo(self):
lock = locks.Lock()
self.assertTrue(asyncio.ensure_future(lock.acquire()).done())
N = 5
history = []
@gen.coroutine
def f(idx):
with (yield lock.acquire()):
history.append(idx)
futures = [f(i) for i in range(N)]
self.assertFalse(any(future.done() for future in futures))
lock.release()
yield futures
self.assertEqual(list(range(N)), history)
@gen_test
def test_acquire_fifo_async_with(self):
# Repeat the above test using `async with lock:`
# instead of `with (yield lock.acquire()):`.
lock = locks.Lock()
self.assertTrue(asyncio.ensure_future(lock.acquire()).done())
N = 5
history = []
async def f(idx):
async with lock:
history.append(idx)
futures = [f(i) for i in range(N)]
lock.release()
yield futures
self.assertEqual(list(range(N)), history)
@gen_test
def test_acquire_timeout(self):
lock = locks.Lock()
lock.acquire()
with self.assertRaises(gen.TimeoutError):
yield lock.acquire(timeout=timedelta(seconds=0.01))
# Still locked.
self.assertFalse(asyncio.ensure_future(lock.acquire()).done())
def test_multi_release(self):
lock = locks.Lock()
self.assertRaises(RuntimeError, lock.release)
lock.acquire()
lock.release()
self.assertRaises(RuntimeError, lock.release)
@gen_test
def test_yield_lock(self):
# Ensure we catch a "with (yield lock)", which should be
# "with (yield lock.acquire())".
with self.assertRaises(gen.BadYieldError):
with (yield locks.Lock()):
pass
def test_context_manager_misuse(self):
# Ensure we catch a "with lock", which should be
# "with (yield lock.acquire())".
with self.assertRaises(RuntimeError):
with locks.Lock():
pass
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,245 @@
#
# Copyright 2012 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import contextlib
import glob
import logging
import os
import re
import subprocess
import sys
import tempfile
import unittest
import warnings
from tornado.escape import utf8
from tornado.log import LogFormatter, define_logging_options, enable_pretty_logging
from tornado.options import OptionParser
from tornado.util import basestring_type
@contextlib.contextmanager
def ignore_bytes_warning():
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=BytesWarning)
yield
class LogFormatterTest(unittest.TestCase):
# Matches the output of a single logging call (which may be multiple lines
# if a traceback was included, so we use the DOTALL option)
LINE_RE = re.compile(
b"(?s)\x01\\[E [0-9]{6} [0-9]{2}:[0-9]{2}:[0-9]{2} log_test:[0-9]+\\]\x02 (.*)"
)
def setUp(self):
self.formatter = LogFormatter(color=False)
# Fake color support. We can't guarantee anything about the $TERM
# variable when the tests are run, so just patch in some values
# for testing. (testing with color off fails to expose some potential
# encoding issues from the control characters)
self.formatter._colors = {logging.ERROR: u"\u0001"}
self.formatter._normal = u"\u0002"
# construct a Logger directly to bypass getLogger's caching
self.logger = logging.Logger("LogFormatterTest")
self.logger.propagate = False
self.tempdir = tempfile.mkdtemp()
self.filename = os.path.join(self.tempdir, "log.out")
self.handler = self.make_handler(self.filename)
self.handler.setFormatter(self.formatter)
self.logger.addHandler(self.handler)
def tearDown(self):
self.handler.close()
os.unlink(self.filename)
os.rmdir(self.tempdir)
def make_handler(self, filename):
# Base case: default setup without explicit encoding.
# In python 2, supports arbitrary byte strings and unicode objects
# that contain only ascii. In python 3, supports ascii-only unicode
# strings (but byte strings will be repr'd automatically).
return logging.FileHandler(filename)
def get_output(self):
with open(self.filename, "rb") as f:
line = f.read().strip()
m = LogFormatterTest.LINE_RE.match(line)
if m:
return m.group(1)
else:
raise Exception("output didn't match regex: %r" % line)
def test_basic_logging(self):
self.logger.error("foo")
self.assertEqual(self.get_output(), b"foo")
def test_bytes_logging(self):
with ignore_bytes_warning():
# This will be "\xe9" on python 2 or "b'\xe9'" on python 3
self.logger.error(b"\xe9")
self.assertEqual(self.get_output(), utf8(repr(b"\xe9")))
def test_utf8_logging(self):
with ignore_bytes_warning():
self.logger.error(u"\u00e9".encode("utf8"))
if issubclass(bytes, basestring_type):
# on python 2, utf8 byte strings (and by extension ascii byte
# strings) are passed through as-is.
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
else:
# on python 3, byte strings always get repr'd even if
# they're ascii-only, so this degenerates into another
# copy of test_bytes_logging.
self.assertEqual(self.get_output(), utf8(repr(utf8(u"\u00e9"))))
def test_bytes_exception_logging(self):
try:
raise Exception(b"\xe9")
except Exception:
self.logger.exception("caught exception")
# This will be "Exception: \xe9" on python 2 or
# "Exception: b'\xe9'" on python 3.
output = self.get_output()
self.assertRegexpMatches(output, br"Exception.*\\xe9")
# The traceback contains newlines, which should not have been escaped.
self.assertNotIn(br"\n", output)
class UnicodeLogFormatterTest(LogFormatterTest):
def make_handler(self, filename):
# Adding an explicit encoding configuration allows non-ascii unicode
# strings in both python 2 and 3, without changing the behavior
# for byte strings.
return logging.FileHandler(filename, encoding="utf8")
def test_unicode_logging(self):
self.logger.error(u"\u00e9")
self.assertEqual(self.get_output(), utf8(u"\u00e9"))
class EnablePrettyLoggingTest(unittest.TestCase):
def setUp(self):
super().setUp()
self.options = OptionParser()
define_logging_options(self.options)
self.logger = logging.Logger("tornado.test.log_test.EnablePrettyLoggingTest")
self.logger.propagate = False
def test_log_file(self):
tmpdir = tempfile.mkdtemp()
try:
self.options.log_file_prefix = tmpdir + "/test_log"
enable_pretty_logging(options=self.options, logger=self.logger)
self.assertEqual(1, len(self.logger.handlers))
self.logger.error("hello")
self.logger.handlers[0].flush()
filenames = glob.glob(tmpdir + "/test_log*")
self.assertEqual(1, len(filenames))
with open(filenames[0]) as f:
self.assertRegexpMatches(f.read(), r"^\[E [^]]*\] hello$")
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
for filename in glob.glob(tmpdir + "/test_log*"):
os.unlink(filename)
os.rmdir(tmpdir)
def test_log_file_with_timed_rotating(self):
tmpdir = tempfile.mkdtemp()
try:
self.options.log_file_prefix = tmpdir + "/test_log"
self.options.log_rotate_mode = "time"
enable_pretty_logging(options=self.options, logger=self.logger)
self.logger.error("hello")
self.logger.handlers[0].flush()
filenames = glob.glob(tmpdir + "/test_log*")
self.assertEqual(1, len(filenames))
with open(filenames[0]) as f:
self.assertRegexpMatches(f.read(), r"^\[E [^]]*\] hello$")
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
for filename in glob.glob(tmpdir + "/test_log*"):
os.unlink(filename)
os.rmdir(tmpdir)
def test_wrong_rotate_mode_value(self):
try:
self.options.log_file_prefix = "some_path"
self.options.log_rotate_mode = "wrong_mode"
self.assertRaises(
ValueError,
enable_pretty_logging,
options=self.options,
logger=self.logger,
)
finally:
for handler in self.logger.handlers:
handler.flush()
handler.close()
class LoggingOptionTest(unittest.TestCase):
"""Test the ability to enable and disable Tornado's logging hooks."""
def logs_present(self, statement, args=None):
# Each test may manipulate and/or parse the options and then logs
# a line at the 'info' level. This level is ignored in the
# logging module by default, but Tornado turns it on by default
# so it is the easiest way to tell whether tornado's logging hooks
# ran.
IMPORT = "from tornado.options import options, parse_command_line"
LOG_INFO = 'import logging; logging.info("hello")'
program = ";".join([IMPORT, statement, LOG_INFO])
proc = subprocess.Popen(
[sys.executable, "-c", program] + (args or []),
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
)
stdout, stderr = proc.communicate()
self.assertEqual(proc.returncode, 0, "process failed: %r" % stdout)
return b"hello" in stdout
def test_default(self):
self.assertFalse(self.logs_present("pass"))
def test_tornado_default(self):
self.assertTrue(self.logs_present("parse_command_line()"))
def test_disable_command_line(self):
self.assertFalse(self.logs_present("parse_command_line()", ["--logging=none"]))
def test_disable_command_line_case_insensitive(self):
self.assertFalse(self.logs_present("parse_command_line()", ["--logging=None"]))
def test_disable_code_string(self):
self.assertFalse(
self.logs_present('options.logging = "none"; parse_command_line()')
)
def test_disable_code_none(self):
self.assertFalse(
self.logs_present("options.logging = None; parse_command_line()")
)
def test_disable_override(self):
# command line trumps code defaults
self.assertTrue(
self.logs_present(
"options.logging = None; parse_command_line()", ["--logging=info"]
)
)

View file

@ -0,0 +1,233 @@
import errno
import os
import signal
import socket
from subprocess import Popen
import sys
import time
import unittest
from tornado.netutil import (
BlockingResolver,
OverrideResolver,
ThreadedResolver,
is_valid_ip,
bind_sockets,
)
from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
from tornado.test.util import skipIfNoNetwork
import typing
if typing.TYPE_CHECKING:
from typing import List # noqa: F401
try:
import pycares # type: ignore
except ImportError:
pycares = None
else:
from tornado.platform.caresresolver import CaresResolver
try:
import twisted # type: ignore
import twisted.names # type: ignore
except ImportError:
twisted = None
else:
from tornado.platform.twisted import TwistedResolver
class _ResolverTestMixin(object):
resolver = None # type: typing.Any
@gen_test
def test_localhost(self: typing.Any):
addrinfo = yield self.resolver.resolve("localhost", 80, socket.AF_UNSPEC)
self.assertIn((socket.AF_INET, ("127.0.0.1", 80)), addrinfo)
# It is impossible to quickly and consistently generate an error in name
# resolution, so test this case separately, using mocks as needed.
class _ResolverErrorTestMixin(object):
resolver = None # type: typing.Any
@gen_test
def test_bad_host(self: typing.Any):
with self.assertRaises(IOError):
yield self.resolver.resolve("an invalid domain", 80, socket.AF_UNSPEC)
def _failing_getaddrinfo(*args):
"""Dummy implementation of getaddrinfo for use in mocks"""
raise socket.gaierror(errno.EIO, "mock: lookup failed")
@skipIfNoNetwork
class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super().setUp()
self.resolver = BlockingResolver()
# getaddrinfo-based tests need mocking to reliably generate errors;
# some configurations are slow to produce errors and take longer than
# our default timeout.
class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super().setUp()
self.resolver = BlockingResolver()
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super().tearDown()
class OverrideResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super().setUp()
mapping = {
("google.com", 80): ("1.2.3.4", 80),
("google.com", 80, socket.AF_INET): ("1.2.3.4", 80),
("google.com", 80, socket.AF_INET6): (
"2a02:6b8:7c:40c:c51e:495f:e23a:3",
80,
),
}
self.resolver = OverrideResolver(BlockingResolver(), mapping)
@gen_test
def test_resolve_multiaddr(self):
result = yield self.resolver.resolve("google.com", 80, socket.AF_INET)
self.assertIn((socket.AF_INET, ("1.2.3.4", 80)), result)
result = yield self.resolver.resolve("google.com", 80, socket.AF_INET6)
self.assertIn(
(socket.AF_INET6, ("2a02:6b8:7c:40c:c51e:495f:e23a:3", 80, 0, 0)), result
)
@skipIfNoNetwork
class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super().setUp()
self.resolver = ThreadedResolver()
def tearDown(self):
self.resolver.close()
super().tearDown()
class ThreadedResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
def setUp(self):
super().setUp()
self.resolver = BlockingResolver()
self.real_getaddrinfo = socket.getaddrinfo
socket.getaddrinfo = _failing_getaddrinfo
def tearDown(self):
socket.getaddrinfo = self.real_getaddrinfo
super().tearDown()
@skipIfNoNetwork
@unittest.skipIf(sys.platform == "win32", "preexec_fn not available on win32")
class ThreadedResolverImportTest(unittest.TestCase):
def test_import(self):
TIMEOUT = 5
# Test for a deadlock when importing a module that runs the
# ThreadedResolver at import-time. See resolve_test.py for
# full explanation.
command = [sys.executable, "-c", "import tornado.test.resolve_test_helper"]
start = time.time()
popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT))
while time.time() - start < TIMEOUT:
return_code = popen.poll()
if return_code is not None:
self.assertEqual(0, return_code)
return # Success.
time.sleep(0.05)
self.fail("import timed out")
# We do not test errors with CaresResolver:
# Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
# with an NXDOMAIN status code. Most resolvers treat this as an error;
# C-ares returns the results, making the "bad_host" tests unreliable.
# C-ares will try to resolve even malformed names, such as the
# name with spaces used in this test.
@skipIfNoNetwork
@unittest.skipIf(pycares is None, "pycares module not present")
@unittest.skipIf(sys.platform == "win32", "pycares doesn't return loopback on windows")
@unittest.skipIf(sys.platform == "darwin", "pycares doesn't return 127.0.0.1 on darwin")
class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super().setUp()
self.resolver = CaresResolver()
# TwistedResolver produces consistent errors in our test cases so we
# could test the regular and error cases in the same class. However,
# in the error cases it appears that cleanup of socket objects is
# handled asynchronously and occasionally results in "unclosed socket"
# warnings if not given time to shut down (and there is no way to
# explicitly shut it down). This makes the test flaky, so we do not
# test error cases here.
@skipIfNoNetwork
@unittest.skipIf(twisted is None, "twisted module not present")
@unittest.skipIf(
getattr(twisted, "__version__", "0.0") < "12.1", "old version of twisted"
)
@unittest.skipIf(sys.platform == "win32", "twisted resolver hangs on windows")
class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
def setUp(self):
super().setUp()
self.resolver = TwistedResolver()
class IsValidIPTest(unittest.TestCase):
def test_is_valid_ip(self):
self.assertTrue(is_valid_ip("127.0.0.1"))
self.assertTrue(is_valid_ip("4.4.4.4"))
self.assertTrue(is_valid_ip("::1"))
self.assertTrue(is_valid_ip("2620:0:1cfe:face:b00c::3"))
self.assertTrue(not is_valid_ip("www.google.com"))
self.assertTrue(not is_valid_ip("localhost"))
self.assertTrue(not is_valid_ip("4.4.4.4<"))
self.assertTrue(not is_valid_ip(" 127.0.0.1"))
self.assertTrue(not is_valid_ip(""))
self.assertTrue(not is_valid_ip(" "))
self.assertTrue(not is_valid_ip("\n"))
self.assertTrue(not is_valid_ip("\x00"))
class TestPortAllocation(unittest.TestCase):
def test_same_port_allocation(self):
if "TRAVIS" in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
sockets = bind_sockets(0, "localhost")
try:
port = sockets[0].getsockname()[1]
self.assertTrue(all(s.getsockname()[1] == port for s in sockets[1:]))
finally:
for sock in sockets:
sock.close()
@unittest.skipIf(
not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported"
)
def test_reuse_port(self):
sockets = [] # type: List[socket.socket]
socket, port = bind_unused_port(reuse_port=True)
try:
sockets = bind_sockets(port, "127.0.0.1", reuse_port=True)
self.assertTrue(all(s.getsockname()[1] == port for s in sockets))
finally:
socket.close()
for sock in sockets:
sock.close()

View file

@ -0,0 +1,7 @@
port=443
port=443
username='李康'
foo_bar='a'
my_path = __file__

View file

@ -0,0 +1,328 @@
import datetime
from io import StringIO
import os
import sys
from unittest import mock
import unittest
from tornado.options import OptionParser, Error
from tornado.util import basestring_type
from tornado.test.util import subTest
import typing
if typing.TYPE_CHECKING:
from typing import List # noqa: F401
class Email(object):
def __init__(self, value):
if isinstance(value, str) and "@" in value:
self._value = value
else:
raise ValueError()
@property
def value(self):
return self._value
class OptionsTest(unittest.TestCase):
def test_parse_command_line(self):
options = OptionParser()
options.define("port", default=80)
options.parse_command_line(["main.py", "--port=443"])
self.assertEqual(options.port, 443)
def test_parse_config_file(self):
options = OptionParser()
options.define("port", default=80)
options.define("username", default="foo")
options.define("my_path")
config_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "options_test.cfg"
)
options.parse_config_file(config_path)
self.assertEqual(options.port, 443)
self.assertEqual(options.username, "李康")
self.assertEqual(options.my_path, config_path)
def test_parse_callbacks(self):
options = OptionParser()
self.called = False
def callback():
self.called = True
options.add_parse_callback(callback)
# non-final parse doesn't run callbacks
options.parse_command_line(["main.py"], final=False)
self.assertFalse(self.called)
# final parse does
options.parse_command_line(["main.py"])
self.assertTrue(self.called)
# callbacks can be run more than once on the same options
# object if there are multiple final parses
self.called = False
options.parse_command_line(["main.py"])
self.assertTrue(self.called)
def test_help(self):
options = OptionParser()
try:
orig_stderr = sys.stderr
sys.stderr = StringIO()
with self.assertRaises(SystemExit):
options.parse_command_line(["main.py", "--help"])
usage = sys.stderr.getvalue()
finally:
sys.stderr = orig_stderr
self.assertIn("Usage:", usage)
def test_subcommand(self):
base_options = OptionParser()
base_options.define("verbose", default=False)
sub_options = OptionParser()
sub_options.define("foo", type=str)
rest = base_options.parse_command_line(
["main.py", "--verbose", "subcommand", "--foo=bar"]
)
self.assertEqual(rest, ["subcommand", "--foo=bar"])
self.assertTrue(base_options.verbose)
rest2 = sub_options.parse_command_line(rest)
self.assertEqual(rest2, [])
self.assertEqual(sub_options.foo, "bar")
# the two option sets are distinct
try:
orig_stderr = sys.stderr
sys.stderr = StringIO()
with self.assertRaises(Error):
sub_options.parse_command_line(["subcommand", "--verbose"])
finally:
sys.stderr = orig_stderr
def test_setattr(self):
options = OptionParser()
options.define("foo", default=1, type=int)
options.foo = 2
self.assertEqual(options.foo, 2)
def test_setattr_type_check(self):
# setattr requires that options be the right type and doesn't
# parse from string formats.
options = OptionParser()
options.define("foo", default=1, type=int)
with self.assertRaises(Error):
options.foo = "2"
def test_setattr_with_callback(self):
values = [] # type: List[int]
options = OptionParser()
options.define("foo", default=1, type=int, callback=values.append)
options.foo = 2
self.assertEqual(values, [2])
def _sample_options(self):
options = OptionParser()
options.define("a", default=1)
options.define("b", default=2)
return options
def test_iter(self):
options = self._sample_options()
# OptionParsers always define 'help'.
self.assertEqual(set(["a", "b", "help"]), set(iter(options)))
def test_getitem(self):
options = self._sample_options()
self.assertEqual(1, options["a"])
def test_setitem(self):
options = OptionParser()
options.define("foo", default=1, type=int)
options["foo"] = 2
self.assertEqual(options["foo"], 2)
def test_items(self):
options = self._sample_options()
# OptionParsers always define 'help'.
expected = [("a", 1), ("b", 2), ("help", options.help)]
actual = sorted(options.items())
self.assertEqual(expected, actual)
def test_as_dict(self):
options = self._sample_options()
expected = {"a": 1, "b": 2, "help": options.help}
self.assertEqual(expected, options.as_dict())
def test_group_dict(self):
options = OptionParser()
options.define("a", default=1)
options.define("b", group="b_group", default=2)
frame = sys._getframe(0)
this_file = frame.f_code.co_filename
self.assertEqual(set(["b_group", "", this_file]), options.groups())
b_group_dict = options.group_dict("b_group")
self.assertEqual({"b": 2}, b_group_dict)
self.assertEqual({}, options.group_dict("nonexistent"))
def test_mock_patch(self):
# ensure that our setattr hooks don't interfere with mock.patch
options = OptionParser()
options.define("foo", default=1)
options.parse_command_line(["main.py", "--foo=2"])
self.assertEqual(options.foo, 2)
with mock.patch.object(options.mockable(), "foo", 3):
self.assertEqual(options.foo, 3)
self.assertEqual(options.foo, 2)
# Try nested patches mixed with explicit sets
with mock.patch.object(options.mockable(), "foo", 4):
self.assertEqual(options.foo, 4)
options.foo = 5
self.assertEqual(options.foo, 5)
with mock.patch.object(options.mockable(), "foo", 6):
self.assertEqual(options.foo, 6)
self.assertEqual(options.foo, 5)
self.assertEqual(options.foo, 2)
def _define_options(self):
options = OptionParser()
options.define("str", type=str)
options.define("basestring", type=basestring_type)
options.define("int", type=int)
options.define("float", type=float)
options.define("datetime", type=datetime.datetime)
options.define("timedelta", type=datetime.timedelta)
options.define("email", type=Email)
options.define("list-of-int", type=int, multiple=True)
return options
def _check_options_values(self, options):
self.assertEqual(options.str, "asdf")
self.assertEqual(options.basestring, "qwer")
self.assertEqual(options.int, 42)
self.assertEqual(options.float, 1.5)
self.assertEqual(options.datetime, datetime.datetime(2013, 4, 28, 5, 16))
self.assertEqual(options.timedelta, datetime.timedelta(seconds=45))
self.assertEqual(options.email.value, "tornado@web.com")
self.assertTrue(isinstance(options.email, Email))
self.assertEqual(options.list_of_int, [1, 2, 3])
def test_types(self):
options = self._define_options()
options.parse_command_line(
[
"main.py",
"--str=asdf",
"--basestring=qwer",
"--int=42",
"--float=1.5",
"--datetime=2013-04-28 05:16",
"--timedelta=45s",
"--email=tornado@web.com",
"--list-of-int=1,2,3",
]
)
self._check_options_values(options)
def test_types_with_conf_file(self):
for config_file_name in (
"options_test_types.cfg",
"options_test_types_str.cfg",
):
options = self._define_options()
options.parse_config_file(
os.path.join(os.path.dirname(__file__), config_file_name)
)
self._check_options_values(options)
def test_multiple_string(self):
options = OptionParser()
options.define("foo", type=str, multiple=True)
options.parse_command_line(["main.py", "--foo=a,b,c"])
self.assertEqual(options.foo, ["a", "b", "c"])
def test_multiple_int(self):
options = OptionParser()
options.define("foo", type=int, multiple=True)
options.parse_command_line(["main.py", "--foo=1,3,5:7"])
self.assertEqual(options.foo, [1, 3, 5, 6, 7])
def test_error_redefine(self):
options = OptionParser()
options.define("foo")
with self.assertRaises(Error) as cm:
options.define("foo")
self.assertRegexpMatches(str(cm.exception), "Option.*foo.*already defined")
def test_error_redefine_underscore(self):
# Ensure that the dash/underscore normalization doesn't
# interfere with the redefinition error.
tests = [
("foo-bar", "foo-bar"),
("foo_bar", "foo_bar"),
("foo-bar", "foo_bar"),
("foo_bar", "foo-bar"),
]
for a, b in tests:
with subTest(self, a=a, b=b):
options = OptionParser()
options.define(a)
with self.assertRaises(Error) as cm:
options.define(b)
self.assertRegexpMatches(
str(cm.exception), "Option.*foo.bar.*already defined"
)
def test_dash_underscore_cli(self):
# Dashes and underscores should be interchangeable.
for defined_name in ["foo-bar", "foo_bar"]:
for flag in ["--foo-bar=a", "--foo_bar=a"]:
options = OptionParser()
options.define(defined_name)
options.parse_command_line(["main.py", flag])
# Attr-style access always uses underscores.
self.assertEqual(options.foo_bar, "a")
# Dict-style access allows both.
self.assertEqual(options["foo-bar"], "a")
self.assertEqual(options["foo_bar"], "a")
def test_dash_underscore_file(self):
# No matter how an option was defined, it can be set with underscores
# in a config file.
for defined_name in ["foo-bar", "foo_bar"]:
options = OptionParser()
options.define(defined_name)
options.parse_config_file(
os.path.join(os.path.dirname(__file__), "options_test.cfg")
)
self.assertEqual(options.foo_bar, "a")
def test_dash_underscore_introspection(self):
# Original names are preserved in introspection APIs.
options = OptionParser()
options.define("with-dash", group="g")
options.define("with_underscore", group="g")
all_options = ["help", "with-dash", "with_underscore"]
self.assertEqual(sorted(options), all_options)
self.assertEqual(sorted(k for (k, v) in options.items()), all_options)
self.assertEqual(sorted(options.as_dict().keys()), all_options)
self.assertEqual(
sorted(options.group_dict("g")), ["with-dash", "with_underscore"]
)
# --help shows CLI-style names with dashes.
buf = StringIO()
options.print_help(buf)
self.assertIn("--with-dash", buf.getvalue())
self.assertIn("--with-underscore", buf.getvalue())

View file

@ -0,0 +1,11 @@
from datetime import datetime, timedelta
from tornado.test.options_test import Email
str = 'asdf'
basestring = 'qwer'
int = 42
float = 1.5
datetime = datetime(2013, 4, 28, 5, 16)
timedelta = timedelta(0, 45)
email = Email('tornado@web.com')
list_of_int = [1, 2, 3]

View file

@ -0,0 +1,8 @@
str = 'asdf'
basestring = 'qwer'
int = 42
float = 1.5
datetime = '2013-04-28 05:16'
timedelta = '45s'
email = 'tornado@web.com'
list_of_int = '1,2,3'

View file

@ -0,0 +1,274 @@
import asyncio
import logging
import os
import signal
import subprocess
import sys
import time
import unittest
from tornado.httpclient import HTTPClient, HTTPError
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.log import gen_log
from tornado.process import fork_processes, task_id, Subprocess
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.testing import bind_unused_port, ExpectLog, AsyncTestCase, gen_test
from tornado.test.util import skipIfNonUnix
from tornado.web import RequestHandler, Application
# Not using AsyncHTTPTestCase because we need control over the IOLoop.
@skipIfNonUnix
class ProcessTest(unittest.TestCase):
def get_app(self):
class ProcessHandler(RequestHandler):
def get(self):
if self.get_argument("exit", None):
# must use os._exit instead of sys.exit so unittest's
# exception handler doesn't catch it
os._exit(int(self.get_argument("exit")))
if self.get_argument("signal", None):
os.kill(os.getpid(), int(self.get_argument("signal")))
self.write(str(os.getpid()))
return Application([("/", ProcessHandler)])
def tearDown(self):
if task_id() is not None:
# We're in a child process, and probably got to this point
# via an uncaught exception. If we return now, both
# processes will continue with the rest of the test suite.
# Exit now so the parent process will restart the child
# (since we don't have a clean way to signal failure to
# the parent that won't restart)
logging.error("aborting child process from tearDown")
logging.shutdown()
os._exit(1)
# In the surviving process, clear the alarm we set earlier
signal.alarm(0)
super().tearDown()
def test_multi_process(self):
# This test doesn't work on twisted because we use the global
# reactor and don't restore it to a sane state after the fork
# (asyncio has the same issue, but we have a special case in
# place for it).
with ExpectLog(
gen_log, "(Starting .* processes|child .* exited|uncaught exception)"
):
sock, port = bind_unused_port()
def get_url(path):
return "http://127.0.0.1:%d%s" % (port, path)
# ensure that none of these processes live too long
signal.alarm(5) # master process
try:
id = fork_processes(3, max_restarts=3)
self.assertTrue(id is not None)
signal.alarm(5) # child processes
except SystemExit as e:
# if we exit cleanly from fork_processes, all the child processes
# finished with status 0
self.assertEqual(e.code, 0)
self.assertTrue(task_id() is None)
sock.close()
return
try:
if asyncio is not None:
# Reset the global asyncio event loop, which was put into
# a broken state by the fork.
asyncio.set_event_loop(asyncio.new_event_loop())
if id in (0, 1):
self.assertEqual(id, task_id())
server = HTTPServer(self.get_app())
server.add_sockets([sock])
IOLoop.current().start()
elif id == 2:
self.assertEqual(id, task_id())
sock.close()
# Always use SimpleAsyncHTTPClient here; the curl
# version appears to get confused sometimes if the
# connection gets closed before it's had a chance to
# switch from writing mode to reading mode.
client = HTTPClient(SimpleAsyncHTTPClient)
def fetch(url, fail_ok=False):
try:
return client.fetch(get_url(url))
except HTTPError as e:
if not (fail_ok and e.code == 599):
raise
# Make two processes exit abnormally
fetch("/?exit=2", fail_ok=True)
fetch("/?exit=3", fail_ok=True)
# They've been restarted, so a new fetch will work
int(fetch("/").body)
# Now the same with signals
# Disabled because on the mac a process dying with a signal
# can trigger an "Application exited abnormally; send error
# report to Apple?" prompt.
# fetch("/?signal=%d" % signal.SIGTERM, fail_ok=True)
# fetch("/?signal=%d" % signal.SIGABRT, fail_ok=True)
# int(fetch("/").body)
# Now kill them normally so they won't be restarted
fetch("/?exit=0", fail_ok=True)
# One process left; watch it's pid change
pid = int(fetch("/").body)
fetch("/?exit=4", fail_ok=True)
pid2 = int(fetch("/").body)
self.assertNotEqual(pid, pid2)
# Kill the last one so we shut down cleanly
fetch("/?exit=0", fail_ok=True)
os._exit(0)
except Exception:
logging.error("exception in child process %d", id, exc_info=True)
raise
@skipIfNonUnix
class SubprocessTest(AsyncTestCase):
def term_and_wait(self, subproc):
subproc.proc.terminate()
subproc.proc.wait()
@gen_test
def test_subprocess(self):
if IOLoop.configured_class().__name__.endswith("LayeredTwistedIOLoop"):
# This test fails non-deterministically with LayeredTwistedIOLoop.
# (the read_until('\n') returns '\n' instead of 'hello\n')
# This probably indicates a problem with either TornadoReactor
# or TwistedIOLoop, but I haven't been able to track it down
# and for now this is just causing spurious travis-ci failures.
raise unittest.SkipTest(
"Subprocess tests not compatible with " "LayeredTwistedIOLoop"
)
subproc = Subprocess(
[sys.executable, "-u", "-i"],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM,
stderr=subprocess.STDOUT,
)
self.addCleanup(lambda: self.term_and_wait(subproc))
self.addCleanup(subproc.stdout.close)
self.addCleanup(subproc.stdin.close)
yield subproc.stdout.read_until(b">>> ")
subproc.stdin.write(b"print('hello')\n")
data = yield subproc.stdout.read_until(b"\n")
self.assertEqual(data, b"hello\n")
yield subproc.stdout.read_until(b">>> ")
subproc.stdin.write(b"raise SystemExit\n")
data = yield subproc.stdout.read_until_close()
self.assertEqual(data, b"")
@gen_test
def test_close_stdin(self):
# Close the parent's stdin handle and see that the child recognizes it.
subproc = Subprocess(
[sys.executable, "-u", "-i"],
stdin=Subprocess.STREAM,
stdout=Subprocess.STREAM,
stderr=subprocess.STDOUT,
)
self.addCleanup(lambda: self.term_and_wait(subproc))
yield subproc.stdout.read_until(b">>> ")
subproc.stdin.close()
data = yield subproc.stdout.read_until_close()
self.assertEqual(data, b"\n")
@gen_test
def test_stderr(self):
# This test is mysteriously flaky on twisted: it succeeds, but logs
# an error of EBADF on closing a file descriptor.
subproc = Subprocess(
[sys.executable, "-u", "-c", r"import sys; sys.stderr.write('hello\n')"],
stderr=Subprocess.STREAM,
)
self.addCleanup(lambda: self.term_and_wait(subproc))
data = yield subproc.stderr.read_until(b"\n")
self.assertEqual(data, b"hello\n")
# More mysterious EBADF: This fails if done with self.addCleanup instead of here.
subproc.stderr.close()
def test_sigchild(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, "-c", "pass"])
subproc.set_exit_callback(self.stop)
ret = self.wait()
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
@gen_test
def test_sigchild_future(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, "-c", "pass"])
ret = yield subproc.wait_for_exit()
self.assertEqual(ret, 0)
self.assertEqual(subproc.returncode, ret)
def test_sigchild_signal(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess(
[sys.executable, "-c", "import time; time.sleep(30)"],
stdout=Subprocess.STREAM,
)
self.addCleanup(subproc.stdout.close)
subproc.set_exit_callback(self.stop)
# For unclear reasons, killing a process too soon after
# creating it can result in an exit status corresponding to
# SIGKILL instead of the actual signal involved. This has been
# observed on macOS 10.15 with Python 3.8 installed via brew,
# but not with the system-installed Python 3.7.
time.sleep(0.1)
os.kill(subproc.pid, signal.SIGTERM)
try:
ret = self.wait()
except AssertionError:
# We failed to get the termination signal. This test is
# occasionally flaky on pypy, so try to get a little more
# information: did the process close its stdout
# (indicating that the problem is in the parent process's
# signal handling) or did the child process somehow fail
# to terminate?
fut = subproc.stdout.read_until_close()
fut.add_done_callback(lambda f: self.stop()) # type: ignore
try:
self.wait()
except AssertionError:
raise AssertionError("subprocess failed to terminate")
else:
raise AssertionError(
"subprocess closed stdout but failed to " "get termination signal"
)
self.assertEqual(subproc.returncode, ret)
self.assertEqual(ret, -signal.SIGTERM)
@gen_test
def test_wait_for_exit_raise(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, "-c", "import sys; sys.exit(1)"])
with self.assertRaises(subprocess.CalledProcessError) as cm:
yield subproc.wait_for_exit()
self.assertEqual(cm.exception.returncode, 1)
@gen_test
def test_wait_for_exit_raise_disabled(self):
Subprocess.initialize()
self.addCleanup(Subprocess.uninitialize)
subproc = Subprocess([sys.executable, "-c", "import sys; sys.exit(1)"])
ret = yield subproc.wait_for_exit(raise_error=False)
self.assertEqual(ret, 1)

View file

@ -0,0 +1,431 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import asyncio
from datetime import timedelta
from random import random
import unittest
from tornado import gen, queues
from tornado.gen import TimeoutError
from tornado.testing import gen_test, AsyncTestCase
class QueueBasicTest(AsyncTestCase):
def test_repr_and_str(self):
q = queues.Queue(maxsize=1) # type: queues.Queue[None]
self.assertIn(hex(id(q)), repr(q))
self.assertNotIn(hex(id(q)), str(q))
q.get()
for q_str in repr(q), str(q):
self.assertTrue(q_str.startswith("<Queue"))
self.assertIn("maxsize=1", q_str)
self.assertIn("getters[1]", q_str)
self.assertNotIn("putters", q_str)
self.assertNotIn("tasks", q_str)
q.put(None)
q.put(None)
# Now the queue is full, this putter blocks.
q.put(None)
for q_str in repr(q), str(q):
self.assertNotIn("getters", q_str)
self.assertIn("putters[1]", q_str)
self.assertIn("tasks=2", q_str)
def test_order(self):
q = queues.Queue() # type: queues.Queue[int]
for i in [1, 3, 2]:
q.put_nowait(i)
items = [q.get_nowait() for _ in range(3)]
self.assertEqual([1, 3, 2], items)
@gen_test
def test_maxsize(self):
self.assertRaises(TypeError, queues.Queue, maxsize=None)
self.assertRaises(ValueError, queues.Queue, maxsize=-1)
q = queues.Queue(maxsize=2) # type: queues.Queue[int]
self.assertTrue(q.empty())
self.assertFalse(q.full())
self.assertEqual(2, q.maxsize)
self.assertTrue(q.put(0).done())
self.assertTrue(q.put(1).done())
self.assertFalse(q.empty())
self.assertTrue(q.full())
put2 = q.put(2)
self.assertFalse(put2.done())
self.assertEqual(0, (yield q.get())) # Make room.
self.assertTrue(put2.done())
self.assertFalse(q.empty())
self.assertTrue(q.full())
class QueueGetTest(AsyncTestCase):
@gen_test
def test_blocking_get(self):
q = queues.Queue() # type: queues.Queue[int]
q.put_nowait(0)
self.assertEqual(0, (yield q.get()))
def test_nonblocking_get(self):
q = queues.Queue() # type: queues.Queue[int]
q.put_nowait(0)
self.assertEqual(0, q.get_nowait())
def test_nonblocking_get_exception(self):
q = queues.Queue() # type: queues.Queue[int]
self.assertRaises(queues.QueueEmpty, q.get_nowait)
@gen_test
def test_get_with_putters(self):
q = queues.Queue(1) # type: queues.Queue[int]
q.put_nowait(0)
put = q.put(1)
self.assertEqual(0, (yield q.get()))
self.assertIsNone((yield put))
@gen_test
def test_blocking_get_wait(self):
q = queues.Queue() # type: queues.Queue[int]
q.put(0)
self.io_loop.call_later(0.01, q.put_nowait, 1)
self.io_loop.call_later(0.02, q.put_nowait, 2)
self.assertEqual(0, (yield q.get(timeout=timedelta(seconds=1))))
self.assertEqual(1, (yield q.get(timeout=timedelta(seconds=1))))
@gen_test
def test_get_timeout(self):
q = queues.Queue() # type: queues.Queue[int]
get_timeout = q.get(timeout=timedelta(seconds=0.01))
get = q.get()
with self.assertRaises(TimeoutError):
yield get_timeout
q.put_nowait(0)
self.assertEqual(0, (yield get))
@gen_test
def test_get_timeout_preempted(self):
q = queues.Queue() # type: queues.Queue[int]
get = q.get(timeout=timedelta(seconds=0.01))
q.put(0)
yield gen.sleep(0.02)
self.assertEqual(0, (yield get))
@gen_test
def test_get_clears_timed_out_putters(self):
q = queues.Queue(1) # type: queues.Queue[int]
# First putter succeeds, remainder block.
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
put = q.put(10)
self.assertEqual(10, len(q._putters))
yield gen.sleep(0.02)
self.assertEqual(10, len(q._putters))
self.assertFalse(put.done()) # Final waiter is still active.
q.put(11)
self.assertEqual(0, (yield q.get())) # get() clears the waiters.
self.assertEqual(1, len(q._putters))
for putter in putters[1:]:
self.assertRaises(TimeoutError, putter.result)
@gen_test
def test_get_clears_timed_out_getters(self):
q = queues.Queue() # type: queues.Queue[int]
getters = [
asyncio.ensure_future(q.get(timedelta(seconds=0.01))) for _ in range(10)
]
get = asyncio.ensure_future(q.get())
self.assertEqual(11, len(q._getters))
yield gen.sleep(0.02)
self.assertEqual(11, len(q._getters))
self.assertFalse(get.done()) # Final waiter is still active.
q.get() # get() clears the waiters.
self.assertEqual(2, len(q._getters))
for getter in getters:
self.assertRaises(TimeoutError, getter.result)
@gen_test
def test_async_for(self):
q = queues.Queue() # type: queues.Queue[int]
for i in range(5):
q.put(i)
async def f():
results = []
async for i in q:
results.append(i)
if i == 4:
return results
results = yield f()
self.assertEqual(results, list(range(5)))
class QueuePutTest(AsyncTestCase):
@gen_test
def test_blocking_put(self):
q = queues.Queue() # type: queues.Queue[int]
q.put(0)
self.assertEqual(0, q.get_nowait())
def test_nonblocking_put_exception(self):
q = queues.Queue(1) # type: queues.Queue[int]
q.put(0)
self.assertRaises(queues.QueueFull, q.put_nowait, 1)
@gen_test
def test_put_with_getters(self):
q = queues.Queue() # type: queues.Queue[int]
get0 = q.get()
get1 = q.get()
yield q.put(0)
self.assertEqual(0, (yield get0))
yield q.put(1)
self.assertEqual(1, (yield get1))
@gen_test
def test_nonblocking_put_with_getters(self):
q = queues.Queue() # type: queues.Queue[int]
get0 = q.get()
get1 = q.get()
q.put_nowait(0)
# put_nowait does *not* immediately unblock getters.
yield gen.moment
self.assertEqual(0, (yield get0))
q.put_nowait(1)
yield gen.moment
self.assertEqual(1, (yield get1))
@gen_test
def test_blocking_put_wait(self):
q = queues.Queue(1) # type: queues.Queue[int]
q.put_nowait(0)
def get_and_discard():
q.get()
self.io_loop.call_later(0.01, get_and_discard)
self.io_loop.call_later(0.02, get_and_discard)
futures = [q.put(0), q.put(1)]
self.assertFalse(any(f.done() for f in futures))
yield futures
@gen_test
def test_put_timeout(self):
q = queues.Queue(1) # type: queues.Queue[int]
q.put_nowait(0) # Now it's full.
put_timeout = q.put(1, timeout=timedelta(seconds=0.01))
put = q.put(2)
with self.assertRaises(TimeoutError):
yield put_timeout
self.assertEqual(0, q.get_nowait())
# 1 was never put in the queue.
self.assertEqual(2, (yield q.get()))
# Final get() unblocked this putter.
yield put
@gen_test
def test_put_timeout_preempted(self):
q = queues.Queue(1) # type: queues.Queue[int]
q.put_nowait(0)
put = q.put(1, timeout=timedelta(seconds=0.01))
q.get()
yield gen.sleep(0.02)
yield put # No TimeoutError.
@gen_test
def test_put_clears_timed_out_putters(self):
q = queues.Queue(1) # type: queues.Queue[int]
# First putter succeeds, remainder block.
putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
put = q.put(10)
self.assertEqual(10, len(q._putters))
yield gen.sleep(0.02)
self.assertEqual(10, len(q._putters))
self.assertFalse(put.done()) # Final waiter is still active.
q.put(11) # put() clears the waiters.
self.assertEqual(2, len(q._putters))
for putter in putters[1:]:
self.assertRaises(TimeoutError, putter.result)
@gen_test
def test_put_clears_timed_out_getters(self):
q = queues.Queue() # type: queues.Queue[int]
getters = [
asyncio.ensure_future(q.get(timedelta(seconds=0.01))) for _ in range(10)
]
get = asyncio.ensure_future(q.get())
q.get()
self.assertEqual(12, len(q._getters))
yield gen.sleep(0.02)
self.assertEqual(12, len(q._getters))
self.assertFalse(get.done()) # Final waiters still active.
q.put(0) # put() clears the waiters.
self.assertEqual(1, len(q._getters))
self.assertEqual(0, (yield get))
for getter in getters:
self.assertRaises(TimeoutError, getter.result)
@gen_test
def test_float_maxsize(self):
# If a float is passed for maxsize, a reasonable limit should
# be enforced, instead of being treated as unlimited.
# It happens to be rounded up.
# http://bugs.python.org/issue21723
q = queues.Queue(maxsize=1.3) # type: ignore
self.assertTrue(q.empty())
self.assertFalse(q.full())
q.put_nowait(0)
q.put_nowait(1)
self.assertFalse(q.empty())
self.assertTrue(q.full())
self.assertRaises(queues.QueueFull, q.put_nowait, 2)
self.assertEqual(0, q.get_nowait())
self.assertFalse(q.empty())
self.assertFalse(q.full())
yield q.put(2)
put = q.put(3)
self.assertFalse(put.done())
self.assertEqual(1, (yield q.get()))
yield put
self.assertTrue(q.full())
class QueueJoinTest(AsyncTestCase):
queue_class = queues.Queue
def test_task_done_underflow(self):
q = self.queue_class() # type: queues.Queue
self.assertRaises(ValueError, q.task_done)
@gen_test
def test_task_done(self):
q = self.queue_class() # type: queues.Queue
for i in range(100):
q.put_nowait(i)
self.accumulator = 0
@gen.coroutine
def worker():
while True:
item = yield q.get()
self.accumulator += item
q.task_done()
yield gen.sleep(random() * 0.01)
# Two coroutines share work.
worker()
worker()
yield q.join()
self.assertEqual(sum(range(100)), self.accumulator)
@gen_test
def test_task_done_delay(self):
# Verify it is task_done(), not get(), that unblocks join().
q = self.queue_class() # type: queues.Queue
q.put_nowait(0)
join = asyncio.ensure_future(q.join())
self.assertFalse(join.done())
yield q.get()
self.assertFalse(join.done())
yield gen.moment
self.assertFalse(join.done())
q.task_done()
self.assertTrue(join.done())
@gen_test
def test_join_empty_queue(self):
q = self.queue_class() # type: queues.Queue
yield q.join()
yield q.join()
@gen_test
def test_join_timeout(self):
q = self.queue_class() # type: queues.Queue
q.put(0)
with self.assertRaises(TimeoutError):
yield q.join(timeout=timedelta(seconds=0.01))
class PriorityQueueJoinTest(QueueJoinTest):
queue_class = queues.PriorityQueue
@gen_test
def test_order(self):
q = self.queue_class(maxsize=2)
q.put_nowait((1, "a"))
q.put_nowait((0, "b"))
self.assertTrue(q.full())
q.put((3, "c"))
q.put((2, "d"))
self.assertEqual((0, "b"), q.get_nowait())
self.assertEqual((1, "a"), (yield q.get()))
self.assertEqual((2, "d"), q.get_nowait())
self.assertEqual((3, "c"), (yield q.get()))
self.assertTrue(q.empty())
class LifoQueueJoinTest(QueueJoinTest):
queue_class = queues.LifoQueue
@gen_test
def test_order(self):
q = self.queue_class(maxsize=2)
q.put_nowait(1)
q.put_nowait(0)
self.assertTrue(q.full())
q.put(3)
q.put(2)
self.assertEqual(3, q.get_nowait())
self.assertEqual(2, (yield q.get()))
self.assertEqual(0, q.get_nowait())
self.assertEqual(1, (yield q.get()))
self.assertTrue(q.empty())
class ProducerConsumerTest(AsyncTestCase):
@gen_test
def test_producer_consumer(self):
q = queues.Queue(maxsize=3) # type: queues.Queue[int]
history = []
# We don't yield between get() and task_done(), so get() must wait for
# the next tick. Otherwise we'd immediately call task_done and unblock
# join() before q.put() resumes, and we'd only process the first four
# items.
@gen.coroutine
def consumer():
while True:
history.append((yield q.get()))
q.task_done()
@gen.coroutine
def producer():
for item in range(10):
yield q.put(item)
consumer()
yield producer()
yield q.join()
self.assertEqual(list(range(10)), history)
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,10 @@
from tornado.ioloop import IOLoop
from tornado.netutil import ThreadedResolver
# When this module is imported, it runs getaddrinfo on a thread. Since
# the hostname is unicode, getaddrinfo attempts to import encodings.idna
# but blocks on the import lock. Verify that ThreadedResolver avoids
# this deadlock.
resolver = ThreadedResolver()
IOLoop.current().run_sync(lambda: resolver.resolve(u"localhost", 80))

View file

@ -0,0 +1,276 @@
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from tornado.httputil import (
HTTPHeaders,
HTTPMessageDelegate,
HTTPServerConnectionDelegate,
ResponseStartLine,
)
from tornado.routing import (
HostMatches,
PathMatches,
ReversibleRouter,
Router,
Rule,
RuleRouter,
)
from tornado.testing import AsyncHTTPTestCase
from tornado.web import Application, HTTPError, RequestHandler
from tornado.wsgi import WSGIContainer
import typing # noqa: F401
class BasicRouter(Router):
def find_handler(self, request, **kwargs):
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}),
b"OK",
)
self.connection.finish()
return MessageDelegate(request.connection)
class BasicRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
return BasicRouter()
def test_basic_router(self):
response = self.fetch("/any_request")
self.assertEqual(response.body, b"OK")
resources = {} # type: typing.Dict[str, bytes]
class GetResource(RequestHandler):
def get(self, path):
if path not in resources:
raise HTTPError(404)
self.finish(resources[path])
class PostResource(RequestHandler):
def post(self, path):
resources[path] = self.request.body
class HTTPMethodRouter(Router):
def __init__(self, app):
self.app = app
def find_handler(self, request, **kwargs):
handler = GetResource if request.method == "GET" else PostResource
return self.app.get_handler_delegate(request, handler, path_args=[request.path])
class HTTPMethodRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
return HTTPMethodRouter(Application())
def test_http_method_router(self):
response = self.fetch("/post_resource", method="POST", body="data")
self.assertEqual(response.code, 200)
response = self.fetch("/get_resource")
self.assertEqual(response.code, 404)
response = self.fetch("/post_resource")
self.assertEqual(response.code, 200)
self.assertEqual(response.body, b"data")
def _get_named_handler(handler_name):
class Handler(RequestHandler):
def get(self, *args, **kwargs):
if self.application.settings.get("app_name") is not None:
self.write(self.application.settings["app_name"] + ": ")
self.finish(handler_name + ": " + self.reverse_url(handler_name))
return Handler
FirstHandler = _get_named_handler("first_handler")
SecondHandler = _get_named_handler("second_handler")
class CustomRouter(ReversibleRouter):
def __init__(self):
super().__init__()
self.routes = {} # type: typing.Dict[str, typing.Any]
def add_routes(self, routes):
self.routes.update(routes)
def find_handler(self, request, **kwargs):
if request.path in self.routes:
app, handler = self.routes[request.path]
return app.get_handler_delegate(request, handler)
def reverse_url(self, name, *args):
handler_path = "/" + name
return handler_path if handler_path in self.routes else None
class CustomRouterTestCase(AsyncHTTPTestCase):
def get_app(self):
router = CustomRouter()
class CustomApplication(Application):
def reverse_url(self, name, *args):
return router.reverse_url(name, *args)
app1 = CustomApplication(app_name="app1")
app2 = CustomApplication(app_name="app2")
router.add_routes(
{
"/first_handler": (app1, FirstHandler),
"/second_handler": (app2, SecondHandler),
"/first_handler_second_app": (app2, FirstHandler),
}
)
return router
def test_custom_router(self):
response = self.fetch("/first_handler")
self.assertEqual(response.body, b"app1: first_handler: /first_handler")
response = self.fetch("/second_handler")
self.assertEqual(response.body, b"app2: second_handler: /second_handler")
response = self.fetch("/first_handler_second_app")
self.assertEqual(response.body, b"app2: first_handler: /first_handler")
class ConnectionDelegate(HTTPServerConnectionDelegate):
def start_request(self, server_conn, request_conn):
class MessageDelegate(HTTPMessageDelegate):
def __init__(self, connection):
self.connection = connection
def finish(self):
response_body = b"OK"
self.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": str(len(response_body))}),
)
self.connection.write(response_body)
self.connection.finish()
return MessageDelegate(request_conn)
class RuleRouterTest(AsyncHTTPTestCase):
def get_app(self):
app = Application()
def request_callable(request):
request.connection.write_headers(
ResponseStartLine("HTTP/1.1", 200, "OK"),
HTTPHeaders({"Content-Length": "2"}),
)
request.connection.write(b"OK")
request.connection.finish()
router = CustomRouter()
router.add_routes(
{"/nested_handler": (app, _get_named_handler("nested_handler"))}
)
app.add_handlers(
".*",
[
(
HostMatches("www.example.com"),
[
(
PathMatches("/first_handler"),
"tornado.test.routing_test.SecondHandler",
{},
"second_handler",
)
],
),
Rule(PathMatches("/.*handler"), router),
Rule(PathMatches("/first_handler"), FirstHandler, name="first_handler"),
Rule(PathMatches("/request_callable"), request_callable),
("/connection_delegate", ConnectionDelegate()),
],
)
return app
def test_rule_based_router(self):
response = self.fetch("/first_handler")
self.assertEqual(response.body, b"first_handler: /first_handler")
response = self.fetch("/first_handler", headers={"Host": "www.example.com"})
self.assertEqual(response.body, b"second_handler: /first_handler")
response = self.fetch("/nested_handler")
self.assertEqual(response.body, b"nested_handler: /nested_handler")
response = self.fetch("/nested_not_found_handler")
self.assertEqual(response.code, 404)
response = self.fetch("/connection_delegate")
self.assertEqual(response.body, b"OK")
response = self.fetch("/request_callable")
self.assertEqual(response.body, b"OK")
response = self.fetch("/404")
self.assertEqual(response.code, 404)
class WSGIContainerTestCase(AsyncHTTPTestCase):
def get_app(self):
wsgi_app = WSGIContainer(self.wsgi_app)
class Handler(RequestHandler):
def get(self, *args, **kwargs):
self.finish(self.reverse_url("tornado"))
return RuleRouter(
[
(
PathMatches("/tornado.*"),
Application([(r"/tornado/test", Handler, {}, "tornado")]),
),
(PathMatches("/wsgi"), wsgi_app),
]
)
def wsgi_app(self, environ, start_response):
start_response("200 OK", [])
return [b"WSGI"]
def test_wsgi_container(self):
response = self.fetch("/tornado/test")
self.assertEqual(response.body, b"/tornado/test")
response = self.fetch("/wsgi")
self.assertEqual(response.body, b"WSGI")
def test_delegate_not_found(self):
response = self.fetch("/404")
self.assertEqual(response.code, 404)

View file

@ -0,0 +1,241 @@
from functools import reduce
import gc
import io
import locale # system locale module, not tornado.locale
import logging
import operator
import textwrap
import sys
import unittest
import warnings
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.netutil import Resolver
from tornado.options import define, add_parse_callback, options
TEST_MODULES = [
"tornado.httputil.doctests",
"tornado.iostream.doctests",
"tornado.util.doctests",
"tornado.test.asyncio_test",
"tornado.test.auth_test",
"tornado.test.autoreload_test",
"tornado.test.concurrent_test",
"tornado.test.curl_httpclient_test",
"tornado.test.escape_test",
"tornado.test.gen_test",
"tornado.test.http1connection_test",
"tornado.test.httpclient_test",
"tornado.test.httpserver_test",
"tornado.test.httputil_test",
"tornado.test.import_test",
"tornado.test.ioloop_test",
"tornado.test.iostream_test",
"tornado.test.locale_test",
"tornado.test.locks_test",
"tornado.test.netutil_test",
"tornado.test.log_test",
"tornado.test.options_test",
"tornado.test.process_test",
"tornado.test.queues_test",
"tornado.test.routing_test",
"tornado.test.simple_httpclient_test",
"tornado.test.tcpclient_test",
"tornado.test.tcpserver_test",
"tornado.test.template_test",
"tornado.test.testing_test",
"tornado.test.twisted_test",
"tornado.test.util_test",
"tornado.test.web_test",
"tornado.test.websocket_test",
"tornado.test.wsgi_test",
]
def all():
return unittest.defaultTestLoader.loadTestsFromNames(TEST_MODULES)
def test_runner_factory(stderr):
class TornadoTextTestRunner(unittest.TextTestRunner):
def __init__(self, *args, **kwargs):
kwargs["stream"] = stderr
super().__init__(*args, **kwargs)
def run(self, test):
result = super().run(test)
if result.skipped:
skip_reasons = set(reason for (test, reason) in result.skipped)
self.stream.write( # type: ignore
textwrap.fill(
"Some tests were skipped because: %s"
% ", ".join(sorted(skip_reasons))
)
)
self.stream.write("\n") # type: ignore
return result
return TornadoTextTestRunner
class LogCounter(logging.Filter):
"""Counts the number of WARNING or higher log records."""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.info_count = self.warning_count = self.error_count = 0
def filter(self, record):
if record.levelno >= logging.ERROR:
self.error_count += 1
elif record.levelno >= logging.WARNING:
self.warning_count += 1
elif record.levelno >= logging.INFO:
self.info_count += 1
return True
class CountingStderr(io.IOBase):
def __init__(self, real):
self.real = real
self.byte_count = 0
def write(self, data):
self.byte_count += len(data)
return self.real.write(data)
def flush(self):
return self.real.flush()
def main():
# Be strict about most warnings (This is set in our test running
# scripts to catch import-time warnings, but set it again here to
# be sure). This also turns on warnings that are ignored by
# default, including DeprecationWarnings and python 3.2's
# ResourceWarnings.
warnings.filterwarnings("error")
# setuptools sometimes gives ImportWarnings about things that are on
# sys.path even if they're not being used.
warnings.filterwarnings("ignore", category=ImportWarning)
# Tornado generally shouldn't use anything deprecated, but some of
# our dependencies do (last match wins).
warnings.filterwarnings("ignore", category=DeprecationWarning)
warnings.filterwarnings("error", category=DeprecationWarning, module=r"tornado\..*")
warnings.filterwarnings("ignore", category=PendingDeprecationWarning)
warnings.filterwarnings(
"error", category=PendingDeprecationWarning, module=r"tornado\..*"
)
# The unittest module is aggressive about deprecating redundant methods,
# leaving some without non-deprecated spellings that work on both
# 2.7 and 3.2
warnings.filterwarnings(
"ignore", category=DeprecationWarning, message="Please use assert.* instead"
)
warnings.filterwarnings(
"ignore",
category=PendingDeprecationWarning,
message="Please use assert.* instead",
)
# Twisted 15.0.0 triggers some warnings on py3 with -bb.
warnings.filterwarnings("ignore", category=BytesWarning, module=r"twisted\..*")
if (3,) < sys.version_info < (3, 6):
# Prior to 3.6, async ResourceWarnings were rather noisy
# and even
# `python3.4 -W error -c 'import asyncio; asyncio.get_event_loop()'`
# would generate a warning.
warnings.filterwarnings(
"ignore", category=ResourceWarning, module=r"asyncio\..*"
)
# This deprecation warning is introduced in Python 3.8 and is
# triggered by pycurl. Unforunately, because it is raised in the C
# layer it can't be filtered by module and we must match the
# message text instead (Tornado's C module uses PY_SSIZE_T_CLEAN
# so it's not at risk of running into this issue).
warnings.filterwarnings(
"ignore",
category=DeprecationWarning,
message="PY_SSIZE_T_CLEAN will be required",
)
logging.getLogger("tornado.access").setLevel(logging.CRITICAL)
define(
"httpclient",
type=str,
default=None,
callback=lambda s: AsyncHTTPClient.configure(
s, defaults=dict(allow_ipv6=False)
),
)
define("httpserver", type=str, default=None, callback=HTTPServer.configure)
define("resolver", type=str, default=None, callback=Resolver.configure)
define(
"debug_gc",
type=str,
multiple=True,
help="A comma-separated list of gc module debug constants, "
"e.g. DEBUG_STATS or DEBUG_COLLECTABLE,DEBUG_OBJECTS",
callback=lambda values: gc.set_debug(
reduce(operator.or_, (getattr(gc, v) for v in values))
),
)
define(
"fail-if-logs",
default=True,
help="If true, fail the tests if any log output is produced (unless captured by ExpectLog)",
)
def set_locale(x):
locale.setlocale(locale.LC_ALL, x)
define("locale", type=str, default=None, callback=set_locale)
log_counter = LogCounter()
add_parse_callback(lambda: logging.getLogger().handlers[0].addFilter(log_counter))
# Certain errors (especially "unclosed resource" errors raised in
# destructors) go directly to stderr instead of logging. Count
# anything written by anything but the test runner as an error.
orig_stderr = sys.stderr
counting_stderr = CountingStderr(orig_stderr)
sys.stderr = counting_stderr # type: ignore
import tornado.testing
kwargs = {}
# HACK: unittest.main will make its own changes to the warning
# configuration, which may conflict with the settings above
# or command-line flags like -bb. Passing warnings=False
# suppresses this behavior, although this looks like an implementation
# detail. http://bugs.python.org/issue15626
kwargs["warnings"] = False
kwargs["testRunner"] = test_runner_factory(orig_stderr)
try:
tornado.testing.main(**kwargs)
finally:
# The tests should run clean; consider it a failure if they
# logged anything at info level or above.
if (
log_counter.info_count > 0
or log_counter.warning_count > 0
or log_counter.error_count > 0
or counting_stderr.byte_count > 0
):
logging.error(
"logged %d infos, %d warnings, %d errors, and %d bytes to stderr",
log_counter.info_count,
log_counter.warning_count,
log_counter.error_count,
counting_stderr.byte_count,
)
if options.fail_if_logs:
sys.exit(1)
if __name__ == "__main__":
main()

View file

@ -0,0 +1,834 @@
import collections
from contextlib import closing
import errno
import logging
import os
import re
import socket
import ssl
import sys
import typing # noqa: F401
from tornado.escape import to_unicode, utf8
from tornado import gen, version
from tornado.httpclient import AsyncHTTPClient
from tornado.httputil import HTTPHeaders, ResponseStartLine
from tornado.ioloop import IOLoop
from tornado.iostream import UnsatisfiableReadError
from tornado.locks import Event
from tornado.log import gen_log
from tornado.netutil import Resolver, bind_sockets
from tornado.simple_httpclient import (
SimpleAsyncHTTPClient,
HTTPStreamClosedError,
HTTPTimeoutError,
)
from tornado.test.httpclient_test import (
ChunkHandler,
CountdownHandler,
HelloWorldHandler,
RedirectHandler,
UserAgentHandler,
)
from tornado.test import httpclient_test
from tornado.testing import (
AsyncHTTPTestCase,
AsyncHTTPSTestCase,
AsyncTestCase,
ExpectLog,
gen_test,
)
from tornado.test.util import skipOnTravis, skipIfNoIPv6, refusing_port
from tornado.web import RequestHandler, Application, url, stream_request_body
class SimpleHTTPClientCommonTestCase(httpclient_test.HTTPClientCommonTestCase):
def get_http_client(self):
client = SimpleAsyncHTTPClient(force_instance=True)
self.assertTrue(isinstance(client, SimpleAsyncHTTPClient))
return client
class TriggerHandler(RequestHandler):
def initialize(self, queue, wake_callback):
self.queue = queue
self.wake_callback = wake_callback
@gen.coroutine
def get(self):
logging.debug("queuing trigger")
event = Event()
self.queue.append(event.set)
if self.get_argument("wake", "true") == "true":
self.wake_callback()
yield event.wait()
class ContentLengthHandler(RequestHandler):
def get(self):
self.stream = self.detach()
IOLoop.current().spawn_callback(self.write_response)
@gen.coroutine
def write_response(self):
yield self.stream.write(
utf8(
"HTTP/1.0 200 OK\r\nContent-Length: %s\r\n\r\nok"
% self.get_argument("value")
)
)
self.stream.close()
class HeadHandler(RequestHandler):
def head(self):
self.set_header("Content-Length", "7")
class OptionsHandler(RequestHandler):
def options(self):
self.set_header("Access-Control-Allow-Origin", "*")
self.write("ok")
class NoContentHandler(RequestHandler):
def get(self):
self.set_status(204)
self.finish()
class SeeOtherPostHandler(RequestHandler):
def post(self):
redirect_code = int(self.request.body)
assert redirect_code in (302, 303), "unexpected body %r" % self.request.body
self.set_header("Location", "/see_other_get")
self.set_status(redirect_code)
class SeeOtherGetHandler(RequestHandler):
def get(self):
if self.request.body:
raise Exception("unexpected body %r" % self.request.body)
self.write("ok")
class HostEchoHandler(RequestHandler):
def get(self):
self.write(self.request.headers["Host"])
class NoContentLengthHandler(RequestHandler):
def get(self):
if self.request.version.startswith("HTTP/1"):
# Emulate the old HTTP/1.0 behavior of returning a body with no
# content-length. Tornado handles content-length at the framework
# level so we have to go around it.
stream = self.detach()
stream.write(b"HTTP/1.0 200 OK\r\n\r\n" b"hello")
stream.close()
else:
self.finish("HTTP/1 required")
class EchoPostHandler(RequestHandler):
def post(self):
self.write(self.request.body)
@stream_request_body
class RespondInPrepareHandler(RequestHandler):
def prepare(self):
self.set_status(403)
self.finish("forbidden")
class SimpleHTTPClientTestMixin(object):
def create_client(self, **kwargs):
raise NotImplementedError()
def get_app(self: typing.Any):
# callable objects to finish pending /trigger requests
self.triggers = (
collections.deque()
) # type: typing.Deque[typing.Callable[[], None]]
return Application(
[
url(
"/trigger",
TriggerHandler,
dict(queue=self.triggers, wake_callback=self.stop),
),
url("/chunk", ChunkHandler),
url("/countdown/([0-9]+)", CountdownHandler, name="countdown"),
url("/hello", HelloWorldHandler),
url("/content_length", ContentLengthHandler),
url("/head", HeadHandler),
url("/options", OptionsHandler),
url("/no_content", NoContentHandler),
url("/see_other_post", SeeOtherPostHandler),
url("/see_other_get", SeeOtherGetHandler),
url("/host_echo", HostEchoHandler),
url("/no_content_length", NoContentLengthHandler),
url("/echo_post", EchoPostHandler),
url("/respond_in_prepare", RespondInPrepareHandler),
url("/redirect", RedirectHandler),
url("/user_agent", UserAgentHandler),
],
gzip=True,
)
def test_singleton(self: typing.Any):
# Class "constructor" reuses objects on the same IOLoop
self.assertTrue(SimpleAsyncHTTPClient() is SimpleAsyncHTTPClient())
# unless force_instance is used
self.assertTrue(
SimpleAsyncHTTPClient() is not SimpleAsyncHTTPClient(force_instance=True)
)
# different IOLoops use different objects
with closing(IOLoop()) as io_loop2:
async def make_client():
await gen.sleep(0)
return SimpleAsyncHTTPClient()
client1 = self.io_loop.run_sync(make_client)
client2 = io_loop2.run_sync(make_client)
self.assertTrue(client1 is not client2)
def test_connection_limit(self: typing.Any):
with closing(self.create_client(max_clients=2)) as client:
self.assertEqual(client.max_clients, 2)
seen = []
# Send 4 requests. Two can be sent immediately, while the others
# will be queued
for i in range(4):
def cb(fut, i=i):
seen.append(i)
self.stop()
client.fetch(self.get_url("/trigger")).add_done_callback(cb)
self.wait(condition=lambda: len(self.triggers) == 2)
self.assertEqual(len(client.queue), 2)
# Finish the first two requests and let the next two through
self.triggers.popleft()()
self.triggers.popleft()()
self.wait(condition=lambda: (len(self.triggers) == 2 and len(seen) == 2))
self.assertEqual(set(seen), set([0, 1]))
self.assertEqual(len(client.queue), 0)
# Finish all the pending requests
self.triggers.popleft()()
self.triggers.popleft()()
self.wait(condition=lambda: len(seen) == 4)
self.assertEqual(set(seen), set([0, 1, 2, 3]))
self.assertEqual(len(self.triggers), 0)
@gen_test
def test_redirect_connection_limit(self: typing.Any):
# following redirects should not consume additional connections
with closing(self.create_client(max_clients=1)) as client:
response = yield client.fetch(self.get_url("/countdown/3"), max_redirects=3)
response.rethrow()
def test_max_redirects(self: typing.Any):
response = self.fetch("/countdown/5", max_redirects=3)
self.assertEqual(302, response.code)
# We requested 5, followed three redirects for 4, 3, 2, then the last
# unfollowed redirect is to 1.
self.assertTrue(response.request.url.endswith("/countdown/5"))
self.assertTrue(response.effective_url.endswith("/countdown/2"))
self.assertTrue(response.headers["Location"].endswith("/countdown/1"))
def test_header_reuse(self: typing.Any):
# Apps may reuse a headers object if they are only passing in constant
# headers like user-agent. The header object should not be modified.
headers = HTTPHeaders({"User-Agent": "Foo"})
self.fetch("/hello", headers=headers)
self.assertEqual(list(headers.get_all()), [("User-Agent", "Foo")])
def test_default_user_agent(self: typing.Any):
response = self.fetch("/user_agent", method="GET")
self.assertEqual(200, response.code)
self.assertEqual(response.body.decode(), "Tornado/{}".format(version))
def test_see_other_redirect(self: typing.Any):
for code in (302, 303):
response = self.fetch("/see_other_post", method="POST", body="%d" % code)
self.assertEqual(200, response.code)
self.assertTrue(response.request.url.endswith("/see_other_post"))
self.assertTrue(response.effective_url.endswith("/see_other_get"))
# request is the original request, is a POST still
self.assertEqual("POST", response.request.method)
@skipOnTravis
@gen_test
def test_connect_timeout(self: typing.Any):
timeout = 0.1
cleanup_event = Event()
test = self
class TimeoutResolver(Resolver):
async def resolve(self, *args, **kwargs):
await cleanup_event.wait()
# Return something valid so the test doesn't raise during shutdown.
return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))]
with closing(self.create_client(resolver=TimeoutResolver())) as client:
with self.assertRaises(HTTPTimeoutError):
yield client.fetch(
self.get_url("/hello"),
connect_timeout=timeout,
request_timeout=3600,
raise_error=True,
)
# Let the hanging coroutine clean up after itself. We need to
# wait more than a single IOLoop iteration for the SSL case,
# which logs errors on unexpected EOF.
cleanup_event.set()
yield gen.sleep(0.2)
@skipOnTravis
def test_request_timeout(self: typing.Any):
timeout = 0.1
if os.name == "nt":
timeout = 0.5
with self.assertRaises(HTTPTimeoutError):
self.fetch("/trigger?wake=false", request_timeout=timeout, raise_error=True)
# trigger the hanging request to let it clean up after itself
self.triggers.popleft()()
self.io_loop.run_sync(lambda: gen.sleep(0))
@skipIfNoIPv6
def test_ipv6(self: typing.Any):
[sock] = bind_sockets(0, "::1", family=socket.AF_INET6)
port = sock.getsockname()[1]
self.http_server.add_socket(sock)
url = "%s://[::1]:%d/hello" % (self.get_protocol(), port)
# ipv6 is currently enabled by default but can be disabled
with self.assertRaises(Exception):
self.fetch(url, allow_ipv6=False, raise_error=True)
response = self.fetch(url)
self.assertEqual(response.body, b"Hello world!")
def test_multiple_content_length_accepted(self: typing.Any):
response = self.fetch("/content_length?value=2,2")
self.assertEqual(response.body, b"ok")
response = self.fetch("/content_length?value=2,%202,2")
self.assertEqual(response.body, b"ok")
with ExpectLog(
gen_log, ".*Multiple unequal Content-Lengths", level=logging.INFO
):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/content_length?value=2,4", raise_error=True)
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/content_length?value=2,%202,3", raise_error=True)
def test_head_request(self: typing.Any):
response = self.fetch("/head", method="HEAD")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["content-length"], "7")
self.assertFalse(response.body)
def test_options_request(self: typing.Any):
response = self.fetch("/options", method="OPTIONS")
self.assertEqual(response.code, 200)
self.assertEqual(response.headers["content-length"], "2")
self.assertEqual(response.headers["access-control-allow-origin"], "*")
self.assertEqual(response.body, b"ok")
def test_no_content(self: typing.Any):
response = self.fetch("/no_content")
self.assertEqual(response.code, 204)
# 204 status shouldn't have a content-length
#
# Tests with a content-length header are included below
# in HTTP204NoContentTestCase.
self.assertNotIn("Content-Length", response.headers)
def test_host_header(self: typing.Any):
host_re = re.compile(b"^127.0.0.1:[0-9]+$")
response = self.fetch("/host_echo")
self.assertTrue(host_re.match(response.body))
url = self.get_url("/host_echo").replace("http://", "http://me:secret@")
response = self.fetch(url)
self.assertTrue(host_re.match(response.body), response.body)
def test_connection_refused(self: typing.Any):
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with ExpectLog(gen_log, ".*", required=False):
with self.assertRaises(socket.error) as cm:
self.fetch("http://127.0.0.1:%d/" % port, raise_error=True)
if sys.platform != "cygwin":
# cygwin returns EPERM instead of ECONNREFUSED here
contains_errno = str(errno.ECONNREFUSED) in str(cm.exception)
if not contains_errno and hasattr(errno, "WSAECONNREFUSED"):
contains_errno = str(errno.WSAECONNREFUSED) in str( # type: ignore
cm.exception
)
self.assertTrue(contains_errno, cm.exception)
# This is usually "Connection refused".
# On windows, strerror is broken and returns "Unknown error".
expected_message = os.strerror(errno.ECONNREFUSED)
self.assertTrue(expected_message in str(cm.exception), cm.exception)
def test_queue_timeout(self: typing.Any):
with closing(self.create_client(max_clients=1)) as client:
# Wait for the trigger request to block, not complete.
fut1 = client.fetch(self.get_url("/trigger"), request_timeout=10)
self.wait()
with self.assertRaises(HTTPTimeoutError) as cm:
self.io_loop.run_sync(
lambda: client.fetch(
self.get_url("/hello"), connect_timeout=0.1, raise_error=True
)
)
self.assertEqual(str(cm.exception), "Timeout in request queue")
self.triggers.popleft()()
self.io_loop.run_sync(lambda: fut1)
def test_no_content_length(self: typing.Any):
response = self.fetch("/no_content_length")
if response.body == b"HTTP/1 required":
self.skipTest("requires HTTP/1.x")
else:
self.assertEqual(b"hello", response.body)
def sync_body_producer(self, write):
write(b"1234")
write(b"5678")
@gen.coroutine
def async_body_producer(self, write):
yield write(b"1234")
yield gen.moment
yield write(b"5678")
def test_sync_body_producer_chunked(self: typing.Any):
response = self.fetch(
"/echo_post", method="POST", body_producer=self.sync_body_producer
)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_sync_body_producer_content_length(self: typing.Any):
response = self.fetch(
"/echo_post",
method="POST",
body_producer=self.sync_body_producer,
headers={"Content-Length": "8"},
)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_chunked(self: typing.Any):
response = self.fetch(
"/echo_post", method="POST", body_producer=self.async_body_producer
)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_async_body_producer_content_length(self: typing.Any):
response = self.fetch(
"/echo_post",
method="POST",
body_producer=self.async_body_producer,
headers={"Content-Length": "8"},
)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_native_body_producer_chunked(self: typing.Any):
async def body_producer(write):
await write(b"1234")
import asyncio
await asyncio.sleep(0)
await write(b"5678")
response = self.fetch("/echo_post", method="POST", body_producer=body_producer)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_native_body_producer_content_length(self: typing.Any):
async def body_producer(write):
await write(b"1234")
import asyncio
await asyncio.sleep(0)
await write(b"5678")
response = self.fetch(
"/echo_post",
method="POST",
body_producer=body_producer,
headers={"Content-Length": "8"},
)
response.rethrow()
self.assertEqual(response.body, b"12345678")
def test_100_continue(self: typing.Any):
response = self.fetch(
"/echo_post", method="POST", body=b"1234", expect_100_continue=True
)
self.assertEqual(response.body, b"1234")
def test_100_continue_early_response(self: typing.Any):
def body_producer(write):
raise Exception("should not be called")
response = self.fetch(
"/respond_in_prepare",
method="POST",
body_producer=body_producer,
expect_100_continue=True,
)
self.assertEqual(response.code, 403)
def test_streaming_follow_redirects(self: typing.Any):
# When following redirects, header and streaming callbacks
# should only be called for the final result.
# TODO(bdarnell): this test belongs in httpclient_test instead of
# simple_httpclient_test, but it fails with the version of libcurl
# available on travis-ci. Move it when that has been upgraded
# or we have a better framework to skip tests based on curl version.
headers = [] # type: typing.List[str]
chunk_bytes = [] # type: typing.List[bytes]
self.fetch(
"/redirect?url=/hello",
header_callback=headers.append,
streaming_callback=chunk_bytes.append,
)
chunks = list(map(to_unicode, chunk_bytes))
self.assertEqual(chunks, ["Hello world!"])
# Make sure we only got one set of headers.
num_start_lines = len([h for h in headers if h.startswith("HTTP/")])
self.assertEqual(num_start_lines, 1)
class SimpleHTTPClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPTestCase):
def setUp(self):
super().setUp()
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(force_instance=True, **kwargs)
class SimpleHTTPSClientTestCase(SimpleHTTPClientTestMixin, AsyncHTTPSTestCase):
def setUp(self):
super().setUp()
self.http_client = self.create_client()
def create_client(self, **kwargs):
return SimpleAsyncHTTPClient(
force_instance=True, defaults=dict(validate_cert=False), **kwargs
)
def test_ssl_options(self):
resp = self.fetch("/hello", ssl_options={})
self.assertEqual(resp.body, b"Hello world!")
def test_ssl_context(self):
resp = self.fetch("/hello", ssl_options=ssl.SSLContext(ssl.PROTOCOL_SSLv23))
self.assertEqual(resp.body, b"Hello world!")
def test_ssl_options_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception", required=False):
with self.assertRaises(ssl.SSLError):
self.fetch(
"/hello",
ssl_options=dict(cert_reqs=ssl.CERT_REQUIRED),
raise_error=True,
)
def test_ssl_context_handshake_fail(self):
with ExpectLog(gen_log, "SSL Error|Uncaught exception"):
ctx = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
ctx.verify_mode = ssl.CERT_REQUIRED
with self.assertRaises(ssl.SSLError):
self.fetch("/hello", ssl_options=ctx, raise_error=True)
def test_error_logging(self):
# No stack traces are logged for SSL errors (in this case,
# failure to validate the testing self-signed cert).
# The SSLError is exposed through ssl.SSLError.
with ExpectLog(gen_log, ".*") as expect_log:
with self.assertRaises(ssl.SSLError):
self.fetch("/", validate_cert=True, raise_error=True)
self.assertFalse(expect_log.logged_stack)
class CreateAsyncHTTPClientTestCase(AsyncTestCase):
def setUp(self):
super().setUp()
self.saved = AsyncHTTPClient._save_configuration()
def tearDown(self):
AsyncHTTPClient._restore_configuration(self.saved)
super().tearDown()
def test_max_clients(self):
AsyncHTTPClient.configure(SimpleAsyncHTTPClient)
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 10) # type: ignore
with closing(AsyncHTTPClient(max_clients=11, force_instance=True)) as client:
self.assertEqual(client.max_clients, 11) # type: ignore
# Now configure max_clients statically and try overriding it
# with each way max_clients can be passed
AsyncHTTPClient.configure(SimpleAsyncHTTPClient, max_clients=12)
with closing(AsyncHTTPClient(force_instance=True)) as client:
self.assertEqual(client.max_clients, 12) # type: ignore
with closing(AsyncHTTPClient(max_clients=13, force_instance=True)) as client:
self.assertEqual(client.max_clients, 13) # type: ignore
with closing(AsyncHTTPClient(max_clients=14, force_instance=True)) as client:
self.assertEqual(client.max_clients, 14) # type: ignore
class HTTP100ContinueTestCase(AsyncHTTPTestCase):
def respond_100(self, request):
self.http1 = request.version.startswith("HTTP/1.")
if not self.http1:
request.connection.write_headers(
ResponseStartLine("", 200, "OK"), HTTPHeaders()
)
request.connection.finish()
return
self.request = request
fut = self.request.connection.stream.write(b"HTTP/1.1 100 CONTINUE\r\n\r\n")
fut.add_done_callback(self.respond_200)
def respond_200(self, fut):
fut.result()
fut = self.request.connection.stream.write(
b"HTTP/1.1 200 OK\r\nContent-Length: 1\r\n\r\nA"
)
fut.add_done_callback(lambda f: self.request.connection.stream.close())
def get_app(self):
# Not a full Application, but works as an HTTPServer callback
return self.respond_100
def test_100_continue(self):
res = self.fetch("/")
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(res.body, b"A")
class HTTP204NoContentTestCase(AsyncHTTPTestCase):
def respond_204(self, request):
self.http1 = request.version.startswith("HTTP/1.")
if not self.http1:
# Close the request cleanly in HTTP/2; it will be skipped anyway.
request.connection.write_headers(
ResponseStartLine("", 200, "OK"), HTTPHeaders()
)
request.connection.finish()
return
# A 204 response never has a body, even if doesn't have a content-length
# (which would otherwise mean read-until-close). We simulate here a
# server that sends no content length and does not close the connection.
#
# Tests of a 204 response with no Content-Length header are included
# in SimpleHTTPClientTestMixin.
stream = request.connection.detach()
stream.write(b"HTTP/1.1 204 No content\r\n")
if request.arguments.get("error", [False])[-1]:
stream.write(b"Content-Length: 5\r\n")
else:
stream.write(b"Content-Length: 0\r\n")
stream.write(b"\r\n")
stream.close()
def get_app(self):
return self.respond_204
def test_204_no_content(self):
resp = self.fetch("/")
if not self.http1:
self.skipTest("requires HTTP/1.x")
self.assertEqual(resp.code, 204)
self.assertEqual(resp.body, b"")
def test_204_invalid_content_length(self):
# 204 status with non-zero content length is malformed
with ExpectLog(
gen_log, ".*Response with code 204 should not have body", level=logging.INFO
):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/?error=1", raise_error=True)
if not self.http1:
self.skipTest("requires HTTP/1.x")
if self.http_client.configured_class != SimpleAsyncHTTPClient:
self.skipTest("curl client accepts invalid headers")
class HostnameMappingTestCase(AsyncHTTPTestCase):
def setUp(self):
super().setUp()
self.http_client = SimpleAsyncHTTPClient(
hostname_mapping={
"www.example.com": "127.0.0.1",
("foo.example.com", 8000): ("127.0.0.1", self.get_http_port()),
}
)
def get_app(self):
return Application([url("/hello", HelloWorldHandler)])
def test_hostname_mapping(self):
response = self.fetch("http://www.example.com:%d/hello" % self.get_http_port())
response.rethrow()
self.assertEqual(response.body, b"Hello world!")
def test_port_mapping(self):
response = self.fetch("http://foo.example.com:8000/hello")
response.rethrow()
self.assertEqual(response.body, b"Hello world!")
class ResolveTimeoutTestCase(AsyncHTTPTestCase):
def setUp(self):
self.cleanup_event = Event()
test = self
# Dummy Resolver subclass that never finishes.
class BadResolver(Resolver):
@gen.coroutine
def resolve(self, *args, **kwargs):
yield test.cleanup_event.wait()
# Return something valid so the test doesn't raise during cleanup.
return [(socket.AF_INET, ("127.0.0.1", test.get_http_port()))]
super().setUp()
self.http_client = SimpleAsyncHTTPClient(resolver=BadResolver())
def get_app(self):
return Application([url("/hello", HelloWorldHandler)])
def test_resolve_timeout(self):
with self.assertRaises(HTTPTimeoutError):
self.fetch("/hello", connect_timeout=0.1, raise_error=True)
# Let the hanging coroutine clean up after itself
self.cleanup_event.set()
self.io_loop.run_sync(lambda: gen.sleep(0))
class MaxHeaderSizeTest(AsyncHTTPTestCase):
def get_app(self):
class SmallHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 100)
self.write("ok")
class LargeHeaders(RequestHandler):
def get(self):
self.set_header("X-Filler", "a" * 1000)
self.write("ok")
return Application([("/small", SmallHeaders), ("/large", LargeHeaders)])
def get_http_client(self):
return SimpleAsyncHTTPClient(max_header_size=1024)
def test_small_headers(self):
response = self.fetch("/small")
response.rethrow()
self.assertEqual(response.body, b"ok")
def test_large_headers(self):
with ExpectLog(gen_log, "Unsatisfiable read", level=logging.INFO):
with self.assertRaises(UnsatisfiableReadError):
self.fetch("/large", raise_error=True)
class MaxBodySizeTest(AsyncHTTPTestCase):
def get_app(self):
class SmallBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 64)
class LargeBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 100)
return Application([("/small", SmallBody), ("/large", LargeBody)])
def get_http_client(self):
return SimpleAsyncHTTPClient(max_body_size=1024 * 64)
def test_small_body(self):
response = self.fetch("/small")
response.rethrow()
self.assertEqual(response.body, b"a" * 1024 * 64)
def test_large_body(self):
with ExpectLog(
gen_log,
"Malformed HTTP message from None: Content-Length too long",
level=logging.INFO,
):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/large", raise_error=True)
class MaxBufferSizeTest(AsyncHTTPTestCase):
def get_app(self):
class LargeBody(RequestHandler):
def get(self):
self.write("a" * 1024 * 100)
return Application([("/large", LargeBody)])
def get_http_client(self):
# 100KB body with 64KB buffer
return SimpleAsyncHTTPClient(
max_body_size=1024 * 100, max_buffer_size=1024 * 64
)
def test_large_body(self):
response = self.fetch("/large")
response.rethrow()
self.assertEqual(response.body, b"a" * 1024 * 100)
class ChunkedWithContentLengthTest(AsyncHTTPTestCase):
def get_app(self):
class ChunkedWithContentLength(RequestHandler):
def get(self):
# Add an invalid Transfer-Encoding to the response
self.set_header("Transfer-Encoding", "chunked")
self.write("Hello world")
return Application([("/chunkwithcl", ChunkedWithContentLength)])
def get_http_client(self):
return SimpleAsyncHTTPClient()
def test_chunked_with_content_length(self):
# Make sure the invalid headers are detected
with ExpectLog(
gen_log,
(
"Malformed HTTP message from None: Response "
"with both Transfer-Encoding and Content-Length"
),
level=logging.INFO,
):
with self.assertRaises(HTTPStreamClosedError):
self.fetch("/chunkwithcl", raise_error=True)

View file

@ -0,0 +1 @@
this is the index

View file

@ -0,0 +1,2 @@
User-agent: *
Disallow: /

View file

@ -0,0 +1,23 @@
<?xml version="1.0"?>
<data>
<country name="Liechtenstein">
<rank>1</rank>
<year>2008</year>
<gdppc>141100</gdppc>
<neighbor name="Austria" direction="E"/>
<neighbor name="Switzerland" direction="W"/>
</country>
<country name="Singapore">
<rank>4</rank>
<year>2011</year>
<gdppc>59900</gdppc>
<neighbor name="Malaysia" direction="N"/>
</country>
<country name="Panama">
<rank>68</rank>
<year>2011</year>
<gdppc>13600</gdppc>
<neighbor name="Costa Rica" direction="W"/>
<neighbor name="Colombia" direction="E"/>
</country>
</data>

View file

@ -0,0 +1,2 @@
This file should not be served by StaticFileHandler even though
its name starts with "static".

View file

@ -0,0 +1,438 @@
#
# Copyright 2014 Facebook
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
from contextlib import closing
import getpass
import os
import socket
import unittest
from tornado.concurrent import Future
from tornado.netutil import bind_sockets, Resolver
from tornado.queues import Queue
from tornado.tcpclient import TCPClient, _Connector
from tornado.tcpserver import TCPServer
from tornado.testing import AsyncTestCase, gen_test
from tornado.test.util import skipIfNoIPv6, refusing_port, skipIfNonUnix
from tornado.gen import TimeoutError
import typing
if typing.TYPE_CHECKING:
from tornado.iostream import IOStream # noqa: F401
from typing import List, Dict, Tuple # noqa: F401
# Fake address families for testing. Used in place of AF_INET
# and AF_INET6 because some installations do not have AF_INET6.
AF1, AF2 = 1, 2
class TestTCPServer(TCPServer):
def __init__(self, family):
super().__init__()
self.streams = [] # type: List[IOStream]
self.queue = Queue() # type: Queue[IOStream]
sockets = bind_sockets(0, "localhost", family)
self.add_sockets(sockets)
self.port = sockets[0].getsockname()[1]
def handle_stream(self, stream, address):
self.streams.append(stream)
self.queue.put(stream)
def stop(self):
super().stop()
for stream in self.streams:
stream.close()
class TCPClientTest(AsyncTestCase):
def setUp(self):
super().setUp()
self.server = None
self.client = TCPClient()
def start_server(self, family):
if family == socket.AF_UNSPEC and "TRAVIS" in os.environ:
self.skipTest("dual-stack servers often have port conflicts on travis")
self.server = TestTCPServer(family)
return self.server.port
def stop_server(self):
if self.server is not None:
self.server.stop()
self.server = None
def tearDown(self):
self.client.close()
self.stop_server()
super().tearDown()
def skipIfLocalhostV4(self):
# The port used here doesn't matter, but some systems require it
# to be non-zero if we do not also pass AI_PASSIVE.
addrinfo = self.io_loop.run_sync(lambda: Resolver().resolve("localhost", 80))
families = set(addr[0] for addr in addrinfo)
if socket.AF_INET6 not in families:
self.skipTest("localhost does not resolve to ipv6")
@gen_test
def do_test_connect(self, family, host, source_ip=None, source_port=None):
port = self.start_server(family)
stream = yield self.client.connect(
host, port, source_ip=source_ip, source_port=source_port
)
assert self.server is not None
server_stream = yield self.server.queue.get()
with closing(stream):
stream.write(b"hello")
data = yield server_stream.read_bytes(5)
self.assertEqual(data, b"hello")
def test_connect_ipv4_ipv4(self):
self.do_test_connect(socket.AF_INET, "127.0.0.1")
def test_connect_ipv4_dual(self):
self.do_test_connect(socket.AF_INET, "localhost")
@skipIfNoIPv6
def test_connect_ipv6_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_INET6, "::1")
@skipIfNoIPv6
def test_connect_ipv6_dual(self):
self.skipIfLocalhostV4()
if Resolver.configured_class().__name__.endswith("TwistedResolver"):
self.skipTest("TwistedResolver does not support multiple addresses")
self.do_test_connect(socket.AF_INET6, "localhost")
def test_connect_unspec_ipv4(self):
self.do_test_connect(socket.AF_UNSPEC, "127.0.0.1")
@skipIfNoIPv6
def test_connect_unspec_ipv6(self):
self.skipIfLocalhostV4()
self.do_test_connect(socket.AF_UNSPEC, "::1")
def test_connect_unspec_dual(self):
self.do_test_connect(socket.AF_UNSPEC, "localhost")
@gen_test
def test_refused_ipv4(self):
cleanup_func, port = refusing_port()
self.addCleanup(cleanup_func)
with self.assertRaises(IOError):
yield self.client.connect("127.0.0.1", port)
def test_source_ip_fail(self):
"""Fail when trying to use the source IP Address '8.8.8.8'.
"""
self.assertRaises(
socket.error,
self.do_test_connect,
socket.AF_INET,
"127.0.0.1",
source_ip="8.8.8.8",
)
def test_source_ip_success(self):
"""Success when trying to use the source IP Address '127.0.0.1'.
"""
self.do_test_connect(socket.AF_INET, "127.0.0.1", source_ip="127.0.0.1")
@skipIfNonUnix
def test_source_port_fail(self):
"""Fail when trying to use source port 1.
"""
if getpass.getuser() == "root":
# Root can use any port so we can't easily force this to fail.
# This is mainly relevant for docker.
self.skipTest("running as root")
self.assertRaises(
socket.error,
self.do_test_connect,
socket.AF_INET,
"127.0.0.1",
source_port=1,
)
@gen_test
def test_connect_timeout(self):
timeout = 0.05
class TimeoutResolver(Resolver):
def resolve(self, *args, **kwargs):
return Future() # never completes
with self.assertRaises(TimeoutError):
yield TCPClient(resolver=TimeoutResolver()).connect(
"1.2.3.4", 12345, timeout=timeout
)
class TestConnectorSplit(unittest.TestCase):
def test_one_family(self):
# These addresses aren't in the right format, but split doesn't care.
primary, secondary = _Connector.split([(AF1, "a"), (AF1, "b")])
self.assertEqual(primary, [(AF1, "a"), (AF1, "b")])
self.assertEqual(secondary, [])
def test_mixed(self):
primary, secondary = _Connector.split(
[(AF1, "a"), (AF2, "b"), (AF1, "c"), (AF2, "d")]
)
self.assertEqual(primary, [(AF1, "a"), (AF1, "c")])
self.assertEqual(secondary, [(AF2, "b"), (AF2, "d")])
class ConnectorTest(AsyncTestCase):
class FakeStream(object):
def __init__(self):
self.closed = False
def close(self):
self.closed = True
def setUp(self):
super().setUp()
self.connect_futures = (
{}
) # type: Dict[Tuple[int, typing.Any], Future[ConnectorTest.FakeStream]]
self.streams = {} # type: Dict[typing.Any, ConnectorTest.FakeStream]
self.addrinfo = [(AF1, "a"), (AF1, "b"), (AF2, "c"), (AF2, "d")]
def tearDown(self):
# Unless explicitly checked (and popped) in the test, we shouldn't
# be closing any streams
for stream in self.streams.values():
self.assertFalse(stream.closed)
super().tearDown()
def create_stream(self, af, addr):
stream = ConnectorTest.FakeStream()
self.streams[addr] = stream
future = Future() # type: Future[ConnectorTest.FakeStream]
self.connect_futures[(af, addr)] = future
return stream, future
def assert_pending(self, *keys):
self.assertEqual(sorted(self.connect_futures.keys()), sorted(keys))
def resolve_connect(self, af, addr, success):
future = self.connect_futures.pop((af, addr))
if success:
future.set_result(self.streams[addr])
else:
self.streams.pop(addr)
future.set_exception(IOError())
# Run the loop to allow callbacks to be run.
self.io_loop.add_callback(self.stop)
self.wait()
def assert_connector_streams_closed(self, conn):
for stream in conn.streams:
self.assertTrue(stream.closed)
def start_connect(self, addrinfo):
conn = _Connector(addrinfo, self.create_stream)
# Give it a huge timeout; we'll trigger timeouts manually.
future = conn.start(3600, connect_timeout=self.io_loop.time() + 3600)
return conn, future
def test_immediate_success(self):
conn, future = self.start_connect(self.addrinfo)
self.assertEqual(list(self.connect_futures.keys()), [(AF1, "a")])
self.resolve_connect(AF1, "a", True)
self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
def test_immediate_failure(self):
# Fail with just one address.
conn, future = self.start_connect([(AF1, "a")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
self.resolve_connect(AF1, "b", True)
self.assertEqual(future.result(), (AF1, "b", self.streams["b"]))
def test_one_family_second_try_failure(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
self.resolve_connect(AF1, "b", False)
self.assertRaises(IOError, future.result)
def test_one_family_second_try_timeout(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
# trigger the timeout while the first lookup is pending;
# nothing happens.
conn.on_timeout()
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
self.resolve_connect(AF1, "b", True)
self.assertEqual(future.result(), (AF1, "b", self.streams["b"]))
def test_two_families_immediate_failure(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"), (AF2, "c"))
self.resolve_connect(AF1, "b", False)
self.resolve_connect(AF2, "c", True)
self.assertEqual(future.result(), (AF2, "c", self.streams["c"]))
def test_two_families_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_timeout()
self.assert_pending((AF1, "a"), (AF2, "c"))
self.resolve_connect(AF2, "c", True)
self.assertEqual(future.result(), (AF2, "c", self.streams["c"]))
# resolving 'a' after the connection has completed doesn't start 'b'
self.resolve_connect(AF1, "a", False)
self.assert_pending()
def test_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_timeout()
self.assert_pending((AF1, "a"), (AF2, "c"))
self.resolve_connect(AF1, "a", True)
self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
# resolving 'c' after completion closes the connection.
self.resolve_connect(AF2, "c", True)
self.assertTrue(self.streams.pop("c").closed)
def test_all_fail(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_timeout()
self.assert_pending((AF1, "a"), (AF2, "c"))
self.resolve_connect(AF2, "c", False)
self.assert_pending((AF1, "a"), (AF2, "d"))
self.resolve_connect(AF2, "d", False)
# one queue is now empty
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
self.assertFalse(future.done())
self.resolve_connect(AF1, "b", False)
self.assertRaises(IOError, future.result)
def test_one_family_timeout_after_connect_timeout(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
conn.on_connect_timeout()
# the connector will close all streams on connect timeout, we
# should explicitly pop the connect_future.
self.connect_futures.pop((AF1, "a"))
self.assertTrue(self.streams.pop("a").closed)
conn.on_timeout()
# if the future is set with TimeoutError, we will not iterate next
# possible address.
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_one_family_success_before_connect_timeout(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", True)
conn.on_connect_timeout()
self.assert_pending()
self.assertEqual(self.streams["a"].closed, False)
# success stream will be pop
self.assertEqual(len(conn.streams), 0)
# streams in connector should be closed after connect timeout
self.assert_connector_streams_closed(conn)
self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
def test_one_family_second_try_after_connect_timeout(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, "b"))
self.assertTrue(self.streams.pop("b").closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_one_family_second_try_failure_before_connect_timeout(self):
conn, future = self.start_connect([(AF1, "a"), (AF1, "b")])
self.assert_pending((AF1, "a"))
self.resolve_connect(AF1, "a", False)
self.assert_pending((AF1, "b"))
self.resolve_connect(AF1, "b", False)
conn.on_connect_timeout()
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(IOError, future.result)
def test_two_family_timeout_before_connect_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_timeout()
self.assert_pending((AF1, "a"), (AF2, "c"))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, "a"))
self.assertTrue(self.streams.pop("a").closed)
self.connect_futures.pop((AF2, "c"))
self.assertTrue(self.streams.pop("c").closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 2)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)
def test_two_family_success_after_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_timeout()
self.assert_pending((AF1, "a"), (AF2, "c"))
self.resolve_connect(AF1, "a", True)
# if one of streams succeed, connector will close all other streams
self.connect_futures.pop((AF2, "c"))
self.assertTrue(self.streams.pop("c").closed)
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertEqual(future.result(), (AF1, "a", self.streams["a"]))
def test_two_family_timeout_after_connect_timeout(self):
conn, future = self.start_connect(self.addrinfo)
self.assert_pending((AF1, "a"))
conn.on_connect_timeout()
self.connect_futures.pop((AF1, "a"))
self.assertTrue(self.streams.pop("a").closed)
self.assert_pending()
conn.on_timeout()
# if the future is set with TimeoutError, connector will not
# trigger secondary address.
self.assert_pending()
self.assertEqual(len(conn.streams), 1)
self.assert_connector_streams_closed(conn)
self.assertRaises(TimeoutError, future.result)

View file

@ -0,0 +1,192 @@
import socket
import subprocess
import sys
import textwrap
import unittest
from tornado.escape import utf8, to_unicode
from tornado import gen
from tornado.iostream import IOStream
from tornado.log import app_log
from tornado.tcpserver import TCPServer
from tornado.test.util import skipIfNonUnix
from tornado.testing import AsyncTestCase, ExpectLog, bind_unused_port, gen_test
class TCPServerTest(AsyncTestCase):
@gen_test
def test_handle_stream_coroutine_logging(self):
# handle_stream may be a coroutine and any exception in its
# Future will be logged.
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
yield stream.read_bytes(len(b"hello"))
stream.close()
1 / 0
server = client = None
try:
sock, port = bind_unused_port()
server = TestServer()
server.add_socket(sock)
client = IOStream(socket.socket())
with ExpectLog(app_log, "Exception in callback"):
yield client.connect(("localhost", port))
yield client.write(b"hello")
yield client.read_until_close()
yield gen.moment
finally:
if server is not None:
server.stop()
if client is not None:
client.close()
@gen_test
def test_handle_stream_native_coroutine(self):
# handle_stream may be a native coroutine.
class TestServer(TCPServer):
async def handle_stream(self, stream, address):
stream.write(b"data")
stream.close()
sock, port = bind_unused_port()
server = TestServer()
server.add_socket(sock)
client = IOStream(socket.socket())
yield client.connect(("localhost", port))
result = yield client.read_until_close()
self.assertEqual(result, b"data")
server.stop()
client.close()
def test_stop_twice(self):
sock, port = bind_unused_port()
server = TCPServer()
server.add_socket(sock)
server.stop()
server.stop()
@gen_test
def test_stop_in_callback(self):
# Issue #2069: calling server.stop() in a loop callback should not
# raise EBADF when the loop handles other server connection
# requests in the same loop iteration
class TestServer(TCPServer):
@gen.coroutine
def handle_stream(self, stream, address):
server.stop() # type: ignore
yield stream.read_until_close()
sock, port = bind_unused_port()
server = TestServer()
server.add_socket(sock)
server_addr = ("localhost", port)
N = 40
clients = [IOStream(socket.socket()) for i in range(N)]
connected_clients = []
@gen.coroutine
def connect(c):
try:
yield c.connect(server_addr)
except EnvironmentError:
pass
else:
connected_clients.append(c)
yield [connect(c) for c in clients]
self.assertGreater(len(connected_clients), 0, "all clients failed connecting")
try:
if len(connected_clients) == N:
# Ideally we'd make the test deterministic, but we're testing
# for a race condition in combination with the system's TCP stack...
self.skipTest(
"at least one client should fail connecting "
"for the test to be meaningful"
)
finally:
for c in connected_clients:
c.close()
# Here tearDown() would re-raise the EBADF encountered in the IO loop
@skipIfNonUnix
class TestMultiprocess(unittest.TestCase):
# These tests verify that the two multiprocess examples from the
# TCPServer docs work. Both tests start a server with three worker
# processes, each of which prints its task id to stdout (a single
# byte, so we don't have to worry about atomicity of the shared
# stdout stream) and then exits.
def run_subproc(self, code):
proc = subprocess.Popen(
sys.executable, stdin=subprocess.PIPE, stdout=subprocess.PIPE
)
proc.stdin.write(utf8(code))
proc.stdin.close()
proc.wait()
stdout = proc.stdout.read()
proc.stdout.close()
if proc.returncode != 0:
raise RuntimeError(
"Process returned %d. stdout=%r" % (proc.returncode, stdout)
)
return to_unicode(stdout)
def test_single(self):
# As a sanity check, run the single-process version through this test
# harness too.
code = textwrap.dedent(
"""
from tornado.ioloop import IOLoop
from tornado.tcpserver import TCPServer
server = TCPServer()
server.listen(0, address='127.0.0.1')
IOLoop.current().run_sync(lambda: None)
print('012', end='')
"""
)
out = self.run_subproc(code)
self.assertEqual("".join(sorted(out)), "012")
def test_simple(self):
code = textwrap.dedent(
"""
from tornado.ioloop import IOLoop
from tornado.process import task_id
from tornado.tcpserver import TCPServer
server = TCPServer()
server.bind(0, address='127.0.0.1')
server.start(3)
IOLoop.current().run_sync(lambda: None)
print(task_id(), end='')
"""
)
out = self.run_subproc(code)
self.assertEqual("".join(sorted(out)), "012")
def test_advanced(self):
code = textwrap.dedent(
"""
from tornado.ioloop import IOLoop
from tornado.netutil import bind_sockets
from tornado.process import fork_processes, task_id
from tornado.ioloop import IOLoop
from tornado.tcpserver import TCPServer
sockets = bind_sockets(0, address='127.0.0.1')
fork_processes(3)
server = TCPServer()
server.add_sockets(sockets)
IOLoop.current().run_sync(lambda: None)
print(task_id(), end='')
"""
)
out = self.run_subproc(code)
self.assertEqual("".join(sorted(out)), "012")

View file

@ -0,0 +1,536 @@
import os
import traceback
import unittest
from tornado.escape import utf8, native_str, to_unicode
from tornado.template import Template, DictLoader, ParseError, Loader
from tornado.util import ObjectDict
import typing # noqa: F401
class TemplateTest(unittest.TestCase):
def test_simple(self):
template = Template("Hello {{ name }}!")
self.assertEqual(template.generate(name="Ben"), b"Hello Ben!")
def test_bytes(self):
template = Template("Hello {{ name }}!")
self.assertEqual(template.generate(name=utf8("Ben")), b"Hello Ben!")
def test_expressions(self):
template = Template("2 + 2 = {{ 2 + 2 }}")
self.assertEqual(template.generate(), b"2 + 2 = 4")
def test_comment(self):
template = Template("Hello{# TODO i18n #} {{ name }}!")
self.assertEqual(template.generate(name=utf8("Ben")), b"Hello Ben!")
def test_include(self):
loader = DictLoader(
{
"index.html": '{% include "header.html" %}\nbody text',
"header.html": "header text",
}
)
self.assertEqual(
loader.load("index.html").generate(), b"header text\nbody text"
)
def test_extends(self):
loader = DictLoader(
{
"base.html": """\
<title>{% block title %}default title{% end %}</title>
<body>{% block body %}default body{% end %}</body>
""",
"page.html": """\
{% extends "base.html" %}
{% block title %}page title{% end %}
{% block body %}page body{% end %}
""",
}
)
self.assertEqual(
loader.load("page.html").generate(),
b"<title>page title</title>\n<body>page body</body>\n",
)
def test_relative_load(self):
loader = DictLoader(
{
"a/1.html": "{% include '2.html' %}",
"a/2.html": "{% include '../b/3.html' %}",
"b/3.html": "ok",
}
)
self.assertEqual(loader.load("a/1.html").generate(), b"ok")
def test_escaping(self):
self.assertRaises(ParseError, lambda: Template("{{"))
self.assertRaises(ParseError, lambda: Template("{%"))
self.assertEqual(Template("{{!").generate(), b"{{")
self.assertEqual(Template("{%!").generate(), b"{%")
self.assertEqual(Template("{#!").generate(), b"{#")
self.assertEqual(
Template("{{ 'expr' }} {{!jquery expr}}").generate(),
b"expr {{jquery expr}}",
)
def test_unicode_template(self):
template = Template(utf8(u"\u00e9"))
self.assertEqual(template.generate(), utf8(u"\u00e9"))
def test_unicode_literal_expression(self):
# Unicode literals should be usable in templates. Note that this
# test simulates unicode characters appearing directly in the
# template file (with utf8 encoding), i.e. \u escapes would not
# be used in the template file itself.
template = Template(utf8(u'{{ "\u00e9" }}'))
self.assertEqual(template.generate(), utf8(u"\u00e9"))
def test_custom_namespace(self):
loader = DictLoader(
{"test.html": "{{ inc(5) }}"}, namespace={"inc": lambda x: x + 1}
)
self.assertEqual(loader.load("test.html").generate(), b"6")
def test_apply(self):
def upper(s):
return s.upper()
template = Template(utf8("{% apply upper %}foo{% end %}"))
self.assertEqual(template.generate(upper=upper), b"FOO")
def test_unicode_apply(self):
def upper(s):
return to_unicode(s).upper()
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
def test_bytes_apply(self):
def upper(s):
return utf8(to_unicode(s).upper())
template = Template(utf8(u"{% apply upper %}foo \u00e9{% end %}"))
self.assertEqual(template.generate(upper=upper), utf8(u"FOO \u00c9"))
def test_if(self):
template = Template(utf8("{% if x > 4 %}yes{% else %}no{% end %}"))
self.assertEqual(template.generate(x=5), b"yes")
self.assertEqual(template.generate(x=3), b"no")
def test_if_empty_body(self):
template = Template(utf8("{% if True %}{% else %}{% end %}"))
self.assertEqual(template.generate(), b"")
def test_try(self):
template = Template(
utf8(
"""{% try %}
try{% set y = 1/x %}
{% except %}-except
{% else %}-else
{% finally %}-finally
{% end %}"""
)
)
self.assertEqual(template.generate(x=1), b"\ntry\n-else\n-finally\n")
self.assertEqual(template.generate(x=0), b"\ntry-except\n-finally\n")
def test_comment_directive(self):
template = Template(utf8("{% comment blah blah %}foo"))
self.assertEqual(template.generate(), b"foo")
def test_break_continue(self):
template = Template(
utf8(
"""\
{% for i in range(10) %}
{% if i == 2 %}
{% continue %}
{% end %}
{{ i }}
{% if i == 6 %}
{% break %}
{% end %}
{% end %}"""
)
)
result = template.generate()
# remove extraneous whitespace
result = b"".join(result.split())
self.assertEqual(result, b"013456")
def test_break_outside_loop(self):
try:
Template(utf8("{% break %}"))
raise Exception("Did not get expected exception")
except ParseError:
pass
def test_break_in_apply(self):
# This test verifies current behavior, although of course it would
# be nice if apply didn't cause seemingly unrelated breakage
try:
Template(
utf8("{% for i in [] %}{% apply foo %}{% break %}{% end %}{% end %}")
)
raise Exception("Did not get expected exception")
except ParseError:
pass
@unittest.skip("no testable future imports")
def test_no_inherit_future(self):
# TODO(bdarnell): make a test like this for one of the future
# imports available in python 3. Unfortunately they're harder
# to use in a template than division was.
# This file has from __future__ import division...
self.assertEqual(1 / 2, 0.5)
# ...but the template doesn't
template = Template("{{ 1 / 2 }}")
self.assertEqual(template.generate(), "0")
def test_non_ascii_name(self):
loader = DictLoader({u"t\u00e9st.html": "hello"})
self.assertEqual(loader.load(u"t\u00e9st.html").generate(), b"hello")
class StackTraceTest(unittest.TestCase):
def test_error_line_number_expression(self):
loader = DictLoader(
{
"test.html": """one
two{{1/0}}
three
"""
}
)
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
def test_error_line_number_directive(self):
loader = DictLoader(
{
"test.html": """one
two{%if 1/0%}
three{%end%}
"""
}
)
try:
loader.load("test.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# test.html:2" in traceback.format_exc())
def test_error_line_number_module(self):
loader = None # type: typing.Optional[DictLoader]
def load_generate(path, **kwargs):
assert loader is not None
return loader.load(path).generate(**kwargs)
loader = DictLoader(
{"base.html": "{% module Template('sub.html') %}", "sub.html": "{{1/0}}"},
namespace={"_tt_modules": ObjectDict(Template=load_generate)},
)
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue("# base.html:1" in exc_stack)
self.assertTrue("# sub.html:1" in exc_stack)
def test_error_line_number_include(self):
loader = DictLoader(
{"base.html": "{% include 'sub.html' %}", "sub.html": "{{1/0}}"}
)
try:
loader.load("base.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# sub.html:1 (via base.html:1)" in traceback.format_exc())
def test_error_line_number_extends_base_error(self):
loader = DictLoader(
{"base.html": "{{1/0}}", "sub.html": "{% extends 'base.html' %}"}
)
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
exc_stack = traceback.format_exc()
self.assertTrue("# base.html:1" in exc_stack)
def test_error_line_number_extends_sub_error(self):
loader = DictLoader(
{
"base.html": "{% block 'block' %}{% end %}",
"sub.html": """
{% extends 'base.html' %}
{% block 'block' %}
{{1/0}}
{% end %}
""",
}
)
try:
loader.load("sub.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue("# sub.html:4 (via base.html:1)" in traceback.format_exc())
def test_multi_includes(self):
loader = DictLoader(
{
"a.html": "{% include 'b.html' %}",
"b.html": "{% include 'c.html' %}",
"c.html": "{{1/0}}",
}
)
try:
loader.load("a.html").generate()
self.fail("did not get expected exception")
except ZeroDivisionError:
self.assertTrue(
"# c.html:1 (via b.html:1, a.html:1)" in traceback.format_exc()
)
class ParseErrorDetailTest(unittest.TestCase):
def test_details(self):
loader = DictLoader({"foo.html": "\n\n{{"})
with self.assertRaises(ParseError) as cm:
loader.load("foo.html")
self.assertEqual("Missing end expression }} at foo.html:3", str(cm.exception))
self.assertEqual("foo.html", cm.exception.filename)
self.assertEqual(3, cm.exception.lineno)
def test_custom_parse_error(self):
# Make sure that ParseErrors remain compatible with their
# pre-4.3 signature.
self.assertEqual("asdf at None:0", str(ParseError("asdf")))
class AutoEscapeTest(unittest.TestCase):
def setUp(self):
self.templates = {
"escaped.html": "{% autoescape xhtml_escape %}{{ name }}",
"unescaped.html": "{% autoescape None %}{{ name }}",
"default.html": "{{ name }}",
"include.html": """\
escaped: {% include 'escaped.html' %}
unescaped: {% include 'unescaped.html' %}
default: {% include 'default.html' %}
""",
"escaped_block.html": """\
{% autoescape xhtml_escape %}\
{% block name %}base: {{ name }}{% end %}""",
"unescaped_block.html": """\
{% autoescape None %}\
{% block name %}base: {{ name }}{% end %}""",
# Extend a base template with different autoescape policy,
# with and without overriding the base's blocks
"escaped_extends_unescaped.html": """\
{% autoescape xhtml_escape %}\
{% extends "unescaped_block.html" %}""",
"escaped_overrides_unescaped.html": """\
{% autoescape xhtml_escape %}\
{% extends "unescaped_block.html" %}\
{% block name %}extended: {{ name }}{% end %}""",
"unescaped_extends_escaped.html": """\
{% autoescape None %}\
{% extends "escaped_block.html" %}""",
"unescaped_overrides_escaped.html": """\
{% autoescape None %}\
{% extends "escaped_block.html" %}\
{% block name %}extended: {{ name }}{% end %}""",
"raw_expression.html": """\
{% autoescape xhtml_escape %}\
expr: {{ name }}
raw: {% raw name %}""",
}
def test_default_off(self):
loader = DictLoader(self.templates, autoescape=None)
name = "Bobby <table>s"
self.assertEqual(
loader.load("escaped.html").generate(name=name), b"Bobby &lt;table&gt;s"
)
self.assertEqual(
loader.load("unescaped.html").generate(name=name), b"Bobby <table>s"
)
self.assertEqual(
loader.load("default.html").generate(name=name), b"Bobby <table>s"
)
self.assertEqual(
loader.load("include.html").generate(name=name),
b"escaped: Bobby &lt;table&gt;s\n"
b"unescaped: Bobby <table>s\n"
b"default: Bobby <table>s\n",
)
def test_default_on(self):
loader = DictLoader(self.templates, autoescape="xhtml_escape")
name = "Bobby <table>s"
self.assertEqual(
loader.load("escaped.html").generate(name=name), b"Bobby &lt;table&gt;s"
)
self.assertEqual(
loader.load("unescaped.html").generate(name=name), b"Bobby <table>s"
)
self.assertEqual(
loader.load("default.html").generate(name=name), b"Bobby &lt;table&gt;s"
)
self.assertEqual(
loader.load("include.html").generate(name=name),
b"escaped: Bobby &lt;table&gt;s\n"
b"unescaped: Bobby <table>s\n"
b"default: Bobby &lt;table&gt;s\n",
)
def test_unextended_block(self):
loader = DictLoader(self.templates)
name = "<script>"
self.assertEqual(
loader.load("escaped_block.html").generate(name=name),
b"base: &lt;script&gt;",
)
self.assertEqual(
loader.load("unescaped_block.html").generate(name=name), b"base: <script>"
)
def test_extended_block(self):
loader = DictLoader(self.templates)
def render(name):
return loader.load(name).generate(name="<script>")
self.assertEqual(render("escaped_extends_unescaped.html"), b"base: <script>")
self.assertEqual(
render("escaped_overrides_unescaped.html"), b"extended: &lt;script&gt;"
)
self.assertEqual(
render("unescaped_extends_escaped.html"), b"base: &lt;script&gt;"
)
self.assertEqual(
render("unescaped_overrides_escaped.html"), b"extended: <script>"
)
def test_raw_expression(self):
loader = DictLoader(self.templates)
def render(name):
return loader.load(name).generate(name='<>&"')
self.assertEqual(
render("raw_expression.html"), b"expr: &lt;&gt;&amp;&quot;\n" b'raw: <>&"'
)
def test_custom_escape(self):
loader = DictLoader({"foo.py": "{% autoescape py_escape %}s = {{ name }}\n"})
def py_escape(s):
self.assertEqual(type(s), bytes)
return repr(native_str(s))
def render(template, name):
return loader.load(template).generate(py_escape=py_escape, name=name)
self.assertEqual(render("foo.py", "<html>"), b"s = '<html>'\n")
self.assertEqual(render("foo.py", "';sys.exit()"), b"""s = "';sys.exit()"\n""")
self.assertEqual(
render("foo.py", ["not a string"]), b"""s = "['not a string']"\n"""
)
def test_manual_minimize_whitespace(self):
# Whitespace including newlines is allowed within template tags
# and directives, and this is one way to avoid long lines while
# keeping extra whitespace out of the rendered output.
loader = DictLoader(
{
"foo.txt": """\
{% for i in items
%}{% if i > 0 %}, {% end %}{#
#}{{i
}}{% end
%}"""
}
)
self.assertEqual(
loader.load("foo.txt").generate(items=range(5)), b"0, 1, 2, 3, 4"
)
def test_whitespace_by_filename(self):
# Default whitespace handling depends on the template filename.
loader = DictLoader(
{
"foo.html": " \n\t\n asdf\t ",
"bar.js": " \n\n\n\t qwer ",
"baz.txt": "\t zxcv\n\n",
"include.html": " {% include baz.txt %} \n ",
"include.txt": "\t\t{% include foo.html %} ",
}
)
# HTML and JS files have whitespace compressed by default.
self.assertEqual(loader.load("foo.html").generate(), b"\nasdf ")
self.assertEqual(loader.load("bar.js").generate(), b"\nqwer ")
# TXT files do not.
self.assertEqual(loader.load("baz.txt").generate(), b"\t zxcv\n\n")
# Each file maintains its own status even when included in
# a file of the other type.
self.assertEqual(loader.load("include.html").generate(), b" \t zxcv\n\n\n")
self.assertEqual(loader.load("include.txt").generate(), b"\t\t\nasdf ")
def test_whitespace_by_loader(self):
templates = {"foo.html": "\t\tfoo\n\n", "bar.txt": "\t\tbar\n\n"}
loader = DictLoader(templates, whitespace="all")
self.assertEqual(loader.load("foo.html").generate(), b"\t\tfoo\n\n")
self.assertEqual(loader.load("bar.txt").generate(), b"\t\tbar\n\n")
loader = DictLoader(templates, whitespace="single")
self.assertEqual(loader.load("foo.html").generate(), b" foo\n")
self.assertEqual(loader.load("bar.txt").generate(), b" bar\n")
loader = DictLoader(templates, whitespace="oneline")
self.assertEqual(loader.load("foo.html").generate(), b" foo ")
self.assertEqual(loader.load("bar.txt").generate(), b" bar ")
def test_whitespace_directive(self):
loader = DictLoader(
{
"foo.html": """\
{% whitespace oneline %}
{% for i in range(3) %}
{{ i }}
{% end %}
{% whitespace all %}
pre\tformatted
"""
}
)
self.assertEqual(
loader.load("foo.html").generate(), b" 0 1 2 \n pre\tformatted\n"
)
class TemplateLoaderTest(unittest.TestCase):
def setUp(self):
self.loader = Loader(os.path.join(os.path.dirname(__file__), "templates"))
def test_utf8_in_file(self):
tmpl = self.loader.load("utf8.html")
result = tmpl.generate()
self.assertEqual(to_unicode(result).strip(), u"H\u00e9llo")

View file

@ -0,0 +1 @@
Héllo

View file

@ -0,0 +1,20 @@
-----BEGIN CERTIFICATE-----
MIIDWzCCAkOgAwIBAgIUV4spou0CenmvKqa7Hml/MC+JKiAwDQYJKoZIhvcNAQEL
BQAwPTELMAkGA1UEBhMCVVMxEzARBgNVBAgMCkNhbGlmb3JuaWExGTAXBgNVBAoM
EFRvcm5hZG8gV2ViIFRlc3QwHhcNMTgwOTI5MTM1NjQ1WhcNMjgwOTI2MTM1NjQ1
WjA9MQswCQYDVQQGEwJVUzETMBEGA1UECAwKQ2FsaWZvcm5pYTEZMBcGA1UECgwQ
VG9ybmFkbyBXZWIgVGVzdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEB
AKT0LdyI8tW5uwP3ahE8BFSz+j3SsKBDv/0cKvqxVVE6sLEST2s3HjArZvIIG5sb
iBkWDrqnZ6UKDvB4jlobLGAkepxDbrxHWxK53n0C28XXGLqJQ01TlTZ5rpjttMeg
5SKNjHbxpOvpUwwQS4br4WjZKKyTGiXpFkFUty+tYVU35/U2yyvreWHmzpHx/25t
H7O2RBARVwJYKOGPtlH62lQjpIWfVfklY4Ip8Hjl3B6rBxPyBULmVQw0qgoZn648
oa4oLjs0wnYBz01gVjNMDHej52SsB/ieH7W1TxFMzqOlcvHh41uFbQJPgcXsruSS
9Z4twzSWkUp2vk/C//4Sz38CAwEAAaNTMFEwHQYDVR0OBBYEFLf8fQ5+u8sDWAd3
r5ZjZ5MmDWJeMB8GA1UdIwQYMBaAFLf8fQ5+u8sDWAd3r5ZjZ5MmDWJeMA8GA1Ud
EwEB/wQFMAMBAf8wDQYJKoZIhvcNAQELBQADggEBADkkm3pIb9IeqVNmQ2uhQOgw
UwyToTYUHNTb/Nm5lzBTBqC8gbXAS24RQ30AB/7G115Uxeo+YMKfITxm/CgR+vhF
F59/YrzwXj+G8bdbuVl/UbB6f9RSp+Zo93rUZAtPWr77gxLUrcwSRzzDwxFjC2nC
6eigbkvt1OQY775RwnFAt7HKPclE0Out+cGJIboJuO1f3r57ZdyFH0GzbZEff/7K
atGXohijWJjYvU4mk0KFHORZrcBpsv9cfkFbmgVmiRwxRJ1tLauHM3Ne+VfqYE5M
4rTStSyz3ASqVKJ2iFMQueNR/tUOuDlfRt+0nhJMuYSSkW+KTgnwyOGU9cv+mxA=
-----END CERTIFICATE-----

View file

@ -0,0 +1,28 @@
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCk9C3ciPLVubsD
92oRPARUs/o90rCgQ7/9HCr6sVVROrCxEk9rNx4wK2byCBubG4gZFg66p2elCg7w
eI5aGyxgJHqcQ268R1sSud59AtvF1xi6iUNNU5U2ea6Y7bTHoOUijYx28aTr6VMM
EEuG6+Fo2Siskxol6RZBVLcvrWFVN+f1Nssr63lh5s6R8f9ubR+ztkQQEVcCWCjh
j7ZR+tpUI6SFn1X5JWOCKfB45dweqwcT8gVC5lUMNKoKGZ+uPKGuKC47NMJ2Ac9N
YFYzTAx3o+dkrAf4nh+1tU8RTM6jpXLx4eNbhW0CT4HF7K7kkvWeLcM0lpFKdr5P
wv/+Es9/AgMBAAECggEABi6AaXtYXloPgB6NgwfUwbfc8OQsalUfpMShd7OdluW0
KW6eO05de0ClIvzay/1EJGyHMMeFQtIVrT1XWFkcWJ4FWkXMqJGkABenFtg8lDVz
X8o1E3jGZrw4ptKBq9mDvL/BO9PiclTUH+ecbPn6AIvi0lTQ7grGIryiAM9mjmLy
jpCwoutF2LD4RPNg8vqWe/Z1rQw5lp8FOHhRwPooHHeoq1bSrp8dqvVAwAam7Mmf
uFgI8jrNycPgr2cwEEtbq2TQ625MhVnCpwT+kErmAStfbXXuqv1X1ZZgiNxf+61C
OL0bhPRVIHmmjiK/5qHRuN4Q5u9/Yp2SJ4W5xadSQQKBgQDR7dnOlYYQiaoPJeD/
7jcLVJbWwbr7bE19O/QpYAtkA/FtGlKr+hQxPhK6OYp+in8eHf+ga/NSAjCWRBoh
MNAVCJtiirHo2tFsLFOmlJpGL9n3sX8UnkJN90oHfWrzJ8BZnXaSw2eOuyw8LLj+
Q+ISl6Go8/xfsuy3EDv4AP1wCwKBgQDJJ4vEV3Kr+bc6N/xeu+G0oHvRAWwuQpcx
9D+XpnqbJbFDnWKNE7oGsDCs8Qjr0CdFUN1pm1ppITDZ5N1cWuDg/47ZAXqEK6D1
z13S7O0oQPlnsPL7mHs2Vl73muAaBPAojFvceHHfccr7Z94BXqKsiyfaWz6kclT/
Nl4JTdsC3QKBgQCeYgozL2J/da2lUhnIXcyPstk+29kbueFYu/QBh2HwqnzqqLJ4
5+t2H3P3plQUFp/DdDSZrvhcBiTsKiNgqThEtkKtfSCvIvBf4a2W/4TJsW6MzxCm
2KQDuK/UqM4Y+APKWN/N6Lln2VWNbNyBkWuuRVKFatccyJyJnSjxeqW7cwKBgGyN
idCYPIrwROAHLItXKvOWE5t0ABRq3TsZC2RkdA/b5HCPs4pclexcEriRjvXrK/Yt
MH94Ve8b+UftSUQ4ytjBMS6MrLg87y0YDhLwxv8NKUq65DXAUOW+8JsAmmWQOqY3
MK+m1BT4TMklgVoN3w3sPsKIsSJ/jLz5cv/kYweFAoGAG4iWU1378tI2Ts/Fngsv
7eoWhoda77Y9D0Yoy20aN9VdMHzIYCBOubtRPEuwgaReNwbUBWap01J63yY/fF3K
8PTz6covjoOJqxQJOvM7nM0CsJawG9ccw3YXyd9KgRIdSt6ooEhb7N8W2EXYoKl3
g1i2t41Q/SC3HUGC5mJjpO8=
-----END PRIVATE KEY-----

View file

@ -0,0 +1,353 @@
from tornado import gen, ioloop
from tornado.httpserver import HTTPServer
from tornado.locks import Event
from tornado.testing import AsyncHTTPTestCase, AsyncTestCase, bind_unused_port, gen_test
from tornado.web import Application
import asyncio
import contextlib
import gc
import os
import platform
import traceback
import unittest
import warnings
@contextlib.contextmanager
def set_environ(name, value):
old_value = os.environ.get(name)
os.environ[name] = value
try:
yield
finally:
if old_value is None:
del os.environ[name]
else:
os.environ[name] = old_value
class AsyncTestCaseTest(AsyncTestCase):
def test_wait_timeout(self):
time = self.io_loop.time
# Accept default 5-second timeout, no error
self.io_loop.add_timeout(time() + 0.01, self.stop)
self.wait()
# Timeout passed to wait()
self.io_loop.add_timeout(time() + 1, self.stop)
with self.assertRaises(self.failureException):
self.wait(timeout=0.01)
# Timeout set with environment variable
self.io_loop.add_timeout(time() + 1, self.stop)
with set_environ("ASYNC_TEST_TIMEOUT", "0.01"):
with self.assertRaises(self.failureException):
self.wait()
def test_subsequent_wait_calls(self):
"""
This test makes sure that a second call to wait()
clears the first timeout.
"""
# The first wait ends with time left on the clock
self.io_loop.add_timeout(self.io_loop.time() + 0.00, self.stop)
self.wait(timeout=0.1)
# The second wait has enough time for itself but would fail if the
# first wait's deadline were still in effect.
self.io_loop.add_timeout(self.io_loop.time() + 0.2, self.stop)
self.wait(timeout=0.4)
class LeakTest(AsyncTestCase):
def tearDown(self):
super().tearDown()
# Trigger a gc to make warnings more deterministic.
gc.collect()
def test_leaked_coroutine(self):
# This test verifies that "leaked" coroutines are shut down
# without triggering warnings like "task was destroyed but it
# is pending". If this test were to fail, it would fail
# because runtests.py detected unexpected output to stderr.
event = Event()
async def callback():
try:
await event.wait()
except asyncio.CancelledError:
pass
self.io_loop.add_callback(callback)
self.io_loop.add_callback(self.stop)
self.wait()
class AsyncHTTPTestCaseTest(AsyncHTTPTestCase):
def setUp(self):
super().setUp()
# Bind a second port.
sock, port = bind_unused_port()
app = Application()
server = HTTPServer(app, **self.get_httpserver_options())
server.add_socket(sock)
self.second_port = port
self.second_server = server
def get_app(self):
return Application()
def test_fetch_segment(self):
path = "/path"
response = self.fetch(path)
self.assertEqual(response.request.url, self.get_url(path))
def test_fetch_full_http_url(self):
# Ensure that self.fetch() recognizes absolute urls and does
# not transform them into references to our main test server.
path = "http://localhost:%d/path" % self.second_port
response = self.fetch(path)
self.assertEqual(response.request.url, path)
def tearDown(self):
self.second_server.stop()
super().tearDown()
class AsyncTestCaseWrapperTest(unittest.TestCase):
def test_undecorated_generator(self):
class Test(AsyncTestCase):
def test_gen(self):
yield
test = Test("test_gen")
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("should be decorated", result.errors[0][1])
@unittest.skipIf(
platform.python_implementation() == "PyPy",
"pypy destructor warnings cannot be silenced",
)
def test_undecorated_coroutine(self):
class Test(AsyncTestCase):
async def test_coro(self):
pass
test = Test("test_coro")
result = unittest.TestResult()
# Silence "RuntimeWarning: coroutine 'test_coro' was never awaited".
with warnings.catch_warnings():
warnings.simplefilter("ignore")
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("should be decorated", result.errors[0][1])
def test_undecorated_generator_with_skip(self):
class Test(AsyncTestCase):
@unittest.skip("don't run this")
def test_gen(self):
yield
test = Test("test_gen")
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 0)
self.assertEqual(len(result.skipped), 1)
def test_other_return(self):
class Test(AsyncTestCase):
def test_other_return(self):
return 42
test = Test("test_other_return")
result = unittest.TestResult()
test.run(result)
self.assertEqual(len(result.errors), 1)
self.assertIn("Return value from test method ignored", result.errors[0][1])
class SetUpTearDownTest(unittest.TestCase):
def test_set_up_tear_down(self):
"""
This test makes sure that AsyncTestCase calls super methods for
setUp and tearDown.
InheritBoth is a subclass of both AsyncTestCase and
SetUpTearDown, with the ordering so that the super of
AsyncTestCase will be SetUpTearDown.
"""
events = []
result = unittest.TestResult()
class SetUpTearDown(unittest.TestCase):
def setUp(self):
events.append("setUp")
def tearDown(self):
events.append("tearDown")
class InheritBoth(AsyncTestCase, SetUpTearDown):
def test(self):
events.append("test")
InheritBoth("test").run(result)
expected = ["setUp", "test", "tearDown"]
self.assertEqual(expected, events)
class AsyncHTTPTestCaseSetUpTearDownTest(unittest.TestCase):
def test_tear_down_releases_app_and_http_server(self):
result = unittest.TestResult()
class SetUpTearDown(AsyncHTTPTestCase):
def get_app(self):
return Application()
def test(self):
self.assertTrue(hasattr(self, "_app"))
self.assertTrue(hasattr(self, "http_server"))
test = SetUpTearDown("test")
test.run(result)
self.assertFalse(hasattr(test, "_app"))
self.assertFalse(hasattr(test, "http_server"))
class GenTest(AsyncTestCase):
def setUp(self):
super().setUp()
self.finished = False
def tearDown(self):
self.assertTrue(self.finished)
super().tearDown()
@gen_test
def test_sync(self):
self.finished = True
@gen_test
def test_async(self):
yield gen.moment
self.finished = True
def test_timeout(self):
# Set a short timeout and exceed it.
@gen_test(timeout=0.1)
def test(self):
yield gen.sleep(1)
# This can't use assertRaises because we need to inspect the
# exc_info triple (and not just the exception object)
try:
test(self)
self.fail("did not get expected exception")
except ioloop.TimeoutError:
# The stack trace should blame the add_timeout line, not just
# unrelated IOLoop/testing internals.
self.assertIn("gen.sleep(1)", traceback.format_exc())
self.finished = True
def test_no_timeout(self):
# A test that does not exceed its timeout should succeed.
@gen_test(timeout=1)
def test(self):
yield gen.sleep(0.1)
test(self)
self.finished = True
def test_timeout_environment_variable(self):
@gen_test(timeout=0.5)
def test_long_timeout(self):
yield gen.sleep(0.25)
# Uses provided timeout of 0.5 seconds, doesn't time out.
with set_environ("ASYNC_TEST_TIMEOUT", "0.1"):
test_long_timeout(self)
self.finished = True
def test_no_timeout_environment_variable(self):
@gen_test(timeout=0.01)
def test_short_timeout(self):
yield gen.sleep(1)
# Uses environment-variable timeout of 0.1, times out.
with set_environ("ASYNC_TEST_TIMEOUT", "0.1"):
with self.assertRaises(ioloop.TimeoutError):
test_short_timeout(self)
self.finished = True
def test_with_method_args(self):
@gen_test
def test_with_args(self, *args):
self.assertEqual(args, ("test",))
yield gen.moment
test_with_args(self, "test")
self.finished = True
def test_with_method_kwargs(self):
@gen_test
def test_with_kwargs(self, **kwargs):
self.assertDictEqual(kwargs, {"test": "test"})
yield gen.moment
test_with_kwargs(self, test="test")
self.finished = True
def test_native_coroutine(self):
@gen_test
async def test(self):
self.finished = True
test(self)
def test_native_coroutine_timeout(self):
# Set a short timeout and exceed it.
@gen_test(timeout=0.1)
async def test(self):
await gen.sleep(1)
try:
test(self)
self.fail("did not get expected exception")
except ioloop.TimeoutError:
self.finished = True
class GetNewIOLoopTest(AsyncTestCase):
def get_new_ioloop(self):
# Use the current loop instead of creating a new one here.
return ioloop.IOLoop.current()
def setUp(self):
# This simulates the effect of an asyncio test harness like
# pytest-asyncio.
self.orig_loop = asyncio.get_event_loop()
self.new_loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.new_loop)
super().setUp()
def tearDown(self):
super().tearDown()
# AsyncTestCase must not affect the existing asyncio loop.
self.assertFalse(asyncio.get_event_loop().is_closed())
asyncio.set_event_loop(self.orig_loop)
self.new_loop.close()
def test_loop(self):
self.assertIs(self.io_loop.asyncio_loop, self.new_loop) # type: ignore
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,247 @@
# Author: Ovidiu Predescu
# Date: July 2011
#
# Licensed under the Apache License, Version 2.0 (the "License"); you may
# not use this file except in compliance with the License. You may obtain
# a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import asyncio
import logging
import signal
import unittest
import warnings
from tornado.escape import utf8
from tornado import gen
from tornado.httpclient import AsyncHTTPClient
from tornado.httpserver import HTTPServer
from tornado.ioloop import IOLoop
from tornado.testing import bind_unused_port, AsyncTestCase, gen_test
from tornado.web import RequestHandler, Application
try:
from twisted.internet.defer import ( # type: ignore
Deferred,
inlineCallbacks,
returnValue,
)
from twisted.internet.protocol import Protocol # type: ignore
from twisted.internet.asyncioreactor import AsyncioSelectorReactor # type: ignore
from twisted.web.client import Agent, readBody # type: ignore
from twisted.web.resource import Resource # type: ignore
from twisted.web.server import Site # type: ignore
have_twisted = True
except ImportError:
have_twisted = False
else:
# Not used directly but needed for `yield deferred` to work.
import tornado.platform.twisted # noqa: F401
skipIfNoTwisted = unittest.skipUnless(have_twisted, "twisted module not present")
def save_signal_handlers():
saved = {}
signals = [signal.SIGINT, signal.SIGTERM]
if hasattr(signal, "SIGCHLD"):
signals.append(signal.SIGCHLD)
for sig in signals:
saved[sig] = signal.getsignal(sig)
if "twisted" in repr(saved):
# This indicates we're not cleaning up after ourselves properly.
raise Exception("twisted signal handlers already installed")
return saved
def restore_signal_handlers(saved):
for sig, handler in saved.items():
signal.signal(sig, handler)
# Test various combinations of twisted and tornado http servers,
# http clients, and event loop interfaces.
@skipIfNoTwisted
class CompatibilityTests(unittest.TestCase):
def setUp(self):
self.saved_signals = save_signal_handlers()
self.saved_policy = asyncio.get_event_loop_policy()
if hasattr(asyncio, "WindowsSelectorEventLoopPolicy"):
# Twisted requires a selector event loop, even if Tornado is
# doing its own tricks in AsyncIOLoop to support proactors.
# Setting an AddThreadSelectorEventLoop exposes various edge
# cases so just use a regular selector.
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) # type: ignore
self.io_loop = IOLoop()
self.io_loop.make_current()
self.reactor = AsyncioSelectorReactor()
def tearDown(self):
self.reactor.disconnectAll()
self.io_loop.clear_current()
self.io_loop.close(all_fds=True)
asyncio.set_event_loop_policy(self.saved_policy)
restore_signal_handlers(self.saved_signals)
def start_twisted_server(self):
class HelloResource(Resource):
isLeaf = True
def render_GET(self, request):
return b"Hello from twisted!"
site = Site(HelloResource())
port = self.reactor.listenTCP(0, site, interface="127.0.0.1")
self.twisted_port = port.getHost().port
def start_tornado_server(self):
class HelloHandler(RequestHandler):
def get(self):
self.write("Hello from tornado!")
app = Application([("/", HelloHandler)], log_function=lambda x: None)
server = HTTPServer(app)
sock, self.tornado_port = bind_unused_port()
server.add_sockets([sock])
def run_reactor(self):
# In theory, we can run the event loop through Tornado,
# Twisted, or asyncio interfaces. However, since we're trying
# to avoid installing anything as the global event loop, only
# the twisted interface gets everything wired up correectly
# without extra hacks. This method is a part of a
# no-longer-used generalization that allowed us to test
# different combinations.
self.stop_loop = self.reactor.stop
self.stop = self.reactor.stop
self.reactor.run()
def tornado_fetch(self, url, runner):
client = AsyncHTTPClient()
fut = asyncio.ensure_future(client.fetch(url))
fut.add_done_callback(lambda f: self.stop_loop())
runner()
return fut.result()
def twisted_fetch(self, url, runner):
# http://twistedmatrix.com/documents/current/web/howto/client.html
chunks = []
client = Agent(self.reactor)
d = client.request(b"GET", utf8(url))
class Accumulator(Protocol):
def __init__(self, finished):
self.finished = finished
def dataReceived(self, data):
chunks.append(data)
def connectionLost(self, reason):
self.finished.callback(None)
def callback(response):
finished = Deferred()
response.deliverBody(Accumulator(finished))
return finished
d.addCallback(callback)
def shutdown(failure):
if hasattr(self, "stop_loop"):
self.stop_loop()
elif failure is not None:
# loop hasn't been initialized yet; try our best to
# get an error message out. (the runner() interaction
# should probably be refactored).
try:
failure.raiseException()
except:
logging.error("exception before starting loop", exc_info=True)
d.addBoth(shutdown)
runner()
self.assertTrue(chunks)
return b"".join(chunks)
def twisted_coroutine_fetch(self, url, runner):
body = [None]
@gen.coroutine
def f():
# This is simpler than the non-coroutine version, but it cheats
# by reading the body in one blob instead of streaming it with
# a Protocol.
client = Agent(self.reactor)
response = yield client.request(b"GET", utf8(url))
with warnings.catch_warnings():
# readBody has a buggy DeprecationWarning in Twisted 15.0:
# https://twistedmatrix.com/trac/changeset/43379
warnings.simplefilter("ignore", category=DeprecationWarning)
body[0] = yield readBody(response)
self.stop_loop()
self.io_loop.add_callback(f)
runner()
return body[0]
def testTwistedServerTornadoClientReactor(self):
self.start_twisted_server()
response = self.tornado_fetch(
"http://127.0.0.1:%d" % self.twisted_port, self.run_reactor
)
self.assertEqual(response.body, b"Hello from twisted!")
def testTornadoServerTwistedClientReactor(self):
self.start_tornado_server()
response = self.twisted_fetch(
"http://127.0.0.1:%d" % self.tornado_port, self.run_reactor
)
self.assertEqual(response, b"Hello from tornado!")
def testTornadoServerTwistedCoroutineClientReactor(self):
self.start_tornado_server()
response = self.twisted_coroutine_fetch(
"http://127.0.0.1:%d" % self.tornado_port, self.run_reactor
)
self.assertEqual(response, b"Hello from tornado!")
@skipIfNoTwisted
class ConvertDeferredTest(AsyncTestCase):
@gen_test
def test_success(self):
@inlineCallbacks
def fn():
if False:
# inlineCallbacks doesn't work with regular functions;
# must have a yield even if it's unreachable.
yield
returnValue(42)
res = yield fn()
self.assertEqual(res, 42)
@gen_test
def test_failure(self):
@inlineCallbacks
def fn():
if False:
yield
1 / 0
with self.assertRaises(ZeroDivisionError):
yield fn()
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,114 @@
import contextlib
import os
import platform
import socket
import sys
import textwrap
import typing # noqa: F401
import unittest
import warnings
from tornado.testing import bind_unused_port
skipIfNonUnix = unittest.skipIf(
os.name != "posix" or sys.platform == "cygwin", "non-unix platform"
)
# travis-ci.org runs our tests in an overworked virtual machine, which makes
# timing-related tests unreliable.
skipOnTravis = unittest.skipIf(
"TRAVIS" in os.environ, "timing tests unreliable on travis"
)
# Set the environment variable NO_NETWORK=1 to disable any tests that
# depend on an external network.
skipIfNoNetwork = unittest.skipIf("NO_NETWORK" in os.environ, "network access disabled")
skipNotCPython = unittest.skipIf(
platform.python_implementation() != "CPython", "Not CPython implementation"
)
# Used for tests affected by
# https://bitbucket.org/pypy/pypy/issues/2616/incomplete-error-handling-in
# TODO: remove this after pypy3 5.8 is obsolete.
skipPypy3V58 = unittest.skipIf(
platform.python_implementation() == "PyPy"
and sys.version_info > (3,)
and sys.pypy_version_info < (5, 9), # type: ignore
"pypy3 5.8 has buggy ssl module",
)
def _detect_ipv6():
if not socket.has_ipv6:
# socket.has_ipv6 check reports whether ipv6 was present at compile
# time. It's usually true even when ipv6 doesn't work for other reasons.
return False
sock = None
try:
sock = socket.socket(socket.AF_INET6)
sock.bind(("::1", 0))
except socket.error:
return False
finally:
if sock is not None:
sock.close()
return True
skipIfNoIPv6 = unittest.skipIf(not _detect_ipv6(), "ipv6 support not present")
def refusing_port():
"""Returns a local port number that will refuse all connections.
Return value is (cleanup_func, port); the cleanup function
must be called to free the port to be reused.
"""
# On travis-ci, port numbers are reassigned frequently. To avoid
# collisions with other tests, we use an open client-side socket's
# ephemeral port number to ensure that nothing can listen on that
# port.
server_socket, port = bind_unused_port()
server_socket.setblocking(True)
client_socket = socket.socket()
client_socket.connect(("127.0.0.1", port))
conn, client_addr = server_socket.accept()
conn.close()
server_socket.close()
return (client_socket.close, client_addr[1])
def exec_test(caller_globals, caller_locals, s):
"""Execute ``s`` in a given context and return the result namespace.
Used to define functions for tests in particular python
versions that would be syntax errors in older versions.
"""
# Flatten the real global and local namespace into our fake
# globals: it's all global from the perspective of code defined
# in s.
global_namespace = dict(caller_globals, **caller_locals) # type: ignore
local_namespace = {} # type: typing.Dict[str, typing.Any]
exec(textwrap.dedent(s), global_namespace, local_namespace)
return local_namespace
def subTest(test, *args, **kwargs):
"""Compatibility shim for unittest.TestCase.subTest.
Usage: ``with tornado.test.util.subTest(self, x=x):``
"""
try:
subTest = test.subTest # py34+
except AttributeError:
subTest = contextlib.contextmanager(lambda *a, **kw: (yield))
return subTest(*args, **kwargs)
@contextlib.contextmanager
def ignore_deprecation():
"""Context manager to ignore deprecation warnings."""
with warnings.catch_warnings():
warnings.simplefilter("ignore", DeprecationWarning)
yield

View file

@ -0,0 +1,308 @@
from io import StringIO
import re
import sys
import datetime
import unittest
import tornado.escape
from tornado.escape import utf8
from tornado.util import (
raise_exc_info,
Configurable,
exec_in,
ArgReplacer,
timedelta_to_seconds,
import_object,
re_unescape,
is_finalizing,
)
import typing
from typing import cast
if typing.TYPE_CHECKING:
from typing import Dict, Any # noqa: F401
class RaiseExcInfoTest(unittest.TestCase):
def test_two_arg_exception(self):
# This test would fail on python 3 if raise_exc_info were simply
# a three-argument raise statement, because TwoArgException
# doesn't have a "copy constructor"
class TwoArgException(Exception):
def __init__(self, a, b):
super().__init__()
self.a, self.b = a, b
try:
raise TwoArgException(1, 2)
except TwoArgException:
exc_info = sys.exc_info()
try:
raise_exc_info(exc_info)
self.fail("didn't get expected exception")
except TwoArgException as e:
self.assertIs(e, exc_info[1])
class TestConfigurable(Configurable):
@classmethod
def configurable_base(cls):
return TestConfigurable
@classmethod
def configurable_default(cls):
return TestConfig1
class TestConfig1(TestConfigurable):
def initialize(self, pos_arg=None, a=None):
self.a = a
self.pos_arg = pos_arg
class TestConfig2(TestConfigurable):
def initialize(self, pos_arg=None, b=None):
self.b = b
self.pos_arg = pos_arg
class TestConfig3(TestConfigurable):
# TestConfig3 is a configuration option that is itself configurable.
@classmethod
def configurable_base(cls):
return TestConfig3
@classmethod
def configurable_default(cls):
return TestConfig3A
class TestConfig3A(TestConfig3):
def initialize(self, a=None):
self.a = a
class TestConfig3B(TestConfig3):
def initialize(self, b=None):
self.b = b
class ConfigurableTest(unittest.TestCase):
def setUp(self):
self.saved = TestConfigurable._save_configuration()
self.saved3 = TestConfig3._save_configuration()
def tearDown(self):
TestConfigurable._restore_configuration(self.saved)
TestConfig3._restore_configuration(self.saved3)
def checkSubclasses(self):
# no matter how the class is configured, it should always be
# possible to instantiate the subclasses directly
self.assertIsInstance(TestConfig1(), TestConfig1)
self.assertIsInstance(TestConfig2(), TestConfig2)
obj = TestConfig1(a=1)
self.assertEqual(obj.a, 1)
obj2 = TestConfig2(b=2)
self.assertEqual(obj2.b, 2)
def test_default(self):
# In these tests we combine a typing.cast to satisfy mypy with
# a runtime type-assertion. Without the cast, mypy would only
# let us access attributes of the base class.
obj = cast(TestConfig1, TestConfigurable())
self.assertIsInstance(obj, TestConfig1)
self.assertIs(obj.a, None)
obj = cast(TestConfig1, TestConfigurable(a=1))
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 1)
self.checkSubclasses()
def test_config_class(self):
TestConfigurable.configure(TestConfig2)
obj = cast(TestConfig2, TestConfigurable())
self.assertIsInstance(obj, TestConfig2)
self.assertIs(obj.b, None)
obj = cast(TestConfig2, TestConfigurable(b=2))
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 2)
self.checkSubclasses()
def test_config_str(self):
TestConfigurable.configure("tornado.test.util_test.TestConfig2")
obj = cast(TestConfig2, TestConfigurable())
self.assertIsInstance(obj, TestConfig2)
self.assertIs(obj.b, None)
obj = cast(TestConfig2, TestConfigurable(b=2))
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 2)
self.checkSubclasses()
def test_config_args(self):
TestConfigurable.configure(None, a=3)
obj = cast(TestConfig1, TestConfigurable())
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 3)
obj = cast(TestConfig1, TestConfigurable(42, a=4))
self.assertIsInstance(obj, TestConfig1)
self.assertEqual(obj.a, 4)
self.assertEqual(obj.pos_arg, 42)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
obj = TestConfig1()
self.assertIs(obj.a, None)
def test_config_class_args(self):
TestConfigurable.configure(TestConfig2, b=5)
obj = cast(TestConfig2, TestConfigurable())
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 5)
obj = cast(TestConfig2, TestConfigurable(42, b=6))
self.assertIsInstance(obj, TestConfig2)
self.assertEqual(obj.b, 6)
self.assertEqual(obj.pos_arg, 42)
self.checkSubclasses()
# args bound in configure don't apply when using the subclass directly
obj = TestConfig2()
self.assertIs(obj.b, None)
def test_config_multi_level(self):
TestConfigurable.configure(TestConfig3, a=1)
obj = cast(TestConfig3A, TestConfigurable())
self.assertIsInstance(obj, TestConfig3A)
self.assertEqual(obj.a, 1)
TestConfigurable.configure(TestConfig3)
TestConfig3.configure(TestConfig3B, b=2)
obj2 = cast(TestConfig3B, TestConfigurable())
self.assertIsInstance(obj2, TestConfig3B)
self.assertEqual(obj2.b, 2)
def test_config_inner_level(self):
# The inner level can be used even when the outer level
# doesn't point to it.
obj = TestConfig3()
self.assertIsInstance(obj, TestConfig3A)
TestConfig3.configure(TestConfig3B)
obj = TestConfig3()
self.assertIsInstance(obj, TestConfig3B)
# Configuring the base doesn't configure the inner.
obj2 = TestConfigurable()
self.assertIsInstance(obj2, TestConfig1)
TestConfigurable.configure(TestConfig2)
obj3 = TestConfigurable()
self.assertIsInstance(obj3, TestConfig2)
obj = TestConfig3()
self.assertIsInstance(obj, TestConfig3B)
class UnicodeLiteralTest(unittest.TestCase):
def test_unicode_escapes(self):
self.assertEqual(utf8(u"\u00e9"), b"\xc3\xa9")
class ExecInTest(unittest.TestCase):
# TODO(bdarnell): make a version of this test for one of the new
# future imports available in python 3.
@unittest.skip("no testable future imports")
def test_no_inherit_future(self):
# This file has from __future__ import print_function...
f = StringIO()
print("hello", file=f)
# ...but the template doesn't
exec_in('print >> f, "world"', dict(f=f))
self.assertEqual(f.getvalue(), "hello\nworld\n")
class ArgReplacerTest(unittest.TestCase):
def setUp(self):
def function(x, y, callback=None, z=None):
pass
self.replacer = ArgReplacer(function, "callback")
def test_omitted(self):
args = (1, 2)
kwargs = dict() # type: Dict[str, Any]
self.assertIs(self.replacer.get_old_value(args, kwargs), None)
self.assertEqual(
self.replacer.replace("new", args, kwargs),
(None, (1, 2), dict(callback="new")),
)
def test_position(self):
args = (1, 2, "old", 3)
kwargs = dict() # type: Dict[str, Any]
self.assertEqual(self.replacer.get_old_value(args, kwargs), "old")
self.assertEqual(
self.replacer.replace("new", args, kwargs),
("old", [1, 2, "new", 3], dict()),
)
def test_keyword(self):
args = (1,)
kwargs = dict(y=2, callback="old", z=3)
self.assertEqual(self.replacer.get_old_value(args, kwargs), "old")
self.assertEqual(
self.replacer.replace("new", args, kwargs),
("old", (1,), dict(y=2, callback="new", z=3)),
)
class TimedeltaToSecondsTest(unittest.TestCase):
def test_timedelta_to_seconds(self):
time_delta = datetime.timedelta(hours=1)
self.assertEqual(timedelta_to_seconds(time_delta), 3600.0)
class ImportObjectTest(unittest.TestCase):
def test_import_member(self):
self.assertIs(import_object("tornado.escape.utf8"), utf8)
def test_import_member_unicode(self):
self.assertIs(import_object(u"tornado.escape.utf8"), utf8)
def test_import_module(self):
self.assertIs(import_object("tornado.escape"), tornado.escape)
def test_import_module_unicode(self):
# The internal implementation of __import__ differs depending on
# whether the thing being imported is a module or not.
# This variant requires a byte string in python 2.
self.assertIs(import_object(u"tornado.escape"), tornado.escape)
class ReUnescapeTest(unittest.TestCase):
def test_re_unescape(self):
test_strings = ("/favicon.ico", "index.html", "Hello, World!", "!$@#%;")
for string in test_strings:
self.assertEqual(string, re_unescape(re.escape(string)))
def test_re_unescape_raises_error_on_invalid_input(self):
with self.assertRaises(ValueError):
re_unescape("\\d")
with self.assertRaises(ValueError):
re_unescape("\\b")
with self.assertRaises(ValueError):
re_unescape("\\Z")
class IsFinalizingTest(unittest.TestCase):
def test_basic(self):
self.assertFalse(is_finalizing())

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,840 @@
import asyncio
import functools
import traceback
import typing
import unittest
from tornado.concurrent import Future
from tornado import gen
from tornado.httpclient import HTTPError, HTTPRequest
from tornado.locks import Event
from tornado.log import gen_log, app_log
from tornado.simple_httpclient import SimpleAsyncHTTPClient
from tornado.template import DictLoader
from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
from tornado.web import Application, RequestHandler
try:
import tornado.websocket # noqa: F401
from tornado.util import _websocket_mask_python
except ImportError:
# The unittest module presents misleading errors on ImportError
# (it acts as if websocket_test could not be found, hiding the underlying
# error). If we get an ImportError here (which could happen due to
# TORNADO_EXTENSION=1), print some extra information before failing.
traceback.print_exc()
raise
from tornado.websocket import (
WebSocketHandler,
websocket_connect,
WebSocketError,
WebSocketClosedError,
)
try:
from tornado import speedups
except ImportError:
speedups = None # type: ignore
class TestWebSocketHandler(WebSocketHandler):
"""Base class for testing handlers that exposes the on_close event.
This allows for tests to see the close code and reason on the
server side.
"""
def initialize(self, close_future=None, compression_options=None):
self.close_future = close_future
self.compression_options = compression_options
def get_compression_options(self):
return self.compression_options
def on_close(self):
if self.close_future is not None:
self.close_future.set_result((self.close_code, self.close_reason))
class EchoHandler(TestWebSocketHandler):
@gen.coroutine
def on_message(self, message):
try:
yield self.write_message(message, isinstance(message, bytes))
except asyncio.CancelledError:
pass
except WebSocketClosedError:
pass
class ErrorInOnMessageHandler(TestWebSocketHandler):
def on_message(self, message):
1 / 0
class HeaderHandler(TestWebSocketHandler):
def open(self):
methods_to_test = [
functools.partial(self.write, "This should not work"),
functools.partial(self.redirect, "http://localhost/elsewhere"),
functools.partial(self.set_header, "X-Test", ""),
functools.partial(self.set_cookie, "Chocolate", "Chip"),
functools.partial(self.set_status, 503),
self.flush,
self.finish,
]
for method in methods_to_test:
try:
# In a websocket context, many RequestHandler methods
# raise RuntimeErrors.
method() # type: ignore
raise Exception("did not get expected exception")
except RuntimeError:
pass
self.write_message(self.request.headers.get("X-Test", ""))
class HeaderEchoHandler(TestWebSocketHandler):
def set_default_headers(self):
self.set_header("X-Extra-Response-Header", "Extra-Response-Value")
def prepare(self):
for k, v in self.request.headers.get_all():
if k.lower().startswith("x-test"):
self.set_header(k, v)
class NonWebSocketHandler(RequestHandler):
def get(self):
self.write("ok")
class RedirectHandler(RequestHandler):
def get(self):
self.redirect("/echo")
class CloseReasonHandler(TestWebSocketHandler):
def open(self):
self.on_close_called = False
self.close(1001, "goodbye")
class AsyncPrepareHandler(TestWebSocketHandler):
@gen.coroutine
def prepare(self):
yield gen.moment
def on_message(self, message):
self.write_message(message)
class PathArgsHandler(TestWebSocketHandler):
def open(self, arg):
self.write_message(arg)
class CoroutineOnMessageHandler(TestWebSocketHandler):
def initialize(self, **kwargs):
super().initialize(**kwargs)
self.sleeping = 0
@gen.coroutine
def on_message(self, message):
if self.sleeping > 0:
self.write_message("another coroutine is already sleeping")
self.sleeping += 1
yield gen.sleep(0.01)
self.sleeping -= 1
self.write_message(message)
class RenderMessageHandler(TestWebSocketHandler):
def on_message(self, message):
self.write_message(self.render_string("message.html", message=message))
class SubprotocolHandler(TestWebSocketHandler):
def initialize(self, **kwargs):
super().initialize(**kwargs)
self.select_subprotocol_called = False
def select_subprotocol(self, subprotocols):
if self.select_subprotocol_called:
raise Exception("select_subprotocol called twice")
self.select_subprotocol_called = True
if "goodproto" in subprotocols:
return "goodproto"
return None
def open(self):
if not self.select_subprotocol_called:
raise Exception("select_subprotocol not called")
self.write_message("subprotocol=%s" % self.selected_subprotocol)
class OpenCoroutineHandler(TestWebSocketHandler):
def initialize(self, test, **kwargs):
super().initialize(**kwargs)
self.test = test
self.open_finished = False
@gen.coroutine
def open(self):
yield self.test.message_sent.wait()
yield gen.sleep(0.010)
self.open_finished = True
def on_message(self, message):
if not self.open_finished:
raise Exception("on_message called before open finished")
self.write_message("ok")
class ErrorInOpenHandler(TestWebSocketHandler):
def open(self):
raise Exception("boom")
class ErrorInAsyncOpenHandler(TestWebSocketHandler):
async def open(self):
await asyncio.sleep(0)
raise Exception("boom")
class NoDelayHandler(TestWebSocketHandler):
def open(self):
self.set_nodelay(True)
self.write_message("hello")
class WebSocketBaseTestCase(AsyncHTTPTestCase):
@gen.coroutine
def ws_connect(self, path, **kwargs):
ws = yield websocket_connect(
"ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs
)
raise gen.Return(ws)
class WebSocketTest(WebSocketBaseTestCase):
def get_app(self):
self.close_future = Future() # type: Future[None]
return Application(
[
("/echo", EchoHandler, dict(close_future=self.close_future)),
("/non_ws", NonWebSocketHandler),
("/redirect", RedirectHandler),
("/header", HeaderHandler, dict(close_future=self.close_future)),
(
"/header_echo",
HeaderEchoHandler,
dict(close_future=self.close_future),
),
(
"/close_reason",
CloseReasonHandler,
dict(close_future=self.close_future),
),
(
"/error_in_on_message",
ErrorInOnMessageHandler,
dict(close_future=self.close_future),
),
(
"/async_prepare",
AsyncPrepareHandler,
dict(close_future=self.close_future),
),
(
"/path_args/(.*)",
PathArgsHandler,
dict(close_future=self.close_future),
),
(
"/coroutine",
CoroutineOnMessageHandler,
dict(close_future=self.close_future),
),
("/render", RenderMessageHandler, dict(close_future=self.close_future)),
(
"/subprotocol",
SubprotocolHandler,
dict(close_future=self.close_future),
),
(
"/open_coroutine",
OpenCoroutineHandler,
dict(close_future=self.close_future, test=self),
),
("/error_in_open", ErrorInOpenHandler),
("/error_in_async_open", ErrorInAsyncOpenHandler),
("/nodelay", NoDelayHandler),
],
template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}),
)
def get_http_client(self):
# These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
return SimpleAsyncHTTPClient()
def tearDown(self):
super().tearDown()
RequestHandler._template_loaders.clear()
def test_http_request(self):
# WS server, HTTP client.
response = self.fetch("/echo")
self.assertEqual(response.code, 400)
def test_missing_websocket_key(self):
response = self.fetch(
"/echo",
headers={
"Connection": "Upgrade",
"Upgrade": "WebSocket",
"Sec-WebSocket-Version": "13",
},
)
self.assertEqual(response.code, 400)
def test_bad_websocket_version(self):
response = self.fetch(
"/echo",
headers={
"Connection": "Upgrade",
"Upgrade": "WebSocket",
"Sec-WebSocket-Version": "12",
},
)
self.assertEqual(response.code, 426)
@gen_test
def test_websocket_gen(self):
ws = yield self.ws_connect("/echo")
yield ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
def test_websocket_callbacks(self):
websocket_connect(
"ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop
)
ws = self.wait().result()
ws.write_message("hello")
ws.read_message(self.stop)
response = self.wait().result()
self.assertEqual(response, "hello")
self.close_future.add_done_callback(lambda f: self.stop())
ws.close()
self.wait()
@gen_test
def test_binary_message(self):
ws = yield self.ws_connect("/echo")
ws.write_message(b"hello \xe9", binary=True)
response = yield ws.read_message()
self.assertEqual(response, b"hello \xe9")
@gen_test
def test_unicode_message(self):
ws = yield self.ws_connect("/echo")
ws.write_message(u"hello \u00e9")
response = yield ws.read_message()
self.assertEqual(response, u"hello \u00e9")
@gen_test
def test_render_message(self):
ws = yield self.ws_connect("/render")
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "<b>hello</b>")
@gen_test
def test_error_in_on_message(self):
ws = yield self.ws_connect("/error_in_on_message")
ws.write_message("hello")
with ExpectLog(app_log, "Uncaught exception"):
response = yield ws.read_message()
self.assertIs(response, None)
@gen_test
def test_websocket_http_fail(self):
with self.assertRaises(HTTPError) as cm:
yield self.ws_connect("/notfound")
self.assertEqual(cm.exception.code, 404)
@gen_test
def test_websocket_http_success(self):
with self.assertRaises(WebSocketError):
yield self.ws_connect("/non_ws")
@gen_test
def test_websocket_http_redirect(self):
with self.assertRaises(HTTPError):
yield self.ws_connect("/redirect")
@gen_test
def test_websocket_network_fail(self):
sock, port = bind_unused_port()
sock.close()
with self.assertRaises(IOError):
with ExpectLog(gen_log, ".*"):
yield websocket_connect(
"ws://127.0.0.1:%d/" % port, connect_timeout=3600
)
@gen_test
def test_websocket_close_buffered_data(self):
ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port())
ws.write_message("hello")
ws.write_message("world")
# Close the underlying stream.
ws.stream.close()
@gen_test
def test_websocket_headers(self):
# Ensure that arbitrary headers can be passed through websocket_connect.
ws = yield websocket_connect(
HTTPRequest(
"ws://127.0.0.1:%d/header" % self.get_http_port(),
headers={"X-Test": "hello"},
)
)
response = yield ws.read_message()
self.assertEqual(response, "hello")
@gen_test
def test_websocket_header_echo(self):
# Ensure that headers can be returned in the response.
# Specifically, that arbitrary headers passed through websocket_connect
# can be returned.
ws = yield websocket_connect(
HTTPRequest(
"ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
headers={"X-Test-Hello": "hello"},
)
)
self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
self.assertEqual(
ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
)
@gen_test
def test_server_close_reason(self):
ws = yield self.ws_connect("/close_reason")
msg = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(msg, None)
self.assertEqual(ws.close_code, 1001)
self.assertEqual(ws.close_reason, "goodbye")
# The on_close callback is called no matter which side closed.
code, reason = yield self.close_future
# The client echoed the close code it received to the server,
# so the server's close code (returned via close_future) is
# the same.
self.assertEqual(code, 1001)
@gen_test
def test_client_close_reason(self):
ws = yield self.ws_connect("/echo")
ws.close(1001, "goodbye")
code, reason = yield self.close_future
self.assertEqual(code, 1001)
self.assertEqual(reason, "goodbye")
@gen_test
def test_write_after_close(self):
ws = yield self.ws_connect("/close_reason")
msg = yield ws.read_message()
self.assertIs(msg, None)
with self.assertRaises(WebSocketClosedError):
ws.write_message("hello")
@gen_test
def test_async_prepare(self):
# Previously, an async prepare method triggered a bug that would
# result in a timeout on test shutdown (and a memory leak).
ws = yield self.ws_connect("/async_prepare")
ws.write_message("hello")
res = yield ws.read_message()
self.assertEqual(res, "hello")
@gen_test
def test_path_args(self):
ws = yield self.ws_connect("/path_args/hello")
res = yield ws.read_message()
self.assertEqual(res, "hello")
@gen_test
def test_coroutine(self):
ws = yield self.ws_connect("/coroutine")
# Send both messages immediately, coroutine must process one at a time.
yield ws.write_message("hello1")
yield ws.write_message("hello2")
res = yield ws.read_message()
self.assertEqual(res, "hello1")
res = yield ws.read_message()
self.assertEqual(res, "hello2")
@gen_test
def test_check_origin_valid_no_path(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "http://127.0.0.1:%d" % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
@gen_test
def test_check_origin_valid_with_path(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "http://127.0.0.1:%d/something" % port}
ws = yield websocket_connect(HTTPRequest(url, headers=headers))
ws.write_message("hello")
response = yield ws.read_message()
self.assertEqual(response, "hello")
@gen_test
def test_check_origin_invalid_partial_url(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
headers = {"Origin": "127.0.0.1:%d" % port}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid(self):
port = self.get_http_port()
url = "ws://127.0.0.1:%d/echo" % port
# Host is 127.0.0.1, which should not be accessible from some other
# domain
headers = {"Origin": "http://somewhereelse.com"}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_check_origin_invalid_subdomains(self):
port = self.get_http_port()
url = "ws://localhost:%d/echo" % port
# Subdomains should be disallowed by default. If we could pass a
# resolver to websocket_connect we could test sibling domains as well.
headers = {"Origin": "http://subtenant.localhost"}
with self.assertRaises(HTTPError) as cm:
yield websocket_connect(HTTPRequest(url, headers=headers))
self.assertEqual(cm.exception.code, 403)
@gen_test
def test_subprotocols(self):
ws = yield self.ws_connect(
"/subprotocol", subprotocols=["badproto", "goodproto"]
)
self.assertEqual(ws.selected_subprotocol, "goodproto")
res = yield ws.read_message()
self.assertEqual(res, "subprotocol=goodproto")
@gen_test
def test_subprotocols_not_offered(self):
ws = yield self.ws_connect("/subprotocol")
self.assertIs(ws.selected_subprotocol, None)
res = yield ws.read_message()
self.assertEqual(res, "subprotocol=None")
@gen_test
def test_open_coroutine(self):
self.message_sent = Event()
ws = yield self.ws_connect("/open_coroutine")
yield ws.write_message("hello")
self.message_sent.set()
res = yield ws.read_message()
self.assertEqual(res, "ok")
@gen_test
def test_error_in_open(self):
with ExpectLog(app_log, "Uncaught exception"):
ws = yield self.ws_connect("/error_in_open")
res = yield ws.read_message()
self.assertIsNone(res)
@gen_test
def test_error_in_async_open(self):
with ExpectLog(app_log, "Uncaught exception"):
ws = yield self.ws_connect("/error_in_async_open")
res = yield ws.read_message()
self.assertIsNone(res)
@gen_test
def test_nodelay(self):
ws = yield self.ws_connect("/nodelay")
res = yield ws.read_message()
self.assertEqual(res, "hello")
class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
def initialize(self, **kwargs):
super().initialize(**kwargs)
self.sleeping = 0
async def on_message(self, message):
if self.sleeping > 0:
self.write_message("another coroutine is already sleeping")
self.sleeping += 1
await gen.sleep(0.01)
self.sleeping -= 1
self.write_message(message)
class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
def get_app(self):
return Application([("/native", NativeCoroutineOnMessageHandler)])
@gen_test
def test_native_coroutine(self):
ws = yield self.ws_connect("/native")
# Send both messages immediately, coroutine must process one at a time.
yield ws.write_message("hello1")
yield ws.write_message("hello2")
res = yield ws.read_message()
self.assertEqual(res, "hello1")
res = yield ws.read_message()
self.assertEqual(res, "hello2")
class CompressionTestMixin(object):
MESSAGE = "Hello world. Testing 123 123"
def get_app(self):
class LimitedHandler(TestWebSocketHandler):
@property
def max_message_size(self):
return 1024
def on_message(self, message):
self.write_message(str(len(message)))
return Application(
[
(
"/echo",
EchoHandler,
dict(compression_options=self.get_server_compression_options()),
),
(
"/limited",
LimitedHandler,
dict(compression_options=self.get_server_compression_options()),
),
]
)
def get_server_compression_options(self):
return None
def get_client_compression_options(self):
return None
def verify_wire_bytes(self, bytes_in: int, bytes_out: int) -> None:
raise NotImplementedError()
@gen_test
def test_message_sizes(self: typing.Any):
ws = yield self.ws_connect(
"/echo", compression_options=self.get_client_compression_options()
)
# Send the same message three times so we can measure the
# effect of the context_takeover options.
for i in range(3):
ws.write_message(self.MESSAGE)
response = yield ws.read_message()
self.assertEqual(response, self.MESSAGE)
self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out)
@gen_test
def test_size_limit(self: typing.Any):
ws = yield self.ws_connect(
"/limited", compression_options=self.get_client_compression_options()
)
# Small messages pass through.
ws.write_message("a" * 128)
response = yield ws.read_message()
self.assertEqual(response, "128")
# This message is too big after decompression, but it compresses
# down to a size that will pass the initial checks.
ws.write_message("a" * 2048)
response = yield ws.read_message()
self.assertIsNone(response)
class UncompressedTestMixin(CompressionTestMixin):
"""Specialization of CompressionTestMixin when we expect no compression."""
def verify_wire_bytes(self: typing.Any, bytes_in, bytes_out):
# Bytes out includes the 4-byte mask key per message.
self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
pass
# If only one side tries to compress, the extension is not negotiated.
class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
def get_client_compression_options(self):
return {}
class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
def get_server_compression_options(self):
return {}
def get_client_compression_options(self):
return {}
def verify_wire_bytes(self, bytes_in, bytes_out):
self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
# Bytes out includes the 4 bytes mask key per message.
self.assertEqual(bytes_out, bytes_in + 12)
class MaskFunctionMixin(object):
# Subclasses should define self.mask(mask, data)
def mask(self, mask: bytes, data: bytes) -> bytes:
raise NotImplementedError()
def test_mask(self: typing.Any):
self.assertEqual(self.mask(b"abcd", b""), b"")
self.assertEqual(self.mask(b"abcd", b"b"), b"\x03")
self.assertEqual(self.mask(b"abcd", b"54321"), b"TVPVP")
self.assertEqual(self.mask(b"ZXCV", b"98765432"), b"c`t`olpd")
# Include test cases with \x00 bytes (to ensure that the C
# extension isn't depending on null-terminated strings) and
# bytes with the high bit set (to smoke out signedness issues).
self.assertEqual(
self.mask(b"\x00\x01\x02\x03", b"\xff\xfb\xfd\xfc\xfe\xfa"),
b"\xff\xfa\xff\xff\xfe\xfb",
)
self.assertEqual(
self.mask(b"\xff\xfb\xfd\xfc", b"\x00\x01\x02\x03\x04\x05"),
b"\xff\xfa\xff\xff\xfb\xfe",
)
class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return _websocket_mask_python(mask, data)
@unittest.skipIf(speedups is None, "tornado.speedups module not present")
class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
def mask(self, mask, data):
return speedups.websocket_mask(mask, data)
class ServerPeriodicPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_pong(self, data):
self.write_message("got pong")
return Application([("/", PingHandler)], websocket_ping_interval=0.01)
@gen_test
def test_server_ping(self):
ws = yield self.ws_connect("/")
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got pong")
# TODO: test that the connection gets closed if ping responses stop.
class ClientPeriodicPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_ping(self, data):
self.write_message("got ping")
return Application([("/", PingHandler)])
@gen_test
def test_client_ping(self):
ws = yield self.ws_connect("/", ping_interval=0.01)
for i in range(3):
response = yield ws.read_message()
self.assertEqual(response, "got ping")
# TODO: test that the connection gets closed if ping responses stop.
class ManualPingTest(WebSocketBaseTestCase):
def get_app(self):
class PingHandler(TestWebSocketHandler):
def on_ping(self, data):
self.write_message(data, binary=isinstance(data, bytes))
return Application([("/", PingHandler)])
@gen_test
def test_manual_ping(self):
ws = yield self.ws_connect("/")
self.assertRaises(ValueError, ws.ping, "a" * 126)
ws.ping("hello")
resp = yield ws.read_message()
# on_ping always sees bytes.
self.assertEqual(resp, b"hello")
ws.ping(b"binary hello")
resp = yield ws.read_message()
self.assertEqual(resp, b"binary hello")
class MaxMessageSizeTest(WebSocketBaseTestCase):
def get_app(self):
return Application([("/", EchoHandler)], websocket_max_message_size=1024)
@gen_test
def test_large_message(self):
ws = yield self.ws_connect("/")
# Write a message that is allowed.
msg = "a" * 1024
ws.write_message(msg)
resp = yield ws.read_message()
self.assertEqual(resp, msg)
# Write a message that is too large.
ws.write_message(msg + "b")
resp = yield ws.read_message()
# A message of None means the other side closed the connection.
self.assertIs(resp, None)
self.assertEqual(ws.close_code, 1009)
self.assertEqual(ws.close_reason, "message too big")
# TODO: Needs tests of messages split over multiple
# continuation frames.

View file

@ -0,0 +1,20 @@
from wsgiref.validate import validator
from tornado.testing import AsyncHTTPTestCase
from tornado.wsgi import WSGIContainer
class WSGIContainerTest(AsyncHTTPTestCase):
# TODO: Now that WSGIAdapter is gone, this is a pretty weak test.
def wsgi_app(self, environ, start_response):
status = "200 OK"
response_headers = [("Content-Type", "text/plain")]
start_response(status, response_headers)
return [b"Hello world!"]
def get_app(self):
return WSGIContainer(validator(self.wsgi_app))
def test_simple(self):
response = self.fetch("/")
self.assertEqual(response.body, b"Hello world!")