Implement a `samplerd` singleton actor service

Now spawned under the `pikerd` tree as a singleton-daemon-actor we offer
a slew of new routines in support of this micro-service:

- `maybe_open_samplerd()` and `spawn_samplerd()` which provide the
  `._daemon.Services` integration to conduct service spawning.
- `open_sample_stream()` which is a client-side endpoint which does all
  the work of (lazily) starting the `samplerd` service (if dne) and
  registers shm buffers for update as well as connect a sample-index
  stream for iterator by the caller.
- `register_with_sampler()` which is the `samplerd`-side service task
  endpoint implementing all the shm buffer and index-stream registry
  details as well as logic to ensure a lone service task runs
  `Services.increment_ohlc_buffer()`; it increments at the shortest period
  registered which, for now, is the default 1s duration.

Further impl notes:
- fixes to `Services.broadcast()` to ensure broken streams get discarded
  gracefully.
- we use a `pikerd` side singleton mutex `trio.Lock()` to ensure
  one-and-only-one `samplerd` is ever spawned per `pikerd` actor tree.
samplerd_service
Tyler Goodlet 2023-01-04 22:04:51 -05:00
parent a342f7d2d4
commit 5ec1a72a3d
1 changed files with 315 additions and 109 deletions

View File

@ -24,12 +24,17 @@ from collections import (
Counter, Counter,
defaultdict, defaultdict,
) )
from contextlib import asynccontextmanager as acm
import time import time
from typing import ( from typing import (
AsyncIterator,
TYPE_CHECKING, TYPE_CHECKING,
) )
import tractor import tractor
from tractor.trionics import (
maybe_open_nursery,
)
import trio import trio
from trio_typing import TaskStatus from trio_typing import TaskStatus
@ -37,14 +42,20 @@ from ..log import (
get_logger, get_logger,
get_console_log, get_console_log,
) )
from .._daemon import maybe_spawn_daemon
if TYPE_CHECKING: if TYPE_CHECKING:
from ._sharedmem import ShmArray from ._sharedmem import (
ShmArray,
)
from .feed import _FeedsBus from .feed import _FeedsBus
log = get_logger(__name__) log = get_logger(__name__)
# highest frequency sample step is 1 second by default, though in
# the future we may want to support shorter periods or a dynamic style
# tick-event stream.
_default_delay_s: float = 1.0 _default_delay_s: float = 1.0
@ -55,32 +66,50 @@ class Sampler:
Manages state for sampling events, shm incrementing and Manages state for sampling events, shm incrementing and
sample period logic. sample period logic.
This non-instantiated type is meant to be a singleton within
a `samplerd` actor-service spawned once by the user wishing to
time-step sample real-time quote feeds, see
``._daemon.maybe_open_samplerd()`` and the below
``register_with_sampler()``.
''' '''
service_nursery: None | trio.Nursery = None service_nursery: None | trio.Nursery = None
# TODO: we could stick these in a composed type to avoid # TODO: we could stick these in a composed type to avoid
# angering the "i hate module scoped variables crowd" (yawn). # angering the "i hate module scoped variables crowd" (yawn).
ohlcv_shms: dict[int, list[ShmArray]] = {} ohlcv_shms: dict[float, list[ShmArray]] = {}
# holds one-task-per-sample-period tasks which are spawned as-needed by # holds one-task-per-sample-period tasks which are spawned as-needed by
# data feed requests with a given detected time step usually from # data feed requests with a given detected time step usually from
# history loading. # history loading.
incrementers: dict[int, trio.CancelScope] = {} incr_task_cs: trio.CancelScope | None = None
# holds all the ``tractor.Context`` remote subscriptions for # holds all the ``tractor.Context`` remote subscriptions for
# a particular sample period increment event: all subscribers are # a particular sample period increment event: all subscribers are
# notified on a step. # notified on a step.
subscribers: dict[int, tractor.Context] = {} # subscribers: dict[int, list[tractor.MsgStream]] = {}
subscribers: defaultdict[
float,
list[
float,
set[tractor.MsgStream]
],
] = defaultdict(
lambda: [
round(time.time()),
set(),
]
)
@classmethod @classmethod
async def increment_ohlc_buffer( async def increment_ohlc_buffer(
self, self,
delay_s: int, period_s: float,
task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED, task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED,
): ):
''' '''
Task which inserts new bars into the provide shared memory array Task which inserts new bars into the provide shared memory array
every ``delay_s`` seconds. every ``period_s`` seconds.
This task fulfills 2 purposes: This task fulfills 2 purposes:
- it takes the subscribed set of shm arrays and increments them - it takes the subscribed set of shm arrays and increments them
@ -92,66 +121,74 @@ class Sampler:
the underlying buffers will actually be incremented. the underlying buffers will actually be incremented.
''' '''
# # wait for brokerd to signal we should start sampling
# await shm_incrementing(shm_token['shm_name']).wait()
# TODO: right now we'll spin printing bars if the last time stamp is # TODO: right now we'll spin printing bars if the last time stamp is
# before a large period of no market activity. Likely the best way # before a large period of no market activity. Likely the best way
# to solve this is to make this task aware of the instrument's # to solve this is to make this task aware of the instrument's
# tradable hours? # tradable hours?
# adjust delay to compensate for trio processing time total_s: float = 0 # total seconds counted
ad = min(self.ohlcv_shms.keys()) - 0.001 ad = period_s - 0.001 # compensate for trio processing time
total_s = 0 # total seconds counted
lowest = min(self.ohlcv_shms.keys())
lowest_shm = self.ohlcv_shms[lowest][0]
ad = lowest - 0.001
with trio.CancelScope() as cs: with trio.CancelScope() as cs:
# register this time period step as active # register this time period step as active
self.incrementers[delay_s] = cs
task_status.started(cs) task_status.started(cs)
# sample step loop:
# includes broadcasting to all connected consumers on every
# new sample step as well incrementing any registered
# buffers by registered sample period.
while True: while True:
# TODO: do we want to support dynamically
# adding a "lower" lowest increment period?
await trio.sleep(ad) await trio.sleep(ad)
total_s += delay_s total_s += period_s
# increment all subscribed shm arrays # increment all subscribed shm arrays
# TODO: # TODO:
# - this in ``numba`` # - this in ``numba``
# - just lookup shms for this step instead of iterating? # - just lookup shms for this step instead of iterating?
for this_delay_s, shms in self.ohlcv_shms.items():
i_epoch = round(time.time())
broadcasted: set[float] = set()
# print(f'epoch: {i_epoch} -> REGISTRY {self.ohlcv_shms}')
for shm_period_s, shms in self.ohlcv_shms.items():
# short-circuit on any not-ready because slower sample # short-circuit on any not-ready because slower sample
# rate consuming shm buffers. # rate consuming shm buffers.
if total_s % this_delay_s != 0: if total_s % shm_period_s != 0:
# print(f'skipping `{this_delay_s}s` sample update') # print(f'skipping `{shm_period_s}s` sample update')
continue continue
# update last epoch stamp for this period group
if shm_period_s not in broadcasted:
sub_pair = self.subscribers[shm_period_s]
sub_pair[0] = i_epoch
print(f'skipping `{shm_period_s}s` sample update')
broadcasted.add(shm_period_s)
# TODO: ``numba`` this! # TODO: ``numba`` this!
for shm in shms: for shm in shms:
# print(f'UPDATE {shm_period_s}s STEP for {shm.token}')
# append new entry to buffer thus "incrementing" # append new entry to buffer thus "incrementing"
# the bar # the bar
array = shm.array array = shm.array
last = array[-1:][shm._write_fields].copy() last = array[-1:][shm._write_fields].copy()
# guard against startup backfilling race with
# empty buffers.
if not last.size:
continue
(t, close) = last[0][[ (t, close) = last[0][[
'time', 'time',
'close', 'close',
]] ]]
next_t = t + this_delay_s next_t = t + shm_period_s
i_epoch = round(time.time())
if this_delay_s <= 1: if shm_period_s <= 1:
next_t = i_epoch next_t = i_epoch
# print(f'epoch {shm.token["shm_name"]}: {next_t}')
# this copies non-std fields (eg. vwap) from the # this copies non-std fields (eg. vwap) from the
# last datum # last datum
last[[ last[[
@ -185,43 +222,43 @@ class Sampler:
# write to the buffer # write to the buffer
shm.push(last) shm.push(last)
await self.broadcast(delay_s, shm=lowest_shm) # broadcast increment msg to all updated subs per period
for shm_period_s in broadcasted:
await self.broadcast(
period_s=shm_period_s,
time_stamp=i_epoch,
)
@classmethod @classmethod
async def broadcast( async def broadcast(
self, self,
delay_s: int, period_s: float,
shm: ShmArray | None = None, time_stamp: float | None = None,
) -> None: ) -> None:
''' '''
Broadcast the given ``shm: ShmArray``'s buffer index step to any Broadcast the period size and last index step value to all
subscribers for a given sample period. subscribers for a given sample period.
The sent msg will include the first and last index which slice into
the buffer's non-empty data.
''' '''
subs = self.subscribers.get(delay_s, ()) pair = self.subscribers[period_s]
first = last = -1
if shm is None: last_ts, subs = pair
periods = self.ohlcv_shms.keys()
# if this is an update triggered by a history update there
# might not actually be any sampling bus setup since there's
# no "live feed" active yet.
if periods:
lowest = min(periods)
shm = self.ohlcv_shms[lowest][0]
first = shm._first.value
last = shm._last.value
task = trio.lowlevel.current_task()
log.debug(
f'SUBS {self.subscribers}\n'
f'PAIR {pair}\n'
f'TASK: {task}: {id(task)}\n'
f'broadcasting {period_s} -> {last_ts}\n'
# f'consumers: {subs}'
)
borked: set[tractor.MsgStream] = set()
for stream in subs: for stream in subs:
try: try:
await stream.send({ await stream.send({
'first': first, 'index': time_stamp or last_ts,
'last': last, 'period': period_s,
'index': last,
}) })
except ( except (
trio.BrokenResourceError, trio.BrokenResourceError,
@ -230,6 +267,9 @@ class Sampler:
log.error( log.error(
f'{stream._ctx.chan.uid} dropped connection' f'{stream._ctx.chan.uid} dropped connection'
) )
borked.add(stream)
for stream in borked:
try: try:
subs.remove(stream) subs.remove(stream)
except ValueError: except ValueError:
@ -239,59 +279,219 @@ class Sampler:
@classmethod @classmethod
async def broadcast_all(self) -> None: async def broadcast_all(self) -> None:
for delay_s in self.subscribers: for period_s in self.subscribers:
await self.broadcast(delay_s) await self.broadcast(period_s)
@tractor.context @tractor.context
async def maybe_open_global_sampler( async def register_with_sampler(
ctx: tractor.Context, ctx: tractor.Context,
brokername: str, period_s: float,
shms_by_period: dict[float, dict] | None = None,
open_index_stream: bool = True,
) -> None: ) -> None:
get_console_log(tractor.current_actor().loglevel) get_console_log(tractor.current_actor().loglevel)
incr_was_started: bool = False
global Sampler try:
async with maybe_open_nursery(
Sampler.service_nursery
) as service_nursery:
async with trio.open_nursery() as service_nursery: # init startup, create (actor-)local service nursery and start
# increment task
Sampler.service_nursery = service_nursery Sampler.service_nursery = service_nursery
# unblock caller # always ensure a period subs entry exists
await ctx.started() last_ts, subs = Sampler.subscribers[float(period_s)]
# we pin this task to keep the feeds manager active until the async with trio.Lock():
# parent actor decides to tear it down if Sampler.incr_task_cs is None:
await trio.sleep_forever() Sampler.incr_task_cs = await service_nursery.start(
Sampler.increment_ohlc_buffer,
1.,
@tractor.context
async def iter_ohlc_periods(
ctx: tractor.Context,
delay_s: int,
) -> None:
'''
Subscribe to OHLC sampling "step" events: when the time
aggregation period increments, this event stream emits an index
event.
'''
# add our subscription
subs = Sampler.subscribers.setdefault(delay_s, [])
await ctx.started()
async with ctx.open_stream() as stream:
subs.append(stream)
try:
# stream and block until cancelled
await trio.sleep_forever()
finally:
try:
subs.remove(stream)
except ValueError:
log.error(
f'iOHLC step stream was already dropped {ctx.chan.uid}?'
) )
incr_was_started = True
# insert the base 1s period (for OHLC style sampling) into
# the increment buffer set to update and shift every second.
if shms_by_period is not None:
from ._sharedmem import (
attach_shm_array,
_Token,
)
for period in shms_by_period:
# load and register shm handles
shm_token_msg = shms_by_period[period]
shm = attach_shm_array(
_Token.from_msg(shm_token_msg),
readonly=False,
)
shms_by_period[period] = shm
Sampler.ohlcv_shms.setdefault(period, []).append(shm)
assert Sampler.ohlcv_shms
# unblock caller
await ctx.started(set(Sampler.ohlcv_shms.keys()))
if open_index_stream:
try:
async with ctx.open_stream() as stream:
subs.add(stream)
# except broadcast requests from the subscriber
async for msg in stream:
if msg == 'broadcast_all':
await Sampler.broadcast_all()
finally:
subs.remove(stream)
else:
# if no shms are passed in we just wait until cancelled
# by caller.
await trio.sleep_forever()
finally:
# TODO: why tf isn't this working?
if shms_by_period is not None:
for period, shm in shms_by_period.items():
Sampler.ohlcv_shms[period].remove(shm)
if incr_was_started:
Sampler.incr_task_cs.cancel()
Sampler.incr_task_cs = None
async def spawn_samplerd(
loglevel: str | None = None,
**extra_tractor_kwargs
) -> bool:
'''
Daemon-side service task: start a sampling daemon for common step
update and increment count write and stream broadcasting.
'''
from piker._daemon import Services
dname = 'samplerd'
log.info(f'Spawning `{dname}`')
# singleton lock creation of ``samplerd`` since we only ever want
# one daemon per ``pikerd`` proc tree.
# TODO: make this built-into the service api?
async with Services.locks[dname + '_singleton']:
if dname not in Services.service_tasks:
portal = await Services.actor_n.start_actor(
dname,
enable_modules=[
'piker.data._sampling',
],
loglevel=loglevel,
debug_mode=Services.debug_mode, # set by pikerd flag
**extra_tractor_kwargs
)
await Services.start_service_task(
dname,
portal,
register_with_sampler,
period_s=1,
)
return True
return False
@acm
async def maybe_open_samplerd(
loglevel: str | None = None,
**kwargs,
) -> tractor._portal.Portal: # noqa
'''
Client-side helper to maybe startup the ``samplerd`` service
under the ``pikerd`` tree.
'''
dname = 'samplerd'
async with maybe_spawn_daemon(
dname,
service_task_target=spawn_samplerd,
spawn_args={'loglevel': loglevel},
loglevel=loglevel,
**kwargs,
) as portal:
yield portal
@acm
async def open_sample_stream(
period_s: int,
shms_by_period: dict[float, dict] | None = None,
open_index_stream: bool = True,
cache_key: str | None = None,
allow_new_sampler: bool = True,
) -> AsyncIterator[dict[str, float]]:
'''
Subscribe to OHLC sampling "step" events: when the time aggregation
period increments, this event stream emits an index event.
This is a client-side endpoint that does all the work of ensuring
the `samplerd` actor is up and that mult-consumer-tasks are given
a broadcast stream when possible.
'''
# TODO: wrap this manager with the following to make it cached
# per client-multitasks entry.
# maybe_open_context(
# acm_func=partial(
# portal.open_context,
# register_with_sampler,
# ),
# key=cache_key or period_s,
# )
# if cache_hit:
# # add a new broadcast subscription for the quote stream
# # if this feed is likely already in use
# async with istream.subscribe() as bistream:
# yield bistream
# else:
async with (
# XXX: this should be singleton on a host,
# a lone broker-daemon per provider should be
# created for all practical purposes
maybe_open_samplerd() as portal,
portal.open_context(
register_with_sampler,
**{
'period_s': period_s,
'shms_by_period': shms_by_period,
'open_index_stream': open_index_stream,
},
) as (ctx, first)
):
async with (
ctx.open_stream() as istream,
# TODO: we don't need this task-bcasting right?
# istream.subscribe() as istream,
):
yield istream
async def sample_and_broadcast( async def sample_and_broadcast(
@ -304,7 +504,14 @@ async def sample_and_broadcast(
sum_tick_vlm: bool = True, sum_tick_vlm: bool = True,
) -> None: ) -> None:
'''
`brokerd`-side task which writes latest datum sampled data.
This task is meant to run in the same actor (mem space) as the
`brokerd` real-time quote feed which is being sampled to
a ``ShmArray`` buffer.
'''
log.info("Started shared mem bar writer") log.info("Started shared mem bar writer")
overruns = Counter() overruns = Counter()
@ -341,7 +548,6 @@ async def sample_and_broadcast(
for shm in [rt_shm, hist_shm]: for shm in [rt_shm, hist_shm]:
# update last entry # update last entry
# benchmarked in the 4-5 us range # benchmarked in the 4-5 us range
# for shm in [rt_shm, hist_shm]:
o, high, low, v = shm.array[-1][ o, high, low, v = shm.array[-1][
['open', 'high', 'low', 'volume'] ['open', 'high', 'low', 'volume']
] ]