/*
 * Copyright (c) 2024-2025, NVIDIA CORPORATION & AFFILIATES
 * SPDX-License-Identifier: BSD-3-Clause
 */

#include "mpi.h"
#include "cudss.h"

#include <stdio.h>

extern "C" {

static inline MPI_Datatype cuda_to_mpi_type(cudaDataType_t type) {
    switch(type) {
    case CUDA_R_32F: return MPI_FLOAT;
    case CUDA_R_64F: return MPI_DOUBLE;
    case CUDA_R_32I: return MPI_INT;
    case CUDA_R_64I: return MPI_LONG_LONG_INT;
    //TODO: Proper error
    default: exit(1);
    }
}

static inline MPI_Op cudss_to_mpi_op(cudssOpType_t op) {
    switch(op) {
    case CUDSS_SUM: return MPI_SUM;
    case CUDSS_MAX: return MPI_MAX;
    case CUDSS_MIN: return MPI_MIN;
    //TODO: Proper error
    default: exit(1);
    }
}

int cudssCommRank(void *comm, int *rank)
{
    return MPI_Comm_rank(*((MPI_Comm*)comm), rank);
}

int cudssCommSize(void *comm, int *size)
{
    return MPI_Comm_size(*((MPI_Comm*)comm), size);
}

int cudssSend(const void *buffer, int count, cudaDataType_t datatype, int dest,
    int tag, void *comm, cudaStream_t stream)
{
    cudaStreamSynchronize(stream);
    return MPI_Send(buffer, count, cuda_to_mpi_type(datatype), dest, tag,
        *((MPI_Comm*)comm));
}

int cudssRecv(void *buffer, int count, cudaDataType_t datatype, int root,
    int tag, void *comm, cudaStream_t stream)
{
    cudaStreamSynchronize(stream);
    return MPI_Recv(buffer, count, cuda_to_mpi_type(datatype), root, tag,
        *((MPI_Comm*)comm), MPI_STATUS_IGNORE);
}

int cudssBcast(void *buffer, int count, cudaDataType_t datatype, int root,
    void *comm, cudaStream_t stream)
{
    cudaStreamSynchronize(stream);
    return MPI_Bcast(buffer, count, cuda_to_mpi_type(datatype), root,
        *((MPI_Comm*)comm));
}

int cudssReduce(const void *sendbuf, void *recvbuf, int count,
    cudaDataType_t datatype, cudssOpType_t op, int root, void *comm,
    cudaStream_t stream)
{
    cudaStreamSynchronize(stream);
    int rank;
    MPI_Comm_rank(*((MPI_Comm*)comm), &rank);
    if (sendbuf == recvbuf && rank == root)
        return MPI_Reduce(MPI_IN_PLACE, recvbuf, count, cuda_to_mpi_type(datatype),
            cudss_to_mpi_op(op), root, *((MPI_Comm*)comm));
    else
        return MPI_Reduce(sendbuf, recvbuf, count, cuda_to_mpi_type(datatype),
            cudss_to_mpi_op(op), root, *((MPI_Comm*)comm));
}

int cudssAllreduce(const void *sendbuf, void *recvbuf, int count,
    cudaDataType_t datatype, cudssOpType_t op, void *comm, cudaStream_t stream)
{
    cudaStreamSynchronize(stream);
    if (sendbuf == recvbuf)
        return MPI_Allreduce(MPI_IN_PLACE, recvbuf, count, cuda_to_mpi_type(datatype),
            cudss_to_mpi_op(op), *((MPI_Comm*)comm));
    else
        return MPI_Allreduce(sendbuf, recvbuf, count, cuda_to_mpi_type(datatype),
            cudss_to_mpi_op(op), *((MPI_Comm*)comm));
}

int cudssScatterv(const void *sendbuf, const int *sendcounts,
    const int *displs, cudaDataType_t sendtype, void *recvbuf, int recvcount,
    cudaDataType_t recvtype, int root, void *comm, cudaStream_t stream)
{
    cudaStreamSynchronize(stream);
    int rank;
    MPI_Comm_rank(*((MPI_Comm*)comm), &rank);
    if (sendbuf == recvbuf && rank == root)
        return MPI_Scatterv(sendbuf, sendcounts, displs, cuda_to_mpi_type(sendtype),
            MPI_IN_PLACE, recvcount, cuda_to_mpi_type(recvtype), root,
            *((MPI_Comm*)comm));
    else
        return MPI_Scatterv(sendbuf, sendcounts, displs, cuda_to_mpi_type(sendtype),
            recvbuf, recvcount, cuda_to_mpi_type(recvtype), root,
            *((MPI_Comm*)comm));
}

int cudssCommSplit(const void *comm, int color, int key, void *newcomm)
{
    return MPI_Comm_split(*((MPI_Comm*)comm), color, key, (MPI_Comm*)newcomm);
}

int cudssCommFree(void *comm)
{
    return MPI_Comm_free((MPI_Comm*)comm);
}

/*
 * Distributed communication service API wrapper binding table (imported by cuDSS).
 * The exposed C symbol must be named as "cudssDistributedInterface".
 */
cudssDistributedInterface_t cudssDistributedInterface = {
    cudssCommRank,
    cudssCommSize,
    cudssSend,
    cudssRecv,
    cudssBcast,
    cudssReduce,
    cudssAllreduce,
    cudssScatterv,
    cudssCommSplit,
    cudssCommFree
};

} // extern "C"
