"""Neuroimaging file input and output."""

import collections.abc
import gc
from copy import deepcopy
from pathlib import Path
from warnings import warn

import numpy as np
from nibabel import Nifti1Image, is_proxy, load, spatialimages

from nilearn._utils.helpers import is_gil_enabled, stringify_path
from nilearn._utils.logger import find_stack_level


def _get_data(img) -> np.ndarray:
    # copy-pasted from
    # https://github.com/nipy/nibabel/blob/de44a10/nibabel/dataobj_images.py#L204
    #
    # get_data is removed from nibabel because:
    # see https://github.com/nipy/nibabel/wiki/BIAP8
    if img._data_cache is not None:
        return img._data_cache
    data = np.asanyarray(img._dataobj)
    img._data_cache = data
    return data


def safe_get_data(
    img: Nifti1Image, ensure_finite: bool = False, copy_data: bool = False
) -> np.ndarray:
    """Get the data in the image without having a side effect \
    on the Nifti1Image object.

    This function will create a copy of the image and load data to new image's
    cache, leaving image's state unchanged if:

    - image data is not loaded to cache,
    - ``copy_data`` is `True`,

    If ``ensure_finite`` is set to `True`, this is no more guaranteed. If data
    is not cached, it will be cached and infinite values are replaced with 0.

    Parameters
    ----------
    img : Nifti image/object
        Image to get data.

    ensure_finite : bool
        If True, non-finite values such as (NaNs and infs) found in the
        image will be replaced by zeros.

    copy_data : bool, default=False
        If true, the returned data is a copy of the img data.

    Returns
    -------
    data : numpy array
        nilearn.image.get_data return from Nifti image.
    """
    if not img.in_memory or copy_data:
        img = deepcopy(img)

    if is_gil_enabled():
        # typically the line below can double memory usage
        # that's why we invoke a forced call to the garbage collector
        gc.collect()

    data = _get_data(img)

    if ensure_finite:
        ensure_finite_data(data)
    return data


def has_non_finite(data: np.ndarray) -> tuple[bool, np.ndarray]:
    """Return True if data contains at least one NaN or inf value; False if
    there are no NaN and inf values.

    Besides boolean value, return the mask.
    """
    non_finite_mask = ~np.isfinite(data)
    has_not_finite = non_finite_mask.any()
    return has_not_finite, non_finite_mask


def ensure_finite_data(
    data: np.ndarray, raise_warning: bool = True
) -> np.ndarray:
    """Check if data contains NaN or inf values, set infinite values
    to 0 inplace if exists and return data.
    """
    has_not_finite, non_finite_mask = has_non_finite(data)
    if has_not_finite:
        if raise_warning:
            warn(
                "Non-finite values detected. "
                "These values will be replaced with zeros.",
                stacklevel=find_stack_level(),
            )
        data[non_finite_mask] = 0
    return data


def _get_target_dtype(dtype, target_dtype):
    """Return a new dtype if conversion is needed.

    Parameters
    ----------
    dtype : dtype
        Data type of the original data

    target_dtype : {None, dtype, "auto"}
        If None, no conversion is required. If a type is provided, the
        function will check if a conversion is needed. The "auto" mode will
        automatically convert to int32 if dtype is discrete and float32 if it
        is continuous.

    Returns
    -------
    dtype : dtype
        The data type toward which the original data should be converted.
    """
    if target_dtype is None:
        return None
    if target_dtype == "auto":
        target_dtype = np.int32 if dtype.kind == "i" else np.float32
    return None if target_dtype == dtype else target_dtype


def load_niimg(niimg, dtype=None):
    """Load a niimg, check if it is a nibabel SpatialImage and cast if needed.

    Parameters
    ----------
    niimg : Niimg-like object
        See :ref:`extracting_data`.
        Image to load.

    %(dtype)s

    Returns
    -------
    img : image
        A loaded image object.
    """
    from nilearn.image.image import new_img_like  # avoid circular imports

    niimg = stringify_path(niimg)
    if isinstance(niimg, str):
        # data is a filename, we load it
        niimg = load(niimg)
    elif not isinstance(niimg, spatialimages.SpatialImage):
        raise TypeError(
            "Data given cannot be loaded because it is"
            " not compatible with nibabel format:\n"
            + repr_niimgs(niimg, shorten=True)
        )

    # avoid loading data if dtype is None
    if dtype is not None:
        img_data = _get_data(niimg)
        target_dtype = _get_target_dtype(img_data.dtype, dtype)

        if target_dtype is not None:
            copy_header = niimg.header is not None
            niimg = new_img_like(
                niimg, img_data.astype(target_dtype), niimg.affine
            )
            if copy_header:
                niimg.header.set_data_dtype(target_dtype)

    return niimg


def is_binary_niimg(
    niimg: Nifti1Image,
    block_size: int = 1_000_000,
    accept_non_finite: bool = True,
) -> bool:
    """Return whether a given niimg is binary or not.

    Parameters
    ----------
    niimg : Niimg-like object
        See :ref:`extracting_data`.
        Image to test.

    Returns
    -------
    is_binary : Boolean
        True if binary, False otherwise.

    """
    niimg = load_niimg(niimg)
    data = niimg.dataobj
    return is_binary_data(data, block_size, accept_non_finite)


def _binary_mask(block: np.ndarray) -> np.ndarray:
    """Create a boolean mask for values equal to 0 or 1."""
    return (block == 0) | (block == 1)


def _binary_mask_with_nonfinite(block: np.ndarray) -> np.ndarray:
    """Create a boolean mask for values equal to 0, 1, nan, or +-inf."""
    return _binary_mask(block) | ~np.isfinite(block)


def is_binary_data(data, block_size=1_000_000, accept_non_finite=True) -> bool:
    """Return whether a given proxy array or ndarray is binary or not.
    If accept_non_finite is True, NaN and inf values are ignored.
    """
    flat = np.ravel(data)

    mask_func = (
        _binary_mask_with_nonfinite if accept_non_finite else _binary_mask
    )

    for i in range(0, flat.size, block_size):
        block = flat[i : i + block_size]

        mask = mask_func(block)
        if not mask.all():
            return False

    return True


def repr_niimgs(niimgs, shorten=True):
    """Pretty printing of niimg or niimgs.

    Parameters
    ----------
    niimgs : image or collection of images
        nibabel SpatialImage to repr.

    shorten : boolean, default=True
        If True, filenames with more than 20 characters will be
        truncated, and lists of more than 3 file names will be
        printed with only first and last element.

    Returns
    -------
    repr : str
        String representation of the image.
    """
    # Simple string case
    if isinstance(niimgs, (str, Path)):
        return _short_repr(niimgs, shorten=shorten)
    # Collection case
    if isinstance(niimgs, collections.abc.Iterable):
        # Maximum number of elements to be displayed
        # Note: should be >= 3 to make sense...
        list_max_display = 3
        if shorten and len(niimgs) > list_max_display:
            tmp = ",\n         ...\n ".join(
                repr_niimgs(niimg, shorten=shorten)
                for niimg in [niimgs[0], niimgs[-1]]
            )
            return f"[{tmp}]"
        elif len(niimgs) > list_max_display:
            tmp = ",\n ".join(
                repr_niimgs(niimg, shorten=shorten) for niimg in niimgs
            )
            return f"[{tmp}]"
        else:
            tmp = [repr_niimgs(niimg, shorten=shorten) for niimg in niimgs]
            return f"[{', '.join(tmp)}]"
    # Nibabel objects have a 'get_filename'
    try:
        filename = niimgs.get_filename()
        if filename is not None:
            return (
                f"{niimgs.__class__.__name__}"
                f"('{_short_repr(filename, shorten=shorten)}')"
            )
        else:
            # No shortening in this case
            return (
                f"{niimgs.__class__.__name__}"
                f"(\nshape={niimgs.shape!r},"
                f"\naffine={niimgs.affine!r}\n)"
            )
    except Exception:
        pass
    return _short_repr(repr(niimgs), shorten=shorten)


def _short_repr(niimg_rep, shorten: bool = True, truncate: int = 20) -> str:
    """Give a shorter version of niimg representation."""
    # Make sure truncate has a reasonable value
    truncate = max(truncate, 10)
    path_to_niimg = Path(niimg_rep)
    if not shorten:
        return str(path_to_niimg)
    # If the name of the file itself
    # is larger than truncate,
    # then shorten the name only
    # else add some folder structure if available
    if len(path_to_niimg.name) > truncate:
        return f"{path_to_niimg.name[: (truncate - 2)]}..."
    rep = path_to_niimg.name
    if len(path_to_niimg.parts) > 1:
        for p in path_to_niimg.parts[::-1][1:]:
            if len(rep) + len(p) < truncate - 3:
                rep = str(Path(p, rep))
            else:
                rep = str(Path("...", rep))
                break
    return rep


def img_data_dtype(img):
    """Determine type of data contained in image.

    Based on the information contained in ``niimg.dataobj``, determine the
    dtype of ``np.array(niimg.dataobj).dtype``.
    """
    dataobj = img.dataobj

    # Neuroimages that scale data should be interpreted as floating point
    if is_proxy(dataobj) and (dataobj.slope, dataobj.inter) != (
        1.0,
        0.0,
    ):
        return np.float64

    # ArrayProxy gained the dtype attribute in nibabel 2.2
    return dataobj.dtype if hasattr(dataobj, "dtype") else img.get_data_dtype()
