diff --git a/piker/data/feed.py b/piker/data/feed.py index d9ce062e..9455f7ae 100644 --- a/piker/data/feed.py +++ b/piker/data/feed.py @@ -51,6 +51,7 @@ from ._sampling import ( iter_ohlc_periods, sample_and_broadcast, ) +from .ingest import get_ingestormod log = get_logger(__name__) @@ -302,6 +303,7 @@ class Feed: async def receive(self) -> dict: return await self.stream.__anext__() + @asynccontextmanager async def index_stream( self, delay_s: Optional[int] = None @@ -312,14 +314,16 @@ class Feed: # XXX: this should be singleton on a host, # a lone broker-daemon per provider should be # created for all practical purposes - self._index_stream = await self._brokerd_portal.run( + async with self._brokerd_portal.open_stream_from( iter_ohlc_periods, delay_s=delay_s or self._max_sample_rate, - ) + ) as self._index_stream: + yield self._index_stream + else: + yield self._index_stream - return self._index_stream - - async def recv_trades_data(self) -> AsyncIterator[dict]: + @asynccontextmanager + async def receive_trades_data(self) -> AsyncIterator[dict]: if not getattr(self.mod, 'stream_trades', False): log.warning( @@ -333,7 +337,7 @@ class Feed: # using the ``_.set_fake_trades_stream()`` method if self._trade_stream is None: - self._trade_stream = await self._brokerd_portal.run( + async with self._brokerd_portal.open_stream_from( self.mod.stream_trades, @@ -342,9 +346,10 @@ class Feed: # in messages, though we could probably use # more then one? topics=['local_trades'], - ) - - return self._trade_stream + ) as self._trade_stream: + yield self._trade_stream + else: + yield self._trade_stream def sym_to_shm_key( @@ -373,64 +378,64 @@ async def open_feed( # TODO: do all! sym = symbols[0] - async with maybe_spawn_brokerd( - brokername, - loglevel=loglevel, - ) as portal: + # TODO: compress these to one line with py3.9+ + async with maybe_spawn_brokerd(brokername, loglevel=loglevel) as portal: + + async with portal.open_stream_from( - stream = await portal.run( attach_feed_bus, brokername=brokername, symbol=sym, - loglevel=loglevel, - ) + loglevel=loglevel - # TODO: can we make this work better with the proposed - # context based bidirectional streaming style api proposed in: - # https://github.com/goodboy/tractor/issues/53 - init_msg = await stream.receive() + ) as stream: - # we can only read from shm - shm = attach_shm_array( - token=init_msg[sym]['shm_token'], - readonly=True, - ) + # TODO: can we make this work better with the proposed + # context based bidirectional streaming style api proposed in: + # https://github.com/goodboy/tractor/issues/53 + init_msg = await stream.receive() - feed = Feed( - name=brokername, - stream=stream, - shm=shm, - mod=mod, - _brokerd_portal=portal, - ) - ohlc_sample_rates = [] - - for sym, data in init_msg.items(): - - si = data['symbol_info'] - ohlc_sample_rates.append(data['sample_rate']) - - symbol = Symbol( - key=sym, - type_key=si.get('asset_type', 'forex'), - tick_size=si.get('price_tick_size', 0.01), - lot_tick_size=si.get('lot_tick_size', 0.0), + # we can only read from shm + shm = attach_shm_array( + token=init_msg[sym]['shm_token'], + readonly=True, ) - symbol.broker_info[brokername] = si - feed.symbols[sym] = symbol + feed = Feed( + name=brokername, + stream=stream, + shm=shm, + mod=mod, + _brokerd_portal=portal, + ) + ohlc_sample_rates = [] - # cast shm dtype to list... can't member why we need this - shm_token = data['shm_token'] - shm_token['dtype_descr'] = list(shm_token['dtype_descr']) - assert shm_token == shm.token # sanity + for sym, data in init_msg.items(): - feed._max_sample_rate = max(ohlc_sample_rates) + si = data['symbol_info'] + ohlc_sample_rates.append(data['sample_rate']) - try: - yield feed + symbol = Symbol( + key=sym, + type_key=si.get('asset_type', 'forex'), + tick_size=si.get('price_tick_size', 0.01), + lot_tick_size=si.get('lot_tick_size', 0.0), + ) + symbol.broker_info[brokername] = si - finally: - # always cancel the far end producer task - with trio.CancelScope(shield=True): - await stream.aclose() + feed.symbols[sym] = symbol + + # cast shm dtype to list... can't member why we need this + shm_token = data['shm_token'] + shm_token['dtype_descr'] = list(shm_token['dtype_descr']) + assert shm_token == shm.token # sanity + + feed._max_sample_rate = max(ohlc_sample_rates) + + try: + yield feed + + finally: + # always cancel the far end producer task + with trio.CancelScope(shield=True): + await stream.aclose()