from tornado import gen, netutil from tornado.escape import ( json_decode, json_encode, utf8, _unicode, recursive_unicode, native_str, ) from tornado.http1connection import HTTP1Connection from tornado.httpclient import HTTPError from tornado.httpserver import HTTPServer from tornado.httputil import ( HTTPHeaders, HTTPMessageDelegate, HTTPServerConnectionDelegate, ResponseStartLine, ) from tornado.iostream import IOStream from tornado.locks import Event from tornado.log import gen_log from tornado.netutil import ssl_options_to_context from tornado.simple_httpclient import SimpleAsyncHTTPClient from tornado.testing import ( AsyncHTTPTestCase, AsyncHTTPSTestCase, AsyncTestCase, ExpectLog, gen_test, ) from tornado.test.util import skipOnTravis from tornado.web import Application, RequestHandler, stream_request_body from contextlib import closing import datetime import gzip import logging import os import shutil import socket import ssl import sys import tempfile import unittest import urllib.parse from io import BytesIO import typing if typing.TYPE_CHECKING: from typing import Dict, List # noqa: F401 async def read_stream_body(stream): """Reads an HTTP response from `stream` and returns a tuple of its start_line, headers and body.""" chunks = [] class Delegate(HTTPMessageDelegate): def headers_received(self, start_line, headers): self.headers = headers self.start_line = start_line def data_received(self, chunk): chunks.append(chunk) def finish(self): conn.detach() # type: ignore conn = HTTP1Connection(stream, True) delegate = Delegate() await conn.read_response(delegate) return delegate.start_line, delegate.headers, b"".join(chunks) class HandlerBaseTestCase(AsyncHTTPTestCase): Handler = None def get_app(self): return Application([("/", self.__class__.Handler)]) def fetch_json(self, *args, **kwargs): response = self.fetch(*args, **kwargs) response.rethrow() return json_decode(response.body) class HelloWorldRequestHandler(RequestHandler): def initialize(self, protocol="http"): self.expected_protocol = protocol def get(self): if self.request.protocol != self.expected_protocol: raise Exception("unexpected protocol") self.finish("Hello world") def post(self): self.finish("Got %d bytes in POST" % len(self.request.body)) # In pre-1.0 versions of openssl, SSLv23 clients always send SSLv2 # ClientHello messages, which are rejected by SSLv3 and TLSv1 # servers. Note that while the OPENSSL_VERSION_INFO was formally # introduced in python3.2, it was present but undocumented in # python 2.7 skipIfOldSSL = unittest.skipIf( getattr(ssl, "OPENSSL_VERSION_INFO", (0, 0)) < (1, 0), "old version of ssl module and/or openssl", ) class BaseSSLTest(AsyncHTTPSTestCase): def get_app(self): return Application([("/", HelloWorldRequestHandler, dict(protocol="https"))]) class SSLTestMixin(object): def get_ssl_options(self): return dict( ssl_version=self.get_ssl_version(), **AsyncHTTPSTestCase.default_ssl_options() ) def get_ssl_version(self): raise NotImplementedError() def test_ssl(self: typing.Any): response = self.fetch("/") self.assertEqual(response.body, b"Hello world") def test_large_post(self: typing.Any): response = self.fetch("/", method="POST", body="A" * 5000) self.assertEqual(response.body, b"Got 5000 bytes in POST") def test_non_ssl_request(self: typing.Any): # Make sure the server closes the connection when it gets a non-ssl # connection, rather than waiting for a timeout or otherwise # misbehaving. with ExpectLog(gen_log, "(SSL Error|uncaught exception)"): with ExpectLog(gen_log, "Uncaught exception", required=False): with self.assertRaises((IOError, HTTPError)): # type: ignore self.fetch( self.get_url("/").replace("https:", "http:"), request_timeout=3600, connect_timeout=3600, raise_error=True, ) def test_error_logging(self: typing.Any): # No stack traces are logged for SSL errors. with ExpectLog(gen_log, "SSL Error") as expect_log: with self.assertRaises((IOError, HTTPError)): # type: ignore self.fetch( self.get_url("/").replace("https:", "http:"), raise_error=True ) self.assertFalse(expect_log.logged_stack) # Python's SSL implementation differs significantly between versions. # For example, SSLv3 and TLSv1 throw an exception if you try to read # from the socket before the handshake is complete, but the default # of SSLv23 allows it. class SSLv23Test(BaseSSLTest, SSLTestMixin): def get_ssl_version(self): return ssl.PROTOCOL_SSLv23 @skipIfOldSSL class SSLv3Test(BaseSSLTest, SSLTestMixin): def get_ssl_version(self): return ssl.PROTOCOL_SSLv3 @skipIfOldSSL class TLSv1Test(BaseSSLTest, SSLTestMixin): def get_ssl_version(self): return ssl.PROTOCOL_TLSv1 class SSLContextTest(BaseSSLTest, SSLTestMixin): def get_ssl_options(self): context = ssl_options_to_context(AsyncHTTPSTestCase.get_ssl_options(self)) assert isinstance(context, ssl.SSLContext) return context class BadSSLOptionsTest(unittest.TestCase): def test_missing_arguments(self): application = Application() self.assertRaises( KeyError, HTTPServer, application, ssl_options={"keyfile": "/__missing__.crt"}, ) def test_missing_key(self): """A missing SSL key should cause an immediate exception.""" application = Application() module_dir = os.path.dirname(__file__) existing_certificate = os.path.join(module_dir, "test.crt") existing_key = os.path.join(module_dir, "test.key") self.assertRaises( (ValueError, IOError), HTTPServer, application, ssl_options={"certfile": "/__mising__.crt"}, ) self.assertRaises( (ValueError, IOError), HTTPServer, application, ssl_options={ "certfile": existing_certificate, "keyfile": "/__missing__.key", }, ) # This actually works because both files exist HTTPServer( application, ssl_options={"certfile": existing_certificate, "keyfile": existing_key}, ) class MultipartTestHandler(RequestHandler): def post(self): self.finish( { "header": self.request.headers["X-Header-Encoding-Test"], "argument": self.get_argument("argument"), "filename": self.request.files["files"][0].filename, "filebody": _unicode(self.request.files["files"][0]["body"]), } ) # This test is also called from wsgi_test class HTTPConnectionTest(AsyncHTTPTestCase): def get_handlers(self): return [ ("/multipart", MultipartTestHandler), ("/hello", HelloWorldRequestHandler), ] def get_app(self): return Application(self.get_handlers()) def raw_fetch(self, headers, body, newline=b"\r\n"): with closing(IOStream(socket.socket())) as stream: self.io_loop.run_sync( lambda: stream.connect(("127.0.0.1", self.get_http_port())) ) stream.write( newline.join(headers + [utf8("Content-Length: %d" % len(body))]) + newline + newline + body ) start_line, headers, body = self.io_loop.run_sync( lambda: read_stream_body(stream) ) return body def test_multipart_form(self): # Encodings here are tricky: Headers are latin1, bodies can be # anything (we use utf8 by default). response = self.raw_fetch( [ b"POST /multipart HTTP/1.0", b"Content-Type: multipart/form-data; boundary=1234567890", b"X-Header-encoding-test: \xe9", ], b"\r\n".join( [ b"Content-Disposition: form-data; name=argument", b"", u"\u00e1".encode("utf-8"), b"--1234567890", u'Content-Disposition: form-data; name="files"; filename="\u00f3"'.encode( "utf8" ), b"", u"\u00fa".encode("utf-8"), b"--1234567890--", b"", ] ), ) data = json_decode(response) self.assertEqual(u"\u00e9", data["header"]) self.assertEqual(u"\u00e1", data["argument"]) self.assertEqual(u"\u00f3", data["filename"]) self.assertEqual(u"\u00fa", data["filebody"]) def test_newlines(self): # We support both CRLF and bare LF as line separators. for newline in (b"\r\n", b"\n"): response = self.raw_fetch([b"GET /hello HTTP/1.0"], b"", newline=newline) self.assertEqual(response, b"Hello world") @gen_test def test_100_continue(self): # Run through a 100-continue interaction by hand: # When given Expect: 100-continue, we get a 100 response after the # headers, and then the real response after the body. stream = IOStream(socket.socket()) yield stream.connect(("127.0.0.1", self.get_http_port())) yield stream.write( b"\r\n".join( [ b"POST /hello HTTP/1.1", b"Content-Length: 1024", b"Expect: 100-continue", b"Connection: close", b"\r\n", ] ) ) data = yield stream.read_until(b"\r\n\r\n") self.assertTrue(data.startswith(b"HTTP/1.1 100 "), data) stream.write(b"a" * 1024) first_line = yield stream.read_until(b"\r\n") self.assertTrue(first_line.startswith(b"HTTP/1.1 200"), first_line) header_data = yield stream.read_until(b"\r\n\r\n") headers = HTTPHeaders.parse(native_str(header_data.decode("latin1"))) body = yield stream.read_bytes(int(headers["Content-Length"])) self.assertEqual(body, b"Got 1024 bytes in POST") stream.close() class EchoHandler(RequestHandler): def get(self): self.write(recursive_unicode(self.request.arguments)) def post(self): self.write(recursive_unicode(self.request.arguments)) class TypeCheckHandler(RequestHandler): def prepare(self): self.errors = {} # type: Dict[str, str] fields = [ ("method", str), ("uri", str), ("version", str), ("remote_ip", str), ("protocol", str), ("host", str), ("path", str), ("query", str), ] for field, expected_type in fields: self.check_type(field, getattr(self.request, field), expected_type) self.check_type("header_key", list(self.request.headers.keys())[0], str) self.check_type("header_value", list(self.request.headers.values())[0], str) self.check_type("cookie_key", list(self.request.cookies.keys())[0], str) self.check_type( "cookie_value", list(self.request.cookies.values())[0].value, str ) # secure cookies self.check_type("arg_key", list(self.request.arguments.keys())[0], str) self.check_type("arg_value", list(self.request.arguments.values())[0][0], bytes) def post(self): self.check_type("body", self.request.body, bytes) self.write(self.errors) def get(self): self.write(self.errors) def check_type(self, name, obj, expected_type): actual_type = type(obj) if expected_type != actual_type: self.errors[name] = "expected %s, got %s" % (expected_type, actual_type) class PostEchoHandler(RequestHandler): def post(self, *path_args): self.write(dict(echo=self.get_argument("data"))) class PostEchoGBKHandler(PostEchoHandler): def decode_argument(self, value, name=None): try: return value.decode("gbk") except Exception: raise HTTPError(400, "invalid gbk bytes: %r" % value) class HTTPServerTest(AsyncHTTPTestCase): def get_app(self): return Application( [ ("/echo", EchoHandler), ("/typecheck", TypeCheckHandler), ("//doubleslash", EchoHandler), ("/post_utf8", PostEchoHandler), ("/post_gbk", PostEchoGBKHandler), ] ) def test_query_string_encoding(self): response = self.fetch("/echo?foo=%C3%A9") data = json_decode(response.body) self.assertEqual(data, {u"foo": [u"\u00e9"]}) def test_empty_query_string(self): response = self.fetch("/echo?foo=&foo=") data = json_decode(response.body) self.assertEqual(data, {u"foo": [u"", u""]}) def test_empty_post_parameters(self): response = self.fetch("/echo", method="POST", body="foo=&bar=") data = json_decode(response.body) self.assertEqual(data, {u"foo": [u""], u"bar": [u""]}) def test_types(self): headers = {"Cookie": "foo=bar"} response = self.fetch("/typecheck?foo=bar", headers=headers) data = json_decode(response.body) self.assertEqual(data, {}) response = self.fetch( "/typecheck", method="POST", body="foo=bar", headers=headers ) data = json_decode(response.body) self.assertEqual(data, {}) def test_double_slash(self): # urlparse.urlsplit (which tornado.httpserver used to use # incorrectly) would parse paths beginning with "//" as # protocol-relative urls. response = self.fetch("//doubleslash") self.assertEqual(200, response.code) self.assertEqual(json_decode(response.body), {}) def test_post_encodings(self): headers = {"Content-Type": "application/x-www-form-urlencoded"} uni_text = "chinese: \u5f20\u4e09" for enc in ("utf8", "gbk"): for quote in (True, False): with self.subTest(enc=enc, quote=quote): bin_text = uni_text.encode(enc) if quote: bin_text = urllib.parse.quote(bin_text).encode("ascii") response = self.fetch( "/post_" + enc, method="POST", headers=headers, body=(b"data=" + bin_text), ) self.assertEqual(json_decode(response.body), {"echo": uni_text}) class HTTPServerRawTest(AsyncHTTPTestCase): def get_app(self): return Application([("/echo", EchoHandler)]) def setUp(self): super().setUp() self.stream = IOStream(socket.socket()) self.io_loop.run_sync( lambda: self.stream.connect(("127.0.0.1", self.get_http_port())) ) def tearDown(self): self.stream.close() super().tearDown() def test_empty_request(self): self.stream.close() self.io_loop.add_timeout(datetime.timedelta(seconds=0.001), self.stop) self.wait() def test_malformed_first_line_response(self): with ExpectLog(gen_log, ".*Malformed HTTP request line", level=logging.INFO): self.stream.write(b"asdf\r\n\r\n") start_line, headers, response = self.io_loop.run_sync( lambda: read_stream_body(self.stream) ) self.assertEqual("HTTP/1.1", start_line.version) self.assertEqual(400, start_line.code) self.assertEqual("Bad Request", start_line.reason) def test_malformed_first_line_log(self): with ExpectLog(gen_log, ".*Malformed HTTP request line", level=logging.INFO): self.stream.write(b"asdf\r\n\r\n") # TODO: need an async version of ExpectLog so we don't need # hard-coded timeouts here. self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_malformed_headers(self): with ExpectLog( gen_log, ".*Malformed HTTP message.*no colon in header line", level=logging.INFO, ): self.stream.write(b"GET / HTTP/1.0\r\nasdf\r\n\r\n") self.io_loop.add_timeout(datetime.timedelta(seconds=0.05), self.stop) self.wait() def test_chunked_request_body(self): # Chunked requests are not widely supported and we don't have a way # to generate them in AsyncHTTPClient, but HTTPServer will read them. self.stream.write( b"""\ POST /echo HTTP/1.1 Transfer-Encoding: chunked Content-Type: application/x-www-form-urlencoded 4 foo= 3 bar 0 """.replace( b"\n", b"\r\n" ) ) start_line, headers, response = self.io_loop.run_sync( lambda: read_stream_body(self.stream) ) self.assertEqual(json_decode(response), {u"foo": [u"bar"]}) def test_chunked_request_uppercase(self): # As per RFC 2616 section 3.6, "Transfer-Encoding" header's value is # case-insensitive. self.stream.write( b"""\ POST /echo HTTP/1.1 Transfer-Encoding: Chunked Content-Type: application/x-www-form-urlencoded 4 foo= 3 bar 0 """.replace( b"\n", b"\r\n" ) ) start_line, headers, response = self.io_loop.run_sync( lambda: read_stream_body(self.stream) ) self.assertEqual(json_decode(response), {u"foo": [u"bar"]}) @gen_test def test_invalid_content_length(self): with ExpectLog( gen_log, ".*Only integer Content-Length is allowed", level=logging.INFO ): self.stream.write( b"""\ POST /echo HTTP/1.1 Content-Length: foo bar """.replace( b"\n", b"\r\n" ) ) yield self.stream.read_until_close() class XHeaderTest(HandlerBaseTestCase): class Handler(RequestHandler): def get(self): self.set_header("request-version", self.request.version) self.write( dict( remote_ip=self.request.remote_ip, remote_protocol=self.request.protocol, ) ) def get_httpserver_options(self): return dict(xheaders=True, trusted_downstream=["5.5.5.5"]) def test_ip_headers(self): self.assertEqual(self.fetch_json("/")["remote_ip"], "127.0.0.1") valid_ipv4 = {"X-Real-IP": "4.4.4.4"} self.assertEqual( self.fetch_json("/", headers=valid_ipv4)["remote_ip"], "4.4.4.4" ) valid_ipv4_list = {"X-Forwarded-For": "127.0.0.1, 4.4.4.4"} self.assertEqual( self.fetch_json("/", headers=valid_ipv4_list)["remote_ip"], "4.4.4.4" ) valid_ipv6 = {"X-Real-IP": "2620:0:1cfe:face:b00c::3"} self.assertEqual( self.fetch_json("/", headers=valid_ipv6)["remote_ip"], "2620:0:1cfe:face:b00c::3", ) valid_ipv6_list = {"X-Forwarded-For": "::1, 2620:0:1cfe:face:b00c::3"} self.assertEqual( self.fetch_json("/", headers=valid_ipv6_list)["remote_ip"], "2620:0:1cfe:face:b00c::3", ) invalid_chars = {"X-Real-IP": "4.4.4.4