/*******************************************************************************
* Copyright (C) 2020 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

/*
! Content:
!       Example of using fftwf_plan_many_dft function on a
!       (GPU) device using the OpenMP target (offload) interface..
!
!****************************************************************************/

#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <float.h>
#include "fftw/fftw3.h"
#include "fftw/offload/fftw3_omp_offload.h"

static void init(fftwf_complex *x,
                 int *N, int M, int *EN, int stride, int dist, int *H);
static int verify(fftwf_complex *x,
                 int *N, int M, int *EN, int stride, int dist, int *H);

int main(void)
{
    /*
     * In this example we perform M in-place 2D FFT on the data
     * contained in a larger array.  The sizes of the FFT are defined
     * by N, the embedding array has dimensions defined by
     * EN and EM, not necessarily in this order.
     */

    /* Sizes of 2D transform and the number of them */
    int N[2] = {7, 13};
    int M = 3;

    /* Sizes of embedding array, stride and distance, to be defined */
    int EN[2], EM, stride, dist;

    /* Arbitrary harmonic used to verify FFT */
    int H[2] = {-1, 1};

    /* FFTW plan handle */
    fftwf_plan plan = 0;

    /* Pointer to input/output data */
    fftwf_complex *x = 0;

    /* Execution status */
    int status = 0;

    const int devNum = 0;

    printf("Example sp_plan_many_dft_2d\n");
    printf("Forward multiple 2D complex in-place FFT\n");
    printf("Configuration parameters:\n");
    printf(" N  = {%d, %d}\n", N[0], N[1]);
    printf(" M  = %d\n", M);
    printf(" H  = {%d, %d}\n", H[0], H[1]);

    printf("Define data layout for PARALLEL transforms\n");
    /*
     * Leading dimension is N (parallel transforms): embedding array
     * has dimensions (EM,EN).
     */
    stride  = 1;
    EN[1]   = N[1] * stride;
    EN[0]   = N[0];
    dist    = EN[0] * EN[1];
    EM     = M;
    printf(" EM=%i, dist=%i, EN={%i, %i}, stride=%i\n",
           EM,dist,EN[0],EN[1],stride);


    printf("Allocate x(%i)\n", EM*dist );
    x  = (fftwf_complex*)fftwf_malloc(sizeof(fftwf_complex) * EM*dist );
    if (0 == x) goto failed;

    printf("Initialize input for forward transform\n");
    init(x, N, M, EN, stride, dist, H);

    printf("Create FFTW plan for forward transform\n");
#pragma omp target data map(tofrom:x[0:EM*dist]) device(devNum)
    {
#pragma omp dispatch device(devNum)
    plan = fftwf_plan_many_dft(2, N, M,
                               x, EN, stride, dist,
                               x, EN, stride, dist,
                               FFTW_FORWARD, FFTW_ESTIMATE);
    if(plan == 0) printf("Call to fftwf_plan_many_dft has failed\n");

    printf("Compute forward FFT\n");
#pragma omp dispatch device(devNum)
    fftwf_execute(plan);
    }

    printf("Verify the result of forward FFT\n");
    status = verify(x, N, M, EN, stride, dist, H);
    if (0 != status) goto failed;

    printf("Destroy FFTW plan\n");
    fftwf_destroy_plan(plan);

    printf("Free data array\n");
    fftwf_free(x);


    printf("\nDefine data layout for VECTOR transforms\n");
    /*
     * Leading dimension is M (vector transforms): embedding array
     * has dimensions (EN,EM).
     */
    dist   = 1;
    EM     = M * dist;
    stride = EM;
    EN[1]  = N[1];
    EN[0]  = N[0];

    printf(" EN={%i,%i}, stride=%i, EM=%i, dist=%i\n",
           EN[0],EN[1],stride,EM,dist);

    printf("Allocate x(%i)\n", EN[0]*EN[1]*stride );
    x  = (fftwf_complex*)fftwf_malloc(sizeof(fftwf_complex) * EN[0]*EN[1]*stride );
    if (0 == x) goto failed;

    printf("Initialize input for forward transform\n");
    init(x, N, M, EN, EM, dist, H);

    printf("Create FFTW plan for forward transform\n");
#pragma omp target data map(tofrom:x[0:EN[0]*EN[1]*stride]) device(devNum)
    {
#pragma omp dispatch device(devNum)
    plan = fftwf_plan_many_dft(2, N, M,
                               x, EN, stride, dist,
                               x, EN, stride, dist,
                               FFTW_FORWARD, FFTW_ESTIMATE);
    if(plan == 0) printf("Call to fftwf_plan_many_dft has failed\n");

    printf("Compute forward FFT\n");
#pragma omp dispatch device(devNum)
    fftwf_execute(plan);
    }

    printf("Verify the result of forward FFT\n");
    status = verify(x, N, M, EN, EM, dist, H);
    if (0 != status) goto failed;

 cleanup:

    printf("Destroy FFTW plan\n");
    fftwf_destroy_plan(plan);

    printf("Free data array\n");
    fftwf_free(x);

    printf("TEST %s\n",0==status ? "PASSED" : "FAILED");
    return status;

 failed:
    printf(" ERROR\n");
    status = 1;
    goto cleanup;
}

/* Compute (K*L)%M accurately */
static float moda(int K, int L, int M)
{
    return (float)(((long long)K * L) % M);
}

/* Initialize arrays x with harmonic H */
static void init(fftwf_complex *x,
                 int *N, int M, int *EN, int stride, int dist, int *H)
{
    float TWOPI = 6.2831853071795864769f, phase;
    int n1, n2, m, N1, N2, S1, S2, SM, H1, H2, index;

    SM = dist;
    S2 = stride;   N2 = N[1], H2 = H[1];
    S1 = EN[1]*S2; N1 = N[0]; H1 = H[0];

    for (m = 0; m < M; m++)
    {
        for(n1 = 0; n1 < N1; n1++)
        {
            for (n2 = 0; n2 < N2; n2++)
            {
                phase  = moda(n1,H1,N1) / N1;
                phase += moda(n2,H2,N2) / N2;
                index = n1*S1 + n2*S2 + m*SM;
                x[index][0] = cosf( TWOPI * phase ) / (N1*N2);
                x[index][1] = sinf( TWOPI * phase ) / (N1*N2);
            }
        }
    }

}

/* Verify that x has unit peak at H */
static int verify(fftwf_complex *x,
                  int *N, int M, int *EN, int stride, int dist, int *H)
{
    float err, errthr, maxerr;
    int n1, n2, m, N1, N2, S1, S2, SM, H1, H2, index;

    SM = dist;
    S2 = stride;   N2 = N[1], H2 = H[1];
    S1 = EN[1]*S2; N1 = N[0]; H1 = H[0];

    /*
     * Note, this simple error bound doesn't take into account error of
     * input data
     */
    errthr = 5.0f * logf( (float)N1 ) / logf(2.0f) * FLT_EPSILON;
    printf(" Check if err is below errthr %.3g\n", errthr);

    maxerr = 0;
    for (m = 0; m < M; m++)
    {
        for (n1 = 0; n1 < N1; n1++)
        {
            for (n2 = 0; n2 < N2; n2++)
            {
                float re_exp = 0.0, im_exp = 0.0, re_got, im_got;

                if ((n1-H1)%N1==0 && (n2-H2)%N2==0)
                {
                    re_exp = 1;
                }

                index = n1*S1 + n2*S2 + m*SM;
                re_got = x[index][0];
                im_got = x[index][1];
                err  = fabsf(re_got - re_exp) + fabsf(im_got - im_exp);
                if (err > maxerr) maxerr = err;
                if (!(err < errthr))
                {
                    printf(" x(n1=%i,m=%i): ",n1,m);
                    printf(" expected (%.7g,%.7g), ",re_exp,im_exp);
                    printf(" got (%.7g,%.7g), ",re_got,im_got);
                    printf(" err %.3g\n", err);
                    printf(" Verification FAILED\n");
                    return 1;
                }
            }
        }
    }
    printf(" Verified, maximum error was %.3g\n", maxerr);
    return 0;
}
