# I grabbed this from leo singer as it is a nicely written tool
# but I do not want the user to have to install gsl and the
# the whole ligo framework

Axes subclasses for astronomical mapmaking.

This module adds several :class:`astropy.visualization.wcsaxes.WCSAxes`
subclasses to the Matplotlib projection registry. The projections have names of
the form :samp:`{astro_or_geo} [{lon_units}] {projection}`.

:samp:`{astro_or_geo}` may be ``astro`` or ``geo``. It controls the
reference frame, either celestial (ICRS) or terrestrial (ITRS).

:samp:`{lon_units}` may be ``hours`` or ``degrees``. It controls the units of
the longitude axis. If omitted, ``astro`` implies ``hours`` and ``geo`` implies

:samp:`{projection}` may be any of the following:

* ``aitoff`` for the Aitoff all-sky projection

* ``mollweide`` for the Mollweide all-sky projection

* ``globe`` for an orthographic projection, like the three-dimensional view of
  the Earth from a distant satellite

* ``zoom`` for a gnomonic projection suitable for visualizing small zoomed-in

Some of the projections support additional optional arguments. The ``globe``
projections support the options ``center`` and ``rotate``. The ``zoom``
projections support the options ``center``, ``radius``, and ``rotate``.

.. plot::
   :context: reset
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='astro hours mollweide')

.. plot::
   :context: reset
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='geo aitoff')

.. plot::
   :context: reset
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='astro zoom',
                  center='5h -32d', radius='5 deg', rotate='20 deg')

.. plot::
   :context: reset
   :align: center

    import ligo.skymap.plot
    from matplotlib import pyplot as plt
    ax = plt.axes(projection='geo globe', center='-50d +23d')

Complete Example
The following example demonstrates most of the features of this module.

.. plot::
   :context: reset
   :align: center

    from astropy.coordinates import SkyCoord
    from import fits
    from astropy import units as u
    import ligo.skymap.plot
    from matplotlib import pyplot as plt

    url = ''
    center = SkyCoord.from_name('NGC 4993')

    fig = plt.figure(figsize=(4, 4), dpi=100)

    ax = plt.axes(
        [0.05, 0.05, 0.9, 0.9],
        projection='astro globe',

    ax_inset = plt.axes(
        [0.59, 0.3, 0.4, 0.4],
        projection='astro zoom',

    for key in ['ra', 'dec']:
    ax.connect_inset_axes(ax_inset, 'upper left')
    ax.connect_inset_axes(ax_inset, 'lower left')
    ax_inset.scalebar((0.1, 0.1), 5 * u.deg).label()
    ax_inset.compass(0.9, 0.1, 0.2)

    ax.imshow_hpx(url, cmap='cylon')
    ax_inset.imshow_hpx(url, cmap='cylon')
        center.ra.deg, center.dec.deg,

"""  # noqa: E501
from itertools import product

from astropy.convolution import convolve_fft, Gaussian2DKernel
from astropy.coordinates import SkyCoord
from import Header
from astropy.time import Time
from astropy.visualization.wcsaxes import WCSAxes
from astropy.visualization.wcsaxes.formatter_locator import AngleFormatterLocator
from astropy.visualization.wcsaxes.frame import EllipticalFrame
from astropy.wcs import WCS
from astropy import units as u
from matplotlib import rcParams
from matplotlib.offsetbox import AnchoredOffsetbox
from matplotlib.patches import ConnectionPatch, FancyArrowPatch, PathPatch
from matplotlib.projections import projection_registry
import numpy as np
from reproject import reproject_from_healpix
from scipy.optimize import minimize_scalar
from .angle import reference_angle_deg
from . import itrs_frame_monkeypatch


__all__ = ["AutoScaledWCSAxes", "ScaleBar"]

class WCSInsetPatch(PathPatch):
    """Subclass of `matplotlib.patches.PathPatch` for marking the outline of
    one `astropy.visualization.wcsaxes.WCSAxes` inside another.

    def __init__(self, ax, *args, **kwargs):
        self._ax = ax

    def get_path(self):
        frame = self._ax.coords.frame
        return frame.patch.get_path().interpolated(50).transformed(frame.transform)

class WCSInsetConnectionPatch(ConnectionPatch):
    """Patch to connect an inset WCS axes inside another WCS axes."""

    _corners_map = {1: 3, 2: 1, 3: 0, 4: 2}

    def __init__(self, ax, ax_inset, loc, *args, **kwargs):
            loc =[loc]
        except KeyError:
            loc = int(loc)
        corners = ax_inset.viewLim.corners()
        transform = (
            ax_inset.coords.frame.transform + ax.coords.frame.transform.inverted()
        xy_inset = corners[self._corners_map[loc]]
        xy = transform.transform_point(xy_inset)

[docs]class AutoScaledWCSAxes(WCSAxes): """Axes base class. The pixel scale is adjusted to the DPI of the image, and there are a variety of convenience methods. """ name = "astro wcs" def __init__(self, *args, header, obstime=None, **kwargs): super().__init__(*args, aspect=1, **kwargs) h = Header(header, copy=True) naxis1 = h["NAXIS1"] naxis2 = h["NAXIS2"] scale = min(self.bbox.width / naxis1, self.bbox.height / naxis2) h["NAXIS1"] = int(np.ceil(naxis1 * scale)) h["NAXIS2"] = int(np.ceil(naxis2 * scale)) scale1 = h["NAXIS1"] / naxis1 scale2 = h["NAXIS2"] / naxis2 h["CRPIX1"] = (h["CRPIX1"] - 1) * (h["NAXIS1"] - 1) / (naxis1 - 1) + 1 h["CRPIX2"] = (h["CRPIX2"] - 1) * (h["NAXIS2"] - 1) / (naxis2 - 1) + 1 h["CDELT1"] /= scale1 h["CDELT2"] /= scale2 if obstime is not None: h["DATE-OBS"] = Time(obstime).utc.isot self.reset_wcs(WCS(h)) self.set_xlim(-0.5, h["NAXIS1"] - 0.5) self.set_ylim(-0.5, h["NAXIS2"] - 0.5) self._header = h @property def header(self): return self._header
[docs] def mark_inset_axes(self, ax, *args, **kwargs): """Outline the footprint of another WCSAxes inside this one. Parameters ---------- ax : `astropy.visualization.wcsaxes.WCSAxes` The other axes. Other parameters ---------------- args : Extra arguments for `matplotlib.patches.PathPatch` kwargs : Extra keyword arguments for `matplotlib.patches.PathPatch` Returns ------- patch : `matplotlib.patches.PathPatch` """ return self.add_patch( WCSInsetPatch(ax, *args, transform=self.get_transform("world"), **kwargs) )
[docs] def connect_inset_axes(self, ax, loc, *args, **kwargs): """Connect a corner of another WCSAxes to the matching point inside this one. Parameters ---------- ax : `astropy.visualization.wcsaxes.WCSAxes` The other axes. loc : int, str Which corner to connect. For valid values, see `matplotlib.offsetbox.AnchoredOffsetbox`. Other parameters ---------------- args : Extra arguments for `matplotlib.patches.ConnectionPatch` kwargs : Extra keyword arguments for `matplotlib.patches.ConnectionPatch` Returns ------- patch : `matplotlib.patches.ConnectionPatch` """ return self.add_patch(WCSInsetConnectionPatch(self, ax, loc, *args, **kwargs))
[docs] def compass(self, x, y, size): """Add a compass to indicate the north and east directions. Parameters ---------- x, y : float Position of compass vertex in axes coordinates. size : float Size of compass in axes coordinates. """ xy = x, y scale = self.wcs.pixel_scale_matrix scale /= np.sqrt(np.abs(np.linalg.det(scale))) return [ self.annotate( label, xy, xy + size * n, self.transAxes, self.transAxes, ha="center", va="center", arrowprops=dict(arrowstyle="<-", shrinkA=0.0, shrinkB=0.0), ) for n, label, ha, va in zip( scale, "EN", ["right", "center"], ["center", "bottom"] ) ]
[docs] def scalebar(self, *args, **kwargs): """Add scale bar. Parameters ---------- xy : tuple The axes coordinates of the scale bar. length : `astropy.units.Quantity` The length of the scale bar in angle-compatible units. Other parameters ---------------- args : Extra arguments for `matplotlib.patches.FancyArrowPatch` kwargs : Extra keyword arguments for `matplotlib.patches.FancyArrowPatch` Returns ------- patch : `matplotlib.patches.FancyArrowPatch` """ return self.add_patch(ScaleBar(self, *args, **kwargs))
def _reproject_hpx( self, data, hdu_in=None, order="bilinear", nested=False, field=0, smooth=None ): if isinstance(data, np.ndarray): data = (data, self.header["RADESYS"]) # It's normal for reproject_from_healpix to produce some Numpy invalid # value warnings for points that land outside the projection. with np.errstate(invalid="ignore"): img, mask = reproject_from_healpix( data, self.header, hdu_in=hdu_in, order=order, nested=nested, field=field, ) img =, mask=~mask.astype(bool)) if smooth is not None: pixsize = np.mean(np.abs(self.wcs.wcs.cdelt)) * u.deg smooth = (smooth / pixsize).to(u.dimensionless_unscaled).value kernel = Gaussian2DKernel(smooth) # Ignore divide by zero warnings for pixels that have no valid # neighbors. with np.errstate(invalid="ignore"): img = convolve_fft(img, kernel, fill_value=np.nan) return img
[docs] def contour_hpx( self, data, hdu_in=None, order="bilinear", nested=False, field=0, smooth=None, **kwargs ): """Add contour levels for a HEALPix data set. Parameters ---------- data : `numpy.ndarray` or str or `` or `` or tuple The HEALPix data set. If this is a `numpy.ndarray`, then it is interpreted as the HEALPix array in the same coordinate system as the axes. Otherwise, the input data can be any type that is understood by `reproject.reproject_from_healpix`. smooth : `astropy.units.Quantity`, optional An optional smoothing length in angle-compatible units. Other parameters ---------------- hdu_in, order, nested, field, smooth : Extra arguments for `reproject.reproject_from_healpix` kwargs : Extra keyword arguments for `matplotlib.axes.Axes.contour` Returns ------- countours : `matplotlib.contour.QuadContourSet` """ # noqa: E501 img = self._reproject_hpx( data, hdu_in=hdu_in, order=order, nested=nested, field=field, smooth=smooth ) return self.contour(img, **kwargs)
[docs] def contourf_hpx( self, data, hdu_in=None, order="bilinear", nested=False, field=0, smooth=None, **kwargs ): """Add filled contour levels for a HEALPix data set. Parameters ---------- data : `numpy.ndarray` or str or `` or `` or tuple The HEALPix data set. If this is a `numpy.ndarray`, then it is interpreted as the HEALPix array in the same coordinate system as the axes. Otherwise, the input data can be any type that is understood by `reproject.reproject_from_healpix`. smooth : `astropy.units.Quantity`, optional An optional smoothing length in angle-compatible units. Other parameters ---------------- hdu_in, order, nested, field, smooth : Extra arguments for `reproject.reproject_from_healpix` kwargs : Extra keyword arguments for `matplotlib.axes.Axes.contour` Returns ------- contours : `matplotlib.contour.QuadContourSet` """ # noqa: E501 img = self._reproject_hpx( data, hdu_in=hdu_in, order=order, nested=nested, field=field, smooth=smooth ) return self.contourf(img, **kwargs)
[docs] def imshow_hpx( self, data, hdu_in=None, order="bilinear", nested=False, field=0, smooth=None, **kwargs ): """Add an image for a HEALPix data set. Parameters ---------- data : `numpy.ndarray` or str or `` or `` or tuple The HEALPix data set. If this is a `numpy.ndarray`, then it is interpreted as the HEALPix array in the same coordinate system as the axes. Otherwise, the input data can be any type that is understood by `reproject.reproject_from_healpix`. smooth : `astropy.units.Quantity`, optional An optional smoothing length in angle-compatible units. Other parameters ---------------- hdu_in, order, nested, field, smooth : Extra arguments for `reproject.reproject_from_healpix` kwargs : Extra keyword arguments for `matplotlib.axes.Axes.contour` Returns ------- image : `matplotlib.image.AxesImage` """ # noqa: E501 img = self._reproject_hpx( data, hdu_in=hdu_in, order=order, nested=nested, field=field, smooth=smooth ) return self.imshow(img, **kwargs)
[docs]class ScaleBar(FancyArrowPatch): def _func(self, dx, x, y): p1, p2 = self._transAxesToWorld.transform([[x, y], [x + dx, y]]) p1 = SkyCoord(*p1, unit=u.deg) p2 = SkyCoord(*p2, unit=u.deg) return np.square((p1.separation(p2) - self._length).value) def __init__(self, ax, xy, length, *args, **kwargs): x, y = xy self._ax = ax self._length = u.Quantity(length) self._transAxesToWorld = ( ax.transAxes - ax.transData ) + ax.coords.frame.transform dx = minimize_scalar(self._func, args=xy, bounds=[0, 1 - x], method="bounded").x custom_kwargs = kwargs kwargs = dict( capstyle="round", color="black", linewidth=rcParams["lines.linewidth"], ) kwargs.update(custom_kwargs) super().__init__( xy, (x + dx, y), *args, arrowstyle="-", shrinkA=0.0, shrinkB=0.0, transform=ax.transAxes, **kwargs )
[docs] def label(self, **kwargs): (x0, y), (x1, _) = self._posA_posB s = u" {0.value:g}{0.unit:unicode}".format(self._length) return self._ax.text( 0.5 * (x0 + x1), y, s, ha="center", va="bottom", transform=self._ax.transAxes, **kwargs )
class Astro: _crval1 = 180 _xcoord = "RA--" _ycoord = "DEC-" _radesys = "ICRS" class GeoAngleFormatterLocator(AngleFormatterLocator): def formatter(self, values, spacing): return super().formatter( reference_angle_deg( * u.deg, spacing ) class Geo: _crval1 = 0 _radesys = "ITRS" _xcoord = "TLON" _ycoord = "TLAT" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.invert_xaxis() fl = self.coords[0]._formatter_locator self.coords[0]._formatter_locator = GeoAngleFormatterLocator( values=fl.values, number=fl.number, spacing=fl.spacing, format=fl.format, format_unit=fl.format_unit, ) class Degrees: """WCS axes with longitude axis in degrees.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.coords[0].set_format_unit( class Hours: """WCS axes with longitude axis in hour angle.""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.coords[0].set_format_unit(u.hourangle) class Globe(AutoScaledWCSAxes): def __init__(self, *args, center="0d 0d", rotate=None, **kwargs): center = SkyCoord(center).icrs header = { "NAXIS": 2, "NAXIS1": 180, "NAXIS2": 180, "CRPIX1": 90.5, "CRPIX2": 90.5, "CRVAL1": center.ra.deg, "CRVAL2": center.dec.deg, "CDELT1": -2 / np.pi, "CDELT2": 2 / np.pi, "CTYPE1": self._xcoord + "-SIN", "CTYPE2": self._ycoord + "-SIN", "RADESYS": self._radesys, } if rotate is not None: header["LONPOLE"] = u.Quantity(rotate).to_value(u.deg) super().__init__(*args, frame_class=EllipticalFrame, header=header, **kwargs) class Zoom(AutoScaledWCSAxes): def __init__(self, *args, center="0d 0d", radius="1 deg", rotate=None, **kwargs): center = SkyCoord(center).icrs radius = u.Quantity(radius).to(u.deg).value header = { "NAXIS": 2, "NAXIS1": 512, "NAXIS2": 512, "CRPIX1": 256.5, "CRPIX2": 256.5, "CRVAL1": center.ra.deg, "CRVAL2": center.dec.deg, "CDELT1": -radius / 256, "CDELT2": radius / 256, "CTYPE1": self._xcoord + "-TAN", "CTYPE2": self._ycoord + "-TAN", "RADESYS": self._radesys, } if rotate is not None: header["LONPOLE"] = u.Quantity(rotate).to_value(u.deg) super().__init__(*args, header=header, **kwargs) class AllSkyAxes(AutoScaledWCSAxes): """Base class for a multi-purpose all-sky projection.""" def __init__(self, *args, **kwargs): header = { "NAXIS": 2, "NAXIS1": 360, "NAXIS2": 180, "CRPIX1": 180.5, "CRPIX2": 90.5, "CRVAL1": self._crval1, "CRVAL2": 0.0, "CDELT1": -2 * np.sqrt(2) / np.pi, "CDELT2": 2 * np.sqrt(2) / np.pi, "CTYPE1": self._xcoord + "-" + self._wcsprj, "CTYPE2": self._ycoord + "-" + self._wcsprj, "RADESYS": self._radesys, } super().__init__(*args, frame_class=EllipticalFrame, header=header, **kwargs) self.coords[0].set_ticks(spacing=45 * u.deg) self.coords[1].set_ticks(spacing=30 * u.deg) self.coords[0].set_ticklabel(exclude_overlapping=True) self.coords[1].set_ticklabel(exclude_overlapping=True) class Aitoff(AllSkyAxes): _wcsprj = "AIT" class Mollweide(AllSkyAxes): _wcsprj = "MOL" moddict = globals() # # Create subclasses and register all projections: # '{astro|geo} {hours|degrees} {aitoff|globe|mollweide|zoom}' # bases1 = (Astro, Geo) bases2 = (Hours, Degrees) bases3 = (Aitoff, Globe, Mollweide, Zoom) for bases in product(bases1, bases2, bases3): class_name = "".join(cls.__name__ for cls in bases) + "Axes" projection = " ".join(cls.__name__.lower() for cls in bases) new_class = type(class_name, bases, {"name": projection}) projection_registry.register(new_class) moddict[class_name] = new_class __all__.append(class_name) # # Create some synonyms: # 'astro' will be short for 'astro hours', # 'geo' will be short for 'geo degrees' # for base1, base2 in zip(bases1, bases2): for base3 in (Aitoff, Globe, Mollweide, Zoom): bases = (base1, base2, base3) orig_class_name = "".join(cls.__name__ for cls in bases) + "Axes" orig_class = moddict[orig_class_name] class_name = "".join(cls.__name__ for cls in (base1, base3)) + "Axes" projection = " ".join(cls.__name__.lower() for cls in (base1, base3)) new_class = type(class_name, (orig_class,), {"name": projection}) projection_registry.register(new_class) moddict[class_name] = new_class __all__.append(class_name) del class_name, moddict, projection, projection_registry, new_class __all__ = tuple(__all__)