//===----------------------------------------------------------------------===//
//
// Part of libcu++, the C++ Standard Library for your entire system,
// under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
// SPDX-FileCopyrightText: Copyright (c) 2023 NVIDIA CORPORATION & AFFILIATES.
//
//===----------------------------------------------------------------------===//

// This header contains a preview of a portability system that enables
// CUDA C++ development with NVC++, NVCC, and supported host compilers.
// These interfaces are not guaranteed to be stable.

#ifndef __NV_TARGET_H
#define __NV_TARGET_H

#if defined(__NVCC__) || defined(__CUDACC_RTC__)
#  define _NV_COMPILER_NVCC
#elif defined(__NVCOMPILER) && __cplusplus >= 201103L
#  define _NV_COMPILER_NVCXX
#elif defined(__clang__) && defined(__CUDA__) && defined(__CUDA_ARCH__)
// clang compiling CUDA code, device mode.
#  define _NV_COMPILER_CLANG_CUDA
#endif

#if ((defined(__cplusplus) && __cplusplus >= 201103L) || (defined(_MSC_VER) && _MSVC_LANG >= 201103L))
#  define _NV_TARGET_CPP11
#endif

// Hide `if target` support from NVRTC
#if defined(_NV_TARGET_CPP11) && !defined(__CUDACC_RTC__)

#  if defined(_NV_COMPILER_NVCXX)
#    define _NV_BITSET_ATTRIBUTE [[nv::__target_bitset]]
#  else
#    define _NV_BITSET_ATTRIBUTE
#  endif

namespace nv
{
namespace target
{
namespace detail
{

typedef unsigned long long base_int_t;

// No host specialization
constexpr base_int_t all_hosts = 1;

// NVIDIA GPUs
constexpr base_int_t sm_35_bit  = 1 << 1;
constexpr base_int_t sm_37_bit  = 1 << 2;
constexpr base_int_t sm_50_bit  = 1 << 3;
constexpr base_int_t sm_52_bit  = 1 << 4;
constexpr base_int_t sm_53_bit  = 1 << 5;
constexpr base_int_t sm_60_bit  = 1 << 6;
constexpr base_int_t sm_61_bit  = 1 << 7;
constexpr base_int_t sm_62_bit  = 1 << 8;
constexpr base_int_t sm_70_bit  = 1 << 9;
constexpr base_int_t sm_72_bit  = 1 << 10;
constexpr base_int_t sm_75_bit  = 1 << 11;
constexpr base_int_t sm_80_bit  = 1 << 12;
constexpr base_int_t sm_86_bit  = 1 << 13;
constexpr base_int_t sm_87_bit  = 1 << 14;
constexpr base_int_t sm_89_bit  = 1 << 15;
constexpr base_int_t sm_90_bit  = 1 << 16;
constexpr base_int_t sm_100_bit = 1 << 17;
constexpr base_int_t sm_101_bit = 1 << 18;
constexpr base_int_t sm_103_bit = 1 << 19;
constexpr base_int_t sm_110_bit = 1 << 20;
constexpr base_int_t sm_120_bit = 1 << 21;
constexpr base_int_t all_devices =
  sm_35_bit | sm_37_bit | sm_50_bit | sm_52_bit | sm_53_bit | sm_60_bit | sm_61_bit | sm_62_bit | sm_70_bit | sm_72_bit
  | sm_75_bit | sm_80_bit | sm_86_bit | sm_87_bit | sm_89_bit | sm_90_bit | sm_100_bit | sm_101_bit | sm_103_bit
  | sm_110_bit | sm_120_bit;

// Store a set of targets as a set of bits
struct _NV_BITSET_ATTRIBUTE target_description
{
  base_int_t targets;

  constexpr target_description(base_int_t a)
      : targets(a)
  {}
};

// The type of the user-visible names of the NVIDIA GPU targets
enum class sm_selector : base_int_t
{
  sm_35  = 35,
  sm_37  = 37,
  sm_50  = 50,
  sm_52  = 52,
  sm_53  = 53,
  sm_60  = 60,
  sm_61  = 61,
  sm_62  = 62,
  sm_70  = 70,
  sm_72  = 72,
  sm_75  = 75,
  sm_80  = 80,
  sm_86  = 86,
  sm_87  = 87,
  sm_89  = 89,
  sm_90  = 90,
  sm_100 = 100,
  sm_101 = 101,
  sm_103 = 103,
  sm_110 = 110,
  sm_120 = 120,
};

constexpr base_int_t toint(sm_selector a)
{
  return static_cast<base_int_t>(a);
}

constexpr base_int_t bitexact(sm_selector a)
{
  return toint(a) == 35  ? sm_35_bit
       : toint(a) == 37  ? sm_37_bit
       : toint(a) == 50  ? sm_50_bit
       : toint(a) == 52  ? sm_52_bit
       : toint(a) == 53  ? sm_53_bit
       : toint(a) == 60  ? sm_60_bit
       : toint(a) == 61  ? sm_61_bit
       : toint(a) == 62  ? sm_62_bit
       : toint(a) == 70  ? sm_70_bit
       : toint(a) == 72  ? sm_72_bit
       : toint(a) == 75  ? sm_75_bit
       : toint(a) == 80  ? sm_80_bit
       : toint(a) == 86  ? sm_86_bit
       : toint(a) == 87  ? sm_87_bit
       : toint(a) == 89  ? sm_89_bit
       : toint(a) == 90  ? sm_90_bit
       : toint(a) == 100 ? sm_100_bit
       : toint(a) == 101 ? sm_101_bit
       : toint(a) == 103 ? sm_103_bit
       : toint(a) == 110 ? sm_110_bit
       : toint(a) == 120 ? sm_120_bit
                         : 0;
}

constexpr base_int_t bitrounddown(sm_selector a)
{
  return toint(a) >= 120 ? sm_120_bit
       : toint(a) >= 110 ? sm_110_bit
       : toint(a) >= 103 ? sm_103_bit
       : toint(a) >= 101 ? sm_101_bit
       : toint(a) >= 100 ? sm_100_bit
       : toint(a) >= 90  ? sm_90_bit
       : toint(a) >= 89  ? sm_89_bit
       : toint(a) >= 87  ? sm_87_bit
       : toint(a) >= 86  ? sm_86_bit
       : toint(a) >= 80  ? sm_80_bit
       : toint(a) >= 75  ? sm_75_bit
       : toint(a) >= 72  ? sm_72_bit
       : toint(a) >= 70  ? sm_70_bit
       : toint(a) >= 62  ? sm_62_bit
       : toint(a) >= 61  ? sm_61_bit
       : toint(a) >= 60  ? sm_60_bit
       : toint(a) >= 53  ? sm_53_bit
       : toint(a) >= 52  ? sm_52_bit
       : toint(a) >= 50  ? sm_50_bit
       : toint(a) >= 37  ? sm_37_bit
       : toint(a) >= 35  ? sm_35_bit
                         : 0;
}

// Public API for NVIDIA GPUs

constexpr target_description is_exactly(sm_selector a)
{
  return target_description(bitexact(a));
}

constexpr target_description provides(sm_selector a)
{
  return target_description(~(bitrounddown(a) - 1) & all_devices);
}

// Boolean operations on target sets

constexpr target_description operator&&(target_description a, target_description b)
{
  return target_description(a.targets & b.targets);
}

constexpr target_description operator||(target_description a, target_description b)
{
  return target_description(a.targets | b.targets);
}

constexpr target_description operator!(target_description a)
{
  return target_description(~a.targets & (all_devices | all_hosts));
}
} // namespace detail

using detail::sm_selector;
using detail::target_description;

// The predicates for basic host/device selection
constexpr target_description is_host    = target_description(detail::all_hosts);
constexpr target_description is_device  = target_description(detail::all_devices);
constexpr target_description any_target = target_description(detail::all_hosts | detail::all_devices);
constexpr target_description no_target  = target_description(0);

// The public names for NVIDIA GPU architectures
constexpr sm_selector sm_35  = sm_selector::sm_35;
constexpr sm_selector sm_37  = sm_selector::sm_37;
constexpr sm_selector sm_50  = sm_selector::sm_50;
constexpr sm_selector sm_52  = sm_selector::sm_52;
constexpr sm_selector sm_53  = sm_selector::sm_53;
constexpr sm_selector sm_60  = sm_selector::sm_60;
constexpr sm_selector sm_61  = sm_selector::sm_61;
constexpr sm_selector sm_62  = sm_selector::sm_62;
constexpr sm_selector sm_70  = sm_selector::sm_70;
constexpr sm_selector sm_72  = sm_selector::sm_72;
constexpr sm_selector sm_75  = sm_selector::sm_75;
constexpr sm_selector sm_80  = sm_selector::sm_80;
constexpr sm_selector sm_86  = sm_selector::sm_86;
constexpr sm_selector sm_87  = sm_selector::sm_87;
constexpr sm_selector sm_89  = sm_selector::sm_89;
constexpr sm_selector sm_90  = sm_selector::sm_90;
constexpr sm_selector sm_100 = sm_selector::sm_100;
constexpr sm_selector sm_101 = sm_selector::sm_101;
constexpr sm_selector sm_103 = sm_selector::sm_103;
constexpr sm_selector sm_110 = sm_selector::sm_110;
constexpr sm_selector sm_120 = sm_selector::sm_120;

using detail::is_exactly;
using detail::provides;
} // namespace target
} // namespace nv

#endif // C++11  && !defined(__CUDACC_RTC__)

#include <nv/detail/__target_macros>

#endif // __NV_TARGET_H
