MatplotlibDraw.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387
  1. import os
  2. import matplotlib
  3. matplotlib.use('TkAgg')
  4. import matplotlib.pyplot as mpl
  5. import numpy as np
  6. class MatplotlibDraw:
  7. line_colors = {'red': 'r', 'green': 'g', 'blue': 'b', 'cyan': 'c',
  8. 'magenta': 'm', 'purple': 'p',
  9. 'yellow': 'y', 'black': 'k', 'white': 'w',
  10. 'brown': 'brown', '': ''}
  11. def __init__(self):
  12. self.instruction_file = None
  13. def ok(self):
  14. """
  15. Return True if set_coordinate_system is called and
  16. objects can be drawn.
  17. """
  18. def set_coordinate_system(self, xmin, xmax, ymin, ymax, axis=False,
  19. instruction_file=None):
  20. """
  21. Define the drawing area [xmin,xmax]x[ymin,ymax].
  22. axis: None or False means that axes with tickmarks
  23. are not drawn.
  24. instruction_file: name of file where all the instructions
  25. for the plotting program are stored (useful for debugging
  26. a figure or tailoring plots).
  27. """
  28. self.mpl = mpl
  29. self.xmin, self.xmax, self.ymin, self.ymax = \
  30. float(xmin), float(xmax), float(ymin), float(ymax)
  31. self.xrange = self.xmax - self.xmin
  32. self.yrange = self.ymax - self.ymin
  33. self.axis = axis
  34. # Compute the right X11 geometry on the screen based on the
  35. # x-y ratio of axis ranges
  36. ratio = (self.ymax-self.ymin)/(self.xmax-self.xmin)
  37. self.xsize = 800 # pixel size
  38. self.ysize = self.xsize*ratio
  39. geometry = '%dx%d' % (self.xsize, self.ysize)
  40. # See http://stackoverflow.com/questions/7449585/how-do-you-set-the-absolute-position-of-figure-windows-with-matplotlib
  41. if isinstance(instruction_file, str):
  42. self.instruction_file = open(instruction_file, 'w')
  43. else:
  44. self.instruction_file = None
  45. self.mpl.ion() # important for interactive drawing and animation
  46. if self.instruction_file:
  47. self.instruction_file.write("""\
  48. import matplotlib.pyplot as mpl
  49. mpl.ion() # for interactive drawing
  50. """)
  51. self._make_axes(new_figure=True)
  52. manager = self.mpl.get_current_fig_manager()
  53. manager.window.wm_geometry(geometry)
  54. # Default properties
  55. self.set_linecolor('red')
  56. self.set_linewidth(2)
  57. self.set_linestyle('solid')
  58. self.set_filled_curves() # no filling
  59. self.arrow_head_width = 0.2*self.xrange/16
  60. def _make_axes(self, new_figure=False):
  61. if new_figure:
  62. self.fig = self.mpl.figure()
  63. self.ax = self.fig.gca()
  64. self.ax.set_xlim(self.xmin, self.xmax)
  65. self.ax.set_ylim(self.ymin, self.ymax)
  66. self.ax.set_aspect('equal') # extent of 1 unit is the same on the axes
  67. if not self.axis:
  68. self.mpl.axis('off')
  69. axis_cmd = "mpl.axis('off') # do not show axes with tickmarks\n"
  70. else:
  71. axis_cmd = ''
  72. if self.instruction_file:
  73. fig = 'fig = mpl.figure()\n' if new_figure else ''
  74. self.instruction_file.write("""\
  75. %s
  76. ax = fig.gca()
  77. xmin, xmax, ymin, ymax = %s, %s, %s, %s
  78. ax.set_xlim(xmin, xmax)
  79. ax.set_ylim(ymin, ymax)
  80. ax.set_aspect('equal')
  81. %s
  82. """ % (fig, self.xmin, self.xmax, self.ymin, self.ymax, axis_cmd))
  83. def inside(self, pt, exception=False):
  84. """Is point pt inside the defined plotting area?"""
  85. area = '[%s,%s]x[%s,%s]' % \
  86. (self.xmin, self.xmax, self.ymin, self.ymax)
  87. tol = 1E-14
  88. pt_inside = True
  89. if self.xmin - tol <= pt[0] <= self.xmax + tol:
  90. pass
  91. else:
  92. pt_inside = False
  93. if self.ymin - tol <= pt[1] <= self.ymax + tol:
  94. pass
  95. else:
  96. pt_inside = False
  97. if pt_inside:
  98. return pt_inside, 'point=%s is inside plotting area %s' % \
  99. (pt, area)
  100. else:
  101. msg = 'point=%s is outside plotting area %s' % (pt, area)
  102. if exception:
  103. raise ValueError(msg)
  104. return pt_inside, msg
  105. def set_linecolor(self, color):
  106. """
  107. Change the color of lines. Available colors are
  108. 'black', 'white', 'red', 'blue', 'green', 'yellow',
  109. 'magenta', 'cyan'.
  110. """
  111. self.linecolor = MatplotlibDraw.line_colors[color]
  112. def set_linestyle(self, style):
  113. """Change line style: 'solid', 'dashed', 'dashdot', 'dotted'."""
  114. if not style in ('solid', 'dashed', 'dashdot', 'dotted'):
  115. raise ValueError('Illegal line style: %s' % style)
  116. self.linestyle = style
  117. def set_linewidth(self, width):
  118. """Change the line width (int, starts at 1)."""
  119. self.linewidth = width
  120. def set_filled_curves(self, color='', pattern=''):
  121. """
  122. Fill area inside curves with specified color and/or pattern.
  123. A common pattern is '/' (45 degree lines). Other patterns
  124. include....
  125. """
  126. if color is False:
  127. self.fillcolor = ''
  128. self.fillpattern = ''
  129. else:
  130. self.fillcolor = color if len(color) == 1 else \
  131. MatplotlibDraw.line_colors[color]
  132. self.fillpattern = pattern
  133. def set_grid(self, on=False):
  134. self.mpl.grid(on)
  135. if self.instruction_file:
  136. self.instruction_file.write("\nmpl.grid(%s)\n" % str(on))
  137. def erase(self):
  138. """Erase the current figure."""
  139. self.mpl.delaxes()
  140. if self.instruction_file:
  141. self.instruction_file.write("\nmpl.delaxes() # erase\n")
  142. self._make_axes(new_figure=False)
  143. def plot_curve(self, x, y,
  144. linestyle=None, linewidth=None,
  145. linecolor=None, arrow=None,
  146. fillcolor=None, fillpattern=None):
  147. """Define a curve with coordinates x and y (arrays)."""
  148. self.xdata = np.asarray(x, dtype=np.float)
  149. self.ydata = np.asarray(y, dtype=np.float)
  150. if linestyle is None:
  151. # use "global" linestyle
  152. linestyle = self.linestyle
  153. if linecolor is None:
  154. linecolor = self.linecolor
  155. if linewidth is None:
  156. linewidth = self.linewidth
  157. if fillcolor is None:
  158. fillcolor = self.fillcolor
  159. if fillpattern is None:
  160. fillpattern = self.fillpattern
  161. if self.instruction_file:
  162. import pprint
  163. self.instruction_file.write('x = %s\n' % \
  164. pprint.pformat(self.xdata.tolist()))
  165. self.instruction_file.write('y = %s\n' % \
  166. pprint.pformat(self.ydata.tolist()))
  167. if fillcolor or fillpattern:
  168. if fillpattern != '':
  169. fillcolor = 'white'
  170. #print '%d coords, fillcolor="%s" linecolor="%s" fillpattern="%s"' % (x.size, fillcolor, linecolor, fillpattern)
  171. self.ax.fill(x, y, fillcolor, edgecolor=linecolor,
  172. linewidth=linewidth, hatch=fillpattern)
  173. if self.instruction_file:
  174. self.instruction_file.write("ax.fill(x, y, '%s', edgecolor='%s', linewidth=%d, hatch='%s')\n" % (fillcolor, linecolor, linewidth, fillpattern))
  175. else:
  176. self.ax.plot(x, y, linecolor, linewidth=linewidth,
  177. linestyle=linestyle)
  178. if self.instruction_file:
  179. self.instruction_file.write("ax.plot(x, y, '%s', linewidth=%d, linestyle='%s')\n" % (linecolor, linewidth, linestyle))
  180. if arrow:
  181. if not arrow in ('->', '<-', '<->'):
  182. raise ValueError("arrow argument must be '->', '<-', or '<->', not %s" % repr(arrow))
  183. # Add arrow to first and/or last segment
  184. start = arrow == '<-' or arrow == '<->'
  185. end = arrow == '->' or arrow == '<->'
  186. if start:
  187. x_s, y_s = x[1], y[1]
  188. dx_s, dy_s = x[0]-x[1], y[0]-y[1]
  189. self.plot_arrow(x_s, y_s, dx_s, dy_s, '->',
  190. linestyle, linewidth, linecolor)
  191. if end:
  192. x_e, y_e = x[-2], y[-2]
  193. dx_e, dy_e = x[-1]-x[-2], y[-1]-y[-2]
  194. self.plot_arrow(x_e, y_e, dx_e, dy_e, '->',
  195. linestyle, linewidth, linecolor)
  196. def display(self, title=None):
  197. """Display the figure. Last possible command."""
  198. if title is not None:
  199. self.mpl.title(title)
  200. if self.instruction_file:
  201. self.instruction_file.write('mpl.title("%s")\n' % title)
  202. self.mpl.draw()
  203. if self.instruction_file:
  204. self.instruction_file.write('mpl.draw()\n')
  205. def savefig(self, filename):
  206. """Save figure in file."""
  207. self.mpl.savefig(filename)
  208. if self.instruction_file:
  209. self.instruction_file.write('mpl.savefig("%s")\n' % filename)
  210. def text(self, text, position, alignment='center', fontsize=18,
  211. arrow_tip=None):
  212. """
  213. Write text at a position (centered, left, right - according
  214. to the alignment string). position is a 2-tuple.
  215. arrow+tip != None draws an arrow from the text to a point
  216. (on a curve, for instance). The arrow_tip argument is then
  217. the (x,y) coordinates for the arrow tip.
  218. """
  219. x, y = position
  220. if arrow_tip is None:
  221. self.ax.text(x, y, text, horizontalalignment=alignment,
  222. fontsize=fontsize)
  223. if self.instruction_file:
  224. self.instruction_file.write("""\
  225. ax.text(%g, %g, %s,
  226. horizontalalignment=%s, fontsize=%d)
  227. """ % (x, y, repr(text), repr(alignment), fontsize))
  228. else:
  229. if not len(arrow_tip) == 2:
  230. raise ValueError('arrow_tip=%s must be (x,y) pt.' % arrow)
  231. pt = arrow_tip
  232. self.ax.annotate(text, xy=pt, xycoords='data',
  233. textcoords='data', xytext=position,
  234. horizontalalignment=alignment,
  235. verticalalignment='top',
  236. fontsize=fontsize,
  237. arrowprops=dict(arrowstyle='->',
  238. facecolor='black',
  239. #linewidth=2,
  240. linewidth=1,
  241. shrinkA=5,
  242. shrinkB=5))
  243. if self.instruction_file:
  244. self.instruction_file.write("""\
  245. ax.annotate('%s', xy=%s, xycoords='data',
  246. textcoords='data', xytext=%s,
  247. horizontalalignment='%s',
  248. verticalalignment='top',
  249. fontsize=%d,
  250. arrowprops=dict(arrowstyle='->',
  251. facecolor='black',
  252. linewidth=2,
  253. shrinkA=5,
  254. shrinkB=5))
  255. """ % (text, pt, position, alignment, fontsize))
  256. # Drawing annotations with arrows:
  257. #http://matplotlib.sourceforge.net/users/annotations_intro.html
  258. #http://matplotlib.sourceforge.net/mpl_examples/pylab_examples/annotation_demo2.py
  259. #http://matplotlib.sourceforge.net/users/annotations_intro.html
  260. #http://matplotlib.sourceforge.net/users/annotations_guide.html#plotting-guide-annotation
  261. def plot_arrow(self, x, y, dx, dy, style='->',
  262. linestyle=None, linewidth=None, linecolor=None):
  263. """Draw arrow (dx,dy) at (x,y). `style` is '->', '<-' or '<->'."""
  264. if linestyle is None:
  265. # use "global" linestyle
  266. linestyle = self.linestyle
  267. if linecolor is None:
  268. linecolor = self.linecolor
  269. if linewidth is None:
  270. linewidth = self.linewidth
  271. if style == '->' or style == '<->':
  272. self.mpl.arrow(x, y, dx, dy, hold=True,
  273. facecolor=linecolor,
  274. edgecolor=linecolor,
  275. linewidth=linewidth,
  276. head_width=self.arrow_head_width,
  277. #head_width=0.1,
  278. #width=1, # width of arrow body in coordinate scale
  279. length_includes_head=True,
  280. shape='full')
  281. if self.instruction_file:
  282. self.instruction_file.write("""\
  283. mpl.arrow(x=%g, y=%g, dx=%g, dy=%g,
  284. facecolor='%s', edgecolor='%s',
  285. linewidth=%g, head_width=0.1,
  286. length_includes_head=True,
  287. shape='full')
  288. """ % (x, y, dx, dy, linecolor, linecolor, linewidth))
  289. if style == '<-' or style == '<->':
  290. self.mpl.arrow(x+dx, y+dy, -dx, -dy, hold=True,
  291. facecolor=linecolor,
  292. edgecolor=linecolor,
  293. linewidth=linewidth,
  294. head_width=0.1,
  295. #width=1,
  296. length_includes_head=True,
  297. shape='full')
  298. if self.instruction_file:
  299. self.instruction_file.write("""\
  300. mpl.arrow(x=%g, y=%g, dx=%g, dy=%g,
  301. facecolor='%s', edgecolor='%s',
  302. linewidth=%g, head_width=0.1,
  303. length_includes_head=True,
  304. shape='full')
  305. """ % (x+dx, y+dy, -dx, -dy, linecolor, linecolor, linewidth))
  306. def arrow2(self, x, y, dx, dy, style='->'):
  307. """Draw arrow (dx,dy) at (x,y). `style` is '->', '<-' or '<->'."""
  308. self.ax.annotate('', xy=(x+dx,y+dy), xytext=(x,y),
  309. arrowprops=dict(arrowstyle=style,
  310. facecolor='black',
  311. linewidth=1,
  312. shrinkA=0,
  313. shrinkB=0))
  314. if self.instruction_file:
  315. self.instruction_file.write("""
  316. ax.annotate('', xy=(%s,%s), xytext=(%s,%s),
  317. arrowprops=dict(arrowstyle=%s,
  318. facecolor='black',
  319. linewidth=1,
  320. shrinkA=0,
  321. shrinkB=0))
  322. """ % (x+dx, y+dy, x, y, style))
  323. def _test():
  324. d = MatplotlibDraw(0, 10, 0, 5, instruction_file='tmp3.py', axis=True)
  325. d.set_linecolor('magenta')
  326. d.set_linewidth(6)
  327. # triangle
  328. x = np.array([1, 4, 1, 1]); y = np.array([1, 1, 4, 1])
  329. d.set_filled_curves('magenta')
  330. d.plot_curve(x, y)
  331. d.set_filled_curves(False)
  332. d.plot_curve(x+4, y)
  333. d.text('some text1', position=(8,4), arrow_tip=(6, 1), alignment='left',
  334. fontsize=18)
  335. pos = np.array((7,4.5)) # numpy points work fine
  336. d.text('some text2', position=pos, arrow_tip=(6, 1), alignment='center',
  337. fontsize=12)
  338. d.set_linewidth(2)
  339. d.arrow(0.25, 0.25, 0.45, 0.45)
  340. d.arrow(0.25, 0.25, 0.25, 4, style='<->')
  341. d.arrow2(4.5, 0, 0, 3, style='<->')
  342. x = np.linspace(0, 9, 201)
  343. y = 4.5 + 0.45*np.cos(0.5*np.pi*x)
  344. d.plot_curve(x, y, arrow='end')
  345. d.display()
  346. raw_input()
  347. if __name__ == '__main__':
  348. _test()