Add a `._sampling.sampler` registry composite type

async_hist_loading
Tyler Goodlet 2022-02-28 11:56:36 -05:00
parent 6f3d78b729
commit c239faf4e5
2 changed files with 38 additions and 29 deletions

View File

@ -32,19 +32,27 @@ from ..log import get_logger
log = get_logger(__name__) log = get_logger(__name__)
class sampler:
'''
Global sampling engine registry.
Manages state for sampling events, shm incrementing and
sample period logic.
'''
# TODO: we could stick these in a composed type to avoid # TODO: we could stick these in a composed type to avoid
# angering the "i hate module scoped variables crowd" (yawn). # angering the "i hate module scoped variables crowd" (yawn).
_ohlcv_shms: dict[int, list[ShmArray]] = {} ohlcv_shms: dict[int, list[ShmArray]] = {}
# holds one-task-per-sample-period tasks which are spawned as-needed by # holds one-task-per-sample-period tasks which are spawned as-needed by
# data feed requests with a given detected time step usually from # data feed requests with a given detected time step usually from
# history loading. # history loading.
_incrementers: dict[int, trio.CancelScope] = {} incrementers: dict[int, trio.CancelScope] = {}
# holds all the ``tractor.Context`` remote subscriptions for # holds all the ``tractor.Context`` remote subscriptions for
# a particular sample period increment event: all subscribers are # a particular sample period increment event: all subscribers are
# notified on a step. # notified on a step.
_subscribers: dict[int, tractor.Context] = {} subscribers: dict[int, tractor.Context] = {}
async def increment_ohlc_buffer( async def increment_ohlc_buffer(
@ -73,19 +81,17 @@ async def increment_ohlc_buffer(
# to solve this is to make this task aware of the instrument's # to solve this is to make this task aware of the instrument's
# tradable hours? # tradable hours?
global _incrementers, _ohlcv_shms, _subscribers
# adjust delay to compensate for trio processing time # adjust delay to compensate for trio processing time
ad = min(_ohlcv_shms.keys()) - 0.001 ad = min(sampler.ohlcv_shms.keys()) - 0.001
total_s = 0 # total seconds counted total_s = 0 # total seconds counted
lowest = min(_ohlcv_shms.keys()) lowest = min(sampler.ohlcv_shms.keys())
ad = lowest - 0.001 ad = lowest - 0.001
with trio.CancelScope() as cs: with trio.CancelScope() as cs:
# register this time period step as active # register this time period step as active
_incrementers[delay_s] = cs sampler.incrementers[delay_s] = cs
task_status.started(cs) task_status.started(cs)
while True: while True:
@ -98,7 +104,7 @@ async def increment_ohlc_buffer(
# TODO: # TODO:
# - this in ``numba`` # - this in ``numba``
# - just lookup shms for this step instead of iterating? # - just lookup shms for this step instead of iterating?
for delay_s, shms in _ohlcv_shms.items(): for delay_s, shms in sampler.ohlcv_shms.items():
if total_s % delay_s != 0: if total_s % delay_s != 0:
continue continue
@ -125,7 +131,7 @@ async def increment_ohlc_buffer(
# broadcast the buffer index step to any subscribers for # broadcast the buffer index step to any subscribers for
# a given sample period. # a given sample period.
subs = _subscribers.get(delay_s, ()) subs = sampler.subscribers.get(delay_s, ())
for ctx in subs: for ctx in subs:
try: try:
@ -151,8 +157,7 @@ async def iter_ohlc_periods(
''' '''
# add our subscription # add our subscription
global _subscribers subs = sampler.subscribers.setdefault(delay_s, [])
subs = _subscribers.setdefault(delay_s, [])
subs.append(ctx) subs.append(ctx)
try: try:
@ -313,6 +318,8 @@ async def uniform_rate_send(
quote_stream: trio.abc.ReceiveChannel, quote_stream: trio.abc.ReceiveChannel,
stream: tractor.MsgStream, stream: tractor.MsgStream,
task_status: TaskStatus = trio.TASK_STATUS_IGNORED,
) -> None: ) -> None:
# TODO: compute the approx overhead latency per cycle # TODO: compute the approx overhead latency per cycle
@ -323,6 +330,8 @@ async def uniform_rate_send(
last_send = time.time() last_send = time.time()
diff = 0 diff = 0
task_status.started()
while True: while True:
# compute the remaining time to sleep for this throttled cycle # compute the remaining time to sleep for this throttled cycle

View File

@ -56,10 +56,7 @@ from ._source import (
) )
from ..ui import _search from ..ui import _search
from ._sampling import ( from ._sampling import (
# TODO: should probably group these in a compound type at this point XD sampler,
_ohlcv_shms,
_subscribers,
_incrementers,
increment_ohlc_buffer, increment_ohlc_buffer,
iter_ohlc_periods, iter_ohlc_periods,
sample_and_broadcast, sample_and_broadcast,
@ -118,7 +115,7 @@ class _FeedsBus(BaseModel):
trio.CancelScope] = trio.TASK_STATUS_IGNORED, trio.CancelScope] = trio.TASK_STATUS_IGNORED,
) -> None: ) -> None:
with trio.CancelScope() as cs: with trio.CancelScope() as cs:
self.nursery.start_soon( await self.nursery.start(
target, target,
*args, *args,
) )
@ -255,23 +252,26 @@ async def manage_history(
await feed_is_live.wait() await feed_is_live.wait()
if opened: if opened:
_ohlcv_shms.setdefault(delay_s, []).append(shm) sampler.ohlcv_shms.setdefault(delay_s, []).append(shm)
# start shm incrementing for OHLC sampling at the current # start shm incrementing for OHLC sampling at the current
# detected sampling period if one dne. # detected sampling period if one dne.
if _incrementers.get(delay_s) is None: if sampler.incrementers.get(delay_s) is None:
cs = await bus.start_task(increment_ohlc_buffer, delay_s) cs = await bus.start_task(
increment_ohlc_buffer,
delay_s,
)
await trio.sleep_forever() await trio.sleep_forever()
cs.cancel() cs.cancel()
async def allocate_persistent_feed( async def allocate_persistent_feed(
bus: _FeedsBus, bus: _FeedsBus,
brokername: str, brokername: str,
symbol: str, symbol: str,
loglevel: str, loglevel: str,
task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED, task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED,
) -> None: ) -> None: