// -*- C++ -*-
//===--------------------------- semaphore --------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//

#ifndef _LIBCUDACXX_SEMAPHORE
#define _LIBCUDACXX_SEMAPHORE

/*
    semaphore synopsis

namespace std {

template<ptrdiff_t least_max_value = implementation-defined>
class counting_semaphore
{
public:
static constexpr ptrdiff_t max() noexcept;

constexpr explicit counting_semaphore(ptrdiff_t desired);
~counting_semaphore();

counting_semaphore(const counting_semaphore&) = delete;
counting_semaphore& operator=(const counting_semaphore&) = delete;

void release(ptrdiff_t __update = 1);
void acquire();
bool try_acquire() noexcept;
template<class Rep, class Period>
    bool try_acquire_for(const chrono::duration<Rep, Period>& __rel_time);
template<class Clock, class Duration>
    bool try_acquire_until(const chrono::time_point<Clock, Duration>& __abs_time);

private:
ptrdiff_t counter; // exposition only
};

using binary_semaphore = counting_semaphore<1>;

}

*/

#include <cuda/std/detail/__config>

#if defined(_CCCL_IMPLICIT_SYSTEM_HEADER_GCC)
#  pragma GCC system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_CLANG)
#  pragma clang system_header
#elif defined(_CCCL_IMPLICIT_SYSTEM_HEADER_MSVC)
#  pragma system_header
#endif // no system header

#include <cuda/std/__type_traits/conditional.h>
#include <cuda/std/atomic>
#include <cuda/std/detail/libcxx/include/__assert> // all public C++ headers provide the assertion handler
#include <cuda/std/type_traits>

_CCCL_PUSH_MACROS

#ifdef _LIBCUDACXX_HAS_NO_THREADS
#  error <semaphore> is not supported on this single threaded system
#endif

_LIBCUDACXX_BEGIN_NAMESPACE_STD

template <thread_scope _Sco, ptrdiff_t __least_max_value>
class __atomic_semaphore_base
{
  _LIBCUDACXX_HIDE_FROM_ABI bool __fetch_sub_if_slow(ptrdiff_t __old)
  {
    while (__old != 0)
    {
      if (__count.compare_exchange_weak(__old, __old - 1, memory_order_acquire, memory_order_relaxed))
      {
        return true;
      }
    }
    return false;
  }

  _LIBCUDACXX_HIDE_FROM_ABI bool __fetch_sub_if()
  {
    ptrdiff_t __old = __count.load(memory_order_acquire);
    if (__old == 0)
    {
      return false;
    }
    if (__count.compare_exchange_weak(__old, __old - 1, memory_order_acquire, memory_order_relaxed))
    {
      return true;
    }
    return __fetch_sub_if_slow(__old); // fail only if not __available
  }

  _LIBCUDACXX_HIDE_FROM_ABI void __wait_slow()
  {
    while (1)
    {
      ptrdiff_t const __old = __count.load(memory_order_acquire);
      if (__old != 0)
      {
        break;
      }
      __count.wait(__old, memory_order_relaxed);
    }
  }

  _LIBCUDACXX_HIDE_FROM_ABI bool __acquire_slow_timed(chrono::nanoseconds const& __rel_time)
  {
    return __libcpp_thread_poll_with_backoff(
      [this]() {
        ptrdiff_t const __old = __count.load(memory_order_acquire);
        return __old != 0 && __fetch_sub_if_slow(__old);
      },
      __rel_time);
  }
  __atomic_impl<ptrdiff_t, _Sco> __count;

public:
  _LIBCUDACXX_HIDE_FROM_ABI static constexpr ptrdiff_t max() noexcept
  {
    return numeric_limits<ptrdiff_t>::max();
  }

  _LIBCUDACXX_HIDE_FROM_ABI constexpr __atomic_semaphore_base(ptrdiff_t __count) noexcept
      : __count(__count)
  {}

  _CCCL_HIDE_FROM_ABI ~__atomic_semaphore_base() = default;

  __atomic_semaphore_base(__atomic_semaphore_base const&)            = delete;
  __atomic_semaphore_base& operator=(__atomic_semaphore_base const&) = delete;

  _LIBCUDACXX_HIDE_FROM_ABI void release(ptrdiff_t __update = 1)
  {
    __count.fetch_add(__update, memory_order_release);
    if (__update > 1)
    {
      __count.notify_all();
    }
    else
    {
      __count.notify_one();
    }
  }

  _LIBCUDACXX_HIDE_FROM_ABI void acquire()
  {
    while (!try_acquire())
    {
      __wait_slow();
    }
  }

  _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire() noexcept
  {
    return __fetch_sub_if();
  }

  template <class Clock, class Duration>
  _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire_until(chrono::time_point<Clock, Duration> const& __abs_time)
  {
    if (try_acquire())
    {
      return true;
    }
    else
    {
      return __acquire_slow_timed(__abs_time - Clock::now());
    }
  }

  template <class Rep, class Period>
  _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire_for(chrono::duration<Rep, Period> const& __rel_time)
  {
    if (try_acquire())
    {
      return true;
    }
    else
    {
      return __acquire_slow_timed(__rel_time);
    }
  }
};

#ifndef _LIBCUDACXX_USE_NATIVE_SEMAPHORES

template <thread_scope _Sco>
class __atomic_semaphore_base<_Sco, 1>
{
  _LIBCUDACXX_HIDE_FROM_ABI bool __acquire_slow_timed(chrono::nanoseconds const& __rel_time)
  {
    return __libcpp_thread_poll_with_backoff(
      [this]() {
        return try_acquire();
      },
      __rel_time);
  }
  __atomic_impl<int, _Sco> __available;

public:
  _LIBCUDACXX_HIDE_FROM_ABI static constexpr ptrdiff_t max() noexcept
  {
    return 1;
  }

  _LIBCUDACXX_HIDE_FROM_ABI constexpr __atomic_semaphore_base(ptrdiff_t __available)
      : __available(__available)
  {}

  _CCCL_HIDE_FROM_ABI ~__atomic_semaphore_base() = default;

  __atomic_semaphore_base(__atomic_semaphore_base const&)            = delete;
  __atomic_semaphore_base& operator=(__atomic_semaphore_base const&) = delete;

  _LIBCUDACXX_HIDE_FROM_ABI void release(ptrdiff_t __update = 1)
  {
    _LIBCUDACXX_ASSERT(__update == 1, "");
    __available.store(1, memory_order_release);
    __available.notify_one();
    (void) __update;
  }

  _LIBCUDACXX_HIDE_FROM_ABI void acquire()
  {
    while (!try_acquire())
    {
      __available.wait(0, memory_order_relaxed);
    }
  }

  _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire() noexcept
  {
    return 1 == __available.exchange(0, memory_order_acquire);
  }

  template <class Clock, class Duration>
  _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire_until(chrono::time_point<Clock, Duration> const& __abs_time)
  {
    if (try_acquire())
    {
      return true;
    }
    else
    {
      return __acquire_slow_timed(__abs_time - Clock::now());
    }
  }

  template <class Rep, class Period>
  _LIBCUDACXX_HIDE_FROM_ABI bool try_acquire_for(chrono::duration<Rep, Period> const& __rel_time)
  {
    if (try_acquire())
    {
      return true;
    }
    else
    {
      return __acquire_slow_timed(__rel_time);
    }
  }
};

#else

template <thread_scope _Sco>
class __sem_semaphore_base
{
  _LIBCUDACXX_HIDE_FROM_ABI bool __backfill(bool __success)
  {
#  ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_BACK_BUFFER
    if (__success)
    {
      auto const __back_amount = __backbuffer.fetch_sub(2, memory_order_acquire);
    }
    bool const __post_one = __back_amount > 0;
    bool const __post_two = __back_amount > 1;
    auto const __success =
      (!__post_one || __libcpp_semaphore_post(&__semaphore)) && (!__post_two || __libcpp_semaphore_post(&__semaphore));
    _LIBCUDACXX_ASSERT(__success, "");
    if (!__post_one || !__post_two)
    {
      __backbuffer.fetch_add(!__post_one ? 2 : 1, memory_order_relaxed);
    }
  }
#  endif
  return __success;
}

_LIBCUDACXX_HIDE_FROM_ABI bool
__try_acquire_fast()
{
#  ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER

  ptrdiff_t __old;
  __libcpp_thread_poll_with_backoff(
    [&]() {
      __old = __frontbuffer.load(memory_order_relaxed);
      return 0 != (__old >> 32);
    },
    chrono::microseconds(5));

  // always steal if you can
  while (__old >> 32)
  {
    if (__frontbuffer.compare_exchange_weak(__old, __old - (1ll << 32), memory_order_acquire))
    {
      return true;
    }
  }
  // record we're waiting
  __old = __frontbuffer.fetch_add(1ll, memory_order_release);
  // ALWAYS steal if you can!
  while (__old >> 32)
  {
    if (__frontbuffer.compare_exchange_weak(__old, __old - (1ll << 32), memory_order_acquire))
    {
      break;
    }
  }
  // not going to wait after all
  if (__old >> 32)
  {
    return __try_done(true);
  }
#  endif
  // the wait has begun...
  return false;
}

_LIBCUDACXX_HIDE_FROM_ABI bool __try_done(bool __success)
{
#  ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
  // record we're NOT waiting
  __frontbuffer.fetch_sub(1ll, memory_order_release);
#  endif
  return __backfill(__success);
}

_LIBCUDACXX_HIDE_FROM_ABI void __release_slow(ptrdiff_t __post_amount)
{
#  ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_BACK_BUFFER
  bool const __post_one = __post_amount > 0;
  bool const __post_two = __post_amount > 1;
  if (__post_amount > 2)
  {
    __backbuffer.fetch_add(__post_amount - 2, memory_order_acq_rel);
  }
  auto const __success =
    (!__post_one || __libcpp_semaphore_post(&__semaphore)) && (!__post_two || __libcpp_semaphore_post(&__semaphore));
  _LIBCUDACXX_ASSERT(__success, "");
#  else
    for (; __post_amount; --__post_amount)
    {
      auto const __success = __libcpp_semaphore_post(&__semaphore);
      _LIBCUDACXX_ASSERT(__success, "");
    }
#  endif
}

__libcpp_semaphore_t __semaphore;
#  ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
__atomic_impl<ptrdiff_t, _Sco> __frontbuffer;
#  endif
#  ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_BACK_BUFFER
__atomic_impl<ptrdiff_t, _Sco> __backbuffer;
#  endif

public:
static constexpr ptrdiff_t max() noexcept
{
  return _LIBCUDACXX_SEMAPHORE_MAX;
}

_LIBCUDACXX_HIDE_FROM_ABI __sem_semaphore_base(ptrdiff_t __count = 0)
    : __semaphore()
#  ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
    , __frontbuffer(__count << 32)
#  endif
#  ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_BACK_BUFFER
    , __backbuffer(0)
#  endif
{
  _LIBCUDACXX_ASSERT(__count <= max(), "");
  auto const __success =
#  ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
    __libcpp_semaphore_init(&__semaphore, 0);
#  else
      __libcpp_semaphore_init(&__semaphore, __count);
#  endif
  _LIBCUDACXX_ASSERT(__success, "");
}

_LIBCUDACXX_HIDE_FROM_ABI ~__sem_semaphore_base()
{
#  ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
  _LIBCUDACXX_ASSERT(0 == (__frontbuffer.load(memory_order_relaxed) & ~0u), "");
#  endif
  auto const __success = __libcpp_semaphore_destroy(&__semaphore);
  _LIBCUDACXX_ASSERT(__success, "");
}

__sem_semaphore_base(const __sem_semaphore_base&)            = delete;
__sem_semaphore_base& operator=(const __sem_semaphore_base&) = delete;

_LIBCUDACXX_HIDE_FROM_ABI void release(ptrdiff_t __update = 1)
{
#  ifndef _LIBCUDACXX_HAS_NO_SEMAPHORE_FRONT_BUFFER
  // boldly assume the semaphore is taken but uncontended
  ptrdiff_t __old = 0;
  // try to fast-release as long as it's uncontended
  while (0 == (__old & ~0ul))
  {
    if (__frontbuffer.compare_exchange_weak(__old, __old + (__update << 32), memory_order_acq_rel))
    {
      return;
    }
  }
#  endif
  // slow-release it is
  __release_slow(__update);
}

_LIBCUDACXX_HIDE_FROM_ABI void acquire()
{
  if (!__try_acquire_fast())
  {
    __try_done(__libcpp_semaphore_wait(&__semaphore));
  }
}

_LIBCUDACXX_HIDE_FROM_ABI bool try_acquire() noexcept
{
  return try_acquire_for(chrono::nanoseconds(0));
}

template <class Clock, class Duration>
_LIBCUDACXX_HIDE_FROM_ABI bool try_acquire_until(chrono::time_point<Clock, Duration> const& __abs_time)
{
  auto const current = max(Clock::now(), __abs_time);
  return try_acquire_for(chrono::duration_cast<chrono::nanoseconds>(__abs_time - current));
}

template <class Rep, class Period>
_LIBCUDACXX_HIDE_FROM_ABI bool try_acquire_for(chrono::duration<Rep, Period> const& __rel_time)
{
  return __try_acquire_fast() || __try_done(__libcpp_semaphore_wait_timed(&__semaphore, __rel_time));
}
}
;

#endif //_LIBCUDACXX_HAS_NO_SEMAPHORES

template <ptrdiff_t __least_max_value, thread_scope _Sco>
using __semaphore_base =
#ifdef _LIBCUDACXX_USE_NATIVE_SEMAPHORES
  __conditional_t<__least_max_value <= __sem_semaphore_base<_Sco>::max(),
                  __sem_semaphore_base<_Sco>,
                  __atomic_semaphore_base<_Sco, __least_max_value>>
#else
  __atomic_semaphore_base<_Sco, __least_max_value>
#endif
  ;

template <ptrdiff_t __least_max_value = INT_MAX>
class counting_semaphore : public __semaphore_base<__least_max_value, thread_scope_system>
{
  static_assert(__least_max_value <= __semaphore_base<__least_max_value, thread_scope_system>::max(), "");

public:
  _LIBCUDACXX_HIDE_FROM_ABI constexpr counting_semaphore(ptrdiff_t __count = 0)
      : __semaphore_base<__least_max_value, thread_scope_system>(__count)
  {}
  _CCCL_HIDE_FROM_ABI ~counting_semaphore() = default;

  counting_semaphore(const counting_semaphore&)            = delete;
  counting_semaphore& operator=(const counting_semaphore&) = delete;
};

using binary_semaphore = counting_semaphore<1>;

_LIBCUDACXX_END_NAMESPACE_STD

#include <cuda/std/__cuda/semaphore.h>
_CCCL_POP_MACROS

#endif //_LIBCUDACXX_SEMAPHORE
