Source code for menpo.fitmultilevel.fittingresult

from __future__ import division
import numpy as np

from menpo.transform import Scale
from menpo.visualize.base import GraphPlotter, MultipleImageViewer
from menpo.fit.fittingresult import FittingResult


[docs]class MultilevelFittingResult(FittingResult): r""" Object that holds the state of a MultipleFitter object (to which it is linked) after it has fitted a particular image. Parameters ----------- image: :class:`menpo.image.masked.MaskedImage` The fitted image. multiple_fitter: :class:`menpo.fitter.base.Fitter` The fitter object used to fitter the image. fitting_results: :class:`menpo.fitter.fittingresult.FittingResult` list A list of basic fitting objects. affine_correction: :class: `menpo.transforms.affine.Affine` An affine transform that maps the result of the top resolution fitting level to the space scale of the original image. gt_shape: class:`menpo.shape.PointCloud`, optional The ground truth shape associated to the image. Default: None error_type: 'me_norm', 'me' or 'rmse', optional. Specifies the way in which the error between the fitted and ground truth shapes is to be computed. Default: 'me_norm' """ def __init__(self, image, multiple_fitter, fitting_results, affine_correction, gt_shape=None, error_type='me_norm'): self._error_stop = None # Internal attribute of error_type setter self.fitting_results = fitting_results self._affine_correction = affine_correction super(MultilevelFittingResult, self).__init__( image, multiple_fitter, gt_shape=gt_shape, error_type=error_type) @property
[docs] def n_levels(self): r""" Returns the number of levels of the fitter object. """ return self.fitter.n_levels
@property
[docs] def downscale(self): r""" Returns the downscale factor used by the multiple fitter. """ return self.fitter.downscale
@property
[docs] def scaled_levels(self): r""" Returns True if the shape results returned by the basic fitting_results must be scaled. """ return True # self.fitter.scaled_levels
@property
[docs] def fitted(self): r""" Returns the fitted state of each fitting object. """ return [f.fitted for f in self.fitting_results]
@FittingResult.error_type.setter
[docs] def error_type(self, error_type): r""" Sets the error type according to a set of predefined options. """ if error_type == 'me_norm': for f in self.fitting_results: f.error_type = error_type self._error_stop = 0.1 self._error_text = 'Point-to-point error normalized by object ' \ 'size' elif error_type == 'me': NotImplementedError("erro_type 'me' not implemented yet") elif error_type == 'rmse': NotImplementedError("error_type 'rmse' not implemented yet") else: raise ValueError("Unknown error_type string selected. Valid" "options are: 'me_norm', 'me', 'rmse'") self._error_type = error_type
@property
[docs] def n_iters(self): r""" Returns the total number of iterations used to fitter the image. """ n_iters = 0 for f in self.fitting_results: n_iters += f.n_iters return n_iters
[docs] def shapes(self, as_points=False): r""" Generates a list containing the shapes obtained at each fitting iteration. Parameters ----------- as_points: boolean, optional Whether the results is returned as a list of PointClouds or ndarrays. Default: False Returns ------- shapes: :class:`menpo.shape.PointCoulds or ndarray list A list containing the shapes obtained at each fitting iteration. """ n = self.n_levels - 1 shapes = [] for j, f in enumerate(self.fitting_results): if self.scaled_levels: transform = Scale(self.downscale**(n-j), 2) for t in f.shapes(as_points=as_points): transform.apply_inplace(t) shapes.append(self._affine_correction.apply(t)) else: for t in f.shapes(as_points=as_points): shapes.append(self._affine_correction.apply(t)) return shapes
@property
[docs] def final_shape(self): r""" Returns the final fitted shape. """ return self._affine_correction.apply( self.fitting_results[-1].final_shape)
@property
[docs] def initial_shape(self): r""" Returns the initial shape from which the fitting started. """ n = self.n_levels - 1 initial_shape = self.fitting_results[0].initial_shape if self.scaled_levels: Scale(self.downscale ** n, initial_shape.n_dims).apply_inplace(initial_shape) return self._affine_correction.apply(initial_shape)
@FittingResult.gt_shape.setter
[docs] def gt_shape(self, value): r""" Setter for the ground truth shape associated to the image. """ self._gt_shape = value
[docs] def plot_error(self, figure_id=None, new_figure=False, **kwargs): r""" Plots the error evolution throughout the fitting. """ if self.gt_shape is not None: title = 'Error evolution' legend = [self.algorithm] x_label = 'Number of iterations' y_label = self._error_text errors = self.errors x_limit = self.n_iters + self.n_levels axis_limits = [0, x_limit, 0, np.max(errors)] return GraphPlotter(figure_id, new_figure, range(0, x_limit), [errors], title=title, legend=legend, x_label=x_label, y_label=y_label, axis_limits=axis_limits).render(**kwargs) else: raise ValueError('Ground truth shape has not been set, error ' 'cannot be plotted')
def __str__(self): out = "Initial error: {0:.4f}\nFinal error: {1:.4f}".format( self.initial_error, self.final_error) return out
[docs]class AAMMultilevelFittingResult(MultilevelFittingResult): r""" Object let us recover the state of an AAM Fitter after the latter has fitted a particular image. Parameters ----------- image: :class:`pybug.image.masked.MaskedImage` The fitted image. aam_fitter: :class:`pybug.aam.fitter.AAMFitter` The aam_fitter object used to fitter the image. basic_fittings: :class:`pybug.aam.fitting.BasicFitting` list A list of basic fitting objects. _affine_correction: :class: `pybug.transforms.affine.Affine` An affine transform that maps the result of the top resolution fitting level to the space scale of the original image. gt_shape: class:`pybug.shape.PointCloud`, optional The ground truth shape associated to the image. Default: None error_type: 'me_norm', 'me' or 'rmse', optional. Specifies the way in which the error between the fitted and ground truth shapes is to be computed. Default: 'me_norm' """ @property
[docs] def residual(self): r""" Returns the type of residual used by the basic fitter associated to each basic fitting. """ # TODO: ensure that all basic_fitting residuals are the same? return self.fitting_results[-1].residual.type
@property
[docs] def costs(self): r""" Returns a list containing the cost at each fitting iteration. """ raise ValueError('costs not implemented yet.') #return self._flatten_out([f.costs for f in self.basic_fittings])
@staticmethod def _flatten_out(list_of_lists): return [i for l in list_of_lists for i in l] @property
[docs] def final_cost(self): r""" Returns the final fitting cost. """ return self.fitting_results[-1].final_cost
@property
[docs] def initial_cost(self): r""" Returns the initial fitting cost. """ return self.fitting_results[0].initial_cost
[docs] def warped_images(self, from_basic_fittings=False, as_pixels=False): r""" Generates a list containing the warped images obtained at each fitting iteration. Parameters ----------- from_basic_fittings : `boolean`, optional If ``True``, the returned transform per iteration is used to warp the internal image representation used by each basic fitter. If ``False``, the transforms are used to warp original image. as_pixels : `boolean`, optional Whether the result is returned as a list of :map:`Image` or `ndarray`. Returns ------- warped_images : :map:`MaskedImage` or `ndarray` list A list containing the warped images obtained at each fitting iteration. """ if from_basic_fittings: return self._flatten_out([f.warped_images(as_pixels=as_pixels) for f in self.fitting_results]) else: mask = self.fitting_results[-1].fitter.template.mask transform = self.fitting_results[-1].fitter.transform interpolator = self.fitting_results[-1].fitter.interpolator warped_images = [] for t in self.shapes(): transform.set_target(t) image = self.image.warp_to(mask, transform, interpolator=interpolator) if as_pixels: image = image.pixels warped_images.append(image) return warped_images
[docs] def appearance_reconstructions(self, as_pixels=False): r""" Generates a list containing the appearance reconstruction obtained at each fitting iteration. Parameters ---------- as_pixels : `boolean`, optional Whether the result is returned as a list of :map:`Image` or `ndarray`. Returns ------- appearance_reconstructions : :map:`MaskedImage` or `ndarray` list A list containing the appearance reconstructions obtained at each fitting iteration. """ return self._flatten_out( [f.appearance_reconstructions(as_pixels=as_pixels) for f in self.fitting_results])
[docs] def plot_cost(self, figure_id=None, new_figure=False, **kwargs): r""" Plots the cost evolution throughout the fitting. """ title = 'Cost evolution' legend = self.algorithm x_label = 'Number of iterations' y_label = 'Normalized cost' costs = [c for cost in self.costs for c in cost] total_n_iters = self.n_iters + self.n_levels axis_limits = [0, total_n_iters, 0, max(costs)] return GraphPlotter(figure_id, new_figure, range(0, self.n_iters+self.n_levels), costs, title=title, legend=legend, x_label=x_label, y_label=y_label, axis_limits=axis_limits).render(**kwargs)
[docs] def view_warped_images(self, figure_id=None, new_figure=False, channels=None, from_basic_fittings=False, **kwargs): r""" Displays the warped images. """ pixels_list = self.warped_images( from_basic_fittings=from_basic_fittings, as_pixels=True) return self._view_images(pixels_list, figure_id=figure_id, new_figure=new_figure, channels=channels, **kwargs)
[docs] def view_appearance_reconstructions(self, figure_id=None, new_figure=False, channels=None, **kwargs): r""" Displays the appearance reconstructions. """ pixels_list = self.appearance_reconstructions(as_pixels=True) return self._view_images(pixels_list, figure_id=figure_id, new_figure=new_figure, channels=channels, **kwargs)
[docs] def view_error_images(self, figure_id=None, new_figure=False, channels=None, **kwargs): r""" Displays the error images. """ warped_images = self.warped_images(as_pixels=True) appearances = self.appearance_reconstructions(as_pixels=True) pixels_list = [a - i for a, i in zip(appearances, warped_images)] return self._view_images(pixels_list, figure_id=figure_id, new_figure=new_figure, channels=channels, **kwargs)
def _view_images(self, pixels_list, figure_id=None, new_figure=False, channels=None, **kwargs): return MultipleImageViewer(figure_id, new_figure, self.image.n_dims, pixels_list, channels=channels).render(**kwargs) def print_fitting_info(self): super(AAMMultilevelFittingResult, self).print_fitting_info() #print "Initial cost: {}".format(self.initial_cost) #print "Final cost: {}".format(self.final_cost)