ssl_servers.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207
  1. from __future__ import absolute_import, division, print_function, unicode_literals
  2. from future.builtins import filter, str
  3. from future import utils
  4. import os
  5. import sys
  6. import ssl
  7. import pprint
  8. import socket
  9. from future.backports.urllib import parse as urllib_parse
  10. from future.backports.http.server import (HTTPServer as _HTTPServer,
  11. SimpleHTTPRequestHandler, BaseHTTPRequestHandler)
  12. from future.backports.test import support
  13. threading = support.import_module("threading")
  14. here = os.path.dirname(__file__)
  15. HOST = support.HOST
  16. CERTFILE = os.path.join(here, 'keycert.pem')
  17. # This one's based on HTTPServer, which is based on SocketServer
  18. class HTTPSServer(_HTTPServer):
  19. def __init__(self, server_address, handler_class, context):
  20. _HTTPServer.__init__(self, server_address, handler_class)
  21. self.context = context
  22. def __str__(self):
  23. return ('<%s %s:%s>' %
  24. (self.__class__.__name__,
  25. self.server_name,
  26. self.server_port))
  27. def get_request(self):
  28. # override this to wrap socket with SSL
  29. try:
  30. sock, addr = self.socket.accept()
  31. sslconn = self.context.wrap_socket(sock, server_side=True)
  32. except socket.error as e:
  33. # socket errors are silenced by the caller, print them here
  34. if support.verbose:
  35. sys.stderr.write("Got an error:\n%s\n" % e)
  36. raise
  37. return sslconn, addr
  38. class RootedHTTPRequestHandler(SimpleHTTPRequestHandler):
  39. # need to override translate_path to get a known root,
  40. # instead of using os.curdir, since the test could be
  41. # run from anywhere
  42. server_version = "TestHTTPS/1.0"
  43. root = here
  44. # Avoid hanging when a request gets interrupted by the client
  45. timeout = 5
  46. def translate_path(self, path):
  47. """Translate a /-separated PATH to the local filename syntax.
  48. Components that mean special things to the local file system
  49. (e.g. drive or directory names) are ignored. (XXX They should
  50. probably be diagnosed.)
  51. """
  52. # abandon query parameters
  53. path = urllib.parse.urlparse(path)[2]
  54. path = os.path.normpath(urllib.parse.unquote(path))
  55. words = path.split('/')
  56. words = filter(None, words)
  57. path = self.root
  58. for word in words:
  59. drive, word = os.path.splitdrive(word)
  60. head, word = os.path.split(word)
  61. path = os.path.join(path, word)
  62. return path
  63. def log_message(self, format, *args):
  64. # we override this to suppress logging unless "verbose"
  65. if support.verbose:
  66. sys.stdout.write(" server (%s:%d %s):\n [%s] %s\n" %
  67. (self.server.server_address,
  68. self.server.server_port,
  69. self.request.cipher(),
  70. self.log_date_time_string(),
  71. format%args))
  72. class StatsRequestHandler(BaseHTTPRequestHandler):
  73. """Example HTTP request handler which returns SSL statistics on GET
  74. requests.
  75. """
  76. server_version = "StatsHTTPS/1.0"
  77. def do_GET(self, send_body=True):
  78. """Serve a GET request."""
  79. sock = self.rfile.raw._sock
  80. context = sock.context
  81. stats = {
  82. 'session_cache': context.session_stats(),
  83. 'cipher': sock.cipher(),
  84. 'compression': sock.compression(),
  85. }
  86. body = pprint.pformat(stats)
  87. body = body.encode('utf-8')
  88. self.send_response(200)
  89. self.send_header("Content-type", "text/plain; charset=utf-8")
  90. self.send_header("Content-Length", str(len(body)))
  91. self.end_headers()
  92. if send_body:
  93. self.wfile.write(body)
  94. def do_HEAD(self):
  95. """Serve a HEAD request."""
  96. self.do_GET(send_body=False)
  97. def log_request(self, format, *args):
  98. if support.verbose:
  99. BaseHTTPRequestHandler.log_request(self, format, *args)
  100. class HTTPSServerThread(threading.Thread):
  101. def __init__(self, context, host=HOST, handler_class=None):
  102. self.flag = None
  103. self.server = HTTPSServer((host, 0),
  104. handler_class or RootedHTTPRequestHandler,
  105. context)
  106. self.port = self.server.server_port
  107. threading.Thread.__init__(self)
  108. self.daemon = True
  109. def __str__(self):
  110. return "<%s %s>" % (self.__class__.__name__, self.server)
  111. def start(self, flag=None):
  112. self.flag = flag
  113. threading.Thread.start(self)
  114. def run(self):
  115. if self.flag:
  116. self.flag.set()
  117. try:
  118. self.server.serve_forever(0.05)
  119. finally:
  120. self.server.server_close()
  121. def stop(self):
  122. self.server.shutdown()
  123. def make_https_server(case, certfile=CERTFILE, host=HOST, handler_class=None):
  124. # we assume the certfile contains both private key and certificate
  125. context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
  126. context.load_cert_chain(certfile)
  127. server = HTTPSServerThread(context, host, handler_class)
  128. flag = threading.Event()
  129. server.start(flag)
  130. flag.wait()
  131. def cleanup():
  132. if support.verbose:
  133. sys.stdout.write('stopping HTTPS server\n')
  134. server.stop()
  135. if support.verbose:
  136. sys.stdout.write('joining HTTPS thread\n')
  137. server.join()
  138. case.addCleanup(cleanup)
  139. return server
  140. if __name__ == "__main__":
  141. import argparse
  142. parser = argparse.ArgumentParser(
  143. description='Run a test HTTPS server. '
  144. 'By default, the current directory is served.')
  145. parser.add_argument('-p', '--port', type=int, default=4433,
  146. help='port to listen on (default: %(default)s)')
  147. parser.add_argument('-q', '--quiet', dest='verbose', default=True,
  148. action='store_false', help='be less verbose')
  149. parser.add_argument('-s', '--stats', dest='use_stats_handler', default=False,
  150. action='store_true', help='always return stats page')
  151. parser.add_argument('--curve-name', dest='curve_name', type=str,
  152. action='store',
  153. help='curve name for EC-based Diffie-Hellman')
  154. parser.add_argument('--dh', dest='dh_file', type=str, action='store',
  155. help='PEM file containing DH parameters')
  156. args = parser.parse_args()
  157. support.verbose = args.verbose
  158. if args.use_stats_handler:
  159. handler_class = StatsRequestHandler
  160. else:
  161. handler_class = RootedHTTPRequestHandler
  162. if utils.PY2:
  163. handler_class.root = os.getcwdu()
  164. else:
  165. handler_class.root = os.getcwd()
  166. context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
  167. context.load_cert_chain(CERTFILE)
  168. if args.curve_name:
  169. context.set_ecdh_curve(args.curve_name)
  170. if args.dh_file:
  171. context.load_dh_params(args.dh_file)
  172. server = HTTPSServer(("", args.port), handler_class, context)
  173. if args.verbose:
  174. print("Listening on https://localhost:{0.port}".format(args))
  175. server.serve_forever(0.1)