netutil_test.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226
  1. import errno
  2. import os
  3. import signal
  4. import socket
  5. from subprocess import Popen
  6. import sys
  7. import time
  8. import unittest
  9. from tornado.netutil import (
  10. BlockingResolver,
  11. OverrideResolver,
  12. ThreadedResolver,
  13. is_valid_ip,
  14. bind_sockets,
  15. )
  16. from tornado.testing import AsyncTestCase, gen_test, bind_unused_port
  17. from tornado.test.util import skipIfNoNetwork
  18. import typing
  19. if typing.TYPE_CHECKING:
  20. from typing import List # noqa: F401
  21. try:
  22. import pycares # type: ignore
  23. except ImportError:
  24. pycares = None
  25. else:
  26. from tornado.platform.caresresolver import CaresResolver
  27. try:
  28. import twisted # type: ignore
  29. import twisted.names # type: ignore
  30. except ImportError:
  31. twisted = None
  32. else:
  33. from tornado.platform.twisted import TwistedResolver
  34. class _ResolverTestMixin(object):
  35. @gen_test
  36. def test_localhost(self):
  37. addrinfo = yield self.resolver.resolve("localhost", 80, socket.AF_UNSPEC)
  38. self.assertIn((socket.AF_INET, ("127.0.0.1", 80)), addrinfo)
  39. # It is impossible to quickly and consistently generate an error in name
  40. # resolution, so test this case separately, using mocks as needed.
  41. class _ResolverErrorTestMixin(object):
  42. @gen_test
  43. def test_bad_host(self):
  44. with self.assertRaises(IOError):
  45. yield self.resolver.resolve("an invalid domain", 80, socket.AF_UNSPEC)
  46. def _failing_getaddrinfo(*args):
  47. """Dummy implementation of getaddrinfo for use in mocks"""
  48. raise socket.gaierror(errno.EIO, "mock: lookup failed")
  49. @skipIfNoNetwork
  50. class BlockingResolverTest(AsyncTestCase, _ResolverTestMixin):
  51. def setUp(self):
  52. super(BlockingResolverTest, self).setUp()
  53. self.resolver = BlockingResolver()
  54. # getaddrinfo-based tests need mocking to reliably generate errors;
  55. # some configurations are slow to produce errors and take longer than
  56. # our default timeout.
  57. class BlockingResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
  58. def setUp(self):
  59. super(BlockingResolverErrorTest, self).setUp()
  60. self.resolver = BlockingResolver()
  61. self.real_getaddrinfo = socket.getaddrinfo
  62. socket.getaddrinfo = _failing_getaddrinfo
  63. def tearDown(self):
  64. socket.getaddrinfo = self.real_getaddrinfo
  65. super(BlockingResolverErrorTest, self).tearDown()
  66. class OverrideResolverTest(AsyncTestCase, _ResolverTestMixin):
  67. def setUp(self):
  68. super(OverrideResolverTest, self).setUp()
  69. mapping = {
  70. ("google.com", 80): ("1.2.3.4", 80),
  71. ("google.com", 80, socket.AF_INET): ("1.2.3.4", 80),
  72. ("google.com", 80, socket.AF_INET6): (
  73. "2a02:6b8:7c:40c:c51e:495f:e23a:3",
  74. 80,
  75. ),
  76. }
  77. self.resolver = OverrideResolver(BlockingResolver(), mapping)
  78. @gen_test
  79. def test_resolve_multiaddr(self):
  80. result = yield self.resolver.resolve("google.com", 80, socket.AF_INET)
  81. self.assertIn((socket.AF_INET, ("1.2.3.4", 80)), result)
  82. result = yield self.resolver.resolve("google.com", 80, socket.AF_INET6)
  83. self.assertIn(
  84. (socket.AF_INET6, ("2a02:6b8:7c:40c:c51e:495f:e23a:3", 80, 0, 0)), result
  85. )
  86. @skipIfNoNetwork
  87. class ThreadedResolverTest(AsyncTestCase, _ResolverTestMixin):
  88. def setUp(self):
  89. super(ThreadedResolverTest, self).setUp()
  90. self.resolver = ThreadedResolver()
  91. def tearDown(self):
  92. self.resolver.close()
  93. super(ThreadedResolverTest, self).tearDown()
  94. class ThreadedResolverErrorTest(AsyncTestCase, _ResolverErrorTestMixin):
  95. def setUp(self):
  96. super(ThreadedResolverErrorTest, self).setUp()
  97. self.resolver = BlockingResolver()
  98. self.real_getaddrinfo = socket.getaddrinfo
  99. socket.getaddrinfo = _failing_getaddrinfo
  100. def tearDown(self):
  101. socket.getaddrinfo = self.real_getaddrinfo
  102. super(ThreadedResolverErrorTest, self).tearDown()
  103. @skipIfNoNetwork
  104. @unittest.skipIf(sys.platform == "win32", "preexec_fn not available on win32")
  105. class ThreadedResolverImportTest(unittest.TestCase):
  106. def test_import(self):
  107. TIMEOUT = 5
  108. # Test for a deadlock when importing a module that runs the
  109. # ThreadedResolver at import-time. See resolve_test.py for
  110. # full explanation.
  111. command = [sys.executable, "-c", "import tornado.test.resolve_test_helper"]
  112. start = time.time()
  113. popen = Popen(command, preexec_fn=lambda: signal.alarm(TIMEOUT))
  114. while time.time() - start < TIMEOUT:
  115. return_code = popen.poll()
  116. if return_code is not None:
  117. self.assertEqual(0, return_code)
  118. return # Success.
  119. time.sleep(0.05)
  120. self.fail("import timed out")
  121. # We do not test errors with CaresResolver:
  122. # Some DNS-hijacking ISPs (e.g. Time Warner) return non-empty results
  123. # with an NXDOMAIN status code. Most resolvers treat this as an error;
  124. # C-ares returns the results, making the "bad_host" tests unreliable.
  125. # C-ares will try to resolve even malformed names, such as the
  126. # name with spaces used in this test.
  127. @skipIfNoNetwork
  128. @unittest.skipIf(pycares is None, "pycares module not present")
  129. class CaresResolverTest(AsyncTestCase, _ResolverTestMixin):
  130. def setUp(self):
  131. super(CaresResolverTest, self).setUp()
  132. self.resolver = CaresResolver()
  133. # TwistedResolver produces consistent errors in our test cases so we
  134. # could test the regular and error cases in the same class. However,
  135. # in the error cases it appears that cleanup of socket objects is
  136. # handled asynchronously and occasionally results in "unclosed socket"
  137. # warnings if not given time to shut down (and there is no way to
  138. # explicitly shut it down). This makes the test flaky, so we do not
  139. # test error cases here.
  140. @skipIfNoNetwork
  141. @unittest.skipIf(twisted is None, "twisted module not present")
  142. @unittest.skipIf(
  143. getattr(twisted, "__version__", "0.0") < "12.1", "old version of twisted"
  144. )
  145. class TwistedResolverTest(AsyncTestCase, _ResolverTestMixin):
  146. def setUp(self):
  147. super(TwistedResolverTest, self).setUp()
  148. self.resolver = TwistedResolver()
  149. class IsValidIPTest(unittest.TestCase):
  150. def test_is_valid_ip(self):
  151. self.assertTrue(is_valid_ip("127.0.0.1"))
  152. self.assertTrue(is_valid_ip("4.4.4.4"))
  153. self.assertTrue(is_valid_ip("::1"))
  154. self.assertTrue(is_valid_ip("2620:0:1cfe:face:b00c::3"))
  155. self.assertTrue(not is_valid_ip("www.google.com"))
  156. self.assertTrue(not is_valid_ip("localhost"))
  157. self.assertTrue(not is_valid_ip("4.4.4.4<"))
  158. self.assertTrue(not is_valid_ip(" 127.0.0.1"))
  159. self.assertTrue(not is_valid_ip(""))
  160. self.assertTrue(not is_valid_ip(" "))
  161. self.assertTrue(not is_valid_ip("\n"))
  162. self.assertTrue(not is_valid_ip("\x00"))
  163. class TestPortAllocation(unittest.TestCase):
  164. def test_same_port_allocation(self):
  165. if "TRAVIS" in os.environ:
  166. self.skipTest("dual-stack servers often have port conflicts on travis")
  167. sockets = bind_sockets(0, "localhost")
  168. try:
  169. port = sockets[0].getsockname()[1]
  170. self.assertTrue(all(s.getsockname()[1] == port for s in sockets[1:]))
  171. finally:
  172. for sock in sockets:
  173. sock.close()
  174. @unittest.skipIf(
  175. not hasattr(socket, "SO_REUSEPORT"), "SO_REUSEPORT is not supported"
  176. )
  177. def test_reuse_port(self):
  178. sockets = [] # type: List[socket.socket]
  179. socket, port = bind_unused_port(reuse_port=True)
  180. try:
  181. sockets = bind_sockets(port, "127.0.0.1", reuse_port=True)
  182. self.assertTrue(all(s.getsockname()[1] == port for s in sockets))
  183. finally:
  184. socket.close()
  185. for sock in sockets:
  186. sock.close()