Compare commits
	
		
			22 Commits 
		
	
	
		
			179a42ac78
			...
			e4a68bdec9
		
	
	| Author | SHA1 | Date | 
|---|---|---|
| 
							
							
								 | 
						e4a68bdec9 | |
| 
							
							
								 | 
						8efbfb7038 | |
| 
							
							
								 | 
						98c068d757 | |
| 
							
							
								 | 
						363a7cf4e0 | |
| 
							
							
								 | 
						f524912882 | |
| 
							
							
								 | 
						2f2b87b3d3 | |
| 
							
							
								 | 
						513544f5f7 | |
| 
							
							
								 | 
						9dc8006f10 | |
| 
							
							
								 | 
						6f17ef0c34 | |
| 
							
							
								 | 
						877f334b85 | |
| 
							
							
								 | 
						0508d08c49 | |
| 
							
							
								 | 
						9c14b08b81 | |
| 
							
							
								 | 
						5ad8863e32 | |
| 
							
							
								 | 
						5fad27c30d | |
| 
							
							
								 | 
						eb1a0a685d | |
| 
							
							
								 | 
						5c3f2750c9 | |
| 
							
							
								 | 
						9b7b062dfd | |
| 
							
							
								 | 
						0538e420cb | |
| 
							
							
								 | 
						4848dc40cc | |
| 
							
							
								 | 
						407efe3b10 | |
| 
							
							
								 | 
						00108010c9 | |
| 
							
							
								 | 
						8a4901c517 | 
| 
						 | 
				
			
			@ -30,7 +30,8 @@ from types import ModuleType
 | 
			
		|||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Iterator,
 | 
			
		||||
    Generator
 | 
			
		||||
    Generator,
 | 
			
		||||
    TYPE_CHECKING,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import pendulum
 | 
			
		||||
| 
						 | 
				
			
			@ -59,8 +60,10 @@ from ..clearing._messages import (
 | 
			
		|||
    BrokerdPosition,
 | 
			
		||||
)
 | 
			
		||||
from piker.types import Struct
 | 
			
		||||
from piker.data._symcache import SymbologyCache
 | 
			
		||||
from ..log import get_logger
 | 
			
		||||
from piker.log import get_logger
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from piker.data._symcache import SymbologyCache
 | 
			
		||||
 | 
			
		||||
log = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -493,6 +496,17 @@ class Account(Struct):
 | 
			
		|||
 | 
			
		||||
        _mktmap_table: dict[str, MktPair] | None = None,
 | 
			
		||||
 | 
			
		||||
        only_require: list[str]|True = True,
 | 
			
		||||
        # ^list of fqmes that are "required" to be processed from
 | 
			
		||||
        # this ledger pass; we often don't care about others and
 | 
			
		||||
        # definitely shouldn't always error in such cases.
 | 
			
		||||
        # (eg. broker backend loaded that doesn't yet supsport the
 | 
			
		||||
        # symcache but also, inside the paper engine we don't ad-hoc
 | 
			
		||||
        # request `get_mkt_info()` for every symbol in the ledger,
 | 
			
		||||
        # only the one for which we're simulating against).
 | 
			
		||||
        # TODO, not sure if there's a better soln for this, ideally
 | 
			
		||||
        # all backends get symcache support afap i guess..
 | 
			
		||||
 | 
			
		||||
    ) -> dict[str, Position]:
 | 
			
		||||
        '''
 | 
			
		||||
        Update the internal `.pps[str, Position]` table from input
 | 
			
		||||
| 
						 | 
				
			
			@ -535,11 +549,32 @@ class Account(Struct):
 | 
			
		|||
                if _mktmap_table is None:
 | 
			
		||||
                    raise
 | 
			
		||||
 | 
			
		||||
                required: bool = (
 | 
			
		||||
                    only_require is True
 | 
			
		||||
                    or (
 | 
			
		||||
                        only_require is not True
 | 
			
		||||
                        and
 | 
			
		||||
                        fqme in only_require
 | 
			
		||||
                    )
 | 
			
		||||
                )
 | 
			
		||||
                # XXX: caller is allowed to provide a fallback
 | 
			
		||||
                # mktmap table for the case where a new position is
 | 
			
		||||
                # being added and the preloaded symcache didn't
 | 
			
		||||
                # have this entry prior (eg. with frickin IB..)
 | 
			
		||||
                mkt = _mktmap_table[fqme]
 | 
			
		||||
                if (
 | 
			
		||||
                    not (mkt := _mktmap_table.get(fqme))
 | 
			
		||||
                    and
 | 
			
		||||
                    required
 | 
			
		||||
                ):
 | 
			
		||||
                    raise
 | 
			
		||||
 | 
			
		||||
                elif not required:
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
                else:
 | 
			
		||||
                    # should be an entry retreived somewhere
 | 
			
		||||
                    assert mkt
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
            if not (pos := pps.get(bs_mktid)):
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -656,7 +691,7 @@ class Account(Struct):
 | 
			
		|||
    def write_config(self) -> None:
 | 
			
		||||
        '''
 | 
			
		||||
        Write the current account state to the user's account TOML file, normally
 | 
			
		||||
        something like ``pps.toml``.
 | 
			
		||||
        something like `pps.toml`.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        # TODO: show diff output?
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -98,13 +98,14 @@ async def open_cached_client(
 | 
			
		|||
    If one has not been setup do it and cache it.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    brokermod = get_brokermod(brokername)
 | 
			
		||||
    brokermod: ModuleType = get_brokermod(brokername)
 | 
			
		||||
 | 
			
		||||
    # TODO: make abstract or `typing.Protocol`
 | 
			
		||||
    # client: Client
 | 
			
		||||
    async with maybe_open_context(
 | 
			
		||||
        acm_func=brokermod.get_client,
 | 
			
		||||
        kwargs=kwargs,
 | 
			
		||||
 | 
			
		||||
    ) as (cache_hit, client):
 | 
			
		||||
 | 
			
		||||
        if cache_hit:
 | 
			
		||||
            log.runtime(f'Reusing existing {client}')
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -42,7 +42,6 @@ from trio_typing import TaskStatus
 | 
			
		|||
from pendulum import (
 | 
			
		||||
    from_timestamp,
 | 
			
		||||
)
 | 
			
		||||
from rapidfuzz import process as fuzzy
 | 
			
		||||
import numpy as np
 | 
			
		||||
import tractor
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -111,6 +110,7 @@ class AggTrade(Struct, frozen=True):
 | 
			
		|||
 | 
			
		||||
async def stream_messages(
 | 
			
		||||
    ws: NoBsWs,
 | 
			
		||||
 | 
			
		||||
) -> AsyncGenerator[NoBsWs, dict]:
 | 
			
		||||
 | 
			
		||||
    # TODO: match syntax here!
 | 
			
		||||
| 
						 | 
				
			
			@ -221,6 +221,8 @@ def make_sub(pairs: list[str], sub_name: str, uid: int) -> dict[str, str]:
 | 
			
		|||
    }
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO, why aren't frame resp `log.info()`s showing in upstream
 | 
			
		||||
# code?!
 | 
			
		||||
@acm
 | 
			
		||||
async def open_history_client(
 | 
			
		||||
    mkt: MktPair,
 | 
			
		||||
| 
						 | 
				
			
			@ -463,6 +465,8 @@ async def stream_quotes(
 | 
			
		|||
    ):
 | 
			
		||||
        init_msgs: list[FeedInit] = []
 | 
			
		||||
        for sym in symbols:
 | 
			
		||||
            mkt: MktPair
 | 
			
		||||
            pair: Pair
 | 
			
		||||
            mkt, pair = await get_mkt_info(sym)
 | 
			
		||||
 | 
			
		||||
            # build out init msgs according to latest spec
 | 
			
		||||
| 
						 | 
				
			
			@ -511,7 +515,6 @@ async def stream_quotes(
 | 
			
		|||
 | 
			
		||||
            # start streaming
 | 
			
		||||
            async for typ, quote in msg_gen:
 | 
			
		||||
 | 
			
		||||
                # period = time.time() - last
 | 
			
		||||
                # hz = 1/period if period else float('inf')
 | 
			
		||||
                # if hz > 60:
 | 
			
		||||
| 
						 | 
				
			
			@ -547,7 +550,7 @@ async def open_symbol_search(
 | 
			
		|||
                )
 | 
			
		||||
 | 
			
		||||
                # repack in fqme-keyed table
 | 
			
		||||
                byfqme: dict[start, Pair] = {}
 | 
			
		||||
                byfqme: dict[str, Pair] = {}
 | 
			
		||||
                for pair in pairs.values():
 | 
			
		||||
                    byfqme[pair.bs_fqme] = pair
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -471,11 +471,15 @@ def search(
 | 
			
		|||
 | 
			
		||||
    '''
 | 
			
		||||
    # global opts
 | 
			
		||||
    brokermods = list(config['brokermods'].values())
 | 
			
		||||
    brokermods: list[ModuleType] = list(config['brokermods'].values())
 | 
			
		||||
 | 
			
		||||
    # TODO: this is coming from the `search --pdb` NOT from
 | 
			
		||||
    # the `piker --pdb` XD ..
 | 
			
		||||
    # -[ ] pull from the parent click ctx's values..dumdum
 | 
			
		||||
    # assert pdb
 | 
			
		||||
 | 
			
		||||
    # define tractor entrypoint
 | 
			
		||||
    async def main(func):
 | 
			
		||||
 | 
			
		||||
        async with maybe_open_pikerd(
 | 
			
		||||
            loglevel=config['loglevel'],
 | 
			
		||||
            debug_mode=pdb,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -22,7 +22,9 @@ routines should be primitive data types where possible.
 | 
			
		|||
"""
 | 
			
		||||
import inspect
 | 
			
		||||
from types import ModuleType
 | 
			
		||||
from typing import List, Dict, Any, Optional
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import trio
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -34,8 +36,10 @@ from ..accounting import MktPair
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
async def api(brokername: str, methname: str, **kwargs) -> dict:
 | 
			
		||||
    """Make (proxy through) a broker API call by name and return its result.
 | 
			
		||||
    """
 | 
			
		||||
    '''
 | 
			
		||||
    Make (proxy through) a broker API call by name and return its result.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    brokermod = get_brokermod(brokername)
 | 
			
		||||
    async with brokermod.get_client() as client:
 | 
			
		||||
        meth = getattr(client, methname, None)
 | 
			
		||||
| 
						 | 
				
			
			@ -62,10 +66,14 @@ async def api(brokername: str, methname: str, **kwargs) -> dict:
 | 
			
		|||
 | 
			
		||||
async def stocks_quote(
 | 
			
		||||
    brokermod: ModuleType,
 | 
			
		||||
    tickers: List[str]
 | 
			
		||||
) -> Dict[str, Dict[str, Any]]:
 | 
			
		||||
    """Return quotes dict for ``tickers``.
 | 
			
		||||
    """
 | 
			
		||||
    tickers: list[str]
 | 
			
		||||
 | 
			
		||||
) -> dict[str, dict[str, Any]]:
 | 
			
		||||
    '''
 | 
			
		||||
    Return a `dict` of snapshot quotes for the provided input
 | 
			
		||||
    `tickers`: a `list` of fqmes.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    async with brokermod.get_client() as client:
 | 
			
		||||
        return await client.quote(tickers)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -74,13 +82,15 @@ async def stocks_quote(
 | 
			
		|||
async def option_chain(
 | 
			
		||||
    brokermod: ModuleType,
 | 
			
		||||
    symbol: str,
 | 
			
		||||
    date: Optional[str] = None,
 | 
			
		||||
) -> Dict[str, Dict[str, Dict[str, Any]]]:
 | 
			
		||||
    """Return option chain for ``symbol`` for ``date``.
 | 
			
		||||
    date: str|None = None,
 | 
			
		||||
) -> dict[str, dict[str, dict[str, Any]]]:
 | 
			
		||||
    '''
 | 
			
		||||
    Return option chain for ``symbol`` for ``date``.
 | 
			
		||||
 | 
			
		||||
    By default all expiries are returned. If ``date`` is provided
 | 
			
		||||
    then contract quotes for that single expiry are returned.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    async with brokermod.get_client() as client:
 | 
			
		||||
        if date:
 | 
			
		||||
            id = int((await client.tickers2ids([symbol]))[symbol])
 | 
			
		||||
| 
						 | 
				
			
			@ -98,7 +108,7 @@ async def option_chain(
 | 
			
		|||
# async def contracts(
 | 
			
		||||
#     brokermod: ModuleType,
 | 
			
		||||
#     symbol: str,
 | 
			
		||||
# ) -> Dict[str, Dict[str, Dict[str, Any]]]:
 | 
			
		||||
# ) -> dict[str, dict[str, dict[str, Any]]]:
 | 
			
		||||
#     """Return option contracts (all expiries) for ``symbol``.
 | 
			
		||||
#     """
 | 
			
		||||
#     async with brokermod.get_client() as client:
 | 
			
		||||
| 
						 | 
				
			
			@ -110,15 +120,24 @@ async def bars(
 | 
			
		|||
    brokermod: ModuleType,
 | 
			
		||||
    symbol: str,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
) -> Dict[str, Dict[str, Dict[str, Any]]]:
 | 
			
		||||
    """Return option contracts (all expiries) for ``symbol``.
 | 
			
		||||
    """
 | 
			
		||||
) -> dict[str, dict[str, dict[str, Any]]]:
 | 
			
		||||
    '''
 | 
			
		||||
    Return option contracts (all expiries) for ``symbol``.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    async with brokermod.get_client() as client:
 | 
			
		||||
        return await client.bars(symbol, **kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def search_w_brokerd(name: str, pattern: str) -> dict:
 | 
			
		||||
async def search_w_brokerd(
 | 
			
		||||
    name: str,
 | 
			
		||||
    pattern: str,
 | 
			
		||||
) -> dict:
 | 
			
		||||
 | 
			
		||||
    # TODO: WHY NOT WORK!?!
 | 
			
		||||
    # when we `step` through the next block?
 | 
			
		||||
    # import tractor
 | 
			
		||||
    # await tractor.pause()
 | 
			
		||||
    async with open_cached_client(name) as client:
 | 
			
		||||
 | 
			
		||||
        # TODO: support multiple asset type concurrent searches.
 | 
			
		||||
| 
						 | 
				
			
			@ -130,12 +149,12 @@ async def symbol_search(
 | 
			
		|||
    pattern: str,
 | 
			
		||||
    **kwargs,
 | 
			
		||||
 | 
			
		||||
) -> Dict[str, Dict[str, Dict[str, Any]]]:
 | 
			
		||||
) -> dict[str, dict[str, dict[str, Any]]]:
 | 
			
		||||
    '''
 | 
			
		||||
    Return symbol info from broker.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    results = []
 | 
			
		||||
    results: list[str] = []
 | 
			
		||||
 | 
			
		||||
    async def search_backend(
 | 
			
		||||
        brokermod: ModuleType
 | 
			
		||||
| 
						 | 
				
			
			@ -143,6 +162,13 @@ async def symbol_search(
 | 
			
		|||
 | 
			
		||||
        brokername: str = mod.name
 | 
			
		||||
 | 
			
		||||
        # TODO: figure this the FUCK OUT
 | 
			
		||||
        # -> ok so obvi in the root actor any async task that's
 | 
			
		||||
        # spawned outside the main tractor-root-actor task needs to
 | 
			
		||||
        # call this..
 | 
			
		||||
        # await tractor.devx._debug.maybe_init_greenback()
 | 
			
		||||
        # tractor.pause_from_sync()
 | 
			
		||||
 | 
			
		||||
        async with maybe_spawn_brokerd(
 | 
			
		||||
            mod.name,
 | 
			
		||||
            infect_asyncio=getattr(
 | 
			
		||||
| 
						 | 
				
			
			@ -162,7 +188,6 @@ async def symbol_search(
 | 
			
		|||
            ))
 | 
			
		||||
 | 
			
		||||
    async with trio.open_nursery() as n:
 | 
			
		||||
 | 
			
		||||
        for mod in brokermods:
 | 
			
		||||
            n.start_soon(search_backend, mod.name)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -172,11 +197,13 @@ async def symbol_search(
 | 
			
		|||
async def mkt_info(
 | 
			
		||||
    brokermod: ModuleType,
 | 
			
		||||
    fqme: str,
 | 
			
		||||
 | 
			
		||||
    **kwargs,
 | 
			
		||||
 | 
			
		||||
) -> MktPair:
 | 
			
		||||
    '''
 | 
			
		||||
    Return MktPair info from broker including src and dst assets.
 | 
			
		||||
    Return the `piker.accounting.MktPair` info struct from a given
 | 
			
		||||
    backend broker tradable src/dst asset pair.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    async with open_cached_client(brokermod.name) as client:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -168,7 +168,6 @@ class OrderClient(Struct):
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
async def relay_orders_from_sync_code(
 | 
			
		||||
 | 
			
		||||
    client: OrderClient,
 | 
			
		||||
    symbol_key: str,
 | 
			
		||||
    to_ems_stream: tractor.MsgStream,
 | 
			
		||||
| 
						 | 
				
			
			@ -242,6 +241,11 @@ async def open_ems(
 | 
			
		|||
 | 
			
		||||
    async with maybe_open_emsd(
 | 
			
		||||
        broker,
 | 
			
		||||
        # XXX NOTE, LOL so this determines the daemon `emsd` loglevel
 | 
			
		||||
        # then FYI.. that's kinda wrong no?
 | 
			
		||||
        # -[ ] shouldn't it be set by `pikerd -l` or no?
 | 
			
		||||
        # -[ ] would make a lot more sense to have a subsys ctl for
 | 
			
		||||
        #     levels.. like `-l emsd.info` or something?
 | 
			
		||||
        loglevel=loglevel,
 | 
			
		||||
    ) as portal:
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -653,7 +653,11 @@ class Router(Struct):
 | 
			
		|||
            flume = feed.flumes[fqme]
 | 
			
		||||
            first_quote: dict = flume.first_quote
 | 
			
		||||
            book: DarkBook = self.get_dark_book(broker)
 | 
			
		||||
            book.lasts[fqme]: float = float(first_quote['last'])
 | 
			
		||||
 | 
			
		||||
            if not (last := first_quote.get('last')):
 | 
			
		||||
                last: float = flume.rt_shm.array[-1]['close']
 | 
			
		||||
 | 
			
		||||
            book.lasts[fqme]: float = float(last)
 | 
			
		||||
 | 
			
		||||
            async with self.maybe_open_brokerd_dialog(
 | 
			
		||||
                brokermod=brokermod,
 | 
			
		||||
| 
						 | 
				
			
			@ -716,7 +720,7 @@ class Router(Struct):
 | 
			
		|||
            subs = self.subscribers[sub_key]
 | 
			
		||||
 | 
			
		||||
        sent_some: bool = False
 | 
			
		||||
        for client_stream in subs:
 | 
			
		||||
        for client_stream in subs.copy():
 | 
			
		||||
            try:
 | 
			
		||||
                await client_stream.send(msg)
 | 
			
		||||
                sent_some = True
 | 
			
		||||
| 
						 | 
				
			
			@ -1010,6 +1014,10 @@ async def translate_and_relay_brokerd_events(
 | 
			
		|||
                status_msg.brokerd_msg = msg
 | 
			
		||||
                status_msg.src = msg.broker_details['name']
 | 
			
		||||
 | 
			
		||||
                if not status_msg.req:
 | 
			
		||||
                    # likely some order change state?
 | 
			
		||||
                    await tractor.pause()
 | 
			
		||||
                else:
 | 
			
		||||
                    await router.client_broadcast(
 | 
			
		||||
                        status_msg.req.symbol,
 | 
			
		||||
                        status_msg,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -297,6 +297,8 @@ class PaperBoi(Struct):
 | 
			
		|||
 | 
			
		||||
        # transmit pp msg to ems
 | 
			
		||||
        pp: Position = self.acnt.pps[bs_mktid]
 | 
			
		||||
        # TODO, this will break if `require_only=True` was passed to
 | 
			
		||||
        # `.update_from_ledger()`
 | 
			
		||||
 | 
			
		||||
        pp_msg = BrokerdPosition(
 | 
			
		||||
            broker=self.broker,
 | 
			
		||||
| 
						 | 
				
			
			@ -653,6 +655,7 @@ async def open_trade_dialog(
 | 
			
		|||
                # in) use manually constructed table from calling
 | 
			
		||||
                # the `.get_mkt_info()` provider EP above.
 | 
			
		||||
                _mktmap_table=mkt_by_fqme,
 | 
			
		||||
                only_require=list(mkt_by_fqme),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            pp_msgs: list[BrokerdPosition] = []
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -30,6 +30,7 @@ subsys: str = 'piker.clearing'
 | 
			
		|||
 | 
			
		||||
log = get_logger(subsys)
 | 
			
		||||
 | 
			
		||||
# TODO, oof doesn't this ignore the `loglevel` then???
 | 
			
		||||
get_console_log = partial(
 | 
			
		||||
    get_console_log,
 | 
			
		||||
    name=subsys,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -140,11 +140,10 @@ def pikerd(
 | 
			
		|||
 | 
			
		||||
        if pdb:
 | 
			
		||||
            log.warning((
 | 
			
		||||
                "\n"
 | 
			
		||||
                "!!! YOU HAVE ENABLED DAEMON DEBUG MODE !!!\n"
 | 
			
		||||
                "When a `piker` daemon crashes it will block the "
 | 
			
		||||
                "task-thread until resumed from console!\n"
 | 
			
		||||
                "\n"
 | 
			
		||||
                '\n\n'
 | 
			
		||||
                '!!! YOU HAVE ENABLED DAEMON DEBUG MODE !!!\n'
 | 
			
		||||
                'When a `piker` daemon crashes it will block the '
 | 
			
		||||
                'task-thread until resumed from console!\n'
 | 
			
		||||
            ))
 | 
			
		||||
 | 
			
		||||
        # service-actor registry endpoint socket-address set
 | 
			
		||||
| 
						 | 
				
			
			@ -177,7 +176,7 @@ def pikerd(
 | 
			
		|||
        from .. import service
 | 
			
		||||
 | 
			
		||||
        async def main():
 | 
			
		||||
            service_mngr: service.Services
 | 
			
		||||
            service_mngr: service.ServiceMngr
 | 
			
		||||
 | 
			
		||||
            async with (
 | 
			
		||||
                service.open_pikerd(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -104,14 +104,15 @@ def get_app_dir(
 | 
			
		|||
    # `tractor`) with the testing dir and check for it whenever we
 | 
			
		||||
    # detect `pytest` is being used (which it isn't under normal
 | 
			
		||||
    # operation).
 | 
			
		||||
    if "pytest" in sys.modules:
 | 
			
		||||
        import tractor
 | 
			
		||||
        actor = tractor.current_actor(err_on_no_runtime=False)
 | 
			
		||||
        if actor:  # runtime is up
 | 
			
		||||
            rvs = tractor._state._runtime_vars
 | 
			
		||||
            testdirpath = Path(rvs['piker_vars']['piker_test_dir'])
 | 
			
		||||
            assert testdirpath.exists(), 'piker test harness might be borked!?'
 | 
			
		||||
            app_name = str(testdirpath)
 | 
			
		||||
    # if "pytest" in sys.modules:
 | 
			
		||||
    #     import tractor
 | 
			
		||||
    #     actor = tractor.current_actor(err_on_no_runtime=False)
 | 
			
		||||
    #     if actor:  # runtime is up
 | 
			
		||||
    #         rvs = tractor._state._runtime_vars
 | 
			
		||||
    #         import pdbp; pdbp.set_trace()
 | 
			
		||||
    #         testdirpath = Path(rvs['piker_vars']['piker_test_dir'])
 | 
			
		||||
    #         assert testdirpath.exists(), 'piker test harness might be borked!?'
 | 
			
		||||
    #         app_name = str(testdirpath)
 | 
			
		||||
 | 
			
		||||
    if platform.system() == 'Windows':
 | 
			
		||||
        key = "APPDATA" if roaming else "LOCALAPPDATA"
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -95,6 +95,12 @@ class Sampler:
 | 
			
		|||
    # history loading.
 | 
			
		||||
    incr_task_cs: trio.CancelScope | None = None
 | 
			
		||||
 | 
			
		||||
    bcast_errors: tuple[Exception] = (
 | 
			
		||||
        trio.BrokenResourceError,
 | 
			
		||||
        trio.ClosedResourceError,
 | 
			
		||||
        trio.EndOfChannel,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # holds all the ``tractor.Context`` remote subscriptions for
 | 
			
		||||
    # a particular sample period increment event: all subscribers are
 | 
			
		||||
    # notified on a step.
 | 
			
		||||
| 
						 | 
				
			
			@ -258,14 +264,15 @@ class Sampler:
 | 
			
		|||
        subs: set
 | 
			
		||||
        last_ts, subs = pair
 | 
			
		||||
 | 
			
		||||
        task = trio.lowlevel.current_task()
 | 
			
		||||
        log.debug(
 | 
			
		||||
            f'SUBS {self.subscribers}\n'
 | 
			
		||||
            f'PAIR {pair}\n'
 | 
			
		||||
            f'TASK: {task}: {id(task)}\n'
 | 
			
		||||
            f'broadcasting {period_s} -> {last_ts}\n'
 | 
			
		||||
        # NOTE, for debugging pub-sub issues
 | 
			
		||||
        # task = trio.lowlevel.current_task()
 | 
			
		||||
        # log.debug(
 | 
			
		||||
        #     f'AlL-SUBS@{period_s!r}: {self.subscribers}\n'
 | 
			
		||||
        #     f'PAIR: {pair}\n'
 | 
			
		||||
        #     f'TASK: {task}: {id(task)}\n'
 | 
			
		||||
        #     f'broadcasting {period_s} -> {last_ts}\n'
 | 
			
		||||
        #     f'consumers: {subs}'
 | 
			
		||||
        )
 | 
			
		||||
        # )
 | 
			
		||||
        borked: set[MsgStream] = set()
 | 
			
		||||
        sent: set[MsgStream] = set()
 | 
			
		||||
        while True:
 | 
			
		||||
| 
						 | 
				
			
			@ -282,12 +289,11 @@ class Sampler:
 | 
			
		|||
                        await stream.send(msg)
 | 
			
		||||
                        sent.add(stream)
 | 
			
		||||
 | 
			
		||||
                    except (
 | 
			
		||||
                        trio.BrokenResourceError,
 | 
			
		||||
                        trio.ClosedResourceError
 | 
			
		||||
                    ):
 | 
			
		||||
                    except self.bcast_errors as err:
 | 
			
		||||
                        log.error(
 | 
			
		||||
                            f'{stream._ctx.chan.uid} dropped connection'
 | 
			
		||||
                            f'Connection dropped for IPC ctx\n'
 | 
			
		||||
                            f'{stream._ctx}\n\n'
 | 
			
		||||
                            f'Due to {type(err)}'
 | 
			
		||||
                        )
 | 
			
		||||
                        borked.add(stream)
 | 
			
		||||
                else:
 | 
			
		||||
| 
						 | 
				
			
			@ -394,7 +400,8 @@ async def register_with_sampler(
 | 
			
		|||
                finally:
 | 
			
		||||
                    if (
 | 
			
		||||
                        sub_for_broadcasts
 | 
			
		||||
                        and subs
 | 
			
		||||
                        and
 | 
			
		||||
                        subs
 | 
			
		||||
                    ):
 | 
			
		||||
                        try:
 | 
			
		||||
                            subs.remove(stream)
 | 
			
		||||
| 
						 | 
				
			
			@ -561,8 +568,7 @@ async def open_sample_stream(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
async def sample_and_broadcast(
 | 
			
		||||
 | 
			
		||||
    bus: _FeedsBus,  # noqa
 | 
			
		||||
    bus: _FeedsBus,
 | 
			
		||||
    rt_shm: ShmArray,
 | 
			
		||||
    hist_shm: ShmArray,
 | 
			
		||||
    quote_stream: trio.abc.ReceiveChannel,
 | 
			
		||||
| 
						 | 
				
			
			@ -582,11 +588,33 @@ async def sample_and_broadcast(
 | 
			
		|||
 | 
			
		||||
    overruns = Counter()
 | 
			
		||||
 | 
			
		||||
    # NOTE, only used for debugging live-data-feed issues, though
 | 
			
		||||
    # this should be resolved more correctly in the future using the
 | 
			
		||||
    # new typed-msgspec feats of `tractor`!
 | 
			
		||||
    #
 | 
			
		||||
    # XXX, a multiline nested `dict` formatter (since rn quote-msgs
 | 
			
		||||
    # are just that).
 | 
			
		||||
    # pfmt: Callable[[str], str] = mk_repr()
 | 
			
		||||
 | 
			
		||||
    # iterate stream delivered by broker
 | 
			
		||||
    async for quotes in quote_stream:
 | 
			
		||||
        # print(quotes)
 | 
			
		||||
 | 
			
		||||
        # TODO: ``numba`` this!
 | 
			
		||||
        # XXX WARNING XXX only enable for debugging bc ow can cost
 | 
			
		||||
        # ALOT of perf with HF-feedz!!!
 | 
			
		||||
        #
 | 
			
		||||
        # log.info(
 | 
			
		||||
        #     'Rx live quotes:\n'
 | 
			
		||||
        #     f'{pfmt(quotes)}'
 | 
			
		||||
        # )
 | 
			
		||||
 | 
			
		||||
        # TODO,
 | 
			
		||||
        # -[ ] `numba` or `cython`-nize this loop possibly?
 | 
			
		||||
        #  |_alternatively could we do it in rust somehow by upacking
 | 
			
		||||
        #    arrow msgs instead of using `msgspec`?
 | 
			
		||||
        # -[ ] use `msgspec.Struct` support in new typed-msging from
 | 
			
		||||
        #     `tractor` to ensure only allowed msgs are transmitted?
 | 
			
		||||
        #
 | 
			
		||||
        for broker_symbol, quote in quotes.items():
 | 
			
		||||
            # TODO: in theory you can send the IPC msg *before* writing
 | 
			
		||||
            # to the sharedmem array to decrease latency, however, that
 | 
			
		||||
| 
						 | 
				
			
			@ -659,6 +687,21 @@ async def sample_and_broadcast(
 | 
			
		|||
            sub_key: str = broker_symbol.lower()
 | 
			
		||||
            subs: set[Sub] = bus.get_subs(sub_key)
 | 
			
		||||
 | 
			
		||||
            # TODO, figure out how to make this useful whilst
 | 
			
		||||
            # incoporating feed "pausing" ..
 | 
			
		||||
            #
 | 
			
		||||
            # if not subs:
 | 
			
		||||
            #     all_bs_fqmes: list[str] = list(
 | 
			
		||||
            #         bus._subscribers.keys()
 | 
			
		||||
            #     )
 | 
			
		||||
            #     log.warning(
 | 
			
		||||
            #         f'No subscribers for {brokername!r} live-quote ??\n'
 | 
			
		||||
            #         f'broker_symbol: {broker_symbol}\n\n'
 | 
			
		||||
 | 
			
		||||
            #         f'Maybe the backend-sys symbol does not match one of,\n'
 | 
			
		||||
            #         f'{pfmt(all_bs_fqmes)}\n'
 | 
			
		||||
            #     )
 | 
			
		||||
 | 
			
		||||
            # NOTE: by default the broker backend doesn't append
 | 
			
		||||
            # it's own "name" into the fqme schema (but maybe it
 | 
			
		||||
            # should?) so we have to manually generate the correct
 | 
			
		||||
| 
						 | 
				
			
			@ -728,18 +771,14 @@ async def sample_and_broadcast(
 | 
			
		|||
                        if lags > 10:
 | 
			
		||||
                            await tractor.pause()
 | 
			
		||||
 | 
			
		||||
                except (
 | 
			
		||||
                    trio.BrokenResourceError,
 | 
			
		||||
                    trio.ClosedResourceError,
 | 
			
		||||
                    trio.EndOfChannel,
 | 
			
		||||
                ):
 | 
			
		||||
                except Sampler.bcast_errors as ipc_err:
 | 
			
		||||
                    ctx: Context = ipc._ctx
 | 
			
		||||
                    chan: Channel = ctx.chan
 | 
			
		||||
                    if ctx:
 | 
			
		||||
                        log.warning(
 | 
			
		||||
                            'Dropped `brokerd`-quotes-feed connection:\n'
 | 
			
		||||
                            f'{broker_symbol}:'
 | 
			
		||||
                            f'{ctx.cid}@{chan.uid}'
 | 
			
		||||
                            f'Dropped `brokerd`-feed for {broker_symbol!r} due to,\n'
 | 
			
		||||
                            f'x>) {ctx.cid}@{chan.uid}'
 | 
			
		||||
                            f'|_{ipc_err!r}\n\n'
 | 
			
		||||
                        )
 | 
			
		||||
                    if sub.throttle_rate:
 | 
			
		||||
                        assert ipc._closed
 | 
			
		||||
| 
						 | 
				
			
			@ -756,12 +795,11 @@ async def sample_and_broadcast(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
async def uniform_rate_send(
 | 
			
		||||
 | 
			
		||||
    rate: float,
 | 
			
		||||
    quote_stream: trio.abc.ReceiveChannel,
 | 
			
		||||
    stream: MsgStream,
 | 
			
		||||
 | 
			
		||||
    task_status: TaskStatus = trio.TASK_STATUS_IGNORED,
 | 
			
		||||
    task_status: TaskStatus[None] = trio.TASK_STATUS_IGNORED,
 | 
			
		||||
 | 
			
		||||
) -> None:
 | 
			
		||||
    '''
 | 
			
		||||
| 
						 | 
				
			
			@ -779,13 +817,16 @@ async def uniform_rate_send(
 | 
			
		|||
    https://gist.github.com/njsmith/7ea44ec07e901cb78ebe1dd8dd846cb9
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    # TODO: compute the approx overhead latency per cycle
 | 
			
		||||
    left_to_sleep = throttle_period = 1/rate - 0.000616
 | 
			
		||||
    # ?TODO? dynamically compute the **actual** approx overhead latency per cycle
 | 
			
		||||
    # instead of this magic # bidinezz?
 | 
			
		||||
    throttle_period: float = 1/rate - 0.000616
 | 
			
		||||
    left_to_sleep: float = throttle_period
 | 
			
		||||
 | 
			
		||||
    # send cycle state
 | 
			
		||||
    first_quote: dict|None
 | 
			
		||||
    first_quote = last_quote = None
 | 
			
		||||
    last_send = time.time()
 | 
			
		||||
    diff = 0
 | 
			
		||||
    last_send: float = time.time()
 | 
			
		||||
    diff: float = 0
 | 
			
		||||
 | 
			
		||||
    task_status.started()
 | 
			
		||||
    ticks_by_type: dict[
 | 
			
		||||
| 
						 | 
				
			
			@ -796,22 +837,28 @@ async def uniform_rate_send(
 | 
			
		|||
    clear_types = _tick_groups['clears']
 | 
			
		||||
 | 
			
		||||
    while True:
 | 
			
		||||
 | 
			
		||||
        # compute the remaining time to sleep for this throttled cycle
 | 
			
		||||
        left_to_sleep = throttle_period - diff
 | 
			
		||||
        left_to_sleep: float = throttle_period - diff
 | 
			
		||||
 | 
			
		||||
        if left_to_sleep > 0:
 | 
			
		||||
            cs: trio.CancelScope
 | 
			
		||||
            with trio.move_on_after(left_to_sleep) as cs:
 | 
			
		||||
                sym: str
 | 
			
		||||
                last_quote: dict
 | 
			
		||||
                try:
 | 
			
		||||
                    sym, last_quote = await quote_stream.receive()
 | 
			
		||||
                except trio.EndOfChannel:
 | 
			
		||||
                    log.exception(f"feed for {stream} ended?")
 | 
			
		||||
                    log.exception(
 | 
			
		||||
                        f'Live stream for feed for ended?\n'
 | 
			
		||||
                        f'<=c\n'
 | 
			
		||||
                        f'  |_[{stream!r}\n'
 | 
			
		||||
                    )
 | 
			
		||||
                    break
 | 
			
		||||
 | 
			
		||||
                diff = time.time() - last_send
 | 
			
		||||
                diff: float = time.time() - last_send
 | 
			
		||||
 | 
			
		||||
                if not first_quote:
 | 
			
		||||
                    first_quote = last_quote
 | 
			
		||||
                    first_quote: float = last_quote
 | 
			
		||||
                    # first_quote['tbt'] = ticks_by_type
 | 
			
		||||
 | 
			
		||||
                if (throttle_period - diff) > 0:
 | 
			
		||||
| 
						 | 
				
			
			@ -872,7 +919,9 @@ async def uniform_rate_send(
 | 
			
		|||
        # TODO: now if only we could sync this to the display
 | 
			
		||||
        # rate timing exactly lul
 | 
			
		||||
        try:
 | 
			
		||||
            await stream.send({sym: first_quote})
 | 
			
		||||
            await stream.send({
 | 
			
		||||
                sym: first_quote
 | 
			
		||||
            })
 | 
			
		||||
        except tractor.RemoteActorError as rme:
 | 
			
		||||
            if rme.type is not tractor._exceptions.StreamOverrun:
 | 
			
		||||
                raise
 | 
			
		||||
| 
						 | 
				
			
			@ -883,19 +932,28 @@ async def uniform_rate_send(
 | 
			
		|||
                f'{sym}:{ctx.cid}@{chan.uid}'
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        except (
 | 
			
		||||
            # NOTE: any of these can be raised by ``tractor``'s IPC
 | 
			
		||||
        # NOTE: any of these can be raised by `tractor`'s IPC
 | 
			
		||||
        # transport-layer and we want to be highly resilient
 | 
			
		||||
        # to consumers which crash or lose network connection.
 | 
			
		||||
        # I.e. we **DO NOT** want to crash and propagate up to
 | 
			
		||||
        # ``pikerd`` these kinds of errors!
 | 
			
		||||
            trio.ClosedResourceError,
 | 
			
		||||
            trio.BrokenResourceError,
 | 
			
		||||
        except (
 | 
			
		||||
            ConnectionResetError,
 | 
			
		||||
        ):
 | 
			
		||||
        ) + Sampler.bcast_errors as ipc_err:
 | 
			
		||||
            match ipc_err:
 | 
			
		||||
                case trio.EndOfChannel():
 | 
			
		||||
                    log.info(
 | 
			
		||||
                        f'{stream} terminated by peer,\n'
 | 
			
		||||
                        f'{ipc_err!r}'
 | 
			
		||||
                    )
 | 
			
		||||
                case _:
 | 
			
		||||
                    # if the feed consumer goes down then drop
 | 
			
		||||
                    # out of this rate limiter
 | 
			
		||||
            log.warning(f'{stream} closed')
 | 
			
		||||
                    log.warning(
 | 
			
		||||
                        f'{stream} closed due to,\n'
 | 
			
		||||
                        f'{ipc_err!r}'
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
            await stream.aclose()
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -31,6 +31,7 @@ from pathlib import Path
 | 
			
		|||
from pprint import pformat
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any,
 | 
			
		||||
    Callable,
 | 
			
		||||
    Sequence,
 | 
			
		||||
    Hashable,
 | 
			
		||||
    TYPE_CHECKING,
 | 
			
		||||
| 
						 | 
				
			
			@ -56,7 +57,7 @@ from piker.brokers import (
 | 
			
		|||
)
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from ..accounting import (
 | 
			
		||||
    from piker.accounting import (
 | 
			
		||||
        Asset,
 | 
			
		||||
        MktPair,
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -149,19 +150,36 @@ class SymbologyCache(Struct):
 | 
			
		|||
                    'Implement `Client.get_assets()`!'
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if get_mkt_pairs := getattr(client, 'get_mkt_pairs', None):
 | 
			
		||||
            get_mkt_pairs: Callable|None = getattr(
 | 
			
		||||
                client,
 | 
			
		||||
                'get_mkt_pairs',
 | 
			
		||||
                None,
 | 
			
		||||
            )
 | 
			
		||||
            if not get_mkt_pairs:
 | 
			
		||||
                log.warning(
 | 
			
		||||
                    'No symbology cache `Pair` support for `{provider}`..\n'
 | 
			
		||||
                    'Implement `Client.get_mkt_pairs()`!'
 | 
			
		||||
                )
 | 
			
		||||
                return self
 | 
			
		||||
 | 
			
		||||
            pairs: dict[str, Struct] = await get_mkt_pairs()
 | 
			
		||||
                for bs_fqme, pair in pairs.items():
 | 
			
		||||
            if not pairs:
 | 
			
		||||
                log.warning(
 | 
			
		||||
                    'No pairs from intial {provider!r} sym-cache request?\n\n'
 | 
			
		||||
                    '`Client.get_mkt_pairs()` -> {pairs!r} ?'
 | 
			
		||||
                )
 | 
			
		||||
                return self
 | 
			
		||||
 | 
			
		||||
                    # NOTE: every backend defined pair should
 | 
			
		||||
                    # declare it's ns path for roundtrip
 | 
			
		||||
                    # serialization lookup.
 | 
			
		||||
            for bs_fqme, pair in pairs.items():
 | 
			
		||||
                if not getattr(pair, 'ns_path', None):
 | 
			
		||||
                    # XXX: every backend defined pair must declare
 | 
			
		||||
                    # a `.ns_path: tractor.NamespacePath` to enable
 | 
			
		||||
                    # roundtrip serialization lookup from a local
 | 
			
		||||
                    # cache file.
 | 
			
		||||
                    raise TypeError(
 | 
			
		||||
                        f'Pair-struct for {self.mod.name} MUST define a '
 | 
			
		||||
                            '`.ns_path: str`!\n'
 | 
			
		||||
                            f'{pair}'
 | 
			
		||||
                        '`.ns_path: str`!\n\n'
 | 
			
		||||
                        f'{pair!r}'
 | 
			
		||||
                    )
 | 
			
		||||
 | 
			
		||||
                entry = await self.mod.get_mkt_info(pair.bs_fqme)
 | 
			
		||||
| 
						 | 
				
			
			@ -195,12 +213,6 @@ class SymbologyCache(Struct):
 | 
			
		|||
                pair,
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                log.warning(
 | 
			
		||||
                    'No symbology cache `Pair` support for `{provider}`..\n'
 | 
			
		||||
                    'Implement `Client.get_mkt_pairs()`!'
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
        return self
 | 
			
		||||
 | 
			
		||||
    @classmethod
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -273,7 +273,7 @@ async def _reconnect_forever(
 | 
			
		|||
                nobsws._connected.set()
 | 
			
		||||
                await trio.sleep_forever()
 | 
			
		||||
        except HandshakeError:
 | 
			
		||||
            log.exception(f'Retrying connection')
 | 
			
		||||
            log.exception('Retrying connection')
 | 
			
		||||
 | 
			
		||||
        # ws & nursery block ends
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -359,8 +359,8 @@ async def open_autorecon_ws(
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
JSONRPC response-request style machinery for transparent multiplexing of msgs
 | 
			
		||||
over a NoBsWs.
 | 
			
		||||
JSONRPC response-request style machinery for transparent multiplexing
 | 
			
		||||
of msgs over a NoBsWs.
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -377,16 +377,25 @@ async def open_jsonrpc_session(
 | 
			
		|||
    url: str,
 | 
			
		||||
    start_id: int = 0,
 | 
			
		||||
    response_type: type = JSONRPCResult,
 | 
			
		||||
    request_type: Optional[type] = None,
 | 
			
		||||
    request_hook: Optional[Callable] = None,
 | 
			
		||||
    error_hook: Optional[Callable] = None,
 | 
			
		||||
) -> Callable[[str, dict], dict]:
 | 
			
		||||
    '''
 | 
			
		||||
    Init a json-RPC-over-websocket connection to the provided `url`.
 | 
			
		||||
 | 
			
		||||
    A `json_rpc: Callable[[str, dict], dict` is delivered to the
 | 
			
		||||
    caller for sending requests and a bg-`trio.Task` handles
 | 
			
		||||
    processing of response msgs including error reporting/raising in
 | 
			
		||||
    the parent/caller task.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    # NOTE, store all request msgs so we can raise errors on the
 | 
			
		||||
    # caller side!
 | 
			
		||||
    req_msgs: dict[int, dict] = {}
 | 
			
		||||
 | 
			
		||||
    async with (
 | 
			
		||||
        trio.open_nursery() as n,
 | 
			
		||||
        trio.open_nursery() as tn,
 | 
			
		||||
        open_autorecon_ws(url) as ws
 | 
			
		||||
    ):
 | 
			
		||||
        rpc_id: Iterable = count(start_id)
 | 
			
		||||
        rpc_id: Iterable[int] = count(start_id)
 | 
			
		||||
        rpc_results: dict[int, dict] = {}
 | 
			
		||||
 | 
			
		||||
        async def json_rpc(method: str, params: dict) -> dict:
 | 
			
		||||
| 
						 | 
				
			
			@ -394,27 +403,41 @@ async def open_jsonrpc_session(
 | 
			
		|||
            perform a json rpc call and wait for the result, raise exception in
 | 
			
		||||
            case of error field present on response
 | 
			
		||||
            '''
 | 
			
		||||
            nonlocal req_msgs
 | 
			
		||||
 | 
			
		||||
            req_id: int = next(rpc_id)
 | 
			
		||||
            msg = {
 | 
			
		||||
                'jsonrpc': '2.0',
 | 
			
		||||
                'id': next(rpc_id),
 | 
			
		||||
                'id': req_id,
 | 
			
		||||
                'method': method,
 | 
			
		||||
                'params': params
 | 
			
		||||
            }
 | 
			
		||||
            _id = msg['id']
 | 
			
		||||
 | 
			
		||||
            rpc_results[_id] = {
 | 
			
		||||
            result = rpc_results[_id] = {
 | 
			
		||||
                'result': None,
 | 
			
		||||
                'event': trio.Event()
 | 
			
		||||
                'error': None,
 | 
			
		||||
                'event': trio.Event(),  # signal caller resp arrived
 | 
			
		||||
            }
 | 
			
		||||
            req_msgs[_id] = msg
 | 
			
		||||
 | 
			
		||||
            await ws.send_msg(msg)
 | 
			
		||||
 | 
			
		||||
            # wait for reponse before unblocking requester code
 | 
			
		||||
            await rpc_results[_id]['event'].wait()
 | 
			
		||||
 | 
			
		||||
            ret = rpc_results[_id]['result']
 | 
			
		||||
 | 
			
		||||
            if (maybe_result := result['result']):
 | 
			
		||||
                ret = maybe_result
 | 
			
		||||
                del rpc_results[_id]
 | 
			
		||||
 | 
			
		||||
            else:
 | 
			
		||||
                err = result['error']
 | 
			
		||||
                raise Exception(
 | 
			
		||||
                    f'JSONRPC request failed\n'
 | 
			
		||||
                    f'req: {msg}\n'
 | 
			
		||||
                    f'resp: {err}\n'
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            if ret.error is not None:
 | 
			
		||||
                raise Exception(json.dumps(ret.error, indent=4))
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -428,6 +451,7 @@ async def open_jsonrpc_session(
 | 
			
		|||
            the server side.
 | 
			
		||||
 | 
			
		||||
            '''
 | 
			
		||||
            nonlocal req_msgs
 | 
			
		||||
            async for msg in ws:
 | 
			
		||||
                match msg:
 | 
			
		||||
                    case {
 | 
			
		||||
| 
						 | 
				
			
			@ -451,19 +475,28 @@ async def open_jsonrpc_session(
 | 
			
		|||
                        'params': _,
 | 
			
		||||
                    }:
 | 
			
		||||
                        log.debug(f'Recieved\n{msg}')
 | 
			
		||||
                        if request_hook:
 | 
			
		||||
                            await request_hook(request_type(**msg))
 | 
			
		||||
 | 
			
		||||
                    case {
 | 
			
		||||
                        'error': error
 | 
			
		||||
                    }:
 | 
			
		||||
                        log.warning(f'Recieved\n{error}')
 | 
			
		||||
                        if error_hook:
 | 
			
		||||
                            await error_hook(response_type(**msg))
 | 
			
		||||
                        # retreive orig request msg, set error
 | 
			
		||||
                        # response in original "result" msg,
 | 
			
		||||
                        # THEN FINALLY set the event to signal caller
 | 
			
		||||
                        # to raise the error in the parent task.
 | 
			
		||||
                        req_id: int = error['id']
 | 
			
		||||
                        req_msg: dict = req_msgs[req_id]
 | 
			
		||||
                        result: dict = rpc_results[req_id]
 | 
			
		||||
                        result['error'] = error
 | 
			
		||||
                        result['event'].set()
 | 
			
		||||
                        log.error(
 | 
			
		||||
                            f'JSONRPC request failed\n'
 | 
			
		||||
                            f'req: {req_msg}\n'
 | 
			
		||||
                            f'resp: {error}\n'
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                    case _:
 | 
			
		||||
                        log.warning(f'Unhandled JSON-RPC msg!?\n{msg}')
 | 
			
		||||
 | 
			
		||||
        n.start_soon(recv_task)
 | 
			
		||||
        tn.start_soon(recv_task)
 | 
			
		||||
        yield json_rpc
 | 
			
		||||
        n.cancel_scope.cancel()
 | 
			
		||||
        tn.cancel_scope.cancel()
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -786,7 +786,6 @@ async def install_brokerd_search(
 | 
			
		|||
 | 
			
		||||
@acm
 | 
			
		||||
async def maybe_open_feed(
 | 
			
		||||
 | 
			
		||||
    fqmes: list[str],
 | 
			
		||||
    loglevel: str | None = None,
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -840,13 +839,12 @@ async def maybe_open_feed(
 | 
			
		|||
 | 
			
		||||
@acm
 | 
			
		||||
async def open_feed(
 | 
			
		||||
 | 
			
		||||
    fqmes: list[str],
 | 
			
		||||
 | 
			
		||||
    loglevel: str | None = None,
 | 
			
		||||
    loglevel: str|None = None,
 | 
			
		||||
    allow_overruns: bool = True,
 | 
			
		||||
    start_stream: bool = True,
 | 
			
		||||
    tick_throttle: float | None = None,  # Hz
 | 
			
		||||
    tick_throttle: float|None = None,  # Hz
 | 
			
		||||
 | 
			
		||||
    allow_remote_ctl_ui: bool = False,
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -36,10 +36,10 @@ from ._sharedmem import (
 | 
			
		|||
    ShmArray,
 | 
			
		||||
    _Token,
 | 
			
		||||
)
 | 
			
		||||
from piker.accounting import MktPair
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from ..accounting import MktPair
 | 
			
		||||
    from .feed import Feed
 | 
			
		||||
    from piker.data.feed import Feed
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Flume(Struct):
 | 
			
		||||
| 
						 | 
				
			
			@ -82,7 +82,7 @@ class Flume(Struct):
 | 
			
		|||
 | 
			
		||||
    # TODO: do we need this really if we can pull the `Portal` from
 | 
			
		||||
    # ``tractor``'s internals?
 | 
			
		||||
    feed: Feed | None = None
 | 
			
		||||
    feed: Feed|None = None
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def rt_shm(self) -> ShmArray:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -113,9 +113,9 @@ def validate_backend(
 | 
			
		|||
            )
 | 
			
		||||
            if ep is None:
 | 
			
		||||
                log.warning(
 | 
			
		||||
                    f'Provider backend {mod.name} is missing '
 | 
			
		||||
                    f'{daemon_name} support :(\n'
 | 
			
		||||
                    f'The following endpoint is missing: {name}'
 | 
			
		||||
                    f'Provider backend {mod.name!r} is missing '
 | 
			
		||||
                    f'{daemon_name!r} support?\n'
 | 
			
		||||
                    f'|_module endpoint-func missing: {name!r}\n'
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
    inits: list[
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										30
									
								
								piker/log.py
								
								
								
								
							
							
						
						
									
										30
									
								
								piker/log.py
								
								
								
								
							| 
						 | 
				
			
			@ -19,6 +19,10 @@ Log like a forester!
 | 
			
		|||
"""
 | 
			
		||||
import logging
 | 
			
		||||
import json
 | 
			
		||||
import reprlib
 | 
			
		||||
from typing import (
 | 
			
		||||
    Callable,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
import tractor
 | 
			
		||||
from pygments import (
 | 
			
		||||
| 
						 | 
				
			
			@ -84,3 +88,29 @@ def colorize_json(
 | 
			
		|||
        # likeable styles: algol_nu, tango, monokai
 | 
			
		||||
        formatters.TerminalTrueColorFormatter(style=style)
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
# TODO, eventually defer to the version in `modden` once
 | 
			
		||||
# it becomes a dep!
 | 
			
		||||
def mk_repr(
 | 
			
		||||
    **repr_kws,
 | 
			
		||||
) -> Callable[[str], str]:
 | 
			
		||||
    '''
 | 
			
		||||
    Allocate and deliver a `repr.Repr` instance with provided input
 | 
			
		||||
    settings using the std-lib's `reprlib` mod,
 | 
			
		||||
     * https://docs.python.org/3/library/reprlib.html
 | 
			
		||||
 | 
			
		||||
    ------ Ex. ------
 | 
			
		||||
    An up to 6-layer-nested `dict` as multi-line:
 | 
			
		||||
    - https://stackoverflow.com/a/79102479
 | 
			
		||||
    - https://docs.python.org/3/library/reprlib.html#reprlib.Repr.maxlevel
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    def_kws: dict[str, int] = dict(
 | 
			
		||||
        indent=2,
 | 
			
		||||
        maxlevel=6,  # recursion levels
 | 
			
		||||
        maxstring=66,  # match editor line-len limit
 | 
			
		||||
    )
 | 
			
		||||
    def_kws |= repr_kws
 | 
			
		||||
    reprr = reprlib.Repr(**def_kws)
 | 
			
		||||
    return reprr.repr
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -119,6 +119,10 @@ async def open_piker_runtime(
 | 
			
		|||
                # spawn other specialized daemons I think?
 | 
			
		||||
                enable_modules=enable_modules,
 | 
			
		||||
 | 
			
		||||
                # TODO: how to configure this?
 | 
			
		||||
                # keep it on by default if debug mode is set?
 | 
			
		||||
                maybe_enable_greenback=False,
 | 
			
		||||
 | 
			
		||||
                **tractor_kwargs,
 | 
			
		||||
            ) as actor,
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -386,6 +386,8 @@ def ldshm(
 | 
			
		|||
            open_annot_ctl() as actl,
 | 
			
		||||
        ):
 | 
			
		||||
            shm_df: pl.DataFrame | None = None
 | 
			
		||||
            tf2aids: dict[float, dict] = {}
 | 
			
		||||
 | 
			
		||||
            for (
 | 
			
		||||
                shmfile,
 | 
			
		||||
                shm,
 | 
			
		||||
| 
						 | 
				
			
			@ -526,16 +528,17 @@ def ldshm(
 | 
			
		|||
                            new_df,
 | 
			
		||||
                            step_gaps,
 | 
			
		||||
                        )
 | 
			
		||||
 | 
			
		||||
                        # last chance manual overwrites in REPL
 | 
			
		||||
                        await tractor.pause()
 | 
			
		||||
                        # await tractor.pause()
 | 
			
		||||
                        assert aids
 | 
			
		||||
                        tf2aids[period_s] = aids
 | 
			
		||||
 | 
			
		||||
                else:
 | 
			
		||||
                    # allow interaction even when no ts problems.
 | 
			
		||||
                    await tractor.pause()
 | 
			
		||||
                    # assert not diff
 | 
			
		||||
                    assert not diff
 | 
			
		||||
 | 
			
		||||
            await tractor.pause()
 | 
			
		||||
            log.info('Exiting TSP shm anal-izer!')
 | 
			
		||||
 | 
			
		||||
            if shm_df is None:
 | 
			
		||||
                log.error(
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -161,7 +161,13 @@ class NativeStorageClient:
 | 
			
		|||
 | 
			
		||||
    def index_files(self):
 | 
			
		||||
        for path in self._datadir.iterdir():
 | 
			
		||||
            if path.name in {'borked', 'expired',}:
 | 
			
		||||
            if (
 | 
			
		||||
                path.is_dir()
 | 
			
		||||
                or
 | 
			
		||||
                '.parquet' not in str(path)
 | 
			
		||||
                # or
 | 
			
		||||
                # path.name in {'borked', 'expired',}
 | 
			
		||||
            ):
 | 
			
		||||
                continue
 | 
			
		||||
 | 
			
		||||
            key: str = path.name.rstrip('.parquet')
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -458,13 +458,15 @@ async def start_backfill(
 | 
			
		|||
                    'bf_until <- last_start_dt:\n'
 | 
			
		||||
                    f'{backfill_until_dt} <- {last_start_dt}\n'
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # ugh, what's a better way?
 | 
			
		||||
                # TODO: fwiw, we probably want a way to signal a throttle
 | 
			
		||||
                # condition (eg. with ib) so that we can halt the
 | 
			
		||||
                # request loop until the condition is resolved?
 | 
			
		||||
                if timeframe > 1:
 | 
			
		||||
                    await tractor.pause()
 | 
			
		||||
                # UGH: what's a better way?
 | 
			
		||||
                # TODO: backends are responsible for being correct on
 | 
			
		||||
                # this right!?
 | 
			
		||||
                # -[ ] in the `ib` case we could maybe offer some way
 | 
			
		||||
                #     to halt the request loop until the condition is
 | 
			
		||||
                #     resolved or should the backend be entirely in
 | 
			
		||||
                #     charge of solving such faults? yes, right?
 | 
			
		||||
                # if timeframe > 1:
 | 
			
		||||
                #     await tractor.pause()
 | 
			
		||||
                return
 | 
			
		||||
 | 
			
		||||
            assert (
 | 
			
		||||
| 
						 | 
				
			
			@ -572,15 +574,19 @@ async def start_backfill(
 | 
			
		|||
                    f'{next_start_dt} -> {last_start_dt}'
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
                # always drop the src asset token for
 | 
			
		||||
                # NOTE, always drop the src asset token for
 | 
			
		||||
                # non-currency-pair like market types (for now)
 | 
			
		||||
                #
 | 
			
		||||
                # THAT IS, for now our table key schema is NOT
 | 
			
		||||
                # including the dst[/src] source asset token. SO,
 | 
			
		||||
                # 'tsla.nasdaq.ib' over 'tsla/usd.nasdaq.ib' for
 | 
			
		||||
                # historical reasons ONLY.
 | 
			
		||||
                if mkt.dst.atype not in {
 | 
			
		||||
                    'crypto',
 | 
			
		||||
                    'crypto_currency',
 | 
			
		||||
                    'fiat',  # a "forex pair"
 | 
			
		||||
                    'perpetual_future',  # stupid "perps" from cex land
 | 
			
		||||
                }:
 | 
			
		||||
                    # for now, our table key schema is not including
 | 
			
		||||
                    # the dst[/src] source asset token.
 | 
			
		||||
                    col_sym_key: str = mkt.get_fqme(
 | 
			
		||||
                        delim_char='',
 | 
			
		||||
                        without_src=True,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -616,6 +616,18 @@ def detect_price_gaps(
 | 
			
		|||
    # ])
 | 
			
		||||
    ...
 | 
			
		||||
 | 
			
		||||
# TODO: probably just use the null_segs impl above?
 | 
			
		||||
def detect_vlm_gaps(
 | 
			
		||||
    df: pl.DataFrame,
 | 
			
		||||
    col: str = 'volume',
 | 
			
		||||
 | 
			
		||||
) -> pl.DataFrame:
 | 
			
		||||
 | 
			
		||||
    vnull: pl.DataFrame = w_dts.filter(
 | 
			
		||||
        pl.col(col) == 0
 | 
			
		||||
    )
 | 
			
		||||
    return vnull
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def dedupe(
 | 
			
		||||
    src_df: pl.DataFrame,
 | 
			
		||||
| 
						 | 
				
			
			@ -626,7 +638,6 @@ def dedupe(
 | 
			
		|||
 | 
			
		||||
) -> tuple[
 | 
			
		||||
    pl.DataFrame,  # with dts
 | 
			
		||||
    pl.DataFrame,  # gaps
 | 
			
		||||
    pl.DataFrame,  # with deduplicated dts (aka gap/repeat removal)
 | 
			
		||||
    int,  # len diff between input and deduped
 | 
			
		||||
]:
 | 
			
		||||
| 
						 | 
				
			
			@ -639,19 +650,22 @@ def dedupe(
 | 
			
		|||
    '''
 | 
			
		||||
    wdts: pl.DataFrame = with_dts(src_df)
 | 
			
		||||
 | 
			
		||||
    # maybe sort on any time field
 | 
			
		||||
    if sort:
 | 
			
		||||
        wdts = wdts.sort(by='time')
 | 
			
		||||
        # TODO: detect out-of-order segments which were corrected!
 | 
			
		||||
        # -[ ] report in log msg
 | 
			
		||||
        # -[ ] possibly return segment sections which were moved?
 | 
			
		||||
    deduped = wdts
 | 
			
		||||
 | 
			
		||||
    # remove duplicated datetime samples/sections
 | 
			
		||||
    deduped: pl.DataFrame = wdts.unique(
 | 
			
		||||
        subset=['dt'],
 | 
			
		||||
        # subset=['dt'],
 | 
			
		||||
        subset=['time'],
 | 
			
		||||
        maintain_order=True,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    # maybe sort on any time field
 | 
			
		||||
    if sort:
 | 
			
		||||
        deduped = deduped.sort(by='time')
 | 
			
		||||
        # TODO: detect out-of-order segments which were corrected!
 | 
			
		||||
        # -[ ] report in log msg
 | 
			
		||||
        # -[ ] possibly return segment sections which were moved?
 | 
			
		||||
 | 
			
		||||
    diff: int = (
 | 
			
		||||
        wdts.height
 | 
			
		||||
        -
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										228
									
								
								piker/types.py
								
								
								
								
							
							
						
						
									
										228
									
								
								piker/types.py
								
								
								
								
							| 
						 | 
				
			
			@ -21,230 +21,4 @@ Extensions to built-in or (heavily used but 3rd party) friend-lib
 | 
			
		|||
types.
 | 
			
		||||
 | 
			
		||||
'''
 | 
			
		||||
from __future__ import annotations
 | 
			
		||||
from collections import UserList
 | 
			
		||||
from pprint import (
 | 
			
		||||
    saferepr,
 | 
			
		||||
)
 | 
			
		||||
from typing import Any
 | 
			
		||||
 | 
			
		||||
from msgspec import (
 | 
			
		||||
    msgpack,
 | 
			
		||||
    Struct as _Struct,
 | 
			
		||||
    structs,
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DiffDump(UserList):
 | 
			
		||||
    '''
 | 
			
		||||
    Very simple list delegator that repr() dumps (presumed) tuple
 | 
			
		||||
    elements of the form `tuple[str, Any, Any]` in a nice
 | 
			
		||||
    multi-line readable form for analyzing `Struct` diffs.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    def __repr__(self) -> str:
 | 
			
		||||
        if not len(self):
 | 
			
		||||
            return super().__repr__()
 | 
			
		||||
 | 
			
		||||
        # format by displaying item pair's ``repr()`` on multiple,
 | 
			
		||||
        # indented lines such that they are more easily visually
 | 
			
		||||
        # comparable when printed to console when printed to
 | 
			
		||||
        # console.
 | 
			
		||||
        repstr: str = '[\n'
 | 
			
		||||
        for k, left, right in self:
 | 
			
		||||
            repstr += (
 | 
			
		||||
                f'({k},\n'
 | 
			
		||||
                f'\t{repr(left)},\n'
 | 
			
		||||
                f'\t{repr(right)},\n'
 | 
			
		||||
                ')\n'
 | 
			
		||||
            )
 | 
			
		||||
        repstr += ']\n'
 | 
			
		||||
        return repstr
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Struct(
 | 
			
		||||
    _Struct,
 | 
			
		||||
 | 
			
		||||
    # https://jcristharif.com/msgspec/structs.html#tagged-unions
 | 
			
		||||
    # tag='pikerstruct',
 | 
			
		||||
    # tag=True,
 | 
			
		||||
):
 | 
			
		||||
    '''
 | 
			
		||||
    A "human friendlier" (aka repl buddy) struct subtype.
 | 
			
		||||
 | 
			
		||||
    '''
 | 
			
		||||
    def _sin_props(self) -> Iterator[
 | 
			
		||||
        tuple[
 | 
			
		||||
            structs.FieldIinfo,
 | 
			
		||||
            str,
 | 
			
		||||
            Any,
 | 
			
		||||
        ]
 | 
			
		||||
    ]:
 | 
			
		||||
        '''
 | 
			
		||||
        Iterate over all non-@property fields of this struct.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        fi: structs.FieldInfo
 | 
			
		||||
        for fi in structs.fields(self):
 | 
			
		||||
            key: str = fi.name
 | 
			
		||||
            val: Any = getattr(self, key)
 | 
			
		||||
            yield fi, key, val
 | 
			
		||||
 | 
			
		||||
    def to_dict(
 | 
			
		||||
        self,
 | 
			
		||||
        include_non_members: bool = True,
 | 
			
		||||
 | 
			
		||||
    ) -> dict:
 | 
			
		||||
        '''
 | 
			
		||||
        Like it sounds.. direct delegation to:
 | 
			
		||||
        https://jcristharif.com/msgspec/api.html#msgspec.structs.asdict
 | 
			
		||||
 | 
			
		||||
        BUT, by default we pop all non-member (aka not defined as
 | 
			
		||||
        struct fields) fields by default.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        asdict: dict = structs.asdict(self)
 | 
			
		||||
        if include_non_members:
 | 
			
		||||
            return asdict
 | 
			
		||||
 | 
			
		||||
        # only return a dict of the struct members
 | 
			
		||||
        # which were provided as input, NOT anything
 | 
			
		||||
        # added as type-defined `@property` methods!
 | 
			
		||||
        sin_props: dict = {}
 | 
			
		||||
        fi: structs.FieldInfo
 | 
			
		||||
        for fi, k, v in self._sin_props():
 | 
			
		||||
            sin_props[k] = asdict[k]
 | 
			
		||||
 | 
			
		||||
        return sin_props
 | 
			
		||||
 | 
			
		||||
    def pformat(
 | 
			
		||||
        self,
 | 
			
		||||
        field_indent: int = 2,
 | 
			
		||||
        indent: int = 0,
 | 
			
		||||
 | 
			
		||||
    ) -> str:
 | 
			
		||||
        '''
 | 
			
		||||
        Recursion-safe `pprint.pformat()` style formatting of
 | 
			
		||||
        a `msgspec.Struct` for sane reading by a human using a REPL.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        # global whitespace indent
 | 
			
		||||
        ws: str = ' '*indent
 | 
			
		||||
 | 
			
		||||
        # field whitespace indent
 | 
			
		||||
        field_ws: str = ' '*(field_indent + indent)
 | 
			
		||||
 | 
			
		||||
        # qtn: str = ws + self.__class__.__qualname__
 | 
			
		||||
        qtn: str = self.__class__.__qualname__
 | 
			
		||||
 | 
			
		||||
        obj_str: str = ''  # accumulator
 | 
			
		||||
        fi: structs.FieldInfo
 | 
			
		||||
        k: str
 | 
			
		||||
        v: Any
 | 
			
		||||
        for fi, k, v in self._sin_props():
 | 
			
		||||
 | 
			
		||||
            # TODO: how can we prefer `Literal['option1',  'option2,
 | 
			
		||||
            # ..]` over .__name__ == `Literal` but still get only the
 | 
			
		||||
            # latter for simple types like `str | int | None` etc..?
 | 
			
		||||
            ft: type = fi.type
 | 
			
		||||
            typ_name: str = getattr(ft, '__name__', str(ft))
 | 
			
		||||
 | 
			
		||||
            # recurse to get sub-struct's `.pformat()` output Bo
 | 
			
		||||
            if isinstance(v, Struct):
 | 
			
		||||
                val_str: str =  v.pformat(
 | 
			
		||||
                    indent=field_indent + indent,
 | 
			
		||||
                    field_indent=indent + field_indent,
 | 
			
		||||
                )
 | 
			
		||||
 | 
			
		||||
            else:  # the `pprint` recursion-safe format:
 | 
			
		||||
                # https://docs.python.org/3.11/library/pprint.html#pprint.saferepr
 | 
			
		||||
                val_str: str = saferepr(v)
 | 
			
		||||
 | 
			
		||||
            obj_str += (field_ws + f'{k}: {typ_name} = {val_str},\n')
 | 
			
		||||
 | 
			
		||||
        return (
 | 
			
		||||
            f'{qtn}(\n'
 | 
			
		||||
            f'{obj_str}'
 | 
			
		||||
            f'{ws})'
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    # TODO: use a pprint.PrettyPrinter instance around ONLY rendering
 | 
			
		||||
    # inside a known tty?
 | 
			
		||||
    # def __repr__(self) -> str:
 | 
			
		||||
    #     ...
 | 
			
		||||
 | 
			
		||||
    # __str__ = __repr__ = pformat
 | 
			
		||||
    __repr__ = pformat
 | 
			
		||||
 | 
			
		||||
    def copy(
 | 
			
		||||
        self,
 | 
			
		||||
        update: dict | None = None,
 | 
			
		||||
 | 
			
		||||
    ) -> Struct:
 | 
			
		||||
        '''
 | 
			
		||||
        Validate-typecast all self defined fields, return a copy of
 | 
			
		||||
        us with all such fields.
 | 
			
		||||
 | 
			
		||||
        NOTE: This is kinda like the default behaviour in
 | 
			
		||||
        `pydantic.BaseModel` except a copy of the object is
 | 
			
		||||
        returned making it compat with `frozen=True`.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        if update:
 | 
			
		||||
            for k, v in update.items():
 | 
			
		||||
                setattr(self, k, v)
 | 
			
		||||
 | 
			
		||||
        # NOTE: roundtrip serialize to validate
 | 
			
		||||
        # - enode to msgpack binary format,
 | 
			
		||||
        # - decode that back to a struct.
 | 
			
		||||
        return msgpack.Decoder(type=type(self)).decode(
 | 
			
		||||
            msgpack.Encoder().encode(self)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    def typecast(
 | 
			
		||||
        self,
 | 
			
		||||
 | 
			
		||||
        # TODO: allow only casting a named subset?
 | 
			
		||||
        # fields: set[str] | None = None,
 | 
			
		||||
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        '''
 | 
			
		||||
        Cast all fields using their declared type annotations
 | 
			
		||||
        (kinda like what `pydantic` does by default).
 | 
			
		||||
 | 
			
		||||
        NOTE: this of course won't work on frozen types, use
 | 
			
		||||
        ``.copy()`` above in such cases.
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        # https://jcristharif.com/msgspec/api.html#msgspec.structs.fields
 | 
			
		||||
        fi: structs.FieldInfo
 | 
			
		||||
        for fi in structs.fields(self):
 | 
			
		||||
            setattr(
 | 
			
		||||
                self,
 | 
			
		||||
                fi.name,
 | 
			
		||||
                fi.type(getattr(self, fi.name)),
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
    def __sub__(
 | 
			
		||||
        self,
 | 
			
		||||
        other: Struct,
 | 
			
		||||
 | 
			
		||||
    ) -> DiffDump[tuple[str, Any, Any]]:
 | 
			
		||||
        '''
 | 
			
		||||
        Compare fields/items key-wise and return a ``DiffDump``
 | 
			
		||||
        for easy visual REPL comparison B)
 | 
			
		||||
 | 
			
		||||
        '''
 | 
			
		||||
        diffs: DiffDump[tuple[str, Any, Any]] = DiffDump()
 | 
			
		||||
        for fi in structs.fields(self):
 | 
			
		||||
            attr_name: str = fi.name
 | 
			
		||||
            ours: Any = getattr(self, attr_name)
 | 
			
		||||
            theirs: Any = getattr(other, attr_name)
 | 
			
		||||
            if ours != theirs:
 | 
			
		||||
                diffs.append((
 | 
			
		||||
                    attr_name,
 | 
			
		||||
                    ours,
 | 
			
		||||
                    theirs,
 | 
			
		||||
                ))
 | 
			
		||||
 | 
			
		||||
        return diffs
 | 
			
		||||
from tractor.msg import Struct as Struct
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue