diff --git a/piker/brokers/ib.py b/piker/brokers/ib.py index 1d01907d..51724e5a 100644 --- a/piker/brokers/ib.py +++ b/piker/brokers/ib.py @@ -355,28 +355,34 @@ class Client: # batch request all details results = await asyncio.gather(*futs) - # XXX: if there is more then one entry in the details list + # one set per future result details = {} for details_set in results: + # XXX: if there is more then one entry in the details list # then the contract is so called "ambiguous". for d in details_set: con = d.contract - unique_sym = f'{con.symbol}.{con.primaryExchange}' - as_dict = asdict(d) + key = '.'.join([ + con.symbol, + con.primaryExchange or con.exchange, + ]) + expiry = con.lastTradeDateOrContractMonth + if expiry: + key += f'.{expiry}' + # nested dataclass we probably don't need and that - # won't IPC serialize - as_dict.pop('secIdList') + # won't IPC serialize.. + d.secIdList = '' - details[unique_sym] = as_dict + details[key] = d return details async def search_stocks( self, pattern: str, - get_details: bool = False, upto: int = 3, # how many contracts to search "up to" ) -> dict[str, ContractDetails]: @@ -388,31 +394,13 @@ class Client: ''' descriptions = await self.ib.reqMatchingSymbolsAsync(pattern) - if descriptions is not None: - descrs = descriptions[:upto] - - if get_details: - deats = await self.con_deats([d.contract for d in descrs]) - return deats - - else: - results = {} - for d in descrs: - con = d.contract - # sometimes there's a weird extra suffix returned - # from search? - exch = con.primaryExchange.rsplit('.')[0] - unique_sym = f'{con.symbol}.{exch}' - expiry = con.lastTradeDateOrContractMonth - if expiry: - unique_sym += f'{expiry}' - - results[unique_sym] = {} - - return results - else: + if descriptions is None: return {} + # limit + descrs = descriptions[:upto] + return await self.con_deats([d.contract for d in descrs]) + async def search_symbols( self, pattern: str, @@ -427,36 +415,30 @@ class Client: results = await self.search_stocks( pattern, upto=upto, - get_details=True, ) - for key, contracts in results.copy().items(): - tract = contracts['contract'] - sym = tract['symbol'] + for key, deats in results.copy().items(): + + tract = deats.contract + sym = tract.symbol + sectype = tract.secType - sectype = tract['secType'] if sectype == 'IND': results[f'{sym}.IND'] = tract results.pop(key) - exch = tract['exchange'] + exch = tract.exchange if exch in _futes_venues: # try get all possible contracts for symbol as per, # https://interactivebrokers.github.io/tws-api/basic_contracts.html#fut - con = Contract( - 'FUT+CONTFUT', + con = ibis.Future( symbol=sym, exchange=exch, ) try: - possibles = await self.ib.qualifyContractsAsync(con) - for i, condict in enumerate(sorted( - map(asdict, possibles), - # sort by expiry - key=lambda con: con['lastTradeDateOrContractMonth'], - )): - expiry = condict['lastTradeDateOrContractMonth'] - results[f'{sym}.{exch}.{expiry}'] = condict + all_deats = await self.con_deats([con]) + results |= all_deats + except RequestError as err: log.warning(err.message) @@ -600,6 +582,12 @@ class Client: raise ValueError(f"No contract could be found {con}") self._contracts[pattern] = contract + + # add an aditional entry with expiry suffix if available + conexp = contract.lastTradeDateOrContractMonth + if conexp: + self._contracts[pattern + f'.{conexp}'] = contract + return contract async def get_head_time( @@ -1640,7 +1628,7 @@ async def backfill_bars( out, fails = await get_bars(proxy, fqsn, end_dt=first_dt) - if out == None: + if out is None: # could be trying to retreive bars over weekend # TODO: add logic here to handle tradable hours and # only grab valid bars in the range