import numpy as np
from menpo.visualize.base import Renderer
# The colour map used for all lines and markers
GLOBAL_CMAP = "jet"
[docs]class MatplotlibRenderer(Renderer):
r"""
Abstract class for rendering visualizations using Matplotlib.
Parameters
----------
figure_id : `int` or ``None``
A figure id or ``None``. ``None`` assumes we maintain the Matplotlib
state machine and use `plt.gcf()`.
new_figure : `bool`
If ``True``, it creates a new figure to render on.
"""
def __init__(self, figure_id, new_figure):
super(MatplotlibRenderer, self).__init__(figure_id, new_figure)
# Set up data for saving
self._supported_ext = self.figure.canvas.get_supported_filetypes().keys()
# Create the extensions map, have to add . in front of the extensions
# and map every extension to the savefig method
n_ext = len(self._supported_ext)
func_list = [lambda obj, fp, **kwargs: self.figure.savefig(fp, **obj)] * n_ext
self._extensions_map = dict(
zip(["." + s for s in self._supported_ext], func_list)
)
[docs] def force_draw(self):
r"""
Method for forcing the current figure to render.
"""
import matplotlib.pyplot as plt
plt.show()
class MatplotlibSubplots(object):
def _subplot_layout(self, num_subplots):
if num_subplots < 2:
return [1, 1]
while self._is_prime(num_subplots) and num_subplots > 4:
num_subplots += 1
p = self._factor(num_subplots)
if len(p) == 1:
p.insert(0, 1)
return p
while len(p) > 2:
if len(p) >= 4:
p[0] = p[0] * p[-2]
p[1] = p[1] * p[-1]
del p[-2:]
else:
p[0] = p[0] * p[1]
del p[1]
p.sort()
# Reformat if the column/row ratio is too large: we want a roughly
# square design
while (p[1] / p[0]) > 2.5:
p = self._subplot_layout(num_subplots + 1)
return p
def _factor(self, n):
gaps = [1, 2, 2, 4, 2, 4, 2, 4, 6, 2, 6]
length, cycle = 11, 3
f, fs, next_ind = 2, [], 0
while f * f <= n:
while n % f == 0:
fs.append(f)
n /= f
f += gaps[next_ind]
next_ind += 1
if next_ind == length:
next_ind = cycle
if n > 1:
fs.append(n)
return fs
def _is_prime(self, n):
if n == 2 or n == 3:
return True
if n < 2 or n % 2 == 0:
return False
if n < 9:
return True
if n % 3 == 0:
return False
r = int(n ** 0.5)
f = 5
while f <= r:
if n % f == 0:
return False
if n % (f + 2) == 0:
return False
f += 6
return True
def _parse_cmap(cmap_name=None, image_shape_len=3):
import matplotlib.cm as cm
if cmap_name is not None:
return cm.get_cmap(cmap_name)
else:
if image_shape_len == 2:
# Single channels are viewed in Gray by default
return cm.gray
else:
return None
def _parse_axes_limits(min_x, max_x, min_y, max_y, axes_x_limits, axes_y_limits):
if isinstance(axes_x_limits, int):
axes_x_limits = float(axes_x_limits)
if isinstance(axes_y_limits, int):
axes_y_limits = float(axes_y_limits)
if isinstance(axes_x_limits, float):
pad = (max_x - min_x) * axes_x_limits
axes_x_limits = [min_x - pad, max_x + pad]
if isinstance(axes_y_limits, float):
pad = (max_y - min_y) * axes_y_limits
axes_y_limits = [min_y - pad, max_y + pad]
return axes_x_limits, axes_y_limits
def _set_axes_options(
ax,
render_axes=True,
inverted_y_axis=False,
axes_font_name="sans-serif",
axes_font_size=10,
axes_font_style="normal",
axes_font_weight="normal",
axes_x_limits=None,
axes_y_limits=None,
axes_x_ticks=None,
axes_y_ticks=None,
axes_x_label=None,
axes_y_label=None,
title=None,
):
if render_axes:
# render axes
ax.set_axis_on()
# set font options
for l in ax.get_xticklabels() + ax.get_yticklabels():
l.set_fontsize(axes_font_size)
l.set_fontname(axes_font_name)
l.set_fontstyle(axes_font_style)
l.set_fontweight(axes_font_weight)
# set ticks
if axes_x_ticks is not None:
ax.set_xticks(axes_x_ticks)
if axes_y_ticks is not None:
ax.set_yticks(axes_y_ticks)
# set labels and title
if axes_x_label is None:
axes_x_label = ""
if axes_y_label is None:
axes_y_label = ""
if title is None:
title = ""
ax.set_xlabel(
axes_x_label,
fontsize=axes_font_size,
fontname=axes_font_name,
fontstyle=axes_font_style,
fontweight=axes_font_weight,
)
ax.set_ylabel(
axes_y_label,
fontsize=axes_font_size,
fontname=axes_font_name,
fontstyle=axes_font_style,
fontweight=axes_font_weight,
)
ax.set_title(
title,
fontsize=axes_font_size,
fontname=axes_font_name,
fontstyle=axes_font_style,
fontweight=axes_font_weight,
)
else:
# do not render axes
ax.set_axis_off()
# also remove the ticks to get rid of the white area
ax.set_xticks([])
ax.set_yticks([])
# set axes limits
if axes_x_limits is not None:
ax.set_xlim(np.sort(axes_x_limits))
if axes_y_limits is None:
axes_y_limits = ax.get_ylim()
if inverted_y_axis:
ax.set_ylim(np.sort(axes_y_limits)[::-1])
else:
ax.set_ylim(np.sort(axes_y_limits))
def _set_grid_options(render_grid=True, grid_line_style="--", grid_line_width=2):
import matplotlib.pyplot as plt
if render_grid:
plt.grid("on", linestyle=grid_line_style, linewidth=grid_line_width)
else:
plt.grid("off")
def _set_figure_size(fig, figure_size=(7, 7)):
if figure_size is not None:
fig.set_size_inches(np.asarray(figure_size))
def _set_numbering(
ax,
centers,
render_numbering=True,
numbers_horizontal_align="center",
numbers_vertical_align="bottom",
numbers_font_name="sans-serif",
numbers_font_size=10,
numbers_font_style="normal",
numbers_font_weight="normal",
numbers_font_colour="k",
):
if render_numbering:
for k, p in enumerate(centers):
ax.annotate(
str(k),
xy=(p[0], p[1]),
horizontalalignment=numbers_horizontal_align,
verticalalignment=numbers_vertical_align,
size=numbers_font_size,
family=numbers_font_name,
fontstyle=numbers_font_style,
fontweight=numbers_font_weight,
color=numbers_font_colour,
)
def _set_legend(
ax,
legend_handles,
render_legend=True,
legend_title="",
legend_font_name="sans-serif",
legend_font_style="normal",
legend_font_size=10,
legend_font_weight="normal",
legend_marker_scale=None,
legend_location=2,
legend_bbox_to_anchor=(1.05, 1.0),
legend_border_axes_pad=None,
legend_n_columns=1,
legend_horizontal_spacing=None,
legend_vertical_spacing=None,
legend_border=True,
legend_border_padding=None,
legend_shadow=False,
legend_rounded_corners=False,
):
if render_legend:
# Options related to legend's font
prop = {
"family": legend_font_name,
"size": legend_font_size,
"style": legend_font_style,
"weight": legend_font_weight,
}
# Render legend
ax.legend(
handles=legend_handles,
title=legend_title,
prop=prop,
loc=legend_location,
bbox_to_anchor=legend_bbox_to_anchor,
borderaxespad=legend_border_axes_pad,
ncol=legend_n_columns,
columnspacing=legend_horizontal_spacing,
labelspacing=legend_vertical_spacing,
frameon=legend_border,
borderpad=legend_border_padding,
shadow=legend_shadow,
fancybox=legend_rounded_corners,
markerscale=legend_marker_scale,
)
class MatplotlibImageViewer2d(MatplotlibRenderer):
def __init__(self, figure_id, new_figure, image):
super(MatplotlibImageViewer2d, self).__init__(figure_id, new_figure)
self.image = image
self.axes_list = []
def render(
self,
interpolation="bilinear",
cmap_name=None,
alpha=1.0,
render_axes=False,
axes_font_name="sans-serif",
axes_font_size=10,
axes_font_style="normal",
axes_font_weight="normal",
axes_x_limits=None,
axes_y_limits=None,
axes_x_ticks=None,
axes_y_ticks=None,
figure_size=(7, 7),
):
import matplotlib.pyplot as plt
# parse colour map argument
cmap = _parse_cmap(cmap_name=cmap_name, image_shape_len=len(self.image.shape))
# parse axes limits
axes_x_limits, axes_y_limits = _parse_axes_limits(
0.0,
self.image.shape[1],
0.0,
self.image.shape[0],
axes_x_limits,
axes_y_limits,
)
# render image
plt.imshow(self.image, cmap=cmap, interpolation=interpolation, alpha=alpha)
# store axes object
ax = plt.gca()
self.axes_list = [ax]
# set axes options
_set_axes_options(
ax,
render_axes=render_axes,
inverted_y_axis=True,
axes_font_name=axes_font_name,
axes_font_size=axes_font_size,
axes_font_style=axes_font_style,
axes_font_weight=axes_font_weight,
axes_x_limits=axes_x_limits,
axes_y_limits=axes_y_limits,
axes_x_ticks=axes_x_ticks,
axes_y_ticks=axes_y_ticks,
)
# set figure size
_set_figure_size(self.figure, figure_size)
return self
class MatplotlibImageSubplotsViewer2d(MatplotlibRenderer, MatplotlibSubplots):
def __init__(self, figure_id, new_figure, image):
super(MatplotlibImageSubplotsViewer2d, self).__init__(figure_id, new_figure)
self.image = image
self.num_subplots = self.image.shape[2]
self.plot_layout = self._subplot_layout(self.num_subplots)
self.axes_list = []
def render(
self,
interpolation="bilinear",
cmap_name=None,
alpha=1.0,
render_axes=False,
axes_font_name="sans-serif",
axes_font_size=10,
axes_font_style="normal",
axes_font_weight="normal",
axes_x_limits=None,
axes_y_limits=None,
axes_x_ticks=None,
axes_y_ticks=None,
figure_size=(7, 7),
):
import matplotlib.pyplot as plt
# parse colour map argument
cmap = _parse_cmap(cmap_name=cmap_name, image_shape_len=2)
# parse axes limits
axes_x_limits, axes_y_limits = _parse_axes_limits(
0.0,
self.image.shape[1],
0.0,
self.image.shape[0],
axes_x_limits,
axes_y_limits,
)
p = self.plot_layout
for i in range(self.image.shape[2]):
# create subplot and append the axes object
ax = plt.subplot(p[0], p[1], 1 + i)
self.axes_list.append(ax)
# render image
plt.imshow(
self.image[:, :, i], cmap=cmap, interpolation=interpolation, alpha=alpha
)
# set axes options
_set_axes_options(
ax,
render_axes=render_axes,
inverted_y_axis=True,
axes_font_name=axes_font_name,
axes_font_size=axes_font_size,
axes_font_style=axes_font_style,
axes_font_weight=axes_font_weight,
axes_x_limits=axes_x_limits,
axes_y_limits=axes_y_limits,
axes_x_ticks=axes_x_ticks,
axes_y_ticks=axes_y_ticks,
)
# set figure size
_set_figure_size(self.figure, figure_size)
return self
class MatplotlibPointGraphViewer2d(MatplotlibRenderer):
def __init__(self, figure_id, new_figure, points, edges):
super(MatplotlibPointGraphViewer2d, self).__init__(figure_id, new_figure)
self.points = points
self.edges = edges
def render(
self,
image_view=False,
render_lines=True,
line_colour="r",
line_style="-",
line_width=1,
render_markers=True,
marker_style="o",
marker_size=5,
marker_face_colour="r",
marker_edge_colour="k",
marker_edge_width=1.0,
render_numbering=False,
numbers_horizontal_align="center",
numbers_vertical_align="bottom",
numbers_font_name="sans-serif",
numbers_font_size=10,
numbers_font_style="normal",
numbers_font_weight="normal",
numbers_font_colour="k",
render_axes=True,
axes_font_name="sans-serif",
axes_font_size=10,
axes_font_style="normal",
axes_font_weight="normal",
axes_x_limits=None,
axes_y_limits=None,
axes_x_ticks=None,
axes_y_ticks=None,
figure_size=(7, 7),
label=None,
):
from matplotlib import collections as mc
import matplotlib.pyplot as plt
# Flip x and y for viewing if points are tied to an image
points = self.points[:, ::-1] if image_view else self.points
# parse axes limits
min_x, min_y = np.min(points, axis=0)
max_x, max_y = np.max(points, axis=0)
axes_x_limits, axes_y_limits = _parse_axes_limits(
min_x, max_x, min_y, max_y, axes_x_limits, axes_y_limits
)
# get current axes object
ax = plt.gca()
# Check if graph has edges to be rendered (for example a PointCloud
# won't have any edges)
if render_lines and np.array(self.edges).shape[0] > 0:
# Get edges to be rendered
lines = zip(points[self.edges[:, 0], :], points[self.edges[:, 1], :])
# Draw line objects
lc = mc.LineCollection(
lines,
colors=line_colour,
linestyles=line_style,
linewidths=line_width,
cmap=GLOBAL_CMAP,
label=label,
)
ax.add_collection(lc)
# If a label is defined, it should only be applied to the lines, of
# a PointGraph, which represent each one of the labels, unless a
# PointCloud is passed in.
label = None
ax.autoscale()
if render_markers:
plt.plot(
points[:, 0],
points[:, 1],
linewidth=0,
markersize=marker_size,
marker=marker_style,
markeredgewidth=marker_edge_width,
markeredgecolor=marker_edge_colour,
markerfacecolor=marker_face_colour,
label=label,
)
# set numbering
_set_numbering(
ax,
points,
render_numbering=render_numbering,
numbers_horizontal_align=numbers_horizontal_align,
numbers_vertical_align=numbers_vertical_align,
numbers_font_name=numbers_font_name,
numbers_font_size=numbers_font_size,
numbers_font_style=numbers_font_style,
numbers_font_weight=numbers_font_weight,
numbers_font_colour=numbers_font_colour,
)
# set axes options
_set_axes_options(
ax,
render_axes=render_axes,
inverted_y_axis=image_view,
axes_font_name=axes_font_name,
axes_font_size=axes_font_size,
axes_font_style=axes_font_style,
axes_font_weight=axes_font_weight,
axes_x_limits=axes_x_limits,
axes_y_limits=axes_y_limits,
axes_x_ticks=axes_x_ticks,
axes_y_ticks=axes_y_ticks,
)
# set equal aspect ratio
ax.set_aspect("equal", adjustable="box")
# set figure size
_set_figure_size(self.figure, figure_size)
return self
class MatplotlibLandmarkViewer2d(MatplotlibRenderer):
def __init__(self, figure_id, new_figure, group, landmark_group):
super(MatplotlibLandmarkViewer2d, self).__init__(figure_id, new_figure)
self.group = group
self.landmark_group = landmark_group
def render(
self,
image_view=False,
render_lines=True,
line_colour="r",
line_style="-",
line_width=1,
render_markers=True,
marker_style="o",
marker_size=5,
marker_face_colour="r",
marker_edge_colour="k",
marker_edge_width=1.0,
render_numbering=False,
numbers_horizontal_align="center",
numbers_vertical_align="bottom",
numbers_font_name="sans-serif",
numbers_font_size=10,
numbers_font_style="normal",
numbers_font_weight="normal",
numbers_font_colour="k",
render_legend=True,
legend_title="",
legend_font_name="sans-serif",
legend_font_style="normal",
legend_font_size=10,
legend_font_weight="normal",
legend_marker_scale=None,
legend_location=2,
legend_bbox_to_anchor=(1.05, 1.0),
legend_border_axes_pad=None,
legend_n_columns=1,
legend_horizontal_spacing=None,
legend_vertical_spacing=None,
legend_border=True,
legend_border_padding=None,
legend_shadow=False,
legend_rounded_corners=False,
render_axes=True,
axes_font_name="sans-serif",
axes_font_size=10,
axes_font_style="normal",
axes_font_weight="normal",
axes_x_limits=None,
axes_y_limits=None,
axes_x_ticks=None,
axes_y_ticks=None,
figure_size=(7, 7),
):
import matplotlib.lines as mlines
import matplotlib.pyplot as plt
# Regarding the labels colours, we may get passed either no colours (in
# which case we generate random colours) or a single colour to colour
# all the labels with
# TODO: All marker and line options could be defined as lists...
n_labels = self.landmark_group.n_labels
line_colour = _check_colours_list(
render_lines,
line_colour,
n_labels,
"Must pass a list of line colours with length n_labels or a single "
"line colour for all labels.",
)
marker_face_colour = _check_colours_list(
render_markers,
marker_face_colour,
n_labels,
"Must pass a list of marker face colours with length n_labels or "
"a single marker face colour for all labels.",
)
marker_edge_colour = _check_colours_list(
render_markers,
marker_edge_colour,
n_labels,
"Must pass a list of marker edge colours with length n_labels or "
"a single marker edge colour for all labels.",
)
# check axes limits
points = self.landmark_group.points
if image_view:
min_y, min_x = np.min(points, axis=0)
max_y, max_x = np.max(points, axis=0)
else:
min_x, min_y = np.min(points, axis=0)
max_x, max_y = np.max(points, axis=0)
axes_x_limits, axes_y_limits = _parse_axes_limits(
min_x, max_x, min_y, max_y, axes_x_limits, axes_y_limits
)
# get pointcloud of each label
sub_pointclouds = self._build_sub_pointclouds()
# initialize legend_handles list
legend_handles = []
# for each pointcloud
for i, (label, pc) in enumerate(sub_pointclouds):
# render pointcloud
pc.view(
figure_id=self.figure_id,
image_view=image_view,
render_lines=render_lines,
line_colour=line_colour[i],
line_style=line_style,
line_width=line_width,
render_markers=render_markers,
marker_style=marker_style,
marker_size=marker_size,
marker_face_colour=marker_face_colour[i],
marker_edge_colour=marker_edge_colour[i],
marker_edge_width=marker_edge_width,
render_numbering=render_numbering,
numbers_horizontal_align=numbers_horizontal_align,
numbers_vertical_align=numbers_vertical_align,
numbers_font_name=numbers_font_name,
numbers_font_size=numbers_font_size,
numbers_font_style=numbers_font_style,
numbers_font_weight=numbers_font_weight,
numbers_font_colour=numbers_font_colour,
render_axes=render_axes,
axes_font_name=axes_font_name,
axes_font_size=axes_font_size,
axes_font_style=axes_font_style,
axes_font_weight=axes_font_weight,
axes_x_limits=axes_x_limits,
axes_y_limits=axes_y_limits,
axes_x_ticks=axes_x_ticks,
axes_y_ticks=axes_y_ticks,
figure_size=None,
)
# set legend entry
if render_legend:
tmp_line = line_style
tmp_marker = marker_style if render_markers else "None"
legend_handles.append(
mlines.Line2D(
[],
[],
linewidth=line_width,
linestyle=tmp_line,
color=line_colour[i],
marker=tmp_marker,
markersize=marker_size ** 0.5,
markeredgewidth=marker_edge_width,
markeredgecolor=marker_edge_colour[i],
markerfacecolor=marker_face_colour[i],
label="{0}: {1}".format(self.group, label),
)
)
# set legend
_set_legend(
plt.gca(),
legend_handles,
render_legend=render_legend,
legend_title=legend_title,
legend_font_name=legend_font_name,
legend_font_style=legend_font_style,
legend_font_size=legend_font_size,
legend_font_weight=legend_font_weight,
legend_marker_scale=legend_marker_scale,
legend_location=legend_location,
legend_bbox_to_anchor=legend_bbox_to_anchor,
legend_border_axes_pad=legend_border_axes_pad,
legend_n_columns=legend_n_columns,
legend_horizontal_spacing=legend_horizontal_spacing,
legend_vertical_spacing=legend_vertical_spacing,
legend_border=legend_border,
legend_border_padding=legend_border_padding,
legend_shadow=legend_shadow,
legend_rounded_corners=legend_rounded_corners,
)
# set figure size
_set_figure_size(self.figure, figure_size)
# If no limits are set then ensure that all sub-pointclouds fit in the
# view
if axes_x_limits is None and axes_y_limits is None:
plt.autoscale()
return self
def _build_sub_pointclouds(self):
return [
(label, self.landmark_group.get_label(label))
for label in self.landmark_group.labels
]
class MatplotlibAlignmentViewer2d(MatplotlibRenderer):
def __init__(self, figure_id, new_figure, alignment_transform):
super(MatplotlibAlignmentViewer2d, self).__init__(figure_id, new_figure)
self.alignment_transform = alignment_transform
def render(self, image=False, **kwargs):
r"""
Visualize how points are affected by the warp in 2 dimensions.
"""
import matplotlib.pyplot as plt
source = self.alignment_transform.source.points
target = self.alignment_transform.target.points
# a factor by which the minimum and maximum x and y values of the warp
# will be increased by.
x_margin_factor, y_margin_factor = 0.5, 0.5
# the number of x and y samples to take
n_x, n_y = 50, 50
# {x y}_{min max} is the actual bounds on either source or target
# landmarks
x_min, y_min = np.vstack([target.min(0), source.min(0)]).min(0)
x_max, y_max = np.vstack([target.max(0), source.max(0)]).max(0)
x_margin = x_margin_factor * (x_max - x_min)
y_margin = y_margin_factor * (y_max - y_min)
# {x y}_{min max}_m is the bound once it has been grown by the factor
# of the spread in that dimension
x_min_m = x_min - x_margin
x_max_m = x_max + x_margin
y_min_m = y_min - y_margin
y_max_m = y_max + y_margin
# build sample points for the selected region
x = np.linspace(x_min_m, x_max_m, n_x)
y = np.linspace(y_min_m, y_max_m, n_y)
xx, yy = np.meshgrid(x, y)
sample_points = np.concatenate(
[xx.reshape([-1, 1]), yy.reshape([-1, 1])], axis=1
)
warped_points = self.alignment_transform.apply(sample_points)
delta = warped_points - sample_points
# plot the sample points result
x, y, = (
0,
1,
)
if image:
# if we are overlaying points onto an image,
# we have to account for the fact that axis 0 is typically
# called 'y' and axis 1 is typically called 'x'. Flip them here
x, y = y, x
plt.quiver(sample_points[:, x], sample_points[:, y], delta[:, x], delta[:, y])
delta = target - source
# plot how the landmarks move from source to target
plt.quiver(
source[:, x],
source[:, y],
delta[:, x],
delta[:, y],
angles="xy",
scale_units="xy",
scale=1,
)
# rescale to the bounds
plt.xlim((x_min_m, x_max_m))
plt.ylim((y_min_m, y_max_m))
if image:
# if we are overlaying points on an image, axis0 (the 'y' axis)
# is flipped.
plt.gca().invert_yaxis()
return self
class MatplotlibGraphPlotter(MatplotlibRenderer):
def __init__(
self,
figure_id,
new_figure,
x_axis,
y_axis,
title=None,
legend_entries=None,
x_label=None,
y_label=None,
x_axis_limits=None,
y_axis_limits=None,
x_axis_ticks=None,
y_axis_ticks=None,
):
super(MatplotlibGraphPlotter, self).__init__(figure_id, new_figure)
self.x_axis = x_axis
self.y_axis = y_axis
if legend_entries is None:
legend_entries = ["Curve {}".format(i) for i in range(len(y_axis))]
self.legend_entries = legend_entries
self.title = title
self.x_label = x_label
self.y_label = y_label
self.x_axis_ticks = x_axis_ticks
self.y_axis_ticks = y_axis_ticks
# parse axes limits
min_x = np.min(x_axis)
max_x = np.max(x_axis)
min_y = np.min([np.min(l) for l in y_axis])
max_y = np.max([np.max(l) for l in y_axis])
self.x_axis_limits, self.y_axis_limits = _parse_axes_limits(
min_x, max_x, min_y, max_y, x_axis_limits, y_axis_limits
)
def render(
self,
render_lines=True,
line_colour="r",
line_style="-",
line_width=1,
render_markers=True,
marker_style="o",
marker_size=6,
marker_face_colour="r",
marker_edge_colour="k",
marker_edge_width=1.0,
render_legend=True,
legend_title="",
legend_font_name="sans-serif",
legend_font_style="normal",
legend_font_size=10,
legend_font_weight="normal",
legend_marker_scale=None,
legend_location=2,
legend_bbox_to_anchor=(1.05, 1.0),
legend_border_axes_pad=None,
legend_n_columns=1,
legend_horizontal_spacing=None,
legend_vertical_spacing=None,
legend_border=True,
legend_border_padding=None,
legend_shadow=False,
legend_rounded_corners=False,
render_axes=True,
axes_font_name="sans-serif",
axes_font_size=10,
axes_font_style="normal",
axes_font_weight="normal",
figure_size=(7, 7),
render_grid=True,
grid_line_style="--",
grid_line_width=1,
):
import matplotlib.pyplot as plt
# Check the viewer options that can be different for each plotted curve
n_curves = len(self.y_axis)
render_lines = _check_render_flag(
render_lines,
n_curves,
"Must pass a list of different "
"render_lines flag for each curve or "
"a single render_lines flag for all "
"curves.",
)
render_markers = _check_render_flag(
render_markers,
n_curves,
"Must pass a list of different "
"render_markers flag for each "
"curve or a single render_markers "
"flag for all curves.",
)
line_colour = _check_colours_list(
True,
line_colour,
n_curves,
"Must pass a list of line colours with length n_curves or a single "
"line colour for all curves.",
)
line_style = _check_colours_list(
True,
line_style,
n_curves,
"Must pass a list of line styles with length n_curves or a single "
"line style for all curves.",
)
line_width = _check_colours_list(
True,
line_width,
n_curves,
"Must pass a list of line widths with length n_curves or a single "
"line width for all curves.",
)
marker_style = _check_colours_list(
True,
marker_style,
n_curves,
"Must pass a list of marker styles with length n_curves or a "
"single marker style for all curves.",
)
marker_size = _check_colours_list(
True,
marker_size,
n_curves,
"Must pass a list of marker sizes with length n_curves or a single "
"marker size for all curves.",
)
marker_face_colour = _check_colours_list(
True,
marker_face_colour,
n_curves,
"Must pass a list of marker face colours with length n_curves or a "
"single marker face colour for all curves.",
)
marker_edge_colour = _check_colours_list(
True,
marker_edge_colour,
n_curves,
"Must pass a list of marker edge colours with length n_curves or a "
"single marker edge colour for all curves.",
)
marker_edge_width = _check_colours_list(
True,
marker_edge_width,
n_curves,
"Must pass a list of marker edge widths with length n_curves or a "
"single marker edge width for all curves.",
)
# plot all curves
ax = plt.gca()
for i, y in enumerate(self.y_axis):
linestyle = line_style[i]
if not render_lines[i]:
linestyle = "None"
marker = marker_style[i]
if not render_markers[i]:
marker = "None"
plt.plot(
self.x_axis,
y,
color=line_colour[i],
linestyle=linestyle,
linewidth=line_width[i],
marker=marker,
markeredgecolor=marker_edge_colour[i],
markerfacecolor=marker_face_colour[i],
markeredgewidth=marker_edge_width[i],
markersize=marker_size[i],
label=self.legend_entries[i],
)
# set legend
_set_legend(
ax,
legend_handles=None,
render_legend=render_legend,
legend_title=legend_title,
legend_font_name=legend_font_name,
legend_font_style=legend_font_style,
legend_font_size=legend_font_size,
legend_font_weight=legend_font_weight,
legend_marker_scale=legend_marker_scale,
legend_location=legend_location,
legend_bbox_to_anchor=legend_bbox_to_anchor,
legend_border_axes_pad=legend_border_axes_pad,
legend_n_columns=legend_n_columns,
legend_horizontal_spacing=legend_horizontal_spacing,
legend_vertical_spacing=legend_vertical_spacing,
legend_border=legend_border,
legend_border_padding=legend_border_padding,
legend_shadow=legend_shadow,
legend_rounded_corners=legend_rounded_corners,
)
# set axes options
_set_axes_options(
ax,
render_axes=render_axes,
inverted_y_axis=False,
axes_font_name=axes_font_name,
axes_font_size=axes_font_size,
axes_font_style=axes_font_style,
axes_font_weight=axes_font_weight,
axes_x_limits=self.x_axis_limits,
axes_y_limits=self.y_axis_limits,
axes_x_ticks=self.x_axis_ticks,
axes_y_ticks=self.y_axis_ticks,
axes_x_label=self.x_label,
axes_y_label=self.y_label,
title=self.title,
)
# set grid options
_set_grid_options(
render_grid=render_grid,
grid_line_style=grid_line_style,
grid_line_width=grid_line_width,
)
# set figure size
_set_figure_size(self.figure, figure_size)
return self
class MatplotlibMultiImageViewer2d(MatplotlibRenderer):
def __init__(self, figure_id, new_figure, image_list):
super(MatplotlibMultiImageViewer2d, self).__init__(figure_id, new_figure)
self.image_list = image_list
def render(self, interval=50, **kwargs):
import matplotlib.pyplot as plt
import matplotlib.cm as cm
import matplotlib.animation as animation
if len(self.image_list[0].shape) == 2:
# Single channels are viewed in Gray
_ax = plt.imshow(self.image_list[0], cmap=cm.Greys_r, **kwargs)
else:
_ax = plt.imshow(self.image_list[0], **kwargs)
def init():
return (_ax,)
def animate(j):
_ax.set_data(self.image_list[j])
return (_ax,)
self._ani = animation.FuncAnimation(
self.figure,
animate,
init_func=init,
frames=len(self.image_list),
interval=interval,
blit=True,
)
return self
class MatplotlibMultiImageSubplotsViewer2d(MatplotlibRenderer, MatplotlibSubplots):
def __init__(self, figure_id, new_figure, image_list):
super(MatplotlibMultiImageSubplotsViewer2d, self).__init__(
figure_id, new_figure
)
self.image_list = image_list
self.num_subplots = self.image_list[0].shape[2]
self.plot_layout = self._subplot_layout(self.num_subplots)
def render(self, interval=50, **kwargs):
import matplotlib.cm as cm
import matplotlib.animation as animation
import matplotlib.pyplot as plt
p = self.plot_layout
_axs = []
for i in range(self.image_list[0].shape[2]):
plt.subplot(p[0], p[1], 1 + i)
# Hide the x and y labels
plt.axis("off")
_ax = plt.imshow(self.image_list[0][:, :, i], cmap=cm.Greys_r, **kwargs)
_axs.append(_ax)
def init():
return _axs
def animate(j):
for k, _ax in enumerate(_axs):
_ax.set_data(self.image_list[j][:, :, k])
return _axs
self._ani = animation.FuncAnimation(
self.figure,
animate,
init_func=init,
frames=len(self.image_list),
interval=interval,
blit=True,
)
return self
def sample_colours_from_colourmap(n_colours, colour_map):
import matplotlib.pyplot as plt
cm = plt.get_cmap(colour_map)
return [cm(1.0 * i / n_colours)[:3] for i in range(n_colours)]
def _check_colours_list(render_flag, colours_list, n_objects, error_str):
if render_flag:
if colours_list is None:
# sample colours from jet colour map
colours_list = sample_colours_from_colourmap(n_objects, GLOBAL_CMAP)
if isinstance(colours_list, list):
if len(colours_list) == 1:
colours_list *= n_objects
elif len(colours_list) != n_objects:
raise ValueError(error_str)
else:
colours_list = [colours_list] * n_objects
else:
colours_list = [None] * n_objects
return colours_list
def _check_render_flag(render_flag, n_objects, error_str):
if isinstance(render_flag, bool):
render_flag = [render_flag] * n_objects
elif isinstance(render_flag, list):
if len(render_flag) == 1:
render_flag *= n_objects
elif len(render_flag) != n_objects:
raise ValueError(error_str)
else:
raise ValueError(error_str)
return render_flag