"""
Test if figure in report output have changed.

See the  maintenance page of our documentation for more information
https://nilearn.github.io/dev/maintenance.html#generating-new-baseline-figures-for-plotting-tests
"""

import numpy as np
import pytest

from nilearn.datasets import (
    load_fsaverage_data,
    load_mni152_gm_mask,
    load_sample_motor_activation_image,
)
from nilearn.image import load_img, math_img, new_img_like, threshold_img
from nilearn.maskers import (
    MultiNiftiLabelsMasker,
    MultiNiftiMapsMasker,
    MultiNiftiMasker,
    MultiSurfaceMapsMasker,
    MultiSurfaceMasker,
    NiftiLabelsMasker,
    NiftiMapsMasker,
    NiftiMasker,
    NiftiSpheresMasker,
    SurfaceMapsMasker,
    SurfaceMasker,
)
from nilearn.surface.surface import at_least_2d, find_surface_clusters


def loaded_motor_activation_image():
    """Load motor activation image.

    Needed to standardize image name when used in test parametrization.
    """
    return load_img(load_sample_motor_activation_image())


@pytest.mark.slow
@pytest.mark.mpl_image_compare
@pytest.mark.thread_unsafe
@pytest.mark.parametrize(
    "mask_img, img",
    (
        [load_mni152_gm_mask(), None],
        [None, loaded_motor_activation_image()],
        [load_mni152_gm_mask(), loaded_motor_activation_image()],
    ),
)
@pytest.mark.parametrize("src_masker", [NiftiMasker, MultiNiftiMasker])
def test_nifti_masker_create_figure_for_report(src_masker, mask_img, img):
    """Check figure generated in report of NiftiMasker."""
    masker = src_masker(mask_img)
    masker.fit(img)
    return masker._create_figure_for_report()


@pytest.mark.slow
@pytest.mark.thread_unsafe
@pytest.mark.mpl_image_compare
@pytest.mark.parametrize("mask_img", [load_mni152_gm_mask(), None])
@pytest.mark.parametrize("img", [None, loaded_motor_activation_image()])
@pytest.mark.parametrize(
    "src_masker", [NiftiLabelsMasker, MultiNiftiLabelsMasker]
)
def test_nifti_labels_masker_create_figure_for_report(
    src_masker, mask_img, img
):
    """Check figure generated in report of NiftiLabelsMasker."""
    # generate a dummy label image that makes sense for human visualization
    positive_img = threshold_img(
        load_sample_motor_activation_image(),
        3,
        cluster_threshold=300,
        two_sided=False,
    )
    positive_data = positive_img.get_fdata()
    positive_data[positive_data > 0] = 1
    positive_img = new_img_like(positive_img, data=positive_data)

    negative_img = threshold_img(
        load_sample_motor_activation_image(),
        -3,
        cluster_threshold=100,
        two_sided=False,
    )
    negative_data = negative_img.get_fdata()
    negative_data[negative_data < 0] = 2
    negative_img = new_img_like(negative_img, data=negative_data)

    labels_img = math_img("img1 + img2", img1=positive_img, img2=negative_img)

    masker = src_masker(labels_img, mask_img=mask_img)
    masker.fit(img)

    labels_image = masker._reporting_data["labels_image"]

    return masker._create_figure_for_report(labels_image)


@pytest.mark.slow
@pytest.mark.mpl_image_compare
@pytest.mark.thread_unsafe
@pytest.mark.parametrize("mask_img", [load_mni152_gm_mask(), None])
@pytest.mark.parametrize("img", [None, loaded_motor_activation_image()])
@pytest.mark.parametrize("src_masker", [NiftiMapsMasker, MultiNiftiMapsMasker])
def test_nifti_maps_masker_create_figure_for_report(src_masker, mask_img, img):
    """Check figure generated in report of NiftiMapsMasker."""
    # generate dummy maps image
    maps_img = threshold_img(
        load_sample_motor_activation_image(),
        3,
        cluster_threshold=300,
        two_sided=False,
    )

    masker = src_masker(maps_img, mask_img=mask_img)
    masker.fit(img)
    masker._report_content["displayed_maps"] = [0]
    return masker._create_figure_for_report()[0]


@pytest.mark.thread_unsafe
@pytest.mark.mpl_image_compare
@pytest.mark.parametrize("mask_img", [load_mni152_gm_mask(), None])
@pytest.mark.parametrize("img", [None, loaded_motor_activation_image()])
def test_nifti_spheres_masker_create_figure_for_report(mask_img, img):
    """Check figure generated in report of NiftiSpheresMasker."""
    masker = NiftiSpheresMasker(seeds=[(0, 0, 0)], mask_img=mask_img)
    masker.fit(img)
    masker._report_content["displayed_maps"] = [0, 1]
    return masker._create_figure_for_report()[1]


@pytest.mark.mpl_image_compare
def test_nifti_spheres_masker_create_summary_figure_for_report():
    """Check figure with all spheres generated by NiftiSpheresMasker."""
    masker = NiftiSpheresMasker(seeds=[(0, 0, 0), (0, 10, 20), (20, 10, 0)])
    masker.fit()
    masker._report_content["displayed_maps"] = [0]
    return masker._create_figure_for_report()[0]


def _fs_inflated_sulcal():
    """Load fs average sulcal data on inflated surface."""
    return load_fsaverage_data(mesh_type="inflated")


def _surface_mask_img():
    """Generate surface mask including only high curvature regions."""
    return threshold_img(
        _fs_inflated_sulcal(), 0.5, cluster_threshold=50, two_sided=False
    )


@pytest.mark.thread_unsafe
@pytest.mark.mpl_image_compare
@pytest.mark.parametrize(
    "mask_img, img",
    (
        [_surface_mask_img(), None],
        [None, _fs_inflated_sulcal()],
        [_surface_mask_img(), _fs_inflated_sulcal()],
    ),
)
@pytest.mark.parametrize("src_masker", [SurfaceMasker, MultiSurfaceMasker])
def test_surface_masker_create_figure_for_report(src_masker, mask_img, img):
    """Check figure generated in report of (Multi)SurfaceMasker."""
    masker = src_masker(mask_img)
    masker.fit(img)
    return masker._create_figure_for_report()


# TODO: add later as there seem to be some flaky tests failures
# @pytest.mark.mpl_image_compare
# @pytest.mark.parametrize("mask_img", [_surface_mask_img(), None])
# @pytest.mark.parametrize("img", [None, _fs_inflated_sulcal()])
# def test_surface_labels_masker_create_figure_for_report(mask_img, img):
#     """Check figure generated in report of SurfaceLabelsMasker."""
#     # generate dummy labels image
#     tmp = _surface_mask_img()
#     data = {}
#     for hemi in tmp.data.parts:
#         _, labels = find_surface_clusters(
#             tmp.mesh.parts[hemi], tmp.data.parts[hemi]
#         )
#         data[hemi] = labels
#     labels_img = new_img_like(tmp, data)

#     masker = SurfaceLabelsMasker(labels_img, mask_img=mask_img)
#     masker.fit(img)
#     return masker._create_figure_for_report()[0]


@pytest.mark.mpl_image_compare
@pytest.mark.thread_unsafe
@pytest.mark.parametrize("hemi", ["left", "right"])
@pytest.mark.parametrize("mask_img", [_surface_mask_img(), None])
@pytest.mark.parametrize("img", [None, _fs_inflated_sulcal()])
@pytest.mark.parametrize(
    "src_masker", [SurfaceMapsMasker, MultiSurfaceMapsMasker]
)
def test_surface_maps_masker_create_figure_for_report(
    src_masker, mask_img, img, hemi
):
    """Check figure generated in report of SurfaceMapsMasker."""
    # generate dummy maps image
    # take values main cluster in each hemisphere
    tmp = _surface_mask_img()

    data = {
        "right": np.zeros(tmp.data.parts["right"].shape, dtype=np.float32),
        "left": np.zeros(tmp.data.parts["left"].shape, dtype=np.float32),
    }
    data[hemi] = tmp.data.parts[hemi].astype(np.float32)

    clusters, labels = find_surface_clusters(
        tmp.mesh.parts[hemi], tmp.data.parts[hemi]
    )
    max_size = clusters["size"].max()
    idx_biggest_cluster = clusters["index"][
        clusters["size"] == max_size
    ].to_numpy()

    data[hemi][labels != idx_biggest_cluster] = 0

    maps_imgs = at_least_2d(new_img_like(tmp, data))

    masker = src_masker(maps_imgs, mask_img=mask_img)
    masker.fit(img)
    masker._report_content["engine"] = "matplotlib"
    return masker._create_figure_for_report(maps_imgs, bg_img=img)[0]
