Implement a `samplerd` singleton actor service
Now spawned under the `pikerd` tree as a singleton-daemon-actor we offer a slew of new routines in support of this micro-service: - `maybe_open_samplerd()` and `spawn_samplerd()` which provide the `._daemon.Services` integration to conduct service spawning. - `open_sample_stream()` which is a client-side endpoint which does all the work of (lazily) starting the `samplerd` service (if dne) and registers shm buffers for update as well as connect a sample-index stream for iterator by the caller. - `register_with_sampler()` which is the `samplerd`-side service task endpoint implementing all the shm buffer and index-stream registry details as well as logic to ensure a lone service task runs `Services.increment_ohlc_buffer()`; it increments at the shortest period registered which, for now, is the default 1s duration. Further impl notes: - fixes to `Services.broadcast()` to ensure broken streams get discarded gracefully. - we use a `pikerd` side singleton mutex `trio.Lock()` to ensure one-and-only-one `samplerd` is ever spawned per `pikerd` actor tree.samplerd_service
							parent
							
								
									a342f7d2d4
								
							
						
					
					
						commit
						5ec1a72a3d
					
				|  | @ -24,12 +24,17 @@ from collections import ( | |||
|     Counter, | ||||
|     defaultdict, | ||||
| ) | ||||
| from contextlib import asynccontextmanager as acm | ||||
| import time | ||||
| from typing import ( | ||||
|     AsyncIterator, | ||||
|     TYPE_CHECKING, | ||||
| ) | ||||
| 
 | ||||
| import tractor | ||||
| from tractor.trionics import ( | ||||
|     maybe_open_nursery, | ||||
| ) | ||||
| import trio | ||||
| from trio_typing import TaskStatus | ||||
| 
 | ||||
|  | @ -37,14 +42,20 @@ from ..log import ( | |||
|     get_logger, | ||||
|     get_console_log, | ||||
| ) | ||||
| from .._daemon import maybe_spawn_daemon | ||||
| 
 | ||||
| if TYPE_CHECKING: | ||||
|     from ._sharedmem import ShmArray | ||||
|     from ._sharedmem import ( | ||||
|         ShmArray, | ||||
|     ) | ||||
|     from .feed import _FeedsBus | ||||
| 
 | ||||
| log = get_logger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| # highest frequency sample step is 1 second by default, though in | ||||
| # the future we may want to support shorter periods or a dynamic style | ||||
| # tick-event stream. | ||||
| _default_delay_s: float = 1.0 | ||||
| 
 | ||||
| 
 | ||||
|  | @ -55,32 +66,50 @@ class Sampler: | |||
|     Manages state for sampling events, shm incrementing and | ||||
|     sample period logic. | ||||
| 
 | ||||
|     This non-instantiated type is meant to be a singleton within | ||||
|     a `samplerd` actor-service spawned once by the user wishing to | ||||
|     time-step sample real-time quote feeds, see | ||||
|     ``._daemon.maybe_open_samplerd()`` and the below | ||||
|     ``register_with_sampler()``. | ||||
| 
 | ||||
|     ''' | ||||
|     service_nursery: None | trio.Nursery = None | ||||
| 
 | ||||
|     # TODO: we could stick these in a composed type to avoid | ||||
|     # angering the "i hate module scoped variables crowd" (yawn). | ||||
|     ohlcv_shms: dict[int, list[ShmArray]] = {} | ||||
|     ohlcv_shms: dict[float, list[ShmArray]] = {} | ||||
| 
 | ||||
|     # holds one-task-per-sample-period tasks which are spawned as-needed by | ||||
|     # data feed requests with a given detected time step usually from | ||||
|     # history loading. | ||||
|     incrementers: dict[int, trio.CancelScope] = {} | ||||
|     incr_task_cs: trio.CancelScope | None = None | ||||
| 
 | ||||
|     # holds all the ``tractor.Context`` remote subscriptions for | ||||
|     # a particular sample period increment event: all subscribers are | ||||
|     # notified on a step. | ||||
|     subscribers: dict[int, tractor.Context] = {} | ||||
|     # subscribers: dict[int, list[tractor.MsgStream]] = {} | ||||
|     subscribers: defaultdict[ | ||||
|         float, | ||||
|         list[ | ||||
|             float, | ||||
|             set[tractor.MsgStream] | ||||
|         ], | ||||
|     ] = defaultdict( | ||||
|         lambda: [ | ||||
|             round(time.time()), | ||||
|             set(), | ||||
|         ] | ||||
|     ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     async def increment_ohlc_buffer( | ||||
|         self, | ||||
|         delay_s: int, | ||||
|         period_s: float, | ||||
|         task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED, | ||||
|     ): | ||||
|         ''' | ||||
|         Task which inserts new bars into the provide shared memory array | ||||
|         every ``delay_s`` seconds. | ||||
|         every ``period_s`` seconds. | ||||
| 
 | ||||
|         This task fulfills 2 purposes: | ||||
|         - it takes the subscribed set of shm arrays and increments them | ||||
|  | @ -92,66 +121,74 @@ class Sampler: | |||
|         the underlying buffers will actually be incremented. | ||||
| 
 | ||||
|         ''' | ||||
|         # # wait for brokerd to signal we should start sampling | ||||
|         # await shm_incrementing(shm_token['shm_name']).wait() | ||||
| 
 | ||||
|         # TODO: right now we'll spin printing bars if the last time stamp is | ||||
|         # before a large period of no market activity.  Likely the best way | ||||
|         # to solve this is to make this task aware of the instrument's | ||||
|         # tradable hours? | ||||
| 
 | ||||
|         # adjust delay to compensate for trio processing time | ||||
|         ad = min(self.ohlcv_shms.keys()) - 0.001 | ||||
| 
 | ||||
|         total_s = 0  # total seconds counted | ||||
|         lowest = min(self.ohlcv_shms.keys()) | ||||
|         lowest_shm = self.ohlcv_shms[lowest][0] | ||||
|         ad = lowest - 0.001 | ||||
|         total_s: float = 0  # total seconds counted | ||||
|         ad = period_s - 0.001  # compensate for trio processing time | ||||
| 
 | ||||
|         with trio.CancelScope() as cs: | ||||
| 
 | ||||
|             # register this time period step as active | ||||
|             self.incrementers[delay_s] = cs | ||||
|             task_status.started(cs) | ||||
| 
 | ||||
|             # sample step loop: | ||||
|             # includes broadcasting to all connected consumers on every | ||||
|             # new sample step as well incrementing any registered | ||||
|             # buffers by registered sample period. | ||||
|             while True: | ||||
|                 # TODO: do we want to support dynamically | ||||
|                 # adding a "lower" lowest increment period? | ||||
|                 await trio.sleep(ad) | ||||
|                 total_s += delay_s | ||||
|                 total_s += period_s | ||||
| 
 | ||||
|                 # increment all subscribed shm arrays | ||||
|                 # TODO: | ||||
|                 # - this in ``numba`` | ||||
|                 # - just lookup shms for this step instead of iterating? | ||||
|                 for this_delay_s, shms in self.ohlcv_shms.items(): | ||||
| 
 | ||||
|                 i_epoch = round(time.time()) | ||||
|                 broadcasted: set[float] = set() | ||||
| 
 | ||||
|                 # print(f'epoch: {i_epoch} -> REGISTRY {self.ohlcv_shms}') | ||||
|                 for shm_period_s, shms in self.ohlcv_shms.items(): | ||||
| 
 | ||||
|                     # short-circuit on any not-ready because slower sample | ||||
|                     # rate consuming shm buffers. | ||||
|                     if total_s % this_delay_s != 0: | ||||
|                         # print(f'skipping `{this_delay_s}s` sample update') | ||||
|                     if total_s % shm_period_s != 0: | ||||
|                         # print(f'skipping `{shm_period_s}s` sample update') | ||||
|                         continue | ||||
| 
 | ||||
|                     # update last epoch stamp for this period group | ||||
|                     if shm_period_s not in broadcasted: | ||||
|                         sub_pair = self.subscribers[shm_period_s] | ||||
|                         sub_pair[0] = i_epoch | ||||
|                         print(f'skipping `{shm_period_s}s` sample update') | ||||
|                         broadcasted.add(shm_period_s) | ||||
| 
 | ||||
|                     # TODO: ``numba`` this! | ||||
|                     for shm in shms: | ||||
|                         # print(f'UPDATE {shm_period_s}s STEP for {shm.token}') | ||||
| 
 | ||||
|                         # append new entry to buffer thus "incrementing" | ||||
|                         # the bar | ||||
|                         array = shm.array | ||||
|                         last = array[-1:][shm._write_fields].copy() | ||||
| 
 | ||||
|                         # guard against startup backfilling race with | ||||
|                         # empty buffers. | ||||
|                         if not last.size: | ||||
|                             continue | ||||
| 
 | ||||
|                         (t, close) = last[0][[ | ||||
|                             'time', | ||||
|                             'close', | ||||
|                         ]] | ||||
| 
 | ||||
|                         next_t = t + this_delay_s | ||||
|                         i_epoch = round(time.time()) | ||||
|                         next_t = t + shm_period_s | ||||
| 
 | ||||
|                         if this_delay_s <= 1: | ||||
|                         if shm_period_s <= 1: | ||||
|                             next_t = i_epoch | ||||
| 
 | ||||
|                         # print(f'epoch {shm.token["shm_name"]}: {next_t}') | ||||
| 
 | ||||
|                         # this copies non-std fields (eg. vwap) from the | ||||
|                         # last datum | ||||
|                         last[[ | ||||
|  | @ -185,43 +222,43 @@ class Sampler: | |||
|                         # write to the buffer | ||||
|                         shm.push(last) | ||||
| 
 | ||||
|                 await self.broadcast(delay_s, shm=lowest_shm) | ||||
|                 # broadcast increment msg to all updated subs per period | ||||
|                 for shm_period_s in broadcasted: | ||||
|                     await self.broadcast( | ||||
|                         period_s=shm_period_s, | ||||
|                         time_stamp=i_epoch, | ||||
|                     ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     async def broadcast( | ||||
|         self, | ||||
|         delay_s: int, | ||||
|         shm: ShmArray | None = None, | ||||
|         period_s: float, | ||||
|         time_stamp: float | None = None, | ||||
| 
 | ||||
|     ) -> None: | ||||
|         ''' | ||||
|         Broadcast the given ``shm: ShmArray``'s buffer index step to any | ||||
|         Broadcast the period size and last index step value to all | ||||
|         subscribers for a given sample period. | ||||
| 
 | ||||
|         The sent msg will include the first and last index which slice into | ||||
|         the buffer's non-empty data. | ||||
| 
 | ||||
|         ''' | ||||
|         subs = self.subscribers.get(delay_s, ()) | ||||
|         first = last = -1 | ||||
|         pair = self.subscribers[period_s] | ||||
| 
 | ||||
|         if shm is None: | ||||
|             periods = self.ohlcv_shms.keys() | ||||
|             # if this is an update triggered by a history update there | ||||
|             # might not actually be any sampling bus setup since there's | ||||
|             # no "live feed" active yet. | ||||
|             if periods: | ||||
|                 lowest = min(periods) | ||||
|                 shm = self.ohlcv_shms[lowest][0] | ||||
|                 first = shm._first.value | ||||
|                 last = shm._last.value | ||||
|         last_ts, subs = pair | ||||
| 
 | ||||
|         task = trio.lowlevel.current_task() | ||||
|         log.debug( | ||||
|             f'SUBS {self.subscribers}\n' | ||||
|             f'PAIR {pair}\n' | ||||
|             f'TASK: {task}: {id(task)}\n' | ||||
|             f'broadcasting {period_s} -> {last_ts}\n' | ||||
|             # f'consumers: {subs}' | ||||
|         ) | ||||
|         borked: set[tractor.MsgStream] = set() | ||||
|         for stream in subs: | ||||
|             try: | ||||
|                 await stream.send({ | ||||
|                     'first': first, | ||||
|                     'last': last, | ||||
|                     'index': last, | ||||
|                     'index': time_stamp or last_ts, | ||||
|                     'period': period_s, | ||||
|                 }) | ||||
|             except ( | ||||
|                 trio.BrokenResourceError, | ||||
|  | @ -230,69 +267,232 @@ class Sampler: | |||
|                 log.error( | ||||
|                     f'{stream._ctx.chan.uid} dropped connection' | ||||
|                 ) | ||||
|                 try: | ||||
|                     subs.remove(stream) | ||||
|                 except ValueError: | ||||
|                     log.warning( | ||||
|                         f'{stream._ctx.chan.uid} sub already removed!?' | ||||
|                     ) | ||||
|                 borked.add(stream) | ||||
| 
 | ||||
|     @classmethod | ||||
|     async def broadcast_all(self) -> None: | ||||
|         for delay_s in self.subscribers: | ||||
|             await self.broadcast(delay_s) | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def maybe_open_global_sampler( | ||||
|     ctx: tractor.Context, | ||||
|     brokername: str, | ||||
| 
 | ||||
| ) -> None: | ||||
|     get_console_log(tractor.current_actor().loglevel) | ||||
| 
 | ||||
|     global Sampler | ||||
| 
 | ||||
|     async with trio.open_nursery() as service_nursery: | ||||
|         Sampler.service_nursery = service_nursery | ||||
| 
 | ||||
|         # unblock caller | ||||
|         await ctx.started() | ||||
| 
 | ||||
|         # we pin this task to keep the feeds manager active until the | ||||
|         # parent actor decides to tear it down | ||||
|         await trio.sleep_forever() | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def iter_ohlc_periods( | ||||
|     ctx: tractor.Context, | ||||
|     delay_s: int, | ||||
| 
 | ||||
| ) -> None: | ||||
|     ''' | ||||
|     Subscribe to OHLC sampling "step" events: when the time | ||||
|     aggregation period increments, this event stream emits an index | ||||
|     event. | ||||
| 
 | ||||
|     ''' | ||||
|     # add our subscription | ||||
|     subs = Sampler.subscribers.setdefault(delay_s, []) | ||||
|     await ctx.started() | ||||
|     async with ctx.open_stream() as stream: | ||||
|         subs.append(stream) | ||||
| 
 | ||||
|         try: | ||||
|             # stream and block until cancelled | ||||
|             await trio.sleep_forever() | ||||
|         finally: | ||||
|         for stream in borked: | ||||
|             try: | ||||
|                 subs.remove(stream) | ||||
|             except ValueError: | ||||
|                 log.error( | ||||
|                     f'iOHLC step stream was already dropped {ctx.chan.uid}?' | ||||
|                 log.warning( | ||||
|                     f'{stream._ctx.chan.uid} sub already removed!?' | ||||
|                 ) | ||||
| 
 | ||||
|     @classmethod | ||||
|     async def broadcast_all(self) -> None: | ||||
|         for period_s in self.subscribers: | ||||
|             await self.broadcast(period_s) | ||||
| 
 | ||||
| 
 | ||||
| @tractor.context | ||||
| async def register_with_sampler( | ||||
|     ctx: tractor.Context, | ||||
|     period_s: float, | ||||
|     shms_by_period: dict[float, dict] | None = None, | ||||
|     open_index_stream: bool = True, | ||||
| 
 | ||||
| ) -> None: | ||||
| 
 | ||||
|     get_console_log(tractor.current_actor().loglevel) | ||||
|     incr_was_started: bool = False | ||||
| 
 | ||||
|     try: | ||||
|         async with maybe_open_nursery( | ||||
|             Sampler.service_nursery | ||||
|         ) as service_nursery: | ||||
| 
 | ||||
|             # init startup, create (actor-)local service nursery and start | ||||
|             # increment task | ||||
|             Sampler.service_nursery = service_nursery | ||||
| 
 | ||||
|             # always ensure a period subs entry exists | ||||
|             last_ts, subs = Sampler.subscribers[float(period_s)] | ||||
| 
 | ||||
|             async with trio.Lock(): | ||||
|                 if Sampler.incr_task_cs is None: | ||||
|                     Sampler.incr_task_cs = await service_nursery.start( | ||||
|                         Sampler.increment_ohlc_buffer, | ||||
|                         1., | ||||
|                     ) | ||||
|                     incr_was_started = True | ||||
| 
 | ||||
|             # insert the base 1s period (for OHLC style sampling) into | ||||
|             # the increment buffer set to update and shift every second. | ||||
|             if shms_by_period is not None: | ||||
|                 from ._sharedmem import ( | ||||
|                     attach_shm_array, | ||||
|                     _Token, | ||||
|                 ) | ||||
|                 for period in shms_by_period: | ||||
| 
 | ||||
|                     # load and register shm handles | ||||
|                     shm_token_msg = shms_by_period[period] | ||||
|                     shm = attach_shm_array( | ||||
|                         _Token.from_msg(shm_token_msg), | ||||
|                         readonly=False, | ||||
|                     ) | ||||
|                     shms_by_period[period] = shm | ||||
|                     Sampler.ohlcv_shms.setdefault(period, []).append(shm) | ||||
| 
 | ||||
|                 assert Sampler.ohlcv_shms | ||||
| 
 | ||||
|             # unblock caller | ||||
|             await ctx.started(set(Sampler.ohlcv_shms.keys())) | ||||
| 
 | ||||
|             if open_index_stream: | ||||
|                 try: | ||||
|                     async with ctx.open_stream() as stream: | ||||
|                         subs.add(stream) | ||||
| 
 | ||||
|                         # except broadcast requests from the subscriber | ||||
|                         async for msg in stream: | ||||
|                             if msg == 'broadcast_all': | ||||
|                                 await Sampler.broadcast_all() | ||||
| 
 | ||||
|                 finally: | ||||
|                     subs.remove(stream) | ||||
|             else: | ||||
|                 # if no shms are passed in we just wait until cancelled | ||||
|                 # by caller. | ||||
|                 await trio.sleep_forever() | ||||
| 
 | ||||
|     finally: | ||||
|         # TODO: why tf isn't this working? | ||||
|         if shms_by_period is not None: | ||||
|             for period, shm in shms_by_period.items(): | ||||
|                 Sampler.ohlcv_shms[period].remove(shm) | ||||
| 
 | ||||
|         if incr_was_started: | ||||
|             Sampler.incr_task_cs.cancel() | ||||
|             Sampler.incr_task_cs = None | ||||
| 
 | ||||
| 
 | ||||
| async def spawn_samplerd( | ||||
| 
 | ||||
|     loglevel: str | None = None, | ||||
|     **extra_tractor_kwargs | ||||
| 
 | ||||
| ) -> bool: | ||||
|     ''' | ||||
|     Daemon-side service task: start a sampling daemon for common step | ||||
|     update and increment count write and stream broadcasting. | ||||
| 
 | ||||
|     ''' | ||||
|     from piker._daemon import Services | ||||
| 
 | ||||
|     dname = 'samplerd' | ||||
|     log.info(f'Spawning `{dname}`') | ||||
| 
 | ||||
|     # singleton lock creation of ``samplerd`` since we only ever want | ||||
|     # one daemon per ``pikerd`` proc tree. | ||||
|     # TODO: make this built-into the service api? | ||||
|     async with Services.locks[dname + '_singleton']: | ||||
| 
 | ||||
|         if dname not in Services.service_tasks: | ||||
| 
 | ||||
|             portal = await Services.actor_n.start_actor( | ||||
|                 dname, | ||||
|                 enable_modules=[ | ||||
|                     'piker.data._sampling', | ||||
|                 ], | ||||
|                 loglevel=loglevel, | ||||
|                 debug_mode=Services.debug_mode,  # set by pikerd flag | ||||
|                 **extra_tractor_kwargs | ||||
|             ) | ||||
| 
 | ||||
|             await Services.start_service_task( | ||||
|                 dname, | ||||
|                 portal, | ||||
|                 register_with_sampler, | ||||
|                 period_s=1, | ||||
|             ) | ||||
|             return True | ||||
| 
 | ||||
|         return False | ||||
| 
 | ||||
| 
 | ||||
| @acm | ||||
| async def maybe_open_samplerd( | ||||
| 
 | ||||
|     loglevel: str | None = None, | ||||
|     **kwargs, | ||||
| 
 | ||||
| ) -> tractor._portal.Portal:  # noqa | ||||
|     ''' | ||||
|     Client-side helper to maybe startup the ``samplerd`` service | ||||
|     under the ``pikerd`` tree. | ||||
| 
 | ||||
|     ''' | ||||
|     dname = 'samplerd' | ||||
| 
 | ||||
|     async with maybe_spawn_daemon( | ||||
|         dname, | ||||
|         service_task_target=spawn_samplerd, | ||||
|         spawn_args={'loglevel': loglevel}, | ||||
|         loglevel=loglevel, | ||||
|         **kwargs, | ||||
| 
 | ||||
|     ) as portal: | ||||
|         yield portal | ||||
| 
 | ||||
| 
 | ||||
| @acm | ||||
| async def open_sample_stream( | ||||
|     period_s: int, | ||||
|     shms_by_period: dict[float, dict] | None = None, | ||||
|     open_index_stream: bool = True, | ||||
| 
 | ||||
|     cache_key: str | None = None, | ||||
|     allow_new_sampler: bool = True, | ||||
| 
 | ||||
| ) -> AsyncIterator[dict[str, float]]: | ||||
|     ''' | ||||
|     Subscribe to OHLC sampling "step" events: when the time aggregation | ||||
|     period increments, this event stream emits an index event. | ||||
| 
 | ||||
|     This is a client-side endpoint that does all the work of ensuring | ||||
|     the `samplerd` actor is up and that mult-consumer-tasks are given | ||||
|     a broadcast stream when possible. | ||||
| 
 | ||||
|     ''' | ||||
|     # TODO: wrap this manager with the following to make it cached | ||||
|     # per client-multitasks entry. | ||||
|     # maybe_open_context( | ||||
|     #     acm_func=partial( | ||||
|     #         portal.open_context, | ||||
|     #         register_with_sampler, | ||||
|     #     ), | ||||
|     #     key=cache_key or period_s, | ||||
|     # ) | ||||
|     # if cache_hit: | ||||
|     #     # add a new broadcast subscription for the quote stream | ||||
|     #     # if this feed is likely already in use | ||||
|     #     async with istream.subscribe() as bistream: | ||||
|     #         yield bistream | ||||
|     # else: | ||||
| 
 | ||||
|     async with ( | ||||
|         # XXX: this should be singleton on a host, | ||||
|         # a lone broker-daemon per provider should be | ||||
|         # created for all practical purposes | ||||
|         maybe_open_samplerd() as portal, | ||||
| 
 | ||||
|         portal.open_context( | ||||
|             register_with_sampler, | ||||
|             **{ | ||||
|                 'period_s': period_s, | ||||
|                 'shms_by_period': shms_by_period, | ||||
|                 'open_index_stream': open_index_stream, | ||||
|             }, | ||||
|         ) as (ctx, first) | ||||
|     ): | ||||
|         async with ( | ||||
|             ctx.open_stream() as istream, | ||||
| 
 | ||||
|             # TODO: we don't need this task-bcasting right? | ||||
|             # istream.subscribe() as istream, | ||||
|         ): | ||||
|             yield istream | ||||
| 
 | ||||
| 
 | ||||
| async def sample_and_broadcast( | ||||
| 
 | ||||
|  | @ -304,7 +504,14 @@ async def sample_and_broadcast( | |||
|     sum_tick_vlm: bool = True, | ||||
| 
 | ||||
| ) -> None: | ||||
|     ''' | ||||
|     `brokerd`-side task which writes latest datum sampled data. | ||||
| 
 | ||||
|     This task is meant to run in the same actor (mem space) as the | ||||
|     `brokerd` real-time quote feed which is being sampled to | ||||
|     a ``ShmArray`` buffer. | ||||
| 
 | ||||
|     ''' | ||||
|     log.info("Started shared mem bar writer") | ||||
| 
 | ||||
|     overruns = Counter() | ||||
|  | @ -341,7 +548,6 @@ async def sample_and_broadcast( | |||
|                     for shm in [rt_shm, hist_shm]: | ||||
|                         # update last entry | ||||
|                         # benchmarked in the 4-5 us range | ||||
|                         # for shm in [rt_shm, hist_shm]: | ||||
|                         o, high, low, v = shm.array[-1][ | ||||
|                             ['open', 'high', 'low', 'volume'] | ||||
|                         ] | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue