From 312169e79063d544cf232ee75b76baa507d2d41d Mon Sep 17 00:00:00 2001 From: Tyler Goodlet Date: Mon, 20 Jul 2020 16:58:40 -0400 Subject: [PATCH] Support the `stream_quotes()` api in questrade backend --- piker/brokers/questrade.py | 162 ++++++++++++++++++++++++++++++++++++- 1 file changed, 158 insertions(+), 4 deletions(-) diff --git a/piker/brokers/questrade.py b/piker/brokers/questrade.py index 1fc3cbc8..961190c1 100644 --- a/piker/brokers/questrade.py +++ b/piker/brokers/questrade.py @@ -3,15 +3,20 @@ Questrade API backend. """ from __future__ import annotations import inspect +import contextlib import time from datetime import datetime from functools import partial import itertools import configparser -from typing import List, Tuple, Dict, Any, Iterator, NamedTuple +from typing import ( + List, Tuple, Dict, Any, Iterator, NamedTuple, + AsyncGenerator, +) import arrow import trio +import tractor from async_generator import asynccontextmanager import pandas as pd import numpy as np @@ -23,6 +28,7 @@ from . import config from ._util import resproc, BrokerError, SymbolNotFound from ..log import get_logger, colorize_json from .._async_utils import async_lifo_cache +from . import get_brokermod log = get_logger(__name__) @@ -408,10 +414,10 @@ class Client: return symbols2ids - async def symbol_info(self, tickers: List[str]): - """Return symbol data for ``tickers``. + async def symbol_info(self, symbols: List[str]): + """Return symbol data for ``symbols``. """ - t2ids = await self.tickers2ids(tickers) + t2ids = await self.tickers2ids(symbols) ids = ','.join(t2ids.values()) symbols = {} for pkt in (await self.api.symbols(ids=ids))['symbols']: @@ -1004,3 +1010,151 @@ def format_option_quote( displayable[new_key] = display_value return new, displayable + + +@asynccontextmanager +async def get_cached_client( + brokername: str, + *args, + **kwargs, +) -> 'Client': + """Get a cached broker client from the current actor's local vars. + + If one has not been setup do it and cache it. + """ + # check if a cached client is in the local actor's statespace + ss = tractor.current_actor().statespace + clients = ss.setdefault('clients', {'_lock': trio.Lock()}) + lock = clients['_lock'] + client = None + try: + log.info(f"Loading existing `{brokername}` daemon") + async with lock: + client = clients[brokername] + except KeyError: + log.info(f"Creating new client for broker {brokername}") + async with lock: + brokermod = get_brokermod(brokername) + exit_stack = contextlib.AsyncExitStack() + client = await exit_stack.enter_async_context( + brokermod.get_client()) + client._exit_stack = exit_stack + clients[brokername] = client + else: + client._consumers += 1 + yield client + finally: + client._consumers -= 1 + if client._consumers <= 0: + # teardown the client + await client._exit_stack.aclose() + + +async def smoke_quote(get_quotes, tickers): # , broker): + """Do an initial "smoke" request for symbols in ``tickers`` filtering + out any symbols not supported by the broker queried in the call to + ``get_quotes()``. + """ + from operator import itemgetter + # TODO: trim out with #37 + ################################################# + # get a single quote filtering out any bad tickers + # NOTE: this code is always run for every new client + # subscription even when a broker quoter task is already running + # since the new client needs to know what symbols are accepted + log.warn(f"Retrieving smoke quote for symbols {tickers}") + quotes = await get_quotes(tickers) + + # report any tickers that aren't returned in the first quote + invalid_tickers = set(tickers) - set(map(itemgetter('key'), quotes)) + for symbol in invalid_tickers: + tickers.remove(symbol) + log.warn( + f"Symbol `{symbol}` not found") # by broker `{broker}`" + # ) + + # pop any tickers that return "empty" quotes + payload = {} + for quote in quotes: + symbol = quote['symbol'] + if quote is None: + log.warn( + f"Symbol `{symbol}` not found") + # XXX: not this mutates the input list (for now) + tickers.remove(symbol) + continue + + # report any unknown/invalid symbols (QT specific) + if quote.get('low52w', False) is None: + log.error( + f"{symbol} seems to be defunct") + + payload[symbol] = quote + + return payload + + # end of section to be trimmed out with #37 + ########################################### + + +@tractor.stream +async def stream_quotes( + ctx: tractor.Context, # marks this as a streaming func + symbols: List[str], + feed_type: str = 'stock', + diff_cached: bool = True, + rate: int = 3, + # feed_type: str = 'stock', +) -> AsyncGenerator[str, Dict[str, Any]]: + + async with get_cached_client('questrade') as client: + if feed_type == 'stock': + formatter = format_stock_quote + get_quotes = await stock_quoter(client, symbols) + + # do a smoke quote (note this mutates the input list and filters + # out bad symbols for now) + payload = await smoke_quote(get_quotes, list(symbols)) + else: + formatter = format_option_quote + get_quotes = await option_quoter(client, symbols) + # packetize + payload = { + quote['symbol']: quote + for quote in await get_quotes(symbols) + } + + symbol_data = await client.symbol_info(symbols) + + # function to format packets delivered to subscribers + def packetizer( + topic: str, + quotes: Dict[str, Any] + ) -> Dict[str, Any]: + """Normalize quotes by name into dicts. + """ + new = {} + for quote in quotes: + new[quote['symbol']], _ = formatter(quote, symbol_data) + + return new + + # push initial smoke quote response for client initialization + await ctx.send_yield(payload) + + from .data import stream_poll_requests + + await stream_poll_requests( + + # ``msg.pub`` required kwargs + task_name=feed_type, + ctx=ctx, + topics=symbols, + packetizer=packetizer, + + # actual func args + get_quotes=get_quotes, + diff_cached=diff_cached, + rate=rate, + ) + log.info("Terminating stream quoter task")