# piker: trading gear for hackers # Copyright (C) Tyler Goodlet (in stewardship for pikers) # This program is free software: you can redistribute it and/or modify # it under the terms of the GNU Affero General Public License as published by # the Free Software Foundation, either version 3 of the License, or # (at your option) any later version. # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU Affero General Public License for more details. # You should have received a copy of the GNU Affero General Public License # along with this program. If not, see . ''' Core API client machinery; mostly sane/useful wrapping around `ib_insync`.. ''' from __future__ import annotations from contextlib import ( asynccontextmanager as acm, contextmanager as cm, ) from contextlib import AsyncExitStack from dataclasses import ( asdict, astuple, ) from datetime import datetime from functools import ( partial, ) import itertools from math import isnan import asyncio from pprint import pformat import inspect import time from typing import ( Any, Callable, ) from types import SimpleNamespace from bidict import bidict import trio import tractor from tractor import to_asyncio from pendulum import ( from_timestamp, DateTime, Duration, duration as mk_duration, ) from eventkit import Event from ib_insync import ( client as ib_client, IB, Contract, ContractDetails, Crypto, Commodity, Forex, Future, ContFuture, Stock, Order, Ticker, BarDataList, Position, Fill, # Execution, # CommissionReport, Wrapper, RequestError, ) import numpy as np # TODO: in hindsight, probably all imports should be # non-relative for backends so that non-builting backends # can be easily modelled after this style B) from piker import config from piker.accounting import MktPair from .symbols import ( con2fqme, parse_patt2fqme, _adhoc_symbol_map, _exch_skip_list, _futes_venues, ) from ._util import ( log, # only for the ib_sync internal logging get_logger, ) _bar_load_dtype: list[tuple[str, type]] = [ # NOTE XXX: only part that's diff # from our default fields where # time is normally an int. # TODO: can we just cast to this # at np.ndarray load time? ('time', float), ('open', float), ('high', float), ('low', float), ('close', float), ('volume', float), ('count', int), ] # Broker specific ohlc schema which includes a vwap field _ohlc_dtype: list[tuple[str, type]] = _bar_load_dtype.copy() _ohlc_dtype.insert( 0, ('index', int), ) _time_units = { 's': ' sec', 'm': ' mins', 'h': ' hours', } _bar_sizes = { 1: '1 Sec', 60: '1 min', 60*60: '1 hour', 24*60*60: '1 day', } _show_wap_in_history: bool = False # overrides to sidestep pretty questionable design decisions in # ``ib_insync``: class NonShittyWrapper(Wrapper): def tcpDataArrived(self): """Override time stamps to be floats for now. """ # use a ns int to store epoch time instead of datetime self.lastTime = time.time_ns() for ticker in self.pendingTickers: ticker.rtTime = None ticker.ticks = [] ticker.tickByTicks = [] ticker.domTicks = [] self.pendingTickers = set() def execDetails( self, reqId: int, contract: Contract, execu, ): """ Get rid of datetime on executions. """ # this is the IB server's execution time supposedly # https://interactivebrokers.github.io/tws-api/classIBApi_1_1Execution.html#a2e05cace0aa52d809654c7248e052ef2 execu.time = execu.time.timestamp() return super().execDetails(reqId, contract, execu) class NonShittyIB(IB): ''' The beginning of overriding quite a few decisions in this lib. - Don't use datetimes - Don't use named tuples ''' def __init__(self): # override `ib_insync` internal loggers so we can see wtf # it's doing.. self._logger = get_logger( 'ib_insync.ib', ) self._createEvents() # XXX: just to override this wrapper self.wrapper = NonShittyWrapper(self) self.client = ib_client.Client(self.wrapper) self.client._logger = get_logger( 'ib_insync.client', ) # self.errorEvent += self._onError self.client.apiEnd += self.disconnectedEvent _enters = 0 def bars_to_np(bars: list) -> np.ndarray: ''' Convert a "bars list thing" (``BarDataList`` type from ibis) into a numpy struct array. ''' # TODO: maybe rewrite this faster with ``numba`` np_ready = [] for bardata in bars: ts = bardata.date.timestamp() t = astuple(bardata)[:7] np_ready.append((ts, ) + t[1:7]) nparr = np.array( np_ready, dtype=_bar_load_dtype, ) assert nparr['time'][0] == bars[0].date.timestamp() assert nparr['time'][-1] == bars[-1].date.timestamp() return nparr # NOTE: pacing violations exist for higher sample rates: # https://interactivebrokers.github.io/tws-api/historical_limitations.html#pacing_violations # Also see note on duration limits being lifted on 1m+ periods, # but they say "use with discretion": # https://interactivebrokers.github.io/tws-api/historical_limitations.html#non-available_hd _samplings: dict[int, tuple[str, str]] = { 1: ( # ib strs '1 secs', f'{int(2e3)} S', mk_duration(seconds=2e3), ), # TODO: benchmark >1 D duration on query to see if # throughput can be made faster during backfilling. 60: ( # ib strs '1 min', '2 D', mk_duration(days=2), ), } @cm def remove_handler_on_err( event: Event, handler: Callable, ) -> None: try: yield except trio.BrokenResourceError: # XXX: eventkit's ``Event.emit()`` for whatever redic # reason will catch and ignore regular exceptions # resulting in tracebacks spammed to console.. # Manually do the dereg ourselves. log.exception(f'Disconnected from {event} updates') event.disconnect(handler) class Client: ''' IB wrapped for our broker backend API. Note: this client requires running inside an ``asyncio`` loop. ''' # keyed by fqmes _contracts: dict[str, Contract] = {} # keyed by conId _cons: dict[str, Contract] = {} # for going between ib and piker types _cons2mkts: bidict[Contract, MktPair] = bidict({}) def __init__( self, ib: IB, config: dict[str, Any], ) -> None: # stash `brokers.toml` config on client for user settings # as needed throughout this backend (eg. vnc sockaddr). self.conf = config # NOTE: the ib.client here is "throttled" to 45 rps by default self.ib: IB = ib self.ib.RaiseRequestErrors: bool = True # self._acnt_names: set[str] = {} self._acnt_names: list[str] = [] @property def acnts(self) -> list[str]: # return list(self._acnt_names) return self._acnt_names def __repr__(self) -> str: return ( f'<{type(self).__name__}(' f'ib={self.ib} ' f'acnts={self.acnts}' # TODO: we need to mask out acnt-#s and other private # infos if we're going to console this! # f' |_.conf:\n' # f' {pformat(self.conf)}\n' ')>' ) async def get_fills(self) -> list[Fill]: ''' Return list of rents `Fills` from trading session. In theory this can be configured for dumping clears from multiple days but can't member where to set that.. ''' fills: list[Fill] = self.ib.fills() return fills async def orders(self) -> list[Order]: return await self.ib.reqAllOpenOrdersAsync( apiOnly=False, ) async def bars( self, fqme: str, # EST in ISO 8601 format is required... below is EPOCH start_dt: datetime | str = "1970-01-01T00:00:00.000000-05:00", end_dt: datetime | str = "", # ohlc sample period in seconds sample_period_s: int = 1, # optional "duration of time" equal to the # length of the returned history frame. duration: str | None = None, **kwargs, ) -> tuple[BarDataList, np.ndarray, Duration]: ''' Retreive OHLCV bars for a fqme over a range to the present. ''' # See API docs here: # https://interactivebrokers.github.io/tws-api/historical_data.html bars_kwargs = {'whatToShow': 'TRADES'} bars_kwargs.update(kwargs) ( bar_size, ib_duration_str, default_dt_duration, ) = _samplings[sample_period_s] dt_duration: Duration = ( duration or default_dt_duration ) # TODO: maybe remove all this? global _enters if not end_dt: end_dt = '' _enters += 1 contract: Contract = (await self.find_contracts(fqme))[0] bars_kwargs.update(getattr(contract, 'bars_kwargs', {})) kwargs: dict[str, Any] = dict( contract=contract, endDateTime=end_dt, formatDate=2, # OHLC sampling values: # 1 secs, 5 secs, 10 secs, 15 secs, 30 secs, 1 min, 2 mins, # 3 mins, 5 mins, 10 mins, 15 mins, 20 mins, 30 mins, # 1 hour, 2 hours, 3 hours, 4 hours, 8 hours, 1 day, 1W, 1M barSizeSetting=bar_size, # time history length values format: # ``durationStr=integer{SPACE}unit (S|D|W|M|Y)`` durationStr=ib_duration_str, # always use extended hours useRTH=False, # restricted per contract type **bars_kwargs, # whatToShow='MIDPOINT', # whatToShow='TRADES', ) bars = await self.ib.reqHistoricalDataAsync( **kwargs, ) query_info: str = ( f'REQUESTING IB history BARS\n' f' ------ - ------\n' f'dt_duration: {dt_duration}\n' f'ib_duration_str: {ib_duration_str}\n' f'bar_size: {bar_size}\n' f'fqme: {fqme}\n' f'actor-global _enters: {_enters}\n' f'kwargs: {pformat(kwargs)}\n' ) # tail case if no history for range or none prior. # NOTE: there's actually 3 cases here to handle (and # this should be read alongside the implementation of # `.reqHistoricalDataAsync()`): # - a timeout occurred in which case insync internals return # an empty list thing with bars.clear()... # - no data exists for the period likely due to # a weekend, holiday or other non-trading period prior to # ``end_dt`` which exceeds the ``duration``, # - LITERALLY this is the start of the mkt's history! if not bars: # TODO: figure out wut's going on here. # TODO: is this handy, a sync requester for tinkering # with empty frame cases? # def get_hist(): # return self.ib.reqHistoricalData(**kwargs) # import pdbp # pdbp.set_trace() log.critical( 'STUPID IB SAYS NO HISTORY\n\n' + query_info ) # TODO: we could maybe raise ``NoData`` instead if we # rewrite the method in the first case? # right now there's no way to detect a timeout.. return [], np.empty(0), dt_duration log.info(query_info) # NOTE XXX: ensure minimum duration in bars? # => recursively call this method until we get at least as # many bars such that they sum in aggregate to the the # desired total time (duration) at most. # - if you query over a gap and get no data # that may short circuit the history if ( # XXX XXX XXX # => WHY DID WE EVEN NEED THIS ORIGINALLY!? <= # XXX XXX XXX False and end_dt ): nparr: np.ndarray = bars_to_np(bars) times: np.ndarray = nparr['time'] first: float = times[0] tdiff: float = times[-1] - first if ( # len(bars) * sample_period_s) < dt_duration.in_seconds() tdiff < dt_duration.in_seconds() # and False ): end_dt: DateTime = from_timestamp(first) log.warning( f'Frame result was shorter then {dt_duration}!?\n' 'Recursing for more bars:\n' f'end_dt: {end_dt}\n' f'dt_duration: {dt_duration}\n' ) ( r_bars, r_arr, r_duration, ) = await self.bars( fqme, start_dt=start_dt, end_dt=end_dt, sample_period_s=sample_period_s, # TODO: make a table for Duration to # the ib str values in order to use this? # duration=duration, ) r_bars.extend(bars) bars = r_bars nparr: np.ndarray = bars_to_np(bars) # timestep should always be at least as large as the # period step. tdiff: np.ndarray = np.diff(nparr['time']) to_short: np.ndarray = tdiff < sample_period_s if (to_short).any(): # raise ValueError( log.error( f'OHLC frame for {sample_period_s} has {to_short.size} ' 'time steps which are shorter then expected?!"' ) # OOF: this will break teardown? # -[ ] check if it's greenback # -[ ] why tf are we leaking shm entries.. # -[ ] make a test on the debugging asyncio testing # branch.. # breakpoint() return ( bars, nparr, dt_duration, ) async def con_deats( self, contracts: list[Contract], ) -> dict[str, ContractDetails]: futs: list[asyncio.Future] = [] for con in contracts: exch: str = con.primaryExchange or con.exchange if ( exch and exch not in _exch_skip_list ): futs.append(self.ib.reqContractDetailsAsync(con)) # batch request all details try: results: list[ContractDetails] = await asyncio.gather(*futs) except RequestError as err: msg: str = err.message if ( 'No security definition' in msg ): log.warning(f'{msg}: {contracts}') return {} raise # one set per future result details: dict[str, ContractDetails] = {} for details_set in results: # XXX: if there is more then one entry in the details list # then the contract is so called "ambiguous". for d in details_set: # nested dataclass we probably don't need and that won't # IPC serialize.. d.secIdList = '' key, calc_price = con2fqme(d.contract) details[key] = d return details async def search_contracts( self, pattern: str, upto: int = 3, # how many contracts to search "up to" ) -> dict[str, ContractDetails]: ''' Search for ``Contract``s matching provided ``str`` pattern. Return a dictionary of ``upto`` entries worth of ``ContractDetails``. ''' descrs: list[ContractDetails] = ( await self.ib.reqMatchingSymbolsAsync(pattern) ) if descrs is None: return {} return await self.con_deats( # limit to first ``upto`` entries [d.contract for d in descrs[:upto]] ) async def search_symbols( self, pattern: str, # how many contracts to search "up to" upto: int = 16, asdicts: bool = True, ) -> dict[str, ContractDetails]: # TODO add search though our adhoc-locally defined symbol set # for futes/cmdtys/ try: results = await self.search_contracts( pattern, upto=upto, ) except ConnectionError: return {} dict_results: dict[str, dict] = {} for key, deats in results.copy().items(): tract = deats.contract sym = tract.symbol sectype = tract.secType deats_dict = asdict(deats) if sectype == 'IND': results.pop(key) key = f'{sym}.IND' results[key] = tract # exch = tract.exchange # XXX: add back one of these to get the weird deadlock # on the debugger from root without the latest # maybe_wait_for_debugger() fix in the `open_context()` # exit. # assert 0 # if con.exchange not in _exch_skip_list: exch = tract.exchange if exch not in _exch_skip_list: # try to lookup any contracts from our adhoc set # since often the exchange/venue is named slightly # different (eg. BRR.CMECRYPTO` instead of just # `.CME`). info = _adhoc_symbol_map.get(sym) if info: con_kwargs, bars_kwargs = info exch = con_kwargs['exchange'] # try get all possible contracts for symbol as per, # https://interactivebrokers.github.io/tws-api/basic_contracts.html#fut con = Future( symbol=sym, exchange=exch, ) # TODO: make this work, think it's something to do # with the qualify flag. # cons = await self.find_contracts( # contract=con, # err_on_qualify=False, # ) # if cons: all_deats = await self.con_deats([con]) results |= all_deats for key in all_deats: dict_results[key] = asdict(all_deats[key]) # forex pairs elif sectype == 'CASH': results.pop(key) dst, src = tract.localSymbol.split('.') pair_key = "/".join([dst, src]) exch = tract.exchange.lower() key = f'{pair_key}.{exch}' results[key] = tract # XXX: again seems to trigger the weird tractor # bug with the debugger.. # assert 0 dict_results[key] = deats_dict return dict_results async def get_fute( self, symbol: str, exchange: str, expiry: str = '', front: bool = False, ) -> Contract: ''' Get an unqualifed contract for the current "continous" future. ''' # it's the "front" contract returned here if front: con = (await self.ib.qualifyContractsAsync( ContFuture(symbol, exchange=exchange) ))[0] else: con = (await self.ib.qualifyContractsAsync( Future( symbol, exchange=exchange, lastTradeDateOrContractMonth=expiry, ) ))[0] return con # TODO: is this a better approach? # @async_lifo_cache() async def get_con( self, conid: int, ) -> Contract: try: return self._cons[conid] except KeyError: con: Contract = await self.ib.qualifyContractsAsync( Contract(conId=conid) ) self._cons[str(conid)] = con[0] return con async def find_contracts( self, pattern: str | None = None, contract: Contract | None = None, qualify: bool = True, err_on_qualify: bool = True, ) -> Contract: if pattern is not None: symbol, currency, exch, expiry = parse_patt2fqme( pattern, ) sectype: str = '' exch: str = exch.upper() else: assert contract symbol: str = contract.symbol sectype: str = contract.secType exch: str = contract.exchange or contract.primaryExchange expiry: str = contract.lastTradeDateOrContractMonth currency: str = contract.currency # contract searching stage # ------------------------ # futes, ensure exch/venue is uppercase for matching # our adhoc set. if exch.upper() in _futes_venues: if expiry: # get the "front" contract con = await self.get_fute( symbol=symbol, exchange=exch, expiry=expiry, ) else: # get the "front" contract con = await self.get_fute( symbol=symbol, exchange=exch, front=True, ) elif ( exch in {'IDEALPRO'} or sectype == 'CASH' ): pair: str = symbol if '/' in symbol: src, dst = symbol.split('/') pair: str = ''.join([src, dst]) con = Forex( pair=pair, currency='', ) con.bars_kwargs = {'whatToShow': 'MIDPOINT'} # commodities elif exch == 'CMDTY': # eg. XAUUSD.CMDTY con_kwargs, bars_kwargs = _adhoc_symbol_map[symbol.upper()] con = Commodity(**con_kwargs) con.bars_kwargs = bars_kwargs # crypto$ elif exch == 'PAXOS': # btc.paxos con = Crypto( symbol=symbol, currency=currency, ) # stonks else: # TODO: metadata system for all these exchange rules.. primaryExchange = '' if exch in ('PURE', 'TSE'): # non-yankee currency = 'CAD' # stupid ib... primaryExchange = exch exch = 'SMART' else: # XXX: order is super important here since # a primary == 'SMART' won't ever work. primaryExchange = exch exch = 'SMART' con = Stock( symbol=symbol, exchange=exch, primaryExchange=primaryExchange, currency=currency, ) exch = 'SMART' if not exch else exch contracts: list[Contract] = [con] if qualify: try: contracts: list[Contract] = ( await self.ib.qualifyContractsAsync(con) ) except RequestError as err: msg = err.message if ( 'No security definition' in msg and not err_on_qualify ): log.warning( f'Could not find def for {con}') return None else: raise if not contracts: raise ValueError(f"No contract could be found {con}") # pack all contracts into cache for tract in contracts: exch: str = ( tract.primaryExchange or tract.exchange or exch ) pattern: str = f'{symbol}.{exch.lower()}' expiry: str = tract.lastTradeDateOrContractMonth # add an entry with expiry suffix if available if expiry: pattern += f'.{expiry}' # since pos update msgs will always have the full fqme # with suffix? pattern += '.ib' # directly cache the input pattern to the output # contract match as well as by the IB-internal conId. self._contracts[pattern] = tract self._cons[str(tract.conId)] = tract return contracts async def maybe_get_head_time( self, fqme: str, ) -> datetime | None: ''' Return the first datetime stamp for `fqme` or `None` on request failure. ''' try: head_dt: datetime = await self.get_head_time(fqme=fqme) return head_dt except RequestError: log.warning(f'Unable to get head time: {fqme} ?') return None async def get_head_time( self, fqme: str, ) -> datetime: ''' Return the first datetime stamp for ``contract``. ''' contract = (await self.find_contracts(fqme))[0] return await self.ib.reqHeadTimeStampAsync( contract, whatToShow='TRADES', useRTH=False, formatDate=2, # timezone aware UTC datetime ) async def get_sym_details( self, fqme: str, ) -> tuple[ Contract, ContractDetails, ]: ''' Return matching contracts for a given ``fqme: str`` including ``Contract`` and matching ``ContractDetails``. ''' contract: Contract = (await self.find_contracts(fqme))[0] details: ContractDetails = ( await self.ib.reqContractDetailsAsync(contract) )[0] return contract, details async def get_quote( self, contract: Contract, timeout: float = 1, tries: int = 100, raise_on_timeout: bool = False, ) -> Ticker | None: ''' Return a single (snap) quote for symbol. ''' ticker: Ticker = self.ib.reqMktData( contract, snapshot=True, ) ready: ticker.TickerUpdateEvent = ticker.updateEvent # ensure a last price gets filled in before we deliver quote timeouterr: Exception | None = None warnset: bool = False for _ in range(tries): # wait for a first update(Event) indicatingn a # live quote feed. if isnan(ticker.last): try: tkr = await asyncio.wait_for( ready, timeout=timeout, ) if tkr: break except TimeoutError as err: timeouterr = err await asyncio.sleep(0.01) continue else: if not warnset: log.warning( f'Quote req timed out..maybe venue is closed?\n' f'{asdict(contract)}' ) warnset = True else: log.info( 'Got first quote for contract\n' f'{contract}\n' ) break else: if timeouterr and raise_on_timeout: import pdbp pdbp.set_trace() raise timeouterr if not warnset: log.warning( f'Contract {contract} is not returning a quote ' 'it may be outside trading hours?' ) warnset = True return None return ticker # async to be consistent for the client proxy, and cuz why not. def submit_limit( self, oid: str, # ignored since doesn't support defining your own symbol: str, price: float, action: str, size: int, account: str, # if blank the "default" tws account is used # XXX: by default 0 tells ``ib_insync`` methods that there is no # existing order so ask the client to create a new one (which it # seems to do by allocating an int counter - collision prone..) reqid: int = None, ) -> int: ''' Place an order and return integer request id provided by client. Relevant docs: - https://interactivebrokers.github.io/tws-api/order_limitations.html ''' try: con: Contract = self._contracts[symbol] except KeyError: # require that the symbol has been previously cached by # a data feed request - ensure we aren't making orders # against non-known prices. raise RuntimeError("Can not order {symbol}, no live feed?") try: trade = self.ib.placeOrder( con, Order( orderId=reqid or 0, # stupid api devs.. action=action.upper(), # BUY/SELL # lookup the literal account number by name here. account=account, orderType='LMT', lmtPrice=price, totalQuantity=size, outsideRth=True, optOutSmartRouting=True, # TODO: need to understand this setting better as # it pertains to shit ass mms.. routeMarketableToBbo=True, designatedLocation='SMART', # TODO: make all orders GTC? # https://interactivebrokers.github.io/tws-api/classIBApi_1_1Order.html#a95539081751afb9980f4c6bd1655a6ba # goodTillDate=f"yyyyMMdd-HH:mm:ss", ), ) except AssertionError: # errrg insync.. log.warning(f'order for {reqid} already complete?') # will trigger an error in ems request handler task. return None # ib doesn't support setting your own id outside # their own weird client int counting ids.. return trade.order.orderId def submit_cancel( self, reqid: str, ) -> None: """Send cancel request for order id ``oid``. """ self.ib.cancelOrder( Order( orderId=reqid, clientId=self.ib.client.clientId, ) ) def inline_errors( self, to_trio: trio.abc.SendChannel, ) -> None: ''' Setup error relay to the provided ``trio`` mem chan such that trio tasks can retreive and parse ``asyncio``-side API request errors. ''' def push_err( reqId: int, errorCode: int, errorString: str, con: Contract, ) -> None: reason: str = errorString if reqId == -1: # it's a general event? key: str = 'event' log.info(errorString) else: key: str = 'error' log.error(errorString) try: to_trio.send_nowait(( key, { 'type': key, 'reqid': reqId, 'reason': reason, 'error_code': errorCode, 'contract': con, } )) except trio.BrokenResourceError: # XXX: eventkit's ``Event.emit()`` for whatever redic # reason will catch and ignore regular exceptions # resulting in tracebacks spammed to console.. # Manually do the dereg ourselves. log.exception('Disconnected from errorEvent updates') self.ib.errorEvent.disconnect(push_err) self.ib.errorEvent.connect(push_err) api_err = self.ib.client.apiError def report_api_err(msg: str) -> None: with remove_handler_on_err( api_err, report_api_err, ): to_trio.send_nowait(( 'error', msg, )) api_err.clear() # drop msg history api_err.connect(report_api_err) def positions( self, account: str = '', ) -> list[Position]: """ Retrieve position info for ``account``. """ return self.ib.positions(account=account) # per-actor API ep caching _client_cache: dict[tuple[str, int], Client] = {} _scan_ignore: set[tuple[str, int]] = set() def get_config() -> dict[str, Any]: conf, path = config.load( conf_name='brokers', ) section = conf.get('ib') accounts = section.get('accounts') if not accounts: raise ValueError( 'brokers.toml -> `ib.accounts` must be defined\n' f'location: {path}' ) names = list(accounts.keys()) accts = section['accounts'] = bidict(accounts) log.info( f'{path} defines {len(accts)} account aliases:\n' f'{pformat(names)}\n' ) if section is None: log.warning(f'No config section found for ib in {path}') return {} return section _accounts2clients: dict[str, Client] = {} @acm async def load_aio_clients( host: str = '127.0.0.1', port: int = None, client_id: int = 6116, # the API TCP in `ib_insync` connection can be flaky af so instead # retry a few times to get the client going.. connect_retries: int = 3, connect_timeout: float = 10, disconnect_on_exit: bool = True, ) -> dict[str, Client]: ''' Return an ``ib_insync.IB`` instance wrapped in our client API. Client instances are cached for later use. TODO: consider doing this with a ctx mngr eventually? ''' global _accounts2clients, _client_cache, _scan_ignore conf: dict[str, Any] = get_config() ib = None client = None # attempt to get connection info from config; if no .toml entry # exists, we try to load from a default localhost connection. localhost = '127.0.0.1' host, hosts = conf.get('host'), conf.get('hosts') if not (hosts or host): host = localhost if not hosts: hosts = [host] elif host and hosts: raise ValueError( 'Specify only one of `host` or `hosts` in `brokers.toml` config') try_ports = conf.get( 'ports', # default order is to check for gw first [4002, 7497] ) if isinstance(try_ports, dict): log.warning( '`ib.ports` in `brokers.toml` should be a `list` NOT a `dict`' ) try_ports = list(try_ports.values()) _err = None accounts_def: dict[str, str] = config.load_accounts(['ib']) ports = try_ports if port is None else [port] combos = list(itertools.product(hosts, ports)) accounts_found: dict[str, Client] = {} # (re)load any and all clients that can be found # from connection details in ``brokers.toml``. for host, port in combos: sockaddr = (host, port) maybe_client = _client_cache.get(sockaddr) if ( sockaddr in _scan_ignore or ( maybe_client and maybe_client.ib.isConnected() ) ): continue ib: IB = NonShittyIB() for i in range(connect_retries): try: log.info( 'Trying `ib_async` connect\n' f'{host}: {port}\n' f'clientId: {client_id}\n' f'timeout: {connect_timeout}\n' ) await ib.connectAsync( host, port, clientId=client_id + i, # this timeout is sensitive on windows and will # fail without a good "timeout error" so be # careful. timeout=connect_timeout, ) # create and cache client client = Client(ib=ib, config=conf) # update all actor-global caches log.runtime( f'Connected and caching `Client` @ {sockaddr!r}' ) _client_cache[sockaddr] = client break except ( ConnectionRefusedError, # TODO: if trying to scan for remote api clients # pretty sure we need to catch this, though it # definitely needs a shorter timeout since it hangs # for like 5s.. asyncio.exceptions.TimeoutError, OSError, ) as ce: _err = ce message: str = ( f'Failed to connect on {host}:{port} after {i} tries with\n' f'{ib.client.apiError.value()!r}\n\n' 'Retrying with a new client id..\n' ) log.runtime(message) else: # XXX report loudly if we never established after all # re-tries log.warning(message) # Pre-collect all accounts available for this # connection and map account names to this client # instance. for value in ib.accountValues(): acct_number: str = value.account acnt_alias: str = accounts_def.inverse.get(acct_number) if not acnt_alias: # TODO: should we constuct the below reco-ex from # the existing config content? _, path = config.load( conf_name='brokers', ) raise ValueError( 'No alias in account section for account!\n' f'Please add an acnt alias entry to your {path}\n' 'For example,\n\n' '[ib.accounts]\n' 'margin = {accnt_number!r}\n' '^^^^^^ <- you need this part!\n\n' 'This ensures `piker` will not leak private acnt info ' 'to console output by default!\n' ) # surjection of account names to operating clients. if acnt_alias not in accounts_found: accounts_found[acnt_alias] = client # client._acnt_names.add(acnt_alias) client._acnt_names.append(acnt_alias) if accounts_found: log.info( f'Loaded accounts for api client\n\n' f'{pformat(accounts_found)}\n' ) # XXX: why aren't we just updating this directy above # instead of using the intermediary `accounts_found`? _accounts2clients.update(accounts_found) # if we have no clients after the scan loop then error out. if not _client_cache: raise ConnectionError( 'No ib APIs could be found scanning @:\n' f'{pformat(combos)}\n' 'Check your `brokers.toml` and/or network' ) from _err try: yield _accounts2clients finally: # TODO: for re-scans we'll want to not teardown clients which # are up and stable right? if disconnect_on_exit: for acct, client in _accounts2clients.items(): log.info(f'Disconnecting {acct}@{client}') client.ib.disconnect() _client_cache.pop((host, port), None) async def load_clients_for_trio( from_trio: asyncio.Queue, to_trio: trio.abc.SendChannel, ) -> None: ''' Pure async mngr proxy to ``load_aio_clients()``. This is a bootstrap entrypoing to call from a ``tractor.to_asyncio.open_channel_from()``. ''' async with load_aio_clients( disconnect_on_exit=False, ) as accts2clients: to_trio.send_nowait(accts2clients) # TODO: maybe a sync event to wait on instead? await asyncio.sleep(float('inf')) @acm async def open_client_proxies() -> tuple[ dict[str, MethodProxy], dict[str, Client], ]: async with ( tractor.trionics.maybe_open_context( acm_func=tractor.to_asyncio.open_channel_from, kwargs={'target': load_clients_for_trio}, # lock around current actor task access # TODO: maybe this should be the default in tractor? key=tractor.current_actor().uid, ) as (cache_hit, (clients, _)), AsyncExitStack() as stack ): if cache_hit: log.info(f'Re-using cached clients: {clients}') proxies = {} for acct_name, client in clients.items(): proxy = await stack.enter_async_context( open_client_proxy(client), ) proxies[acct_name] = proxy yield proxies, clients def get_preferred_data_client( clients: dict[str, Client], ) -> tuple[str, Client]: ''' Load and return the (first found) `Client` instance that is preferred and should be used for data by iterating, in priority order, the ``ib.prefer_data_account: list[str]`` account names in the users ``brokers.toml`` file. ''' conf = get_config() data_accounts: list[str] = conf['prefer_data_account'] for name in data_accounts: client = clients.get(f'ib.{name}') if client: return name, client else: raise ValueError( 'No preferred data client could be found:\n' f'{data_accounts}' ) class MethodProxy: def __init__( self, chan: to_asyncio.LinkedTaskChannel, event_table: dict[str, trio.Event], asyncio_ns: SimpleNamespace, ) -> None: self.chan = chan self.event_table = event_table self._aio_ns = asyncio_ns async def _run_method( self, *, meth: str = None, **kwargs ) -> Any: ''' Make a ``Client`` method call by requesting through the ``tractor.to_asyncio`` layer. ''' chan = self.chan await chan.send((meth, kwargs)) while not chan.closed(): # send through method + ``kwargs: dict`` as pair msg = await chan.receive() # TODO: implement reconnect functionality like # in our `.data._web_bs.NoBsWs` # try: # msg = await chan.receive() # except ConnectionError: # self.reset() # print(f'NEXT MSG: {msg}') # TODO: py3.10 ``match:`` syntax B) if 'result' in msg: res = msg.get('result') return res elif 'exception' in msg: err = msg.get('exception') raise err elif 'error' in msg: etype, emsg = msg log.warning(f'IB error relay: {emsg}') continue else: log.warning(f'UNKNOWN IB MSG: {msg}') def status_event( self, pattern: str, ) -> dict[str, Any] | trio.Event: ev = self.event_table.get(pattern) if not ev or ev.is_set(): # print(f'inserting new data reset event item') ev = self.event_table[pattern] = trio.Event() return ev async def wait_for_data_reset(self) -> None: ''' Send hacker hot keys to ib program and wait for the event that declares the data feeds to be back up before unblocking. ''' ... async def open_aio_client_method_relay( from_trio: asyncio.Queue, to_trio: trio.abc.SendChannel, client: Client, event_consumers: dict[str, trio.Event], ) -> None: # sync with `open_client_proxy()` caller to_trio.send_nowait(client) # TODO: separate channel for error handling? client.inline_errors(to_trio) # relay all method requests to ``asyncio``-side client and deliver # back results while not to_trio._closed: msg: tuple[str, dict] | dict | None = await from_trio.get() match msg: case None: # termination sentinel log.info('asyncio `Client` method-proxy SHUTDOWN!') break case (meth_name, kwargs): meth_name, kwargs = msg meth = getattr(client, meth_name) try: resp = await meth(**kwargs) # echo the msg back to_trio.send_nowait({'result': resp}) except ( RequestError, # TODO: relay all errors to trio? # BaseException, ) as err: to_trio.send_nowait({'exception': err}) case {'error': content}: to_trio.send_nowait({'exception': content}) case _: raise ValueError(f'Unhandled msg {msg}') @acm async def open_client_proxy( client: Client, ) -> MethodProxy: event_table = {} async with ( to_asyncio.open_channel_from( open_aio_client_method_relay, client=client, event_consumers=event_table, ) as (first, chan), trio.open_nursery() as relay_n, ): assert isinstance(first, Client) proxy = MethodProxy( chan, event_table, asyncio_ns=first, ) # mock all remote methods on ib ``Client``. for name, method in inspect.getmembers( Client, predicate=inspect.isfunction, ): if '_' == name[0]: continue setattr(proxy, name, partial(proxy._run_method, meth=name)) async def relay_events(): async with chan.subscribe() as msg_stream: async for msg in msg_stream: if 'event' not in msg: continue # if 'event' in msg: # wake up any system event waiters. etype, status_msg = msg reason = status_msg['reason'] ev = proxy.event_table.pop(reason, None) if ev and ev.statistics().tasks_waiting: log.info(f'Relaying ib status message: {msg}') ev.set() continue relay_n.start_soon(relay_events) yield proxy # terminate asyncio side task await chan.send(None) @acm async def get_client( is_brokercheck: bool = False, **kwargs, ) -> Client: ''' Init the ``ib_insync`` client in another actor and return a method proxy to it. ''' # hack for `piker brokercheck ib`.. if is_brokercheck: yield Client return from .feed import open_data_client # TODO: the IPC via portal relay layer for when this current # actor isn't in aio mode. async with open_data_client() as proxy: yield proxy