#include "sierrachart.h" 
#include <thread>       
#include <atomic>      
#include <chrono>       
#include <condition_variable> 
#include <mutex>          

SCDLLName("ThreadedAccountPolling")

std::atomic<bool> KeepPolling(false);
std::thread PollingThread;
std::condition_variable PollCondition;
std::mutex PollMutex;

void PollingLoop()
{
    std::unique_lock<std::mutex> lock(PollMutex);
    
    while (KeepPolling)
    {

        PollCondition.wait_for(lock, std::chrono::milliseconds(250), [] { return !KeepPolling; });

        if (!KeepPolling) 
            break; 
    }
}

SCSFExport scsf_AccountPollingThread(SCStudyInterfaceRef sc)
{
    static int MenuID_StartStop = 0;
    static bool PollNow = false;

    if (sc.SetDefaults)
    {
        sc.GraphName = "Threaded Account Polling";
        sc.AutoLoop = 0;  
        sc.UpdateAlways = 1;  

        return;
    }

    if (sc.Index == 0 && MenuID_StartStop == 0)
    {
        MenuID_StartStop = sc.AddACSChartShortcutMenuItem(sc.ChartNumber, "Start/Stop Polling");
    }

    if (sc.MenuEventID == MenuID_StartStop)
    {
        if (KeepPolling)
        {
            KeepPolling = false;
            PollCondition.notify_all();
            if (PollingThread.joinable()) 
                PollingThread.join();
            sc.AddMessageToLog("Polling Stopped", 0);
        }
        else
        {
            KeepPolling = true;
            PollingThread = std::thread(PollingLoop);
            sc.AddMessageToLog("Polling Started", 0);
        }
    }

    if (KeepPolling)
    {
        static SCDateTimeMS LastExecutionTime;
        SCDateTimeMS CurrentTime = sc.CurrentSystemDateTimeMS;
        int64_t CurrentTimeMs = CurrentTime.GetTimeInMilliseconds();
        int64_t LastTimeMs = LastExecutionTime.GetTimeInMilliseconds();

        if ((CurrentTimeMs - LastTimeMs) >= 250)
        {
            s_SCPositionData PositionData;
            if (sc.GetTradePosition(PositionData))
            {
                SCString LogMessage;
                LogMessage.Format("Polling - Position: %f, AvgPrice: %f",
                                  PositionData.PositionQuantity,
                                  PositionData.AveragePrice);
                sc.AddMessageToLog(LogMessage, 0);
            }

            LastExecutionTime = CurrentTime; 
        }
    }

    if (sc.LastCallToFunction)
    {
        KeepPolling = false;
        PollCondition.notify_all();  
        if (PollingThread.joinable()) 
            PollingThread.join(); 
    }
}
