websocket_test.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822
  1. import asyncio
  2. import functools
  3. import traceback
  4. import unittest
  5. from tornado.concurrent import Future
  6. from tornado import gen
  7. from tornado.httpclient import HTTPError, HTTPRequest
  8. from tornado.locks import Event
  9. from tornado.log import gen_log, app_log
  10. from tornado.simple_httpclient import SimpleAsyncHTTPClient
  11. from tornado.template import DictLoader
  12. from tornado.testing import AsyncHTTPTestCase, gen_test, bind_unused_port, ExpectLog
  13. from tornado.web import Application, RequestHandler
  14. try:
  15. import tornado.websocket # noqa: F401
  16. from tornado.util import _websocket_mask_python
  17. except ImportError:
  18. # The unittest module presents misleading errors on ImportError
  19. # (it acts as if websocket_test could not be found, hiding the underlying
  20. # error). If we get an ImportError here (which could happen due to
  21. # TORNADO_EXTENSION=1), print some extra information before failing.
  22. traceback.print_exc()
  23. raise
  24. from tornado.websocket import (
  25. WebSocketHandler,
  26. websocket_connect,
  27. WebSocketError,
  28. WebSocketClosedError,
  29. )
  30. try:
  31. from tornado import speedups
  32. except ImportError:
  33. speedups = None # type: ignore
  34. class TestWebSocketHandler(WebSocketHandler):
  35. """Base class for testing handlers that exposes the on_close event.
  36. This allows for tests to see the close code and reason on the
  37. server side.
  38. """
  39. def initialize(self, close_future=None, compression_options=None):
  40. self.close_future = close_future
  41. self.compression_options = compression_options
  42. def get_compression_options(self):
  43. return self.compression_options
  44. def on_close(self):
  45. if self.close_future is not None:
  46. self.close_future.set_result((self.close_code, self.close_reason))
  47. class EchoHandler(TestWebSocketHandler):
  48. @gen.coroutine
  49. def on_message(self, message):
  50. try:
  51. yield self.write_message(message, isinstance(message, bytes))
  52. except asyncio.CancelledError:
  53. pass
  54. except WebSocketClosedError:
  55. pass
  56. class ErrorInOnMessageHandler(TestWebSocketHandler):
  57. def on_message(self, message):
  58. 1 / 0
  59. class HeaderHandler(TestWebSocketHandler):
  60. def open(self):
  61. methods_to_test = [
  62. functools.partial(self.write, "This should not work"),
  63. functools.partial(self.redirect, "http://localhost/elsewhere"),
  64. functools.partial(self.set_header, "X-Test", ""),
  65. functools.partial(self.set_cookie, "Chocolate", "Chip"),
  66. functools.partial(self.set_status, 503),
  67. self.flush,
  68. self.finish,
  69. ]
  70. for method in methods_to_test:
  71. try:
  72. # In a websocket context, many RequestHandler methods
  73. # raise RuntimeErrors.
  74. method()
  75. raise Exception("did not get expected exception")
  76. except RuntimeError:
  77. pass
  78. self.write_message(self.request.headers.get("X-Test", ""))
  79. class HeaderEchoHandler(TestWebSocketHandler):
  80. def set_default_headers(self):
  81. self.set_header("X-Extra-Response-Header", "Extra-Response-Value")
  82. def prepare(self):
  83. for k, v in self.request.headers.get_all():
  84. if k.lower().startswith("x-test"):
  85. self.set_header(k, v)
  86. class NonWebSocketHandler(RequestHandler):
  87. def get(self):
  88. self.write("ok")
  89. class CloseReasonHandler(TestWebSocketHandler):
  90. def open(self):
  91. self.on_close_called = False
  92. self.close(1001, "goodbye")
  93. class AsyncPrepareHandler(TestWebSocketHandler):
  94. @gen.coroutine
  95. def prepare(self):
  96. yield gen.moment
  97. def on_message(self, message):
  98. self.write_message(message)
  99. class PathArgsHandler(TestWebSocketHandler):
  100. def open(self, arg):
  101. self.write_message(arg)
  102. class CoroutineOnMessageHandler(TestWebSocketHandler):
  103. def initialize(self, **kwargs):
  104. super(CoroutineOnMessageHandler, self).initialize(**kwargs)
  105. self.sleeping = 0
  106. @gen.coroutine
  107. def on_message(self, message):
  108. if self.sleeping > 0:
  109. self.write_message("another coroutine is already sleeping")
  110. self.sleeping += 1
  111. yield gen.sleep(0.01)
  112. self.sleeping -= 1
  113. self.write_message(message)
  114. class RenderMessageHandler(TestWebSocketHandler):
  115. def on_message(self, message):
  116. self.write_message(self.render_string("message.html", message=message))
  117. class SubprotocolHandler(TestWebSocketHandler):
  118. def initialize(self, **kwargs):
  119. super(SubprotocolHandler, self).initialize(**kwargs)
  120. self.select_subprotocol_called = False
  121. def select_subprotocol(self, subprotocols):
  122. if self.select_subprotocol_called:
  123. raise Exception("select_subprotocol called twice")
  124. self.select_subprotocol_called = True
  125. if "goodproto" in subprotocols:
  126. return "goodproto"
  127. return None
  128. def open(self):
  129. if not self.select_subprotocol_called:
  130. raise Exception("select_subprotocol not called")
  131. self.write_message("subprotocol=%s" % self.selected_subprotocol)
  132. class OpenCoroutineHandler(TestWebSocketHandler):
  133. def initialize(self, test, **kwargs):
  134. super(OpenCoroutineHandler, self).initialize(**kwargs)
  135. self.test = test
  136. self.open_finished = False
  137. @gen.coroutine
  138. def open(self):
  139. yield self.test.message_sent.wait()
  140. yield gen.sleep(0.010)
  141. self.open_finished = True
  142. def on_message(self, message):
  143. if not self.open_finished:
  144. raise Exception("on_message called before open finished")
  145. self.write_message("ok")
  146. class ErrorInOpenHandler(TestWebSocketHandler):
  147. def open(self):
  148. raise Exception("boom")
  149. class ErrorInAsyncOpenHandler(TestWebSocketHandler):
  150. async def open(self):
  151. await asyncio.sleep(0)
  152. raise Exception("boom")
  153. class NoDelayHandler(TestWebSocketHandler):
  154. def open(self):
  155. self.set_nodelay(True)
  156. self.write_message("hello")
  157. class WebSocketBaseTestCase(AsyncHTTPTestCase):
  158. @gen.coroutine
  159. def ws_connect(self, path, **kwargs):
  160. ws = yield websocket_connect(
  161. "ws://127.0.0.1:%d%s" % (self.get_http_port(), path), **kwargs
  162. )
  163. raise gen.Return(ws)
  164. class WebSocketTest(WebSocketBaseTestCase):
  165. def get_app(self):
  166. self.close_future = Future() # type: Future[None]
  167. return Application(
  168. [
  169. ("/echo", EchoHandler, dict(close_future=self.close_future)),
  170. ("/non_ws", NonWebSocketHandler),
  171. ("/header", HeaderHandler, dict(close_future=self.close_future)),
  172. (
  173. "/header_echo",
  174. HeaderEchoHandler,
  175. dict(close_future=self.close_future),
  176. ),
  177. (
  178. "/close_reason",
  179. CloseReasonHandler,
  180. dict(close_future=self.close_future),
  181. ),
  182. (
  183. "/error_in_on_message",
  184. ErrorInOnMessageHandler,
  185. dict(close_future=self.close_future),
  186. ),
  187. (
  188. "/async_prepare",
  189. AsyncPrepareHandler,
  190. dict(close_future=self.close_future),
  191. ),
  192. (
  193. "/path_args/(.*)",
  194. PathArgsHandler,
  195. dict(close_future=self.close_future),
  196. ),
  197. (
  198. "/coroutine",
  199. CoroutineOnMessageHandler,
  200. dict(close_future=self.close_future),
  201. ),
  202. ("/render", RenderMessageHandler, dict(close_future=self.close_future)),
  203. (
  204. "/subprotocol",
  205. SubprotocolHandler,
  206. dict(close_future=self.close_future),
  207. ),
  208. (
  209. "/open_coroutine",
  210. OpenCoroutineHandler,
  211. dict(close_future=self.close_future, test=self),
  212. ),
  213. ("/error_in_open", ErrorInOpenHandler),
  214. ("/error_in_async_open", ErrorInAsyncOpenHandler),
  215. ("/nodelay", NoDelayHandler),
  216. ],
  217. template_loader=DictLoader({"message.html": "<b>{{ message }}</b>"}),
  218. )
  219. def get_http_client(self):
  220. # These tests require HTTP/1; force the use of SimpleAsyncHTTPClient.
  221. return SimpleAsyncHTTPClient()
  222. def tearDown(self):
  223. super(WebSocketTest, self).tearDown()
  224. RequestHandler._template_loaders.clear()
  225. def test_http_request(self):
  226. # WS server, HTTP client.
  227. response = self.fetch("/echo")
  228. self.assertEqual(response.code, 400)
  229. def test_missing_websocket_key(self):
  230. response = self.fetch(
  231. "/echo",
  232. headers={
  233. "Connection": "Upgrade",
  234. "Upgrade": "WebSocket",
  235. "Sec-WebSocket-Version": "13",
  236. },
  237. )
  238. self.assertEqual(response.code, 400)
  239. def test_bad_websocket_version(self):
  240. response = self.fetch(
  241. "/echo",
  242. headers={
  243. "Connection": "Upgrade",
  244. "Upgrade": "WebSocket",
  245. "Sec-WebSocket-Version": "12",
  246. },
  247. )
  248. self.assertEqual(response.code, 426)
  249. @gen_test
  250. def test_websocket_gen(self):
  251. ws = yield self.ws_connect("/echo")
  252. yield ws.write_message("hello")
  253. response = yield ws.read_message()
  254. self.assertEqual(response, "hello")
  255. def test_websocket_callbacks(self):
  256. websocket_connect(
  257. "ws://127.0.0.1:%d/echo" % self.get_http_port(), callback=self.stop
  258. )
  259. ws = self.wait().result()
  260. ws.write_message("hello")
  261. ws.read_message(self.stop)
  262. response = self.wait().result()
  263. self.assertEqual(response, "hello")
  264. self.close_future.add_done_callback(lambda f: self.stop())
  265. ws.close()
  266. self.wait()
  267. @gen_test
  268. def test_binary_message(self):
  269. ws = yield self.ws_connect("/echo")
  270. ws.write_message(b"hello \xe9", binary=True)
  271. response = yield ws.read_message()
  272. self.assertEqual(response, b"hello \xe9")
  273. @gen_test
  274. def test_unicode_message(self):
  275. ws = yield self.ws_connect("/echo")
  276. ws.write_message(u"hello \u00e9")
  277. response = yield ws.read_message()
  278. self.assertEqual(response, u"hello \u00e9")
  279. @gen_test
  280. def test_render_message(self):
  281. ws = yield self.ws_connect("/render")
  282. ws.write_message("hello")
  283. response = yield ws.read_message()
  284. self.assertEqual(response, "<b>hello</b>")
  285. @gen_test
  286. def test_error_in_on_message(self):
  287. ws = yield self.ws_connect("/error_in_on_message")
  288. ws.write_message("hello")
  289. with ExpectLog(app_log, "Uncaught exception"):
  290. response = yield ws.read_message()
  291. self.assertIs(response, None)
  292. @gen_test
  293. def test_websocket_http_fail(self):
  294. with self.assertRaises(HTTPError) as cm:
  295. yield self.ws_connect("/notfound")
  296. self.assertEqual(cm.exception.code, 404)
  297. @gen_test
  298. def test_websocket_http_success(self):
  299. with self.assertRaises(WebSocketError):
  300. yield self.ws_connect("/non_ws")
  301. @gen_test
  302. def test_websocket_network_fail(self):
  303. sock, port = bind_unused_port()
  304. sock.close()
  305. with self.assertRaises(IOError):
  306. with ExpectLog(gen_log, ".*"):
  307. yield websocket_connect(
  308. "ws://127.0.0.1:%d/" % port, connect_timeout=3600
  309. )
  310. @gen_test
  311. def test_websocket_close_buffered_data(self):
  312. ws = yield websocket_connect("ws://127.0.0.1:%d/echo" % self.get_http_port())
  313. ws.write_message("hello")
  314. ws.write_message("world")
  315. # Close the underlying stream.
  316. ws.stream.close()
  317. @gen_test
  318. def test_websocket_headers(self):
  319. # Ensure that arbitrary headers can be passed through websocket_connect.
  320. ws = yield websocket_connect(
  321. HTTPRequest(
  322. "ws://127.0.0.1:%d/header" % self.get_http_port(),
  323. headers={"X-Test": "hello"},
  324. )
  325. )
  326. response = yield ws.read_message()
  327. self.assertEqual(response, "hello")
  328. @gen_test
  329. def test_websocket_header_echo(self):
  330. # Ensure that headers can be returned in the response.
  331. # Specifically, that arbitrary headers passed through websocket_connect
  332. # can be returned.
  333. ws = yield websocket_connect(
  334. HTTPRequest(
  335. "ws://127.0.0.1:%d/header_echo" % self.get_http_port(),
  336. headers={"X-Test-Hello": "hello"},
  337. )
  338. )
  339. self.assertEqual(ws.headers.get("X-Test-Hello"), "hello")
  340. self.assertEqual(
  341. ws.headers.get("X-Extra-Response-Header"), "Extra-Response-Value"
  342. )
  343. @gen_test
  344. def test_server_close_reason(self):
  345. ws = yield self.ws_connect("/close_reason")
  346. msg = yield ws.read_message()
  347. # A message of None means the other side closed the connection.
  348. self.assertIs(msg, None)
  349. self.assertEqual(ws.close_code, 1001)
  350. self.assertEqual(ws.close_reason, "goodbye")
  351. # The on_close callback is called no matter which side closed.
  352. code, reason = yield self.close_future
  353. # The client echoed the close code it received to the server,
  354. # so the server's close code (returned via close_future) is
  355. # the same.
  356. self.assertEqual(code, 1001)
  357. @gen_test
  358. def test_client_close_reason(self):
  359. ws = yield self.ws_connect("/echo")
  360. ws.close(1001, "goodbye")
  361. code, reason = yield self.close_future
  362. self.assertEqual(code, 1001)
  363. self.assertEqual(reason, "goodbye")
  364. @gen_test
  365. def test_write_after_close(self):
  366. ws = yield self.ws_connect("/close_reason")
  367. msg = yield ws.read_message()
  368. self.assertIs(msg, None)
  369. with self.assertRaises(WebSocketClosedError):
  370. ws.write_message("hello")
  371. @gen_test
  372. def test_async_prepare(self):
  373. # Previously, an async prepare method triggered a bug that would
  374. # result in a timeout on test shutdown (and a memory leak).
  375. ws = yield self.ws_connect("/async_prepare")
  376. ws.write_message("hello")
  377. res = yield ws.read_message()
  378. self.assertEqual(res, "hello")
  379. @gen_test
  380. def test_path_args(self):
  381. ws = yield self.ws_connect("/path_args/hello")
  382. res = yield ws.read_message()
  383. self.assertEqual(res, "hello")
  384. @gen_test
  385. def test_coroutine(self):
  386. ws = yield self.ws_connect("/coroutine")
  387. # Send both messages immediately, coroutine must process one at a time.
  388. yield ws.write_message("hello1")
  389. yield ws.write_message("hello2")
  390. res = yield ws.read_message()
  391. self.assertEqual(res, "hello1")
  392. res = yield ws.read_message()
  393. self.assertEqual(res, "hello2")
  394. @gen_test
  395. def test_check_origin_valid_no_path(self):
  396. port = self.get_http_port()
  397. url = "ws://127.0.0.1:%d/echo" % port
  398. headers = {"Origin": "http://127.0.0.1:%d" % port}
  399. ws = yield websocket_connect(HTTPRequest(url, headers=headers))
  400. ws.write_message("hello")
  401. response = yield ws.read_message()
  402. self.assertEqual(response, "hello")
  403. @gen_test
  404. def test_check_origin_valid_with_path(self):
  405. port = self.get_http_port()
  406. url = "ws://127.0.0.1:%d/echo" % port
  407. headers = {"Origin": "http://127.0.0.1:%d/something" % port}
  408. ws = yield websocket_connect(HTTPRequest(url, headers=headers))
  409. ws.write_message("hello")
  410. response = yield ws.read_message()
  411. self.assertEqual(response, "hello")
  412. @gen_test
  413. def test_check_origin_invalid_partial_url(self):
  414. port = self.get_http_port()
  415. url = "ws://127.0.0.1:%d/echo" % port
  416. headers = {"Origin": "127.0.0.1:%d" % port}
  417. with self.assertRaises(HTTPError) as cm:
  418. yield websocket_connect(HTTPRequest(url, headers=headers))
  419. self.assertEqual(cm.exception.code, 403)
  420. @gen_test
  421. def test_check_origin_invalid(self):
  422. port = self.get_http_port()
  423. url = "ws://127.0.0.1:%d/echo" % port
  424. # Host is 127.0.0.1, which should not be accessible from some other
  425. # domain
  426. headers = {"Origin": "http://somewhereelse.com"}
  427. with self.assertRaises(HTTPError) as cm:
  428. yield websocket_connect(HTTPRequest(url, headers=headers))
  429. self.assertEqual(cm.exception.code, 403)
  430. @gen_test
  431. def test_check_origin_invalid_subdomains(self):
  432. port = self.get_http_port()
  433. url = "ws://localhost:%d/echo" % port
  434. # Subdomains should be disallowed by default. If we could pass a
  435. # resolver to websocket_connect we could test sibling domains as well.
  436. headers = {"Origin": "http://subtenant.localhost"}
  437. with self.assertRaises(HTTPError) as cm:
  438. yield websocket_connect(HTTPRequest(url, headers=headers))
  439. self.assertEqual(cm.exception.code, 403)
  440. @gen_test
  441. def test_subprotocols(self):
  442. ws = yield self.ws_connect(
  443. "/subprotocol", subprotocols=["badproto", "goodproto"]
  444. )
  445. self.assertEqual(ws.selected_subprotocol, "goodproto")
  446. res = yield ws.read_message()
  447. self.assertEqual(res, "subprotocol=goodproto")
  448. @gen_test
  449. def test_subprotocols_not_offered(self):
  450. ws = yield self.ws_connect("/subprotocol")
  451. self.assertIs(ws.selected_subprotocol, None)
  452. res = yield ws.read_message()
  453. self.assertEqual(res, "subprotocol=None")
  454. @gen_test
  455. def test_open_coroutine(self):
  456. self.message_sent = Event()
  457. ws = yield self.ws_connect("/open_coroutine")
  458. yield ws.write_message("hello")
  459. self.message_sent.set()
  460. res = yield ws.read_message()
  461. self.assertEqual(res, "ok")
  462. @gen_test
  463. def test_error_in_open(self):
  464. with ExpectLog(app_log, "Uncaught exception"):
  465. ws = yield self.ws_connect("/error_in_open")
  466. res = yield ws.read_message()
  467. self.assertIsNone(res)
  468. @gen_test
  469. def test_error_in_async_open(self):
  470. with ExpectLog(app_log, "Uncaught exception"):
  471. ws = yield self.ws_connect("/error_in_async_open")
  472. res = yield ws.read_message()
  473. self.assertIsNone(res)
  474. @gen_test
  475. def test_nodelay(self):
  476. ws = yield self.ws_connect("/nodelay")
  477. res = yield ws.read_message()
  478. self.assertEqual(res, "hello")
  479. class NativeCoroutineOnMessageHandler(TestWebSocketHandler):
  480. def initialize(self, **kwargs):
  481. super().initialize(**kwargs)
  482. self.sleeping = 0
  483. async def on_message(self, message):
  484. if self.sleeping > 0:
  485. self.write_message("another coroutine is already sleeping")
  486. self.sleeping += 1
  487. await gen.sleep(0.01)
  488. self.sleeping -= 1
  489. self.write_message(message)
  490. class WebSocketNativeCoroutineTest(WebSocketBaseTestCase):
  491. def get_app(self):
  492. return Application([("/native", NativeCoroutineOnMessageHandler)])
  493. @gen_test
  494. def test_native_coroutine(self):
  495. ws = yield self.ws_connect("/native")
  496. # Send both messages immediately, coroutine must process one at a time.
  497. yield ws.write_message("hello1")
  498. yield ws.write_message("hello2")
  499. res = yield ws.read_message()
  500. self.assertEqual(res, "hello1")
  501. res = yield ws.read_message()
  502. self.assertEqual(res, "hello2")
  503. class CompressionTestMixin(object):
  504. MESSAGE = "Hello world. Testing 123 123"
  505. def get_app(self):
  506. class LimitedHandler(TestWebSocketHandler):
  507. @property
  508. def max_message_size(self):
  509. return 1024
  510. def on_message(self, message):
  511. self.write_message(str(len(message)))
  512. return Application(
  513. [
  514. (
  515. "/echo",
  516. EchoHandler,
  517. dict(compression_options=self.get_server_compression_options()),
  518. ),
  519. (
  520. "/limited",
  521. LimitedHandler,
  522. dict(compression_options=self.get_server_compression_options()),
  523. ),
  524. ]
  525. )
  526. def get_server_compression_options(self):
  527. return None
  528. def get_client_compression_options(self):
  529. return None
  530. @gen_test
  531. def test_message_sizes(self):
  532. ws = yield self.ws_connect(
  533. "/echo", compression_options=self.get_client_compression_options()
  534. )
  535. # Send the same message three times so we can measure the
  536. # effect of the context_takeover options.
  537. for i in range(3):
  538. ws.write_message(self.MESSAGE)
  539. response = yield ws.read_message()
  540. self.assertEqual(response, self.MESSAGE)
  541. self.assertEqual(ws.protocol._message_bytes_out, len(self.MESSAGE) * 3)
  542. self.assertEqual(ws.protocol._message_bytes_in, len(self.MESSAGE) * 3)
  543. self.verify_wire_bytes(ws.protocol._wire_bytes_in, ws.protocol._wire_bytes_out)
  544. @gen_test
  545. def test_size_limit(self):
  546. ws = yield self.ws_connect(
  547. "/limited", compression_options=self.get_client_compression_options()
  548. )
  549. # Small messages pass through.
  550. ws.write_message("a" * 128)
  551. response = yield ws.read_message()
  552. self.assertEqual(response, "128")
  553. # This message is too big after decompression, but it compresses
  554. # down to a size that will pass the initial checks.
  555. ws.write_message("a" * 2048)
  556. response = yield ws.read_message()
  557. self.assertIsNone(response)
  558. class UncompressedTestMixin(CompressionTestMixin):
  559. """Specialization of CompressionTestMixin when we expect no compression."""
  560. def verify_wire_bytes(self, bytes_in, bytes_out):
  561. # Bytes out includes the 4-byte mask key per message.
  562. self.assertEqual(bytes_out, 3 * (len(self.MESSAGE) + 6))
  563. self.assertEqual(bytes_in, 3 * (len(self.MESSAGE) + 2))
  564. class NoCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
  565. pass
  566. # If only one side tries to compress, the extension is not negotiated.
  567. class ServerOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
  568. def get_server_compression_options(self):
  569. return {}
  570. class ClientOnlyCompressionTest(UncompressedTestMixin, WebSocketBaseTestCase):
  571. def get_client_compression_options(self):
  572. return {}
  573. class DefaultCompressionTest(CompressionTestMixin, WebSocketBaseTestCase):
  574. def get_server_compression_options(self):
  575. return {}
  576. def get_client_compression_options(self):
  577. return {}
  578. def verify_wire_bytes(self, bytes_in, bytes_out):
  579. self.assertLess(bytes_out, 3 * (len(self.MESSAGE) + 6))
  580. self.assertLess(bytes_in, 3 * (len(self.MESSAGE) + 2))
  581. # Bytes out includes the 4 bytes mask key per message.
  582. self.assertEqual(bytes_out, bytes_in + 12)
  583. class MaskFunctionMixin(object):
  584. # Subclasses should define self.mask(mask, data)
  585. def test_mask(self):
  586. self.assertEqual(self.mask(b"abcd", b""), b"")
  587. self.assertEqual(self.mask(b"abcd", b"b"), b"\x03")
  588. self.assertEqual(self.mask(b"abcd", b"54321"), b"TVPVP")
  589. self.assertEqual(self.mask(b"ZXCV", b"98765432"), b"c`t`olpd")
  590. # Include test cases with \x00 bytes (to ensure that the C
  591. # extension isn't depending on null-terminated strings) and
  592. # bytes with the high bit set (to smoke out signedness issues).
  593. self.assertEqual(
  594. self.mask(b"\x00\x01\x02\x03", b"\xff\xfb\xfd\xfc\xfe\xfa"),
  595. b"\xff\xfa\xff\xff\xfe\xfb",
  596. )
  597. self.assertEqual(
  598. self.mask(b"\xff\xfb\xfd\xfc", b"\x00\x01\x02\x03\x04\x05"),
  599. b"\xff\xfa\xff\xff\xfb\xfe",
  600. )
  601. class PythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
  602. def mask(self, mask, data):
  603. return _websocket_mask_python(mask, data)
  604. @unittest.skipIf(speedups is None, "tornado.speedups module not present")
  605. class CythonMaskFunctionTest(MaskFunctionMixin, unittest.TestCase):
  606. def mask(self, mask, data):
  607. return speedups.websocket_mask(mask, data)
  608. class ServerPeriodicPingTest(WebSocketBaseTestCase):
  609. def get_app(self):
  610. class PingHandler(TestWebSocketHandler):
  611. def on_pong(self, data):
  612. self.write_message("got pong")
  613. return Application([("/", PingHandler)], websocket_ping_interval=0.01)
  614. @gen_test
  615. def test_server_ping(self):
  616. ws = yield self.ws_connect("/")
  617. for i in range(3):
  618. response = yield ws.read_message()
  619. self.assertEqual(response, "got pong")
  620. # TODO: test that the connection gets closed if ping responses stop.
  621. class ClientPeriodicPingTest(WebSocketBaseTestCase):
  622. def get_app(self):
  623. class PingHandler(TestWebSocketHandler):
  624. def on_ping(self, data):
  625. self.write_message("got ping")
  626. return Application([("/", PingHandler)])
  627. @gen_test
  628. def test_client_ping(self):
  629. ws = yield self.ws_connect("/", ping_interval=0.01)
  630. for i in range(3):
  631. response = yield ws.read_message()
  632. self.assertEqual(response, "got ping")
  633. # TODO: test that the connection gets closed if ping responses stop.
  634. class ManualPingTest(WebSocketBaseTestCase):
  635. def get_app(self):
  636. class PingHandler(TestWebSocketHandler):
  637. def on_ping(self, data):
  638. self.write_message(data, binary=isinstance(data, bytes))
  639. return Application([("/", PingHandler)])
  640. @gen_test
  641. def test_manual_ping(self):
  642. ws = yield self.ws_connect("/")
  643. self.assertRaises(ValueError, ws.ping, "a" * 126)
  644. ws.ping("hello")
  645. resp = yield ws.read_message()
  646. # on_ping always sees bytes.
  647. self.assertEqual(resp, b"hello")
  648. ws.ping(b"binary hello")
  649. resp = yield ws.read_message()
  650. self.assertEqual(resp, b"binary hello")
  651. class MaxMessageSizeTest(WebSocketBaseTestCase):
  652. def get_app(self):
  653. return Application([("/", EchoHandler)], websocket_max_message_size=1024)
  654. @gen_test
  655. def test_large_message(self):
  656. ws = yield self.ws_connect("/")
  657. # Write a message that is allowed.
  658. msg = "a" * 1024
  659. ws.write_message(msg)
  660. resp = yield ws.read_message()
  661. self.assertEqual(resp, msg)
  662. # Write a message that is too large.
  663. ws.write_message(msg + "b")
  664. resp = yield ws.read_message()
  665. # A message of None means the other side closed the connection.
  666. self.assertIs(resp, None)
  667. self.assertEqual(ws.close_code, 1009)
  668. self.assertEqual(ws.close_reason, "message too big")
  669. # TODO: Needs tests of messages split over multiple
  670. # continuation frames.