diff --git a/piker/data/_sampling.py b/piker/data/_sampling.py index faa17da9..622a8ad1 100644 --- a/piker/data/_sampling.py +++ b/piker/data/_sampling.py @@ -33,7 +33,10 @@ import tractor import trio from trio_typing import TaskStatus -from ..log import get_logger +from ..log import ( + get_logger, + get_console_log, +) if TYPE_CHECKING: from ._sharedmem import ShmArray @@ -45,7 +48,7 @@ log = get_logger(__name__) _default_delay_s: float = 1.0 -class sampler: +class Sampler: ''' Global sampling engine registry. @@ -53,6 +56,8 @@ class sampler: sample period logic. ''' + service_nursery: None | trio.Nursery = None + # TODO: we could stick these in a composed type to avoid # angering the "i hate module scoped variables crowd" (yawn). ohlcv_shms: dict[int, list[ShmArray]] = {} @@ -67,165 +72,196 @@ class sampler: # notified on a step. subscribers: dict[int, tractor.Context] = {} + @classmethod + async def increment_ohlc_buffer( + self, + delay_s: int, + task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED, + ): + ''' + Task which inserts new bars into the provide shared memory array + every ``delay_s`` seconds. -async def increment_ohlc_buffer( - delay_s: int, - task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED, -): - ''' - Task which inserts new bars into the provide shared memory array - every ``delay_s`` seconds. + This task fulfills 2 purposes: + - it takes the subscribed set of shm arrays and increments them + on a common time period + - broadcast of this increment "signal" message to other actor + subscribers - This task fulfills 2 purposes: - - it takes the subscribed set of shm arrays and increments them - on a common time period - - broadcast of this increment "signal" message to other actor - subscribers + Note that if **no** actor has initiated this task then **none** of + the underlying buffers will actually be incremented. - Note that if **no** actor has initiated this task then **none** of - the underlying buffers will actually be incremented. + ''' + # # wait for brokerd to signal we should start sampling + # await shm_incrementing(shm_token['shm_name']).wait() - ''' - # # wait for brokerd to signal we should start sampling - # await shm_incrementing(shm_token['shm_name']).wait() + # TODO: right now we'll spin printing bars if the last time stamp is + # before a large period of no market activity. Likely the best way + # to solve this is to make this task aware of the instrument's + # tradable hours? - # TODO: right now we'll spin printing bars if the last time stamp is - # before a large period of no market activity. Likely the best way - # to solve this is to make this task aware of the instrument's - # tradable hours? + # adjust delay to compensate for trio processing time + ad = min(self.ohlcv_shms.keys()) - 0.001 - # adjust delay to compensate for trio processing time - ad = min(sampler.ohlcv_shms.keys()) - 0.001 + total_s = 0 # total seconds counted + lowest = min(self.ohlcv_shms.keys()) + lowest_shm = self.ohlcv_shms[lowest][0] + ad = lowest - 0.001 - total_s = 0 # total seconds counted - lowest = min(sampler.ohlcv_shms.keys()) - lowest_shm = sampler.ohlcv_shms[lowest][0] - ad = lowest - 0.001 + with trio.CancelScope() as cs: - with trio.CancelScope() as cs: + # register this time period step as active + self.incrementers[delay_s] = cs + task_status.started(cs) - # register this time period step as active - sampler.incrementers[delay_s] = cs - task_status.started(cs) + while True: + # TODO: do we want to support dynamically + # adding a "lower" lowest increment period? + await trio.sleep(ad) + total_s += delay_s - while True: - # TODO: do we want to support dynamically - # adding a "lower" lowest increment period? - await trio.sleep(ad) - total_s += delay_s + # increment all subscribed shm arrays + # TODO: + # - this in ``numba`` + # - just lookup shms for this step instead of iterating? + for this_delay_s, shms in self.ohlcv_shms.items(): - # increment all subscribed shm arrays - # TODO: - # - this in ``numba`` - # - just lookup shms for this step instead of iterating? - for this_delay_s, shms in sampler.ohlcv_shms.items(): + # short-circuit on any not-ready because slower sample + # rate consuming shm buffers. + if total_s % this_delay_s != 0: + # print(f'skipping `{this_delay_s}s` sample update') + continue - # short-circuit on any not-ready because slower sample - # rate consuming shm buffers. - if total_s % this_delay_s != 0: - # print(f'skipping `{this_delay_s}s` sample update') - continue + # TODO: ``numba`` this! + for shm in shms: + # append new entry to buffer thus "incrementing" + # the bar + array = shm.array + last = array[-1:][shm._write_fields].copy() - # TODO: ``numba`` this! - for shm in shms: - # append new entry to buffer thus "incrementing" the bar - array = shm.array - last = array[-1:][shm._write_fields].copy() + (t, close) = last[0][[ + 'time', + 'close', + ]] - (t, close) = last[0][[ - 'time', - 'close', - ]] + next_t = t + this_delay_s + i_epoch = round(time.time()) - next_t = t + this_delay_s - i_epoch = round(time.time()) + if this_delay_s <= 1: + next_t = i_epoch - if this_delay_s <= 1: - next_t = i_epoch + # print(f'epoch {shm.token["shm_name"]}: {next_t}') - # print(f'epoch {shm.token["shm_name"]}: {next_t}') + # this copies non-std fields (eg. vwap) from the + # last datum + last[[ + 'time', - # this copies non-std fields (eg. vwap) from the last datum - last[[ - 'time', + 'open', + 'high', + 'low', + 'close', - 'open', - 'high', - 'low', - 'close', + 'volume', + ]][0] = ( + # epoch timestamp + next_t, - 'volume', - ]][0] = ( - # epoch timestamp - next_t, + # OHLC + close, + close, + close, + close, - # OHLC - close, - close, - close, - close, + 0, # vlm + ) - 0, # vlm + # TODO: in theory we could make this faster by + # copying the "last" readable value into the + # underlying larger buffer's next value and then + # incrementing the counter instead of using + # ``.push()``? + + # write to the buffer + shm.push(last) + + await self.broadcast(delay_s, shm=lowest_shm) + + @classmethod + async def broadcast( + self, + delay_s: int, + shm: ShmArray | None = None, + + ) -> None: + ''' + Broadcast the given ``shm: ShmArray``'s buffer index step to any + subscribers for a given sample period. + + The sent msg will include the first and last index which slice into + the buffer's non-empty data. + + ''' + subs = self.subscribers.get(delay_s, ()) + first = last = -1 + + if shm is None: + periods = self.ohlcv_shms.keys() + # if this is an update triggered by a history update there + # might not actually be any sampling bus setup since there's + # no "live feed" active yet. + if periods: + lowest = min(periods) + shm = self.ohlcv_shms[lowest][0] + first = shm._first.value + last = shm._last.value + + for stream in subs: + try: + await stream.send({ + 'first': first, + 'last': last, + 'index': last, + }) + except ( + trio.BrokenResourceError, + trio.ClosedResourceError + ): + log.error( + f'{stream._ctx.chan.uid} dropped connection' + ) + try: + subs.remove(stream) + except ValueError: + log.warning( + f'{stream._ctx.chan.uid} sub already removed!?' ) - # TODO: in theory we could make this faster by copying the - # "last" readable value into the underlying larger buffer's - # next value and then incrementing the counter instead of - # using ``.push()``? - - # write to the buffer - shm.push(last) - - await broadcast(delay_s, shm=lowest_shm) + @classmethod + async def broadcast_all(self) -> None: + for delay_s in self.subscribers: + await self.broadcast(delay_s) -async def broadcast( - delay_s: int, - shm: ShmArray | None = None, +@tractor.context +async def maybe_open_global_sampler( + ctx: tractor.Context, + brokername: str, ) -> None: - ''' - Broadcast the given ``shm: ShmArray``'s buffer index step to any - subscribers for a given sample period. + get_console_log(tractor.current_actor().loglevel) - The sent msg will include the first and last index which slice into - the buffer's non-empty data. + global Sampler - ''' - subs = sampler.subscribers.get(delay_s, ()) - first = last = -1 + async with trio.open_nursery() as service_nursery: + Sampler.service_nursery = service_nursery - if shm is None: - periods = sampler.ohlcv_shms.keys() - # if this is an update triggered by a history update there - # might not actually be any sampling bus setup since there's - # no "live feed" active yet. - if periods: - lowest = min(periods) - shm = sampler.ohlcv_shms[lowest][0] - first = shm._first.value - last = shm._last.value + # unblock caller + await ctx.started() - for stream in subs: - try: - await stream.send({ - 'first': first, - 'last': last, - 'index': last, - }) - except ( - trio.BrokenResourceError, - trio.ClosedResourceError - ): - log.error( - f'{stream._ctx.chan.uid} dropped connection' - ) - try: - subs.remove(stream) - except ValueError: - log.warning( - f'{stream._ctx.chan.uid} sub already removed!?' - ) + # we pin this task to keep the feeds manager active until the + # parent actor decides to tear it down + await trio.sleep_forever() @tractor.context @@ -241,7 +277,7 @@ async def iter_ohlc_periods( ''' # add our subscription - subs = sampler.subscribers.setdefault(delay_s, []) + subs = Sampler.subscribers.setdefault(delay_s, []) await ctx.started() async with ctx.open_stream() as stream: subs.append(stream) diff --git a/piker/data/feed.py b/piker/data/feed.py index 744d301f..89330475 100644 --- a/piker/data/feed.py +++ b/piker/data/feed.py @@ -74,9 +74,7 @@ from ._source import ( ) from ..ui import _search from ._sampling import ( - sampler, - broadcast, - increment_ohlc_buffer, + Sampler, sample_and_broadcast, uniform_rate_send, _default_delay_s, @@ -327,8 +325,7 @@ async def start_backfill( # TODO: *** THIS IS A BUG *** # we need to only broadcast to subscribers for this fqsn.. # otherwise all fsps get reset on every chart.. - for delay_s in sampler.subscribers: - await broadcast(delay_s) + await Sampler.broadcast_all() # signal that backfilling to tsdb's end datum is complete bf_done = trio.Event() @@ -496,8 +493,7 @@ async def start_backfill( # in the block above to avoid entering new ``frames`` # values while we're pipelining the current ones to # memory... - for delay_s in sampler.subscribers: - await broadcast(delay_s) + await Sampler.broadcast_all() # short-circuit (for now) bf_done.set() @@ -738,8 +734,7 @@ async def tsdb_backfill( # (usually a chart showing graphics for said fsp) # which tells the chart to conduct a manual full # graphics loop cycle. - for delay_s in sampler.subscribers: - await broadcast(delay_s) + await Sampler.broadcast_all() # TODO: write new data to tsdb to be ready to for next read. @@ -1037,7 +1032,7 @@ async def allocate_persistent_feed( # insert 1s ohlc into the increment buffer set # to update and shift every second - sampler.ohlcv_shms.setdefault( + Sampler.ohlcv_shms.setdefault( 1, [] ).append(rt_shm) @@ -1053,13 +1048,13 @@ async def allocate_persistent_feed( # insert 1m ohlc into the increment buffer set # to shift every 60s. - sampler.ohlcv_shms.setdefault(60, []).append(hist_shm) + Sampler.ohlcv_shms.setdefault(60, []).append(hist_shm) # create buffer a single incrementer task broker backend # (aka `brokerd`) using the lowest sampler period. - if sampler.incrementers.get(_default_delay_s) is None: + if Sampler.incrementers.get(_default_delay_s) is None: await bus.start_task( - increment_ohlc_buffer, + Sampler.increment_ohlc_buffer, _default_delay_s, )