diff --git a/piker/data/_sampling.py b/piker/data/_sampling.py index 622a8ad1..f70e4113 100644 --- a/piker/data/_sampling.py +++ b/piker/data/_sampling.py @@ -24,12 +24,17 @@ from collections import ( Counter, defaultdict, ) +from contextlib import asynccontextmanager as acm import time from typing import ( + AsyncIterator, TYPE_CHECKING, ) import tractor +from tractor.trionics import ( + maybe_open_nursery, +) import trio from trio_typing import TaskStatus @@ -37,14 +42,20 @@ from ..log import ( get_logger, get_console_log, ) +from .._daemon import maybe_spawn_daemon if TYPE_CHECKING: - from ._sharedmem import ShmArray + from ._sharedmem import ( + ShmArray, + ) from .feed import _FeedsBus log = get_logger(__name__) +# highest frequency sample step is 1 second by default, though in +# the future we may want to support shorter periods or a dynamic style +# tick-event stream. _default_delay_s: float = 1.0 @@ -55,32 +66,50 @@ class Sampler: Manages state for sampling events, shm incrementing and sample period logic. + This non-instantiated type is meant to be a singleton within + a `samplerd` actor-service spawned once by the user wishing to + time-step sample real-time quote feeds, see + ``._daemon.maybe_open_samplerd()`` and the below + ``register_with_sampler()``. + ''' service_nursery: None | trio.Nursery = None # TODO: we could stick these in a composed type to avoid # angering the "i hate module scoped variables crowd" (yawn). - ohlcv_shms: dict[int, list[ShmArray]] = {} + ohlcv_shms: dict[float, list[ShmArray]] = {} # holds one-task-per-sample-period tasks which are spawned as-needed by # data feed requests with a given detected time step usually from # history loading. - incrementers: dict[int, trio.CancelScope] = {} + incr_task_cs: trio.CancelScope | None = None # holds all the ``tractor.Context`` remote subscriptions for # a particular sample period increment event: all subscribers are # notified on a step. - subscribers: dict[int, tractor.Context] = {} + # subscribers: dict[int, list[tractor.MsgStream]] = {} + subscribers: defaultdict[ + float, + list[ + float, + set[tractor.MsgStream] + ], + ] = defaultdict( + lambda: [ + round(time.time()), + set(), + ] + ) @classmethod async def increment_ohlc_buffer( self, - delay_s: int, + period_s: float, task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED, ): ''' Task which inserts new bars into the provide shared memory array - every ``delay_s`` seconds. + every ``period_s`` seconds. This task fulfills 2 purposes: - it takes the subscribed set of shm arrays and increments them @@ -92,66 +121,74 @@ class Sampler: the underlying buffers will actually be incremented. ''' - # # wait for brokerd to signal we should start sampling - # await shm_incrementing(shm_token['shm_name']).wait() - # TODO: right now we'll spin printing bars if the last time stamp is # before a large period of no market activity. Likely the best way # to solve this is to make this task aware of the instrument's # tradable hours? - # adjust delay to compensate for trio processing time - ad = min(self.ohlcv_shms.keys()) - 0.001 - - total_s = 0 # total seconds counted - lowest = min(self.ohlcv_shms.keys()) - lowest_shm = self.ohlcv_shms[lowest][0] - ad = lowest - 0.001 + total_s: float = 0 # total seconds counted + ad = period_s - 0.001 # compensate for trio processing time with trio.CancelScope() as cs: - # register this time period step as active - self.incrementers[delay_s] = cs task_status.started(cs) + # sample step loop: + # includes broadcasting to all connected consumers on every + # new sample step as well incrementing any registered + # buffers by registered sample period. while True: - # TODO: do we want to support dynamically - # adding a "lower" lowest increment period? await trio.sleep(ad) - total_s += delay_s + total_s += period_s # increment all subscribed shm arrays # TODO: # - this in ``numba`` # - just lookup shms for this step instead of iterating? - for this_delay_s, shms in self.ohlcv_shms.items(): + + i_epoch = round(time.time()) + broadcasted: set[float] = set() + + # print(f'epoch: {i_epoch} -> REGISTRY {self.ohlcv_shms}') + for shm_period_s, shms in self.ohlcv_shms.items(): # short-circuit on any not-ready because slower sample # rate consuming shm buffers. - if total_s % this_delay_s != 0: - # print(f'skipping `{this_delay_s}s` sample update') + if total_s % shm_period_s != 0: + # print(f'skipping `{shm_period_s}s` sample update') continue + # update last epoch stamp for this period group + if shm_period_s not in broadcasted: + sub_pair = self.subscribers[shm_period_s] + sub_pair[0] = i_epoch + print(f'skipping `{shm_period_s}s` sample update') + broadcasted.add(shm_period_s) + # TODO: ``numba`` this! for shm in shms: + # print(f'UPDATE {shm_period_s}s STEP for {shm.token}') + # append new entry to buffer thus "incrementing" # the bar array = shm.array last = array[-1:][shm._write_fields].copy() + # guard against startup backfilling race with + # empty buffers. + if not last.size: + continue + (t, close) = last[0][[ 'time', 'close', ]] - next_t = t + this_delay_s - i_epoch = round(time.time()) + next_t = t + shm_period_s - if this_delay_s <= 1: + if shm_period_s <= 1: next_t = i_epoch - # print(f'epoch {shm.token["shm_name"]}: {next_t}') - # this copies non-std fields (eg. vwap) from the # last datum last[[ @@ -185,43 +222,43 @@ class Sampler: # write to the buffer shm.push(last) - await self.broadcast(delay_s, shm=lowest_shm) + # broadcast increment msg to all updated subs per period + for shm_period_s in broadcasted: + await self.broadcast( + period_s=shm_period_s, + time_stamp=i_epoch, + ) @classmethod async def broadcast( self, - delay_s: int, - shm: ShmArray | None = None, + period_s: float, + time_stamp: float | None = None, ) -> None: ''' - Broadcast the given ``shm: ShmArray``'s buffer index step to any + Broadcast the period size and last index step value to all subscribers for a given sample period. - The sent msg will include the first and last index which slice into - the buffer's non-empty data. - ''' - subs = self.subscribers.get(delay_s, ()) - first = last = -1 + pair = self.subscribers[period_s] - if shm is None: - periods = self.ohlcv_shms.keys() - # if this is an update triggered by a history update there - # might not actually be any sampling bus setup since there's - # no "live feed" active yet. - if periods: - lowest = min(periods) - shm = self.ohlcv_shms[lowest][0] - first = shm._first.value - last = shm._last.value + last_ts, subs = pair + task = trio.lowlevel.current_task() + log.debug( + f'SUBS {self.subscribers}\n' + f'PAIR {pair}\n' + f'TASK: {task}: {id(task)}\n' + f'broadcasting {period_s} -> {last_ts}\n' + # f'consumers: {subs}' + ) + borked: set[tractor.MsgStream] = set() for stream in subs: try: await stream.send({ - 'first': first, - 'last': last, - 'index': last, + 'index': time_stamp or last_ts, + 'period': period_s, }) except ( trio.BrokenResourceError, @@ -230,69 +267,232 @@ class Sampler: log.error( f'{stream._ctx.chan.uid} dropped connection' ) - try: - subs.remove(stream) - except ValueError: - log.warning( - f'{stream._ctx.chan.uid} sub already removed!?' - ) + borked.add(stream) - @classmethod - async def broadcast_all(self) -> None: - for delay_s in self.subscribers: - await self.broadcast(delay_s) - - -@tractor.context -async def maybe_open_global_sampler( - ctx: tractor.Context, - brokername: str, - -) -> None: - get_console_log(tractor.current_actor().loglevel) - - global Sampler - - async with trio.open_nursery() as service_nursery: - Sampler.service_nursery = service_nursery - - # unblock caller - await ctx.started() - - # we pin this task to keep the feeds manager active until the - # parent actor decides to tear it down - await trio.sleep_forever() - - -@tractor.context -async def iter_ohlc_periods( - ctx: tractor.Context, - delay_s: int, - -) -> None: - ''' - Subscribe to OHLC sampling "step" events: when the time - aggregation period increments, this event stream emits an index - event. - - ''' - # add our subscription - subs = Sampler.subscribers.setdefault(delay_s, []) - await ctx.started() - async with ctx.open_stream() as stream: - subs.append(stream) - - try: - # stream and block until cancelled - await trio.sleep_forever() - finally: + for stream in borked: try: subs.remove(stream) except ValueError: - log.error( - f'iOHLC step stream was already dropped {ctx.chan.uid}?' + log.warning( + f'{stream._ctx.chan.uid} sub already removed!?' ) + @classmethod + async def broadcast_all(self) -> None: + for period_s in self.subscribers: + await self.broadcast(period_s) + + +@tractor.context +async def register_with_sampler( + ctx: tractor.Context, + period_s: float, + shms_by_period: dict[float, dict] | None = None, + open_index_stream: bool = True, + +) -> None: + + get_console_log(tractor.current_actor().loglevel) + incr_was_started: bool = False + + try: + async with maybe_open_nursery( + Sampler.service_nursery + ) as service_nursery: + + # init startup, create (actor-)local service nursery and start + # increment task + Sampler.service_nursery = service_nursery + + # always ensure a period subs entry exists + last_ts, subs = Sampler.subscribers[float(period_s)] + + async with trio.Lock(): + if Sampler.incr_task_cs is None: + Sampler.incr_task_cs = await service_nursery.start( + Sampler.increment_ohlc_buffer, + 1., + ) + incr_was_started = True + + # insert the base 1s period (for OHLC style sampling) into + # the increment buffer set to update and shift every second. + if shms_by_period is not None: + from ._sharedmem import ( + attach_shm_array, + _Token, + ) + for period in shms_by_period: + + # load and register shm handles + shm_token_msg = shms_by_period[period] + shm = attach_shm_array( + _Token.from_msg(shm_token_msg), + readonly=False, + ) + shms_by_period[period] = shm + Sampler.ohlcv_shms.setdefault(period, []).append(shm) + + assert Sampler.ohlcv_shms + + # unblock caller + await ctx.started(set(Sampler.ohlcv_shms.keys())) + + if open_index_stream: + try: + async with ctx.open_stream() as stream: + subs.add(stream) + + # except broadcast requests from the subscriber + async for msg in stream: + if msg == 'broadcast_all': + await Sampler.broadcast_all() + + finally: + subs.remove(stream) + else: + # if no shms are passed in we just wait until cancelled + # by caller. + await trio.sleep_forever() + + finally: + # TODO: why tf isn't this working? + if shms_by_period is not None: + for period, shm in shms_by_period.items(): + Sampler.ohlcv_shms[period].remove(shm) + + if incr_was_started: + Sampler.incr_task_cs.cancel() + Sampler.incr_task_cs = None + + +async def spawn_samplerd( + + loglevel: str | None = None, + **extra_tractor_kwargs + +) -> bool: + ''' + Daemon-side service task: start a sampling daemon for common step + update and increment count write and stream broadcasting. + + ''' + from piker._daemon import Services + + dname = 'samplerd' + log.info(f'Spawning `{dname}`') + + # singleton lock creation of ``samplerd`` since we only ever want + # one daemon per ``pikerd`` proc tree. + # TODO: make this built-into the service api? + async with Services.locks[dname + '_singleton']: + + if dname not in Services.service_tasks: + + portal = await Services.actor_n.start_actor( + dname, + enable_modules=[ + 'piker.data._sampling', + ], + loglevel=loglevel, + debug_mode=Services.debug_mode, # set by pikerd flag + **extra_tractor_kwargs + ) + + await Services.start_service_task( + dname, + portal, + register_with_sampler, + period_s=1, + ) + return True + + return False + + +@acm +async def maybe_open_samplerd( + + loglevel: str | None = None, + **kwargs, + +) -> tractor._portal.Portal: # noqa + ''' + Client-side helper to maybe startup the ``samplerd`` service + under the ``pikerd`` tree. + + ''' + dname = 'samplerd' + + async with maybe_spawn_daemon( + dname, + service_task_target=spawn_samplerd, + spawn_args={'loglevel': loglevel}, + loglevel=loglevel, + **kwargs, + + ) as portal: + yield portal + + +@acm +async def open_sample_stream( + period_s: int, + shms_by_period: dict[float, dict] | None = None, + open_index_stream: bool = True, + + cache_key: str | None = None, + allow_new_sampler: bool = True, + +) -> AsyncIterator[dict[str, float]]: + ''' + Subscribe to OHLC sampling "step" events: when the time aggregation + period increments, this event stream emits an index event. + + This is a client-side endpoint that does all the work of ensuring + the `samplerd` actor is up and that mult-consumer-tasks are given + a broadcast stream when possible. + + ''' + # TODO: wrap this manager with the following to make it cached + # per client-multitasks entry. + # maybe_open_context( + # acm_func=partial( + # portal.open_context, + # register_with_sampler, + # ), + # key=cache_key or period_s, + # ) + # if cache_hit: + # # add a new broadcast subscription for the quote stream + # # if this feed is likely already in use + # async with istream.subscribe() as bistream: + # yield bistream + # else: + + async with ( + # XXX: this should be singleton on a host, + # a lone broker-daemon per provider should be + # created for all practical purposes + maybe_open_samplerd() as portal, + + portal.open_context( + register_with_sampler, + **{ + 'period_s': period_s, + 'shms_by_period': shms_by_period, + 'open_index_stream': open_index_stream, + }, + ) as (ctx, first) + ): + async with ( + ctx.open_stream() as istream, + + # TODO: we don't need this task-bcasting right? + # istream.subscribe() as istream, + ): + yield istream + async def sample_and_broadcast( @@ -304,7 +504,14 @@ async def sample_and_broadcast( sum_tick_vlm: bool = True, ) -> None: + ''' + `brokerd`-side task which writes latest datum sampled data. + This task is meant to run in the same actor (mem space) as the + `brokerd` real-time quote feed which is being sampled to + a ``ShmArray`` buffer. + + ''' log.info("Started shared mem bar writer") overruns = Counter() @@ -341,7 +548,6 @@ async def sample_and_broadcast( for shm in [rt_shm, hist_shm]: # update last entry # benchmarked in the 4-5 us range - # for shm in [rt_shm, hist_shm]: o, high, low, v = shm.array[-1][ ['open', 'high', 'low', 'volume'] ]