queues_test.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. # Licensed under the Apache License, Version 2.0 (the "License"); you may
  2. # not use this file except in compliance with the License. You may obtain
  3. # a copy of the License at
  4. #
  5. # http://www.apache.org/licenses/LICENSE-2.0
  6. #
  7. # Unless required by applicable law or agreed to in writing, software
  8. # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  9. # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
  10. # License for the specific language governing permissions and limitations
  11. # under the License.
  12. import asyncio
  13. from datetime import timedelta
  14. from random import random
  15. import unittest
  16. from tornado import gen, queues
  17. from tornado.gen import TimeoutError
  18. from tornado.testing import gen_test, AsyncTestCase
  19. class QueueBasicTest(AsyncTestCase):
  20. def test_repr_and_str(self):
  21. q = queues.Queue(maxsize=1) # type: queues.Queue[None]
  22. self.assertIn(hex(id(q)), repr(q))
  23. self.assertNotIn(hex(id(q)), str(q))
  24. q.get()
  25. for q_str in repr(q), str(q):
  26. self.assertTrue(q_str.startswith("<Queue"))
  27. self.assertIn("maxsize=1", q_str)
  28. self.assertIn("getters[1]", q_str)
  29. self.assertNotIn("putters", q_str)
  30. self.assertNotIn("tasks", q_str)
  31. q.put(None)
  32. q.put(None)
  33. # Now the queue is full, this putter blocks.
  34. q.put(None)
  35. for q_str in repr(q), str(q):
  36. self.assertNotIn("getters", q_str)
  37. self.assertIn("putters[1]", q_str)
  38. self.assertIn("tasks=2", q_str)
  39. def test_order(self):
  40. q = queues.Queue() # type: queues.Queue[int]
  41. for i in [1, 3, 2]:
  42. q.put_nowait(i)
  43. items = [q.get_nowait() for _ in range(3)]
  44. self.assertEqual([1, 3, 2], items)
  45. @gen_test
  46. def test_maxsize(self):
  47. self.assertRaises(TypeError, queues.Queue, maxsize=None)
  48. self.assertRaises(ValueError, queues.Queue, maxsize=-1)
  49. q = queues.Queue(maxsize=2) # type: queues.Queue[int]
  50. self.assertTrue(q.empty())
  51. self.assertFalse(q.full())
  52. self.assertEqual(2, q.maxsize)
  53. self.assertTrue(q.put(0).done())
  54. self.assertTrue(q.put(1).done())
  55. self.assertFalse(q.empty())
  56. self.assertTrue(q.full())
  57. put2 = q.put(2)
  58. self.assertFalse(put2.done())
  59. self.assertEqual(0, (yield q.get())) # Make room.
  60. self.assertTrue(put2.done())
  61. self.assertFalse(q.empty())
  62. self.assertTrue(q.full())
  63. class QueueGetTest(AsyncTestCase):
  64. @gen_test
  65. def test_blocking_get(self):
  66. q = queues.Queue() # type: queues.Queue[int]
  67. q.put_nowait(0)
  68. self.assertEqual(0, (yield q.get()))
  69. def test_nonblocking_get(self):
  70. q = queues.Queue() # type: queues.Queue[int]
  71. q.put_nowait(0)
  72. self.assertEqual(0, q.get_nowait())
  73. def test_nonblocking_get_exception(self):
  74. q = queues.Queue() # type: queues.Queue[int]
  75. self.assertRaises(queues.QueueEmpty, q.get_nowait)
  76. @gen_test
  77. def test_get_with_putters(self):
  78. q = queues.Queue(1) # type: queues.Queue[int]
  79. q.put_nowait(0)
  80. put = q.put(1)
  81. self.assertEqual(0, (yield q.get()))
  82. self.assertIsNone((yield put))
  83. @gen_test
  84. def test_blocking_get_wait(self):
  85. q = queues.Queue() # type: queues.Queue[int]
  86. q.put(0)
  87. self.io_loop.call_later(0.01, q.put, 1)
  88. self.io_loop.call_later(0.02, q.put, 2)
  89. self.assertEqual(0, (yield q.get(timeout=timedelta(seconds=1))))
  90. self.assertEqual(1, (yield q.get(timeout=timedelta(seconds=1))))
  91. @gen_test
  92. def test_get_timeout(self):
  93. q = queues.Queue() # type: queues.Queue[int]
  94. get_timeout = q.get(timeout=timedelta(seconds=0.01))
  95. get = q.get()
  96. with self.assertRaises(TimeoutError):
  97. yield get_timeout
  98. q.put_nowait(0)
  99. self.assertEqual(0, (yield get))
  100. @gen_test
  101. def test_get_timeout_preempted(self):
  102. q = queues.Queue() # type: queues.Queue[int]
  103. get = q.get(timeout=timedelta(seconds=0.01))
  104. q.put(0)
  105. yield gen.sleep(0.02)
  106. self.assertEqual(0, (yield get))
  107. @gen_test
  108. def test_get_clears_timed_out_putters(self):
  109. q = queues.Queue(1) # type: queues.Queue[int]
  110. # First putter succeeds, remainder block.
  111. putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
  112. put = q.put(10)
  113. self.assertEqual(10, len(q._putters))
  114. yield gen.sleep(0.02)
  115. self.assertEqual(10, len(q._putters))
  116. self.assertFalse(put.done()) # Final waiter is still active.
  117. q.put(11)
  118. self.assertEqual(0, (yield q.get())) # get() clears the waiters.
  119. self.assertEqual(1, len(q._putters))
  120. for putter in putters[1:]:
  121. self.assertRaises(TimeoutError, putter.result)
  122. @gen_test
  123. def test_get_clears_timed_out_getters(self):
  124. q = queues.Queue() # type: queues.Queue[int]
  125. getters = [
  126. asyncio.ensure_future(q.get(timedelta(seconds=0.01))) for _ in range(10)
  127. ]
  128. get = asyncio.ensure_future(q.get())
  129. self.assertEqual(11, len(q._getters))
  130. yield gen.sleep(0.02)
  131. self.assertEqual(11, len(q._getters))
  132. self.assertFalse(get.done()) # Final waiter is still active.
  133. q.get() # get() clears the waiters.
  134. self.assertEqual(2, len(q._getters))
  135. for getter in getters:
  136. self.assertRaises(TimeoutError, getter.result)
  137. @gen_test
  138. def test_async_for(self):
  139. q = queues.Queue() # type: queues.Queue[int]
  140. for i in range(5):
  141. q.put(i)
  142. async def f():
  143. results = []
  144. async for i in q:
  145. results.append(i)
  146. if i == 4:
  147. return results
  148. results = yield f()
  149. self.assertEqual(results, list(range(5)))
  150. class QueuePutTest(AsyncTestCase):
  151. @gen_test
  152. def test_blocking_put(self):
  153. q = queues.Queue() # type: queues.Queue[int]
  154. q.put(0)
  155. self.assertEqual(0, q.get_nowait())
  156. def test_nonblocking_put_exception(self):
  157. q = queues.Queue(1) # type: queues.Queue[int]
  158. q.put(0)
  159. self.assertRaises(queues.QueueFull, q.put_nowait, 1)
  160. @gen_test
  161. def test_put_with_getters(self):
  162. q = queues.Queue() # type: queues.Queue[int]
  163. get0 = q.get()
  164. get1 = q.get()
  165. yield q.put(0)
  166. self.assertEqual(0, (yield get0))
  167. yield q.put(1)
  168. self.assertEqual(1, (yield get1))
  169. @gen_test
  170. def test_nonblocking_put_with_getters(self):
  171. q = queues.Queue() # type: queues.Queue[int]
  172. get0 = q.get()
  173. get1 = q.get()
  174. q.put_nowait(0)
  175. # put_nowait does *not* immediately unblock getters.
  176. yield gen.moment
  177. self.assertEqual(0, (yield get0))
  178. q.put_nowait(1)
  179. yield gen.moment
  180. self.assertEqual(1, (yield get1))
  181. @gen_test
  182. def test_blocking_put_wait(self):
  183. q = queues.Queue(1) # type: queues.Queue[int]
  184. q.put_nowait(0)
  185. self.io_loop.call_later(0.01, q.get)
  186. self.io_loop.call_later(0.02, q.get)
  187. futures = [q.put(0), q.put(1)]
  188. self.assertFalse(any(f.done() for f in futures))
  189. yield futures
  190. @gen_test
  191. def test_put_timeout(self):
  192. q = queues.Queue(1) # type: queues.Queue[int]
  193. q.put_nowait(0) # Now it's full.
  194. put_timeout = q.put(1, timeout=timedelta(seconds=0.01))
  195. put = q.put(2)
  196. with self.assertRaises(TimeoutError):
  197. yield put_timeout
  198. self.assertEqual(0, q.get_nowait())
  199. # 1 was never put in the queue.
  200. self.assertEqual(2, (yield q.get()))
  201. # Final get() unblocked this putter.
  202. yield put
  203. @gen_test
  204. def test_put_timeout_preempted(self):
  205. q = queues.Queue(1) # type: queues.Queue[int]
  206. q.put_nowait(0)
  207. put = q.put(1, timeout=timedelta(seconds=0.01))
  208. q.get()
  209. yield gen.sleep(0.02)
  210. yield put # No TimeoutError.
  211. @gen_test
  212. def test_put_clears_timed_out_putters(self):
  213. q = queues.Queue(1) # type: queues.Queue[int]
  214. # First putter succeeds, remainder block.
  215. putters = [q.put(i, timedelta(seconds=0.01)) for i in range(10)]
  216. put = q.put(10)
  217. self.assertEqual(10, len(q._putters))
  218. yield gen.sleep(0.02)
  219. self.assertEqual(10, len(q._putters))
  220. self.assertFalse(put.done()) # Final waiter is still active.
  221. q.put(11) # put() clears the waiters.
  222. self.assertEqual(2, len(q._putters))
  223. for putter in putters[1:]:
  224. self.assertRaises(TimeoutError, putter.result)
  225. @gen_test
  226. def test_put_clears_timed_out_getters(self):
  227. q = queues.Queue() # type: queues.Queue[int]
  228. getters = [
  229. asyncio.ensure_future(q.get(timedelta(seconds=0.01))) for _ in range(10)
  230. ]
  231. get = asyncio.ensure_future(q.get())
  232. q.get()
  233. self.assertEqual(12, len(q._getters))
  234. yield gen.sleep(0.02)
  235. self.assertEqual(12, len(q._getters))
  236. self.assertFalse(get.done()) # Final waiters still active.
  237. q.put(0) # put() clears the waiters.
  238. self.assertEqual(1, len(q._getters))
  239. self.assertEqual(0, (yield get))
  240. for getter in getters:
  241. self.assertRaises(TimeoutError, getter.result)
  242. @gen_test
  243. def test_float_maxsize(self):
  244. # If a float is passed for maxsize, a reasonable limit should
  245. # be enforced, instead of being treated as unlimited.
  246. # It happens to be rounded up.
  247. # http://bugs.python.org/issue21723
  248. q = queues.Queue(maxsize=1.3) # type: ignore
  249. self.assertTrue(q.empty())
  250. self.assertFalse(q.full())
  251. q.put_nowait(0)
  252. q.put_nowait(1)
  253. self.assertFalse(q.empty())
  254. self.assertTrue(q.full())
  255. self.assertRaises(queues.QueueFull, q.put_nowait, 2)
  256. self.assertEqual(0, q.get_nowait())
  257. self.assertFalse(q.empty())
  258. self.assertFalse(q.full())
  259. yield q.put(2)
  260. put = q.put(3)
  261. self.assertFalse(put.done())
  262. self.assertEqual(1, (yield q.get()))
  263. yield put
  264. self.assertTrue(q.full())
  265. class QueueJoinTest(AsyncTestCase):
  266. queue_class = queues.Queue
  267. def test_task_done_underflow(self):
  268. q = self.queue_class()
  269. self.assertRaises(ValueError, q.task_done)
  270. @gen_test
  271. def test_task_done(self):
  272. q = self.queue_class()
  273. for i in range(100):
  274. q.put_nowait(i)
  275. self.accumulator = 0
  276. @gen.coroutine
  277. def worker():
  278. while True:
  279. item = yield q.get()
  280. self.accumulator += item
  281. q.task_done()
  282. yield gen.sleep(random() * 0.01)
  283. # Two coroutines share work.
  284. worker()
  285. worker()
  286. yield q.join()
  287. self.assertEqual(sum(range(100)), self.accumulator)
  288. @gen_test
  289. def test_task_done_delay(self):
  290. # Verify it is task_done(), not get(), that unblocks join().
  291. q = self.queue_class()
  292. q.put_nowait(0)
  293. join = q.join()
  294. self.assertFalse(join.done())
  295. yield q.get()
  296. self.assertFalse(join.done())
  297. yield gen.moment
  298. self.assertFalse(join.done())
  299. q.task_done()
  300. self.assertTrue(join.done())
  301. @gen_test
  302. def test_join_empty_queue(self):
  303. q = self.queue_class()
  304. yield q.join()
  305. yield q.join()
  306. @gen_test
  307. def test_join_timeout(self):
  308. q = self.queue_class()
  309. q.put(0)
  310. with self.assertRaises(TimeoutError):
  311. yield q.join(timeout=timedelta(seconds=0.01))
  312. class PriorityQueueJoinTest(QueueJoinTest):
  313. queue_class = queues.PriorityQueue
  314. @gen_test
  315. def test_order(self):
  316. q = self.queue_class(maxsize=2)
  317. q.put_nowait((1, "a"))
  318. q.put_nowait((0, "b"))
  319. self.assertTrue(q.full())
  320. q.put((3, "c"))
  321. q.put((2, "d"))
  322. self.assertEqual((0, "b"), q.get_nowait())
  323. self.assertEqual((1, "a"), (yield q.get()))
  324. self.assertEqual((2, "d"), q.get_nowait())
  325. self.assertEqual((3, "c"), (yield q.get()))
  326. self.assertTrue(q.empty())
  327. class LifoQueueJoinTest(QueueJoinTest):
  328. queue_class = queues.LifoQueue
  329. @gen_test
  330. def test_order(self):
  331. q = self.queue_class(maxsize=2)
  332. q.put_nowait(1)
  333. q.put_nowait(0)
  334. self.assertTrue(q.full())
  335. q.put(3)
  336. q.put(2)
  337. self.assertEqual(3, q.get_nowait())
  338. self.assertEqual(2, (yield q.get()))
  339. self.assertEqual(0, q.get_nowait())
  340. self.assertEqual(1, (yield q.get()))
  341. self.assertTrue(q.empty())
  342. class ProducerConsumerTest(AsyncTestCase):
  343. @gen_test
  344. def test_producer_consumer(self):
  345. q = queues.Queue(maxsize=3) # type: queues.Queue[int]
  346. history = []
  347. # We don't yield between get() and task_done(), so get() must wait for
  348. # the next tick. Otherwise we'd immediately call task_done and unblock
  349. # join() before q.put() resumes, and we'd only process the first four
  350. # items.
  351. @gen.coroutine
  352. def consumer():
  353. while True:
  354. history.append((yield q.get()))
  355. q.task_done()
  356. @gen.coroutine
  357. def producer():
  358. for item in range(10):
  359. yield q.put(item)
  360. consumer()
  361. yield producer()
  362. yield q.join()
  363. self.assertEqual(list(range(10)), history)
  364. if __name__ == "__main__":
  365. unittest.main()