/*******************************************************************************
* Copyright (C) 2025 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
*
*  Content:
*       This example demonstrates use of oneAPI Math Kernel Library (oneMKL)
*       API oneapi::mkl::experimental::dft::distributed_descriptor to perform
*       3-D Double Precision Real to Complex Fast Fourier Transform
*       distributed across SYCL GPU devices.
*
*       The supported floating point data types for data are:
*           double
*           std::complex<double>
*
*******************************************************************************/

#include <mpi.h>
#include <sycl/sycl.hpp>
#include <vector>
#include <iostream>
#include <stdexcept>
#include <cfloat>

#include "oneapi/mkl/experimental/distributed_dft.hpp"
#include "oneapi/mkl/exceptions.hpp"
#include "common_for_examples.hpp"
#include "mkl.h" // mkl_malloc

using distributed_desc_t = oneapi::mkl::experimental::dft::distributed_descriptor<oneapi::mkl::dft::precision::DOUBLE, oneapi::mkl::dft::domain::REAL>;

constexpr bool SUCCESS = true;
constexpr bool FAILURE = false;
constexpr double TWOPI = 6.2831853071795864769;

// Initialize array data(N) to produce unit peaks at data(H) and data(N-H)
static void init_r(double *data,
                   int N0, int N1, int N2,
                   int H0, int H1, int H2,
                   int mpi_rank, int mpi_nproc)
{
    // Strides for row-major addressing of data
    // within the local slab
    int S0 = 1, S1 = (N0/2+1)*2, S2 = N1*(N0/2+1)*2;

    double factor =
        ((2 * (N0 - H0) % N0 == 0) &&
         (2 * (N1 - H1) % N1 == 0) &&
         (2 * (N2 - H2) % N2 == 0)) ? 1.0 : 2.0;

    int N2_local = distribute(N2, mpi_rank, mpi_nproc);

    for (int n2 = 0; n2 < N2_local; n2++) {
        for (int n1 = 0; n1 < N1; n1++) {
            for (int n0 = 0; n0 < N0; n0++) {
                double phase = moda<double>(n0, H0, N0) / N0 +
                               moda<double>(n1, H1, N1) / N1 +
                               moda<double>(global_index(n2, N2, mpi_rank, mpi_nproc), H2, N2) / N2;
                int index = n2*S2 + n1*S1 + n0*S0;
                data[index] = factor * cos(TWOPI * phase) / (N2*N1*N0);
            }
        }
    }
}

// Verify that data has unit peak at H
static bool verify_c(const double* data,
                    int N0, int N1, int N2,
                    int H0, int H1, int H2,
                    int mpi_rank, int mpi_nproc)
{
    // Note: this simple error bound doesn't take into account error of
    //       input data
    double errthr = 2.5 * log((double) N2*N1*N0) / log(2.0) * DBL_EPSILON;
    if(mpi_rank == 0) std::cout << "\t\tVerify the result, errthr = " << errthr << std::endl;

    // Strides for row-major addressing of data
    // within the local slab
    int N1_local = distribute(N1, mpi_rank, mpi_nproc);
    int S0 = 1, S1 = N0/2+1;
    int S2 = N1_local*S1;

    int mpi_err;
    bool status = SUCCESS;
    double maxerr = 0.0;
    for (int n2 = 0; n2 < N2; n2++) {
        for (int n1 = 0; n1 < N1_local; n1++) {
            for (int n0 = 0; n0 < N0/2+1; n0++) {
                double re_exp = (
                        ((n0 - H0) % N0 == 0) &&
                        ((global_index(n1, N1, mpi_rank, mpi_nproc) - H1) % N1 == 0) &&
                        ((n2 - H2) % N2 == 0)
                    ) || (
                        ((-n0 - H0) % N0 == 0) &&
                        ((-global_index(n1, N1, mpi_rank, mpi_nproc) - H1) % N1 == 0) &&
                        ((-n2 - H2) % N2 == 0)
                    ) ? 1.0 : 0.0;
                double im_exp = 0.0;

                int index = n2*S2 + n1*S1 + n0*S0;
                double re_got = data[index*2+0];
                double im_got = data[index*2+1];
                double err  = fabs(re_got - re_exp) + fabs(im_got - im_exp);
                if (err > maxerr) maxerr = err;
                if (!(err < errthr)) {
                    std::cout << "\t\t On process:" << mpi_rank << ", data"
                              << "[" << n2 << "][" << n1 << "][" << n0 << "]: "
                              << " expected (" << re_exp << "," << im_exp << ")"
                              << " got (" << re_got << "," << im_got << ")"
                              << " err " << err << std::endl;
                    std::cout << "\t\tVerification FAILED" << std::endl;
                    status = FAILURE;
                    goto done;
                }
            }
        }
    }

    done:
        mpi_err = MPI_Allreduce(MPI_IN_PLACE, &status, 1,
                                MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
        if(mpi_err != MPI_SUCCESS) {
            std::cout << "MPI_AllReduce error" << std::endl;
            return FAILURE;
        }
        if(status == FAILURE) return FAILURE;

    mpi_err =  MPI_Reduce(MPI_IN_PLACE, &maxerr, 1, MPI_DOUBLE,
                              MPI_MAX, 0, MPI_COMM_WORLD);
    if(mpi_err != MPI_SUCCESS) {
        std::cout << "MPI_Reduce error" << std::endl;
        return FAILURE;
    }

    if(mpi_rank == 0)
        std::cout << "\t\tVerified, maximum error was " << maxerr << std::endl;

    return SUCCESS;
}

// Initialize array data(N) to produce unit peak at data(H)
static void init_c(double *data,
                   int N0, int N1, int N2,
                   int H0, int H1, int H2,
                   int mpi_rank, int mpi_nproc)
{
    int N1_local = distribute(N1, mpi_rank, mpi_nproc);
    // Strides for row-major addressing of data
    // within the local slab
    int S0 = 1, S1 = N0/2+1, S2 = N1_local*(N0/2+1);

    for (int n2 = 0; n2 < N2; n2++) {
        for (int n1 = 0; n1 < N1_local; n1++) {
            for (int n0 = 0; n0 < N0/2+1; n0++) {
                double phase = moda<double>(n0, H0, N0) / N0 +
                               moda<double>(global_index(n1, N1, mpi_rank, mpi_nproc), H1, N1) / N1 +
                               moda<double>(n2, H2, N2) / N2;
                int index = n2*S2 + n1*S1 + n0*S0;
                data[index*2+0] =  cos(TWOPI * phase) / (N2*N1*N0);
                data[index*2+1] = -sin(TWOPI * phase) / (N2*N1*N0);
            }
        }
    }
}

/* Verify that data has unit peak at H */
static bool verify_r(const double* data,
                    int N0, int N1, int N2,
                    int H0, int H1, int H2,
                    int mpi_rank, int mpi_nproc)
{
    // Strides for row-major addressing of data
    // within the local slab
    int S0 = 1, S1 = (N0/2+1)*2, S2 = N1*(N0/2+1)*2;
    int N2_local = distribute(N2, mpi_rank, mpi_nproc);
    // Note: this simple error bound doesn't take into account error of
    //       input data
    double errthr = 2.5 * log((double) N2*N1*N0) / log(2.0) * DBL_EPSILON;
    if(mpi_rank == 0) std::cout << "\t\tVerify the result, errthr = " << errthr << std::endl;

    int mpi_err;
    bool status = SUCCESS;
    double maxerr = 0.0;
    for (int n2 = 0; n2 < N2_local; n2++) {
        for (int n1 = 0; n1 < N1; n1++) {
            for (int n0 = 0; n0 < N0; n0++) {
                double re_exp = (
                    ((n0 - H0) % N0 == 0) &&
                    ((n1 - H1) % N1 == 0) &&
                    ((global_index(n2, N2, mpi_rank, mpi_nproc) - H2) % N2 == 0)) ? 1.0 : 0.0;

                int index = n2*S2 + n1*S1 + n0*S0;
                double re_got = data[index];
                double err  = fabs(re_got - re_exp);
                if (err > maxerr) maxerr = err;
                if (!(err < errthr)) {
                    std::cout << "\t\t On process:" << mpi_rank << ", data"
                              << "[" << n2 << "][" << n1 << "][" << n0 << "]: "
                              << " expected (" << re_exp << ")"
                              << " got (" << re_got << ")"
                              << " err " << err << std::endl;
                    std::cout << "\t\tVerification FAILED" << std::endl;
                    status = FAILURE;
                    goto done;
                }
            }
        }
    }

    done:
        mpi_err = MPI_Allreduce(MPI_IN_PLACE, &status, 1,
                                MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
        if(mpi_err != MPI_SUCCESS) {
            std::cout << "MPI_AllReduce error" << std::endl;
            return FAILURE;
        }
        if(status == FAILURE) return FAILURE;

    mpi_err =  MPI_Reduce(MPI_IN_PLACE, &maxerr, 1, MPI_DOUBLE,
                              MPI_MAX, 0, MPI_COMM_WORLD);
    if(mpi_err != MPI_SUCCESS) {
        std::cout << "MPI_Reduce error" << std::endl;
        return FAILURE;
    }

    if(mpi_rank == 0)
        std::cout << "\t\tVerified, maximum error was " << maxerr << std::endl;

    return SUCCESS;
}

bool run_dft_forward_example(sycl::device &dev, int mpi_rank, int mpi_nproc) {
    //
    // Initialize data for DFT
    //
    int N0 = 64, N1 = 128, N2 = 120;
    // Arbitrary harmonic used to verify FFT
    int H0 = -1, H1 = 2, H2 = 4;
    // Distributed DFT assumes the same default forward and backward
    // strides as the DFT SYCL APIs corresponding to packed data layouts, i.e.
    // minimal padding in forward domain to satisfy the *in-place* consistency
    // requirement for the real transforms. In this particular example, this
    // translates into the following default strides for the global data
    // std::vector<std::int64_t> fwd_strides{0, 2*N1*(N0/2+1), 2*(N0/2+1), 1};
    // std::vector<std::int64_t> bwd_strides{0,   N1*(N0/2+1),   (N0/2+1), 1};
    // being set when creating the relevant descriptor. This example uses the
    // above default strides.
    bool result = FAILURE;
    bool alloc_result = SUCCESS;

    //
    // Execute DFT
    //
    // Catch asynchronous exceptions
    auto exception_handler = [] (sycl::exception_list exceptions) {
        for (std::exception_ptr const& e : exceptions) {
            try {
                std::rethrow_exception(e);
            } catch(sycl::exception const& e) {
                std::cout << "Caught asynchronous SYCL exception:" << std::endl
                          << e.what() << std::endl;
            }
        }
    };

    // create execution queue with asynchronous error handling
    sycl::queue queue(dev, exception_handler);

    // Allocate local memory and initialize
    // based on default slab decomposition
    auto N0_bwd = N0/2+1;
    auto N2_local = distribute(N2, mpi_rank, mpi_nproc);
    auto N1_local = distribute(N1, mpi_rank, mpi_nproc);
    auto fwd_size = 2 * N0_bwd * N1 * N2_local;
    auto bwd_size = 2 * N0_bwd * N1_local * N2;

    double *x = (double*) mkl_malloc(fwd_size*sizeof(double), 64);
    double *y = (double*) mkl_malloc(bwd_size*sizeof(double), 64);
    if (!x || !y) alloc_result = FAILURE;
    int mpi_err = MPI_Allreduce(MPI_IN_PLACE, &alloc_result, 1,
                                MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
    if (mpi_err != MPI_SUCCESS) {
        std::cout << "MPI_Allreduce error" << std::endl;
        return FAILURE;
    }

    if(alloc_result == FAILURE) {
        mkl_free(x);
        mkl_free(y);
        throw std::runtime_error("Failed to allocate memory using mkl_free");
    }

    init_r(x, N0, N1, N2, H0, H1, H2, mpi_rank, mpi_nproc);

    distributed_desc_t desc(MPI_COMM_WORLD, {N2, N1, N0});
    // Default behavior is as if
    // desc.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, fwd_strides);
    // desc.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, bwd_strides);
    // were used with the values of fwd_strides and bwd_strides as specified
    // above.
    desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
                   oneapi::mkl::dft::config_value::NOT_INPLACE);
    // Default slab decomposition behavior is as if the following were set,
    // desc.set_value(oneapi::mkl::experimental::dft::distributed_config_param::fwd_divided_dimension,
    //                0);
    // desc.set_value(oneapi::mkl::experimental::dft::distributed_config_param::bwd_divided_dimension,
    //                1);
    desc.commit(queue);

    // Get the size of local USM memory to be allocated after commit
    std::int64_t fwd_usm_bytes, bwd_usm_bytes;
    desc.get_value(oneapi::mkl::experimental::dft::distributed_config_param::fwd_local_data_size_bytes, &fwd_usm_bytes);
    desc.get_value(oneapi::mkl::experimental::dft::distributed_config_param::bwd_local_data_size_bytes, &bwd_usm_bytes);
    double *x_usm = (double *)malloc_device(fwd_usm_bytes, queue);
    double *y_usm = (double *)malloc_device(bwd_usm_bytes, queue);
    if(!x_usm || !y_usm) alloc_result = FAILURE;
    mpi_err = MPI_Allreduce(MPI_IN_PLACE, &alloc_result, 1,
                                MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
    if (mpi_err != MPI_SUCCESS) {
        std::cout << "MPI_Allreduce error" << std::endl;
        return FAILURE;
    }

    if(alloc_result == FAILURE) {
        mkl_free(x);
        mkl_free(y);
        if(x_usm) sycl::free(x_usm, queue);
        if(y_usm) sycl::free(y_usm, queue);
        throw std::runtime_error("Failed to allocate USM memory");
    }

    sycl::event copy_ev = queue.memcpy(x_usm, x, fwd_size*sizeof(double));

    oneapi::mkl::experimental::dft::compute_forward(desc, x_usm, y_usm, {copy_ev}).wait();
    queue.memcpy(y, y_usm, bwd_size*sizeof(double)).wait();
    result = verify_c(y, N0, N1, N2, H0, H1, H2, mpi_rank, mpi_nproc);

    sycl::free(x_usm, queue);
    sycl::free(y_usm, queue);
    mkl_free(x);
    mkl_free(y);

    return result;
}

bool run_dft_backward_example(sycl::device &dev, int mpi_rank, int mpi_nproc) {
    //
    // Initialize data for DFT
    //
    int N0 = 64, N1 = 128, N2 = 120;
    // Arbitrary harmonic used to verify FFT
    int H0 = -1, H1 = 2, H2 = 4;
    // Distributed DFT assumes the same default forward and backward
    // strides as the DFT SYCL APIs corresponding to packed data layouts, i.e.
    // minimal padding in forward domain to satisfy the *in-place* consistency
    // requirement for the real transforms. In this particular example, this
    // translates into the following default strides for the global data
    // std::vector<std::int64_t> fwd_strides{0, 2*N1*(N0/2+1), 2*(N0/2+1), 1};
    // std::vector<std::int64_t> bwd_strides{0,   N1*(N0/2+1),   (N0/2+1), 1};
    // being set when creating the relevant descriptor. This example uses the
    // above default strides.
    bool result = FAILURE;
    bool alloc_result = SUCCESS;

    //
    // Execute DFT
    //
    // Catch asynchronous exceptions
    auto exception_handler = [] (sycl::exception_list exceptions) {
        for (std::exception_ptr const& e : exceptions) {
            try {
                std::rethrow_exception(e);
            } catch(sycl::exception const& e) {
                std::cout << "Caught asynchronous SYCL exception:" << std::endl
                          << e.what() << std::endl;
            }
        }
    };

    // create execution queue with asynchronous error handling
    sycl::queue queue(dev, exception_handler);

    // Allocate local memory and initialize
    // based on default slab decomposition
    auto N0_bwd = N0/2+1;
    auto N2_local = distribute(N2, mpi_rank, mpi_nproc);
    auto N1_local = distribute(N1, mpi_rank, mpi_nproc);
    auto fwd_size = 2 * N0_bwd * N1 * N2_local;
    auto bwd_size = 2 * N0_bwd * N1_local * N2;

    double *x = (double*) mkl_malloc(fwd_size*sizeof(double), 64);
    double *y = (double*) mkl_malloc(bwd_size*sizeof(double), 64);
    if (!x || !y) alloc_result = FAILURE;
    int mpi_err = MPI_Allreduce(MPI_IN_PLACE, &alloc_result, 1,
                                MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
    if (mpi_err != MPI_SUCCESS) {
        std::cout << "MPI_Allreduce error" << std::endl;
        return FAILURE;
    }

    if(alloc_result == FAILURE) {
        mkl_free(x);
        mkl_free(y);
        throw std::runtime_error("Failed to allocate memory using mkl_free");
    }

    init_c(y, N0, N1, N2, H0, H1, H2, mpi_rank, mpi_nproc);

    distributed_desc_t desc(MPI_COMM_WORLD, {N2, N1, N0});
    // Default behavior is as if
    // desc.set_value(oneapi::mkl::dft::config_param::FWD_STRIDES, fwd_strides);
    // desc.set_value(oneapi::mkl::dft::config_param::BWD_STRIDES, bwd_strides);
    // were used with the values of fwd_strides and bwd_strides as specified
    // above
    desc.set_value(oneapi::mkl::dft::config_param::PLACEMENT,
                   oneapi::mkl::dft::config_value::NOT_INPLACE);
    // Default slab decomposition behavior is as if the following were set,
    // desc.set_value(oneapi::mkl::experimental::dft::distributed_config_param::fwd_divided_dimension,
    //                0);
    // desc.set_value(oneapi::mkl::experimental::dft::distributed_config_param::bwd_divided_dimension,
    //                1);
    desc.commit(queue);

    // Get the size of local USM memory to be allocated after commit
    std::int64_t fwd_usm_bytes, bwd_usm_bytes;
    desc.get_value(oneapi::mkl::experimental::dft::distributed_config_param::fwd_local_data_size_bytes, &fwd_usm_bytes);
    desc.get_value(oneapi::mkl::experimental::dft::distributed_config_param::bwd_local_data_size_bytes, &bwd_usm_bytes);
    double *x_usm = (double *)malloc_device(fwd_usm_bytes, queue);
    double *y_usm = (double *)malloc_device(bwd_usm_bytes, queue);
    if(!x_usm || !y_usm) alloc_result = FAILURE;
    mpi_err = MPI_Allreduce(MPI_IN_PLACE, &alloc_result, 1,
                                MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
    if (mpi_err != MPI_SUCCESS) {
        std::cout << "MPI_Allreduce error" << std::endl;
        return FAILURE;
    }

    if(alloc_result == FAILURE) {
        mkl_free(x);
        mkl_free(y);
        if(x_usm) sycl::free(x_usm, queue);
        if(y_usm) sycl::free(y_usm, queue);
        throw std::runtime_error("Failed to allocate USM memory");
    }

    sycl::event copy_ev = queue.memcpy(y_usm, y, bwd_size*sizeof(double));

    oneapi::mkl::experimental::dft::compute_backward(desc, y_usm, x_usm).wait();
    queue.memcpy(x, x_usm, fwd_size*sizeof(double)).wait();
    result = verify_r(x, N0, N1, N2, H0, H1, H2, mpi_rank, mpi_nproc);

    sycl::free(x_usm, queue);
    sycl::free(y_usm, queue);
    mkl_free(x);
    mkl_free(y);

    return result;
}

void print_example_banner() {
    std::cout << "" << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << "# Distributed 3D GPU FFT Real-Complex Double-Precision Example: " << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Using apis:" << std::endl;
    std::cout << "#   experimental::dft" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "# Supported floating point type precisions:" << std::endl;
    std::cout << "#   double" << std::endl;
    std::cout << "#   std::complex<double>" << std::endl;
    std::cout << "# " << std::endl;
    std::cout << "########################################################################" << std::endl;
    std::cout << std::endl;
}

//
// Main entry point for example.
//
// Dispatches to appropriate device types as set at build time with flag:
// -DSYCL_DEVICES_gpu -- only supports SYCL GPU implementation
//
int main(int argc, char **argv) {
    int mpi_err = MPI_Init(&argc, &argv);
    if (mpi_err != MPI_SUCCESS) {
        std::cout << "MPI initialization error" << std::endl;
        std::cout << "Test Failed" << std::endl;
        return mpi_err;
    }

    int mpi_rank, mpi_nproc;
    MPI_Comm_rank(MPI_COMM_WORLD, &mpi_rank);
    MPI_Comm_size(MPI_COMM_WORLD, &mpi_nproc);

    if(mpi_rank == 0) print_example_banner();

    std::list<my_sycl_device_types> list_of_devices;
    set_list_of_devices(list_of_devices);

    int returnCode = 0;
    for (auto it = list_of_devices.begin(); it != list_of_devices.end(); ++it) {
        sycl::device my_dev;
        bool my_dev_is_found = false;
        get_sycl_device(my_dev, my_dev_is_found, *it);
        bool dev_found_and_is_gpu = my_dev_is_found && my_dev.is_gpu();
        // Check if all processes found the GPU
        mpi_err = MPI_Allreduce(MPI_IN_PLACE, &dev_found_and_is_gpu, 1,
                                MPI_CXX_BOOL, MPI_LAND, MPI_COMM_WORLD);
        if (mpi_err != MPI_SUCCESS) {
            std::cout << "MPI_Allreduce error" << std::endl;
            std::cout << "Test Failed" << std::endl;
            return mpi_err;
        }

        if (dev_found_and_is_gpu) {
            bool status;
            if(mpi_rank == 0) {
                std::cout << "Running tests on " << sycl_device_names[*it] << " with " << mpi_nproc << " processes" <<".\n";
                std::cout << "\tRunning with double precision real-to-complex distributed 3-D FFT:" << std::endl;
            }
            try {
                status = run_dft_forward_example(my_dev, mpi_rank, mpi_nproc);
                mpi_err = MPI_Reduce(MPI_IN_PLACE, &status, 1, MPI_CXX_BOOL,
                                     MPI_LAND, 0, MPI_COMM_WORLD);
                if (mpi_err != MPI_SUCCESS) {
                    std::cout << "MPI_Reduce error" << std::endl;
                    std::cout << "Test Failed" << std::endl;
                    return mpi_err;
                }

                if(mpi_rank == 0) {
                    if (status != SUCCESS) {
                        std::cout << "\tTest Forward Failed" << std::endl << std::endl;
                        returnCode = 1;
                    } else {
                        std::cout << "\tTest Forward Passed" << std::endl << std::endl;
                    }
                }
                status = run_dft_backward_example(my_dev, mpi_rank, mpi_nproc);
                mpi_err = MPI_Reduce(MPI_IN_PLACE, &status, 1, MPI_CXX_BOOL,
                                     MPI_LAND, 0, MPI_COMM_WORLD);
                if (mpi_err != MPI_SUCCESS) {
                    std::cout << "MPI_Reduce error" << std::endl;
                    std::cout << "Test Failed" << std::endl;
                    return mpi_err;
                }

                if(mpi_rank == 0) {
                    if (status != SUCCESS) {
                        std::cout << "\tTest Backward Failed" << std::endl << std::endl;
                        returnCode = 1;
                    } else {
                        std::cout << "\tTest Backward Passed" << std::endl << std::endl;
                    }
                }
            } catch(sycl::exception const& e) {
                std::cout << "\t\tSYCL exception during FFT" << std::endl;
                std::cout << "\t\t" << e.what() << std::endl;
                std::cout << "\t\tError code: " << e.code().value() << std::endl;
                returnCode = 1;
            }
            catch(oneapi::mkl::exception const& e) {
                std::cout << "\t\toneMKL exception during FFT" << std::endl;
                std::cout << "\t\t" << e.what() << std::endl;
                returnCode = 1;
            }
            catch(std::runtime_error const& e) {
                std::cout << "\t\tRuntime exception during FFT" << std::endl;
                std::cout << "\t\t" << e.what() << std::endl;
                returnCode = 1;
            }
        } else if(my_dev_is_found) {
            std::cout << "Distributed DFT does not support " << sycl_device_names[*it] << " device for now; skipping tests" << std::endl;
        } else {
            if(mpi_rank == 0) {
#ifdef FAIL_ON_MISSING_DEVICES
                std::cout << "No " << sycl_device_names[*it] << " devices found; Fail on missing devices is enabled." << std::endl;
                return 1;
#else
                std::cout << "No " << sycl_device_names[*it] << " devices found; skipping " << sycl_device_names[*it] << " tests." << std::endl;
#endif
            }
        }
    }

    mkl_free_buffers();
    MPI_Finalize();
    return returnCode;
}
