module SciMLBaseChainRulesCoreExt

using SciMLBase
using SciMLBase: getobserved
import ChainRulesCore
import ChainRulesCore: NoTangent, @non_differentiable, zero_tangent, rrule_via_ad
using SymbolicIndexingInterface
using RecursiveArrayTools: AbstractVectorOfArray

function ChainRulesCore.rrule(
        config::ChainRulesCore.RuleConfig{
            >:ChainRulesCore.HasReverseMode,
        },
        ::typeof(getindex),
        VA::ODESolution,
        sym,
        j::Integer
    )
    function ODESolution_getindex_pullback(Δ)
        i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
        du,
            dprob = if i === nothing
            getter = getobserved(VA)
            grz = rrule_via_ad(config, getter, sym, VA.u[j], VA.prob.p, VA.t[j])[2](Δ)
            du = [k == j ? grz[3] : zero(VA.u[1]) for k in 1:length(VA.u)]
            dp = grz[4] # pullback for p
            if dp == NoTangent()
                dp = zero_tangent(parameter_values(VA.prob))
            end
            dprob = remake(VA.prob, p = dp)
            du, dprob
        else
            du = [
                m == j ? [i == k ? Δ : zero(VA.u[1][1]) for k in 1:length(VA.u[1])] :
                    zero(VA.u[1]) for m in 1:length(VA.u)
            ]
            dp = zero_tangent(VA.prob.p)
            dprob = remake(VA.prob, p = dp)
            du, dprob
        end
        T = eltype(eltype(du))
        N = ndims(eltype(du)) + 1
        Δ′ = ODESolution{T, N}(
            du, nothing, nothing, VA.t, VA.k, nothing, dprob,
            VA.alg, VA.interp, VA.dense, 0, VA.stats, VA.alg_choice, VA.retcode
        )
        return (NoTangent(), Δ′, NoTangent(), NoTangent())
    end
    return VA[sym, j], ODESolution_getindex_pullback
end

function ChainRulesCore.rrule(::typeof(getindex), VA::ODESolution, sym)
    function ODESolution_getindex_pullback(Δ)
        i = symbolic_type(sym) != NotSymbolic() ? variable_index(VA, sym) : sym
        return if i === nothing
            throw(error("AD of purely-symbolic slicing for observed quantities is not yet supported. Work around this by using `A[sym,i]` to access each element sequentially in the function being differentiated."))
        else
            Δ′ = [
                [i == k ? Δ[j] : zero(x[1]) for k in 1:length(x)]
                    for (x, j) in zip(VA.u, 1:length(VA))
            ]
            (NoTangent(), Δ′, NoTangent())
        end
    end
    return VA[sym], ODESolution_getindex_pullback
end

function ChainRulesCore.rrule(::Type{ODEProblem}, args...; kwargs...)
    function ODEProblemAdjoint(ȳ)
        return (NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
    end

    return ODEProblem(args...; kwargs...), ODEProblemAdjoint
end

function ChainRulesCore.rrule(
        ::Type{
            <:ODEProblem{iip, T},
        }, args...; kwargs...
    ) where {iip, T}
    function ODEProblemAdjoint(ȳ)
        return (NoTangent(), ȳ.f, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
    end

    return ODEProblem(args...; kwargs...), ODEProblemAdjoint
end

function ChainRulesCore.rrule(::Type{SDEProblem}, args...; kwargs...)
    function SDEProblemAdjoint(ȳ)
        return (NoTangent(), ȳ.f, ȳ.g, ȳ.u0, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
    end

    return SDEProblem(args...; kwargs...), SDEProblemAdjoint
end

function ChainRulesCore.rrule(
        ::Type{
            <:ODESolution{
                T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
                T11, T12, T13, T14, T15, T16,
            },
        }, u,
        args...
    ) where {
        T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11,
        T12, T13, T14, T15, T16,
    }
    function ODESolutionAdjoint(ȳ)
        return (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
    end

    return ODESolution{T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, T14, T15, T16}(
            u, args...
        ),
        ODESolutionAdjoint
end

function ChainRulesCore.rrule(
        ::Type{
            <:RODESolution{
                T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
                T11, T12, T13, T14,
            },
        }, u,
        args...
    ) where {
        T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
        T11, T12, T13, T14,
    }
    function RODESolutionAdjoint(ȳ)
        return (NoTangent(), ȳ, ntuple(_ -> NoTangent(), length(args))...)
    end

    return RODESolution{
            T1, T2, T3, T4, T5, T6, T7, T8, T9, T10,
            T11, T12, T13, T14,
        }(u, args...),
        RODESolutionAdjoint
end

# EnsembleSolution rrule with full support for various gradient types
# Matches the Zygote extension implementation for consistency
function ChainRulesCore.rrule(
        ::Type{EnsembleSolution}, sim, time, converged, stats = nothing
    )
    out = EnsembleSolution(sim, time, converged, stats)
    function EnsembleSolution_adjoint(p̄::AbstractArray{T, N}) where {T, N}
        arrarr = [
            [
                    p̄[ntuple(x -> Colon(), Val(N - 2))..., j, i]
                    for j in 1:size(p̄)[end - 1]
                ] for i in 1:size(p̄)[end]
        ]
        return (
            NoTangent(),
            EnsembleSolution(arrarr, 0.0, true, stats),
            NoTangent(),
            NoTangent(),
            NoTangent(),
        )
    end
    function EnsembleSolution_adjoint(p̄::AbstractArray{<:AbstractArray, 1})
        return (
            NoTangent(),
            EnsembleSolution(p̄, 0.0, true, stats),
            NoTangent(),
            NoTangent(),
            NoTangent(),
        )
    end
    function EnsembleSolution_adjoint(p̄::AbstractVectorOfArray)
        return (
            NoTangent(),
            EnsembleSolution(p̄, 0.0, true, stats),
            NoTangent(),
            NoTangent(),
            NoTangent(),
        )
    end
    function EnsembleSolution_adjoint(p̄::EnsembleSolution)
        return (NoTangent(), p̄, NoTangent(), NoTangent(), NoTangent())
    end
    function EnsembleSolution_adjoint(p̄::NamedTuple)
        return (NoTangent(), p̄.u, NoTangent(), NoTangent(), NoTangent())
    end
    return out, EnsembleSolution_adjoint
end

function ChainRulesCore.rrule(
        ::Type{SciMLBase.IntervalNonlinearProblem}, args...; kwargs...
    )
    function IntervalNonlinearProblemAdjoint(ȳ)
        return (NoTangent(), ȳ.f, ȳ.tspan, ȳ.p, ȳ.kwargs, ȳ.problem_type)
    end

    return SciMLBase.IntervalNonlinearProblem(args...; kwargs...), IntervalNonlinearProblemAdjoint
end

# This is a workaround for the fact `NonlinearProblem` is a mutable struct. In SciMLSensitivity, we call
# `back` explicitly while already in a reverse pass causing a nested gradient call. The mutable struct
# causes accumulation anytime `getfield/property` is called, accumulating multiple times. This tries to treat
# AbstractDEProblem as immutable for the purposes of reverse mode AD.
function ChainRulesCore.rrule(
        ::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
        ::typeof(Base.getproperty), x::NonlinearProblem, f::Symbol
    )
    val = getfield(x, f)
    function back(der)
        dx = if der === nothing
            ChainRulesCore.zero_tangent(x)
        else
            NamedTuple{(f,)}((der,))
        end
        return (
            ChainRulesCore.NoTangent(), ChainRulesCore.ProjectTo(x)(dx),
            ChainRulesCore.NoTangent(),
        )
    end
    return val, back
end

end
