MatplotlibDraw.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453
  1. import os
  2. import matplotlib
  3. matplotlib.use('TkAgg')
  4. import matplotlib.pyplot as mpl
  5. import numpy as np
  6. class MatplotlibDraw:
  7. """
  8. Simple interface for plotting. This interface makes use of
  9. Matplotlib for plotting.
  10. Some attributes that must be controlled directly (no set_* method
  11. since these attributes are changed quite seldom).
  12. ========================== ============================================
  13. Attribute Description
  14. ========================== ============================================
  15. allow_screen_graphics False means that no plot is shown on
  16. the screen. (Does not work yet.)
  17. arrow_head_width Size of arrow head.
  18. ========================== ============================================
  19. """
  20. line_colors = {'red': 'r', 'green': 'g', 'blue': 'b', 'cyan': 'c',
  21. 'magenta': 'm', 'purple': 'p',
  22. 'yellow': 'y', 'black': 'k', 'white': 'w',
  23. 'brown': 'brown', '': ''}
  24. def __init__(self):
  25. self.instruction_file = None
  26. self.allow_screen_graphics = True # does not work yet
  27. def ok(self):
  28. """
  29. Return True if set_coordinate_system is called and
  30. objects can be drawn.
  31. """
  32. def adjust_coordinate_system(self, minmax, occupation_percent=80):
  33. """
  34. Given a dict of xmin, xmax, ymin, ymax values, and a desired
  35. filling of the plotting area of `occupation_percent` percent,
  36. set new axis limits.
  37. """
  38. x_range = minmax['xmax'] - minmax['xmin']
  39. y_range = minmax['ymax'] - minmax['ymin']
  40. new_x_range = x_range*100./occupation_percent
  41. x_space = new_x_range - x_range
  42. new_y_range = y_range*100./occupation_percent
  43. y_space = new_y_range - y_range
  44. self.ax.set_xlim(minmax['xmin']-x_space/2., minmax['xmax']+x_space/2.)
  45. self.ax.set_ylim(minmax['ymin']-y_space/2., minmax['ymax']+y_space/2.)
  46. def set_coordinate_system(self, xmin, xmax, ymin, ymax, axis=False,
  47. instruction_file=None):
  48. """
  49. Define the drawing area [xmin,xmax]x[ymin,ymax].
  50. axis: None or False means that axes with tickmarks
  51. are not drawn.
  52. instruction_file: name of file where all the instructions
  53. for the plotting program are stored (useful for debugging
  54. a figure or tailoring plots).
  55. """
  56. # Close file for previous figure and start new one
  57. # if not the figure file is the same
  58. if self.instruction_file is not None:
  59. if instruction_file == self.instruction_file.name:
  60. pass # continue with same file
  61. else:
  62. self.instruction_file.close() # make new py file for commands
  63. self.mpl = mpl
  64. self.xmin, self.xmax, self.ymin, self.ymax = \
  65. float(xmin), float(xmax), float(ymin), float(ymax)
  66. self.xrange = self.xmax - self.xmin
  67. self.yrange = self.ymax - self.ymin
  68. self.axis = axis
  69. # Compute the right X11 geometry on the screen based on the
  70. # x-y ratio of axis ranges
  71. ratio = (self.ymax-self.ymin)/(self.xmax-self.xmin)
  72. self.xsize = 800 # pixel size
  73. self.ysize = self.xsize*ratio
  74. geometry = '%dx%d' % (self.xsize, self.ysize)
  75. # See http://stackoverflow.com/questions/7449585/how-do-you-set-the-absolute-position-of-figure-windows-with-matplotlib
  76. if isinstance(instruction_file, str):
  77. self.instruction_file = open(instruction_file, 'w')
  78. else:
  79. self.instruction_file = None
  80. self.mpl.ion() # important for interactive drawing and animation
  81. if self.instruction_file:
  82. self.instruction_file.write("""\
  83. import matplotlib.pyplot as mpl
  84. mpl.ion() # for interactive drawing
  85. """)
  86. # Default properties
  87. self.set_linecolor('red')
  88. self.set_linewidth(2)
  89. self.set_linestyle('solid')
  90. self.set_filled_curves() # no filling
  91. self.set_fontsize(14)
  92. self.arrow_head_width = 0.2*self.xrange/16
  93. self._make_axes(new_figure=True)
  94. manager = self.mpl.get_current_fig_manager()
  95. manager.window.wm_geometry(geometry)
  96. def _make_axes(self, new_figure=False):
  97. if new_figure:
  98. self.fig = self.mpl.figure()
  99. self.ax = self.fig.gca()
  100. self.ax.set_xlim(self.xmin, self.xmax)
  101. self.ax.set_ylim(self.ymin, self.ymax)
  102. self.ax.set_aspect('equal') # extent of 1 unit is the same on the axes
  103. if not self.axis:
  104. self.mpl.axis('off')
  105. axis_cmd = "mpl.axis('off') # do not show axes with tickmarks\n"
  106. else:
  107. axis_cmd = ''
  108. if self.instruction_file:
  109. fig = 'fig = mpl.figure()\n' if new_figure else ''
  110. self.instruction_file.write("""\
  111. %s
  112. ax = fig.gca()
  113. xmin, xmax, ymin, ymax = %s, %s, %s, %s
  114. ax.set_xlim(xmin, xmax)
  115. ax.set_ylim(ymin, ymax)
  116. ax.set_aspect('equal')
  117. %s
  118. """ % (fig, self.xmin, self.xmax, self.ymin, self.ymax, axis_cmd))
  119. def inside(self, pt, exception=False):
  120. """Is point pt inside the defined plotting area?"""
  121. area = '[%s,%s]x[%s,%s]' % \
  122. (self.xmin, self.xmax, self.ymin, self.ymax)
  123. tol = 1E-14
  124. pt_inside = True
  125. if self.xmin - tol <= pt[0] <= self.xmax + tol:
  126. pass
  127. else:
  128. pt_inside = False
  129. if self.ymin - tol <= pt[1] <= self.ymax + tol:
  130. pass
  131. else:
  132. pt_inside = False
  133. if pt_inside:
  134. return pt_inside, 'point=%s is inside plotting area %s' % \
  135. (pt, area)
  136. else:
  137. msg = 'point=%s is outside plotting area %s' % (pt, area)
  138. if exception:
  139. raise ValueError(msg)
  140. return pt_inside, msg
  141. def set_linecolor(self, color):
  142. """
  143. Change the color of lines. Available colors are
  144. 'black', 'white', 'red', 'blue', 'green', 'yellow',
  145. 'magenta', 'cyan'.
  146. """
  147. self.linecolor = MatplotlibDraw.line_colors[color]
  148. def set_linestyle(self, style):
  149. """Change line style: 'solid', 'dashed', 'dashdot', 'dotted'."""
  150. if not style in ('solid', 'dashed', 'dashdot', 'dotted'):
  151. raise ValueError('Illegal line style: %s' % style)
  152. self.linestyle = style
  153. def set_linewidth(self, width):
  154. """Change the line width (int, starts at 1)."""
  155. self.linewidth = width
  156. def set_filled_curves(self, color='', pattern=''):
  157. """
  158. Fill area inside curves with specified color and/or pattern.
  159. A common pattern is '/' (45 degree lines). Other patterns
  160. include....
  161. """
  162. if color is False:
  163. self.fillcolor = ''
  164. self.fillpattern = ''
  165. else:
  166. self.fillcolor = color if len(color) == 1 else \
  167. MatplotlibDraw.line_colors[color]
  168. self.fillpattern = pattern
  169. def set_fontsize(self, fontsize=18):
  170. """
  171. Method for setting a common fontsize for text, unless
  172. individually specified when calling ``text``.
  173. """
  174. self.fontsize = fontsize
  175. def set_grid(self, on=False):
  176. self.mpl.grid(on)
  177. if self.instruction_file:
  178. self.instruction_file.write("\nmpl.grid(%s)\n" % str(on))
  179. def erase(self):
  180. """Erase the current figure."""
  181. self.mpl.delaxes()
  182. if self.instruction_file:
  183. self.instruction_file.write("\nmpl.delaxes() # erase\n")
  184. self._make_axes(new_figure=False)
  185. def plot_curve(self, x, y,
  186. linestyle=None, linewidth=None,
  187. linecolor=None, arrow=None,
  188. fillcolor=None, fillpattern=None):
  189. """Define a curve with coordinates x and y (arrays)."""
  190. #if not self.allow_screen_graphics:
  191. # mpl.ioff()
  192. #else:
  193. # mpl.ion()
  194. self.xdata = np.asarray(x, dtype=np.float)
  195. self.ydata = np.asarray(y, dtype=np.float)
  196. if linestyle is None:
  197. # use "global" linestyle
  198. linestyle = self.linestyle
  199. if linecolor is None:
  200. linecolor = self.linecolor
  201. if linewidth is None:
  202. linewidth = self.linewidth
  203. if fillcolor is None:
  204. fillcolor = self.fillcolor
  205. if fillpattern is None:
  206. fillpattern = self.fillpattern
  207. if self.instruction_file:
  208. import pprint
  209. self.instruction_file.write('x = %s\n' % \
  210. pprint.pformat(self.xdata.tolist()))
  211. self.instruction_file.write('y = %s\n' % \
  212. pprint.pformat(self.ydata.tolist()))
  213. if fillcolor or fillpattern:
  214. if fillpattern != '':
  215. fillcolor = 'white'
  216. #print '%d coords, fillcolor="%s" linecolor="%s" fillpattern="%s"' % (x.size, fillcolor, linecolor, fillpattern)
  217. self.ax.fill(x, y, fillcolor, edgecolor=linecolor,
  218. linewidth=linewidth, hatch=fillpattern)
  219. if self.instruction_file:
  220. self.instruction_file.write("ax.fill(x, y, '%s', edgecolor='%s', linewidth=%d, hatch='%s')\n" % (fillcolor, linecolor, linewidth, fillpattern))
  221. else:
  222. self.ax.plot(x, y, linecolor, linewidth=linewidth,
  223. linestyle=linestyle)
  224. if self.instruction_file:
  225. self.instruction_file.write("ax.plot(x, y, '%s', linewidth=%d, linestyle='%s')\n" % (linecolor, linewidth, linestyle))
  226. if arrow:
  227. if not arrow in ('->', '<-', '<->'):
  228. raise ValueError("arrow argument must be '->', '<-', or '<->', not %s" % repr(arrow))
  229. # Add arrow to first and/or last segment
  230. start = arrow == '<-' or arrow == '<->'
  231. end = arrow == '->' or arrow == '<->'
  232. if start:
  233. x_s, y_s = x[1], y[1]
  234. dx_s, dy_s = x[0]-x[1], y[0]-y[1]
  235. self.plot_arrow(x_s, y_s, dx_s, dy_s, '->',
  236. linestyle, linewidth, linecolor)
  237. if end:
  238. x_e, y_e = x[-2], y[-2]
  239. dx_e, dy_e = x[-1]-x[-2], y[-1]-y[-2]
  240. self.plot_arrow(x_e, y_e, dx_e, dy_e, '->',
  241. linestyle, linewidth, linecolor)
  242. def display(self, title=None):
  243. """Display the figure. Last possible command."""
  244. if title is not None:
  245. self.mpl.title(title)
  246. if self.instruction_file:
  247. self.instruction_file.write('mpl.title("%s")\n' % title)
  248. self.mpl.draw()
  249. if self.instruction_file:
  250. self.instruction_file.write('mpl.draw()\n')
  251. def savefig(self, filename):
  252. """Save figure in file."""
  253. self.mpl.savefig(filename)
  254. if self.instruction_file:
  255. self.instruction_file.write('mpl.savefig("%s")\n' % filename)
  256. def text(self, text, position, alignment='center', fontsize=0,
  257. arrow_tip=None):
  258. """
  259. Write `text` string at a position (centered, left, right - according
  260. to the `alignment` string). `position` is a point in the coordinate
  261. system.
  262. If ``arrow+tip != None``, an arrow is drawn from the text to a point
  263. (on a curve, for instance). The arrow_tip argument is then
  264. the (x,y) coordinates for the arrow tip.
  265. fontsize=0 indicates use of the default font as set by
  266. ``set_fontsize``.
  267. """
  268. if fontsize == 0:
  269. if hasattr(self, 'fontsize'):
  270. fontsize = self.fontsize
  271. else:
  272. raise AttributeError(
  273. 'No self.fontsize attribute to be used when text(...)\n'
  274. 'is called with fontsize=0. Call set_fontsize method.')
  275. x, y = position
  276. if arrow_tip is None:
  277. self.ax.text(x, y, text, horizontalalignment=alignment,
  278. fontsize=fontsize)
  279. if self.instruction_file:
  280. self.instruction_file.write("""\
  281. ax.text(%g, %g, %s,
  282. horizontalalignment=%s, fontsize=%d)
  283. """ % (x, y, repr(text), repr(alignment), fontsize))
  284. else:
  285. if not len(arrow_tip) == 2:
  286. raise ValueError('arrow_tip=%s must be (x,y) pt.' % arrow)
  287. pt = arrow_tip
  288. self.ax.annotate(text, xy=pt, xycoords='data',
  289. textcoords='data', xytext=position,
  290. horizontalalignment=alignment,
  291. verticalalignment='top',
  292. fontsize=fontsize,
  293. arrowprops=dict(arrowstyle='->',
  294. facecolor='black',
  295. #linewidth=2,
  296. linewidth=1,
  297. shrinkA=5,
  298. shrinkB=5))
  299. if self.instruction_file:
  300. self.instruction_file.write("""\
  301. ax.annotate('%s', xy=%s, xycoords='data',
  302. textcoords='data', xytext=%s,
  303. horizontalalignment='%s',
  304. verticalalignment='top',
  305. fontsize=%d,
  306. arrowprops=dict(arrowstyle='->',
  307. facecolor='black',
  308. linewidth=2,
  309. shrinkA=5,
  310. shrinkB=5))
  311. """ % (text, pt, position, alignment, fontsize))
  312. # Drawing annotations with arrows:
  313. #http://matplotlib.sourceforge.net/users/annotations_intro.html
  314. #http://matplotlib.sourceforge.net/mpl_examples/pylab_examples/annotation_demo2.py
  315. #http://matplotlib.sourceforge.net/users/annotations_intro.html
  316. #http://matplotlib.sourceforge.net/users/annotations_guide.html#plotting-guide-annotation
  317. def plot_arrow(self, x, y, dx, dy, style='->',
  318. linestyle=None, linewidth=None, linecolor=None):
  319. """Draw arrow (dx,dy) at (x,y). `style` is '->', '<-' or '<->'."""
  320. if linestyle is None:
  321. # use "global" linestyle
  322. linestyle = self.linestyle
  323. if linecolor is None:
  324. linecolor = self.linecolor
  325. if linewidth is None:
  326. linewidth = self.linewidth
  327. if style == '->' or style == '<->':
  328. self.mpl.arrow(x, y, dx, dy, hold=True,
  329. facecolor=linecolor,
  330. edgecolor=linecolor,
  331. linewidth=linewidth,
  332. head_width=self.arrow_head_width,
  333. #head_width=0.1,
  334. #width=1, # width of arrow body in coordinate scale
  335. length_includes_head=True,
  336. shape='full')
  337. if self.instruction_file:
  338. self.instruction_file.write("""\
  339. mpl.arrow(x=%g, y=%g, dx=%g, dy=%g,
  340. facecolor='%s', edgecolor='%s',
  341. linewidth=%g, head_width=0.1,
  342. length_includes_head=True,
  343. shape='full')
  344. """ % (x, y, dx, dy, linecolor, linecolor, linewidth))
  345. if style == '<-' or style == '<->':
  346. self.mpl.arrow(x+dx, y+dy, -dx, -dy, hold=True,
  347. facecolor=linecolor,
  348. edgecolor=linecolor,
  349. linewidth=linewidth,
  350. head_width=0.1,
  351. #width=1,
  352. length_includes_head=True,
  353. shape='full')
  354. if self.instruction_file:
  355. self.instruction_file.write("""\
  356. mpl.arrow(x=%g, y=%g, dx=%g, dy=%g,
  357. facecolor='%s', edgecolor='%s',
  358. linewidth=%g, head_width=0.1,
  359. length_includes_head=True,
  360. shape='full')
  361. """ % (x+dx, y+dy, -dx, -dy, linecolor, linecolor, linewidth))
  362. def arrow2(self, x, y, dx, dy, style='->'):
  363. """Draw arrow (dx,dy) at (x,y). `style` is '->', '<-' or '<->'."""
  364. self.ax.annotate('', xy=(x+dx,y+dy), xytext=(x,y),
  365. arrowprops=dict(arrowstyle=style,
  366. facecolor='black',
  367. linewidth=1,
  368. shrinkA=0,
  369. shrinkB=0))
  370. if self.instruction_file:
  371. self.instruction_file.write("""
  372. ax.annotate('', xy=(%s,%s), xytext=(%s,%s),
  373. arrowprops=dict(arrowstyle=%s,
  374. facecolor='black',
  375. linewidth=1,
  376. shrinkA=0,
  377. shrinkB=0))
  378. """ % (x+dx, y+dy, x, y, style))
  379. def _test():
  380. d = MatplotlibDraw(0, 10, 0, 5, instruction_file='tmp3.py', axis=True)
  381. d.set_linecolor('magenta')
  382. d.set_linewidth(6)
  383. # triangle
  384. x = np.array([1, 4, 1, 1]); y = np.array([1, 1, 4, 1])
  385. d.set_filled_curves('magenta')
  386. d.plot_curve(x, y)
  387. d.set_filled_curves(False)
  388. d.plot_curve(x+4, y)
  389. d.text('some text1', position=(8,4), arrow_tip=(6, 1), alignment='left',
  390. fontsize=18)
  391. pos = np.array((7,4.5)) # numpy points work fine
  392. d.text('some text2', position=pos, arrow_tip=(6, 1), alignment='center',
  393. fontsize=12)
  394. d.set_linewidth(2)
  395. d.arrow(0.25, 0.25, 0.45, 0.45)
  396. d.arrow(0.25, 0.25, 0.25, 4, style='<->')
  397. d.arrow2(4.5, 0, 0, 3, style='<->')
  398. x = np.linspace(0, 9, 201)
  399. y = 4.5 + 0.45*np.cos(0.5*np.pi*x)
  400. d.plot_curve(x, y, arrow='end')
  401. d.display()
  402. raw_input()
  403. if __name__ == '__main__':
  404. _test()