MatplotlibDraw.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576
  1. from __future__ import division
  2. from __future__ import unicode_literals
  3. from __future__ import print_function
  4. from __future__ import absolute_import
  5. #from future import standard_library
  6. #standard_library.install_aliases()
  7. from builtins import input
  8. from builtins import str
  9. from builtins import *
  10. from builtins import object
  11. import os
  12. import matplotlib
  13. matplotlib.use('module://ipympl.backend_nbagg')
  14. #matplotlib.rcParams['text.latex.preamble'] = '\\usepackage{amsmath}'
  15. #matplotlib.rcParams['verbose.level'] = 'debug-annoying'
  16. import matplotlib.pyplot as mpl
  17. import matplotlib.transforms as transforms
  18. import numpy as np
  19. class MatplotlibDraw(object):
  20. """
  21. Simple interface for plotting. This interface makes use of
  22. Matplotlib for plotting.
  23. Some attributes that must be controlled directly (no set_* method
  24. since these attributes are changed quite seldom).
  25. ========================== ============================================
  26. Attribute Description
  27. ========================== ============================================
  28. allow_screen_graphics False means that no plot is shown on
  29. the screen. (Does not work yet.)
  30. arrow_head_width Size of arrow head.
  31. ========================== ============================================
  32. """
  33. line_colors = {'red': 'r', 'green': 'g', 'blue': 'b', 'cyan': 'c',
  34. 'magenta': 'm', 'purple': 'p',
  35. 'yellow': 'y', 'black': 'k', 'white': 'w',
  36. 'brown': 'brown', '': ''}
  37. def __init__(self):
  38. self.instruction_file = None
  39. self.allow_screen_graphics = True # does not work yet
  40. def __del__(self):
  41. if self.instruction_file:
  42. self.instruction_file.write('\nmpl.draw()\nraw_input()\n')
  43. self.instruction_file.close()
  44. def ok(self):
  45. """
  46. Return True if set_coordinate_system is called and
  47. objects can be drawn.
  48. """
  49. def adjust_coordinate_system(self, minmax, occupation_percent=80):
  50. """
  51. Given a dict of xmin, xmax, ymin, ymax values, and a desired
  52. filling of the plotting area of `occupation_percent` percent,
  53. set new axis limits.
  54. """
  55. x_range = minmax['xmax'] - minmax['xmin']
  56. y_range = minmax['ymax'] - minmax['ymin']
  57. new_x_range = x_range*100./occupation_percent
  58. x_space = new_x_range - x_range
  59. new_y_range = y_range*100./occupation_percent
  60. y_space = new_y_range - y_range
  61. self.ax.set_xlim(minmax['xmin']-x_space/2., minmax['xmax']+x_space/2.)
  62. self.ax.set_ylim(minmax['ymin']-y_space/2., minmax['ymax']+y_space/2.)
  63. def set_coordinate_system(self, xmin, xmax, ymin, ymax, axis=False,
  64. instruction_file=None, new_figure=True,
  65. xkcd=False):
  66. """
  67. Define the drawing area [xmin,xmax]x[ymin,ymax].
  68. axis: None or False means that axes with tickmarks
  69. are not drawn.
  70. instruction_file: name of file where all the instructions
  71. for the plotting program are stored (useful for debugging
  72. a figure or tailoring plots).
  73. """
  74. # Close file for previous figure and start new one
  75. # if not the figure file is the same
  76. if self.instruction_file is not None:
  77. if instruction_file == self.instruction_file.name:
  78. pass # continue with same file
  79. else:
  80. self.instruction_file.close() # make new py file for commands
  81. self.mpl = mpl
  82. if xkcd:
  83. self.mpl.xkcd()
  84. else:
  85. # Allow \boldsymbol{} etc in title, labels, etc
  86. matplotlib.rc('text', usetex=True)
  87. self.xmin, self.xmax, self.ymin, self.ymax = \
  88. float(xmin), float(xmax), float(ymin), float(ymax)
  89. self.xrange = self.xmax - self.xmin
  90. self.yrange = self.ymax - self.ymin
  91. self.axis = axis
  92. # Compute the right X11 geometry on the screen based on the
  93. # x-y ratio of axis ranges
  94. ratio = (self.ymax-self.ymin)/(self.xmax-self.xmin)
  95. self.xsize = 800 # pixel size
  96. self.ysize = self.xsize*ratio
  97. geometry = '%dx%d' % (self.xsize, self.ysize)
  98. # See http://stackoverflow.com/questions/7449585/how-do-you-set-the-absolute-position-of-figure-windows-with-matplotlib
  99. if isinstance(instruction_file, str):
  100. self.instruction_file = open(instruction_file, 'w')
  101. else:
  102. self.instruction_file = None
  103. self.mpl.ioff() # important for interactive drawing and animation
  104. if self.instruction_file:
  105. self.instruction_file.write("""\
  106. import matplotlib
  107. matplotlib.use('module://ipympl.backend_nbagg')
  108. # Allow \boldsymbol{} etc in title, labels, etc
  109. matplotlib.rc('text', usetex=True)
  110. #matplotlib.rcParams['text.latex.preamble'] = '\\usepackage{amsmath}'
  111. import matplotlib.pyplot as mpl
  112. import matplotlib.transforms as transforms
  113. mpl.ion() # for interactive drawing
  114. """)
  115. # Default properties
  116. self.set_linecolor('red')
  117. self.set_linewidth(2)
  118. self.set_linestyle('solid')
  119. self.set_filled_curves() # no filling
  120. self.set_fontsize(14)
  121. self.arrow_head_width = 0.2*self.xrange/16
  122. self._make_axes(new_figure=new_figure)
  123. manager = self.mpl.get_current_fig_manager()
  124. #manager.window.wm_geometry(geometry)
  125. def _make_axes(self, new_figure=False):
  126. if new_figure:
  127. self.fig = self.mpl.figure()
  128. self.ax = self.fig.gca()
  129. self.ax.set_xlim(self.xmin, self.xmax)
  130. self.ax.set_ylim(self.ymin, self.ymax)
  131. self.ax.set_aspect('equal') # extent of 1 unit is the same on the axes
  132. if not self.axis:
  133. self.mpl.axis('off')
  134. axis_cmd = "mpl.axis('off') # do not show axes with tickmarks\n"
  135. else:
  136. axis_cmd = ''
  137. if self.instruction_file:
  138. fig = 'fig = mpl.figure()\n' if new_figure else ''
  139. self.instruction_file.write("""\
  140. %s
  141. ax = fig.gca()
  142. xmin, xmax, ymin, ymax = %s, %s, %s, %s
  143. ax.set_xlim(xmin, xmax)
  144. ax.set_ylim(ymin, ymax)
  145. ax.set_aspect('equal')
  146. %s
  147. """ % (fig, self.xmin, self.xmax, self.ymin, self.ymax, axis_cmd))
  148. def inside(self, pt, exception=False):
  149. """Is point pt inside the defined plotting area?"""
  150. area = '[%s,%s]x[%s,%s]' % \
  151. (self.xmin, self.xmax, self.ymin, self.ymax)
  152. tol = 1E-14
  153. pt_inside = True
  154. if self.xmin - tol <= pt[0] <= self.xmax + tol:
  155. pass
  156. else:
  157. pt_inside = False
  158. if self.ymin - tol <= pt[1] <= self.ymax + tol:
  159. pass
  160. else:
  161. pt_inside = False
  162. if pt_inside:
  163. return pt_inside, 'point=%s is inside plotting area %s' % \
  164. (pt, area)
  165. else:
  166. msg = 'point=%s is outside plotting area %s' % (pt, area)
  167. if exception:
  168. raise ValueError(msg)
  169. return pt_inside, msg
  170. def set_linecolor(self, color):
  171. """
  172. Change the color of lines. Available colors are
  173. 'black', 'white', 'red', 'blue', 'green', 'yellow',
  174. 'magenta', 'cyan'.
  175. """
  176. self.linecolor = MatplotlibDraw.line_colors[color]
  177. def set_linestyle(self, style):
  178. """Change line style: 'solid', 'dashed', 'dashdot', 'dotted'."""
  179. if not style in ('solid', 'dashed', 'dashdot', 'dotted'):
  180. raise ValueError('Illegal line style: %s' % style)
  181. self.linestyle = style
  182. def set_linewidth(self, width):
  183. """Change the line width (int, starts at 1)."""
  184. self.linewidth = width
  185. def set_filled_curves(self, color='', pattern=''):
  186. """
  187. Fill area inside curves with specified color and/or pattern.
  188. A common pattern is '/' (45 degree lines). Other patterns
  189. include '-', '+', 'x', '\\', '*', 'o', 'O', '.'.
  190. """
  191. if color is False:
  192. self.fillcolor = ''
  193. self.fillpattern = ''
  194. else:
  195. self.fillcolor = color if len(color) == 1 else \
  196. MatplotlibDraw.line_colors[color]
  197. self.fillpattern = pattern
  198. def set_fontsize(self, fontsize=18):
  199. """
  200. Method for setting a common fontsize for text, unless
  201. individually specified when calling ``text``.
  202. """
  203. self.fontsize = fontsize
  204. def set_grid(self, on=False):
  205. self.mpl.grid(on)
  206. if self.instruction_file:
  207. self.instruction_file.write("\nmpl.grid(%s)\n" % str(on))
  208. def erase(self):
  209. """Erase the current figure."""
  210. self.mpl.delaxes()
  211. if self.instruction_file:
  212. self.instruction_file.write("\nmpl.delaxes() # erase\n")
  213. self._make_axes(new_figure=False)
  214. def plot_curve(self, x, y,
  215. linestyle=None, linewidth=None,
  216. linecolor=None, arrow=None,
  217. fillcolor=None, fillpattern=None,
  218. shadow=0, name=None):
  219. """Define a curve with coordinates x and y (arrays)."""
  220. #if not self.allow_screen_graphics:
  221. # mpl.ioff()
  222. #else:
  223. # mpl.ion()
  224. self.xdata = np.asarray(x, dtype=np.float)
  225. self.ydata = np.asarray(y, dtype=np.float)
  226. if linestyle is None:
  227. # use "global" linestyle
  228. linestyle = self.linestyle
  229. if linecolor is None:
  230. linecolor = self.linecolor
  231. if linewidth is None:
  232. linewidth = self.linewidth
  233. if fillcolor is None:
  234. fillcolor = self.fillcolor
  235. if fillpattern is None:
  236. fillpattern = self.fillpattern
  237. if shadow == 1:
  238. shadow = 3 # smallest displacement that is visible
  239. # We can plot fillcolor/fillpattern, arrow or line
  240. if self.instruction_file:
  241. import pprint
  242. if name is not None:
  243. self.instruction_file.write('\n# %s\n' % name)
  244. if not arrow:
  245. self.instruction_file.write(
  246. 'x = %s\n' % pprint.pformat(self.xdata.tolist()))
  247. self.instruction_file.write(
  248. 'y = %s\n' % pprint.pformat(self.ydata.tolist()))
  249. if fillcolor or fillpattern:
  250. if fillpattern != '':
  251. fillcolor = 'white'
  252. #print('%d coords, fillcolor="%s" linecolor="%s" fillpattern="%s"' % (x.size, fillcolor, linecolor, fillpattern))
  253. [line] = self.ax.fill(x, y, fillcolor, edgecolor=linecolor,
  254. linewidth=linewidth, hatch=fillpattern)
  255. if self.instruction_file:
  256. self.instruction_file.write("[line] = ax.fill(x, y, '%s', edgecolor='%s', linewidth=%d, hatch='%s')\n" % (fillcolor, linecolor, linewidth, fillpattern))
  257. else:
  258. # Plain line
  259. [line] = self.ax.plot(x, y, linecolor, linewidth=linewidth,
  260. linestyle=linestyle)
  261. if self.instruction_file:
  262. self.instruction_file.write("[line] = ax.plot(x, y, '%s', linewidth=%d, linestyle='%s')\n" % (linecolor, linewidth, linestyle))
  263. if arrow:
  264. # Note that a Matplotlib arrow is a line with the arrow tip
  265. if not arrow in ('->', '<-', '<->'):
  266. raise ValueError("arrow argument must be '->', '<-', or '<->', not %s" % repr(arrow))
  267. # Add arrow to first and/or last segment
  268. start = arrow == '<-' or arrow == '<->'
  269. end = arrow == '->' or arrow == '<->'
  270. if start:
  271. x_s, y_s = x[1], y[1]
  272. dx_s, dy_s = x[0]-x[1], y[0]-y[1]
  273. self._plot_arrow(x_s, y_s, dx_s, dy_s, '->',
  274. linestyle, linewidth, linecolor)
  275. if end:
  276. x_e, y_e = x[-2], y[-2]
  277. dx_e, dy_e = x[-1]-x[-2], y[-1]-y[-2]
  278. self._plot_arrow(x_e, y_e, dx_e, dy_e, '->',
  279. linestyle, linewidth, linecolor)
  280. if shadow:
  281. # http://matplotlib.sourceforge.net/users/transforms_tutorial.html#using-offset-transforms-to-create-a-shadow-effect
  282. # shift the object over 2 points, and down 2 points
  283. dx, dy = shadow/72., -shadow/72.
  284. offset = transforms.ScaledTranslation(
  285. dx, dy, self.fig.dpi_scale_trans)
  286. shadow_transform = self.ax.transData + offset
  287. # now plot the same data with our offset transform;
  288. # use the zorder to make sure we are below the line
  289. if linewidth is None:
  290. linewidth = 3
  291. self.ax.plot(x, y, linewidth=linewidth, color='gray',
  292. transform=shadow_transform,
  293. zorder=0.5*line.get_zorder())
  294. if self.instruction_file:
  295. self.instruction_file.write("""
  296. # Shadow effect for last ax.plot
  297. dx, dy = 3/72., -3/72.
  298. offset = matplotlib.transforms.ScaledTranslation(dx, dy, fig.dpi_scale_trans)
  299. shadow_transform = ax.transData + offset
  300. self.ax.plot(x, y, linewidth=%d, color='gray',
  301. transform=shadow_transform,
  302. zorder=0.5*line.get_zorder())
  303. """ % linewidth)
  304. def display(self, title=None, show=True):
  305. """Display the figure."""
  306. if title is not None:
  307. self.mpl.title(title)
  308. if self.instruction_file:
  309. self.instruction_file.write('mpl.title("%s")\n' % title)
  310. if show:
  311. self.mpl.draw()
  312. if self.instruction_file:
  313. self.instruction_file.write('mpl.draw()\n')
  314. def savefig(self, filename, dpi=None, crop=True):
  315. """Save figure in file. Set dpi=300 for really high resolution."""
  316. # If filename is without extension, generate all important formats
  317. ext = os.path.splitext(filename)[1]
  318. if not ext:
  319. # Create both PNG and PDF file
  320. self.mpl.savefig(filename + '.png', dpi=dpi)
  321. self.mpl.savefig(filename + '.pdf')
  322. if crop:
  323. # Crop the PNG file
  324. failure = os.system('convert -trim %s.png %s.png' %
  325. (filename, filename))
  326. if failure:
  327. print('convert from ImageMagick is not installed - needed for cropping PNG files')
  328. failure = os.system('pdfcrop %s.pdf %s.pdf' %
  329. (filename, filename))
  330. if failure:
  331. print('pdfcrop is not installed - needed for cropping PDF files')
  332. #self.mpl.savefig(filename + '.eps')
  333. if self.instruction_file:
  334. self.instruction_file.write('mpl.savefig("%s.png", dpi=%s)\n'
  335. % (filename, dpi))
  336. self.instruction_file.write('mpl.savefig("%s.pdf")\n'
  337. % filename)
  338. else:
  339. self.mpl.savefig(filename, dpi=dpi)
  340. if ext == '.png':
  341. if crop:
  342. failure = os.system('convert -trim %s %s' % (filename, filename))
  343. if failure:
  344. print('convert from ImageMagick is not installed - needed for cropping PNG files')
  345. elif ext == '.pdf':
  346. if crop:
  347. failure = os.system('pdfcrop %s %s' % (filename, filename))
  348. if failure:
  349. print('pdfcrop is not installed - needed for cropping PDF files')
  350. if self.instruction_file:
  351. self.instruction_file.write('mpl.savefig("%s", dpi=%s)\n'
  352. % (filename, dpi))
  353. def text(self, text, position, alignment='center', fontsize=0,
  354. arrow_tip=None, bgcolor=None, fgcolor=None, fontfamily=None):
  355. """
  356. Write `text` string at a position (centered, left, right - according
  357. to the `alignment` string). `position` is a point in the coordinate
  358. system.
  359. If ``arrow+tip != None``, an arrow is drawn from the text to a point
  360. (on a curve, for instance). The arrow_tip argument is then
  361. the (x,y) coordinates for the arrow tip.
  362. fontsize=0 indicates use of the default font as set by
  363. ``set_fontsize``.
  364. """
  365. if fontsize == 0:
  366. if hasattr(self, 'fontsize'):
  367. fontsize = self.fontsize
  368. else:
  369. raise AttributeError(
  370. 'No self.fontsize attribute to be used when text(...)\n'
  371. 'is called with fontsize=0. Call set_fontsize method.')
  372. kwargs = {}
  373. if fontfamily is not None:
  374. kwargs['family'] = fontfamily
  375. if bgcolor is not None:
  376. kwargs['backgroundcolor'] = bgcolor
  377. if fgcolor is not None:
  378. kwargs['color'] = fgcolor
  379. x, y = position
  380. if arrow_tip is None:
  381. self.ax.text(x, y, text, horizontalalignment=alignment,
  382. fontsize=fontsize, **kwargs)
  383. if self.instruction_file:
  384. self.instruction_file.write("""\
  385. ax.text(%g, %g, %s,
  386. horizontalalignment=%s, fontsize=%d)
  387. """ % (x, y, repr(text), repr(alignment), fontsize))
  388. else:
  389. if not len(arrow_tip) == 2:
  390. raise ValueError('arrow_tip=%s must be (x,y) pt.' % arrow)
  391. pt = arrow_tip
  392. self.ax.annotate(text, xy=pt, xycoords='data',
  393. textcoords='data', xytext=position,
  394. horizontalalignment=alignment,
  395. verticalalignment='top',
  396. fontsize=fontsize,
  397. arrowprops=dict(arrowstyle='->',
  398. facecolor='black',
  399. #linewidth=2,
  400. linewidth=1,
  401. shrinkA=5,
  402. shrinkB=5))
  403. if self.instruction_file:
  404. self.instruction_file.write("""\
  405. ax.annotate('%s', xy=%s, xycoords='data',
  406. textcoords='data', xytext=%s,
  407. horizontalalignment='%s',
  408. verticalalignment='top',
  409. fontsize=%d,
  410. arrowprops=dict(arrowstyle='->',
  411. facecolor='black',
  412. linewidth=2,
  413. shrinkA=5,
  414. shrinkB=5))
  415. """ % (text, pt.tolist() if isinstance(pt, np.ndarray) else pt,
  416. position, alignment, fontsize))
  417. # Drawing annotations with arrows:
  418. #http://matplotlib.sourceforge.net/users/annotations_intro.html
  419. #http://matplotlib.sourceforge.net/mpl_examples/pylab_examples/annotation_demo2.py
  420. #http://matplotlib.sourceforge.net/users/annotations_intro.html
  421. #http://matplotlib.sourceforge.net/users/annotations_guide.html#plotting-guide-annotation
  422. def _plot_arrow(self, x, y, dx, dy, style='->',
  423. linestyle=None, linewidth=None, linecolor=None):
  424. """Draw arrow (dx,dy) at (x,y). `style` is '->', '<-' or '<->'."""
  425. if linestyle is None:
  426. # use "global" linestyle
  427. linestyle = self.linestyle
  428. if linecolor is None:
  429. linecolor = self.linecolor
  430. if linewidth is None:
  431. linewidth = self.linewidth
  432. if style == '->' or style == '<->':
  433. self.mpl.arrow(x, y, dx, dy,
  434. #hold=True,
  435. facecolor=linecolor,
  436. edgecolor=linecolor,
  437. linestyle=linestyle,
  438. linewidth=linewidth,
  439. head_width=self.arrow_head_width,
  440. #head_width=0.1,
  441. #width=1, # width of arrow body in coordinate scale
  442. length_includes_head=True,
  443. shape='full')
  444. if self.instruction_file:
  445. self.instruction_file.write("""\
  446. mpl.arrow(x=%g, y=%g, dx=%g, dy=%g,
  447. facecolor='%s', edgecolor='%s',
  448. linestyle='%s',
  449. linewidth=%g, head_width=0.1,
  450. length_includes_head=True,
  451. shape='full')
  452. """ % (x, y, dx, dy, linecolor, linecolor, linestyle, linewidth))
  453. if style == '<-' or style == '<->':
  454. self.mpl.arrow(x+dx, y+dy, -dx, -dy, hold=True,
  455. facecolor=linecolor,
  456. edgecolor=linecolor,
  457. linewidth=linewidth,
  458. head_width=0.1,
  459. #width=1,
  460. length_includes_head=True,
  461. shape='full')
  462. if self.instruction_file:
  463. self.instruction_file.write("""\
  464. mpl.arrow(x=%g, y=%g, dx=%g, dy=%g,
  465. facecolor='%s', edgecolor='%s',
  466. linewidth=%g, head_width=0.1,
  467. length_includes_head=True,
  468. shape='full')
  469. """ % (x+dx, y+dy, -dx, -dy, linecolor, linecolor, linewidth))
  470. def arrow2(self, x, y, dx, dy, style='->'):
  471. """Draw arrow (dx,dy) at (x,y). `style` is '->', '<-' or '<->'."""
  472. self.ax.annotate('', xy=(x+dx,y+dy), xytext=(x,y),
  473. arrowprops=dict(arrowstyle=style,
  474. facecolor='black',
  475. linewidth=1,
  476. shrinkA=0,
  477. shrinkB=0))
  478. if self.instruction_file:
  479. self.instruction_file.write("""
  480. ax.annotate('', xy=(%s,%s), xytext=(%s,%s),
  481. arrowprops=dict(arrowstyle=%s,
  482. facecolor='black',
  483. linewidth=1,
  484. shrinkA=0,
  485. shrinkB=0))
  486. """ % (x+dx, y+dy, x, y, style))
  487. def _test():
  488. d = MatplotlibDraw(0, 10, 0, 5, instruction_file='tmp3.py', axis=True)
  489. d.set_linecolor('magenta')
  490. d.set_linewidth(6)
  491. # triangle
  492. x = np.array([1, 4, 1, 1]); y = np.array([1, 1, 4, 1])
  493. d.set_filled_curves('magenta')
  494. d.plot_curve(x, y)
  495. d.set_filled_curves(False)
  496. d.plot_curve(x+4, y)
  497. d.text('some text1', position=(8,4), arrow_tip=(6, 1), alignment='left',
  498. fontsize=18)
  499. pos = np.array((7,4.5)) # numpy points work fine
  500. d.text('some text2', position=pos, arrow_tip=(6, 1), alignment='center',
  501. fontsize=12)
  502. d.set_linewidth(2)
  503. d.arrow(0.25, 0.25, 0.45, 0.45)
  504. d.arrow(0.25, 0.25, 0.25, 4, style='<->')
  505. d.arrow2(4.5, 0, 0, 3, style='<->')
  506. x = np.linspace(0, 9, 201)
  507. y = 4.5 + 0.45*np.cos(0.5*np.pi*x)
  508. d.plot_curve(x, y, arrow='end')
  509. d.display()
  510. input()
  511. if __name__ == '__main__':
  512. _test()