Source code for menpo.transform.piecewiseaffine.base

import numpy as np
from copy import deepcopy
from menpo.base import Copyable
from menpo.transform.base import Alignment, Invertible, Transform

# TODO View is broken for PWA (TriangleContainmentError)


class TriangleContainmentError(Exception):
    r"""
    Exception that is thrown when an attempt is made to map a point with a
    PWATransform that does not lie in a source triangle.

    points_outside_source_domain : ``(d,)`` `ndarray`
        A `bool` value for the ``d`` points that were attempted to be applied.
        If ``True```, the point was outside of the domain.
    """

    def __init__(self, points_outside_source_domain):
        super(TriangleContainmentError, self).__init__()
        self.points_outside_source_domain = points_outside_source_domain


def containment_from_alpha_beta(alpha, beta):
    r"""
    Check `alpha` and `beta` are within a triangle (``alpha >= 0``,
    ``beta >= 0``, ``alpha + beta <= 1``). Returns the indices of the triangles
    that are `alpha` and `beta` are in. If any of the points are not contained
    in a triangle, raises a `TriangleContainmentError`.

    Parameters
    ----------
    alpha : ``(K, n_tris)`` `ndarray`
        Alpha for each point and triangle being tested.
    beta : ``(K, n_tris)`` `ndarray`
        Beta for each point and triangle being tested.

    Returns
    -------
    tri_index : ``(L,)`` `ndarray`
        Triangle index for each `points`, assigning each
        point in a triangle to the triangle index.

    Raises
    ------
    TriangleContainmentError
        All `points` must be contained in a source triangle. Check
        `error.points_outside_source_domain` to handle this case.
    """
    # (K, n_tris), boolean for whether a given triangle contains a given
    #  point
    point_containment = np.logical_and(
        np.logical_and(alpha >= 0, beta >= 0), alpha + beta <= 1
    )
    # is each point in a triangle?
    point_in_a_triangle = np.any(point_containment, axis=1)
    if np.any(~point_in_a_triangle):
        raise TriangleContainmentError(~point_in_a_triangle)
    point_index, tri_index = np.nonzero(point_containment)
    # don't want duplicates! ensure that here:
    index = np.zeros(alpha.shape[0])
    index[point_index] = tri_index
    return index.astype(np.uint32)


def alpha_beta(i, ij, ik, points):
    r"""
    Calculates the `alpha` and `beta` values (barycentric coordinates) for each
    triangle for all points provided. Note that this does not raise a
    `TriangleContainmentError`.

    Parameters
    ----------
    i : ``(n_tris, 2)`` `ndarray`
        The coordinate of the i'th point of each triangle
    ij : ``(n_tris, 2)`` `ndarray`
        The vector between the i'th point and the j'th point of each
        triangle
    ik : ``(n_tris, 2)`` `ndarray`
        The vector between the i'th point and the k'th point of each
        triangle
    points : ``(n_points, 2)`` `ndarray`
        Points to calculate the barycentric coordinates for.

    Returns
    -------
    alpha : ``(n_points, n_tris)`` `ndarray`
        The `alpha` for each point and triangle. Alpha can be interpreted
        as the contribution of the `ij` vector to the position of the point in
        question.
    beta : ``(n_points, n_tris)`` `ndarray`
        The beta for each point and triangle. Beta can be interpreted as
        the contribution of the ik vector to the position of the point in
        question.
    """
    ip = points[..., None] - i
    dot_jj = np.einsum("dt, dt -> t", ij, ij)
    dot_kk = np.einsum("dt, dt -> t", ik, ik)
    dot_jk = np.einsum("dt, dt -> t", ij, ik)
    dot_pj = np.einsum("vdt, dt -> vt", ip, ij)
    dot_pk = np.einsum("vdt, dt -> vt", ip, ik)

    d = 1.0 / (dot_jj * dot_kk - dot_jk * dot_jk)
    alpha = (dot_kk * dot_pj - dot_jk * dot_pk) * d
    beta = (dot_jj * dot_pk - dot_jk * dot_pj) * d
    return alpha, beta


def index_alpha_beta(i, ij, ik, points):
    """
    Finds for each input point the index of it's bounding triangle and the
    `alpha` and `beta` value for that point in the triangle. Note this means
    that the following statements will always be true::

        alpha + beta <= 1
        alpha >= 0
        beta >= 0

    for each triangle result.

    Trying to map a point that does not exist in a triangle throws a
    `TriangleContainmentError`.

    Parameters
    ----------
    i : ``(n_tris, 2)`` `ndarray`
        The coordinate of the i'th point of each triangle
    ij : ``(n_tris, 2)`` `ndarray`
        The vector between the i'th point and the j'th point of each
        triangle
    ik : ``(n_tris, 2)`` `ndarray`
        The vector between the i'th point and the k'th point of each
        triangle
    points : ``(n_points, 2)`` `ndarray`
        Points to calculate the barycentric coordinates for.

    Returns
    -------
    tri_index : ``(n_tris,)`` `ndarray`
        Triangle index for each of the `points`, assigning each point to its
        containing triangle.
    alpha : ``(n_tris,)`` `ndarray`
        Alpha for containing triangle of each point.
    beta : ``(n_tris,)`` `ndarray`
        Beta for containing triangle of each point.

    Raises
    ------
    TriangleContainmentError
        All `points` must be contained in a source triangle. Check
        `error.points_outside_source_domain` to handle this case.
    """
    alpha, beta = alpha_beta(i, ij, ik, points)
    each_point = np.arange(points.shape[0])
    index = containment_from_alpha_beta(alpha, beta)
    return index, alpha[each_point, index], beta[each_point, index]


def barycentric_vectors(points, trilist):
    r"""
    Compute the affine transformation between each triangle in the `source`
    and `target`. This is calculated analytically.

    Parameters
    ----------
    points : ``(n_points, 2)`` `ndarray`
        Points to calculate the barycentric coordinates for.
    trilist: ``(n_tris, 3)`` `ndarray`
        The 0-based index triangulation joining the points.

    Returns
    -------
    i : ``(n_tris, 2)`` `ndarray`
        The coordinate of the i'th point of each triangle
    ij : ``(n_tris, 2)`` `ndarray`
        The vector between the i'th point and the j'th point of each
        triangle
    ik : ``(n_tris, 2)`` `ndarray`
        The vector between the i'th point and the k'th point of each
        triangle
    """
    # we permute the axes of the indexed point set to have shape
    # [3, n_dims, n_tris] for ease of indexing in.
    x = np.transpose(points[trilist], axes=[1, 2, 0])
    return x[0], x[1] - x[0], x[2] - x[0]


# Note we inherit from Alignment first to get it's n_dims behavior
class AbstractPWA(Alignment, Transform, Invertible):
    r"""
    A piecewise affine transformation.

    This is composed of a number of triangles defined be a set of `source` and
    `target` vertices. These vertices are related by a common triangle `list`.
    No limitations on the nature of the triangle `list` are imposed. Points can
    then be mapped via barycentric coordinates from the `source` to the `target`
    space. Trying to map points that are not contained by any source triangle
    throws a `TriangleContainmentError`, which contains diagnostic information.

    Parameters
    ----------
    source : :map:`PointCloud` or :map:`TriMesh`
        The source points. If a TriMesh is provided, the triangulation on
        the TriMesh is used. If a PointCloud is provided, a Delaunay
        triangulation of the source is performed automatically.
    target : :map:`PointCloud`
        The target points. Note that the trilist is entirely decided by the
        source.

    Raises
    ------
    ValueError
        Source and target must both be 2D.
    TriangleContainmentError
        All points to apply must be contained in a source triangle. Check
        `error.points_outside_source_domain` to handle this case.
    """

    def __init__(self, source, target):
        from menpo.shape import TriMesh  # to avoid circular import

        if not isinstance(source, TriMesh):
            source = TriMesh(source.points)
        Alignment.__init__(self, source, target)
        if self.n_dims != 2:
            raise ValueError("source and target must be 2 " "dimensional")
        self.ti, self.tij, self.tik = None, None, None
        self._rebuild_target_vectors()

    @property
    def n_tris(self):
        r"""
        The number of triangles in the triangle list.

        :type: `int`
        """
        return self.source.n_tris

    @property
    def trilist(self):
        r"""
        The triangle list.

        :type: ``(n_tris, 3)`` `ndarray`
        """
        return self.source.trilist

    def _rebuild_target_vectors(self):
        r"""
        Rebuild the vectors that are used in the apply method. This needs to
        be called whenever the target is changed.
        """
        t = self.target.points[self.trilist]
        # get vectors ij ik for the target
        self.tij, self.tik = t[:, 1] - t[:, 0], t[:, 2] - t[:, 0]
        # target i'th vertex positions
        self.ti = t[:, 0]

    def _sync_state_from_target(self):
        r"""
        PWA is particularly efficient to sync from target - we don't have to
        do much at all, just rebuild the target vectors.
        """
        self._rebuild_target_vectors()

    def _apply(self, x, **kwargs):
        """
        Applies this transform to a new set of vectors.

        Parameters
        ----------
        x : ``(K, 2)`` `ndarray`
            Points to apply this transform to.

        Returns
        -------
        transformed : ``(K, 2)`` `ndarray`
            The transformed array.
        """
        tri_index, alpha, beta = self.index_alpha_beta(x)
        return (
            self.ti[tri_index]
            + alpha[:, None] * self.tij[tri_index]
            + beta[:, None] * self.tik[tri_index]
        )

    def _apply_batched(self, x, batch_size, **kwargs):
        # This is a rare case where we need to override the batched apply
        # method. In this case, we override it because we want to the
        # possibly raised TriangleContainmentError to contain ALL the points
        # that were considered, and not just the first batch of points.
        if batch_size is None:
            return self._apply(x, **kwargs)
        else:
            outputs = []
            points_outside_source_domain = []
            n_points = x.shape[0]
            exception_thrown = False
            for lo_ind in range(0, n_points, batch_size):
                try:
                    hi_ind = lo_ind + batch_size
                    outputs.append(self._apply(x[lo_ind:hi_ind], **kwargs))
                except TriangleContainmentError as e:
                    exception_thrown = True
                    points_outside_source_domain.append(e.points_outside_source_domain)
                else:
                    # No exception was thrown, so all points were inside
                    points_outside_source_domain.append(
                        np.zeros(batch_size, dtype=np.bool)
                    )

            if exception_thrown:
                raise TriangleContainmentError(np.hstack(points_outside_source_domain))
            else:
                return np.vstack(outputs)

    def index_alpha_beta(self, points):
        """
        Finds for each input point the index of its bounding triangle and the
        `alpha` and `beta` value for that point in the triangle. Note this
        means that the following statements will always be true::

            alpha + beta <= 1
            alpha >= 0
            beta >= 0

        for each triangle result.

        Trying to map a point that does not exist in a triangle throws a
        `TriangleContainmentError`.

        Parameters
        ----------
        points : ``(K, 2)`` `ndarray`
            Points to test.

        Returns
        -------
        tri_index : ``(L,)`` `ndarray`
            Triangle index for each of the `points`, assigning each
            point to it's containing triangle.
        alpha : ``(L,)`` `ndarray`
            Alpha for containing triangle of each point.
        beta : ``(L,)`` `ndarray`
            Beta for containing triangle of each point.

        Raises
        ------
        TriangleContainmentError
            All `points` must be contained in a source triangle. Check
            `error.points_outside_source_domain` to handle this case.
        """
        raise NotImplementedError()

    @property
    def has_true_inverse(self):
        """
        The inverse is true.

        :type: ``True``
        """
        return True

    def pseudoinverse(self):
        r"""
        The pseudoinverse of the transform - that is, the transform that
        results from swapping `source` and `target`, or more formally, negating
        the transforms parameters. If the transform has a true inverse this
        is returned instead.

        :type: ``type(self)``
        """
        from menpo.shape import PointCloud, TriMesh  # to avoid circular import

        new_source = TriMesh(self.target.points, self.source.trilist)
        new_target = PointCloud(self.source.points)
        return type(self)(new_source, new_target)


class PythonPWA(AbstractPWA):
    def __init__(self, source, target):
        super(PythonPWA, self).__init__(source, target)
        si, sij, sik = barycentric_vectors(self.source.points, self.trilist)
        self.s, self.sij, self.sik = si, sij, sik

    def index_alpha_beta(self, points):
        return index_alpha_beta(self.s, self.sij, self.sik, points)


class CachedPWA(PythonPWA):
    def __init__(self, source, target):
        super(CachedPWA, self).__init__(source, target)
        self._applied_points, self._iab = None, None

    def index_alpha_beta(self, points):
        if (
            self._applied_points is None
            or not points.shape == self._applied_points.shape
            or not np.allclose(points, self._applied_points)
        ):
            # This must happen first in case index_alpha_beta throws a
            # TriangleContainmentError
            self._iab = PythonPWA.index_alpha_beta(self, points)
            self._applied_points = points
        return self._iab