#!/usr/libexec/platform-python

# SPDX-FileCopyrightText: Copyright (c) 2022-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: LicenseRef-NvidiaProprietary
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.

from __future__ import print_function

import os
import sys
import time
from concurrent.futures import ThreadPoolExecutor, as_completed
from contextlib import contextmanager


class PackageManager:
    YUM = 1
    DNF = 2


def get_package_manager():
    try:
        import yum

        return PackageManager.YUM
    except ImportError:
        pass

    try:
        import dnf

        return PackageManager.DNF
    except ImportError:
        raise Exception(
            "yum and dnf package managers are no exist.\
            To continue, you need to install one of them."
        )


@contextmanager
def silence_stdout():
    old_target = sys.stdout
    try:
        with open(os.devnull, "w") as new_target:
            sys.stdout = new_target
            yield new_target
    finally:
        sys.stdout = old_target


def error_print(*args, **kwargs):
    print(*args, file=sys.stderr, **kwargs)


def download_single_package_yum(
    dest_dir, package_name, only_list_deps, exist_deps, name_resolved=False
):
    import yum
    from rpmUtils.arch import getBaseArch

    yb = yum.YumBase()
    yb.preconf.init_plugins = False
    yb.setCacheDir()
    yb.conf.downloaddir = dest_dir

    # First try to find the package by name
    for arch in ["noarch", getBaseArch()]:
        packages_list = yb.doPackageLists(patterns=[package_name + "." + arch])
        if packages_list.available:
            break
        if packages_list.installed:
            break

    if packages_list.installed:
        error_print("Package {} is already installed".format(package_name))
        return []

    if not packages_list.available:
        if not name_resolved:
            # Try to find packages that provide this library/file
            try:
                providers = yb.searchPackageProvides(args=[package_name])
                if providers:
                    provider_pkg = next(iter(providers))
                    error_print(
                        "Resolved {} to package {}".format(
                            package_name, provider_pkg.name
                        )
                    )
                    return download_single_package_yum(
                        dest_dir, provider_pkg.name, only_list_deps, exist_deps, True
                    )
            except Exception as e:
                error_print(
                    "Failed to resolve provider for {}: {}".format(package_name, str(e))
                )
        error_print("No package or provider found for {}".format(package_name))
        return []

    dep_name_set = set()
    packages_to_download = packages_list.available
    yb.doTsSetup()
    for pkg in packages_to_download:
        yb.tsInfo.addInstall(pkg)
        yb.localPackages.append(pkg)
        dep_name_set.add(str(pkg))

    result, result_msg = yb.resolveDeps()
    if result == 1:
        for msg in result_msg:
            error_print("Dependency resolution error: {}".format(msg))
        return []

    for pkg in yb.tsInfo.getMembers():
        if pkg.ts_state in ("i", "u") and pkg.po not in packages_to_download:
            packages_to_download.append(pkg.po)
            dep_name_set.add(str(pkg.po))

    if only_list_deps:
        return list(dep_name_set)

    for pkg in packages_to_download:
        pkg.repo.copy_local = True
        pkg.repo.cache = 0

    probs = yb.downloadPkgs(packages_to_download)
    if probs:
        for key in probs:
            for error in probs[key]:
                error_print("Download error for {}: {}".format(key, error))
        return []

    downloaded_packages = []
    for pkg in packages_to_download:
        downloaded_packages.append(pkg.localpath)

    return downloaded_packages


def download_batch_packages_yum(dest_dir, package_names, only_list_deps, exist_deps):
    """Optimized batch processing for yum"""
    all_downloaded = []
    all_deps = set()

    # Use thread pool for parallel processing
    with ThreadPoolExecutor(max_workers=4) as executor:
        future_to_package = {
            executor.submit(
                download_single_package_yum, dest_dir, pkg, only_list_deps, exist_deps
            ): pkg
            for pkg in package_names
        }

        for future in as_completed(future_to_package):
            package_name = future_to_package[future]
            try:
                result = future.result()
                if only_list_deps:
                    all_deps.update(result)
                else:
                    all_downloaded.extend(result)
            except Exception as exc:
                error_print(
                    "Package {} generated an exception: {}".format(
                        package_name, str(exc)
                    )
                )

    return list(all_deps) if only_list_deps else all_downloaded


def download_with_yum(
    dest_dir, package_name, only_list_deps, exist_deps, name_resolved=False
):
    return download_single_package_yum(
        dest_dir, package_name, only_list_deps, exist_deps, name_resolved
    )


def is_package_installed(installed_pkgs, pkg):
    key = (pkg.name, pkg.arch, pkg.version)
    inst_pkg = installed_pkgs.get(key, [None])[0]
    return inst_pkg is not None


def parse_packages(base, installed_pkgs, pkgs):
    import dnf
    import hawkey

    pkg_sack = base.sack
    matches = set()
    for pkg in pkgs:
        hkpkgs = set()
        subject = dnf.subject.Subject(pkg)
        hkpkgs |= set(subject.get_best_selector(pkg_sack, obsoletes=True).matches())
        if len(matches) == 0:
            matches = hkpkgs
        else:
            matches |= hkpkgs
    result = list(matches)
    a = pkg_sack.query().available()
    result = a.filter(pkg=result, arch=["noarch", hawkey.detect_arch()]).latest().run()
    filtered_res = []
    for pkg in result:
        if not is_package_installed(installed_pkgs, pkg):
            filtered_res.append(pkg)
    return filtered_res


def get_centos_stream_version():
    try:
        with open("/etc/os-release") as os_release:
            lines = [
                line.strip() for line in os_release.readlines() if line.strip() != ""
            ]
            info = {
                k: v.strip("'\"")
                for k, v in (line.split("=", maxsplit=1) for line in lines)
            }

        if info["NAME"].find("centos stream"):
            return info["VERSION_ID"]
    except:
        pass
    return None


def download_batch_packages_dnf(dest_dir, package_names, only_list_deps):
    """Optimized batch processing for dnf"""
    files_in_dest_dir = set(os.listdir(dest_dir))
    import dnf
    import tempfile

    downloaded_rpms = []
    all_deps = set()

    with tempfile.TemporaryDirectory() as tmp_cache_dirname:
        base = dnf.Base()
        base.conf.cachedir = tmp_cache_dirname
        base.conf.destdir = dest_dir
        base.conf.install_weak_deps = False

        centos_stream_version = get_centos_stream_version()
        if centos_stream_version:
            base.conf.substitutions["stream"] = centos_stream_version + "-stream"

        base.read_all_repos()

        last_exception = None
        for repeat in range(5):
            try:
                base.fill_sack(load_system_repo=True, load_available_repos=True)
                last_exception = None
                break
            except Exception as e:
                last_exception = e
                time.sleep(3)
                pass
        if last_exception:
            raise last_exception

        # Resolve library names to package names and install
        packages_to_install = []
        for package_name in package_names:
            try:
                # Try direct package installation first
                base.install(package_name)
                packages_to_install.append(package_name)
            except dnf.exceptions.MarkingError:
                # If direct install fails, try to find provider packages
                try:
                    # Search for packages that provide this library/file
                    query = base.sack.query().available()
                    providers = query.filter(provides=package_name)

                    if providers:
                        # Install the first provider found
                        provider_pkg = list(providers)[0]
                        base.install(provider_pkg.name)
                        packages_to_install.append(provider_pkg.name)
                        error_print(
                            "Resolved {} to package {}".format(
                                package_name, provider_pkg.name
                            )
                        )
                    else:
                        error_print("No provider found for {}".format(package_name))
                        continue
                except Exception as e:
                    error_print(
                        "Failed to resolve provider for {}: {}".format(
                            package_name, str(e)
                        )
                    )
                    continue

        if not packages_to_install:
            return []

        try:
            base.resolve()
            not_installed_set = set(base.transaction.install_set)
        except dnf.exceptions.DepsolveError as e:
            error_print("Dependency resolution failed: {}".format(str(e)))
            return []

        if only_list_deps:
            dep_list = []
            for pkg in not_installed_set:
                dep_list.append(str(pkg))
            return dep_list

        base.repos.all().pkgdir = base.conf.destdir
        base.download_packages(not_installed_set)

        downloaded_rpms = set(os.listdir(dest_dir)) - files_in_dest_dir
        downloaded_rpms = [
            os.path.join(dest_dir, f) for f in downloaded_rpms if f.endswith(".rpm")
        ]
        return downloaded_rpms


def download_with_dnf_resolve(dest_dir, package_name, only_list_deps, exist_deps):
    """Legacy single package download for dnf"""
    return download_batch_packages_dnf(dest_dir, [package_name], only_list_deps)


def str_to_bool(v):
    return v.lower() in ("yes", "true", "t", "1")


if __name__ == "__main__":
    package_list = []

    if len(sys.argv) < 4:
        error_print(
            "Destination directory and package name should be set to download packages"
        )
        sys.exit(50)

    dest_dir = sys.argv[1]
    only_list_deps = True if len(sys.argv) >= 4 and str_to_bool(sys.argv[3]) else False

    exist_deps = []
    if len(sys.argv) >= 5:
        exist_deps.extend(sys.argv[4:])

    # Support both single package and batch processing
    if "," in sys.argv[2]:
        # Batch mode - multiple packages separated by commas
        package_names = [pkg.strip() for pkg in sys.argv[2].split(",")]
        error_print("Batch processing {} packages".format(len(package_names)))
    else:
        # Single package mode (legacy compatibility)
        package_names = [sys.argv[2]]

    try:
        if not os.path.isdir(dest_dir):
            os.makedirs(dest_dir)
    except Exception as e:
        error_print(
            "Destination {} directory cannot be opened: ".format(dest_dir) + str(e)
        )

    with silence_stdout():
        package_manager = get_package_manager()
        if package_manager == PackageManager.YUM:
            if len(package_names) > 1:
                package_list = download_batch_packages_yum(
                    dest_dir, package_names, only_list_deps, exist_deps
                )
            else:
                package_list = download_with_yum(
                    dest_dir, package_names[0], only_list_deps, exist_deps
                )
        elif package_manager == PackageManager.DNF:
            package_list = download_batch_packages_dnf(
                dest_dir, package_names, only_list_deps
            )

    print("\n".join(str(pkg) for pkg in package_list))
