/*
 * Copyright 2010-2025 NVIDIA Corporation. All rights reserved
 *
 * Sample application demonstrating the use of CUPTI Callback API to capture the duration 
 * of CUDA runtime APIs, including tracing calls such as cudaMalloc, cudaMemcpy and cudaLaunchKernel.
 *
 */

// System headers
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>

// CUDA headers
#include <cuda.h>
#include <cuda_runtime.h>

// CUPTI headers
#include "cupti.h"
#include "helper_cupti.h"
#include "helper_cupti_activity.h"

// CUDA API Trace Structure
typedef struct
{
    const char *apiName;
    uint64_t startTimestamp;
    uint64_t endTimestamp;
    union 
    {
        struct 
        {
            void **devPtr;
            size_t size;
        } malloc;
        struct 
        {
            size_t count;
            enum cudaMemcpyKind kind;
        } memcpy;
        struct 
        {
            void *func;
            dim3 gridDim;
            dim3 blockDim;
            void **args;
        } kernel;
        // More fields can be added for other APIs
    } params;
} CudaApiEventTrace;

// Dynamic List for Traces
typedef struct
{
    CudaApiEventTrace *data;
    size_t size;
    size_t capacity;
} TraceBuffer;

// Initialize the trace buffer
void 
TraceBufferInit(
    TraceBuffer *tb) 
{
    tb->size = 0;
    tb->capacity = 64;
    tb->data = (CudaApiEventTrace*)calloc(tb->capacity, sizeof(CudaApiEventTrace));
    if (!tb->data) 
    {
        fprintf(stderr, "TraceBuffer alloc failed\n");
        exit(1);
    }
}

void 
TraceBufferFree(
    TraceBuffer *tb) 
{
    free(tb->data);
    tb->data = NULL;
    tb->size = tb->capacity = 0;
}

// Push an event in the trace buffer, and resize it as needed
void 
TraceBufferPush(
    TraceBuffer *tb, 
    CudaApiEventTrace *evt) 
{
    if (tb->size == tb->capacity)
    {
        // resize the buffer
        tb->capacity *= 2;
        tb->data = (CudaApiEventTrace*)realloc(tb->data, tb->capacity * sizeof(CudaApiEventTrace));
        if (!tb->data)
        { 
            fprintf(stderr, "TraceBuffer realloc failed\n");
            exit(1);
        }
    }
    tb->data[tb->size++] = *evt;
}

// Vector addition kernel
__global__ void
VectorAdd(
    const int *pA,
    const int *pB,
    int *pC,
    int N)
{
    int i = blockDim.x * blockIdx.x + threadIdx.x;
    if (i < N)
    {
        pC[i] = pA[i] + pB[i];
    }
}


// Initialize a vector
static void
InitializeVector(
    int *pVector,
    int N)
{
    for (int i = 0; i < N; i++)
    {
        pVector[i] = i;
    }
}

// CUPTI Callback Handler
void CUPTIAPI 
CallbackHandler(
    void *userdata,
    CUpti_CallbackDomain domain,
    CUpti_CallbackId cbid,
    const CUpti_CallbackData *cbInfo)
{
    TraceBuffer *tb = (TraceBuffer *)userdata;
    static CudaApiEventTrace ev = {0};

    if (domain == CUPTI_CB_DOMAIN_RUNTIME_API)
    {
        if (cbInfo->callbackSite == CUPTI_API_ENTER)
        {
            ev.apiName = cbInfo->functionName;

            switch (cbid)
            { 
                case CUPTI_RUNTIME_TRACE_CBID_cudaMalloc_v3020:
                {
                    cudaMalloc_v3020_params *params = (cudaMalloc_v3020_params *)cbInfo->functionParams;
                    ev.params.malloc.devPtr = params->devPtr;
                    ev.params.malloc.size = params->size;
                    break;
                }
                case CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020:
                {
                    cudaMemcpy_v3020_params *params = (cudaMemcpy_v3020_params *)cbInfo->functionParams;
                    ev.params.memcpy.count = params->count;
                    ev.params.memcpy.kind = params->kind;
                    break;
                }
                case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000:
                {
                    cudaLaunchKernel_v7000_params *params = (cudaLaunchKernel_v7000_params *)cbInfo->functionParams;
                    ev.params.kernel.func = (void*)params->func;
                    ev.params.kernel.gridDim = params->gridDim;
                    ev.params.kernel.blockDim = params->blockDim;
                    ev.params.kernel.args = (void **)params->args;
                    break;
                }
                // More CUDA APIs can be handled here
                default:
                    break;
            }
            cuptiGetTimestamp(&ev.startTimestamp);
        }
        else if (cbInfo->callbackSite == CUPTI_API_EXIT)
        {
            cuptiGetTimestamp(&ev.endTimestamp);

            // Store only entered/exited API events of interest
            switch (cbid)
            { 
                case CUPTI_RUNTIME_TRACE_CBID_cudaMalloc_v3020:
                case CUPTI_RUNTIME_TRACE_CBID_cudaMemcpy_v3020:
                case CUPTI_RUNTIME_TRACE_CBID_cudaLaunchKernel_v7000:
                {
                    TraceBufferPush(tb, &ev);
                    break;
                }
                default:
                    break;
            }
        }
    }
}

static const char *
GetMemcpyKind(
    enum cudaMemcpyKind memcpyKind)
{
    switch (memcpyKind)
    {
        case cudaMemcpyHostToDevice:
            return "HostToDevice";
        case cudaMemcpyDeviceToHost:
            return "DeviceToHost";
        default:
            return "<unknown>";
    }
}

// Print events from the trace buffer
void 
PrintTraceBuffer(
    const TraceBuffer *tb)
{
    printf("========= CUDA API Call Timeline =========\n");
    printf("%-25s %-20s %-12s %-12s %-10s\n",
          "API", "Start(ns)", "Dur(ns)", "Bytes", "MemcpyKind / Grid");
    for (size_t i = 0; i < tb->size; ++i)
    {
        const CudaApiEventTrace *ev = &tb->data[i];
        uint64_t dur = ev->endTimestamp - ev->startTimestamp;
        if (strcmp(ev->apiName, "cudaMalloc") == 0)
        {
            printf("%-25s %-20llu %-12llu %-12zu %-10s\n",
                ev->apiName,
                (unsigned long long)ev->startTimestamp, (unsigned long long)dur,
                ev->params.malloc.size, "N/A");
        }
        else if (strcmp(ev->apiName, "cudaMemcpy") == 0) 
        {
            printf("%-25s %-20llu %-12llu %-12zu %-10s\n",
                ev->apiName,
                (unsigned long long)ev->startTimestamp, (unsigned long long)dur,
                ev->params.memcpy.count, GetMemcpyKind(ev->params.memcpy.kind));
        }
        else if (strcmp(ev->apiName, "cudaLaunchKernel") == 0) 
        {
            printf("%-25s %-20llu %-12llu %-12s (%u,%u,%u) (%u,%u,%u)\n",
                ev->apiName,
                (unsigned long long)ev->startTimestamp, (unsigned long long)dur,
                "N/A",
                ev->params.kernel.gridDim.x, ev->params.kernel.gridDim.y, ev->params.kernel.gridDim.z,
                ev->params.kernel.blockDim.x, ev->params.kernel.blockDim.y, ev->params.kernel.blockDim.z);
        }
        else 
        {
            printf("%-25s %-20llu %-12llu      ...\n",
                ev->apiName,
                (unsigned long long)ev->startTimestamp, (unsigned long long)dur);
        }
    }
    printf("===========================================\n");
}

static void
CleanUp(
    int *pHostA,
    int *pHostB,
    int *pHostC,
    int *pDeviceA,
    int *pDeviceB,
    int *pDeviceC)
{
    // Free host memory.
    if (pHostA)
    {
        free(pHostA);
    }
    if (pHostB)
    {
        free(pHostB);
    }
    if (pHostC)
    {
        free(pHostC);
    }

    // Free device memory.
    if (pDeviceA)
    {
        RUNTIME_API_CALL(cudaFree(pDeviceA));
    }
    if (pDeviceB)
    {
        RUNTIME_API_CALL(cudaFree(pDeviceB));
    }
    if (pDeviceC)
    {
        RUNTIME_API_CALL(cudaFree(pDeviceC));
    }
}

int
main(
    int argc,
    char *argv[])
{
    CUcontext context = NULL;
    CUdevice device = 0;
    int N = 50000;
    size_t size = N * sizeof(int);

    int threadsPerBlock = 0;
    int blocksPerGrid = 0;

    int sum, i;

    int *pHostA, *pHostB, *pHostC;
    int *pDeviceA, *pDeviceB, *pDeviceC;

    TraceBuffer tb;
    TraceBufferInit(&tb);

    CUpti_SubscriberHandle subscriber;
    CUPTI_API_CALL_VERBOSE(cuptiSubscribe(&subscriber, (CUpti_CallbackFunc)CallbackHandler, &tb));

    // Enable all callbacks for CUDA Runtime APIs.
    // Callback will be invoked at the entry and exit points of each of the CUDA Runtime API.
    CUPTI_API_CALL_VERBOSE(cuptiEnableDomain(1, subscriber, CUPTI_CB_DOMAIN_RUNTIME_API));

    // Enable the state domain callbacks for instantaneous error reporting.
    CUPTI_API_CALL_VERBOSE(cuptiEnableDomain(1, subscriber, CUPTI_CB_DOMAIN_STATE));

    DRIVER_API_CALL(cuInit(0));
    DRIVER_API_CALL(cuDeviceGet(&device, 0));
    DRIVER_API_CALL(cuCtxCreate(&context, (CUctxCreateParams*)0, 0, device));

    // Allocate input vectors pHostA and pHostB in host memory.
    pHostA = (int *)malloc(size);
    MEMORY_ALLOCATION_CALL(pHostA);

    pHostB = (int *)malloc(size);
    MEMORY_ALLOCATION_CALL(pHostB);

    pHostC = (int *)malloc(size);
    MEMORY_ALLOCATION_CALL(pHostC);

    // Initialize input vectors
    InitializeVector(pHostA, N);
    InitializeVector(pHostB, N);
    memset(pHostC, 0, size);

    // Allocate vectors in device memory.
    RUNTIME_API_CALL(cudaMalloc((void **)&pDeviceA, size));
    RUNTIME_API_CALL(cudaMalloc((void **)&pDeviceB, size));
    RUNTIME_API_CALL(cudaMalloc((void **)&pDeviceC, size));

    // Copy vectors from host memory to device memory.
    RUNTIME_API_CALL(cudaMemcpy(pDeviceA, pHostA, size, cudaMemcpyHostToDevice));
    RUNTIME_API_CALL(cudaMemcpy(pDeviceB, pHostB, size, cudaMemcpyHostToDevice));

    // Invoke kernel
    threadsPerBlock = 256;
    blocksPerGrid = (N + threadsPerBlock - 1) / threadsPerBlock;

    VectorAdd <<< blocksPerGrid, threadsPerBlock >>> (pDeviceA, pDeviceB, pDeviceC, N);
    RUNTIME_API_CALL(cudaGetLastError());
    RUNTIME_API_CALL(cudaDeviceSynchronize());

    // Copy result from device memory to host memory.
    // pHostC contains the result in host memory.
    RUNTIME_API_CALL(cudaMemcpy(pHostC, pDeviceC, size, cudaMemcpyDeviceToHost));

    // Verify result
    for (i = 0; i < N; ++i)
    {
        sum = pHostA[i] + pHostB[i];
        if (pHostC[i] != sum)
        {
            printf("Error: Kernel execution failed.\n");
            goto Error;
        }
    }

    // Display timestamps and API params collected in the callback.
    PrintTraceBuffer(&tb);

    CleanUp(pHostA, pHostB, pHostC, pDeviceA, pDeviceB, pDeviceC);
    RUNTIME_API_CALL(cudaDeviceSynchronize());

    CUPTI_API_CALL_VERBOSE(cuptiUnsubscribe(subscriber));

    TraceBufferFree(&tb);
    exit(EXIT_SUCCESS);

Error:
    CleanUp(pHostA, pHostB, pHostC, pDeviceA, pDeviceB, pDeviceC);
    RUNTIME_API_CALL(cudaDeviceSynchronize());

    exit(EXIT_FAILURE);
}

