//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

#ifndef _CUDA_STD_NUMBERS
#define _CUDA_STD_NUMBERS

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/std/__floating_point/nvfp_types.h>
#include <cuda/std/__type_traits/enable_if.h>
#include <cuda/std/__type_traits/is_floating_point.h>
#include <cuda/std/version>

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <class _Tp>
struct __numbers_ill_formed_impl : false_type
{};

template <class _Tp, class = void>
struct __numbers
{
  static_assert(__numbers_ill_formed_impl<_Tp>::value,
                "[math.constants] A program that instantiates a primary template of a mathematical constant variable "
                "template is ill-formed.");
};

_CCCL_DIAG_PUSH
_CCCL_DIAG_SUPPRESS_MSVC(4305) // truncation from 'double' to 'const _Tp'

template <class _Tp>
struct __numbers<_Tp, enable_if_t<_CCCL_TRAIT(is_floating_point, _Tp)>>
{
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __e() noexcept
  {
    return 2.718281828459045235360287471352662;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __log2e() noexcept
  {
    return 1.442695040888963407359924681001892;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __log10e() noexcept
  {
    return 0.434294481903251827651128918916605;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __pi() noexcept
  {
    return 3.141592653589793238462643383279502;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __inv_pi() noexcept
  {
    return 0.318309886183790671537767526745028;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __inv_sqrtpi() noexcept
  {
    return 0.564189583547756286948079451560772;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __ln2() noexcept
  {
    return 0.693147180559945309417232121458176;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __ln10() noexcept
  {
    return 2.302585092994045684017991454684364;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __sqrt2() noexcept
  {
    return 1.414213562373095048801688724209698;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __sqrt3() noexcept
  {
    return 1.732050807568877293527446341505872;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __inv_sqrt3() noexcept
  {
    return 0.577350269189625764509148780501957;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __egamma() noexcept
  {
    return 0.577215664901532860606512090082402;
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr _Tp __phi() noexcept
  {
    return 1.618033988749894848204586834365638;
  }
};

_CCCL_DIAG_POP

#if _LIBCUDACXX_HAS_NVFP16()
template <>
struct __numbers<__half>
{
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __e() noexcept
  {
    return __half_raw{0x4170u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __log2e() noexcept
  {
    return __half_raw{0x3dc5u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __log10e() noexcept
  {
    return __half_raw{0x36f3u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __pi() noexcept
  {
    return __half_raw{0x4248u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __inv_pi() noexcept
  {
    return __half_raw{0x3518u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __inv_sqrtpi() noexcept
  {
    return __half_raw{0x3883u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __ln2() noexcept
  {
    return __half_raw{0x398cu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __ln10() noexcept
  {
    return __half_raw{0x409bu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __sqrt2() noexcept
  {
    return __half_raw{0x3da8u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __sqrt3() noexcept
  {
    return __half_raw{0x3eeeu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __inv_sqrt3() noexcept
  {
    return __half_raw{0x389eu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __egamma() noexcept
  {
    return __half_raw{0x389eu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __half __phi() noexcept
  {
    return __half_raw{0x3e79u};
  }
};
#endif // _LIBCUDACXX_HAS_NVFP16()

#if _LIBCUDACXX_HAS_NVBF16()
template <>
struct __numbers<__nv_bfloat16>
{
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __e() noexcept
  {
    return __nv_bfloat16_raw{0x402eu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __log2e() noexcept
  {
    return __nv_bfloat16_raw{0x3fb9u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __log10e() noexcept
  {
    return __nv_bfloat16_raw{0x3edeu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __pi() noexcept
  {
    return __nv_bfloat16_raw{0x4049u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __inv_pi() noexcept
  {
    return __nv_bfloat16_raw{0x3ea3u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __inv_sqrtpi() noexcept
  {
    return __nv_bfloat16_raw{0x3f10u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __ln2() noexcept
  {
    return __nv_bfloat16_raw{0x3f31u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __ln10() noexcept
  {
    return __nv_bfloat16_raw{0x4013u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __sqrt2() noexcept
  {
    return __nv_bfloat16_raw{0x3fb5u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __sqrt3() noexcept
  {
    return __nv_bfloat16_raw{0x3fdeu};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __inv_sqrt3() noexcept
  {
    return __nv_bfloat16_raw{0x3f14u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __egamma() noexcept
  {
    return __nv_bfloat16_raw{0x3f14u};
  }
  static _LIBCUDACXX_HIDE_FROM_ABI constexpr __nv_bfloat16 __phi() noexcept
  {
    return __nv_bfloat16_raw{0x3fcfu};
  }
};
#endif // _LIBCUDACXX_HAS_NVBF16()

namespace numbers
{

template <class _Tp>
inline constexpr _Tp e_v = __numbers<_Tp>::__e();
template <class _Tp>
inline constexpr _Tp log2e_v = __numbers<_Tp>::__log2e();
template <class _Tp>
inline constexpr _Tp log10e_v = __numbers<_Tp>::__log10e();
template <class _Tp>
inline constexpr _Tp pi_v = __numbers<_Tp>::__pi();
template <class _Tp>
inline constexpr _Tp inv_pi_v = __numbers<_Tp>::__inv_pi();
template <class _Tp>
inline constexpr _Tp inv_sqrtpi_v = __numbers<_Tp>::__inv_sqrtpi();
template <class _Tp>
inline constexpr _Tp ln2_v = __numbers<_Tp>::__ln2();
template <class _Tp>
inline constexpr _Tp ln10_v = __numbers<_Tp>::__ln10();
template <class _Tp>
inline constexpr _Tp sqrt2_v = __numbers<_Tp>::__sqrt2();
template <class _Tp>
inline constexpr _Tp sqrt3_v = __numbers<_Tp>::__sqrt3();
template <class _Tp>
inline constexpr _Tp inv_sqrt3_v = __numbers<_Tp>::__inv_sqrt3();
template <class _Tp>
inline constexpr _Tp egamma_v = __numbers<_Tp>::__egamma();
template <class _Tp>
inline constexpr _Tp phi_v = __numbers<_Tp>::__phi();

#if !_CCCL_COMPILER(MSVC)
// MSVC errors here because of "error: A __device__ variable template cannot have a const qualified type on Windows"
#  if _LIBCUDACXX_HAS_NVFP16()
template <>
_CCCL_GLOBAL_CONSTANT __half e_v<__half> = __numbers<__half>::__e();
template <>
_CCCL_GLOBAL_CONSTANT __half log2e_v<__half> = __numbers<__half>::__log2e();
template <>
_CCCL_GLOBAL_CONSTANT __half log10e_v<__half> = __numbers<__half>::__log10e();
template <>
_CCCL_GLOBAL_CONSTANT __half pi_v<__half> = __numbers<__half>::__pi();
template <>
_CCCL_GLOBAL_CONSTANT __half inv_pi_v<__half> = __numbers<__half>::__inv_pi();
template <>
_CCCL_GLOBAL_CONSTANT __half inv_sqrtpi_v<__half> = __numbers<__half>::__inv_sqrtpi();
template <>
_CCCL_GLOBAL_CONSTANT __half ln2_v<__half> = __numbers<__half>::__ln2();
template <>
_CCCL_GLOBAL_CONSTANT __half ln10_v<__half> = __numbers<__half>::__ln10();
template <>
_CCCL_GLOBAL_CONSTANT __half sqrt2_v<__half> = __numbers<__half>::__sqrt2();
template <>
_CCCL_GLOBAL_CONSTANT __half sqrt3_v<__half> = __numbers<__half>::__sqrt3();
template <>
_CCCL_GLOBAL_CONSTANT __half inv_sqrt3_v<__half> = __numbers<__half>::__inv_sqrt3();
template <>
_CCCL_GLOBAL_CONSTANT __half egamma_v<__half> = __numbers<__half>::__egamma();
template <>
_CCCL_GLOBAL_CONSTANT __half phi_v<__half> = __numbers<__half>::__phi();
#  endif // _LIBCUDACXX_HAS_NVFP16()

#  if _LIBCUDACXX_HAS_NVBF16()
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 e_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__e();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 log2e_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__log2e();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 log10e_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__log10e();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 pi_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__pi();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 inv_pi_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__inv_pi();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 inv_sqrtpi_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__inv_sqrtpi();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 ln2_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__ln2();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 ln10_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__ln10();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 sqrt2_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__sqrt2();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 sqrt3_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__sqrt3();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 inv_sqrt3_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__inv_sqrt3();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 egamma_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__egamma();
template <>
_CCCL_GLOBAL_CONSTANT __nv_bfloat16 phi_v<__nv_bfloat16> = __numbers<__nv_bfloat16>::__phi();
#  endif // _LIBCUDACXX_HAS_NVBF16()
#endif // !_CCCL_COMPILER(MSVC)

inline constexpr double e          = __numbers<double>::__e();
inline constexpr double log2e      = __numbers<double>::__log2e();
inline constexpr double log10e     = __numbers<double>::__log10e();
inline constexpr double pi         = __numbers<double>::__pi();
inline constexpr double inv_pi     = __numbers<double>::__inv_pi();
inline constexpr double inv_sqrtpi = __numbers<double>::__inv_sqrtpi();
inline constexpr double ln2        = __numbers<double>::__ln2();
inline constexpr double ln10       = __numbers<double>::__ln10();
inline constexpr double sqrt2      = __numbers<double>::__sqrt2();
inline constexpr double sqrt3      = __numbers<double>::__sqrt3();
inline constexpr double inv_sqrt3  = __numbers<double>::__inv_sqrt3();
inline constexpr double egamma     = __numbers<double>::__egamma();
inline constexpr double phi        = __numbers<double>::__phi();

} // namespace numbers

_LIBCUDACXX_END_NAMESPACE_STD

#endif // _CUDA_STD_NUMBERS
