tcpclient.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334
  1. #
  2. # Copyright 2014 Facebook
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License"); you may
  5. # not use this file except in compliance with the License. You may obtain
  6. # a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  12. # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
  13. # License for the specific language governing permissions and limitations
  14. # under the License.
  15. """A non-blocking TCP connection factory.
  16. """
  17. import functools
  18. import socket
  19. import numbers
  20. import datetime
  21. import ssl
  22. from tornado.concurrent import Future, future_add_done_callback
  23. from tornado.ioloop import IOLoop
  24. from tornado.iostream import IOStream
  25. from tornado import gen
  26. from tornado.netutil import Resolver
  27. from tornado.platform.auto import set_close_exec
  28. from tornado.gen import TimeoutError
  29. import typing
  30. from typing import Any, Union, Dict, Tuple, List, Callable, Iterator
  31. if typing.TYPE_CHECKING:
  32. from typing import Optional, Set # noqa: F401
  33. _INITIAL_CONNECT_TIMEOUT = 0.3
  34. class _Connector(object):
  35. """A stateless implementation of the "Happy Eyeballs" algorithm.
  36. "Happy Eyeballs" is documented in RFC6555 as the recommended practice
  37. for when both IPv4 and IPv6 addresses are available.
  38. In this implementation, we partition the addresses by family, and
  39. make the first connection attempt to whichever address was
  40. returned first by ``getaddrinfo``. If that connection fails or
  41. times out, we begin a connection in parallel to the first address
  42. of the other family. If there are additional failures we retry
  43. with other addresses, keeping one connection attempt per family
  44. in flight at a time.
  45. http://tools.ietf.org/html/rfc6555
  46. """
  47. def __init__(
  48. self,
  49. addrinfo: List[Tuple],
  50. connect: Callable[
  51. [socket.AddressFamily, Tuple], Tuple[IOStream, "Future[IOStream]"]
  52. ],
  53. ) -> None:
  54. self.io_loop = IOLoop.current()
  55. self.connect = connect
  56. self.future = (
  57. Future()
  58. ) # type: Future[Tuple[socket.AddressFamily, Any, IOStream]]
  59. self.timeout = None # type: Optional[object]
  60. self.connect_timeout = None # type: Optional[object]
  61. self.last_error = None # type: Optional[Exception]
  62. self.remaining = len(addrinfo)
  63. self.primary_addrs, self.secondary_addrs = self.split(addrinfo)
  64. self.streams = set() # type: Set[IOStream]
  65. @staticmethod
  66. def split(
  67. addrinfo: List[Tuple],
  68. ) -> Tuple[
  69. List[Tuple[socket.AddressFamily, Tuple]],
  70. List[Tuple[socket.AddressFamily, Tuple]],
  71. ]:
  72. """Partition the ``addrinfo`` list by address family.
  73. Returns two lists. The first list contains the first entry from
  74. ``addrinfo`` and all others with the same family, and the
  75. second list contains all other addresses (normally one list will
  76. be AF_INET and the other AF_INET6, although non-standard resolvers
  77. may return additional families).
  78. """
  79. primary = []
  80. secondary = []
  81. primary_af = addrinfo[0][0]
  82. for af, addr in addrinfo:
  83. if af == primary_af:
  84. primary.append((af, addr))
  85. else:
  86. secondary.append((af, addr))
  87. return primary, secondary
  88. def start(
  89. self,
  90. timeout: float = _INITIAL_CONNECT_TIMEOUT,
  91. connect_timeout: Union[float, datetime.timedelta] = None,
  92. ) -> "Future[Tuple[socket.AddressFamily, Any, IOStream]]":
  93. self.try_connect(iter(self.primary_addrs))
  94. self.set_timeout(timeout)
  95. if connect_timeout is not None:
  96. self.set_connect_timeout(connect_timeout)
  97. return self.future
  98. def try_connect(self, addrs: Iterator[Tuple[socket.AddressFamily, Tuple]]) -> None:
  99. try:
  100. af, addr = next(addrs)
  101. except StopIteration:
  102. # We've reached the end of our queue, but the other queue
  103. # might still be working. Send a final error on the future
  104. # only when both queues are finished.
  105. if self.remaining == 0 and not self.future.done():
  106. self.future.set_exception(
  107. self.last_error or IOError("connection failed")
  108. )
  109. return
  110. stream, future = self.connect(af, addr)
  111. self.streams.add(stream)
  112. future_add_done_callback(
  113. future, functools.partial(self.on_connect_done, addrs, af, addr)
  114. )
  115. def on_connect_done(
  116. self,
  117. addrs: Iterator[Tuple[socket.AddressFamily, Tuple]],
  118. af: socket.AddressFamily,
  119. addr: Tuple,
  120. future: "Future[IOStream]",
  121. ) -> None:
  122. self.remaining -= 1
  123. try:
  124. stream = future.result()
  125. except Exception as e:
  126. if self.future.done():
  127. return
  128. # Error: try again (but remember what happened so we have an
  129. # error to raise in the end)
  130. self.last_error = e
  131. self.try_connect(addrs)
  132. if self.timeout is not None:
  133. # If the first attempt failed, don't wait for the
  134. # timeout to try an address from the secondary queue.
  135. self.io_loop.remove_timeout(self.timeout)
  136. self.on_timeout()
  137. return
  138. self.clear_timeouts()
  139. if self.future.done():
  140. # This is a late arrival; just drop it.
  141. stream.close()
  142. else:
  143. self.streams.discard(stream)
  144. self.future.set_result((af, addr, stream))
  145. self.close_streams()
  146. def set_timeout(self, timeout: float) -> None:
  147. self.timeout = self.io_loop.add_timeout(
  148. self.io_loop.time() + timeout, self.on_timeout
  149. )
  150. def on_timeout(self) -> None:
  151. self.timeout = None
  152. if not self.future.done():
  153. self.try_connect(iter(self.secondary_addrs))
  154. def clear_timeout(self) -> None:
  155. if self.timeout is not None:
  156. self.io_loop.remove_timeout(self.timeout)
  157. def set_connect_timeout(
  158. self, connect_timeout: Union[float, datetime.timedelta]
  159. ) -> None:
  160. self.connect_timeout = self.io_loop.add_timeout(
  161. connect_timeout, self.on_connect_timeout
  162. )
  163. def on_connect_timeout(self) -> None:
  164. if not self.future.done():
  165. self.future.set_exception(TimeoutError())
  166. self.close_streams()
  167. def clear_timeouts(self) -> None:
  168. if self.timeout is not None:
  169. self.io_loop.remove_timeout(self.timeout)
  170. if self.connect_timeout is not None:
  171. self.io_loop.remove_timeout(self.connect_timeout)
  172. def close_streams(self) -> None:
  173. for stream in self.streams:
  174. stream.close()
  175. class TCPClient(object):
  176. """A non-blocking TCP connection factory.
  177. .. versionchanged:: 5.0
  178. The ``io_loop`` argument (deprecated since version 4.1) has been removed.
  179. """
  180. def __init__(self, resolver: Resolver = None) -> None:
  181. if resolver is not None:
  182. self.resolver = resolver
  183. self._own_resolver = False
  184. else:
  185. self.resolver = Resolver()
  186. self._own_resolver = True
  187. def close(self) -> None:
  188. if self._own_resolver:
  189. self.resolver.close()
  190. async def connect(
  191. self,
  192. host: str,
  193. port: int,
  194. af: socket.AddressFamily = socket.AF_UNSPEC,
  195. ssl_options: Union[Dict[str, Any], ssl.SSLContext] = None,
  196. max_buffer_size: int = None,
  197. source_ip: str = None,
  198. source_port: int = None,
  199. timeout: Union[float, datetime.timedelta] = None,
  200. ) -> IOStream:
  201. """Connect to the given host and port.
  202. Asynchronously returns an `.IOStream` (or `.SSLIOStream` if
  203. ``ssl_options`` is not None).
  204. Using the ``source_ip`` kwarg, one can specify the source
  205. IP address to use when establishing the connection.
  206. In case the user needs to resolve and
  207. use a specific interface, it has to be handled outside
  208. of Tornado as this depends very much on the platform.
  209. Raises `TimeoutError` if the input future does not complete before
  210. ``timeout``, which may be specified in any form allowed by
  211. `.IOLoop.add_timeout` (i.e. a `datetime.timedelta` or an absolute time
  212. relative to `.IOLoop.time`)
  213. Similarly, when the user requires a certain source port, it can
  214. be specified using the ``source_port`` arg.
  215. .. versionchanged:: 4.5
  216. Added the ``source_ip`` and ``source_port`` arguments.
  217. .. versionchanged:: 5.0
  218. Added the ``timeout`` argument.
  219. """
  220. if timeout is not None:
  221. if isinstance(timeout, numbers.Real):
  222. timeout = IOLoop.current().time() + timeout
  223. elif isinstance(timeout, datetime.timedelta):
  224. timeout = IOLoop.current().time() + timeout.total_seconds()
  225. else:
  226. raise TypeError("Unsupported timeout %r" % timeout)
  227. if timeout is not None:
  228. addrinfo = await gen.with_timeout(
  229. timeout, self.resolver.resolve(host, port, af)
  230. )
  231. else:
  232. addrinfo = await self.resolver.resolve(host, port, af)
  233. connector = _Connector(
  234. addrinfo,
  235. functools.partial(
  236. self._create_stream,
  237. max_buffer_size,
  238. source_ip=source_ip,
  239. source_port=source_port,
  240. ),
  241. )
  242. af, addr, stream = await connector.start(connect_timeout=timeout)
  243. # TODO: For better performance we could cache the (af, addr)
  244. # information here and re-use it on subsequent connections to
  245. # the same host. (http://tools.ietf.org/html/rfc6555#section-4.2)
  246. if ssl_options is not None:
  247. if timeout is not None:
  248. stream = await gen.with_timeout(
  249. timeout,
  250. stream.start_tls(
  251. False, ssl_options=ssl_options, server_hostname=host
  252. ),
  253. )
  254. else:
  255. stream = await stream.start_tls(
  256. False, ssl_options=ssl_options, server_hostname=host
  257. )
  258. return stream
  259. def _create_stream(
  260. self,
  261. max_buffer_size: int,
  262. af: socket.AddressFamily,
  263. addr: Tuple,
  264. source_ip: str = None,
  265. source_port: int = None,
  266. ) -> Tuple[IOStream, "Future[IOStream]"]:
  267. # Always connect in plaintext; we'll convert to ssl if necessary
  268. # after one connection has completed.
  269. source_port_bind = source_port if isinstance(source_port, int) else 0
  270. source_ip_bind = source_ip
  271. if source_port_bind and not source_ip:
  272. # User required a specific port, but did not specify
  273. # a certain source IP, will bind to the default loopback.
  274. source_ip_bind = "::1" if af == socket.AF_INET6 else "127.0.0.1"
  275. # Trying to use the same address family as the requested af socket:
  276. # - 127.0.0.1 for IPv4
  277. # - ::1 for IPv6
  278. socket_obj = socket.socket(af)
  279. set_close_exec(socket_obj.fileno())
  280. if source_port_bind or source_ip_bind:
  281. # If the user requires binding also to a specific IP/port.
  282. try:
  283. socket_obj.bind((source_ip_bind, source_port_bind))
  284. except socket.error:
  285. socket_obj.close()
  286. # Fail loudly if unable to use the IP/port.
  287. raise
  288. try:
  289. stream = IOStream(socket_obj, max_buffer_size=max_buffer_size)
  290. except socket.error as e:
  291. fu = Future() # type: Future[IOStream]
  292. fu.set_exception(e)
  293. return stream, fu
  294. else:
  295. return stream, stream.connect(addr)