import collections
import contextlib
import numbers
import warnings
from pathlib import Path
from typing import ClassVar

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.colors import ListedColormap
from matplotlib.transforms import Bbox

from nilearn._utils.docs import fill_doc
from nilearn._utils.logger import find_stack_level
from nilearn._utils.niimg import _get_data, is_binary_niimg, safe_get_data
from nilearn._utils.param_validation import check_params
from nilearn.image import check_niimg_3d, get_data, new_img_like, reorder_img
from nilearn.image.image import _check_fov
from nilearn.image.resampling import get_bounds, get_mask_bounds, resample_img
from nilearn.plotting._engine_utils import create_colorbar_for_fig
from nilearn.plotting._utils import (
    DEFAULT_TICK_FORMAT,
    check_threshold_not_negative,
)
from nilearn.plotting.displays import CutAxes
from nilearn.plotting.displays._utils import (
    coords_3d_to_2d,
    get_create_display_fun,
)
from nilearn.plotting.displays.edge_detect import edge_map
from nilearn.plotting.find_cuts import find_cut_slices, find_xyz_cut_coords
from nilearn.typing import NiimgLike


@fill_doc
class BaseSlicer:
    """BaseSlicer implementation which main purpose is to auto adjust \
    the axes size to the data with different layout of cuts.

    It creates 3 linked axes for plotting orthogonal cuts.

    Attributes
    ----------
    cut_coords : :obj:`list`, :obj:`tuple` or :obj:`dict`
        The world coordinates to be used. The concrete type depends on the
        subclass.

    axes : :obj:`dict` of :class:`~matplotlib.axes.Axes`
        The axes used for plotting in each direction of the cut.

    %(displays_partial_attributes)s

    """

    # This actually encodes the figsize for only one axe
    _default_figsize: ClassVar[list[float]] = [2.2, 2.6]
    _axes_class: type[CutAxes] = CutAxes

    def __init__(
        self,
        cut_coords,
        axes=None,
        black_bg=False,
        brain_color=(0.5, 0.5, 0.5),
        **kwargs,
    ):
        self.cut_coords = cut_coords
        if axes is None:
            axes = plt.axes((0.0, 0.0, 1.0, 1.0))
            axes.axis("off")
        self.frame_axes = axes
        axes.set_zorder(1)
        bb = axes.get_position()
        self.rect = (bb.x0, bb.y0, bb.x1, bb.y1)
        self._black_bg = black_bg
        self._brain_color = brain_color
        self._colorbar = False
        self._colorbar_width = 0.05 * bb.width
        self._cbar_tick_format = DEFAULT_TICK_FORMAT
        self._colorbar_margin = {
            "left": 0.25 * bb.width,
            "right": 0.02 * bb.width,
            "top": 0.05 * bb.height,
            "bottom": 0.05 * bb.height,
        }
        self._init_axes(**kwargs)

    @property
    def brain_color(self):
        """Return brain color."""
        return self._brain_color

    @property
    def black_bg(self):
        """Return black background."""
        return self._black_bg

    @classmethod
    @fill_doc
    def find_cut_coords(cls, img=None, threshold=None, cut_coords=None):
        """Find world coordinates of cut positions compatible with this slicer.

        Parameters
        ----------
        img : 3D :class:`~nibabel.nifti1.Nifti1Image`, default=None
            The brain image.

        %(threshold)s

        %(cut_coords)s

        Returns
        -------
        cut_coords : :obj:`list`, :obj:`tuple` or :obj:`dict`
            Cut positions depending on slicer type.

        """
        raise NotImplementedError()

    @classmethod
    def _check_cut_coords_in_bounds(cls, img, cut_coords) -> None:
        """
        Check if the cut coordinates is within the image bounds.

        Parameters
        ----------
        img : 3D :class:`~nibabel.nifti1.Nifti1Image`
            The brain image.

        %(cut_coords)s

        Raises
        ------
        ValueError
            If none of the coords is in the specified image bounds.

        Warns
        -----
        UserWarning
            If at least one of the coordinates is not within the bounds.

        """
        if img is None or cut_coords is None:
            return
        data = _get_data(img)
        bounds = get_bounds(data.shape, img.affine)

        coord_in = cls._get_coords_in_bounds(bounds, cut_coords)

        bounds_str = (
            f"\n\tx: [{bounds[0][0]:.2f}, {bounds[0][1]:.2f}]"
            f"\n\ty: [{bounds[1][0]:.2f}, {bounds[1][1]:.2f}]"
            f"\n\tz: [{bounds[2][0]:.2f}, {bounds[2][1]:.2f}]"
        )

        # if none of the coordinates is in bounds
        # raise error
        if not any(coord_in):
            raise ValueError(
                f"Specified {cut_coords=} is out of the bounds of the image."
                "\nPlease specify coordinates within the bounds:"
                f"{bounds_str}"
            )
        # if at least one (but not all) of the coordinates is out of the
        # bounds, warn user
        if any(coord_in) and not all(coord_in):
            warnings.warn(
                (
                    f"Some of the specified cut_coords "
                    "seem to be out of the image bounds:"
                    f"{bounds_str}"
                ),
                UserWarning,
                stacklevel=find_stack_level(),
            )

    @classmethod
    def _cut_count(cls):
        """Return the number of cut directions for this slicer."""
        raise NotImplementedError()

    @classmethod
    def _sanitize_cut_coords(cls, cut_coords):
        """Sanitize the cut coordinates.

        Check if `cut_coords` is compatible with this slicer and adjust its
        value if necessary.

        Parameters
        ----------
        %(cut_coords)s

        Raises
        ------
        ValueError
            If `cut_coords` is not compatible with this slicer.

        """
        raise NotImplementedError()

    @classmethod
    def _get_coords_in_bounds(cls, bounds, cut_coords):
        """Return a list that has boolean values corresponding to each cut
        coordinate indicating if it is within the bounds of its direction or
        not.

        Parameters
        ----------
        bounds:
            valid bounds for the cut coordinates

        %(cut_coords)s

        Returns
        -------
        list[bool]
            a list of boolean values corresponding to each coordinate
        indicating if it is within the bounds or not

        """
        raise NotImplementedError()

    @classmethod
    @fill_doc  # the fill_doc decorator must be last applied
    def init_with_figure(
        cls,
        img,
        threshold=None,
        cut_coords=None,
        figure=None,
        axes=None,
        black_bg=False,
        leave_space=False,
        colorbar=False,
        brain_color=(0.5, 0.5, 0.5),
        **kwargs,
    ):
        """Initialize the slicer with an image.

        Parameters
        ----------
        %(img)s

        %(threshold)s

        %(cut_coords)s

        figure : :class:`matplotlib.figure.Figure`
            Figure to be used for plots.

        %(axes)s
            default=None
            The axes that will be subdivided in 3.

        %(black_bg)s
            default=False

        leave_space : :obj:`bool`, default=False
            If ``True``, leave space between the plots.

        %(colorbar)s
            Default=False.

        %(brain_color)s

        kwargs : :obj:`dict`
            Extra keyword arguments are passed to
            :class:`~nilearn.plotting.displays.CutAxes` used for plotting in
            each direction.

        Raises
        ------
        ValueError
            if the specified threshold is a negative number

        """
        check_params(locals())
        check_threshold_not_negative(threshold)

        # deal with "fake" 4D images
        if img is not None and img is not False:
            img = check_niimg_3d(img)

        cut_coords = cls.find_cut_coords(img, threshold, cut_coords)

        if isinstance(axes, plt.Axes) and figure is None:
            figure = axes.figure

        if not isinstance(figure, plt.Figure):
            # Make sure that we have a figure
            figsize = cls._default_figsize[:]

            # Adjust for the number of axes
            figsize[0] *= len(cut_coords)

            # Make space for the colorbar
            if colorbar:
                figsize[0] += 0.7

            facecolor = "k" if black_bg else "w"

            if leave_space:
                figsize[0] += 3.4
            figure = plt.figure(figure, figsize=figsize, facecolor=facecolor)
        if isinstance(axes, plt.Axes):
            assert axes.figure is figure, (
                "The axes passed are not in the figure"
            )

        if axes is None:
            axes = [0.3, 0, 0.7, 1.0] if leave_space else [0.0, 0.0, 1.0, 1.0]
        if isinstance(axes, collections.abc.Sequence):
            axes = figure.add_axes(axes)
        # People forget to turn their axis off, or to set the zorder, and
        # then they cannot see their slicer
        axes.axis("off")
        return cls(cut_coords, axes, black_bg, brain_color, **kwargs)

    def title(
        self,
        text,
        x=0.01,
        y=0.99,
        size=15,
        color=None,
        bgcolor=None,
        alpha=1,
        **kwargs,
    ):
        """Write a title to the view.

        Parameters
        ----------
        text : :obj:`str`
            The text of the title.

        x : :obj:`float`, default=0.01
            The horizontal position of the title on the frame in
            fraction of the frame width.

        y : :obj:`float`, default=0.99
            The vertical position of the title on the frame in
            fraction of the frame height.

        size : :obj:`int`, default=15
            The size of the title text.

        color : matplotlib color specifier, default=None
            The color of the font of the title.

        bgcolor : matplotlib color specifier, default=None
            The color of the background of the title.

        alpha : :obj:`float`, default=1
            The alpha value for the background.

        kwargs : :obj:`dict`
            Extra keyword arguments are passed to matplotlib's text
            function.

        """
        if color is None:
            color = "k" if self._black_bg else "w"
        if bgcolor is None:
            bgcolor = "w" if self._black_bg else "k"
        if hasattr(self, "_cut_displayed"):
            # Adapt to the case of mosaic plotting
            if isinstance(self.cut_coords, dict):
                first_axe = self._cut_displayed[-1]
                first_axe = (first_axe, self.cut_coords[first_axe][0])
            else:
                first_axe = self._cut_displayed[0]
        else:
            first_axe = self.cut_coords[0]
        ax = self.axes[first_axe].ax

        kwargs |= {
            "horizontalalignment": "left",
            "verticalalignment": "top",
            "zorder": 1000,
        }

        ax.text(
            x,
            y,
            text,
            transform=self.frame_axes.transAxes,
            size=size,
            color=color,
            bbox={
                "boxstyle": "square,pad=.3",
                "ec": bgcolor,
                "fc": bgcolor,
                "alpha": alpha,
            },
            **kwargs,
        )
        ax.set_zorder(1000)

    @fill_doc
    def add_overlay(
        self,
        img,
        threshold=1e-6,
        colorbar=False,
        cbar_tick_format=DEFAULT_TICK_FORMAT,
        cbar_vmin=None,
        cbar_vmax=None,
        transparency=None,
        transparency_range=None,
        **kwargs,
    ) -> None:
        """Plot a 3D map in all the views.

        Parameters
        ----------
        %(img)s
            If it is a masked array, only the non-masked part will be plotted.

        threshold : :obj:`int` or :obj:`float` or ``None``, default=1e-6
            Threshold to apply:

            - If ``None`` is given, the maps are not thresholded.
            - If number is given, it must be non-negative. The specified
                value is used to threshold the image: values below the
                threshold (in absolute value) are plotted as transparent.

        %(colorbar)s
            Default=False.

        cbar_tick_format : str, default="%%.2g" (scientific notation)
            Controls how to format the tick labels of the colorbar.
            Ex: use "%%i" to display as integers.

        cbar_vmin : :obj:`float`, default=None
            Minimal value for the colorbar. If None, the minimal value
            is computed based on the data.

        cbar_vmax : :obj:`float`, default=None
            Maximal value for the colorbar. If None, the maximal value
            is computed based on the data.

        %(transparency)s

        %(transparency_range)s

        kwargs : :obj:`dict`
            Extra keyword arguments are passed to function
            :func:`~matplotlib.pyplot.imshow`.

        Raises
        ------
        ValueError
            if the specified threshold is a negative number

        """
        check_threshold_not_negative(threshold)

        if colorbar and self._colorbar:
            raise ValueError(
                "This figure already has an overlay with a colorbar."
            )

        self._colorbar = colorbar
        self._cbar_tick_format = cbar_tick_format

        img = check_niimg_3d(img)

        # Make sure that add_overlay shows consistent default behavior
        # with plot_stat_map
        kwargs.setdefault("interpolation", "nearest")
        ims = self._map_show(
            img,
            type="imshow",
            threshold=threshold,
            transparency=transparency,
            transparency_range=transparency_range,
            **kwargs,
        )

        # `ims` can be empty in some corner cases,
        # look at test_img_plotting.test_outlier_cut_coords.
        if colorbar and ims:
            self._show_colorbar(
                ims[0].cmap, ims[0].norm, cbar_vmin, cbar_vmax, threshold
            )

        plt.draw_if_interactive()

    @fill_doc
    def add_contours(
        self, img, threshold=1e-6, filled=False, **kwargs
    ) -> None:
        """Contour a 3D map in all the views.

        Parameters
        ----------
        %(img)s
            Provides image to plot.

        threshold : :obj:`int` or :obj:`float` or ``None``, default=1e-6
            Threshold to apply:

            - If ``None`` is given, the maps are not thresholded.
            - If number is given, it must be non-negative. The specified
                value is used to threshold the image: values below the
                threshold (in absolute value) are plotted as transparent.

        filled : :obj:`bool`, default=False
            If ``filled=True``, contours are displayed with color fillings.

        kwargs : :obj:`dict`
            Extra keyword arguments are passed to function
            :func:`~matplotlib.pyplot.contour`, or function
            :func:`~matplotlib.pyplot.contourf`.
            Useful, arguments are typical "levels", which is a
            list of values to use for plotting a contour or contour
            fillings (if ``filled=True``), and
            "colors", which is one color or a list of colors for
            these contours.

        Raises
        ------
        ValueError
            if the specified threshold is a negative number

        Notes
        -----
        If colors are not specified, default coloring choices
        (from matplotlib) for contours and contour_fillings can be
        different.

        """
        if not filled:
            threshold = None
        else:
            check_threshold_not_negative(threshold)

        self._map_show(img, type="contour", threshold=threshold, **kwargs)
        if filled:
            if "levels" in kwargs:
                levels = kwargs["levels"]
                if len(levels) <= 1:
                    # contour fillings levels
                    # should be given as (lower, upper).
                    levels.append(np.inf)

            if "linewidths" in kwargs:
                warnings.warn(
                    "'linewidths' is not supported for filled contours",
                    UserWarning,
                    stacklevel=find_stack_level(),
                )
                kwargs.pop("linewidths")

            self._map_show(img, type="contourf", threshold=threshold, **kwargs)

        plt.draw_if_interactive()

    def _map_show(
        self,
        img,
        type="imshow",
        resampling_interpolation="continuous",
        threshold=None,
        transparency=None,
        transparency_range=None,
        **kwargs,
    ):
        # In the special case where the affine of img is not diagonal,
        # the function `reorder_img` will trigger a resampling
        # of the provided image with a continuous interpolation
        # since this is the default value here. In the special
        # case where this image is binary, such as when this function
        # is called from `add_contours`, continuous interpolation
        # does not make sense and we turn to nearest interpolation instead.

        if is_binary_niimg(img):
            resampling_interpolation = "nearest"

        # Image reordering should be done before sanitizing transparency
        img = reorder_img(img, resample=resampling_interpolation)

        transparency, transparency_affine = self._sanitize_transparency(
            img,
            transparency,
            transparency_range,
            resampling_interpolation,
        )

        affine = img.affine

        if threshold is not None:
            threshold = float(threshold)
            data = safe_get_data(img, ensure_finite=True)
            data = self._threshold(data, threshold, None, None)
            img = new_img_like(img, data, affine)

        data = safe_get_data(img, ensure_finite=True)
        data_bounds = get_bounds(data.shape, affine)
        (xmin, xmax), (ymin, ymax), (zmin, zmax) = data_bounds

        xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = (
            xmin,
            xmax,
            ymin,
            ymax,
            zmin,
            zmax,
        )

        # Compute tight bounds
        if type in ("contour", "contourf"):
            # Define a pseudo threshold to have a tight bounding box
            thr = (
                0.9 * np.min(np.abs(kwargs["levels"]))
                if "levels" in kwargs
                else 1e-6
            )
            not_mask = np.logical_or(data > thr, data < -thr)
            xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = get_mask_bounds(
                new_img_like(img, not_mask, affine)
            )

        elif hasattr(data, "mask") and isinstance(data.mask, np.ndarray):
            not_mask = np.logical_not(data.mask)
            xmin_, xmax_, ymin_, ymax_, zmin_, zmax_ = get_mask_bounds(
                new_img_like(img, not_mask, affine)
            )

        data_2d_list = []
        transparency_list = []
        for display_ax in self.axes.values():
            if transparency is None or isinstance(transparency, (float, int)):
                transparency_2d = transparency

            try:
                data_2d = display_ax.transform_to_2d(data, affine)
                if isinstance(transparency, np.ndarray):
                    transparency_2d = display_ax.transform_to_2d(
                        transparency, transparency_affine
                    )
            except IndexError:
                # We are cutting outside the indices of the data
                data_2d = None
                transparency_2d = None

            data_2d_list.append(data_2d)
            transparency_list.append(transparency_2d)

        if kwargs.get("vmin") is None:
            kwargs["vmin"] = np.ma.min(
                [d.min() for d in data_2d_list if d is not None]
            )
        if kwargs.get("vmax") is None:
            kwargs["vmax"] = np.ma.max(
                [d.max() for d in data_2d_list if d is not None]
            )

        bounding_box = (xmin_, xmax_), (ymin_, ymax_), (zmin_, zmax_)
        ims = []
        to_iterate_over = zip(
            self.axes.values(), data_2d_list, transparency_list, strict=False
        )
        threshold = float(threshold) if threshold else None
        for display_ax, data_2d, transparency_2d in to_iterate_over:
            # If data_2d is completely masked, then there is nothing to
            # plot. Hence, no point to do imshow().
            if data_2d is not None:
                data_2d = self._threshold(
                    data_2d,
                    threshold,
                    vmin=float(kwargs.get("vmin")),
                    vmax=float(kwargs.get("vmax")),
                )

                im = display_ax.draw_2d(
                    data_2d,
                    data_bounds,
                    bounding_box,
                    type=type,
                    transparency=transparency_2d,
                    **kwargs,
                )
                ims.append(im)
        return ims

    def _sanitize_transparency(
        self, img, transparency, transparency_range, resampling_interpolation
    ):
        """Return transparency as None, float or an array.

        Return
        ------
        transparency: None, float or np.ndarray

        transparency_affine: None or np.ndarray

        """
        transparency_affine = None
        if isinstance(transparency, NiimgLike):
            transparency = check_niimg_3d(transparency, dtype="auto")
            if is_binary_niimg(transparency):
                resampling_interpolation = "nearest"
            transparency = reorder_img(
                transparency, resample=resampling_interpolation
            )
            if not _check_fov(transparency, img.affine, img.shape[:3]):
                warnings.warn(
                    "resampling transparency image to data image...",
                    stacklevel=find_stack_level(),
                )
                transparency = resample_img(
                    transparency,
                    img.affine,
                    img.shape,
                    interpolation=resampling_interpolation,
                )

            transparency_affine = transparency.affine
            transparency = safe_get_data(transparency, ensure_finite=True)

        assert transparency is None or isinstance(
            transparency, (int, float, np.ndarray)
        )

        if isinstance(transparency, (float, int)):
            transparency = float(transparency)
            base_warning_message = (
                "'transparency' must be in the interval [0, 1]. "
            )
            if transparency > 1.0:
                warnings.warn(
                    f"{base_warning_message} Setting it to 1.0.",
                    stacklevel=find_stack_level(),
                )
                transparency = 1.0
            if transparency < 0:
                warnings.warn(
                    f"{base_warning_message} Setting it to 0.0.",
                    stacklevel=find_stack_level(),
                )
                transparency = 0.0

        elif isinstance(transparency, np.ndarray):
            transparency = np.abs(transparency)

            if transparency_range is None:
                transparency_range = [0.0, np.max(transparency)]

            error_msg = (
                "'transparency_range' must be "
                "a list or tuple of 2 non-negative numbers "
                "with 'first value < second value'."
            )

            if len(transparency_range) != 2:
                raise ValueError(f"{error_msg} Got '{transparency_range}'.")

            transparency_range[1] = min(
                transparency_range[1], np.max(transparency)
            )
            transparency_range[0] = max(
                transparency_range[0], np.min(transparency)
            )

            if transparency_range[0] >= transparency_range[1]:
                raise ValueError(f"{error_msg} Got '{transparency_range}'.")

            # make sure that 0 <= transparency <= 1
            # taking into account the requested transparency_range
            transparency = np.clip(
                transparency, transparency_range[0], transparency_range[1]
            )
            transparency = (transparency - transparency_range[0]) / (
                transparency_range[1] - transparency_range[0]
            )

        return transparency, transparency_affine

    @classmethod
    def _threshold(cls, data, threshold=None, vmin=None, vmax=None):
        """Threshold the data.

        Parameters
        ----------
        data: ndarray
            data to be thresholded

        %(threshold)s

        %(vmin)s

        %(vmax)s

        Raises
        ------
        ValueError
            if the specified threshold is a negative number

        """
        check_params(locals())
        check_threshold_not_negative(threshold)

        if threshold is not None:
            data = np.ma.masked_where(
                np.abs(data) <= threshold,
                data,
                copy=False,
            )

            if (vmin is not None) and (vmin >= -threshold):
                data = np.ma.masked_where(data < vmin, data, copy=False)
            if (vmax is not None) and (vmax <= threshold):
                data = np.ma.masked_where(data > vmax, data, copy=False)

        return data

    @fill_doc
    def _show_colorbar(
        self, cmap, norm, cbar_vmin=None, cbar_vmax=None, threshold=None
    ):
        """Display the colorbar.

        Parameters
        ----------
        %(cmap)s
        norm : :class:`~matplotlib.colors.Normalize`
            This object is typically found as the ``norm`` attribute of
            :class:`~matplotlib.image.AxesImage`.

        threshold : :obj:`float` or ``None``, default=None
            The absolute value at which the colorbar is thresholded.

        cbar_vmin : :obj:`float`, default=None
            Minimal value for the colorbar. If None, the minimal value
            is computed based on the data.

        cbar_vmax : :obj:`float`, default=None
            Maximal value for the colorbar. If None, the maximal value
            is computed based on the data.

        """
        # create new  axis for the colorbar
        figure = self.frame_axes.figure
        _, y0, x1, y1 = self.rect
        height = y1 - y0
        x_adjusted_width = self._colorbar_width / len(self.axes)
        x_adjusted_margin = self._colorbar_margin["right"] / len(self.axes)
        lt_wid_top_ht = [
            x1 - (x_adjusted_width + x_adjusted_margin),
            y0 + self._colorbar_margin["top"],
            x_adjusted_width,
            height
            - (self._colorbar_margin["top"] + self._colorbar_margin["bottom"]),
        ]
        self._colorbar_ax = figure.add_axes(lt_wid_top_ht)
        self._colorbar_ax.set_facecolor("w")

        self._cbar = create_colorbar_for_fig(
            figure,
            self._colorbar_ax,
            cmap,
            norm,
            threshold,
            cbar_vmin,
            cbar_vmax,
            tick_format=self._cbar_tick_format,
            spacing="proportional",
            orientation="vertical",
            threshold_color=(*self._brain_color, 0.0),
        )

        self._cbar.ax.set_facecolor(self._brain_color)
        self._colorbar_ax.yaxis.tick_left()
        tick_color = "w" if self._black_bg else "k"
        outline_color = "w" if self._black_bg else "k"

        for tick in self._colorbar_ax.yaxis.get_ticklabels():
            tick.set_color(tick_color)
        self._colorbar_ax.yaxis.set_tick_params(width=0)
        self._cbar.outline.set_edgecolor(outline_color)

    @fill_doc
    def add_edges(self, img, color="r"):
        """Plot the edges of a 3D map in all the views.

        Parameters
        ----------
        %(img)s
            The 3D map to be plotted.
            If it is a masked array, only the non-masked part will be plotted.

        color : matplotlib color: :obj:`str` or (r, g, b) value, default='r'
            The color used to display the edge map.

        """
        img = reorder_img(img, resample="continuous")
        data = get_data(img)
        affine = img.affine
        single_color_cmap = ListedColormap([color])
        data_bounds = get_bounds(data.shape, img.affine)

        # For each ax, cut the data and plot it
        for display_ax in self.axes.values():
            try:
                data_2d = display_ax.transform_to_2d(data, affine)
                edge_mask = edge_map(data_2d)
            except IndexError:
                # We are cutting outside the indices of the data
                continue
            display_ax.draw_2d(
                edge_mask,
                data_bounds,
                data_bounds,
                type="imshow",
                cmap=single_color_cmap,
            )

        plt.draw_if_interactive()

    def add_markers(
        self, marker_coords, marker_color="r", marker_size=30, **kwargs
    ):
        """Add markers to the plot.

        Parameters
        ----------
        marker_coords : :class:`~numpy.ndarray` of shape ``(n_markers, 3)``
            Coordinates of the markers to plot. For each slice, only markers
            that are 2 millimeters away from the slice are plotted.

        marker_color : pyplot compatible color or \
                     :obj:`list` of shape ``(n_markers,)``, default='r'
            List of colors for each marker
            that can be string or matplotlib colors.

        marker_size : :obj:`float` or \
                    :obj:`list` of :obj:`float` of shape ``(n_markers,)``, \
                    default=30
            Size in pixel for each marker.

        kwargs : :obj:`dict`
            Extra keyword arguments are passed to
            :func:`matplotlib.pyplot.scatter`.
        """
        defaults = {"marker": "o", "zorder": 1000}
        marker_coords = np.asanyarray(marker_coords)
        for k, v in defaults.items():
            kwargs.setdefault(k, v)

        for display_ax in self.axes.values():
            direction = display_ax.direction
            coord = display_ax.coord
            marker_coords_2d, third_d = coords_3d_to_2d(
                marker_coords, direction, return_direction=True
            )
            xdata, ydata = marker_coords_2d.T
            # Allow markers only in their respective hemisphere
            # when appropriate
            marker_color_ = marker_color
            marker_size_ = marker_size
            if direction in ("lr"):
                if not isinstance(marker_color, str) and not isinstance(
                    marker_color, np.ndarray
                ):
                    marker_color_ = np.asarray(marker_color)
                xcoords, *_ = marker_coords.T
                if direction == "r":
                    relevant_coords = xcoords >= 0
                elif direction == "l":
                    relevant_coords = xcoords <= 0
                xdata = xdata[relevant_coords]
                ydata = ydata[relevant_coords]
                if (
                    not isinstance(marker_color, str)
                    and len(marker_color) != 1
                ):
                    marker_color_ = marker_color_[relevant_coords]
                if not isinstance(marker_size, numbers.Number):
                    marker_size_ = np.asarray(marker_size_)[relevant_coords]

            # Check if coord has integer represents a cut in direction
            # to follow the heuristic. If no foreground image is given
            # coordinate is empty or None. This case is valid for plotting
            # markers on glass brain without any foreground image.
            if isinstance(coord, numbers.Number):
                # Heuristic that plots only markers that are 2mm away
                # from the current slice.
                # XXX: should we keep this heuristic?
                mask = np.abs(third_d - coord) <= 2.0
                xdata = xdata[mask]
                ydata = ydata[mask]
            display_ax.ax.scatter(
                xdata, ydata, s=marker_size_, c=marker_color_, **kwargs
            )

    def annotate(
        self,
        left_right=True,
        positions=True,
        scalebar=False,
        size=12,
        scale_size=5.0,
        scale_units="cm",
        scale_loc=4,
        decimals=0,
        **kwargs,
    ):
        """Add annotations to the plot.

        Parameters
        ----------
        left_right : :obj:`bool`, default=True
            If ``True``, annotations indicating which side
            is left and which side is right are drawn.

        positions : :obj:`bool`, default=True
            If ``True``, annotations indicating the
            positions of the cuts are drawn.

        scalebar : :obj:`bool`, default=False
            If ``True``, cuts are annotated with a reference scale bar.
            For finer control of the scale bar, please check out
            the ``draw_scale_bar`` method on the axes in "axes" attribute
            of this object.

        size : :obj:`int`, default=12
            The size of the text used.

        scale_size : :obj:`int` or :obj:`float`, default=5.0
            The length of the scalebar, in units of ``scale_units``.

        scale_units : {'cm', 'mm'}, default='cm'
            The units for the ``scalebar``.

        scale_loc : :obj:`int`, default=4
            The positioning for the scalebar.
            Valid location codes are:

            - 1: "upper right"
            - 2: "upper left"
            - 3: "lower left"
            - 4: "lower right"
            - 5: "right"
            - 6: "center left"
            - 7: "center right"
            - 8: "lower center"
            - 9: "upper center"
            - 10: "center"

        decimals : :obj:`int`, default=0
            Number of decimal places on slice position annotation. If zero,
            the slice position is integer without decimal point.

        kwargs : :obj:`dict`
            Extra keyword arguments are passed to matplotlib's text
            function.

        """
        kwargs = kwargs.copy()
        if "color" not in kwargs:
            kwargs["color"] = "w" if self._black_bg else "k"
        bg_color = "k" if self._black_bg else "w"

        if left_right:
            for display_axis in self.axes.values():
                display_axis.draw_left_right(
                    size=size, bg_color=bg_color, **kwargs
                )

        if positions:
            for display_axis in self.axes.values():
                display_axis.draw_position(
                    size=size, bg_color=bg_color, decimals=decimals, **kwargs
                )

        if scalebar:
            axes = self.axes.values()
            for display_axis in axes:
                display_axis.draw_scale_bar(
                    bg_color=bg_color,
                    fontsize=size,
                    size=scale_size,
                    units=scale_units,
                    loc=scale_loc,
                    **kwargs,
                )

    def close(self) -> None:
        """Close the figure.

        This is necessary to avoid leaking memory.

        """
        plt.close(self.frame_axes.figure.number)

    def savefig(self, filename, dpi=None, **kwargs) -> None:
        """Save the figure to a file.

        Parameters
        ----------
        filename : :obj:`str`
            The file name to save to. Its extension determines the
            file type, typically '.png', '.svg' or '.pdf'.

        dpi : None or scalar, default=None
            The resolution in dots per inch.

        kwargs : :obj:`dict`
            Extra keyword arguments are passed to
            :func:`matplotlib.pyplot.savefig`.

        """
        facecolor = edgecolor = "k" if self._black_bg else "w"
        self.frame_axes.figure.savefig(
            filename,
            dpi=dpi,
            facecolor=facecolor,
            edgecolor=edgecolor,
            **kwargs,
        )


class _MultiDSlicer(BaseSlicer):
    # This should be set by each inheriting Slicer
    _cut_displayed: ClassVar[str] = ""

    @classmethod
    def find_cut_coords(cls, img=None, threshold=None, cut_coords=None):
        """Find world coordinates of cut positions compatible with this slicer.

        Parameters
        ----------
        img : 3D :class:`~nibabel.nifti1.Nifti1Image`, default=None
            The brain image.

        threshold : :obj:`int` or :obj:`float` or None, default=None
            Threshold to apply:

            - If ``None`` is given, the maps are not thresholded.
            - If number is given, it must be non-negative. The specified
                value is used to threshold the image: values below the
                threshold (in absolute value) are plotted as transparent.

        cut_coords : sequence of :obj:`float` or :obj:`int`, or None, \
                    default=None
            The world coordinates of the point where the cut is performed.

        Returns
        -------
        cut_coords : :obj:`tuple` of :obj:`float` or :obj:`int`
            The tuple of cut position corresponding to this slicer.

        Raises
        ------
        ValueError
            if the specified threshold is a negative number

        """
        # checks if cut_coords is compatible with this slicer
        # and adjust value if necessary
        cut_coords = cls._sanitize_cut_coords(cut_coords)
        if cut_coords is None:
            if img is None or img is False:
                cut_coords = (0, 0, 0)
            else:
                cut_coords = find_xyz_cut_coords(
                    img, activation_threshold=threshold
                )
            cut_coords = [
                cut_coords["xyz".find(direction)]
                for direction in cls._cut_displayed
            ]
        else:
            # check if cut_coords is within image bounds
            # when it is not generated by this function
            cls._check_cut_coords_in_bounds(img, cut_coords)

        return tuple(cut_coords)

    @classmethod
    def _cut_count(cls):
        return len(cls._cut_displayed)

    @classmethod
    def _sanitize_cut_coords(cls, cut_coords):
        """Sanitize the cut coordinates.

        Check if `cut_coords` is a sequence of numbers compatible with the
        number of cut directions of this slicer and adjust its value if
        necessary.

        Parameters
        ----------
        cut_coords : sequence of :obj:`float` or :obj:`int`, or None
            The world coordinates of the point where the cut is performed.

        Raises
        ------
        ValueError
            If `cut_coords` is not None or the number of elements in
        `cut_coords` is not equal to the number of cut directions of this
        slicer.

        """
        if not (
            cut_coords is None
            or (
                isinstance(cut_coords, (list, tuple, np.ndarray))
                and len(cut_coords) == cls._cut_count()
            )
        ):
            raise ValueError(
                "cut_coords passed does not match the display mode."
                f" {cls.__name__} plotting expects tuple of length "
                f"{cls._cut_count()} or None.\n"
                f"You provided cut_coords={cut_coords}."
            )
        return cut_coords

    @classmethod
    def _get_coords_in_bounds(cls, bounds, cut_coords) -> list[bool]:
        coord_in = []

        for index, direction in enumerate(cls._cut_displayed):
            coord_bounds = bounds["xyz".find(direction)]
            coord_in.append(
                coord_bounds[0] <= cut_coords[index] <= coord_bounds[1]
            )
        return coord_in


@fill_doc
class OrthoSlicer(_MultiDSlicer):
    """Class to create 3 linked axes for plotting orthogonal \
    cuts of 3D maps.

    This visualization mode can be activated
    from Nilearn plotting functions, like
    :func:`~nilearn.plotting.plot_img`, by setting
    ``display_mode='ortho'``:

     .. code-block:: python

         from nilearn.datasets import load_mni152_template
         from nilearn.plotting import plot_img

         img = load_mni152_template()
         # display is an instance of the OrthoSlicer class
         display = plot_img(img, display_mode="ortho")


    Parameters
    ----------
    cut_coords : 3-sequence of :obj:`float` or :obj:`int` or None
        The world coordinates ``(x, y, z)`` of the point where the cut is
        performed.

    %(slicer_init_parameters_partial)s

    Attributes
    ----------
    cut_coords : 3- :obj:`tuple` of :obj:`float` or :obj:`int`
        The world coordinates ``(x, y, z)`` of the point where the cut is
        performed.

    axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
        The axes used for plotting in each direction ('x', 'y' and 'z' here).

    %(displays_partial_attributes)s

    Notes
    -----
    The extent of the different axes are adjusted to fit the data
    best in the viewing area.

    See Also
    --------
    nilearn.plotting.displays.MosaicSlicer : Three cuts are performed \
    along multiple rows and columns.
    nilearn.plotting.displays.TiledSlicer : Three cuts are performed \
    and arranged in a 2x2 grid.

    """

    _cut_displayed: ClassVar[str] = "xyz"
    _default_figsize: ClassVar[list[float]] = [2.2, 3.5]

    def _init_axes(self, **kwargs):
        self.cut_coords = self.find_cut_coords(cut_coords=self.cut_coords)
        x0, y0, x1, y1 = self.rect
        facecolor = "k" if self._black_bg else "w"
        # Create our axes:
        self.axes = {}
        for index, direction in enumerate(self._cut_displayed):
            fh = self.frame_axes.get_figure()
            ax = fh.add_axes(
                [0.3 * index * (x1 - x0) + x0, y0, 0.3 * (x1 - x0), y1 - y0],
                aspect="equal",
            )
            ax.set_facecolor(facecolor)

            ax.axis("off")
            coord = self.cut_coords[index]
            display_ax = self._axes_class(ax, direction, coord, **kwargs)
            self.axes[direction] = display_ax
            ax.set_axes_locator(self._locator)

        if self._black_bg:
            for ax in self.axes.values():
                ax.ax.imshow(
                    np.zeros((2, 2, 3)),
                    extent=[-5000, 5000, -5000, 5000],
                    zorder=-500,
                    aspect="equal",
                )

            # To have a black background in PDF, we need to create a
            # patch in black for the background
            self.frame_axes.imshow(
                np.zeros((2, 2, 3)),
                extent=[-5000, 5000, -5000, 5000],
                zorder=-500,
                aspect="auto",
            )
            self.frame_axes.set_zorder(-1000)

    def _locator(
        self,
        axes,
        renderer,  # noqa: ARG002
    ):
        """Adjust the size of the axes.

        The locator function used by matplotlib to position axes.

        Here we put the logic used to adjust the size of the axes.

        ``renderer`` is required to match the matplotlib API.

        """
        x0, y0, x1, y1 = self.rect
        # A dummy axes, for the situation in which we are not plotting
        # all three (x, y, z) cuts
        dummy_ax = self._axes_class(None, None, None)
        width_dict = {dummy_ax.ax: 0}
        display_ax_dict = self.axes

        if self._colorbar:
            adjusted_width = self._colorbar_width / len(self.axes)
            right_margin = self._colorbar_margin["right"] / len(self.axes)
            ticks_margin = self._colorbar_margin["left"] / len(self.axes)
            x1 = x1 - (adjusted_width + ticks_margin + right_margin)

        for display_ax in display_ax_dict.values():
            bounds = display_ax.get_object_bounds()
            if not bounds:
                # This happens if the call to _map_show was not
                # successful. As it happens asynchronously (during a
                # refresh of the figure) we capture the problem and
                # ignore it: it only adds a non informative traceback
                bounds = [0, 1, 0, 1]
            xmin, xmax, _, _ = bounds
            width_dict[display_ax.ax] = xmax - xmin

        total_width = float(sum(width_dict.values()))
        for ax, width in width_dict.items():
            width_dict[ax] = width / total_width * (x1 - x0)

        direction_ax = [
            display_ax_dict.get(d, dummy_ax).ax for d in self._cut_displayed
        ]
        left_dict = {}
        for idx, ax in enumerate(direction_ax):
            left_dict[ax] = x0
            for prev_ax in direction_ax[:idx]:
                left_dict[ax] += width_dict[prev_ax]

        return Bbox(
            [[left_dict[axes], y0], [left_dict[axes] + width_dict[axes], y1]]
        )

    def draw_cross(self, cut_coords=None, **kwargs) -> None:
        """Draw a crossbar on the plot to show where the cut is performed.

        Parameters
        ----------
        cut_coords : 3-sequence of :obj:`float` or :obj:`int`, or None, \
                     default=None
            The position of the cross to draw in world coordinates
            ``(x, y, z)``.
            If ``None`` is passed, the ``OrthoSlicer``'s cut coordinates are
            used.

        kwargs : :obj:`dict`
            Extra keyword arguments are passed to function
            :func:`~matplotlib.pyplot.axhline`.

        """
        if cut_coords is None:
            cut_coords = self.cut_coords
        coords = {}
        for direction in "xyz":
            coords[direction] = (
                cut_coords[self._cut_displayed.index(direction)]
                if direction in self._cut_displayed
                else None
            )
        x, y, z = coords["x"], coords["y"], coords["z"]

        kwargs = kwargs.copy()
        if "color" not in kwargs:
            kwargs["color"] = ".8" if self._black_bg else "k"
        if "y" in self.axes:
            ax = self.axes["y"].ax
            if x is not None:
                ax.axvline(x, ymin=0.05, ymax=0.95, **kwargs)
            if z is not None:
                ax.axhline(z, **kwargs)

        if "x" in self.axes:
            ax = self.axes["x"].ax
            if y is not None:
                ax.axvline(y, ymin=0.05, ymax=0.95, **kwargs)
            if z is not None:
                ax.axhline(z, xmax=0.95, **kwargs)

        if "z" in self.axes:
            ax = self.axes["z"].ax
            if x is not None:
                ax.axvline(x, ymin=0.05, ymax=0.95, **kwargs)
            if y is not None:
                ax.axhline(y, **kwargs)


@fill_doc
class TiledSlicer(_MultiDSlicer):
    """A class to create 3 axes for plotting orthogonal \
    cuts of 3D maps, organized in a 2x2 grid.

    This visualization mode can be activated from Nilearn plotting functions,
    like :func:`~nilearn.plotting.plot_img`, by setting
    ``display_mode='tiled'``:

    .. code-block:: python

        from nilearn.datasets import load_mni152_template
        from nilearn.plotting import plot_img

        img = load_mni152_template()
        # display is an instance of the TiledSlicer class
        display = plot_img(img, display_mode="tiled")

    Parameters
    ----------
    cut_coords : 3-sequence of :obj:`float` or :obj:`int` or None
        The world coordinates ``(x, y, z)`` of the point where the cut is
        performed.

    %(slicer_init_parameters_partial)s

    Attributes
    ----------
    cut_coords : 3- :obj:`tuple` of :obj:`float` or :obj:`int`
        The world coordinates ``(x, y, z)`` of the point where the cut is
        performed.

    axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
        The axes used for plotting in each direction ('x', 'y' and 'z' here).

    %(displays_partial_attributes)s

    Notes
    -----
    The extent of the different axes are adjusted to fit the data
    best in the viewing area.

    See Also
    --------
    nilearn.plotting.displays.MosaicSlicer : Three cuts are performed \
    along multiple rows and columns.
    nilearn.plotting.displays.OrthoSlicer : Three cuts are performed \
       and arranged in a 2x2 grid.

    """

    _cut_displayed: ClassVar[str] = "xyz"
    _default_figsize: ClassVar[list[float]] = [2.0, 7.6]

    def _find_initial_axes_coord(self, index):
        """Find coordinates for initial axes placement for xyz cuts.

        Parameters
        ----------
        index : :obj:`int`
            Index corresponding to current cut 'x', 'y' or 'z'.

        Returns
        -------
        [coord1, coord2, coord3, coord4] : :obj:`list` of :obj:`int`
            x0, y0, x1, y1 coordinates used by matplotlib
            to position axes in figure.

        """
        rect_x0, rect_y0, rect_x1, rect_y1 = self.rect

        if index == 0:
            coord1 = rect_x1 - rect_x0
            coord2 = 0.5 * (rect_y1 - rect_y0) + rect_y0
            coord3 = 0.5 * (rect_x1 - rect_x0) + rect_x0
            coord4 = rect_y1 - rect_y0
        elif index == 1:
            coord1 = 0.5 * (rect_x1 - rect_x0) + rect_x0
            coord2 = 0.5 * (rect_y1 - rect_y0) + rect_y0
            coord3 = rect_x1 - rect_x0
            coord4 = rect_y1 - rect_y0
        elif index == 2:
            coord1 = rect_x1 - rect_x0
            coord2 = rect_y1 - rect_y0
            coord3 = 0.5 * (rect_x1 - rect_x0) + rect_x0
            coord4 = 0.5 * (rect_y1 - rect_y0) + rect_y0
        return [coord1, coord2, coord3, coord4]

    def _init_axes(self, **kwargs):
        """Initialize and place axes for display of 'xyz' cuts.

        Parameters
        ----------
        kwargs : :obj:`dict`
            Additional arguments to pass to ``self._axes_class``.

        """
        self.cut_coords = self.find_cut_coords(cut_coords=self.cut_coords)
        facecolor = "k" if self._black_bg else "w"

        self.axes = {}
        for index, direction in enumerate(self._cut_displayed):
            fh = self.frame_axes.get_figure()
            axes_coords = self._find_initial_axes_coord(index)
            ax = fh.add_axes(axes_coords, aspect="equal")

            ax.set_facecolor(facecolor)

            ax.axis("off")
            display_ax = self._axes_class(
                ax, direction, self.cut_coords[index], **kwargs
            )
            self.axes[direction] = display_ax
            ax.set_axes_locator(self._locator)

    def _adjust_width_height(
        self, width_dict, height_dict, rect_x0, rect_y0, rect_x1, rect_y1
    ):
        """Adjust absolute image width and height to ratios.

        Parameters
        ----------
        width_dict : :obj:`dict`
            Width of image cuts displayed in axes.

        height_dict : :obj:`dict`
            Height of image cuts displayed in axes.

        rect_x0, rect_y0, rect_x1, rect_y1 : :obj:`float`
            Matplotlib figure boundaries.

        Returns
        -------
        width_dict : :obj:`dict`
            Width ratios of image cuts for optimal positioning of axes.

        height_dict : :obj:`dict`
            Height ratios of image cuts for optimal positioning of axes.

        """
        total_height = 0
        total_width = 0

        if "y" in self.axes:
            ax = self.axes["y"].ax
            total_height += height_dict[ax]
            total_width += width_dict[ax]

        if "x" in self.axes:
            ax = self.axes["x"].ax
            total_width = total_width + width_dict[ax]

        if "z" in self.axes:
            ax = self.axes["z"].ax
            total_height = total_height + height_dict[ax]

        for ax, width in width_dict.items():
            width_dict[ax] = width / total_width * (rect_x1 - rect_x0)

        for ax, height in height_dict.items():
            height_dict[ax] = height / total_height * (rect_y1 - rect_y0)

        return (width_dict, height_dict)

    def _find_axes_coord(
        self,
        rel_width_dict,
        rel_height_dict,
        rect_x0,
        rect_y0,
        rect_x1,
        rect_y1,
    ):
        """Find coordinates for initial axes placement for xyz cuts.

        Parameters
        ----------
        rel_width_dict : :obj:`dict`
            Width ratios of image cuts for optimal positioning of axes.

        rel_height_dict : :obj:`dict`
            Height ratios of image cuts for optimal positioning of axes.

        rect_x0, rect_y0, rect_x1, rect_y1 : :obj:`float`
            Matplotlib figure boundaries.

        Returns
        -------
        coord1, coord2, coord3, coord4 : :obj:`dict`
            x0, y0, x1, y1 coordinates per axes used by matplotlib
            to position axes in figure.

        """
        coord1 = {}
        coord2 = {}
        coord3 = {}
        coord4 = {}

        if "y" in self.axes:
            ax = self.axes["y"].ax
            coord1[ax] = rect_x0
            coord2[ax] = (rect_y1) - rel_height_dict[ax]
            coord3[ax] = rect_x0 + rel_width_dict[ax]
            coord4[ax] = rect_y1

        if "x" in self.axes:
            ax = self.axes["x"].ax
            coord1[ax] = (rect_x1) - rel_width_dict[ax]
            coord2[ax] = (rect_y1) - rel_height_dict[ax]
            coord3[ax] = rect_x1
            coord4[ax] = rect_y1

        if "z" in self.axes:
            ax = self.axes["z"].ax
            coord1[ax] = rect_x0
            coord2[ax] = rect_y0
            coord3[ax] = rect_x0 + rel_width_dict[ax]
            coord4[ax] = rect_y0 + rel_height_dict[ax]

        return (coord1, coord2, coord3, coord4)

    def _locator(
        self,
        axes,
        renderer,  # noqa: ARG002
    ):
        """Adjust the size of the axes.

        The locator function used by matplotlib to position axes.

        Here we put the logic used to adjust the size of the axes.

        ``renderer`` is required to match the matplotlib API.

        """
        rect_x0, rect_y0, rect_x1, rect_y1 = self.rect

        # A dummy axes, for the situation in which we are not plotting
        # all three (x, y, z) cuts
        dummy_ax = self._axes_class(None, None, None)
        width_dict = {dummy_ax.ax: 0}
        height_dict = {dummy_ax.ax: 0}
        display_ax_dict = self.axes

        if self._colorbar:
            adjusted_width = self._colorbar_width / len(self.axes)
            right_margin = self._colorbar_margin["right"] / len(self.axes)
            ticks_margin = self._colorbar_margin["left"] / len(self.axes)
            rect_x1 = rect_x1 - (adjusted_width + ticks_margin + right_margin)

        for display_ax in display_ax_dict.values():
            bounds = display_ax.get_object_bounds()
            if not bounds:
                # This happens if the call to _map_show was not
                # successful. As it happens asynchronously (during a
                # refresh of the figure) we capture the problem and
                # ignore it: it only adds a non informative traceback
                bounds = [0, 1, 0, 1]
            xmin, xmax, ymin, ymax = bounds
            width_dict[display_ax.ax] = xmax - xmin
            height_dict[display_ax.ax] = ymax - ymin

        # relative image height and width
        rel_width_dict, rel_height_dict = self._adjust_width_height(
            width_dict, height_dict, rect_x0, rect_y0, rect_x1, rect_y1
        )

        coord1, coord2, coord3, coord4 = self._find_axes_coord(
            rel_width_dict, rel_height_dict, rect_x0, rect_y0, rect_x1, rect_y1
        )

        return Bbox(
            [[coord1[axes], coord2[axes]], [coord3[axes], coord4[axes]]]
        )

    def draw_cross(self, cut_coords=None, **kwargs) -> None:
        """Draw a crossbar on the plot to show where the cut is performed.

        Parameters
        ----------
        cut_coords : 3-sequence of :obj:`float` or :obj:`int`, or None, \
                     default=None

            The position of the cross to draw in world coordinates
            ``(x, y, z)``.
            If ``None`` is passed, the ``TiledSlicer``'s cut coordinates are
            used.

        kwargs : :obj:`dict`
            Extra keyword arguments are passed to function
            :func:`~matplotlib.pyplot.axhline`.

        """
        if cut_coords is None:
            cut_coords = self.cut_coords
        coords = {}
        for direction in "xyz":
            coords[direction] = (
                cut_coords[self._cut_displayed.index(direction)]
                if direction in self._cut_displayed
                else None
            )
        x, y, z = coords["x"], coords["y"], coords["z"]

        kwargs = kwargs.copy()
        if "color" not in kwargs:
            with contextlib.suppress(KeyError):
                kwargs["color"] = ".8" if self._black_bg else "k"

        if "y" in self.axes:
            ax = self.axes["y"].ax
            if x is not None:
                ax.axvline(x, **kwargs)
            if z is not None:
                ax.axhline(z, **kwargs)

        if "x" in self.axes:
            ax = self.axes["x"].ax
            if y is not None:
                ax.axvline(y, **kwargs)
            if z is not None:
                ax.axhline(z, **kwargs)

        if "z" in self.axes:
            ax = self.axes["z"].ax
            if x is not None:
                ax.axvline(x, **kwargs)
            if y is not None:
                ax.axhline(y, **kwargs)


class BaseStackedSlicer(BaseSlicer):
    """A class to create linked axes for plotting stacked cuts of 2D maps.

    Notes
    -----
    The extent of the different axes are adjusted to fit the data
    best in the viewing area.

    """

    # This should be set by each inheriting Slicer
    _direction: ClassVar[str] = ""

    @classmethod
    def find_cut_coords(
        cls,
        img=None,
        threshold=None,  # noqa: ARG003
        cut_coords=None,
    ):
        """Find world coordinates of cut positions compatible with this slicer.

        Parameters
        ----------
        img : 3D :class:`~nibabel.nifti1.Nifti1Image`, default=None
            The brain image.

        threshold : :obj:`float`, default=None
            The lower threshold to the positive activation.
            If ``None``, the activation threshold is computed using the
            80% percentile of the absolute value of the map.

        cut_coords : :obj:`int`, sequence of :obj:`float` or :obj:`int` or \
                     None, default=None
            The number of cuts to perform or the list of cut positions in the
            direction of this slicer.
            If ``None`` is given, the cuts are calculated automatically.

        Returns
        -------
        cut_coords : :obj:`list` of :obj:`float` or :obj:`int`
            The list of cut positions in the direction of this slicer.

        """
        # checks if cut_coords is compatible with this slicer
        # and adjust value if necessary
        cut_coords = cls._sanitize_cut_coords(cut_coords)
        if img is None or img is False:
            if isinstance(cut_coords, numbers.Number):
                bounds = ((-40, 40), (-30, 30), (-30, 75))
                lower, upper = bounds["xyz".index(cls._direction)]
                cut_coords = np.linspace(lower, upper, cut_coords).tolist()
        elif isinstance(cut_coords, numbers.Number):
            cut_coords = find_cut_slices(
                img, direction=cls._direction, n_cuts=cut_coords
            )
        else:
            # check if cut_coords is within image bounds
            # when it is not generated by this function
            cls._check_cut_coords_in_bounds(img, cut_coords)

        return list(cut_coords)

    @classmethod
    def _sanitize_cut_coords(cls, cut_coords):
        """Sanitize the cut coordinates.

        Check if `cut_coords` is a number or a sequence of numbers and adjust
        its value if necessary.

        Parameters
        ----------
        cut_coords : :obj:`int`, sequence of :obj:`float` or :obj:`int` or None

            The number of cuts to perform or the list of cut positions in the
            direction of this slicer.
            If ``None`` is given, 7 cut coordinates are calculated
            automatically.

        Raises
        ------
        ValueError
            If `cut_coords` is not a number or a sequence of :obj:`float` or
            :obj:`int` or `None`.

        """
        if cut_coords is None:
            cut_coords = 7
        elif isinstance(cut_coords, (list, tuple, np.ndarray)):
            # use dict.fromkeys to preserve order
            unique = dict.fromkeys(cut_coords)
            if len(cut_coords) != len(unique):
                warnings.warn(
                    f"Dropping duplicates cuts from: {cut_coords=}",
                    stacklevel=find_stack_level(),
                )
                cut_coords = list(unique)
        elif not isinstance(cut_coords, numbers.Number):
            raise ValueError(
                "cut_coords passed does not match the display mode."
                f" {cls.__name__} plotting expects a number, list, tuple or "
                "array of numbers or None."
                f"You provided cut_coords={cut_coords}."
            )

        return cut_coords

    @classmethod
    def _cut_count(cls):
        return len(cls._direction)

    @classmethod
    def _get_coords_in_bounds(cls, bounds, cut_coords) -> list[bool]:
        coord_in = []

        index = "xyz".find(cls._direction)
        coord_bounds = bounds[index]
        coord_in = [
            coord_bounds[0] <= coord <= coord_bounds[1] for coord in cut_coords
        ]
        return coord_in

    def _init_axes(self, **kwargs):
        self.cut_coords = self.find_cut_coords(cut_coords=self.cut_coords)
        x0, y0, x1, y1 = self.rect
        # Create our axes:
        self.axes = {}
        fraction = 1.0 / len(self.cut_coords)
        for index, coord in enumerate(self.cut_coords):
            coord = float(coord)
            fh = self.frame_axes.get_figure()
            ax = fh.add_axes(
                [
                    fraction * index * (x1 - x0) + x0,
                    y0,
                    fraction * (x1 - x0),
                    y1 - y0,
                ]
            )
            ax.axis("off")
            display_ax = self._axes_class(ax, self._direction, coord, **kwargs)
            self.axes[coord] = display_ax
            ax.set_axes_locator(self._locator)

        if self._black_bg:
            for ax in self.axes.values():
                ax.ax.imshow(
                    np.zeros((2, 2, 3)),
                    extent=[-5000, 5000, -5000, 5000],
                    zorder=-500,
                    aspect="equal",
                )

            # To have a black background in PDF, we need to create a
            # patch in black for the background
            self.frame_axes.imshow(
                np.zeros((2, 2, 3)),
                extent=[-5000, 5000, -5000, 5000],
                zorder=-500,
                aspect="auto",
            )
            self.frame_axes.set_zorder(-1000)

    def _locator(
        self,
        axes,
        renderer,  # noqa: ARG002
    ):
        """Adjust the size of the axes.

        The locator function used by matplotlib to position axes.

        Here we put the logic used to adjust the size of the axes.

        ``renderer`` is required to match the matplotlib API.

        """
        x0, y0, x1, y1 = self.rect
        width_dict = {}
        display_ax_dict = self.axes

        if self._colorbar:
            adjusted_width = self._colorbar_width / len(self.axes)
            right_margin = self._colorbar_margin["right"] / len(self.axes)
            ticks_margin = self._colorbar_margin["left"] / len(self.axes)
            x1 = x1 - (adjusted_width + right_margin + ticks_margin)

        for display_ax in display_ax_dict.values():
            bounds = display_ax.get_object_bounds()
            if not bounds:
                # This happens if the call to _map_show was not
                # successful. As it happens asynchronously (during a
                # refresh of the figure) we capture the problem and
                # ignore it: it only adds a non informative traceback
                bounds = [0, 1, 0, 1]
            xmin, xmax, _, _ = bounds
            width_dict[display_ax.ax] = xmax - xmin
        total_width = float(sum(width_dict.values()))
        for ax, width in width_dict.items():
            width_dict[ax] = width / total_width * (x1 - x0)
        left_dict = {}
        left = float(x0)
        for display_ax in display_ax_dict.values():
            left_dict[display_ax.ax] = left
            this_width = width_dict[display_ax.ax]
            left += this_width
        return Bbox(
            [[left_dict[axes], y0], [left_dict[axes] + width_dict[axes], y1]]
        )

    def draw_cross(self, cut_coords=None, **kwargs):
        """Draw a crossbar on the plot to show where the cut is performed.

        Not implemented for this slicer.

        Parameters
        ----------
        cut_coords : :obj:`int`, sequence of :obj:`float` or :obj:`int` or \
                     None, default=None
            The list of positions of the crosses to draw in the direction of
            this slicer.
            If ``None`` is passed, the this slicer's cut coordinates are
            used.

        kwargs : :obj:`dict`
            Extra keyword arguments are passed to function
            :func:`matplotlib.pyplot.axhline`.

        """


@fill_doc
class XSlicer(BaseStackedSlicer):
    """The ``XSlicer`` class enables sagittal visualization with \
    plotting functions of Nilearn like \
    :func:`nilearn.plotting.plot_img`.

    This visualization mode
    can be activated by setting ``display_mode='x'``:

    .. code-block:: python

        from nilearn.datasets import load_mni152_template
        from nilearn.plotting import plot_img

        img = load_mni152_template()
        # display is an instance of the XSlicer class
        display = plot_img(img, display_mode="x")

    Parameters
    ----------
    cut_coords : :obj:`int`, sequence of :obj:`float` or :obj:`int` or None
        The number of cuts to perform or the list of cut positions in the
        direction 'x'.
        If ``None``, 7 cut positions are calculated automatically.

    %(slicer_init_parameters_partial)s

    Attributes
    ----------
    cut_coords : :obj:`list` of :obj:`float` or :obj:`int`
        The list of cut positions in direction x.

    axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
        The axes used to plot each view in direction 'x'.

    %(displays_partial_attributes)s

    See Also
    --------
    nilearn.plotting.displays.YSlicer : Coronal view
    nilearn.plotting.displays.ZSlicer : Axial view

    """

    _direction: ClassVar[str] = "x"
    _default_figsize: ClassVar[list[float]] = [2.6, 2.3]


@fill_doc
class YSlicer(BaseStackedSlicer):
    """The ``YSlicer`` class enables coronal visualization with \
    plotting functions of Nilearn like \
    :func:`nilearn.plotting.plot_img`.

    This visualization mode
    can be activated by setting ``display_mode='y'``:

    .. code-block:: python

        from nilearn.datasets import load_mni152_template
        from nilearn.plotting import plot_img

        img = load_mni152_template()
        # display is an instance of the YSlicer class
        display = plot_img(img, display_mode="y")

    Parameters
    ----------
    cut_coords : :obj:`int`, sequence of :obj:`float` or :obj:`int` or None
        The number of cuts to perform or the list of cut positions in the
        direction 'y'.
        If ``None``, 7 cut positions are calculated automatically.

    %(slicer_init_parameters_partial)s

    Attributes
    ----------
    cut_coords : :obj:`list` of :obj:`float` or :obj:`int`
        The list of cut positions in direction y.

    axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
        The axes used to plot each view in direction 'y'.

    %(displays_partial_attributes)s

    See Also
    --------
    nilearn.plotting.displays.XSlicer : Sagittal view
    nilearn.plotting.displays.ZSlicer : Axial view

    """

    _direction: ClassVar[str] = "y"
    _default_figsize: ClassVar[list[float]] = [2.2, 3.0]


@fill_doc
class ZSlicer(BaseStackedSlicer):
    """The ``ZSlicer`` class enables axial visualization with \
    plotting functions of Nilearn like \
    :func:`nilearn.plotting.plot_img`.

    This visualization mode
    can be activated by setting ``display_mode='z'``:

    .. code-block:: python

        from nilearn.datasets import load_mni152_template
        from nilearn.plotting import plot_img

        img = load_mni152_template()
        # display is an instance of the ZSlicer class
        display = plot_img(img, display_mode="z")

    Parameters
    ----------
    cut_coords : :obj:`int`, sequence of :obj:`float` or :obj:`int` or None
        The number of cuts to perform or the list of cut positions in the
        direction 'z'.
        If ``None``, 7 cut positions are calculated automatically.

    %(slicer_init_parameters_partial)s

    Attributes
    ----------
    cut_coords : :obj:`list` of :obj:`float` or :obj:`int`
        The list of cut positions in direction z.

    axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
        The axes used to plot each view in direction 'z'.

    %(displays_partial_attributes)s

    See Also
    --------
    nilearn.plotting.displays.XSlicer : Sagittal view
    nilearn.plotting.displays.YSlicer : Coronal view

    """

    _direction: ClassVar[str] = "z"
    _default_figsize: ClassVar[list[float]] = [2.2, 3.2]


@fill_doc
class XZSlicer(OrthoSlicer):
    """The ``XZSlicer`` class enables to combine sagittal and axial views \
    on the same figure with plotting functions of Nilearn like \
    :func:`nilearn.plotting.plot_img`.

    This visualization mode
    can be activated by setting ``display_mode='xz'``:

    .. code-block:: python

        from nilearn.datasets import load_mni152_template
        from nilearn.plotting import plot_img

        img = load_mni152_template()
        # display is an instance of the XZSlicer class
        display = plot_img(img, display_mode="xz")

    Parameters
    ----------
    cut_coords : 2-sequence of :obj:`float` or :obj:`int` or None
        The world coordinates ``(x, z)`` of the point where the cut is
        performed.

    %(slicer_init_parameters_partial)s

    Attributes
    ----------
    cut_coords : 2- :obj:`tuple` of :obj:`float` or :obj:`int`
        The world coordinates (x, z) of the point where the cut is performed.

    axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
        The axes used for plotting in each direction ('x' and 'z' here).

    %(displays_partial_attributes)s

    See Also
    --------
    nilearn.plotting.displays.YXSlicer : Coronal + Sagittal views
    nilearn.plotting.displays.YZSlicer : Coronal + Axial views

    """

    _cut_displayed = "xz"


@fill_doc
class YXSlicer(OrthoSlicer):
    """The ``YXSlicer`` class enables to combine coronal and sagittal views \
    on the same figure with plotting functions of Nilearn like \
    :func:`nilearn.plotting.plot_img`.

    This visualization mode
    can be activated by setting ``display_mode='yx'``:

    .. code-block:: python

        from nilearn.datasets import load_mni152_template
        from nilearn.plotting import plot_img

        img = load_mni152_template()
        # display is an instance of the YXSlicer class
        display = plot_img(img, display_mode="yx")

    Parameters
    ----------
    cut_coords : 2-sequence of :obj:`float` or :obj:`int` or None
        The world coordinates ``(x, y)`` of the point where the cut is
        performed.

    %(slicer_init_parameters_partial)s

    Attributes
    ----------
    cut_coords : 2- :obj:`tuple` of :obj:`float` or :obj:`int`
        The world coordinates (x, y) of the point where the cut is performed.

    axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
        The axes used for plotting in each direction ('x' and 'y' here).

    %(displays_partial_attributes)s

    See Also
    --------
    nilearn.plotting.displays.XZSlicer : Sagittal + Axial views
    nilearn.plotting.displays.YZSlicer : Coronal + Axial views

    """

    _cut_displayed = "xy"


@fill_doc
class YZSlicer(OrthoSlicer):
    """The ``YZSlicer`` class enables to combine coronal and axial views \
    on the same figure with plotting functions of Nilearn like \
    :func:`nilearn.plotting.plot_img`.

    This visualization mode
    can be activated by setting ``display_mode='yz'``:

    .. code-block:: python

        from nilearn.datasets import load_mni152_template
        from nilearn.plotting import plot_img

        img = load_mni152_template()
        # display is an instance of the YZSlicer class
        display = plot_img(img, display_mode="yz")

    Parameters
    ----------
    cut_coords : 2-sequence of :obj:`float` or :obj:`int` or None
        The world coordinates ``(y, z)`` of the point where the cut is
        performed.

    %(slicer_init_parameters_partial)s

    Attributes
    ----------
    cut_coords : 2- :obj:`tuple` of :obj:`float` or :obj:`int`
        The world coordinates (y, z) of the point where the cut is performed.

    axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
        The axes used for plotting in each direction ('y' and 'z' here).

    %(displays_partial_attributes)s

    See Also
    --------
    nilearn.plotting.displays.XZSlicer : Sagittal + Axial views
    nilearn.plotting.displays.YXSlicer : Coronal + Sagittal views

    """

    _cut_displayed: ClassVar[str] = "yz"
    _default_figsize: ClassVar[list[float]] = [2.2, 3.0]


@fill_doc
class MosaicSlicer(BaseSlicer):
    """A class to create 3 :class:`~matplotlib.axes.Axes` for \
    plotting cuts of 3D maps, in multiple rows and columns.

    This visualization mode can be activated from Nilearn plotting
    functions, like :func:`~nilearn.plotting.plot_img`, by setting
    ``display_mode='mosaic'``.

    .. code-block:: python

        from nilearn.datasets import load_mni152_template
        from nilearn.plotting import plot_img

        img = load_mni152_template()
        # display is an instance of the MosaicSlicer class
        display = plot_img(img, display_mode="mosaic")

    Parameters
    ----------
    cut_coords : :obj:`int`, 3-sequence of :obj:`float` or :obj:`int`, \
                 :obj:`dict` <:obj:`str`: 1D :class:`~numpy.ndarray`> or None
        Either a number to indicate number of cuts in each direction, or a
        sequence of length 3 indicating the number of cuts in each direction
        ``(x, y, z)``, or a dictionary where keys are the directions
        ('x', 'y', 'z') and the values are sequences holding the cut
        coordinates.

    %(slicer_init_parameters_partial)s

    Attributes
    ----------
    cut_coords : :obj:`dict` <:obj:`str`: 1D :class:`~numpy.ndarray`>
        The cut coordinates in a dictionary. The keys are the directions
        ('x', 'y', 'z'), and the values are sequences holding the cut
        coordinates.

    axes : :obj:`dict` of :class:`~nilearn.plotting.displays.CutAxes`
        The axes used for plotting in each direction ('x', 'y' and 'z' here).

    %(displays_partial_attributes)s

    See Also
    --------
    nilearn.plotting.displays.TiledSlicer : Three cuts are performed \
    in orthogonal directions.
    nilearn.plotting.displays.OrthoSlicer : Three cuts are performed \
    and arranged in a 2x2 grid.

    """

    _cut_displayed: ClassVar[str] = "xyz"
    _default_figsize: ClassVar[list[float]] = [4.0, 5.0]

    @classmethod
    def _sanitize_cut_coords(cls, cut_coords):
        """Sanitize the cut coordinates.

        Check if `cut_coords` is a number, a sequence of numbers and adjust its
        value if necessary.

        Parameters
        ----------
        cut_coords : :obj:`int`, 3-sequence of :obj:`float` or :obj:`int` or \
                     :obj:`dict` <:obj:`str`: 1D :class:`~numpy.ndarray`> or \
                     `None`
            The world coordinates of the points where the cuts are performed.

        Raises
        ------
        ValueError
            If `cut_coords` is not a number, or a sequence of numbers, or a
        dictionary.
        """
        if cut_coords is None:
            cut_coords = 7
        if isinstance(cut_coords, numbers.Number):
            cut_coords = [cut_coords] * cls._cut_count()

        if any(
            (
                (
                    isinstance(cut_coords, (list, tuple, np.ndarray, dict))
                    and len(cut_coords) != cls._cut_count()
                ),
                (
                    isinstance(cut_coords, dict)
                    and not {"x", "y", "z"}.issubset(cut_coords)
                ),
                (
                    isinstance(cut_coords, dict)
                    and not all(
                        isinstance(value, (list, tuple, np.ndarray))
                        for value in cut_coords.values()
                    )
                ),
            )
        ):
            raise ValueError(
                "cut_coords passed does not match the display mode. "
                f" {cls.__name__} plotting expects a number, a list, tuple, "
                "or array of 3 numbers, or a dictionary "
                "with keys 'x', 'y', 'z' and values as array."
                f"You provided cut_coords={cut_coords}."
            )

        if isinstance(cut_coords, dict):
            for key, value in cut_coords.items():
                # use dict.fromkeys to preserve order
                unique = dict.fromkeys(value)
                if len(value) != len(unique):
                    warnings.warn(
                        f"Dropping duplicates cuts from direction '{key}' "
                        "values {value}",
                        stacklevel=find_stack_level(),
                    )
                    cut_coords[key] = list(unique)

        return cut_coords

    @classmethod
    def find_cut_coords(
        cls,
        img=None,
        threshold=None,  # noqa: ARG003
        cut_coords=None,
    ):
        """Find world coordinates of cut positions compatible with this slicer.

        Parameters
        ----------
        img : 3D :class:`~nibabel.nifti1.Nifti1Image`, default=None
            The brain image.

        threshold : :obj:`float`, default=None
            The lower threshold to the positive activation. If ``None``,
            the activation threshold is computed using the 80% percentile of
            the absolute value of the map.

        cut_coords : :obj:`int`, sequence of :obj:`float` or :obj:`int` or \
                     :obj:`dict` <:obj:`str`: 1D :class:`~numpy.ndarray`> or \
                     `None`, default=None
            The world coordinates of the points where the cuts are performed.

            If `cut_coords` is not provided, 7 coordinates of cuts are
            automatically calculated for each direction ('x', 'y', 'z').
            If an integer is provided, specified number of cuts are calculated
            for each direction.

        Returns
        -------
        cut_coords : :obj:`dict`  <:obj:`str`: 1D :class:`~numpy.ndarray`>
            xyz world coordinates of cuts in a direction.
            Each key denotes the direction.

        """
        cut_coords = cls._sanitize_cut_coords(cut_coords)

        if isinstance(cut_coords, (list, tuple, np.ndarray)):
            coords = {}
            if img is None or img is False:
                bounds = ((-40, 40), (-30, 30), (-30, 75))
                for direction, n_cuts in zip(
                    cls._cut_displayed, cut_coords, strict=False
                ):
                    lower, upper = bounds["xyz".index(direction)]
                    coords[direction] = np.linspace(
                        lower, upper, n_cuts
                    ).tolist()
            else:
                for direction, n_cuts in zip(
                    cls._cut_displayed, cut_coords, strict=False
                ):
                    coords[direction] = find_cut_slices(
                        img, direction=direction, n_cuts=n_cuts
                    )
            cut_coords = coords
        elif img is not None and img is not False:
            cls._check_cut_coords_in_bounds(img, cut_coords)
        return cut_coords

    @classmethod
    def _get_coords_in_bounds(cls, bounds, cut_coords) -> list[bool]:
        coord_in = []

        for index, direction in enumerate(cls._cut_displayed):
            coords_list = cut_coords[direction]
            coord_bounds = bounds[index]
            coord_in.extend(
                [
                    coord_bounds[0] <= coord <= coord_bounds[1]
                    for coord in coords_list
                ]
            )
        return coord_in

    @classmethod
    def _cut_count(cls):
        return len(cls._cut_displayed)

    def _init_axes(self, **kwargs):
        """Initialize and place axes for display of 'xyz' multiple cuts.

        Also adapts the width of the color bar relative to the axes.

        Parameters
        ----------
        kwargs : :obj:`dict`
            Additional arguments to pass to ``self._axes_class``.

        """
        self.cut_coords = self.find_cut_coords(cut_coords=self.cut_coords)
        if not isinstance(self.cut_coords, dict):
            self.cut_coords = self.find_cut_coords(cut_coords=self.cut_coords)

        x0, y0, x1, y1 = self.rect

        # Create our axes:
        self.axes = {}
        # portions for main axes
        fraction = y1 / len(self.cut_coords)
        height = fraction
        for index, direction in enumerate(self._cut_displayed):
            coords = self.cut_coords[direction]
            # portions allotment for each of 'x', 'y', 'z' coordinate
            fraction_c = 1.0 / len(coords)
            fh = self.frame_axes.get_figure()
            indices = [
                x0,
                fraction * index * (y1 - y0) + y0,
                x1,
                fraction * (y1 - y0),
            ]
            ax = fh.add_axes(indices)
            ax.axis("off")
            this_x0, this_y0, this_x1, _ = indices
            for index_c, coord in enumerate(coords):
                coord = float(coord)
                fh_c = ax.get_figure()
                # indices for each sub axes within main axes
                indices = [
                    fraction_c * index_c * (this_x1 - this_x0) + this_x0,
                    this_y0,
                    fraction_c * (this_x1 - this_x0),
                    height,
                ]
                ax = fh_c.add_axes(indices)
                ax.axis("off")
                display_ax = self._axes_class(ax, direction, coord, **kwargs)
                self.axes[direction, coord] = display_ax
                ax.set_axes_locator(self._locator)

        # increase color bar width to adapt to the number of cuts
        #  see issue https://github.com/nilearn/nilearn/pull/4284
        self._colorbar_width *= len(coords) ** 1.1

    def _locator(
        self,
        axes,
        renderer,  # noqa: ARG002
    ):
        """Adjust the size of the axes.

        Locator function used by matplotlib to position axes.

        Here we put the logic used to adjust the size of the axes.

        ``renderer`` is required to match the matplotlib API.

        """
        x0, y0, x1, y1 = self.rect
        display_ax_dict = self.axes

        if self._colorbar:
            adjusted_width = self._colorbar_width / len(self.axes)
            right_margin = self._colorbar_margin["right"] / len(self.axes)
            ticks_margin = self._colorbar_margin["left"] / len(self.axes)
            x1 = x1 - (adjusted_width + right_margin + ticks_margin)

        # capture widths for each axes for anchoring Bbox
        width_dict = {}
        for direction in self._cut_displayed:
            this_width = {}
            for display_ax in display_ax_dict.values():
                if direction == display_ax.direction:
                    bounds = display_ax.get_object_bounds()
                    if not bounds:
                        # This happens if the call to _map_show was not
                        # successful. As it happens asynchronously (during a
                        # refresh of the figure) we capture the problem and
                        # ignore it: it only adds a non informative traceback
                        bounds = [0, 1, 0, 1]
                    xmin, xmax, _, _ = bounds
                    this_width[display_ax.ax] = xmax - xmin
            total_width = float(sum(this_width.values()))
            for ax, w in this_width.items():
                width_dict[ax] = w / total_width * (x1 - x0)

        left_dict = {}
        # bottom positions in Bbox according to cuts
        bottom_dict = {}
        # fraction is divided by the cut directions 'y', 'x', 'z'
        fraction = y1 / len(self._cut_displayed)
        height_dict = {}
        for index, direction in enumerate(self._cut_displayed):
            left = float(x0)
            this_height = fraction + fraction * index
            for display_ax in display_ax_dict.values():
                if direction == display_ax.direction:
                    left_dict[display_ax.ax] = left
                    this_width = width_dict[display_ax.ax]
                    left += this_width
                    bottom_dict[display_ax.ax] = fraction * index * (y1 - y0)
                    height_dict[display_ax.ax] = this_height
        return Bbox(
            [
                [left_dict[axes], bottom_dict[axes]],
                [left_dict[axes] + width_dict[axes], height_dict[axes]],
            ]
        )

    def draw_cross(self, cut_coords=None, **kwargs):
        """Draw a crossbar on the plot to show where the cut is performed.

        Not implemented for this slicer.

        Parameters
        ----------
        cut_coords : :obj:`dict` <:obj:`str`: 1D :class:`~numpy.ndarray`> or \
                     `None`, default=None
            The positions of the crosses to draw.
            If ``None`` is passed, the ``MosaicSlicer``'s cut coordinates are
            used.

        kwargs : :obj:`dict`
            Extra keyword arguments are passed to function
            :func:`matplotlib.pyplot.axhline`.

        """


SLICERS = {
    "ortho": OrthoSlicer,
    "tiled": TiledSlicer,
    "mosaic": MosaicSlicer,
    "xz": XZSlicer,
    "yz": YZSlicer,
    "yx": YXSlicer,
    "x": XSlicer,
    "y": YSlicer,
    "z": ZSlicer,
}


def get_slicer(display_mode):
    """Retrieve a slicer from a given display mode.

    Parameters
    ----------
    display_mode : :obj:`str`
        The desired display mode.
        Possible options are:

        - "ortho": Three cuts are performed in orthogonal directions.
        - "tiled": Three cuts are performed and arranged in a 2x2 grid.
        - "mosaic": Three cuts are performed along multiple rows and columns.
        - "x": Sagittal
        - "y": Coronal
        - "z": Axial
        - "xz": Sagittal + Axial
        - "yz": Coronal + Axial
        - "yx": Coronal + Sagittal

    Returns
    -------
    slicer : An instance of one of the subclasses of\
    :class:`~nilearn.plotting.displays.BaseSlicer`

        The slicer corresponding to the requested display mode:

        - "ortho": Returns an
            :class:`~nilearn.plotting.displays.OrthoSlicer`.
        - "tiled": Returns a
            :class:`~nilearn.plotting.displays.TiledSlicer`.
        - "mosaic": Returns a
            :class:`~nilearn.plotting.displays.MosaicSlicer`.
        - "xz": Returns a
            :class:`~nilearn.plotting.displays.XZSlicer`.
        - "yz": Returns a
            :class:`~nilearn.plotting.displays.YZSlicer`.
        - "yx": Returns a
            :class:`~nilearn.plotting.displays.YZSlicer`.
        - "x": Returns a
            :class:`~nilearn.plotting.displays.XSlicer`.
        - "y": Returns a
            :class:`~nilearn.plotting.displays.YSlicer`.
        - "z": Returns a
            :class:`~nilearn.plotting.displays.ZSlicer`.

    """
    return get_create_display_fun(display_mode, SLICERS)


def save_figure_if_needed(fig, output_file):
    """Save figure if an output file value is given.

    Create output path if required.

    Parameters
    ----------
    fig: figure, axes, or display instance

    output_file: str, Path or None

    Returns
    -------
    None if ``output_file`` is None, ``fig`` otherwise.

    """
    if output_file is None:
        return fig

    output_file = Path(output_file)
    output_file.parent.mkdir(exist_ok=True, parents=True)

    if not isinstance(fig, (plt.Figure, BaseSlicer)):
        fig = fig.figure

    fig.savefig(output_file)
    if isinstance(fig, plt.Figure):
        plt.close(fig)
    else:
        fig.close()

    return None
