#!/usr/bin/env python3

import aiohttp
import asyncio
import datetime
import json
import os
import struct
import sys
import time
import websockets
import zoneinfo
from enum import IntEnum

VERSION = 1
UNIQUE_HEADER_ID = 0x44444353
SCID_EPOCH = datetime.datetime(1899, 12, 30, tzinfo=datetime.timezone.utc)
RESERVED = bytes(48)
FLAG_END_OF_BATCH = 1
MARKET_DEPTH_FOLDER = "datas/MarketDepthData"

HEADER_FORMAT = "I I I I 48s"
RECORD_FORMAT = "Q B B H f I I"

MinPriceIncrement = {
    "BTCUSDT": 0.00001,
    "ETHUSDT": 0.0001
}

class CommandEnum(IntEnum):
    NO_COMMAND = 0
    COMMAND_CLEAR_BOOK = 1
    COMMAND_ADD_BID_LEVEL = 2
    COMMAND_ADD_ASK_LEVEL = 3
    COMMAND_MODIFY_BID_LEVEL = 4
    COMMAND_MODIFY_ASK_LEVEL = 5
    COMMAND_DELETE_BID_LEVEL = 6
    COMMAND_DELETE_ASK_LEVEL = 7

def GetSCIDDateTime(DateTime):
    return int((datetime.datetime.fromtimestamp(DateTime / 1e6, zoneinfo.ZoneInfo("UTC")) - SCID_EPOCH).total_seconds() * 1e6)

def CreateHeader(RecordSize):
    HeaderSize = struct.calcsize(HEADER_FORMAT)
    return struct.pack(HEADER_FORMAT, UNIQUE_HEADER_ID, HeaderSize, RecordSize, VERSION, RESERVED)

def CreateRecord(SCIDDateTime=0, Command=CommandEnum.NO_COMMAND, Flags=0, NumOrders=0, Price=0.0, PriceMultiplier=1, Quantity=0, QuantityMultiplier=1, Reserved=0):
    Price = float(Price) * PriceMultiplier
    Quantity = round(float(Quantity) * QuantityMultiplier)
    return struct.pack(RECORD_FORMAT, SCIDDateTime, Command, Flags, NumOrders, Price, Quantity, Reserved)

class BinanceDepthFileWriter:
    def __init__(self, Symbol):
        self.Symbol = Symbol
        self.RESTfulBaseURL = "https://api.binance.com/api/v3"
        self.WebSocketBaseURL = f"wss://stream.binance.com:9443/ws/{Symbol.lower()}@depth"
        # self.WebSocketBaseURL = f"wss://stream.binance.com:9443/ws/{Symbol.lower()}@depth@100ms"

        self.OrderBook = {'Bids': {}, 'Asks': {}}
        self.LastUpdateID = 0
        self.CurrentDate = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
        self.OrderBookFilePath = self.GetFilePath(self.CurrentDate)
        self.FirstUpdateDateTime = 0
        self.PriceMultiplier = 1
        # self.QuantityMultiplier = 1e5

        if Symbol in MinPriceIncrement:
            self.QuantityMultiplier = 1 / MinPriceIncrement[Symbol]
        else:
            sys.exit(1)

        self.UpdateCount = 0
        self.DeleteCount = 0

    def GetFilePath(self, Date):
        return os.path.join(MARKET_DEPTH_FOLDER, f"{self.Symbol}-BINANCE.{Date}.depth")

    def UpdateOrderBook(self, Side, Update):
        Price, Amount = float(Update[0]), float(Update[1])

        if Amount == 0:
            if Price in self.OrderBook[Side]:
                self.OrderBook[Side].pop(Price)
                self.DeleteCount += 1
        else:
            self.OrderBook[Side][Price] = Amount
            self.UpdateCount += 1

    async def FetchInitialOrderBook(self):
        async with aiohttp.ClientSession() as Session:
            async with Session.get(f"{self.RESTfulBaseURL}/depth", params={"symbol": self.Symbol, "limit": 5000}) as Response:
                Data = await Response.json()
                self.LastUpdateID = Data['lastUpdateId']
                
                for Bid in Data["bids"]:
                    self.OrderBook["Bids"][float(Bid[0])] = float(Bid[1])

                for Ask in Data["asks"]:
                    self.OrderBook["Asks"][float(Ask[0])] = float(Ask[1])

    async def ProcessWebSocketMessage(self, Message):
        Data = json.loads(Message)
        CurrentDate = datetime.datetime.now(datetime.timezone.utc).strftime("%Y-%m-%d")
        UpdateDateTime = datetime.datetime.fromtimestamp(Data["E"] / 1000, zoneinfo.ZoneInfo("UTC")).strftime("%Y-%m-%d %H:%M:%S.%f UTC")

        if CurrentDate != self.CurrentDate:
            self.CurrentDate = CurrentDate
            self.OrderBookFilePath = self.GetFilePath(self.CurrentDate)
            await self.WriteSnapshot()

        if self.FirstUpdateDateTime is None:
            self.FirstUpdateDateTime = Data["E"] - 1
            await self.WriteSnapshot()

        if Data["u"] <= self.LastUpdateID:
            return

        if Data["U"] <= self.LastUpdateID + 1 <= Data["u"]:
            for Bid in Data["b"]:
                self.UpdateOrderBook("Bids", Bid)
            for Ask in Data["a"]:
                self.UpdateOrderBook("Asks", Ask)
            self.LastUpdateID = Data['u']

            await self.WriteUpdate(Data)

            BidLevels = len(self.OrderBook["Bids"])
            AskLevels = len(self.OrderBook["Asks"])

            LowestBid = min(self.OrderBook["Bids"].keys()) if BidLevels > 0 else None
            HighestAsk = max(self.OrderBook["Asks"].keys()) if AskLevels > 0 else None

            print(f"{UpdateDateTime}; Changed Levels: */+: {str(self.UpdateCount).rjust(5)}, -: {str(self.DeleteCount).rjust(5)}; Snapshot Total Levels[Bid, Ask]: {str(BidLevels).rjust(5)}, {str(AskLevels).rjust(5)}; Edge Prices[Bid, Ask]: {str(LowestBid).rjust(9)}, {str(HighestAsk).rjust(9)}")

            self.UpdateCount = 0
            self.DeleteCount = 0
        else:
            print("out of sync, restarting...")
            await self.Restart()

    async def WriteSnapshot(self):
        IsFileNonEmpty = os.path.getsize(self.OrderBookFilePath) > 0 if os.path.exists(self.OrderBookFilePath) else False

        with open(self.OrderBookFilePath, "ab" if IsFileNonEmpty else "wb") as f:
            if not IsFileNonEmpty:
                f.write(CreateHeader(struct.calcsize(RECORD_FORMAT)))

            SCIDDateTime = GetSCIDDateTime(self.FirstUpdateDateTime * 1000)
            ClearRecord = CreateRecord(SCIDDateTime=SCIDDateTime, Command=CommandEnum.COMMAND_CLEAR_BOOK)
            f.write(ClearRecord)

            for Side, Command in [("Bids", CommandEnum.COMMAND_ADD_BID_LEVEL), ("Asks", CommandEnum.COMMAND_ADD_ASK_LEVEL)]:
                for Price, Quantity in self.OrderBook[Side].items():
                    Record = CreateRecord(SCIDDateTime=SCIDDateTime, Command=Command, NumOrders=0, 
                                          Price=Price, PriceMultiplier=self.PriceMultiplier, 
                                          Quantity=Quantity, QuantityMultiplier=self.QuantityMultiplier)
                    f.write(Record)

            f.write(CreateRecord(SCIDDateTime=SCIDDateTime, Flags=FLAG_END_OF_BATCH))

    async def WriteUpdate(self, Data):
        SCIDDateTime = GetSCIDDateTime(Data["E"] * 1000)
        Updates = {"Bids": Data["b"], "Asks": Data["a"]}

        CurrentDateTime = datetime.datetime.now(datetime.timezone.utc)
        if self.FirstUpdateDateTime is None or (CurrentDateTime - datetime.datetime.fromtimestamp(self.FirstUpdateDateTime / 1000, zoneinfo.ZoneInfo("UTC"))).total_seconds() >= 600:
            self.FirstUpdateDateTime = Data["E"]
            await self.WriteSnapshot()

        with open(self.OrderBookFilePath, "ab") as f:
            for Side in Updates.keys():
                for Price, Quantity in Updates[Side]:
                    Price, Quantity = float(Price), float(Quantity)

                    if Quantity == 0:
                        if Side == "Bids":
                            Command = CommandEnum.COMMAND_DELETE_BID_LEVEL
                        else:
                            Command = CommandEnum.COMMAND_DELETE_ASK_LEVEL
                    elif Price in self.OrderBook[Side]:
                        if Side == "Bids":
                            Command = CommandEnum.COMMAND_MODIFY_BID_LEVEL
                        else:
                            Command = CommandEnum.COMMAND_MODIFY_ASK_LEVEL
                    else:
                        if Side == "Bids":
                            Command = CommandEnum.COMMAND_ADD_BID_LEVEL
                        else:
                            Command = CommandEnum.COMMAND_ADD_ASK_LEVEL

                    Record = CreateRecord(
                        SCIDDateTime=SCIDDateTime, 
                        Command=Command, 
                        NumOrders=0, 
                        Price=Price, 
                        PriceMultiplier=self.PriceMultiplier, 
                        Quantity=Quantity, 
                        QuantityMultiplier=self.QuantityMultiplier
                    )

                    f.write(Record)

            f.write(CreateRecord(SCIDDateTime=SCIDDateTime, Flags=FLAG_END_OF_BATCH))

    async def MergeSnapshot(self):
        NewSnapshot = self.OrderBook.copy()

        LowestNewBid = min(NewSnapshot["Bids"].keys())
        self.OrderBook["Bids"] = {Price: Quantity for Price, Quantity in self.OrderBook["Bids"].items() if Price < LowestNewBid}
        self.OrderBook["Bids"].update(NewSnapshot["Bids"])

        LowestNewAsk = max(NewSnapshot["Asks"].keys())
        self.OrderBook["Asks"] = {Price: Quantity for Price, Quantity in self.OrderBook["Asks"].items() if Price > LowestNewAsk}
        self.OrderBook["Asks"].update(NewSnapshot["Asks"])

    async def Restart(self):
        await self.FetchInitialOrderBook()
        await self.MergeSnapshot()
        await self.WriteSnapshot()

    async def StartWebsocket(self):
        while True:
            try:
                async with websockets.connect(self.WebSocketBaseURL) as Websocket:
                    print("ws connection opened")
                    async for Message in Websocket:
                        await self.ProcessWebSocketMessage(Message)

            except websockets.ConnectionClosed:
                print("ws connection closed, reconnecting...")

            except Exception as e:
                print(f"error in ws connection: {e}")

            await asyncio.sleep(5)

    async def run(self):
        os.makedirs(MARKET_DEPTH_FOLDER, exist_ok=True)
        await self.FetchInitialOrderBook()
        await self.StartWebsocket()

async def Main():
    if len(sys.argv) != 2:
        sys.exit(1)
    
    Symbol = sys.argv[1].upper()
    Writer = BinanceDepthFileWriter(Symbol)
    await Writer.run()

if __name__ == "__main__":
    os.system("clear")
    asyncio.run(Main())
