Source code for menpo.transform.groupalign.procrustes

import numpy as np

from menpo.transform import AlignmentSimilarity, UniformScale, Translation
from .base import MultipleAlignment


[docs]class GeneralizedProcrustesAnalysis(MultipleAlignment): r""" Class for aligning multiple source shapes between them. After construction, the :map:`AlignmentSimilarity` transforms used to map each source optimally to the target can be found at `transforms`. Parameters ---------- sources : list of :map:`PointCloud` List of pointclouds to be aligned. target : :map:`PointCloud` The target :map:`PointCloud` to align each source to. If None, then the mean of the sources is used. Default: None Raises ------- ValueError Need at least two sources to align """ def __init__(self, sources, target=None): super(GeneralizedProcrustesAnalysis, self).__init__(sources, target=target) initial_target = self.target self.transforms = [AlignmentSimilarity(source, self.target) for source in self.sources] self.initial_target_scale = self.target.norm() self.n_iterations = 1 self.max_iterations = 100 self.converged = self._recursive_procrustes() if target is not None: self.target = initial_target def _recursive_procrustes(self): r""" Recursively calculates a procrustes alignment. """ from menpo.shape import PointCloud if self.n_iterations > self.max_iterations: return False av_aligned_source = sum( t.aligned_source.points for t in self.transforms) / self.n_sources new_target = PointCloud(av_aligned_source) # rescale the new_target to be the same size as the original about # it's centre rescale = UniformScale( self.initial_target_scale / new_target.norm(), self.n_dims) centre = Translation(-new_target.centre) rescale_about_centre = centre.compose_before(rescale).compose_before( centre.pseudoinverse) rescale_about_centre.apply_inplace(new_target) # check to see if we have converged yet delta_target = np.linalg.norm(self.target.points - new_target.points) if delta_target < 1e-6: return True else: self.n_iterations += 1 for t in self.transforms: t.set_target(new_target) self.target = new_target return self._recursive_procrustes() @property
[docs] def mean_aligned_shape(self): r""" Returns the mean of the aligned shapes. :type: PointCloud """ from menpo.shape import PointCloud return PointCloud(np.mean([t.target.points for t in self.transforms], axis=0))
@property
[docs] def av_alignment_error(self): r""" Returns the average error of the recursive procrustes alignment. :type: float """ return sum([t.alignment_error for t in self.transforms])/self.n_sources
def __str__(self): if self.converged: return ('Converged after %d iterations with av. error %f' % (self.n_iterations, self.av_alignment_error)) else: return ('Failed to converge after %d iterations with av. error ' '%f' % (self.n_iterations, self.av_alignment_error))