module EllipsisNotation

using PrecompileTools: @compile_workload, @setup_workload
using StaticArrayInterface: StaticArrayInterface

import Base: to_indices, tail

struct Ellipsis end

"""
Implementation of the notation `..` for indexing arrays. It's similar to the Python
`...` in that it means 'all the columns before (or after)'.

`..` slurps dimensions greedily, meaning that the first occurrence
of `..` in an index expression creates as many slices as possible. Other
instances of `..` afterward are treated simply as slices. Usually, you
should only use one instance of `..` in an indexing expression to avoid
possible confusion.

# Example

```jldoctest
julia> A = Array{Int}(undef, 2, 4, 2);

julia> A[.., 1] = [2 1 4 5
                   2 2 3 6];

julia> A[.., 2] = [3 2 6 5
                   3 2 6 6];

julia> A[:, :, 1] == [2 1 4 5
                      2 2 3 6]
true

julia> A[1, ..] = reshape([3 4
                           5 6
                           4 5
                           6 7], 1, 4, 2) # drops singleton dimension
...

julia> B = [3 4
            5 6
            4 5
            6 7];

julia> B == reshape(A[1, ..], 4, 2)
true
```
"""
const .. = Ellipsis()

@inline function to_indices(
        A,
        inds::NTuple{M, Any},
        I::Tuple{Ellipsis, Vararg{Any, N}}
    ) where {M, N}
    # Align the remaining indices to the tail of the `inds`
    colons = ntuple(n -> Colon(), M - _ndims_index(I) + 1)
    return to_indices(A, inds, (colons..., tail(I)...))
end

@inline _ndims_index(inds::Tuple{}) = StaticArrayInterface.static(0)
@inline function _ndims_index(inds::Tuple)
    return StaticArrayInterface.ndims_index(inds[1]) + _ndims_index(tail(inds))
end

StaticArrayInterface.is_splat_index(::Type{Ellipsis}) = StaticArrayInterface.static(true)
StaticArrayInterface.ndims_index(::Type{Ellipsis}) = StaticArrayInterface.static(1)
function StaticArrayInterface.to_index(x, ::Ellipsis)
    return ntuple(i -> StaticArrayInterface.indices(x, i), Val(ndims(x)))
end

export ..

@setup_workload begin
    @compile_workload begin
        # Precompile common ellipsis indexing patterns for Float64 arrays
        # 2D arrays
        A2 = zeros(2, 3)
        A2[.., 1]
        A2[1, ..]

        # 3D arrays (most common use case)
        A3 = zeros(2, 3, 4)
        A3[.., 1]
        A3[1, ..]
        A3[:, .., 1]
        A3[1, .., 2]

        # 4D arrays
        A4 = zeros(2, 3, 4, 5)
        A4[.., 1]
        A4[1, ..]
        A4[.., 1, 2]
        A4[1, 2, ..]

        # 5D arrays
        A5 = zeros(2, 3, 4, 5, 6)
        A5[.., 1]
        A5[1, ..]

        # Int64 arrays (common in indexing operations)
        B3 = zeros(Int, 2, 3, 4)
        B3[.., 1]
        B3[1, ..]

        B4 = zeros(Int, 2, 3, 4, 5)
        B4[.., 1]
        B4[1, ..]

        # Float32 arrays (common in GPU/ML workloads)
        C3 = zeros(Float32, 2, 3, 4)
        C3[.., 1]
        C3[1, ..]

        C4 = zeros(Float32, 2, 3, 4, 5)
        C4[.., 1]
        C4[1, ..]
    end

    # Explicit precompile statements to force native code caching
    # These are needed because getindex/setindex!/to_indices are Base methods
    # and their specializations may not be cached in the package image otherwise

    # Common getindex patterns
    for T in (Float64, Float32, Int)
        # 2D
        precompile(getindex, (Array{T, 2}, Ellipsis, Int))
        precompile(getindex, (Array{T, 2}, Int, Ellipsis))

        # 3D - most common use case
        precompile(getindex, (Array{T, 3}, Ellipsis, Int))
        precompile(getindex, (Array{T, 3}, Int, Ellipsis))
        precompile(getindex, (Array{T, 3}, Colon, Ellipsis, Int))
        precompile(getindex, (Array{T, 3}, Int, Ellipsis, Int))

        # 4D
        precompile(getindex, (Array{T, 4}, Ellipsis, Int))
        precompile(getindex, (Array{T, 4}, Int, Ellipsis))
        precompile(getindex, (Array{T, 4}, Ellipsis, Int, Int))
        precompile(getindex, (Array{T, 4}, Int, Int, Ellipsis))

        # 5D
        precompile(getindex, (Array{T, 5}, Ellipsis, Int))
        precompile(getindex, (Array{T, 5}, Int, Ellipsis))
    end

    # Common setindex! patterns
    for T in (Float64, Float32, Int)
        # 3D
        precompile(setindex!, (Array{T, 3}, Array{T, 2}, Ellipsis, Int))
        precompile(setindex!, (Array{T, 3}, Array{T, 2}, Int, Ellipsis))

        # 4D
        precompile(setindex!, (Array{T, 4}, Array{T, 3}, Ellipsis, Int))
        precompile(setindex!, (Array{T, 4}, Array{T, 3}, Int, Ellipsis))
    end
end

end # module
