From 5766dd518d47b3d756bc01c66e5b80cc572ddcbf Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Mon, 10 May 2021 10:17:06 -0400 Subject: [PATCH] Enforce lower case symbols across providers --- piker/data/feed.py | 31 +++++++++++++++++++------------ piker/ui/_chart.py | 4 ++++ 2 files changed, 23 insertions(+), 12 deletions(-) diff --git a/piker/data/feed.py b/piker/data/feed.py index b9df8595..cdd19070 100644 --- a/piker/data/feed.py +++ b/piker/data/feed.py @@ -379,13 +379,25 @@ class Feed: @asynccontextmanager async def open_symbol_search(self) -> AsyncIterator[dict]: + open_search = getattr(self.mod, 'open_symbol_search', None) + if open_search is None: + + # just return a pure pass through searcher + async def passthru(text: str) -> Dict[str, Any]: + return text + + self.search = passthru + yield self.search + self.search = None + return + async with self._brokerd_portal.open_context( - - self.mod.open_symbol_search, - + open_search, ) as (ctx, cache): - async with ctx.open_stream() as stream: + # shield here since we expect the search rpc to be + # cancellable by the user as they see fit. + async with ctx.open_stream(shield=True) as stream: async def search(text: str) -> Dict[str, Any]: await stream.send(text) @@ -448,7 +460,7 @@ async def open_feed( """ global _cache, _cache_lock - sym = symbols[0] + sym = symbols[0].lower() # TODO: feed cache locking, right now this is causing # issues when reconncting to a long running emsd? @@ -526,11 +538,6 @@ async def open_feed( _cache[(brokername, sym)] = feed - try: - async with feed.open_symbol_search(): - yield feed + async with feed.open_symbol_search(): + yield feed - finally: - # always cancel the far end producer task - with trio.CancelScope(shield=True): - await stream.aclose() diff --git a/piker/ui/_chart.py b/piker/ui/_chart.py index 01d3af35..7b574422 100644 --- a/piker/ui/_chart.py +++ b/piker/ui/_chart.py @@ -94,6 +94,7 @@ class ChartSpace(QtGui.QWidget): # self.init_strategy_ui() self.vbox.addLayout(self.toolbar_layout) self.vbox.addLayout(self.hbox) + self._chart_cache = {} self.linkedcharts: 'LinkedSplitCharts' = None self.symbol_label: Optional[QtGui.QLabel] = None @@ -135,6 +136,9 @@ class ChartSpace(QtGui.QWidget): Expects a ``numpy`` structured array containing all the ohlcv fields. """ + # our symbol key style is always lower case + symbol_key = symbol_key.lower() + linkedcharts = self._chart_cache.get(symbol_key) if not self.vbox.isEmpty():