/*
Module : CPPCThreadPool.h
Purpose: Provides a thread pool framework based on the thread_pool class from the book C++ Concurrency In 
         Action by Anthony Williams
History: PJN / 08-11-2020 1. Initial Public release.
         PJN / 21-12-2020 1. Fixed a bug in CThreadPool::OtherQueuesAreEmpty where the return value was
                          being calculated incorrectly. This bug was causing the threads in the thread pool
                          to do busy waits instead of waiting correctly on the "_TasksAvailable" condition 
                          variable.
                          2. Fixed an issue in the CThreadPool constructor where a lock was not taken prior 
                          to modifying the _Dones array.
                          3. Fixed an issue in CThreadPool::pause where a lock was not taken prior to 
                          modifying the _Paused variable.
                          4. Fixed an issue in CThreadPool::resize where a lock was not taken prior to
                          modifying the _Dones array.
         PJN / 13-02-2022 1. Updated the code to use C++ uniform initialization for all variable declarations
         PJN / 21-03-2023 1. Optimized code to emplace_back call in resize method.
                          2. Updated copyright details.
                          3. The thread pool now does not start by default when you call the CThreadPool
                          constructor. Instead now client code is expected to call a new "Start" method. This
                          breaking change was implemented to allow various virtual functions in the 
                          CThreadPool framework to work correctly as calling virtual functions is of course not
                          supported from C++ constructors.
         PJN / 18-05-2023 1. Updated modules to indicate that it needs to be compiled using /std:c++17. Thanks
                          to Martin Richter for reporting this issue.

Copyright (c) 2020 - 2023 by PJ Naughter (Web: www.naughter.com, Email: pjna@naughter.com)

All rights reserved.

Copyright / Usage Details:

You are allowed to include the source code in any product (commercial, shareware, freeware or otherwise)
when your product is released in binary form. You are allowed to modify the source code in any way you want
except you cannot modify the copyright details at the top of each module. If you want to distribute source
code with your application, then you are only allowed to distribute versions released by the author. This is
to maintain a single distribution point for the source code.

*/


/////////////////////////// Macros / Defines //////////////////////////////////

#if _MSC_VER > 1000
#pragma once
#endif //#if _MSC_VER > 1000

#if _MSVC_LANG < 201703
#error CppConcurrency::CThreadPool requires a minimum C++ language standard of /std:c++17
#endif //#if _MSVC_LANG < 201703

#ifndef __CPPCTHREADPOOL_H__
#define __CPPCTHREADPOOL_H__

#ifndef CPPCTHREADPOOL_EXT_CLASS
#define CPPCTHREADPOOL_EXT_CLASS
#endif //#ifndef CPPCTHREADPOOL_EXT_CLASS


/////////////////////////// Includes //////////////////////////////////////////

#include <memory>
#include <atomic>
#include <thread>
#include <vector>
#include <queue>
#include <deque>
#include <mutex>
#include <future>
#include <functional>
#include <algorithm>
#include <condition_variable>
#include <type_traits>
#include <utility>
#include <cassert>


/////////////////////////// Classes ///////////////////////////////////////////

namespace CppConcurrency
{

class CPPCTHREADPOOL_EXT_CLASS CFunctionWrapper
{
  struct Implbase
  {
    Implbase() = default;
    Implbase(const Implbase&) = delete;
    Implbase(Implbase&&) = delete;
    virtual ~Implbase() = default;
    Implbase& operator=(const Implbase&) = delete;
    Implbase& operator=(Implbase&&) = delete;
    virtual void call() = 0;
  };
  std::unique_ptr<Implbase> _Impl;

template<typename F>
  struct Impltype : Implbase
  {
    F _f;

    Impltype(F&& f) noexcept : _f{std::move(f)}
    {
    }

    void call() final
    {
      _f();
    }
  };

public:
//Constructors / Destructors
  CFunctionWrapper() = default;
  CFunctionWrapper(const CFunctionWrapper&) = delete;
  CFunctionWrapper(CFunctionWrapper&&)  = default;
template<typename F>
  CFunctionWrapper(F&& f) : _Impl{std::make_unique<Impltype<F>>(std::move(f))}
  {
  }
  ~CFunctionWrapper() = default;

//Methods
  CFunctionWrapper& operator=(const CFunctionWrapper&) = delete;
  CFunctionWrapper& operator=(CFunctionWrapper&& other) = default;

  void operator()()
  {
    _Impl->call();
  }
};

class CThreadPool
{
protected:
//Member variables
#ifdef _WIN32
  bool _PumpWindowsMessages;
#endif //#ifdef _WIN32
  std::atomic<bool>_Paused;
  std::atomic<size_t> _ThreadsBusy;
  mutable std::mutex _MutexTasks;
  std::condition_variable _TasksAvailable;
  std::condition_variable _QueueNotFull;
  std::queue<CFunctionWrapper> _Tasks;
  std::vector<std::unique_ptr<std::atomic<bool>>> _Dones;
  std::vector<std::unique_ptr<std::deque<CFunctionWrapper>>> _Queues;
  std::vector<std::thread> _Threads;
  std::vector<std::promise<bool>> _PromiseInitInstanceThreads;
  bool _PermitTaskStealing;
  std::atomic<bool> _Started;
  size_t _StartCount;
  static thread_local std::deque<CFunctionWrapper>* _LocalQueue;
  static thread_local size_t _ThreadIndex;

//Methods

#ifdef _WIN32
#pragma warning(suppress: 26440)
/// @brief [Only valid for Windows compilation] pumps one windows message
  virtual bool PumpMessage()
  {
#ifdef _AFX
    return AfxGetApp()->PumpMessage();
#else
    MSG msg{};
    if (!GetMessage(&msg, nullptr, 0, 0))
      return false;

    //process this message
    TranslateMessage(&msg);
    DispatchMessage(&msg);

    return true;
#endif //#ifdef _AFX
  }

/// @brief [Only valid for Windows compilation] pumps windows messages while there are messages on the windows message queue
  virtual bool PumpMessages()
  {
    //Pump the message queue if necessary
#pragma warning(suppress: 26477)
    MSG msg{};
    while (PeekMessage(&msg, nullptr, 0, 0, PM_NOREMOVE))
    {
      if (!PumpMessage())
        return false;
    }

    return true;
  }
#endif //#ifdef _WIN32

/// @brief This method implements the actual worker thread used by the thread pool
/// @param threadIndex The thread index to be used for this worker thread
  virtual void WorkerThread(size_t threadIndex)
  {
    //Hive away the thread index in the thread local member variable
    _ThreadIndex = threadIndex;

    //Hive away the local queue in the thread local member variable
#ifdef _MSC_VER
#pragma warning(suppress: 26446)
#endif //#ifdef _MSC_VER
    _LocalQueue = _Queues[_ThreadIndex].get();

#ifdef _MSC_VER
#pragma warning(suppress: 26446)
#endif //#ifdef _MSC_VER
    auto& _bDone{*(_Dones[_ThreadIndex])};

    //Call the virtual InitInstanceThread method to allow custom per thread initialization
    const bool initInstance{InitInstanceThread(_ThreadIndex)};
#ifdef _MSC_VER
#pragma warning(suppress: 26446)
#endif //#ifdef _MSC_VER
    _PromiseInitInstanceThreads[_ThreadIndex].set_value(initInstance);
    if (!initInstance)
    {
      ExitInstanceThread(_ThreadIndex);
      return;
    }

    while (!_bDone)
    {
      std::unique_lock<std::mutex> lock{_MutexTasks};
#ifdef _WIN32
      if (_PumpWindowsMessages)
        PumpMessages();
      else
#endif //#ifdef _WIN32
        _TasksAvailable.wait(lock, [&]() noexcept { return _bDone || !_LocalQueue->empty() || !_Tasks.empty() || !OtherQueuesAreEmpty(threadIndex); });

      if (_bDone)
        break;
      if (_Paused)
      {
#ifdef _WIN32
        if (_PumpWindowsMessages)
          std::this_thread::yield();
#endif //#ifdef _WIN32
        continue;
      }
      else if (!_LocalQueue->empty())
      {
        auto task{std::move(_LocalQueue->front())};
        _LocalQueue->pop_front();
        lock.unlock();
        ++_ThreadsBusy;
        task();
        --_ThreadsBusy;
        _QueueNotFull.notify_all();
      }
      else if (!_Tasks.empty())
      {
        auto task{std::move(_Tasks.front())};
        _Tasks.pop();
        lock.unlock();
        ++_ThreadsBusy;
        task();
        --_ThreadsBusy;
        _QueueNotFull.notify_all();
      }
      else if (_PermitTaskStealing)
      {
        const auto nQueues{_Queues.size()};
        bool bFoundTask{false};
        for (size_t i=0; (i<nQueues) && !bFoundTask; ++i)
        {
          const size_t nIndex{(_ThreadIndex + i + 1) % nQueues};
#ifdef _MSC_VER
#pragma warning(suppress: 26446)
#endif //#ifdef _MSC_VER
          auto& queue{_Queues[nIndex]};
          if (!queue->empty())
          {
            auto task{std::move(queue->back())};
            queue->pop_back();
            lock.unlock();
            bFoundTask = true;
            ++_ThreadsBusy;
            task();
            --_ThreadsBusy;
            _QueueNotFull.notify_all();
          }
        }
#ifdef _WIN32
        if (!bFoundTask && _PumpWindowsMessages)
          std::this_thread::yield();
#endif //#ifdef _WIN32
      }
#ifdef _WIN32
      else if (_PumpWindowsMessages)
        std::this_thread::yield();
#endif //#ifdef _WIN32
    }

    //Call the virtual ExitInstanceThread method to allow custom per thread cleanup
    ExitInstanceThread(_ThreadIndex);
  }

/// @brief Note that there is an assumption that this method will be called with the "_mutexTasks" already
/// locked for thread safety reasons. This is ok because this method is only called
/// from the condition_variable wait lambda in WorkerThread.
/// @param currentThreadIndex The thread index of the current thread to perform the search from. Note 
/// that this thread will be explicitly excluded from the search
/// @return true if all the other queues are empty else false
  bool OtherQueuesAreEmpty(size_t currentThreadIndex) const noexcept
  {
    const auto nQueues{_Queues.size()};
    for (size_t i=1; i<nQueues; ++i) //Note we start the loop at 1 instead of 0 to exclude checking the queue at index nCurrentThreadIndex
    {
      const size_t nIndex{(currentThreadIndex + i + 1) % nQueues};
#ifdef _MSC_VER
#pragma warning(suppress: 26446)
#endif //#ifdef _MSC_VER
      auto& localQueue{_Queues[nIndex]};
      if (!localQueue->empty())
        return false;
    }
    return true;
  }

#ifdef _MSC_VER
#pragma warning(suppress: 26440)
#endif //#ifdef _MSC_VER
/// @brief This method is called when each worker thread is starting up. It is modelled on the CWinThread::InitInstance
/// mechanism in MFC. Derived classes of CThreadPool are free to override this method to implement their own custom
/// worker thread startup code
/// @param threadIndex The thread index of the current thread
/// @return true if you want to worker thread to start up or false if you would like this worker thread to fail
  virtual bool InitInstanceThread(size_t /*threadIndex*/)
  {
    return true;
  }

#ifdef _MSC_VER
#pragma warning(suppress: 26440)
#endif //#ifdef _MSC_VER
/// @brief This method is called when the worker thread is shutting down. It is modelled on the CWinThread::ExitInstance
/// mechanism in MFC. Derived classes of CThreadPool are free to override this method to implement their own
/// custom worker thread cleanup code
/// @param threadIndex The thread index of the current thread
  virtual void ExitInstanceThread(size_t /*threadIndex*/)
  {
  }

public:
//Constructors / Destructors

#ifdef _MSC_VER
#pragma warning(suppress: 26455)
#endif //#ifdef _MSC_VER
/// @brief Standard constructor for the class
/// @param threadCount The number of threads to create in the thread pool. If you provide the value 0, then the value return from std::thread::hardware_concurrency is actually used
/// @param permitTaskStealing Should worker threads in the thread pool by allowed to steal tasks from other worker threads
/// @param paused Should the worker threads by created in an initially "paused" state
/// @param pumpWindowsMessages [Only valid for Windows compilation] Should the worker thread pump windows messages
  CThreadPool(size_t threadCount = 0, bool permitTaskStealing = true, bool paused = false
#ifdef _WIN32
                                                                                           , bool pumpWindowsMessages = false) : _PumpWindowsMessages{pumpWindowsMessages},
#else
                                                                                                                             ) :
#endif //#ifdef _WIN32
                                                                                                                                 _Paused{paused},
                                                                                                                                 _ThreadsBusy{0},
                                                                                                                                 _PermitTaskStealing{permitTaskStealing},
                                                                                                                                 _Started{false},
                                                                                                                                 _StartCount{0}
  {
    if (threadCount == 0) //a thread count of 0 means use the thread count as returned from std::thread::hardware_concurrency()
      threadCount = std::thread::hardware_concurrency();
    _StartCount = threadCount; //Hive away the requested thread count
  }
  CThreadPool(const CThreadPool&) = delete;
  CThreadPool(CThreadPool&&) = delete;

///@brief Standard destructor for the class. Internally stop() will be called to stop and wait
/// for all the threads in the thread pool to exit.
  virtual ~CThreadPool()
  {
#ifdef _MSC_VER
#pragma warning(suppress: 26447)
#endif //#ifdef _MSC_VER
    stop();
  }

//Methods
  CThreadPool& operator=(const CThreadPool&) = delete;
  CThreadPool& operator=(CThreadPool&&) = delete;

/// @brief Returns how many threads are in the thread pool
  size_t size() const noexcept { return _Threads.size(); }

#ifdef _WIN32
/// @brief returns true if the worker threads will be pumping windows messages otherwise false
  bool pumpWindowsMessages() const noexcept { return _PumpWindowsMessages; };
#endif //#ifdef _WIN32

#ifdef _MSC_VER
#pragma warning(suppress: 26446)
#endif //#ifdef _MSC_VER
/// @brief Provides access to a specific thread in the thread pool
/// @param threadIndex the index of the thread to retrieve
  std::thread& get_thread(size_t threadIndex) noexcept { return _Threads[threadIndex]; };

/// @brief Returns how many items are on the main thread pool queue
  size_t queueSize() const
  {
    std::lock_guard<std::mutex> lock{_MutexTasks};
    return _Tasks.size();
  }

/// @brief Returns how many items are on a specific thread's local queue
/// @param threadIndex the index of the thread to retrieve
  size_t threadQueueSize(size_t threadIndex) const
  {
    std::lock_guard<std::mutex> lock{_MutexTasks};
#ifdef _MSC_VER
#pragma warning(suppress: 26446)
#endif //#ifdef _MSC_VER
    return _Queues[threadIndex]->size();
  }

/// @brief Returns how many threads are busy executing tasks in the thread pool
  size_t threadsBusy() const noexcept { return _ThreadsBusy; }

/// @brief returns true if the thread pool is currently paused or false otherwise
  bool paused() const noexcept { return _Paused; };

/// @brief By default the thread pool does not start when constructed. This behaviour
/// was changed in v1.11 of the thread pool. This change is to allow virtual functions
/// such as InitInstanceThread which would be called when each thread in the thread pool
/// is started to work correctly. Instead now an explicit call to "start" needs to 
/// done to actually start the thread pool. Note this method should only be called on
/// the same thread on which the CThreadPool instance was created.
  void start()
  {
    if (!_Started)
    {
      try
      {
        resize(_StartCount);
      }
      catch (...)
      {
        //Signal whatever worker threads got created to exit
        {
          std::lock_guard<std::mutex> lock{_MutexTasks};
          std::for_each(_Dones.begin(), _Dones.end(), [](auto& element) noexcept { *element = true; });
        }
        _TasksAvailable.notify_all();
        throw;
      }
      _Started = true;
    }
  }

/// @brief Returns the started state of the thread pool has started
  bool started() const noexcept { return _Started; };

/// @brief Sets the pause state of the thread pool. When a thread pool is paused
/// no tasks will be processed by its worker threads even though task submission
/// can continue
/// @param pause true to pause the thread pool or false to un-pause
  void pause(bool pause)
  {
    {
      std::lock_guard<std::mutex> lock{_MutexTasks};
      _Paused = pause;
    }
    if (!pause)
      _TasksAvailable.notify_all();
  }

/// @brief Clears down the main thread pool queue and all threads local queues.
/// Note this does not affect the count of threads in the thread pool
/// or any currently running tasks.
  void clear()
  {
    std::lock_guard<std::mutex> lock{_MutexTasks};
    while (!_Tasks.empty())
      _Tasks.pop();
    for (auto& queue : _Queues)
    {
      while (!queue->empty())
        queue->pop_front();
    }
  }

/// @brief Shuts down the thread pool. Called in the destructor if you do not explicitly call it yourself.
  void stop()
  {
    //Delegate the work to the resize method
    resize(0);
  }

/// @brief Changes the number of threads in the thread pool. Note this method should only be called on the same
/// thread on which the CThreadPool instance was created.
/// @param threadCount The number of threads you want in the thread pool
  void resize(size_t threadCount)
  {
#ifdef _MSC_VER
#pragma warning(suppress: 26472)
#endif //#ifdef _MSC_VER
    const auto nOldThreads{_Threads.size()};

    if (threadCount > nOldThreads) //The number of threads is increasing
    {
      _Dones.reserve(threadCount);
      std::generate_n(std::back_inserter(_Dones), threadCount, []() { return std::make_unique<std::atomic<bool>>(false); });
      _Queues.reserve(threadCount);
      std::generate_n(std::back_inserter(_Queues), threadCount, []() { return std::make_unique<std::deque<CFunctionWrapper>>(); });
      _PromiseInitInstanceThreads.resize(threadCount);
      _Threads.reserve(threadCount);
      std::lock_guard<std::mutex> lock{_MutexTasks}; //We need to lock the mutex here to prevent any existing worker threads from 
                                                     //dotting into the _Threads array while we are modifying it here. Note this
                                                     //will not occur when resize is called by the CThreadPool constructor but
                                                     //can occur when resize is called on an already non-paused CThreadPool
                                                     //instance
      for (size_t i=nOldThreads; i<threadCount; ++i)
        _Threads.emplace_back(&CThreadPool::WorkerThread, this, i);
    }
    else if (threadCount < nOldThreads) //The number of threads is decreasing
    {
      {
        std::lock_guard<std::mutex> lock{_MutexTasks};
        for (size_t i=threadCount; i<nOldThreads; ++i)
        {
          //Signal the specified worker thread to exit
  #ifdef _MSC_VER
  #pragma warning(suppress: 26446)
  #endif //#ifdef _MSC_VER
          *(_Dones[i]) = true;
        }
      }
      _TasksAvailable.notify_all();

      //Wait for the specified worker threads to exit
      for (size_t i=threadCount; i<nOldThreads; ++i)
      {
#ifdef _MSC_VER
#pragma warning(suppress: 26446)
#endif //#ifdef _MSC_VER
        auto& thread{_Threads[i]};
        if (thread.joinable())
#ifdef _MSC_VER
#pragma warning(suppress: 26447)
#endif //#ifdef _MSC_VER
          thread.join();
      }
      _Threads.resize(threadCount);
      _PromiseInitInstanceThreads.resize(threadCount);
      _Queues.resize(threadCount);
      _Dones.resize(threadCount);
    }
  }

/// @brief Submit a task to any thread in the thread pool
/// @param f The first variable argument to the callable you want to execute in the thread pool
/// @param args The remaining arguments to the callable you want to execute in the thread pool
/// @return A std::future wrapper for the callable
  template<typename Func, typename... Args>
  auto submit(Func&& f, Args&&... args) -> std::future<typename std::invoke_result_t<Func, Args...>>
  {
    assert(_Started); //You forgot to call Start on this thread pool instance

    using ResultT = std::invoke_result_t<Func, Args...>;

    std::packaged_task<ResultT()> task{std::bind(f, args...)}; //NOLINT(modernize-avoid-bind)
    auto future{task.get_future()};
    if (_LocalQueue != nullptr)
    {
      std::lock_guard<std::mutex> lock{_MutexTasks};
      _LocalQueue->push_front(std::move(task));
    }
    else
    {
      std::lock_guard<std::mutex> lock{_MutexTasks};
      _Tasks.push(std::move(task));
    }
    _TasksAvailable.notify_one();
    return future;
  }

/// @brief Submit a task to any thread in the thread pool. Note this override will block until
/// the queue being pushed onto contains less than "MaxQueueSize" tasks.
/// @param maxQueueSize The maximum items which can be put on the queue
/// @param f The first variable argument to the callable you want to execute in the thread pool
/// @param args The remaining arguments to the callable you want to execute in the thread pool
/// @return A std::future wrapper for the callable
  template<typename Func, typename... Args>
  auto blocked_submit(size_t maxQueueSize, Func&& f, Args&&... args) -> std::future<typename std::invoke_result_t<Func, Args...>>
  {
    assert(_Started); //You forgot to call Start on this thread pool instance

    using ResultT = std::invoke_result_t<Func, Args...>;

    std::packaged_task<ResultT()> task{std::bind(f, args...)}; //NOLINT(modernize-avoid-bind)
    auto future{task.get_future()};
    if (_LocalQueue != nullptr)
    {
      std::unique_lock<std::mutex> lock{_MutexTasks};
      _QueueNotFull.wait(lock, [&]() noexcept { return _LocalQueue->size() < maxQueueSize; } );
      _LocalQueue->push_front(std::move(task));
    }
    else
    {
      std::unique_lock<std::mutex> lock{_MutexTasks};
      _QueueNotFull.wait(lock, [&]() noexcept { return _Tasks.size() < maxQueueSize; });
      _Tasks.push(std::move(task));
    }
    _TasksAvailable.notify_one();
    return future;
  }

/// @brief Submit a task to a specific thread in the thread pool
/// @param threadIndex The thread index of the specific thread on which you want the task executed
/// @param f The first variable argument to the thing / callable you want to execute in the thread pool
/// @param args The remaining arguments to the thing / callable you want to execute in the thread pool
/// @return A std::future wrapper for the callable
  template<typename Func, typename... Args>
  auto submit(size_t threadIndex, Func&& f, Args&&... args) -> std::future<typename std::invoke_result_t<Func, Args...>>
  {
    assert(_Started); //You forgot to call Start on this thread pool instance

    using ResultT = std::invoke_result_t<Func, Args...>;

    std::packaged_task<ResultT()> task{std::bind(f, args...)}; //NOLINT(modernize-avoid-bind)
    auto future{task.get_future()};
    {
      std::lock_guard<std::mutex> lock{_MutexTasks};
#ifdef _MSC_VER
#pragma warning(suppress: 26446)
#endif //#ifdef _MSC_VER
      _Queues[threadIndex]->push_front(std::move(task));
    }
    _TasksAvailable.notify_one();
    return future;
  }

/// @brief Submit a task to a specific thread in the thread pool. Note this override will block until
/// the queue being pushed onto contains less than "MaxQueueSize" tasks.
/// @param maxQueueSize The maximum items which can be put on the queue
/// @param threadIndex The thread index of the specific thread on which you want the task executed
/// @param f The first variable argument to the thing / callable you want to execute in the thread pool
/// @param args The remaining arguments to the thing / callable you want to execute in the thread pool
/// @return A std::future wrapper for the callable
  template<typename Func, typename... Args>
  auto blocked_submit(size_t maxQueueSize, size_t threadIndex, Func&& f, Args&&... args) -> std::future<typename std::invoke_result_t<Func, Args...>>
  {
    assert(_Started); //You forgot to call Start on this thread pool instance

    using ResultT = std::invoke_result_t<Func, Args...>;

    std::packaged_task<ResultT()> task{std::bind(f, args...)}; //NOLINT(modernize-avoid-bind)
    auto future{task.get_future()};
    {
      std::unique_lock<std::mutex> lock{_MutexTasks};
#ifdef _MSC_VER
#pragma warning(suppress: 26446)
#endif //#ifdef _MSC_VER
      auto& localQueue{_Queues[threadIndex]};
      _QueueNotFull.wait(lock, [&]() noexcept { return localQueue->size() < maxQueueSize; });
      localQueue->push_front(std::move(task));
    }
    _TasksAvailable.notify_one();
    return future;
  }

/// @brief Waits for the all threads in the thread pool to have called InitInstanceThread.
/// @return true if all threads return true from their InitInstanceThread method, otherwise false
  bool WaitForInitInstanceThreads()
  {
    assert(_Started); //You forgot to call Start on this thread pool instance

    //Wait for all the promises for the return value from the InitInstanceThread method
    //from all the threads in the thread pool
    bool bSuccess{true};
    for (auto& promise : _PromiseInitInstanceThreads)
      bSuccess &= promise.get_future().get();
    return bSuccess;
  }
};


}; //namespace CppConcurrency

#endif //#define __CPPCTHREADPOOL_H__
