MatplotlibDraw.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. import os
  2. import matplotlib
  3. matplotlib.use('TkAgg')
  4. import matplotlib.pyplot as mpl
  5. import numpy as np
  6. class MatplotlibDraw:
  7. line_colors = {'red': 'r', 'green': 'g', 'blue': 'b', 'cyan': 'c',
  8. 'magenta': 'm', 'purple': 'p',
  9. 'yellow': 'y', 'black': 'k', 'white': 'w', '': ''}
  10. def __init__(self):
  11. self.instruction_file = None
  12. def set_instruction_file(self, filename='tmp_mpl.py'):
  13. """
  14. instruction_file: name of file where all the instructions
  15. are recorded.
  16. """
  17. self.instruction_file = filename
  18. def set_coordinate_system(self, xmin, xmax, ymin, ymax, axis=False):
  19. """
  20. Define the drawing area [xmin,xmax]x[ymin,ymax].
  21. axis: None or False means that axes with tickmarks
  22. are not drawn.
  23. """
  24. self.mpl = mpl
  25. self.xmin, self.xmax, self.ymin, self.ymax = \
  26. float(xmin), float(xmax), float(ymin), float(ymax)
  27. self.xrange = self.xmax - self.xmin
  28. self.yrange = self.ymax - self.ymin
  29. self.axis = axis
  30. if self.instruction_file:
  31. self.instruction_file = open(self.instruction_file, 'w')
  32. else:
  33. self.instruction_file = None
  34. # Compute the right X11 geometry on the screen based on the
  35. # x-y ratio of axis ranges
  36. ratio = (self.ymax-self.ymin)/(self.xmax-self.xmin)
  37. self.xsize = 800 # pixel size
  38. self.ysize = self.xsize*ratio
  39. geometry = '%dx%d' % (self.xsize, self.ysize)
  40. # See http://stackoverflow.com/questions/7449585/how-do-you-set-the-absolute-position-of-figure-windows-with-matplotlib
  41. self.mpl.ion() # important for interactive drawing and animation
  42. if self.instruction_file is not None:
  43. self.instruction_file.write("""\
  44. import matplotlib.pyplot as mpl
  45. mpl.ion() # for interactive drawing
  46. """)
  47. self._make_axes(new_figure=True)
  48. manager = self.mpl.get_current_fig_manager()
  49. manager.window.wm_geometry(geometry)
  50. self.set_linecolor('red')
  51. self.set_linewidth(2)
  52. self.set_linestyle('solid')
  53. self.set_filled_curves()
  54. def _make_axes(self, new_figure=False):
  55. if new_figure:
  56. self.fig = self.mpl.figure()
  57. self.ax = self.fig.gca()
  58. self.ax.set_xlim(self.xmin, self.xmax)
  59. self.ax.set_ylim(self.ymin, self.ymax)
  60. self.ax.set_aspect('equal') # extent of 1 unit is the same on the axes
  61. if not self.axis:
  62. self.mpl.axis('off')
  63. axis_cmd = "mpl.axis('off') # do not show axes with tickmarks\n"
  64. else:
  65. axis_cmd = ''
  66. if self.instruction_file is not None:
  67. fig = 'fig = mpl.figure()\n' if new_figure else ''
  68. self.instruction_file.write("""\
  69. %s
  70. ax = fig.gca()
  71. xmin, xmax, ymin, ymax = %s, %s, %s, %s
  72. ax.set_xlim(xmin, xmax)
  73. ax.set_ylim(ymin, ymax)
  74. ax.set_aspect('equal')
  75. %s
  76. """ % (fig, self.xmin, self.xmax, self.ymin, self.ymax, axis_cmd))
  77. def set_linecolor(self, color):
  78. """Change the color of lines."""
  79. self.linecolor = MatplotlibDraw.line_colors[color]
  80. def set_linestyle(self, style):
  81. """Change line style: 'solid', 'dashed', 'dashdot', 'dotted'."""
  82. if not style in ('solid', 'dashed', 'dashdot', 'dotted'):
  83. raise ValueError('Illegal line style: %s' % style)
  84. self.linestyle = style
  85. def set_linewidth(self, width):
  86. """Change the line width (int, starts at 1)."""
  87. self.linewidth = width
  88. def set_filled_curves(self, color='', hatch=''):
  89. """Fill area inside curves with current line color."""
  90. if color is False:
  91. self.fillcolor = ''
  92. self.fillhatch = ''
  93. else:
  94. self.fillcolor = color if len(color) == 1 else \
  95. MatplotlibDraw.line_colors[color]
  96. self.fillhatch = hatch
  97. def set_grid(self, on=False):
  98. self.mpl.grid(on)
  99. if self.instruction_file is not None:
  100. self.instruction_file.write("\nmpl.grid(%s)\n" % str(on))
  101. def erase(self):
  102. """Erase the current figure."""
  103. self.mpl.delaxes()
  104. if self.instruction_file is not None:
  105. self.instruction_file.write("\nmpl.delaxes() # erase\n")
  106. self._make_axes(new_figure=False)
  107. def define_curve(self, x, y,
  108. linestyle=None, linewidth=None,
  109. linecolor=None, arrow=None,
  110. fillcolor=None, fillhatch=None):
  111. """Define a curve with coordinates x and y (arrays)."""
  112. self.xdata = np.asarray(x, dtype=np.float)
  113. self.ydata = np.asarray(y, dtype=np.float)
  114. if linestyle is None:
  115. # use "global" linestyle
  116. linestyle = self.linestyle
  117. if linecolor is None:
  118. linecolor = self.linecolor
  119. if linewidth is None:
  120. linewidth = self.linewidth
  121. if fillcolor is None:
  122. fillcolor = self.fillcolor
  123. if fillhatch is None:
  124. fillhatch = self.fillhatch
  125. if self.instruction_file is not None:
  126. import pprint
  127. self.instruction_file.write('x = %s\n' % \
  128. pprint.pformat(self.xdata.tolist()))
  129. self.instruction_file.write('y = %s\n' % \
  130. pprint.pformat(self.ydata.tolist()))
  131. if fillcolor or fillhatch:
  132. self.ax.fill(x, y, fillcolor, edgecolor=linecolor,
  133. hatch=fillhatch)
  134. if self.instruction_file is not None:
  135. self.instruction_file.write("ax.fill(x, y, '%s', edgecolor='%s', hatch='%s')\n" % (linecolor, fillcolor, fillhatch))
  136. else:
  137. self.ax.plot(x, y, linecolor, linewidth=linewidth,
  138. linestyle=linestyle)
  139. if self.instruction_file is not None:
  140. self.instruction_file.write("ax.plot(x, y, '%s', linewidth=%d, linestyle='%s')\n" % (linecolor, linewidth, linestyle))
  141. if arrow:
  142. if not arrow in ('start', 'end', 'both'):
  143. raise ValueError("arrow argument must be 'start', 'end', or 'both', not %s" % repr(arrow))
  144. # Add arrow to first and/or last segment
  145. start = arrow == 'start' or arrow == 'both'
  146. end = arrow == 'end' or arrow == 'both'
  147. if start:
  148. x_s, y_s = x[1], y[1]
  149. dx_s, dy_s = x[0]-x[1], y[0]-y[1]
  150. self.arrow(x_s, y_s, dx_s, dy_s)
  151. if end:
  152. x_e, y_e = x[-2], y[-2]
  153. dx_e, dy_e = x[-1]-x[-2], y[-1]-y[-2]
  154. self.arrow(x_e, y_e, dx_e, dy_e)
  155. def display(self):
  156. """Display the figure. Last possible command."""
  157. self.mpl.draw()
  158. if self.instruction_file is not None:
  159. self.instruction_file.write('mpl.draw()\n')
  160. def savefig(self, filename):
  161. """Save figure in file."""
  162. self.mpl.savefig(filename)
  163. if self.instruction_file is not None:
  164. self.instruction_file.write('mpl.savefig(%s)\n' % filename)
  165. def text(self, text, position, alignment='center', fontsize=18,
  166. arrow_tip=None):
  167. """
  168. Write text at a position (centered, left, right - according
  169. to the alignment string). position is a 2-tuple.
  170. arrow+tip != None draws an arrow from the text to a point
  171. (on a curve, for instance). The arrow_tip argument is then
  172. the (x,y) coordinates for the arrow tip.
  173. """
  174. x, y = position
  175. if arrow_tip is None:
  176. self.ax.text(x, y, text, horizontalalignment=alignment,
  177. fontsize=fontsize)
  178. if self.instruction_file is not None:
  179. self.instruction_file.write("""\
  180. ax.text(%g, %g, %s,
  181. horizontalalignment=%s, fontsize=%d)
  182. """ % (x, y, repr(text), repr(alignment), fontsize))
  183. else:
  184. if not len(arrow_tip) == 2:
  185. raise ValueError('arrow_tip=%s must be (x,y) pt.' % arrow)
  186. pt = arrow_tip
  187. self.ax.annotate(text, xy=pt, xycoords='data',
  188. textcoords='data', xytext=position,
  189. horizontalalignment=alignment,
  190. verticalalignment='top',
  191. fontsize=fontsize,
  192. arrowprops=dict(arrowstyle='->',
  193. facecolor='black',
  194. #linewidth=2,
  195. linewidth=1,
  196. shrinkA=5,
  197. shrinkB=5))
  198. if self.instruction_file is not None:
  199. self.instruction_file.write("""\
  200. ax.annotate('%s', xy=%s, xycoords='data',
  201. textcoords='data', xytext=%s,
  202. horizontalalignment='%s',
  203. verticalalignment='top',
  204. fontsize=%d,
  205. arrowprops=dict(arrowstyle='->',
  206. facecolor='black',
  207. linewidth=2,
  208. shrinkA=5,
  209. shrinkB=5))
  210. """ % (text, pt, position, alignment, fontsize))
  211. # Drawing annotations with arrows:
  212. #http://matplotlib.sourceforge.net/users/annotations_intro.html
  213. #http://matplotlib.sourceforge.net/mpl_examples/pylab_examples/annotation_demo2.py
  214. #http://matplotlib.sourceforge.net/users/annotations_intro.html
  215. #http://matplotlib.sourceforge.net/users/annotations_guide.html#plotting-guide-annotation
  216. def arrow(self, x, y, dx, dy, style='->',
  217. linestyle=None, linewidth=None, linecolor=None):
  218. """Draw arrow (dx,dy) at (x,y). `style` is '->', '<-' or '<->'."""
  219. if linestyle is None:
  220. # use "global" linestyle
  221. linestyle = self.linestyle
  222. if linecolor is None:
  223. linecolor = self.linecolor
  224. if linewidth is None:
  225. linewidth = self.linewidth
  226. if style == '->' or style == '<->':
  227. self.mpl.arrow(x, y, dx, dy, hold=True,
  228. facecolor=linecolor,
  229. edgecolor=linecolor,
  230. linewidth=linewidth,
  231. head_width=0.1,
  232. #width=1,
  233. length_includes_head=True,
  234. shape='full')
  235. if self.instruction_file is not None:
  236. self.instruction_file.write("""\
  237. mpl.arrow(x=%g, y=%g, dx=%g, dy=%g,
  238. facecolor='%s', edgecolor='%s',
  239. linewidth=%g, head_width=0.1,
  240. length_includes_head=True,
  241. shape='full')
  242. """ % (x, y, dx, dy, linecolor, linecolor, linewidth))
  243. if style == '<-' or style == '<->':
  244. self.mpl.arrow(x+dx, y+dy, -dx, -dy, hold=True,
  245. facecolor=linecolor,
  246. edgecolor=linecolor,
  247. linewidth=linewidth,
  248. head_width=0.1,
  249. #width=1,
  250. length_includes_head=True,
  251. shape='full')
  252. if self.instruction_file is not None:
  253. self.instruction_file.write("""\
  254. mpl.arrow(x=%g, y=%g, dx=%g, dy=%g,
  255. facecolor='%s', edgecolor='%s',
  256. linewidth=%g, head_width=0.1,
  257. length_includes_head=True,
  258. shape='full')
  259. """ % (x+dx, y+dy, -dx, -dy, linecolor, linecolor, linewidth))
  260. def arrow2(self, x, y, dx, dy, style='->'):
  261. """Draw arrow (dx,dy) at (x,y). `style` is '->', '<-' or '<->'."""
  262. self.ax.annotate('', xy=(x+dx,y+dy) , xytext=(x,y),
  263. arrowprops=dict(arrowstyle=style,
  264. facecolor='black',
  265. linewidth=1,
  266. shrinkA=0,
  267. shrinkB=0))
  268. if self.instruction_file is not None:
  269. self.instruction_file.write("")
  270. def _test():
  271. d = MatplotlibDraw(0, 10, 0, 5, instruction_file='tmp3.py', axis=True)
  272. d.set_linecolor('magenta')
  273. d.set_linewidth(6)
  274. # triangle
  275. x = np.array([1, 4, 1, 1]); y = np.array([1, 1, 4, 1])
  276. d.set_filled_curves('magenta')
  277. d.define_curve(x, y)
  278. d.set_filled_curves(False)
  279. d.define_curve(x+4, y)
  280. d.text('some text1', position=(8,4), arrow_tip=(6, 1), alignment='left',
  281. fontsize=18)
  282. pos = np.array((7,4.5)) # numpy points work fine
  283. d.text('some text2', position=pos, arrow_tip=(6, 1), alignment='center',
  284. fontsize=12)
  285. d.set_linewidth(2)
  286. d.arrow(0.25, 0.25, 0.45, 0.45)
  287. d.arrow(0.25, 0.25, 0.25, 4, style='<->')
  288. d.arrow2(4.5, 0, 0, 3, style='<->')
  289. x = np.linspace(0, 9, 201)
  290. y = 4.5 + 0.45*np.cos(0.5*np.pi*x)
  291. d.define_curve(x, y, arrow='end')
  292. d.display()
  293. raw_input()
  294. if __name__ == '__main__':
  295. _test()