Compare commits

..

No commits in common. "multi_symbol_input" and "310_plus" have entirely different histories.

90 changed files with 7878 additions and 17447 deletions

View File

@ -3,9 +3,10 @@ name: CI
on: on:
# Triggers the workflow on push or pull request events but only for the master branch # Triggers the workflow on push or pull request events but only for the master branch
pull_request:
push: push:
branches: [ master ] branches: [ master ]
pull_request:
branches: [ master ]
# Allows you to run this workflow manually from the Actions tab # Allows you to run this workflow manually from the Actions tab
workflow_dispatch: workflow_dispatch:
@ -13,27 +14,6 @@ on:
jobs: jobs:
# test that we can generate a software distribution and install it
# thus avoid missing file issues after packaging.
sdist-linux:
name: 'sdist'
runs-on: ubuntu-latest
steps:
- name: Checkout
uses: actions/checkout@v3
- name: Setup python
uses: actions/setup-python@v2
with:
python-version: '3.10'
- name: Build sdist
run: python setup.py sdist --formats=zip
- name: Install sdist from .zips
run: python -m pip install dist/*.zip
testing: testing:
name: 'install + test-suite' name: 'install + test-suite'
runs-on: ubuntu-latest runs-on: ubuntu-latest

View File

@ -50,8 +50,3 @@ prefer_data_account = [
paper = "XX0000000" paper = "XX0000000"
margin = "X0000000" margin = "X0000000"
ira = "X0000000" ira = "X0000000"
[deribit]
key_id = 'XXXXXXXX'
key_secret = 'Xx_XxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXx'

View File

@ -3,12 +3,11 @@
version: "3.5" version: "3.5"
services: services:
ib_gw_paper: ib-gateway:
# other image tags available: # other image tags available:
# https://github.com/waytrade/ib-gateway-docker#supported-tags # https://github.com/waytrade/ib-gateway-docker#supported-tags
# image: waytrade/ib-gateway:981.3j image: waytrade/ib-gateway:981.3j
image: waytrade/ib-gateway:1012.2i restart: always
restart: 'no' # restart on boot whenev there's a crash or user clicsk
network_mode: 'host' network_mode: 'host'
volumes: volumes:
@ -40,12 +39,14 @@ services:
# this compose file which looks something like: # this compose file which looks something like:
# TWS_USERID='myuser' # TWS_USERID='myuser'
# TWS_PASSWORD='guest' # TWS_PASSWORD='guest'
# TRADING_MODE=paper (or live)
# VNC_SERVER_PASSWORD='diggity'
environment: environment:
TWS_USERID: ${TWS_USERID} TWS_USERID: ${TWS_USERID}
TWS_PASSWORD: ${TWS_PASSWORD} TWS_PASSWORD: ${TWS_PASSWORD}
TRADING_MODE: 'paper' TRADING_MODE: ${TRADING_MODE:-paper}
VNC_SERVER_PASSWORD: 'doggy' VNC_SERVER_PASSWORD: ${VNC_SERVER_PASSWORD:-}
VNC_SERVER_PORT: '3003'
# ports: # ports:
# - target: 4002 # - target: 4002
@ -61,40 +62,3 @@ services:
# - "127.0.0.1:4001:4001" # - "127.0.0.1:4001:4001"
# - "127.0.0.1:4002:4002" # - "127.0.0.1:4002:4002"
# - "127.0.0.1:5900:5900" # - "127.0.0.1:5900:5900"
# ib_gw_live:
# image: waytrade/ib-gateway:1012.2i
# restart: no
# network_mode: 'host'
# volumes:
# - type: bind
# source: ./jts_live.ini
# target: /root/jts/jts.ini
# # don't let ibc clobber this file for
# # the main reason of not having a stupid
# # timezone set..
# read_only: true
# # force our own ibc config
# - type: bind
# source: ./ibc.ini
# target: /root/ibc/config.ini
# # force our noop script - socat isn't needed in host mode.
# - type: bind
# source: ./fork_ports_delayed.sh
# target: /root/scripts/fork_ports_delayed.sh
# # force our noop script - socat isn't needed in host mode.
# - type: bind
# source: ./run_x11_vnc.sh
# target: /root/scripts/run_x11_vnc.sh
# read_only: true
# # NOTE: to fill these out, define an `.env` file in the same dir as
# # this compose file which looks something like:
# environment:
# TRADING_MODE: 'live'
# VNC_SERVER_PASSWORD: 'doggy'
# VNC_SERVER_PORT: '3004'

View File

@ -188,7 +188,7 @@ AcceptNonBrokerageAccountWarning=yes
# #
# The default value is 60. # The default value is 60.
LoginDialogDisplayTimeout=20 LoginDialogDisplayTimeout = 60
@ -292,7 +292,7 @@ ExistingSessionDetectedAction=primary
# be set dynamically at run-time: most users will never need it, # be set dynamically at run-time: most users will never need it,
# so don't use it unless you know you need it. # so don't use it unless you know you need it.
; OverrideTwsApiPort=4002 OverrideTwsApiPort=4002
# Read-only Login # Read-only Login

View File

@ -1,33 +0,0 @@
[IBGateway]
ApiOnly=true
LocalServerPort=4001
# NOTE: must be set if using IBC's "reject" mode
TrustedIPs=127.0.0.1
; RemoteHostOrderRouting=ndc1.ibllc.com
; WriteDebug=true
; RemotePortOrderRouting=4001
; useRemoteSettings=false
; tradingMode=p
; Steps=8
; colorPalletName=dark
# window geo, this may be useful for sending `xdotool` commands?
; MainWindow.Width=1986
; screenHeight=3960
[Logon]
Locale=en
# most markets are oriented around this zone
# so might as well hard code it.
TimeZone=America/New_York
UseSSL=true
displayedproxymsg=1
os_titlebar=true
s3store=true
useRemoteSettings=false
[Communication]
ctciAutoEncrypt=true
Region=usr
; Peer=cdc1.ibllc.com:4001

View File

@ -1,35 +1,16 @@
#!/bin/sh #!/bin/sh
# start vnc server and listen for connections
# on port specced in `$VNC_SERVER_PORT`
# start VNC server
x11vnc \ x11vnc \
-listen 127.0.0.1 \ -ncache_cr \
-allow 127.0.0.1 \ -listen localhost \
-rfbport "${VNC_SERVER_PORT}" \
-display :1 \ -display :1 \
-forever \ -forever \
-shared \ -shared \
-logappend /var/log/x11vnc.log \
-bg \ -bg \
-nowf \
-noxdamage \
-noxfixes \
-no6 \
-noipv6 \ -noipv6 \
-autoport 3003 \
# can't use this because of ``asyncvnc`` issue:
# -nowcr \
# TODO: can't use this because of ``asyncvnc`` issue:
# https://github.com/barneygale/asyncvnc/issues/1 # https://github.com/barneygale/asyncvnc/issues/1
# -passwd 'ibcansmbz' # -passwd 'ibcansmbz'
# XXX: optional graphics caching flags that seem to rekt the overlay
# of the 2 gw windows? When running a single gateway
# this seems to maybe optimize some memory usage?
# -ncache_cr \
# -ncache \
# NOTE: this will prevent logs from going to the console.
# -logappend /var/log/x11vnc.log \
# where to start allocating ports
# -autoport "${VNC_SERVER_PORT}" \

View File

@ -18,10 +18,3 @@
piker: trading gear for hackers. piker: trading gear for hackers.
""" """
from ._daemon import open_piker_runtime
from .data.feed import open_feed
__all__ = [
'open_piker_runtime',
'open_feed',
]

View File

@ -18,27 +18,16 @@
Structured, daemon tree service management. Structured, daemon tree service management.
""" """
from __future__ import annotations from typing import Optional, Union, Callable, Any
import os from contextlib import asynccontextmanager as acm
from typing import (
Optional,
Callable,
Any,
ClassVar,
)
from contextlib import (
asynccontextmanager as acm,
)
from collections import defaultdict from collections import defaultdict
import tractor from pydantic import BaseModel
import trio import trio
from trio_typing import TaskStatus from trio_typing import TaskStatus
import tractor
from .log import ( from .log import get_logger, get_console_log
get_logger,
get_console_log,
)
from .brokers import get_brokermod from .brokers import get_brokermod
@ -46,118 +35,28 @@ log = get_logger(__name__)
_root_dname = 'pikerd' _root_dname = 'pikerd'
_default_registry_host: str = '127.0.0.1' _registry_addr = ('127.0.0.1', 6116)
_default_registry_port: int = 6116 _tractor_kwargs: dict[str, Any] = {
_default_reg_addr: tuple[str, int] = ( # use a different registry addr then tractor's default
_default_registry_host, 'arbiter_addr': _registry_addr
_default_registry_port, }
)
# NOTE: this value is set as an actor-global once the first endpoint
# who is capable, spawns a `pikerd` service tree.
_registry: Registry | None = None
class Registry:
addr: None | tuple[str, int] = None
# TODO: table of uids to sockaddrs
peers: dict[
tuple[str, str],
tuple[str, int],
] = {}
_tractor_kwargs: dict[str, Any] = {}
@acm
async def open_registry(
addr: None | tuple[str, int] = None,
ensure_exists: bool = True,
) -> tuple[str, int]:
global _tractor_kwargs
actor = tractor.current_actor()
uid = actor.uid
if (
Registry.addr is not None
and addr
):
raise RuntimeError(
f'`{uid}` registry addr already bound @ {_registry.sockaddr}'
)
was_set: bool = False
if (
not tractor.is_root_process()
and Registry.addr is None
):
Registry.addr = actor._arb_addr
if (
ensure_exists
and Registry.addr is None
):
raise RuntimeError(
f"`{uid}` registry should already exist bug doesn't?"
)
if (
Registry.addr is None
):
was_set = True
Registry.addr = addr or _default_reg_addr
_tractor_kwargs['arbiter_addr'] = Registry.addr
try:
yield Registry.addr
finally:
# XXX: always clear the global addr if we set it so that the
# next (set of) calls will apply whatever new one is passed
# in.
if was_set:
Registry.addr = None
def get_tractor_runtime_kwargs() -> dict[str, Any]:
'''
Deliver ``tractor`` related runtime variables in a `dict`.
'''
return _tractor_kwargs
_root_modules = [ _root_modules = [
__name__, __name__,
'piker.clearing._ems', 'piker.clearing._ems',
'piker.clearing._client', 'piker.clearing._client',
'piker.data._sampling',
] ]
# TODO: factor this into a ``tractor.highlevel`` extension class Services(BaseModel):
# pack for the library.
class Services:
actor_n: tractor._supervise.ActorNursery actor_n: tractor._supervise.ActorNursery
service_n: trio.Nursery service_n: trio.Nursery
debug_mode: bool # tractor sub-actor debug mode flag debug_mode: bool # tractor sub-actor debug mode flag
service_tasks: dict[ service_tasks: dict[str, tuple[trio.CancelScope, tractor.Portal]] = {}
str,
tuple[ class Config:
trio.CancelScope, arbitrary_types_allowed = True
tractor.Portal,
trio.Event,
]
] = {}
locks = defaultdict(trio.Lock)
@classmethod
async def start_service_task( async def start_service_task(
self, self,
name: str, name: str,
@ -176,12 +75,7 @@ class Services:
''' '''
async def open_context_in_task( async def open_context_in_task(
task_status: TaskStatus[ task_status: TaskStatus[
tuple[ trio.CancelScope] = trio.TASK_STATUS_IGNORED,
trio.CancelScope,
trio.Event,
Any,
]
] = trio.TASK_STATUS_IGNORED,
) -> Any: ) -> Any:
@ -193,173 +87,143 @@ class Services:
) as (ctx, first): ) as (ctx, first):
# unblock once the remote context has started # unblock once the remote context has started
complete = trio.Event() task_status.started((cs, first))
task_status.started((cs, complete, first))
log.info( log.info(
f'`pikerd` service {name} started with value {first}' f'`pikerd` service {name} started with value {first}'
) )
try: try:
# wait on any context's return value # wait on any context's return value
# and any final portal result from the
# sub-actor.
ctx_res = await ctx.result() ctx_res = await ctx.result()
except tractor.ContextCancelled:
# NOTE: blocks indefinitely until cancelled return await self.cancel_service(name)
# either by error from the target context else:
# function or by being cancelled here by the # wait on any error from the sub-actor
# surrounding cancel scope. # NOTE: this will block indefinitely until
# cancelled either by error from the target
# context function or by being cancelled here by
# the surrounding cancel scope
return (await portal.result(), ctx_res) return (await portal.result(), ctx_res)
finally: cs, first = await self.service_n.start(open_context_in_task)
await portal.cancel_actor()
complete.set()
self.service_tasks.pop(name)
cs, complete, first = await self.service_n.start(open_context_in_task)
# store the cancel scope and portal for later cancellation or # store the cancel scope and portal for later cancellation or
# retstart if needed. # retstart if needed.
self.service_tasks[name] = (cs, portal, complete) self.service_tasks[name] = (cs, portal)
return cs, first return cs, first
@classmethod # TODO: per service cancellation by scope, we aren't using this
# anywhere right?
async def cancel_service( async def cancel_service(
self, self,
name: str, name: str,
) -> Any: ) -> Any:
log.info(f'Cancelling `pikerd` service {name}')
cs, portal = self.service_tasks[name]
# XXX: not entirely sure why this is required,
# and should probably be better fine tuned in
# ``tractor``?
cs.cancel()
return await portal.cancel_actor()
_services: Optional[Services] = None
@acm
async def open_pikerd(
start_method: str = 'trio',
loglevel: Optional[str] = None,
# XXX: you should pretty much never want debug mode
# for data daemons when running in production.
debug_mode: bool = False,
) -> Optional[tractor._portal.Portal]:
''' '''
Cancel the service task and actor for the given ``name``. Start a root piker daemon who's lifetime extends indefinitely
until cancelled.
A root actor nursery is created which can be used to create and keep
alive underling services (see below).
''' '''
log.info(f'Cancelling `pikerd` service {name}') global _services
cs, portal, complete = self.service_tasks[name] assert _services is None
cs.cancel()
await complete.wait() # XXX: this may open a root actor as well
assert name not in self.service_tasks, \ async with (
f'Serice task for {name} not terminated?' tractor.open_root_actor(
# passed through to ``open_root_actor``
arbiter_addr=_registry_addr,
name=_root_dname,
loglevel=loglevel,
debug_mode=debug_mode,
start_method=start_method,
# TODO: eventually we should be able to avoid
# having the root have more then permissions to
# spawn other specialized daemons I think?
enable_modules=_root_modules,
) as _,
tractor.open_nursery() as actor_nursery,
):
async with trio.open_nursery() as service_nursery:
# # setup service mngr singleton instance
# async with AsyncExitStack() as stack:
# assign globally for future daemon/task creation
_services = Services(
actor_n=actor_nursery,
service_n=service_nursery,
debug_mode=debug_mode,
)
yield _services
@acm @acm
async def open_piker_runtime( async def open_piker_runtime(
name: str, name: str,
enable_modules: list[str] = [], enable_modules: list[str] = [],
start_method: str = 'trio',
loglevel: Optional[str] = None, loglevel: Optional[str] = None,
# XXX NOTE XXX: you should pretty much never want debug mode # XXX: you should pretty much never want debug mode
# for data daemons when running in production. # for data daemons when running in production.
debug_mode: bool = False, debug_mode: bool = False,
registry_addr: None | tuple[str, int] = None, ) -> Optional[tractor._portal.Portal]:
# TODO: once we have `rsyscall` support we will read a config
# and spawn the service tree distributed per that.
start_method: str = 'trio',
tractor_kwargs: dict = {},
) -> tuple[
tractor.Actor,
tuple[str, int],
]:
''' '''
Start a piker actor who's runtime will automatically sync with Start a piker actor who's runtime will automatically
existing piker actors on the local link based on configuration. sync with existing piker actors in local network
based on configuration.
Can be called from a subactor or any program that needs to start
a root actor.
''' '''
try: global _services
# check for existing runtime assert _services is None
actor = tractor.current_actor().uid
except tractor._exceptions.NoRuntime:
registry_addr = registry_addr or _default_reg_addr
# XXX: this may open a root actor as well
async with ( async with (
tractor.open_root_actor( tractor.open_root_actor(
# passed through to ``open_root_actor`` # passed through to ``open_root_actor``
arbiter_addr=registry_addr, arbiter_addr=_registry_addr,
name=name, name=name,
loglevel=loglevel, loglevel=loglevel,
debug_mode=debug_mode, debug_mode=debug_mode,
start_method=start_method, start_method=start_method,
# TODO: eventually we should be able to avoid
# having the root have more then permissions to
# spawn other specialized daemons I think?
enable_modules=enable_modules,
**tractor_kwargs,
) as _,
open_registry(registry_addr, ensure_exists=False) as addr,
):
yield (
tractor.current_actor(),
addr,
)
else:
async with open_registry(registry_addr) as addr:
yield (
actor,
addr,
)
@acm
async def open_pikerd(
loglevel: str | None = None,
# XXX: you should pretty much never want debug mode
# for data daemons when running in production.
debug_mode: bool = False,
registry_addr: None | tuple[str, int] = None,
) -> Services:
'''
Start a root piker daemon who's lifetime extends indefinitely until
cancelled.
A root actor nursery is created which can be used to create and keep
alive underling services (see below).
'''
async with (
open_piker_runtime(
name=_root_dname,
# TODO: eventually we should be able to avoid # TODO: eventually we should be able to avoid
# having the root have more then permissions to # having the root have more then permissions to
# spawn other specialized daemons I think? # spawn other specialized daemons I think?
enable_modules=_root_modules, enable_modules=_root_modules,
) as _,
loglevel=loglevel,
debug_mode=debug_mode,
registry_addr=registry_addr,
) as (root_actor, reg_addr),
tractor.open_nursery() as actor_nursery,
trio.open_nursery() as service_nursery,
): ):
assert root_actor.accept_addr == reg_addr yield tractor.current_actor()
# assign globally for future daemon/task creation
Services.actor_n = actor_nursery
Services.service_n = service_nursery
Services.debug_mode = debug_mode
try:
yield Services
finally:
# TODO: is this more clever/efficient?
# if 'samplerd' in Services.service_tasks:
# await Services.cancel_service('samplerd')
service_nursery.cancel_scope.cancel()
@acm @acm
@ -368,89 +232,61 @@ async def maybe_open_runtime(
**kwargs, **kwargs,
) -> None: ) -> None:
''' """
Start the ``tractor`` runtime (a root actor) if none exists. Start the ``tractor`` runtime (a root actor) if none exists.
''' """
name = kwargs.pop('name') settings = _tractor_kwargs
settings.update(kwargs)
if not tractor.current_actor(err_on_no_runtime=False): if not tractor.current_actor(err_on_no_runtime=False):
async with open_piker_runtime( async with tractor.open_root_actor(
name,
loglevel=loglevel, loglevel=loglevel,
**kwargs, **settings,
) as (_, addr): ):
yield addr, yield
else: else:
async with open_registry() as addr: yield
yield addr
@acm @acm
async def maybe_open_pikerd( async def maybe_open_pikerd(
loglevel: Optional[str] = None, loglevel: Optional[str] = None,
registry_addr: None | tuple = None,
**kwargs, **kwargs,
) -> tractor._portal.Portal | ClassVar[Services]: ) -> Union[tractor._portal.Portal, Services]:
''' """If no ``pikerd`` daemon-root-actor can be found start it and
If no ``pikerd`` daemon-root-actor can be found start it and
yield up (we should probably figure out returning a portal to self yield up (we should probably figure out returning a portal to self
though). though).
''' """
if loglevel: if loglevel:
get_console_log(loglevel) get_console_log(loglevel)
# subtle, we must have the runtime up here or portal lookup will fail # subtle, we must have the runtime up here or portal lookup will fail
query_name = kwargs.pop('name', f'piker_query_{os.getpid()}') async with maybe_open_runtime(loglevel, **kwargs):
# TODO: if we need to make the query part faster we could not init async with tractor.find_actor(_root_dname) as portal:
# an actor runtime and instead just hit the socket? # assert portal is not None
# from tractor._ipc import _connect_chan, Channel if portal is not None:
# async with _connect_chan(host, port) as chan:
# async with open_portal(chan) as arb_portal:
# yield arb_portal
async with (
open_piker_runtime(
name=query_name,
registry_addr=registry_addr,
loglevel=loglevel,
**kwargs,
) as _,
tractor.find_actor(
_root_dname,
arbiter_sockaddr=registry_addr,
) as portal
):
# connect to any existing daemon presuming
# its registry socket was selected.
if (
portal is not None
):
yield portal yield portal
return return
# presume pikerd role since no daemon could be found at # presume pikerd role since no daemon could be found at
# configured address # configured address
async with open_pikerd( async with open_pikerd(
loglevel=loglevel, loglevel=loglevel,
debug_mode=kwargs.get('debug_mode', False), debug_mode=kwargs.get('debug_mode', False),
registry_addr=registry_addr,
) as service_manager: ) as _:
# in the case where we're starting up the # in the case where we're starting up the
# tractor-piker runtime stack in **this** process # tractor-piker runtime stack in **this** process
# we return no portal to self. # we return no portal to self.
assert service_manager yield None
yield service_manager
# `brokerd` enabled modules # brokerd enabled modules
# NOTE: keeping this list as small as possible is part of our caps-sec
# model and should be treated with utmost care!
_data_mods = [ _data_mods = [
'piker.brokers.core', 'piker.brokers.core',
'piker.brokers.data', 'piker.brokers.data',
@ -460,17 +296,20 @@ _data_mods = [
] ]
class Brokerd:
locks = defaultdict(trio.Lock)
@acm @acm
async def find_service( async def find_service(
service_name: str, service_name: str,
) -> tractor.Portal | None: ) -> Optional[tractor.Portal]:
async with open_registry() as reg_addr:
log.info(f'Scanning for service `{service_name}`') log.info(f'Scanning for service `{service_name}`')
# attach to existing daemon by name if possible # attach to existing daemon by name if possible
async with tractor.find_actor( async with tractor.find_actor(
service_name, service_name,
arbiter_sockaddr=reg_addr, arbiter_sockaddr=_registry_addr,
) as maybe_portal: ) as maybe_portal:
yield maybe_portal yield maybe_portal
@ -478,15 +317,14 @@ async def find_service(
async def check_for_service( async def check_for_service(
service_name: str, service_name: str,
) -> None | tuple[str, int]: ) -> bool:
''' '''
Service daemon "liveness" predicate. Service daemon "liveness" predicate.
''' '''
async with open_registry(ensure_exists=False) as reg_addr:
async with tractor.query_actor( async with tractor.query_actor(
service_name, service_name,
arbiter_sockaddr=reg_addr, arbiter_sockaddr=_registry_addr,
) as sockaddr: ) as sockaddr:
return sockaddr return sockaddr
@ -498,8 +336,6 @@ async def maybe_spawn_daemon(
service_task_target: Callable, service_task_target: Callable,
spawn_args: dict[str, Any], spawn_args: dict[str, Any],
loglevel: Optional[str] = None, loglevel: Optional[str] = None,
singleton: bool = False,
**kwargs, **kwargs,
) -> tractor.Portal: ) -> tractor.Portal:
@ -520,7 +356,7 @@ async def maybe_spawn_daemon(
# serialize access to this section to avoid # serialize access to this section to avoid
# 2 or more tasks racing to create a daemon # 2 or more tasks racing to create a daemon
lock = Services.locks[service_name] lock = Brokerd.locks[service_name]
await lock.acquire() await lock.acquire()
async with find_service(service_name) as portal: async with find_service(service_name) as portal:
@ -531,9 +367,6 @@ async def maybe_spawn_daemon(
log.warning(f"Couldn't find any existing {service_name}") log.warning(f"Couldn't find any existing {service_name}")
# TODO: really shouldn't the actor spawning be part of the service
# starting method `Services.start_service()` ?
# ask root ``pikerd`` daemon to spawn the daemon we need if # ask root ``pikerd`` daemon to spawn the daemon we need if
# pikerd is not live we now become the root of the # pikerd is not live we now become the root of the
# process tree # process tree
@ -544,6 +377,7 @@ async def maybe_spawn_daemon(
) as pikerd_portal: ) as pikerd_portal:
if pikerd_portal is None:
# we are the root and thus are `pikerd` # we are the root and thus are `pikerd`
# so spawn the target service directly by calling # so spawn the target service directly by calling
# the provided target routine. # the provided target routine.
@ -551,9 +385,7 @@ async def maybe_spawn_daemon(
# do the right things to setup both a sub-actor **and** call # do the right things to setup both a sub-actor **and** call
# the ``_Services`` api from above to start the top level # the ``_Services`` api from above to start the top level
# service task for that actor. # service task for that actor.
started: bool await service_task_target(**spawn_args)
if pikerd_portal is None:
started = await service_task_target(**spawn_args)
else: else:
# tell the remote `pikerd` to start the target, # tell the remote `pikerd` to start the target,
@ -562,14 +394,11 @@ async def maybe_spawn_daemon(
# non-blocking and the target task will persist running # non-blocking and the target task will persist running
# on `pikerd` after the client requesting it's start # on `pikerd` after the client requesting it's start
# disconnects. # disconnects.
started = await pikerd_portal.run( await pikerd_portal.run(
service_task_target, service_task_target,
**spawn_args, **spawn_args,
) )
if started:
log.info(f'Service {service_name} started!')
async with tractor.wait_for_actor(service_name) as portal: async with tractor.wait_for_actor(service_name) as portal:
lock.release() lock.release()
yield portal yield portal
@ -592,6 +421,9 @@ async def spawn_brokerd(
extra_tractor_kwargs = getattr(brokermod, '_spawn_kwargs', {}) extra_tractor_kwargs = getattr(brokermod, '_spawn_kwargs', {})
tractor_kwargs.update(extra_tractor_kwargs) tractor_kwargs.update(extra_tractor_kwargs)
global _services
assert _services
# ask `pikerd` to spawn a new sub-actor and manage it under its # ask `pikerd` to spawn a new sub-actor and manage it under its
# actor nursery # actor nursery
modpath = brokermod.__name__ modpath = brokermod.__name__
@ -604,18 +436,18 @@ async def spawn_brokerd(
subpath = f'{modpath}.{submodname}' subpath = f'{modpath}.{submodname}'
broker_enable.append(subpath) broker_enable.append(subpath)
portal = await Services.actor_n.start_actor( portal = await _services.actor_n.start_actor(
dname, dname,
enable_modules=_data_mods + broker_enable, enable_modules=_data_mods + broker_enable,
loglevel=loglevel, loglevel=loglevel,
debug_mode=Services.debug_mode, debug_mode=_services.debug_mode,
**tractor_kwargs **tractor_kwargs
) )
# non-blocking setup of brokerd service nursery # non-blocking setup of brokerd service nursery
from .data import _setup_persistent_brokerd from .data import _setup_persistent_brokerd
await Services.start_service_task( await _services.start_service_task(
dname, dname,
portal, portal,
_setup_persistent_brokerd, _setup_persistent_brokerd,
@ -661,21 +493,24 @@ async def spawn_emsd(
""" """
log.info('Spawning emsd') log.info('Spawning emsd')
portal = await Services.actor_n.start_actor( global _services
assert _services
portal = await _services.actor_n.start_actor(
'emsd', 'emsd',
enable_modules=[ enable_modules=[
'piker.clearing._ems', 'piker.clearing._ems',
'piker.clearing._client', 'piker.clearing._client',
], ],
loglevel=loglevel, loglevel=loglevel,
debug_mode=Services.debug_mode, # set by pikerd flag debug_mode=_services.debug_mode, # set by pikerd flag
**extra_tractor_kwargs **extra_tractor_kwargs
) )
# non-blocking setup of clearing service # non-blocking setup of clearing service
from .clearing._ems import _setup_persistent_emsd from .clearing._ems import _setup_persistent_emsd
await Services.start_service_task( await _services.start_service_task(
'emsd', 'emsd',
portal, portal,
_setup_persistent_emsd, _setup_persistent_emsd,
@ -702,3 +537,25 @@ async def maybe_open_emsd(
) as portal: ) as portal:
yield portal yield portal
# TODO: ideally we can start the tsdb "on demand" but it's
# probably going to require "rootless" docker, at least if we don't
# want to expect the user to start ``pikerd`` with root perms all the
# time.
# async def maybe_open_marketstored(
# loglevel: Optional[str] = None,
# **kwargs,
# ) -> tractor._portal.Portal: # noqa
# async with maybe_spawn_daemon(
# 'marketstored',
# service_task_target=spawn_emsd,
# spawn_args={'loglevel': loglevel},
# loglevel=loglevel,
# **kwargs,
# ) as portal:
# yield portal

View File

@ -18,10 +18,7 @@
Profiling wrappers for internal libs. Profiling wrappers for internal libs.
""" """
import os
import sys
import time import time
from time import perf_counter
from functools import wraps from functools import wraps
# NOTE: you can pass a flag to enable this: # NOTE: you can pass a flag to enable this:
@ -47,184 +44,3 @@ def timeit(fn):
return res return res
return wrapper return wrapper
# Modified version of ``pyqtgraph.debug.Profiler`` that
# core seems hesitant to land in:
# https://github.com/pyqtgraph/pyqtgraph/pull/2281
class Profiler(object):
'''
Simple profiler allowing measurement of multiple time intervals.
By default, profilers are disabled. To enable profiling, set the
environment variable `PYQTGRAPHPROFILE` to a comma-separated list of
fully-qualified names of profiled functions.
Calling a profiler registers a message (defaulting to an increasing
counter) that contains the time elapsed since the last call. When the
profiler is about to be garbage-collected, the messages are passed to the
outer profiler if one is running, or printed to stdout otherwise.
If `delayed` is set to False, messages are immediately printed instead.
Example:
def function(...):
profiler = Profiler()
... do stuff ...
profiler('did stuff')
... do other stuff ...
profiler('did other stuff')
# profiler is garbage-collected and flushed at function end
If this function is a method of class C, setting `PYQTGRAPHPROFILE` to
"C.function" (without the module name) will enable this profiler.
For regular functions, use the qualified name of the function, stripping
only the initial "pyqtgraph." prefix from the module.
'''
_profilers = os.environ.get("PYQTGRAPHPROFILE", None)
_profilers = _profilers.split(",") if _profilers is not None else []
_depth = 0
# NOTE: without this defined at the class level
# you won't see apprpriately "nested" sub-profiler
# instance calls.
_msgs = []
# set this flag to disable all or individual profilers at runtime
disable = False
class DisabledProfiler(object):
def __init__(self, *args, **kwds):
pass
def __call__(self, *args):
pass
def finish(self):
pass
def mark(self, msg=None):
pass
_disabledProfiler = DisabledProfiler()
def __new__(
cls,
msg=None,
disabled='env',
delayed=True,
ms_threshold: float = 0.0,
):
"""Optionally create a new profiler based on caller's qualname.
``ms_threshold`` can be set to value in ms for which, if the
total measured time of the lifetime of this profiler is **less
than** this value, then no profiling messages will be printed.
Setting ``delayed=False`` disables this feature since messages
are emitted immediately.
"""
if (
disabled is True
or (
disabled == 'env'
and len(cls._profilers) == 0
)
):
return cls._disabledProfiler
# determine the qualified name of the caller function
caller_frame = sys._getframe(1)
try:
caller_object_type = type(caller_frame.f_locals["self"])
except KeyError: # we are in a regular function
qualifier = caller_frame.f_globals["__name__"].split(".", 1)[-1]
else: # we are in a method
qualifier = caller_object_type.__name__
func_qualname = qualifier + "." + caller_frame.f_code.co_name
if disabled == 'env' and func_qualname not in cls._profilers:
# don't do anything
return cls._disabledProfiler
# create an actual profiling object
cls._depth += 1
obj = super(Profiler, cls).__new__(cls)
obj._name = msg or func_qualname
obj._delayed = delayed
obj._markCount = 0
obj._finished = False
obj._firstTime = obj._lastTime = perf_counter()
obj._mt = ms_threshold
obj._newMsg("> Entering " + obj._name)
return obj
def __call__(self, msg=None):
"""Register or print a new message with timing information.
"""
if self.disable:
return
if msg is None:
msg = str(self._markCount)
self._markCount += 1
newTime = perf_counter()
ms = (newTime - self._lastTime) * 1000
self._newMsg(" %s: %0.4f ms", msg, ms)
self._lastTime = newTime
def mark(self, msg=None):
self(msg)
def _newMsg(self, msg, *args):
msg = " " * (self._depth - 1) + msg
if self._delayed:
self._msgs.append((msg, args))
else:
print(msg % args)
def __del__(self):
self.finish()
def finish(self, msg=None):
"""Add a final message; flush the message list if no parent profiler.
"""
if self._finished or self.disable:
return
self._finished = True
if msg is not None:
self(msg)
tot_ms = (perf_counter() - self._firstTime) * 1000
self._newMsg(
"< Exiting %s, total time: %0.4f ms",
self._name,
tot_ms,
)
if tot_ms < self._mt:
# print(f'{tot_ms} < {self._mt}, clearing')
# NOTE: this list **must** be an instance var to avoid
# deleting common messages during GC I think?
self._msgs.clear()
# else:
# print(f'{tot_ms} > {self._mt}, not clearing')
# XXX: why is this needed?
# don't we **want to show** nested profiler messages?
if self._msgs: # and self._depth < 1:
# if self._msgs:
print("\n".join([m[0] % m[1] for m in self._msgs]))
# clear all entries
self._msgs.clear()
# type(self)._msgs = []
type(self)._depth -= 1

View File

@ -26,21 +26,10 @@ asks.init('trio')
__brokers__ = [ __brokers__ = [
'binance', 'binance',
'questrade',
'robinhood',
'ib', 'ib',
'kraken', 'kraken',
# broken but used to work
# 'questrade',
# 'robinhood',
# TODO: we should get on these stat!
# alpaca
# wstrade
# iex
# deribit
# kucoin
# bitso
] ]

View File

@ -33,23 +33,15 @@ import asks
from fuzzywuzzy import process as fuzzy from fuzzywuzzy import process as fuzzy
import numpy as np import numpy as np
import tractor import tractor
from pydantic.dataclasses import dataclass
from pydantic import BaseModel
import wsproto import wsproto
from .._cacheables import open_cached_client from .._cacheables import open_cached_client
from ._util import ( from ._util import resproc, SymbolNotFound
resproc, from ..log import get_logger, get_console_log
SymbolNotFound, from ..data import ShmArray
DataUnavailable, from ..data._web_bs import open_autorecon_ws, NoBsWs
)
from ..log import (
get_logger,
get_console_log,
)
from ..data.types import Struct
from ..data._web_bs import (
open_autorecon_ws,
NoBsWs,
)
log = get_logger(__name__) log = get_logger(__name__)
@ -87,14 +79,12 @@ _show_wap_in_history = False
# https://binance-docs.github.io/apidocs/spot/en/#exchange-information # https://binance-docs.github.io/apidocs/spot/en/#exchange-information
class Pair(Struct, frozen=True): class Pair(BaseModel):
symbol: str symbol: str
status: str status: str
baseAsset: str baseAsset: str
baseAssetPrecision: int baseAssetPrecision: int
cancelReplaceAllowed: bool
allowTrailingStop: bool
quoteAsset: str quoteAsset: str
quotePrecision: int quotePrecision: int
quoteAssetPrecision: int quoteAssetPrecision: int
@ -110,21 +100,18 @@ class Pair(Struct, frozen=True):
isSpotTradingAllowed: bool isSpotTradingAllowed: bool
isMarginTradingAllowed: bool isMarginTradingAllowed: bool
defaultSelfTradePreventionMode: str
allowedSelfTradePreventionModes: list[str]
filters: list[dict[str, Union[str, int, float]]] filters: list[dict[str, Union[str, int, float]]]
permissions: list[str] permissions: list[str]
class OHLC(Struct): @dataclass
''' class OHLC:
Description of the flattened OHLC quote format. """Description of the flattened OHLC quote format.
For schema details see: For schema details see:
https://binance-docs.github.io/apidocs/spot/en/#kline-candlestick-streams https://binance-docs.github.io/apidocs/spot/en/#kline-candlestick-streams
''' """
time: int time: int
open: float open: float
@ -147,9 +134,7 @@ class OHLC(Struct):
# convert datetime obj timestamp to unixtime in milliseconds # convert datetime obj timestamp to unixtime in milliseconds
def binance_timestamp( def binance_timestamp(when):
when: datetime
) -> int:
return int((when.timestamp() * 1000) + (when.microsecond / 1000)) return int((when.timestamp() * 1000) + (when.microsecond / 1000))
@ -188,7 +173,7 @@ class Client:
params = {} params = {}
if sym is not None: if sym is not None:
sym = sym.lower() sym = sym.upper()
params = {'symbol': sym} params = {'symbol': sym}
resp = await self._api( resp = await self._api(
@ -245,7 +230,7 @@ class Client:
) -> dict: ) -> dict:
if end_dt is None: if end_dt is None:
end_dt = pendulum.now('UTC').add(minutes=1) end_dt = pendulum.now('UTC')
if start_dt is None: if start_dt is None:
start_dt = end_dt.start_of( start_dt = end_dt.start_of(
@ -275,7 +260,6 @@ class Client:
for i, bar in enumerate(bars): for i, bar in enumerate(bars):
bar = OHLC(*bar) bar = OHLC(*bar)
bar.typecast()
row = [] row = []
for j, (name, ftype) in enumerate(_ohlc_dtype[1:]): for j, (name, ftype) in enumerate(_ohlc_dtype[1:]):
@ -303,7 +287,7 @@ async def get_client() -> Client:
# validation type # validation type
class AggTrade(Struct): class AggTrade(BaseModel):
e: str # Event type e: str # Event type
E: int # Event time E: int # Event time
s: str # Symbol s: str # Symbol
@ -357,9 +341,7 @@ async def stream_messages(ws: NoBsWs) -> AsyncGenerator[NoBsWs, dict]:
elif msg.get('e') == 'aggTrade': elif msg.get('e') == 'aggTrade':
# NOTE: this is purely for a definition, ``msgspec.Struct`` # validate
# does not runtime-validate until you decode/encode.
# see: https://jcristharif.com/msgspec/structs.html#type-validation
msg = AggTrade(**msg) msg = AggTrade(**msg)
# TODO: type out and require this quote format # TODO: type out and require this quote format
@ -370,8 +352,8 @@ async def stream_messages(ws: NoBsWs) -> AsyncGenerator[NoBsWs, dict]:
'brokerd_ts': time.time(), 'brokerd_ts': time.time(),
'ticks': [{ 'ticks': [{
'type': 'trade', 'type': 'trade',
'price': float(msg.p), 'price': msg.p,
'size': float(msg.q), 'size': msg.q,
'broker_ts': msg.T, 'broker_ts': msg.T,
}], }],
} }
@ -402,39 +384,41 @@ async def open_history_client(
async with open_cached_client('binance') as client: async with open_cached_client('binance') as client:
async def get_ohlc( async def get_ohlc(
timeframe: float, end_dt: Optional[datetime] = None,
end_dt: datetime | None = None, start_dt: Optional[datetime] = None,
start_dt: datetime | None = None,
) -> tuple[ ) -> tuple[
np.ndarray, np.ndarray,
datetime, # start datetime, # start
datetime, # end datetime, # end
]: ]:
if timeframe != 60:
raise DataUnavailable('Only 1m bars are supported')
array = await client.bars( array = await client.bars(
symbol, symbol,
start_dt=start_dt, start_dt=start_dt,
end_dt=end_dt, end_dt=end_dt,
) )
times = array['time'] start_dt = pendulum.from_timestamp(array[0]['time'])
if ( end_dt = pendulum.from_timestamp(array[-1]['time'])
end_dt is None
):
inow = round(time.time())
if (inow - times[-1]) > 60:
await tractor.breakpoint()
start_dt = pendulum.from_timestamp(times[0])
end_dt = pendulum.from_timestamp(times[-1])
return array, start_dt, end_dt return array, start_dt, end_dt
yield get_ohlc, {'erlangs': 3, 'rate': 3} yield get_ohlc, {'erlangs': 3, 'rate': 3}
async def backfill_bars(
sym: str,
shm: ShmArray, # type: ignore # noqa
task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED,
) -> None:
"""Fill historical bars into shared mem / storage afap.
"""
with trio.CancelScope() as cs:
async with open_cached_client('binance') as client:
bars = await client.bars(symbol=sym)
shm.push(bars)
task_status.started(cs)
async def stream_quotes( async def stream_quotes(
send_chan: trio.abc.SendChannel, send_chan: trio.abc.SendChannel,
@ -464,20 +448,12 @@ async def stream_quotes(
d = cache[sym.upper()] d = cache[sym.upper()]
syminfo = Pair(**d) # validation syminfo = Pair(**d) # validation
si = sym_infos[sym] = syminfo.to_dict() si = sym_infos[sym] = syminfo.dict()
filters = {}
for entry in syminfo.filters:
ftype = entry['filterType']
filters[ftype] = entry
# XXX: after manually inspecting the response format we # XXX: after manually inspecting the response format we
# just directly pick out the info we need # just directly pick out the info we need
si['price_tick_size'] = float( si['price_tick_size'] = float(syminfo.filters[0]['tickSize'])
filters['PRICE_FILTER']['tickSize'] si['lot_tick_size'] = float(syminfo.filters[2]['stepSize'])
)
si['lot_tick_size'] = float(
filters['LOT_SIZE']['stepSize']
)
si['asset_type'] = 'crypto' si['asset_type'] = 'crypto'
symbol = symbols[0] symbol = symbols[0]
@ -519,7 +495,6 @@ async def stream_quotes(
subs.append("{sym}@bookTicker") subs.append("{sym}@bookTicker")
# unsub from all pairs on teardown # unsub from all pairs on teardown
if ws.connected():
await ws.send_msg({ await ws.send_msg({
"method": "UNSUBSCRIBE", "method": "UNSUBSCRIBE",
"params": subs, "params": subs,

View File

@ -39,148 +39,6 @@ _config_dir = click.get_app_dir('piker')
_watchlists_data_path = os.path.join(_config_dir, 'watchlists.json') _watchlists_data_path = os.path.join(_config_dir, 'watchlists.json')
OK = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
def print_ok(s: str, **kwargs):
print(OK + s + ENDC, **kwargs)
def print_error(s: str, **kwargs):
print(FAIL + s + ENDC, **kwargs)
def get_method(client, meth_name: str):
print(f'checking client for method \'{meth_name}\'...', end='', flush=True)
method = getattr(client, meth_name, None)
assert method
print_ok('found!.')
return method
async def run_method(client, meth_name: str, **kwargs):
method = get_method(client, meth_name)
print('running...', end='', flush=True)
result = await method(**kwargs)
print_ok(f'done! result: {type(result)}')
return result
async def run_test(broker_name: str):
brokermod = get_brokermod(broker_name)
total = 0
passed = 0
failed = 0
print(f'getting client...', end='', flush=True)
if not hasattr(brokermod, 'get_client'):
print_error('fail! no \'get_client\' context manager found.')
return
async with brokermod.get_client(is_brokercheck=True) as client:
print_ok(f'done! inside client context.')
# check for methods present on brokermod
method_list = [
'backfill_bars',
'get_client',
'trades_dialogue',
'open_history_client',
'open_symbol_search',
'stream_quotes',
]
for method in method_list:
print(
f'checking brokermod for method \'{method}\'...',
end='', flush=True)
if not hasattr(brokermod, method):
print_error(f'fail! method \'{method}\' not found.')
failed += 1
else:
print_ok('done!')
passed += 1
total += 1
# check for methods present con brokermod.Client and their
# results
# for private methods only check is present
method_list = [
'get_balances',
'get_assets',
'get_trades',
'get_xfers',
'submit_limit',
'submit_cancel',
'search_symbols',
]
for method_name in method_list:
try:
get_method(client, method_name)
passed += 1
except AssertionError:
print_error(f'fail! method \'{method_name}\' not found.')
failed += 1
total += 1
# check for methods present con brokermod.Client and their
# results
syms = await run_method(client, 'symbol_info')
total += 1
if len(syms) == 0:
raise BaseException('Empty Symbol list?')
passed += 1
first_sym = tuple(syms.keys())[0]
method_list = [
('cache_symbols', {}),
('search_symbols', {'pattern': first_sym[:-1]}),
('bars', {'symbol': first_sym})
]
for method_name, method_kwargs in method_list:
try:
await run_method(client, method_name, **method_kwargs)
passed += 1
except AssertionError:
print_error(f'fail! method \'{method_name}\' not found.')
failed += 1
total += 1
print(f'total: {total}, passed: {passed}, failed: {failed}')
@cli.command()
@click.argument('broker', nargs=1, required=True)
@click.pass_obj
def brokercheck(config, broker):
'''
Test broker apis for completeness.
'''
async def bcheck_main():
async with maybe_spawn_brokerd(broker) as portal:
await portal.run(run_test, broker)
await portal.cancel_actor()
trio.run(run_test, broker)
@cli.command() @cli.command()
@click.option('--keys', '-k', multiple=True, @click.option('--keys', '-k', multiple=True,
help='Return results only for these keys') help='Return results only for these keys')
@ -335,8 +193,6 @@ def contracts(ctx, loglevel, broker, symbol, ids):
brokermod = get_brokermod(broker) brokermod = get_brokermod(broker)
get_console_log(loglevel) get_console_log(loglevel)
contracts = trio.run(partial(core.contracts, brokermod, symbol)) contracts = trio.run(partial(core.contracts, brokermod, symbol))
if not ids: if not ids:
# just print out expiry dates which can be used with # just print out expiry dates which can be used with

View File

@ -1,70 +0,0 @@
``deribit`` backend
------------------
pretty good liquidity crypto derivatives, uses custom json rpc over ws for
client methods, then `cryptofeed` for data streams.
status
******
- supports option charts
- no order support yet
config
******
In order to get order mode support your ``brokers.toml``
needs to have something like the following:
.. code:: toml
[deribit]
key_id = 'XXXXXXXX'
key_secret = 'Xx_XxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXxXx'
To obtain an api id and secret you need to create an account, which can be a
real market account over at:
- deribit.com (requires KYC for deposit address)
Or a testnet account over at:
- test.deribit.com
For testnet once the account is created here is how you deposit fake crypto to
try it out:
1) Go to Wallet:
.. figure:: assets/0_wallet.png
:align: center
:target: assets/0_wallet.png
:alt: wallet page
2) Then click on the elipsis menu and select deposit
.. figure:: assets/1_wallet_select_deposit.png
:align: center
:target: assets/1_wallet_select_deposit.png
:alt: wallet deposit page
3) This will take you to the deposit address page
.. figure:: assets/2_gen_deposit_addr.png
:align: center
:target: assets/2_gen_deposit_addr.png
:alt: generate deposit address page
4) After clicking generate you should see the address, copy it and go to the
`coin faucet <https://test.deribit.com/dericoin/BTC/deposit>`_ and send fake
coins to that address.
.. figure:: assets/3_deposit_address.png
:align: center
:target: assets/3_deposit_address.png
:alt: generated address
5) Back in the deposit address page you should see the deposit in your history
.. figure:: assets/4_wallet_deposit_history.png
:align: center
:target: assets/4_wallet_deposit_history.png
:alt: wallet deposit history

View File

@ -1,65 +0,0 @@
# piker: trading gear for hackers
# Copyright (C) Guillermo Rodriguez (in stewardship for piker0)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Deribit backend.
'''
from piker.log import get_logger
log = get_logger(__name__)
from .api import (
get_client,
)
from .feed import (
open_history_client,
open_symbol_search,
stream_quotes,
backfill_bars
)
# from .broker import (
# trades_dialogue,
# norm_trade_records,
# )
__all__ = [
'get_client',
# 'trades_dialogue',
'open_history_client',
'open_symbol_search',
'stream_quotes',
# 'norm_trade_records',
]
# tractor RPC enable arg
__enable_modules__: list[str] = [
'api',
'feed',
# 'broker',
]
# passed to ``tractor.ActorNursery.start_actor()``
_spawn_kwargs = {
'infect_asyncio': True,
}
# annotation to let backend agnostic code
# know if ``brokerd`` should be spawned with
# ``tractor``'s aio mode.
_infect_asyncio: bool = True

View File

@ -1,672 +0,0 @@
# piker: trading gear for hackers
# Copyright (C) Guillermo Rodriguez (in stewardship for piker0)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Deribit backend.
'''
import json
import time
import asyncio
from contextlib import asynccontextmanager as acm, AsyncExitStack
from functools import partial
from datetime import datetime
from typing import Any, Optional, Iterable, Callable
import pendulum
import asks
import trio
from trio_typing import Nursery, TaskStatus
from fuzzywuzzy import process as fuzzy
import numpy as np
from piker.data.types import Struct
from piker.data._web_bs import (
NoBsWs,
open_autorecon_ws,
open_jsonrpc_session
)
from .._util import resproc
from piker import config
from piker.log import get_logger
from tractor.trionics import (
broadcast_receiver,
BroadcastReceiver,
maybe_open_context
)
from tractor import to_asyncio
from cryptofeed import FeedHandler
from cryptofeed.defines import (
DERIBIT,
L1_BOOK, TRADES,
OPTION, CALL, PUT
)
from cryptofeed.symbols import Symbol
log = get_logger(__name__)
_spawn_kwargs = {
'infect_asyncio': True,
}
_url = 'https://www.deribit.com'
_ws_url = 'wss://www.deribit.com/ws/api/v2'
_testnet_ws_url = 'wss://test.deribit.com/ws/api/v2'
# Broker specific ohlc schema (rest)
_ohlc_dtype = [
('index', int),
('time', int),
('open', float),
('high', float),
('low', float),
('close', float),
('volume', float),
('bar_wap', float), # will be zeroed by sampler if not filled
]
class JSONRPCResult(Struct):
jsonrpc: str = '2.0'
id: int
result: Optional[dict] = None
error: Optional[dict] = None
usIn: int
usOut: int
usDiff: int
testnet: bool
class JSONRPCChannel(Struct):
jsonrpc: str = '2.0'
method: str
params: dict
class KLinesResult(Struct):
close: list[float]
cost: list[float]
high: list[float]
low: list[float]
open: list[float]
status: str
ticks: list[int]
volume: list[float]
class Trade(Struct):
trade_seq: int
trade_id: str
timestamp: int
tick_direction: int
price: float
mark_price: float
iv: float
instrument_name: str
index_price: float
direction: str
combo_trade_id: Optional[int] = 0,
combo_id: Optional[str] = '',
amount: float
class LastTradesResult(Struct):
trades: list[Trade]
has_more: bool
# convert datetime obj timestamp to unixtime in milliseconds
def deribit_timestamp(when):
return int((when.timestamp() * 1000) + (when.microsecond / 1000))
def str_to_cb_sym(name: str) -> Symbol:
base, strike_price, expiry_date, option_type = name.split('-')
quote = base
if option_type == 'put':
option_type = PUT
elif option_type == 'call':
option_type = CALL
else:
raise Exception("Couldn\'t parse option type")
return Symbol(
base, quote,
type=OPTION,
strike_price=strike_price,
option_type=option_type,
expiry_date=expiry_date,
expiry_normalize=False)
def piker_sym_to_cb_sym(name: str) -> Symbol:
base, expiry_date, strike_price, option_type = tuple(
name.upper().split('-'))
quote = base
if option_type == 'P':
option_type = PUT
elif option_type == 'C':
option_type = CALL
else:
raise Exception("Couldn\'t parse option type")
return Symbol(
base, quote,
type=OPTION,
strike_price=strike_price,
option_type=option_type,
expiry_date=expiry_date.upper())
def cb_sym_to_deribit_inst(sym: Symbol):
# cryptofeed normalized
cb_norm = ['F', 'G', 'H', 'J', 'K', 'M', 'N', 'Q', 'U', 'V', 'X', 'Z']
# deribit specific
months = ['JAN', 'FEB', 'MAR', 'APR', 'MAY', 'JUN', 'JUL', 'AUG', 'SEP', 'OCT', 'NOV', 'DEC']
exp = sym.expiry_date
# YYMDD
# 01234
year, month, day = (
exp[:2], months[cb_norm.index(exp[2:3])], exp[3:])
otype = 'C' if sym.option_type == CALL else 'P'
return f'{sym.base}-{day}{month}{year}-{sym.strike_price}-{otype}'
def get_config() -> dict[str, Any]:
conf, path = config.load()
section = conf.get('deribit')
# TODO: document why we send this, basically because logging params for cryptofeed
conf['log'] = {}
conf['log']['disabled'] = True
if section is None:
log.warning(f'No config section found for deribit in {path}')
return conf
class Client:
def __init__(self, json_rpc: Callable) -> None:
self._pairs: dict[str, Any] = None
config = get_config().get('deribit', {})
if ('key_id' in config) and ('key_secret' in config):
self._key_id = config['key_id']
self._key_secret = config['key_secret']
else:
self._key_id = None
self._key_secret = None
self.json_rpc = json_rpc
@property
def currencies(self):
return ['btc', 'eth', 'sol', 'usd']
async def get_balances(self, kind: str = 'option') -> dict[str, float]:
"""Return the set of positions for this account
by symbol.
"""
balances = {}
for currency in self.currencies:
resp = await self.json_rpc(
'private/get_positions', params={
'currency': currency.upper(),
'kind': kind})
balances[currency] = resp.result
return balances
async def get_assets(self) -> dict[str, float]:
"""Return the set of asset balances for this account
by symbol.
"""
balances = {}
for currency in self.currencies:
resp = await self.json_rpc(
'private/get_account_summary', params={
'currency': currency.upper()})
balances[currency] = resp.result['balance']
return balances
async def submit_limit(
self,
symbol: str,
price: float,
action: str,
size: float
) -> dict:
"""Place an order
"""
params = {
'instrument_name': symbol.upper(),
'amount': size,
'type': 'limit',
'price': price,
}
resp = await self.json_rpc(
f'private/{action}', params)
return resp.result
async def submit_cancel(self, oid: str):
"""Send cancel request for order id
"""
resp = await self.json_rpc(
'private/cancel', {'order_id': oid})
return resp.result
async def symbol_info(
self,
instrument: Optional[str] = None,
currency: str = 'btc', # BTC, ETH, SOL, USDC
kind: str = 'option',
expired: bool = False
) -> dict[str, Any]:
"""Get symbol info for the exchange.
"""
if self._pairs:
return self._pairs
# will retrieve all symbols by default
params = {
'currency': currency.upper(),
'kind': kind,
'expired': str(expired).lower()
}
resp = await self.json_rpc('public/get_instruments', params)
results = resp.result
instruments = {
item['instrument_name'].lower(): item
for item in results
}
if instrument is not None:
return instruments[instrument]
else:
return instruments
async def cache_symbols(
self,
) -> dict:
if not self._pairs:
self._pairs = await self.symbol_info()
return self._pairs
async def search_symbols(
self,
pattern: str,
limit: int = 30,
) -> dict[str, Any]:
data = await self.symbol_info()
matches = fuzzy.extractBests(
pattern,
data,
score_cutoff=35,
limit=limit
)
# repack in dict form
return {item[0]['instrument_name'].lower(): item[0]
for item in matches}
async def bars(
self,
symbol: str,
start_dt: Optional[datetime] = None,
end_dt: Optional[datetime] = None,
limit: int = 1000,
as_np: bool = True,
) -> dict:
instrument = symbol
if end_dt is None:
end_dt = pendulum.now('UTC')
if start_dt is None:
start_dt = end_dt.start_of(
'minute').subtract(minutes=limit)
start_time = deribit_timestamp(start_dt)
end_time = deribit_timestamp(end_dt)
# https://docs.deribit.com/#public-get_tradingview_chart_data
resp = await self.json_rpc(
'public/get_tradingview_chart_data',
params={
'instrument_name': instrument.upper(),
'start_timestamp': start_time,
'end_timestamp': end_time,
'resolution': '1'
})
result = KLinesResult(**resp.result)
new_bars = []
for i in range(len(result.close)):
_open = result.open[i]
high = result.high[i]
low = result.low[i]
close = result.close[i]
volume = result.volume[i]
row = [
(start_time + (i * (60 * 1000))) / 1000.0, # time
result.open[i],
result.high[i],
result.low[i],
result.close[i],
result.volume[i],
0
]
new_bars.append((i,) + tuple(row))
array = np.array(new_bars, dtype=_ohlc_dtype) if as_np else klines
return array
async def last_trades(
self,
instrument: str,
count: int = 10
):
resp = await self.json_rpc(
'public/get_last_trades_by_instrument',
params={
'instrument_name': instrument,
'count': count
})
return LastTradesResult(**resp.result)
@acm
async def get_client(
is_brokercheck: bool = False
) -> Client:
async with (
trio.open_nursery() as n,
open_jsonrpc_session(
_testnet_ws_url, dtype=JSONRPCResult) as json_rpc
):
client = Client(json_rpc)
_refresh_token: Optional[str] = None
_access_token: Optional[str] = None
async def _auth_loop(
task_status: TaskStatus = trio.TASK_STATUS_IGNORED
):
"""Background task that adquires a first access token and then will
refresh the access token while the nursery isn't cancelled.
https://docs.deribit.com/?python#authentication-2
"""
renew_time = 10
access_scope = 'trade:read_write'
_expiry_time = time.time()
got_access = False
nonlocal _refresh_token
nonlocal _access_token
while True:
if time.time() - _expiry_time < renew_time:
# if we are close to token expiry time
if _refresh_token != None:
# if we have a refresh token already dont need to send
# secret
params = {
'grant_type': 'refresh_token',
'refresh_token': _refresh_token,
'scope': access_scope
}
else:
# we don't have refresh token, send secret to initialize
params = {
'grant_type': 'client_credentials',
'client_id': client._key_id,
'client_secret': client._key_secret,
'scope': access_scope
}
resp = await json_rpc('public/auth', params)
result = resp.result
_expiry_time = time.time() + result['expires_in']
_refresh_token = result['refresh_token']
if 'access_token' in result:
_access_token = result['access_token']
if not got_access:
# first time this loop runs we must indicate task is
# started, we have auth
got_access = True
task_status.started()
else:
await trio.sleep(renew_time / 2)
# if we have client creds launch auth loop
if client._key_id is not None:
await n.start(_auth_loop)
await client.cache_symbols()
yield client
n.cancel_scope.cancel()
@acm
async def open_feed_handler():
fh = FeedHandler(config=get_config())
yield fh
await to_asyncio.run_task(fh.stop_async)
@acm
async def maybe_open_feed_handler() -> trio.abc.ReceiveStream:
async with maybe_open_context(
acm_func=open_feed_handler,
key='feedhandler',
) as (cache_hit, fh):
yield fh
async def aio_price_feed_relay(
fh: FeedHandler,
instrument: Symbol,
from_trio: asyncio.Queue,
to_trio: trio.abc.SendChannel,
) -> None:
async def _trade(data: dict, receipt_timestamp):
to_trio.send_nowait(('trade', {
'symbol': cb_sym_to_deribit_inst(
str_to_cb_sym(data.symbol)).lower(),
'last': data,
'broker_ts': time.time(),
'data': data.to_dict(),
'receipt': receipt_timestamp
}))
async def _l1(data: dict, receipt_timestamp):
to_trio.send_nowait(('l1', {
'symbol': cb_sym_to_deribit_inst(
str_to_cb_sym(data.symbol)).lower(),
'ticks': [
{'type': 'bid',
'price': float(data.bid_price), 'size': float(data.bid_size)},
{'type': 'bsize',
'price': float(data.bid_price), 'size': float(data.bid_size)},
{'type': 'ask',
'price': float(data.ask_price), 'size': float(data.ask_size)},
{'type': 'asize',
'price': float(data.ask_price), 'size': float(data.ask_size)}
]
}))
fh.add_feed(
DERIBIT,
channels=[TRADES, L1_BOOK],
symbols=[piker_sym_to_cb_sym(instrument)],
callbacks={
TRADES: _trade,
L1_BOOK: _l1
})
if not fh.running:
fh.run(
start_loop=False,
install_signal_handlers=False)
# sync with trio
to_trio.send_nowait(None)
await asyncio.sleep(float('inf'))
@acm
async def open_price_feed(
instrument: str
) -> trio.abc.ReceiveStream:
async with maybe_open_feed_handler() as fh:
async with to_asyncio.open_channel_from(
partial(
aio_price_feed_relay,
fh,
instrument
)
) as (first, chan):
yield chan
@acm
async def maybe_open_price_feed(
instrument: str
) -> trio.abc.ReceiveStream:
# TODO: add a predicate to maybe_open_context
async with maybe_open_context(
acm_func=open_price_feed,
kwargs={
'instrument': instrument
},
key=f'{instrument}-price',
) as (cache_hit, feed):
if cache_hit:
yield broadcast_receiver(feed, 10)
else:
yield feed
async def aio_order_feed_relay(
fh: FeedHandler,
instrument: Symbol,
from_trio: asyncio.Queue,
to_trio: trio.abc.SendChannel,
) -> None:
async def _fill(data: dict, receipt_timestamp):
breakpoint()
async def _order_info(data: dict, receipt_timestamp):
breakpoint()
fh.add_feed(
DERIBIT,
channels=[FILLS, ORDER_INFO],
symbols=[instrument.upper()],
callbacks={
FILLS: _fill,
ORDER_INFO: _order_info,
})
if not fh.running:
fh.run(
start_loop=False,
install_signal_handlers=False)
# sync with trio
to_trio.send_nowait(None)
await asyncio.sleep(float('inf'))
@acm
async def open_order_feed(
instrument: list[str]
) -> trio.abc.ReceiveStream:
async with maybe_open_feed_handler() as fh:
async with to_asyncio.open_channel_from(
partial(
aio_order_feed_relay,
fh,
instrument
)
) as (first, chan):
yield chan
@acm
async def maybe_open_order_feed(
instrument: str
) -> trio.abc.ReceiveStream:
# TODO: add a predicate to maybe_open_context
async with maybe_open_context(
acm_func=open_order_feed,
kwargs={
'instrument': instrument,
'fh': fh
},
key=f'{instrument}-order',
) as (cache_hit, feed):
if cache_hit:
yield broadcast_receiver(feed, 10)
else:
yield feed

Binary file not shown.

Before

Width:  |  Height:  |  Size: 169 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 106 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 59 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 70 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 132 KiB

View File

@ -1,185 +0,0 @@
# piker: trading gear for hackers
# Copyright (C) Guillermo Rodriguez (in stewardship for piker0)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Deribit backend.
'''
from contextlib import asynccontextmanager as acm
from datetime import datetime
from typing import Any, Optional, Callable
import time
import trio
from trio_typing import TaskStatus
import pendulum
from fuzzywuzzy import process as fuzzy
import numpy as np
import tractor
from piker._cacheables import open_cached_client
from piker.log import get_logger, get_console_log
from piker.data import ShmArray
from piker.brokers._util import (
BrokerError,
DataUnavailable,
)
from cryptofeed import FeedHandler
from cryptofeed.defines import (
DERIBIT, L1_BOOK, TRADES, OPTION, CALL, PUT
)
from cryptofeed.symbols import Symbol
from .api import (
Client, Trade,
get_config,
str_to_cb_sym, piker_sym_to_cb_sym, cb_sym_to_deribit_inst,
maybe_open_price_feed
)
_spawn_kwargs = {
'infect_asyncio': True,
}
log = get_logger(__name__)
@acm
async def open_history_client(
instrument: str,
) -> tuple[Callable, int]:
# TODO implement history getter for the new storage layer.
async with open_cached_client('deribit') as client:
async def get_ohlc(
end_dt: Optional[datetime] = None,
start_dt: Optional[datetime] = None,
) -> tuple[
np.ndarray,
datetime, # start
datetime, # end
]:
array = await client.bars(
instrument,
start_dt=start_dt,
end_dt=end_dt,
)
if len(array) == 0:
raise DataUnavailable
start_dt = pendulum.from_timestamp(array[0]['time'])
end_dt = pendulum.from_timestamp(array[-1]['time'])
return array, start_dt, end_dt
yield get_ohlc, {'erlangs': 3, 'rate': 3}
async def stream_quotes(
send_chan: trio.abc.SendChannel,
symbols: list[str],
feed_is_live: trio.Event,
loglevel: str = None,
# startup sync
task_status: TaskStatus[tuple[dict, dict]] = trio.TASK_STATUS_IGNORED,
) -> None:
# XXX: required to propagate ``tractor`` loglevel to piker logging
get_console_log(loglevel or tractor.current_actor().loglevel)
sym = symbols[0]
async with (
open_cached_client('deribit') as client,
send_chan as send_chan
):
init_msgs = {
# pass back token, and bool, signalling if we're the writer
# and that history has been written
sym: {
'symbol_info': {
'asset_type': 'option',
'price_tick_size': 0.0005
},
'shm_write_opts': {'sum_tick_vml': False},
'fqsn': sym,
},
}
nsym = piker_sym_to_cb_sym(sym)
async with maybe_open_price_feed(sym) as stream:
cache = await client.cache_symbols()
last_trades = (await client.last_trades(
cb_sym_to_deribit_inst(nsym), count=1)).trades
if len(last_trades) == 0:
last_trade = None
async for typ, quote in stream:
if typ == 'trade':
last_trade = Trade(**(quote['data']))
break
else:
last_trade = Trade(**(last_trades[0]))
first_quote = {
'symbol': sym,
'last': last_trade.price,
'brokerd_ts': last_trade.timestamp,
'ticks': [{
'type': 'trade',
'price': last_trade.price,
'size': last_trade.amount,
'broker_ts': last_trade.timestamp
}]
}
task_status.started((init_msgs, first_quote))
feed_is_live.set()
async for typ, quote in stream:
topic = quote['symbol']
await send_chan.send({topic: quote})
@tractor.context
async def open_symbol_search(
ctx: tractor.Context,
) -> Client:
async with open_cached_client('deribit') as client:
# load all symbols locally for fast search
cache = await client.cache_symbols()
await ctx.started()
async with ctx.open_stream() as stream:
async for pattern in stream:
# repack in dict form
await stream.send(
await client.search_symbols(pattern))

View File

@ -1,134 +0,0 @@
``ib`` backend
--------------
more or less the "everything broker" for traditional and international
markets. they are the "go to" provider for automatic retail trading
and we interface to their APIs using the `ib_insync` project.
status
******
current support is *production grade* and both real-time data and order
management should be correct and fast. this backend is used by core devs
for live trading.
currently there is not yet full support for:
- options charting and trading
- paxos based crypto rt feeds and trading
config
******
In order to get order mode support your ``brokers.toml``
needs to have something like the following:
.. code:: toml
[ib]
hosts = [
"127.0.0.1",
]
# TODO: when we eventually spawn gateways in our
# container, we can just dynamically allocate these
# using IBC.
ports = [
4002,
4003,
4006,
4001,
7497,
]
# XXX: for a paper account the flex web query service
# is not supported so you have to manually download
# and XML report and put it in a location that can be
# accessed by the ``brokerd.ib`` backend code for parsing.
flex_token = '1111111111111111'
flex_trades_query_id = '6969696' # live accounts only?
# 3rd party web-api token
# (XXX: not sure if this works yet)
trade_log_token = '111111111111111'
# when clients are being scanned this determines
# which clients are preferred to be used for data feeds
# based on account names which are detected as active
# on each client.
prefer_data_account = [
# this has to be first in order to make data work with dual paper + live
'main',
'algopaper',
]
[ib.accounts]
main = 'U69696969'
algopaper = 'DU9696969'
If everything works correctly you should see any current positions
loaded in the pps pane on chart load and you should also be able to
check your trade records in the file::
<pikerk_conf_dir>/ledgers/trades_ib_algopaper.toml
An example ledger file will have entries written verbatim from the
trade events schema:
.. code:: toml
["0000e1a7.630f5e5a.01.01"]
secType = "FUT"
conId = 515416577
symbol = "MNQ"
lastTradeDateOrContractMonth = "20221216"
strike = 0.0
right = ""
multiplier = "2"
exchange = "GLOBEX"
primaryExchange = ""
currency = "USD"
localSymbol = "MNQZ2"
tradingClass = "MNQ"
includeExpired = false
secIdType = ""
secId = ""
comboLegsDescrip = ""
comboLegs = []
execId = "0000e1a7.630f5e5a.01.01"
time = 1661972086.0
acctNumber = "DU69696969"
side = "BOT"
shares = 1.0
price = 12372.75
permId = 441472655
clientId = 6116
orderId = 985
liquidation = 0
cumQty = 1.0
avgPrice = 12372.75
orderRef = ""
evRule = ""
evMultiplier = 0.0
modelCode = ""
lastLiquidity = 1
broker_time = 1661972086.0
name = "ib"
commission = 0.57
realizedPNL = 243.41
yield_ = 0.0
yieldRedemptionDate = 0
listingExchange = "GLOBEX"
date = "2022-08-31T18:54:46+00:00"
your ``pps.toml`` file will have position entries like,
.. code:: toml
[ib.algopaper."mnq.globex.20221216"]
size = -1.0
ppu = 12423.630576923071
bsuid = 515416577
expiry = "2022-12-16T00:00:00+00:00"
clears = [
{ dt = "2022-08-31T18:54:46+00:00", ppu = 12423.630576923071, accum_size = -19.0, price = 12372.75, size = 1.0, cost = 0.57, tid = "0000e1a7.630f5e5a.01.01" },
]

View File

@ -20,10 +20,15 @@ Interactive Brokers API backend.
Sub-modules within break into the core functionalities: Sub-modules within break into the core functionalities:
- ``broker.py`` part for orders / trading endpoints - ``broker.py`` part for orders / trading endpoints
- ``feed.py`` for real-time data feed endpoints - ``data.py`` for real-time data feed endpoints
- ``api.py`` for the core API machinery which is ``trio``-ized
- ``client.py`` for the core API machinery which is ``trio``-ized
wrapping around ``ib_insync``. wrapping around ``ib_insync``.
- ``report.py`` for the hackery to build manual pp calcs
to avoid ib's absolute bullshit FIFO style position
tracking..
""" """
from .api import ( from .api import (
get_client, get_client,
@ -33,10 +38,7 @@ from .feed import (
open_symbol_search, open_symbol_search,
stream_quotes, stream_quotes,
) )
from .broker import ( from .broker import trades_dialogue
trades_dialogue,
norm_trade_records,
)
__all__ = [ __all__ = [
'get_client', 'get_client',

View File

@ -29,7 +29,6 @@ import itertools
from math import isnan from math import isnan
from typing import ( from typing import (
Any, Any,
Optional,
Union, Union,
) )
import asyncio import asyncio
@ -39,30 +38,16 @@ import time
from types import SimpleNamespace from types import SimpleNamespace
from bidict import bidict
import trio import trio
import tractor import tractor
from tractor import to_asyncio from tractor import to_asyncio
import pendulum from ib_insync.wrapper import RequestError
import ib_insync as ibis from ib_insync.contract import Contract, ContractDetails
from ib_insync.contract import (
Contract,
ContractDetails,
Option,
)
from ib_insync.order import Order from ib_insync.order import Order
from ib_insync.ticker import Ticker from ib_insync.ticker import Ticker
from ib_insync.objects import ( from ib_insync.objects import Position
BarDataList, import ib_insync as ibis
Position, from ib_insync.wrapper import Wrapper
Fill,
Execution,
CommissionReport,
)
from ib_insync.wrapper import (
Wrapper,
RequestError,
)
from ib_insync.client import Client as ib_Client from ib_insync.client import Client as ib_Client
import numpy as np import numpy as np
@ -80,11 +65,26 @@ _time_units = {
'h': ' hours', 'h': ' hours',
} }
_bar_sizes = { _time_frames = {
1: '1 Sec', '1s': '1 Sec',
60: '1 min', '5s': '5 Sec',
60*60: '1 hour', '30s': '30 Sec',
24*60*60: '1 day', '1m': 'OneMinute',
'2m': 'TwoMinutes',
'3m': 'ThreeMinutes',
'4m': 'FourMinutes',
'5m': 'FiveMinutes',
'10m': 'TenMinutes',
'15m': 'FifteenMinutes',
'20m': 'TwentyMinutes',
'30m': 'HalfHour',
'1h': 'OneHour',
'2h': 'TwoHours',
'4h': 'FourHours',
'D': 'OneDay',
'W': 'OneWeek',
'M': 'OneMonth',
'Y': 'OneYear',
} }
_show_wap_in_history: bool = False _show_wap_in_history: bool = False
@ -155,118 +155,70 @@ class NonShittyIB(ibis.IB):
self.client.apiEnd += self.disconnectedEvent self.client.apiEnd += self.disconnectedEvent
_futes_venues = (
'GLOBEX',
'NYMEX',
'CME',
'CMECRYPTO',
'COMEX',
'CMDTY', # special name case..
'CBOT', # (treasury) yield futures
)
_adhoc_futes_set = {
# equities
'nq.cme',
'mnq.cme', # micro
'es.cme',
'mes.cme', # micro
# cypto$
'brr.cme',
'ethusdrr.cme',
# agriculture
'he.comex', # lean hogs
'le.comex', # live cattle (geezers)
'gf.comex', # feeder cattle (younguns)
# raw
'lb.comex', # random len lumber
# metals
# https://misc.interactivebrokers.com/cstools/contract_info/v3.10/index.php?action=Conid%20Info&wlId=IB&conid=69067924
'xauusd.cmdty', # london gold spot ^
'gc.comex',
'mgc.comex', # micro
# oil & gas
'cl.comex',
'xagusd.cmdty', # silver spot
'ni.comex', # silver futes
'qi.comex', # mini-silver futes
# treasury yields
# etfs by duration:
# SHY -> IEI -> IEF -> TLT
'zt.cbot', # 2y
'z3n.cbot', # 3y
'zf.cbot', # 5y
'zn.cbot', # 10y
'zb.cbot', # 30y
# (micros of above)
'2yy.cbot',
'5yy.cbot',
'10y.cbot',
'30y.cbot',
}
# taken from list here:
# https://www.interactivebrokers.com/en/trading/products-spot-currencies.php
_adhoc_fiat_set = set((
'USD, AED, AUD, CAD,'
'CHF, CNH, CZK, DKK,'
'EUR, GBP, HKD, HUF,'
'ILS, JPY, MXN, NOK,'
'NZD, PLN, RUB, SAR,'
'SEK, SGD, TRY, ZAR'
).split(' ,')
)
# map of symbols to contract ids # map of symbols to contract ids
_adhoc_symbol_map = { _adhoc_cmdty_data_map = {
# https://misc.interactivebrokers.com/cstools/contract_info/v3.10/index.php?action=Conid%20Info&wlId=IB&conid=69067924 # https://misc.interactivebrokers.com/cstools/contract_info/v3.10/index.php?action=Conid%20Info&wlId=IB&conid=69067924
# NOTE: some cmdtys/metals don't have trade data like gold/usd: # NOTE: some cmdtys/metals don't have trade data like gold/usd:
# https://groups.io/g/twsapi/message/44174 # https://groups.io/g/twsapi/message/44174
'XAUUSD': ({'conId': 69067924}, {'whatToShow': 'MIDPOINT'}), 'XAUUSD': ({'conId': 69067924}, {'whatToShow': 'MIDPOINT'}),
} }
for qsn in _adhoc_futes_set:
sym, venue = qsn.split('.')
assert venue.upper() in _futes_venues, f'{venue}'
_adhoc_symbol_map[sym.upper()] = (
{'exchange': venue},
{},
)
_futes_venues = (
'GLOBEX',
'NYMEX',
'CME',
'CMECRYPTO',
)
_adhoc_futes_set = {
# equities
'nq.globex',
'mnq.globex',
'es.globex',
'mes.globex',
# cypto$
'brr.cmecrypto',
'ethusdrr.cmecrypto',
# agriculture
'he.globex', # lean hogs
'le.globex', # live cattle (geezers)
'gf.globex', # feeder cattle (younguns)
# raw
'lb.globex', # random len lumber
# metals
'xauusd.cmdty', # gold spot
'gc.nymex',
'mgc.nymex',
'xagusd.cmdty', # silver spot
'ni.nymex', # silver futes
'qi.comex', # mini-silver futes
}
# exchanges we don't support at the moment due to not knowing # exchanges we don't support at the moment due to not knowing
# how to do symbol-contract lookup correctly likely due # how to do symbol-contract lookup correctly likely due
# to not having the data feeds subscribed. # to not having the data feeds subscribed.
_exch_skip_list = { _exch_skip_list = {
'ASX', # aussie stocks 'ASX', # aussie stocks
'MEXI', # mexican stocks 'MEXI', # mexican stocks
'VALUE', # no idea
# no idea
'VALUE',
'FUNDSERV',
'SWB2',
'PSE',
} }
# https://misc.interactivebrokers.com/cstools/contract_info/v3.10/index.php?action=Conid%20Info&wlId=IB&conid=69067924
_enters = 0 _enters = 0
def bars_to_np(bars: list) -> np.ndarray: def bars_to_np(bars: list) -> np.ndarray:
''' '''
Convert a "bars list thing" (``BarDataList`` type from ibis) Convert a "bars list thing" (``BarsList`` type from ibis)
into a numpy struct array. into a numpy struct array.
''' '''
@ -286,27 +238,6 @@ def bars_to_np(bars: list) -> np.ndarray:
return nparr return nparr
# NOTE: pacing violations exist for higher sample rates:
# https://interactivebrokers.github.io/tws-api/historical_limitations.html#pacing_violations
# Also see note on duration limits being lifted on 1m+ periods,
# but they say "use with discretion":
# https://interactivebrokers.github.io/tws-api/historical_limitations.html#non-available_hd
_samplings: dict[int, tuple[str, str]] = {
1: (
'1 secs',
f'{int(2e3)} S',
pendulum.duration(seconds=2e3),
),
# TODO: benchmark >1 D duration on query to see if
# throughput can be made faster during backfilling.
60: (
'1 min',
'1 D',
pendulum.duration(days=1),
),
}
class Client: class Client:
''' '''
IB wrapped for our broker backend API. IB wrapped for our broker backend API.
@ -330,29 +261,27 @@ class Client:
# NOTE: the ib.client here is "throttled" to 45 rps by default # NOTE: the ib.client here is "throttled" to 45 rps by default
async def trades(self) -> dict[str, Any]: async def trades(
''' self,
Return list of trade-fills from current session in ``dict``. # api_only: bool = False,
''' ) -> dict[str, Any]:
fills: list[Fill] = self.ib.fills()
norm_fills: list[dict] = [] # orders = await self.ib.reqCompletedOrdersAsync(
# apiOnly=api_only
# )
fills = await self.ib.reqExecutionsAsync()
norm_fills = []
for fill in fills: for fill in fills:
fill = fill._asdict() # namedtuple fill = fill._asdict() # namedtuple
for key, val in fill.items(): for key, val in fill.copy().items():
match val: if isinstance(val, Contract):
case Contract() | Execution() | CommissionReport():
fill[key] = asdict(val) fill[key] = asdict(val)
norm_fills.append(fill) norm_fills.append(fill)
return norm_fills return norm_fills
async def orders(self) -> list[Order]:
return await self.ib.reqAllOpenOrdersAsync(
apiOnly=False,
)
async def bars( async def bars(
self, self,
fqsn: str, fqsn: str,
@ -361,55 +290,52 @@ class Client:
start_dt: Union[datetime, str] = "1970-01-01T00:00:00.000000-05:00", start_dt: Union[datetime, str] = "1970-01-01T00:00:00.000000-05:00",
end_dt: Union[datetime, str] = "", end_dt: Union[datetime, str] = "",
# ohlc sample period in seconds sample_period_s: str = 1, # ohlc sample period
sample_period_s: int = 1, period_count: int = int(2e3), # <- max per 1s sample query
# optional "duration of time" equal to the ) -> list[dict[str, Any]]:
# length of the returned history frame.
duration: Optional[str] = None,
**kwargs,
) -> tuple[BarDataList, np.ndarray, pendulum.Duration]:
''' '''
Retreive OHLCV bars for a fqsn over a range to the present. Retreive OHLCV bars for a fqsn over a range to the present.
''' '''
# See API docs here:
# https://interactivebrokers.github.io/tws-api/historical_data.html
bars_kwargs = {'whatToShow': 'TRADES'} bars_kwargs = {'whatToShow': 'TRADES'}
bars_kwargs.update(kwargs)
bar_size, duration, dt_duration = _samplings[sample_period_s]
global _enters global _enters
# log.info(f'REQUESTING BARS {_enters} @ end={end_dt}') # log.info(f'REQUESTING BARS {_enters} @ end={end_dt}')
print( print(f'REQUESTING BARS {_enters} @ end={end_dt}')
f"REQUESTING {duration}'s worth {bar_size} BARS\n"
f'{_enters} @ end={end_dt}"'
)
if not end_dt: if not end_dt:
end_dt = '' end_dt = ''
_enters += 1 _enters += 1
contract = (await self.find_contracts(fqsn))[0] contract = await self.find_contract(fqsn)
bars_kwargs.update(getattr(contract, 'bars_kwargs', {})) bars_kwargs.update(getattr(contract, 'bars_kwargs', {}))
# _min = min(2000*100, count)
bars = await self.ib.reqHistoricalDataAsync( bars = await self.ib.reqHistoricalDataAsync(
contract, contract,
endDateTime=end_dt, endDateTime=end_dt,
formatDate=2, formatDate=2,
# time history length values format:
# ``durationStr=integer{SPACE}unit (S|D|W|M|Y)``
# OHLC sampling values: # OHLC sampling values:
# 1 secs, 5 secs, 10 secs, 15 secs, 30 secs, 1 min, 2 mins, # 1 secs, 5 secs, 10 secs, 15 secs, 30 secs, 1 min, 2 mins,
# 3 mins, 5 mins, 10 mins, 15 mins, 20 mins, 30 mins, # 3 mins, 5 mins, 10 mins, 15 mins, 20 mins, 30 mins,
# 1 hour, 2 hours, 3 hours, 4 hours, 8 hours, 1 day, 1W, 1M # 1 hour, 2 hours, 3 hours, 4 hours, 8 hours, 1 day, 1W, 1M
barSizeSetting=bar_size, # barSizeSetting='1 secs',
# time history length values format: # durationStr='{count} S'.format(count=15000 * 5),
# ``durationStr=integer{SPACE}unit (S|D|W|M|Y)`` # durationStr='{count} D'.format(count=1),
durationStr=duration, # barSizeSetting='5 secs',
durationStr='{count} S'.format(count=period_count),
# barSizeSetting='5 secs',
barSizeSetting='1 secs',
# barSizeSetting='1 min',
# always use extended hours # always use extended hours
useRTH=False, useRTH=False,
@ -420,21 +346,11 @@ class Client:
# whatToShow='TRADES', # whatToShow='TRADES',
) )
if not bars: if not bars:
# NOTE: there's 2 cases here to handle (and this should be # TODO: raise underlying error here
# read alongside the implementation of raise ValueError(f"No bars retreived for {fqsn}?")
# ``.reqHistoricalDataAsync()``):
# - no data is returned for the period likely due to
# a weekend, holiday or other non-trading period prior to
# ``end_dt`` which exceeds the ``duration``,
# - a timeout occurred in which case insync internals return
# an empty list thing with bars.clear()...
return [], np.empty(0), dt_duration
# TODO: we could maybe raise ``NoData`` instead if we
# rewrite the method in the first case? right now there's no
# way to detect a timeout.
nparr = bars_to_np(bars) nparr = bars_to_np(bars)
return bars, nparr, dt_duration return bars, nparr
async def con_deats( async def con_deats(
self, self,
@ -448,15 +364,7 @@ class Client:
futs.append(self.ib.reqContractDetailsAsync(con)) futs.append(self.ib.reqContractDetailsAsync(con))
# batch request all details # batch request all details
try:
results = await asyncio.gather(*futs) results = await asyncio.gather(*futs)
except RequestError as err:
msg = err.message
if (
'No security definition' in msg
):
log.warning(f'{msg}: {contracts}')
return {}
# one set per future result # one set per future result
details = {} details = {}
@ -465,11 +373,20 @@ class Client:
# XXX: if there is more then one entry in the details list # XXX: if there is more then one entry in the details list
# then the contract is so called "ambiguous". # then the contract is so called "ambiguous".
for d in details_set: for d in details_set:
con = d.contract
# nested dataclass we probably don't need and that won't key = '.'.join([
# IPC serialize.. con.symbol,
con.primaryExchange or con.exchange,
])
expiry = con.lastTradeDateOrContractMonth
if expiry:
key += f'.{expiry}'
# nested dataclass we probably don't need and that
# won't IPC serialize..
d.secIdList = '' d.secIdList = ''
key, calc_price = con2fqsn(d.contract)
details[key] = d details[key] = d
return details return details
@ -499,20 +416,17 @@ class Client:
self, self,
pattern: str, pattern: str,
# how many contracts to search "up to" # how many contracts to search "up to"
upto: int = 16, upto: int = 3,
asdicts: bool = True, asdicts: bool = True,
) -> dict[str, ContractDetails]: ) -> dict[str, ContractDetails]:
# TODO add search though our adhoc-locally defined symbol set # TODO add search though our adhoc-locally defined symbol set
# for futes/cmdtys/ # for futes/cmdtys/
try:
results = await self.search_stocks( results = await self.search_stocks(
pattern, pattern,
upto=upto, upto=upto,
) )
except ConnectionError:
return {}
for key, deats in results.copy().items(): for key, deats in results.copy().items():
@ -523,54 +437,21 @@ class Client:
if sectype == 'IND': if sectype == 'IND':
results[f'{sym}.IND'] = tract results[f'{sym}.IND'] = tract
results.pop(key) results.pop(key)
# exch = tract.exchange
# XXX: add back one of these to get the weird deadlock
# on the debugger from root without the latest
# maybe_wait_for_debugger() fix in the `open_context()`
# exit.
# assert 0
# if con.exchange not in _exch_skip_list:
exch = tract.exchange exch = tract.exchange
if exch not in _exch_skip_list:
# try to lookup any contracts from our adhoc set
# since often the exchange/venue is named slightly
# different (eg. BRR.CMECRYPTO` instead of just
# `.CME`).
info = _adhoc_symbol_map.get(sym)
if info:
con_kwargs, bars_kwargs = info
exch = con_kwargs['exchange']
if exch in _futes_venues:
# try get all possible contracts for symbol as per, # try get all possible contracts for symbol as per,
# https://interactivebrokers.github.io/tws-api/basic_contracts.html#fut # https://interactivebrokers.github.io/tws-api/basic_contracts.html#fut
con = ibis.Future( con = ibis.Future(
symbol=sym, symbol=sym,
exchange=exch, exchange=exch,
) )
# TODO: make this work, think it's something to do try:
# with the qualify flag.
# cons = await self.find_contracts(
# contract=con,
# err_on_qualify=False,
# )
# if cons:
all_deats = await self.con_deats([con]) all_deats = await self.con_deats([con])
results |= all_deats results |= all_deats
# forex pairs except RequestError as err:
elif sectype == 'CASH': log.warning(err.message)
dst, src = tract.localSymbol.split('.')
pair_key = "/".join([dst, src])
exch = tract.exchange.lower()
results[f'{pair_key}.{exch}'] = tract
results.pop(key)
# XXX: again seems to trigger the weird tractor
# bug with the debugger..
# assert 0
return results return results
@ -602,19 +483,13 @@ class Client:
return con return con
async def get_con( async def find_contract(
self,
conid: int,
) -> Contract:
return await self.ib.qualifyContractsAsync(
ibis.Contract(conId=conid)
)
def parse_patt2fqsn(
self, self,
pattern: str, pattern: str,
currency: str = 'USD',
**kwargs,
) -> tuple[str, str, str, str]: ) -> Contract:
# TODO: we can't use this currently because # TODO: we can't use this currently because
# ``wrapper.starTicker()`` currently cashes ticker instances # ``wrapper.starTicker()`` currently cashes ticker instances
@ -627,30 +502,12 @@ class Client:
# XXX UPDATE: we can probably do the tick/trades scraping # XXX UPDATE: we can probably do the tick/trades scraping
# inside our eventkit handler instead to bypass this entirely? # inside our eventkit handler instead to bypass this entirely?
currency = ''
# fqsn parsing stage
# ------------------
if '.ib' in pattern: if '.ib' in pattern:
from ..data._source import unpack_fqsn from ..data._source import unpack_fqsn
_, symbol, expiry = unpack_fqsn(pattern) broker, symbol, expiry = unpack_fqsn(pattern)
else: else:
symbol = pattern symbol = pattern
expiry = ''
# another hack for forex pairs lul.
if (
'.idealpro' in symbol
# or '/' in symbol
):
exch = 'IDEALPRO'
symbol = symbol.removesuffix('.idealpro')
if '/' in symbol:
symbol, currency = symbol.split('/')
else:
# TODO: yes, a cache..
# try: # try:
# # give the cache a go # # give the cache a go
# return self._contracts[symbol] # return self._contracts[symbol]
@ -661,80 +518,45 @@ class Client:
symbol, _, expiry = symbol.rpartition('.') symbol, _, expiry = symbol.rpartition('.')
# use heuristics to figure out contract "type" # use heuristics to figure out contract "type"
symbol, exch = symbol.upper().rsplit('.', maxsplit=1) sym, exch = symbol.upper().rsplit('.', maxsplit=1)
return symbol, currency, exch, expiry qualify: bool = True
async def find_contracts(
self,
pattern: Optional[str] = None,
contract: Optional[Contract] = None,
qualify: bool = True,
err_on_qualify: bool = True,
) -> Contract:
if pattern is not None:
symbol, currency, exch, expiry = self.parse_patt2fqsn(
pattern,
)
sectype = ''
else:
assert contract
symbol = contract.symbol
sectype = contract.secType
exch = contract.exchange or contract.primaryExchange
expiry = contract.lastTradeDateOrContractMonth
currency = contract.currency
# contract searching stage
# ------------------------
# futes # futes
if exch in _futes_venues: if exch in _futes_venues:
if expiry: if expiry:
# get the "front" contract # get the "front" contract
con = await self.get_fute( contract = await self.get_fute(
symbol=symbol, symbol=sym,
exchange=exch, exchange=exch,
expiry=expiry, expiry=expiry,
) )
else: else:
# get the "front" contract # get the "front" contract
con = await self.get_fute( contract = await self.get_fute(
symbol=symbol, symbol=sym,
exchange=exch, exchange=exch,
front=True, front=True,
) )
elif ( qualify = False
exch in ('IDEALPRO')
or sectype == 'CASH' elif exch in ('FOREX'):
): currency = ''
# if '/' in symbol: symbol, currency = sym.split('/')
# currency = ''
# symbol, currency = symbol.split('/')
con = ibis.Forex( con = ibis.Forex(
pair=''.join((symbol, currency)), symbol=symbol,
currency=currency, currency=currency,
) )
con.bars_kwargs = {'whatToShow': 'MIDPOINT'} con.bars_kwargs = {'whatToShow': 'MIDPOINT'}
# commodities # commodities
elif exch == 'CMDTY': # eg. XAUUSD.CMDTY elif exch == 'CMDTY': # eg. XAUUSD.CMDTY
con_kwargs, bars_kwargs = _adhoc_symbol_map[symbol] con_kwargs, bars_kwargs = _adhoc_cmdty_data_map[sym]
con = ibis.Commodity(**con_kwargs) con = ibis.Commodity(**con_kwargs)
con.bars_kwargs = bars_kwargs con.bars_kwargs = bars_kwargs
# crypto$
elif exch == 'PAXOS': # btc.paxos
con = ibis.Crypto(
symbol=symbol,
currency=currency,
)
# stonks # stonks
else: else:
# TODO: metadata system for all these exchange rules.. # TODO: metadata system for all these exchange rules..
@ -747,61 +569,41 @@ class Client:
exch = 'SMART' exch = 'SMART'
else: else:
# XXX: order is super important here since
# a primary == 'SMART' won't ever work.
primaryExchange = exch
exch = 'SMART' exch = 'SMART'
primaryExchange = exch
con = ibis.Stock( con = ibis.Stock(
symbol=symbol, symbol=sym,
exchange=exch, exchange=exch,
primaryExchange=primaryExchange, primaryExchange=primaryExchange,
currency=currency, currency=currency,
) )
exch = 'SMART' if not exch else exch
contracts = [con]
if qualify:
try: try:
contracts = await self.ib.qualifyContractsAsync(con) exch = 'SMART' if not exch else exch
except RequestError as err: if qualify:
msg = err.message contract = (await self.ib.qualifyContractsAsync(con))[0]
if (
'No security definition' in msg
and not err_on_qualify
):
log.warning(
f'Could not find def for {con}')
return None
else: else:
raise assert contract
if not contracts:
except IndexError:
raise ValueError(f"No contract could be found {con}") raise ValueError(f"No contract could be found {con}")
# pack all contracts into cache self._contracts[pattern] = contract
for tract in contracts:
exch: str = tract.primaryExchange or tract.exchange or exch
pattern = f'{symbol}.{exch}'
expiry = tract.lastTradeDateOrContractMonth
# add an entry with expiry suffix if available
if expiry:
pattern += f'.{expiry}'
self._contracts[pattern.lower()] = tract # add an aditional entry with expiry suffix if available
conexp = contract.lastTradeDateOrContractMonth
if conexp:
self._contracts[pattern + f'.{conexp}'] = contract
return contracts return contract
async def get_head_time( async def get_head_time(
self, self,
fqsn: str, contract: Contract,
) -> datetime: ) -> datetime:
''' """Return the first datetime stamp for ``contract``.
Return the first datetime stamp for ``contract``.
''' """
contract = (await self.find_contracts(fqsn))[0]
return await self.ib.reqHeadTimeStampAsync( return await self.ib.reqHeadTimeStampAsync(
contract, contract,
whatToShow='TRADES', whatToShow='TRADES',
@ -812,10 +614,9 @@ class Client:
async def get_sym_details( async def get_sym_details(
self, self,
symbol: str, symbol: str,
) -> tuple[Contract, Ticker, ContractDetails]: ) -> tuple[Contract, Ticker, ContractDetails]:
contract = (await self.find_contracts(symbol))[0] contract = await self.find_contract(symbol)
ticker: Ticker = self.ib.reqMktData( ticker: Ticker = self.ib.reqMktData(
contract, contract,
snapshot=True, snapshot=True,
@ -871,7 +672,9 @@ class Client:
# async to be consistent for the client proxy, and cuz why not. # async to be consistent for the client proxy, and cuz why not.
def submit_limit( def submit_limit(
self, self,
oid: str, # ignored since doesn't support defining your own # ignored since ib doesn't support defining your
# own order id
oid: str,
symbol: str, symbol: str,
price: float, price: float,
action: str, action: str,
@ -887,9 +690,6 @@ class Client:
''' '''
Place an order and return integer request id provided by client. Place an order and return integer request id provided by client.
Relevant docs:
- https://interactivebrokers.github.io/tws-api/order_limitations.html
''' '''
try: try:
contract = self._contracts[symbol] contract = self._contracts[symbol]
@ -915,9 +715,6 @@ class Client:
optOutSmartRouting=True, optOutSmartRouting=True,
routeMarketableToBbo=True, routeMarketableToBbo=True,
designatedLocation='SMART', designatedLocation='SMART',
# TODO: make all orders GTC?
# https://interactivebrokers.github.io/tws-api/classIBApi_1_1Order.html#a95539081751afb9980f4c6bd1655a6ba
# goodTillDate=f"yyyyMMdd-HH:mm:ss",
), ),
) )
except AssertionError: # errrg insync.. except AssertionError: # errrg insync..
@ -1007,76 +804,6 @@ class Client:
return self.ib.positions(account=account) return self.ib.positions(account=account)
def con2fqsn(
con: Contract,
_cache: dict[int, (str, bool)] = {}
) -> tuple[str, bool]:
'''
Convert contracts to fqsn-style strings to be used both in symbol-search
matching and as feed tokens passed to the front end data deed layer.
Previously seen contracts are cached by id.
'''
# should be real volume for this contract by default
calc_price = False
if con.conId:
try:
return _cache[con.conId]
except KeyError:
pass
suffix = con.primaryExchange or con.exchange
symbol = con.symbol
expiry = con.lastTradeDateOrContractMonth or ''
match con:
case Option():
# TODO: option symbol parsing and sane display:
symbol = con.localSymbol.replace(' ', '')
case ibis.Commodity():
# commodities and forex don't have an exchange name and
# no real volume so we have to calculate the price
suffix = con.secType
# no real volume on this tract
calc_price = True
case ibis.Forex() | ibis.Contract(secType='CASH'):
dst, src = con.localSymbol.split('.')
symbol = ''.join([dst, src])
suffix = con.exchange or 'idealpro'
# no real volume on forex feeds..
calc_price = True
if not suffix:
entry = _adhoc_symbol_map.get(
con.symbol or con.localSymbol
)
if entry:
meta, kwargs = entry
cid = meta.get('conId')
if cid:
assert con.conId == meta['conId']
suffix = meta['exchange']
# append a `.<suffix>` to the returned symbol
# key for derivatives that normally is the expiry
# date key.
if expiry:
suffix += f'.{expiry}'
fqsn_key = symbol.lower()
if suffix:
fqsn_key = '.'.join((fqsn_key, suffix)).lower()
_cache[con.conId] = fqsn_key, calc_price
return fqsn_key, calc_price
# per-actor API ep caching # per-actor API ep caching
_client_cache: dict[tuple[str, int], Client] = {} _client_cache: dict[tuple[str, int], Client] = {}
_scan_ignore: set[tuple[str, int]] = set() _scan_ignore: set[tuple[str, int]] = set()
@ -1084,23 +811,10 @@ _scan_ignore: set[tuple[str, int]] = set()
def get_config() -> dict[str, Any]: def get_config() -> dict[str, Any]:
conf, path = config.load('brokers') conf, path = config.load()
section = conf.get('ib') section = conf.get('ib')
accounts = section.get('accounts')
if not accounts:
raise ValueError(
'brokers.toml -> `ib.accounts` must be defined\n'
f'location: {path}'
)
names = list(accounts.keys())
accts = section['accounts'] = bidict(accounts)
log.info(
f'brokers.toml defines {len(accts)} accounts: '
f'{pformat(names)}'
)
if section is None: if section is None:
log.warning(f'No config section found for ib in {path}') log.warning(f'No config section found for ib in {path}')
return {} return {}
@ -1122,7 +836,6 @@ async def load_aio_clients(
# retry a few times to get the client going.. # retry a few times to get the client going..
connect_retries: int = 3, connect_retries: int = 3,
connect_timeout: float = 0.5, connect_timeout: float = 0.5,
disconnect_on_exit: bool = True,
) -> dict[str, Client]: ) -> dict[str, Client]:
''' '''
@ -1195,12 +908,6 @@ async def load_aio_clients(
# careful. # careful.
timeout=connect_timeout, timeout=connect_timeout,
) )
# create and cache client
client = Client(ib)
# update all actor-global caches
log.info(f"Caching client for {sockaddr}")
_client_cache[sockaddr] = client
break break
except ( except (
@ -1224,9 +931,21 @@ async def load_aio_clients(
log.warning( log.warning(
f'Failed to connect on {port} for {i} time, retrying...') f'Failed to connect on {port} for {i} time, retrying...')
# create and cache client
client = Client(ib)
# Pre-collect all accounts available for this # Pre-collect all accounts available for this
# connection and map account names to this client # connection and map account names to this client
# instance. # instance.
pps = ib.positions()
if pps:
for pp in pps:
accounts_found[
accounts_def.inverse[pp.account]
] = client
# if there are accounts without positions we should still
# register them for this client
for value in ib.accountValues(): for value in ib.accountValues():
acct_number = value.account acct_number = value.account
@ -1247,6 +966,10 @@ async def load_aio_clients(
f'{pformat(accounts_found)}' f'{pformat(accounts_found)}'
) )
# update all actor-global caches
log.info(f"Caching client for {sockaddr}")
_client_cache[sockaddr] = client
# XXX: why aren't we just updating this directy above # XXX: why aren't we just updating this directy above
# instead of using the intermediary `accounts_found`? # instead of using the intermediary `accounts_found`?
_accounts2clients.update(accounts_found) _accounts2clients.update(accounts_found)
@ -1264,11 +987,10 @@ async def load_aio_clients(
finally: finally:
# TODO: for re-scans we'll want to not teardown clients which # TODO: for re-scans we'll want to not teardown clients which
# are up and stable right? # are up and stable right?
if disconnect_on_exit:
for acct, client in _accounts2clients.items(): for acct, client in _accounts2clients.items():
log.info(f'Disconnecting {acct}@{client}') log.info(f'Disconnecting {acct}@{client}')
client.ib.disconnect() client.ib.disconnect()
_client_cache.pop((host, port), None) _client_cache.pop((host, port))
async def load_clients_for_trio( async def load_clients_for_trio(
@ -1297,6 +1019,9 @@ async def load_clients_for_trio(
await asyncio.sleep(float('inf')) await asyncio.sleep(float('inf'))
_proxies: dict[str, MethodProxy] = {}
@acm @acm
async def open_client_proxies() -> tuple[ async def open_client_proxies() -> tuple[
dict[str, MethodProxy], dict[str, MethodProxy],
@ -1304,6 +1029,7 @@ async def open_client_proxies() -> tuple[
]: ]:
async with ( async with (
tractor.trionics.maybe_open_context( tractor.trionics.maybe_open_context(
# acm_func=open_client_proxies,
acm_func=tractor.to_asyncio.open_channel_from, acm_func=tractor.to_asyncio.open_channel_from,
kwargs={'target': load_clients_for_trio}, kwargs={'target': load_clients_for_trio},
@ -1318,14 +1044,13 @@ async def open_client_proxies() -> tuple[
if cache_hit: if cache_hit:
log.info(f'Re-using cached clients: {clients}') log.info(f'Re-using cached clients: {clients}')
proxies = {}
for acct_name, client in clients.items(): for acct_name, client in clients.items():
proxy = await stack.enter_async_context( proxy = await stack.enter_async_context(
open_client_proxy(client), open_client_proxy(client),
) )
proxies[acct_name] = proxy _proxies[acct_name] = proxy
yield proxies, clients yield _proxies, clients
def get_preferred_data_client( def get_preferred_data_client(
@ -1474,13 +1199,11 @@ async def open_client_proxy(
event_table = {} event_table = {}
async with ( async with (
to_asyncio.open_channel_from( to_asyncio.open_channel_from(
open_aio_client_method_relay, open_aio_client_method_relay,
client=client, client=client,
event_consumers=event_table, event_consumers=event_table,
) as (first, chan), ) as (first, chan),
trio.open_nursery() as relay_n, trio.open_nursery() as relay_n,
): ):

File diff suppressed because it is too large Load Diff

View File

@ -22,7 +22,6 @@ import asyncio
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
from dataclasses import asdict from dataclasses import asdict
from datetime import datetime from datetime import datetime
from functools import partial
from math import isnan from math import isnan
import time import time
from typing import ( from typing import (
@ -39,14 +38,10 @@ import tractor
import trio import trio
from trio_typing import TaskStatus from trio_typing import TaskStatus
from .._util import ( from piker.data._sharedmem import ShmArray
NoData, from .._util import SymbolNotFound, NoData
DataUnavailable,
SymbolNotFound,
)
from .api import ( from .api import (
# _adhoc_futes_set, _adhoc_futes_set,
con2fqsn,
log, log,
load_aio_clients, load_aio_clients,
ibis, ibis,
@ -107,7 +102,7 @@ async def open_data_client() -> MethodProxy:
@acm @acm
async def open_history_client( async def open_history_client(
fqsn: str, symbol: str,
) -> tuple[Callable, int]: ) -> tuple[Callable, int]:
''' '''
@ -115,75 +110,26 @@ async def open_history_client(
that takes in ``pendulum.datetime`` and returns ``numpy`` arrays. that takes in ``pendulum.datetime`` and returns ``numpy`` arrays.
''' '''
# TODO:
# - add logic to handle tradable hours and only grab
# valid bars in the range?
# - we want to avoid overrunning the underlying shm array buffer and
# we should probably calc the number of calls to make depending on
# that until we have the `marketstore` daemon in place in which case
# the shm size will be driven by user config and available sys
# memory.
async with open_data_client() as proxy: async with open_data_client() as proxy:
max_timeout: float = 2.
mean: float = 0
count: int = 0
head_dt: None | datetime = None
if (
# fx cons seem to not provide this endpoint?
'idealpro' not in fqsn
):
try:
head_dt = await proxy.get_head_time(fqsn=fqsn)
except RequestError:
head_dt = None
async def get_hist( async def get_hist(
timeframe: float,
end_dt: Optional[datetime] = None, end_dt: Optional[datetime] = None,
start_dt: Optional[datetime] = None, start_dt: Optional[datetime] = None,
) -> tuple[np.ndarray, str]: ) -> tuple[np.ndarray, str]:
nonlocal max_timeout, mean, count
query_start = time.time() out, fails = await get_bars(proxy, symbol, end_dt=end_dt)
out, timedout = await get_bars(
proxy,
fqsn,
timeframe,
end_dt=end_dt,
)
latency = time.time() - query_start
if (
not timedout
# and latency <= max_timeout
):
count += 1
mean += latency / count
print(
f'HISTORY FRAME QUERY LATENCY: {latency}\n'
f'mean: {mean}'
)
if ( # TODO: add logic here to handle tradable hours and only grab
out is None # valid bars in the range
): if out is None:
# could be trying to retreive bars over weekend # could be trying to retreive bars over weekend
log.error(f"Can't grab bars starting at {end_dt}!?!?") log.error(f"Can't grab bars starting at {end_dt}!?!?")
raise NoData( raise NoData(
f'{end_dt}', f'{end_dt}',
# frame_size=2000, frame_size=2000,
) )
if (
end_dt
and head_dt
and end_dt <= head_dt
):
raise DataUnavailable(f'First timestamp is {head_dt}')
bars, bars_array, first_dt, last_dt = out bars, bars_array, first_dt, last_dt = out
# volume cleaning since there's -ve entries, # volume cleaning since there's -ve entries,
@ -198,7 +144,7 @@ async def open_history_client(
# quite sure why.. needs some tinkering and probably # quite sure why.. needs some tinkering and probably
# a lookthrough of the ``ib_insync`` machinery, for eg. maybe # a lookthrough of the ``ib_insync`` machinery, for eg. maybe
# we have to do the batch queries on the `asyncio` side? # we have to do the batch queries on the `asyncio` side?
yield get_hist, {'erlangs': 1, 'rate': 3} yield get_hist, {'erlangs': 1, 'rate': 6}
_pacing: str = ( _pacing: str = (
@ -207,19 +153,96 @@ _pacing: str = (
) )
async def wait_on_data_reset( async def get_bars(
proxy: MethodProxy, proxy: MethodProxy,
reset_type: str = 'data', fqsn: str,
timeout: float = 16,
task_status: TaskStatus[ # blank to start which tells ib to look up the latest datum
tuple[ end_dt: str = '',
trio.CancelScope,
trio.Event,
]
] = trio.TASK_STATUS_IGNORED,
) -> bool:
) -> (dict, np.ndarray):
'''
Retrieve historical data from a ``trio``-side task using
a ``MethoProxy``.
'''
fails = 0
bars: Optional[list] = None
first_dt: datetime = None
last_dt: datetime = None
if end_dt:
last_dt = pendulum.from_timestamp(end_dt.timestamp())
for _ in range(10):
try:
out = await proxy.bars(
fqsn=fqsn,
end_dt=end_dt,
)
if out:
bars, bars_array = out
else:
await tractor.breakpoint()
if bars_array is None:
raise SymbolNotFound(fqsn)
first_dt = pendulum.from_timestamp(
bars[0].date.timestamp())
last_dt = pendulum.from_timestamp(
bars[-1].date.timestamp())
time = bars_array['time']
assert time[-1] == last_dt.timestamp()
assert time[0] == first_dt.timestamp()
log.info(
f'{len(bars)} bars retreived for {first_dt} -> {last_dt}'
)
return (bars, bars_array, first_dt, last_dt), fails
except RequestError as err:
msg = err.message
# why do we always need to rebind this?
# _err = err
if 'No market data permissions for' in msg:
# TODO: signalling for no permissions searches
raise NoData(
f'Symbol: {fqsn}',
)
elif (
err.code == 162
and 'HMDS query returned no data' in err.message
):
# XXX: this is now done in the storage mgmt layer
# and we shouldn't implicitly decrement the frame dt
# index since the upper layer may be doing so
# concurrently and we don't want to be delivering frames
# that weren't asked for.
log.warning(
f'NO DATA found ending @ {end_dt}\n'
)
# try to decrement start point and look further back
# end_dt = last_dt = last_dt.subtract(seconds=2000)
raise NoData(
f'Symbol: {fqsn}',
frame_size=2000,
)
elif _pacing in msg:
log.warning(
'History throttle rate reached!\n'
'Resetting farms with `ctrl-alt-f` hack\n'
)
# TODO: we might have to put a task lock around this # TODO: we might have to put a task lock around this
# method.. # method..
hist_ev = proxy.status_event( hist_ev = proxy.status_event(
@ -235,259 +258,144 @@ async def wait_on_data_reset(
# live_ev = proxy.status_event( # live_ev = proxy.status_event(
# 'Market data farm connection is OK:usfuture' # 'Market data farm connection is OK:usfuture'
# ) # )
# try to wait on the reset event(s) to arrive, a timeout # try to wait on the reset event(s) to arrive, a timeout
# will trigger a retry up to 6 times (for now). # will trigger a retry up to 6 times (for now).
tries: int = 2
timeout: float = 10
done = trio.Event() # try 3 time with a data reset then fail over to
with trio.move_on_after(timeout) as cs: # a connection reset.
for i in range(1, tries):
task_status.started((cs, done))
log.warning('Sending DATA RESET request') log.warning('Sending DATA RESET request')
res = await data_reset_hack(reset_type=reset_type) await data_reset_hack(reset_type='data')
if not res: with trio.move_on_after(timeout) as cs:
log.warning(
'NO VNC DETECTED!\n'
'Manually press ctrl-alt-f on your IB java app'
)
done.set()
return False
# TODO: not sure if waiting on other events
# is all that useful here or not.
# - in theory you could wait on one of the ones above first
# to verify the reset request was sent?
# - we need the same for real-time quote feeds which can
# sometimes flake out and stop delivering..
for name, ev in [ for name, ev in [
# TODO: not sure if waiting on other events
# is all that useful here or not. in theory
# you could wait on one of the ones above
# first to verify the reset request was
# sent?
('history', hist_ev), ('history', hist_ev),
]: ]:
await ev.wait() await ev.wait()
log.info(f"{name} DATA RESET") log.info(f"{name} DATA RESET")
done.set() break
return True
if cs.cancel_called: if cs.cancelled_caught:
fails += 1
log.warning( log.warning(
'Data reset task canceled?' f'Data reset {name} timeout, retrying {i}.'
) )
done.set()
return False
_data_resetter_task: trio.Task | None = None
async def get_bars(
proxy: MethodProxy,
fqsn: str,
timeframe: int,
# blank to start which tells ib to look up the latest datum
end_dt: str = '',
# TODO: make this more dynamic based on measured frame rx latency?
# how long before we trigger a feed reset (seconds)
feed_reset_timeout: float = 3,
# how many days to subtract before giving up on further
# history queries for instrument, presuming that most don't
# not trade for a week XD
max_nodatas: int = 6,
task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED,
) -> (dict, np.ndarray):
'''
Retrieve historical data from a ``trio``-side task using
a ``MethoProxy``.
'''
global _data_resetter_task
nodatas_count: int = 0
data_cs: trio.CancelScope | None = None
result: tuple[
ibis.objects.BarDataList,
np.ndarray,
datetime,
datetime,
] | None = None
result_ready = trio.Event()
async def query():
nonlocal result, data_cs, end_dt, nodatas_count
while True:
try:
out = await proxy.bars(
fqsn=fqsn,
end_dt=end_dt,
sample_period_s=timeframe,
# ideally we cancel the request just before we
# cancel on the ``trio``-side and trigger a data
# reset hack.. the problem is there's no way (with
# current impl) to detect a cancel case.
# timeout=timeout,
)
if out is None:
raise NoData(f'{end_dt}')
bars, bars_array, dt_duration = out
if not bars:
log.warning(
f'History is blank for {dt_duration} from {end_dt}'
)
end_dt -= dt_duration
continue continue
else:
if bars_array is None: log.warning('Sending CONNECTION RESET')
raise SymbolNotFound(fqsn) await data_reset_hack(reset_type='connection')
first_dt = pendulum.from_timestamp( with trio.move_on_after(timeout) as cs:
bars[0].date.timestamp()) for name, ev in [
# TODO: not sure if waiting on other events
# is all that useful here or not. in theory
# you could wait on one of the ones above
# first to verify the reset request was
# sent?
('history', hist_ev),
]:
await ev.wait()
log.info(f"{name} DATA RESET")
last_dt = pendulum.from_timestamp( if cs.cancelled_caught:
bars[-1].date.timestamp()) fails += 1
log.warning('Data CONNECTION RESET timeout!?')
time = bars_array['time']
assert time[-1] == last_dt.timestamp()
assert time[0] == first_dt.timestamp()
log.info(
f'{len(bars)} bars retreived {first_dt} -> {last_dt}'
)
if data_cs:
data_cs.cancel()
result = (bars, bars_array, first_dt, last_dt)
# signal data reset loop parent task
result_ready.set()
return result
except RequestError as err:
msg = err.message
if 'No market data permissions for' in msg:
# TODO: signalling for no permissions searches
raise NoData(
f'Symbol: {fqsn}',
)
elif err.code == 162:
if (
'HMDS query returned no data' in msg
):
# XXX: this is now done in the storage mgmt
# layer and we shouldn't implicitly decrement
# the frame dt index since the upper layer may
# be doing so concurrently and we don't want to
# be delivering frames that weren't asked for.
# try to decrement start point and look further back
# end_dt = end_dt.subtract(seconds=2000)
logmsg = "SUBTRACTING DAY from DT index"
if end_dt is not None:
end_dt = end_dt.subtract(days=1)
elif end_dt is None:
end_dt = pendulum.now().subtract(days=1)
log.warning(
f'NO DATA found ending @ {end_dt}\n'
+ logmsg
)
if nodatas_count >= max_nodatas:
raise DataUnavailable(
f'Presuming {fqsn} has no further history '
f'after {max_nodatas} tries..'
)
nodatas_count += 1
continue
elif 'API historical data query cancelled' in err.message:
log.warning(
'Query cancelled by IB (:eyeroll:):\n'
f'{err.message}'
)
continue
elif (
'Trading TWS session is connected from a different IP'
in err.message
):
log.warning("ignoring ip address warning")
continue
# XXX: more or less same as above timeout case
elif _pacing in msg:
log.warning(
'History throttle rate reached!\n'
'Resetting farms with `ctrl-alt-f` hack\n'
)
# cancel any existing reset task
if data_cs:
data_cs.cancel()
# spawn new data reset task
data_cs, reset_done = await nurse.start(
partial(
wait_on_data_reset,
proxy,
timeout=float('inf'),
reset_type='connection'
)
)
continue
else: else:
raise raise
# TODO: make this global across all history task/requests return None, None
# such that simultaneous symbol queries don't try data resettingn # else: # throttle wasn't fixed so error out immediately
# too fast.. # raise _err
unset_resetter: bool = False
async with trio.open_nursery() as nurse:
# start history request that we allow
# to run indefinitely until a result is acquired
nurse.start_soon(query)
# start history reset loop which waits up to the timeout async def backfill_bars(
# for a result before triggering a data feed reset.
while not result_ready.is_set():
with trio.move_on_after(feed_reset_timeout): fqsn: str,
await result_ready.wait() shm: ShmArray, # type: ignore # noqa
break
if _data_resetter_task: # TODO: we want to avoid overrunning the underlying shm array buffer
# don't double invoke the reset hack if another # and we should probably calc the number of calls to make depending
# requester task already has it covered. # on that until we have the `marketstore` daemon in place in which
# case the shm size will be driven by user config and available sys
# memory.
count: int = 16,
task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED,
) -> None:
'''
Fill historical bars into shared mem / storage afap.
TODO: avoid pacing constraints:
https://github.com/pikers/piker/issues/128
'''
# last_dt1 = None
last_dt = None
with trio.CancelScope() as cs:
async with open_data_client() as proxy:
out, fails = await get_bars(proxy, fqsn)
if out is None:
raise RuntimeError("Could not pull currrent history?!")
(first_bars, bars_array, first_dt, last_dt) = out
vlm = bars_array['volume']
vlm[vlm < 0] = 0
last_dt = first_dt
# write historical data to buffer
shm.push(bars_array)
task_status.started(cs)
i = 0
while i < count:
out, fails = await get_bars(proxy, fqsn, end_dt=first_dt)
if out is None:
# could be trying to retreive bars over weekend
# TODO: add logic here to handle tradable hours and
# only grab valid bars in the range
log.error(f"Can't grab bars starting at {first_dt}!?!?")
# XXX: get_bars() should internally decrement dt by
# 2k seconds and try again.
continue continue
else:
_data_resetter_task = trio.lowlevel.current_task()
unset_resetter = True
# spawn new data reset task (first_bars, bars_array, first_dt, last_dt) = out
data_cs, reset_done = await nurse.start( # last_dt1 = last_dt
partial( # last_dt = first_dt
wait_on_data_reset,
proxy,
timeout=float('inf'),
)
)
# sync wait on reset to complete
await reset_done.wait()
_data_resetter_task = None if unset_resetter else _data_resetter_task # volume cleaning since there's -ve entries,
return result, data_cs is not None # wood luv to know what crookery that is..
vlm = bars_array['volume']
vlm[vlm < 0] = 0
# TODO we should probably dig into forums to see what peeps
# think this data "means" and then use it as an indicator of
# sorts? dinkus has mentioned that $vlms for the day dont'
# match other platforms nor the summary stat tws shows in
# the monitor - it's probably worth investigating.
shm.push(bars_array, prepend=True)
i += 1
asset_type_map = { asset_type_map = {
@ -505,7 +413,6 @@ asset_type_map = {
'WAR': 'warrant', 'WAR': 'warrant',
'IOPT': 'warran', 'IOPT': 'warran',
'BAG': 'bag', 'BAG': 'bag',
'CRYPTO': 'crypto', # bc it's diff then fiat?
# 'NEWS': 'news', # 'NEWS': 'news',
} }
@ -545,9 +452,7 @@ async def _setup_quote_stream(
to_trio.send_nowait(None) to_trio.send_nowait(None)
async with load_aio_clients( async with load_aio_clients() as accts2clients:
disconnect_on_exit=False,
) as accts2clients:
caccount_name, client = get_preferred_data_client(accts2clients) caccount_name, client = get_preferred_data_client(accts2clients)
contract = contract or (await client.find_contract(symbol)) contract = contract or (await client.find_contract(symbol))
ticker: Ticker = client.ib.reqMktData(contract, ','.join(opts)) ticker: Ticker = client.ib.reqMktData(contract, ','.join(opts))
@ -593,11 +498,10 @@ async def _setup_quote_stream(
# Manually do the dereg ourselves. # Manually do the dereg ourselves.
teardown() teardown()
except trio.WouldBlock: except trio.WouldBlock:
# log.warning( log.warning(
# f'channel is blocking symbol feed for {symbol}?' f'channel is blocking symbol feed for {symbol}?'
# f'\n{to_trio.statistics}' f'\n{to_trio.statistics}'
# ) )
pass
# except trio.WouldBlock: # except trio.WouldBlock:
# # for slow debugging purposes to avoid clobbering prompt # # for slow debugging purposes to avoid clobbering prompt
@ -627,8 +531,7 @@ async def open_aio_quote_stream(
from_aio = _quote_streams.get(symbol) from_aio = _quote_streams.get(symbol)
if from_aio: if from_aio:
# if we already have a cached feed deliver a rx side clone # if we already have a cached feed deliver a rx side clone to consumer
# to consumer
async with broadcast_receiver( async with broadcast_receiver(
from_aio, from_aio,
2**6, 2**6,
@ -650,17 +553,38 @@ async def open_aio_quote_stream(
# TODO: cython/mypyc/numba this! # TODO: cython/mypyc/numba this!
# or we can at least cache a majority of the values
# except for the ones we expect to change?..
def normalize( def normalize(
ticker: Ticker, ticker: Ticker,
calc_price: bool = False calc_price: bool = False
) -> dict: ) -> dict:
# should be real volume for this contract by default
calc_price = False
# check for special contract types # check for special contract types
con = ticker.contract con = ticker.contract
fqsn, calc_price = con2fqsn(con) if type(con) in (
ibis.Commodity,
ibis.Forex,
):
# commodities and forex don't have an exchange name and
# no real volume so we have to calculate the price
suffix = con.secType
# no real volume on this tract
calc_price = True
else:
suffix = con.primaryExchange
if not suffix:
suffix = con.exchange
# append a `.<suffix>` to the returned symbol
# key for derivatives that normally is the expiry
# date key.
expiry = con.lastTradeDateOrContractMonth
if expiry:
suffix += f'.{expiry}'
# convert named tuples to dicts so we send usable keys # convert named tuples to dicts so we send usable keys
new_ticks = [] new_ticks = []
@ -692,7 +616,9 @@ def normalize(
# generate fqsn with possible specialized suffix # generate fqsn with possible specialized suffix
# for derivatives, note the lowercase. # for derivatives, note the lowercase.
data['symbol'] = data['fqsn'] = fqsn data['symbol'] = data['fqsn'] = '.'.join(
(con.symbol, suffix)
).lower()
# convert named tuples to dicts for transport # convert named tuples to dicts for transport
tbts = data.get('tickByTicks') tbts = data.get('tickByTicks')
@ -757,13 +683,6 @@ async def stream_quotes(
# TODO: more consistent field translation # TODO: more consistent field translation
atype = syminfo['asset_type'] = asset_type_map[syminfo['secType']] atype = syminfo['asset_type'] = asset_type_map[syminfo['secType']]
if atype in {
'forex',
'index',
'commodity',
}:
syminfo['no_vlm'] = True
# for stocks it seems TWS reports too small a tick size # for stocks it seems TWS reports too small a tick size
# such that you can't submit orders with that granularity? # such that you can't submit orders with that granularity?
min_tick = 0.01 if atype == 'stock' else 0 min_tick = 0.01 if atype == 'stock' else 0
@ -790,9 +709,9 @@ async def stream_quotes(
}, },
} }
return init_msgs, syminfo return init_msgs
init_msgs, syminfo = mk_init_msgs() init_msgs = mk_init_msgs()
# TODO: we should instead spawn a task that waits on a feed to start # TODO: we should instead spawn a task that waits on a feed to start
# and let it wait indefinitely..instead of this hard coded stuff. # and let it wait indefinitely..instead of this hard coded stuff.
@ -801,14 +720,7 @@ async def stream_quotes(
# it might be outside regular trading hours so see if we can at # it might be outside regular trading hours so see if we can at
# least grab history. # least grab history.
if ( if isnan(first_ticker.last):
isnan(first_ticker.last)
and type(first_ticker.contract) not in (
ibis.Commodity,
ibis.Forex,
ibis.Crypto,
)
):
task_status.started((init_msgs, first_quote)) task_status.started((init_msgs, first_quote))
# it's not really live but this will unblock # it's not really live but this will unblock
@ -819,77 +731,41 @@ async def stream_quotes(
await trio.sleep_forever() await trio.sleep_forever()
return # we never expect feed to come up? return # we never expect feed to come up?
cs: Optional[trio.CancelScope] = None async with open_aio_quote_stream(
startup: bool = True
while (
startup
or cs.cancel_called
):
with trio.CancelScope() as cs:
async with (
trio.open_nursery() as nurse,
open_aio_quote_stream(
symbol=sym, symbol=sym,
contract=con, contract=con,
) as stream, ) as stream:
):
# ugh, clear ticks since we've consumed them # ugh, clear ticks since we've consumed them
# (ahem, ib_insync is stateful trash) # (ahem, ib_insync is stateful trash)
first_ticker.ticks = [] first_ticker.ticks = []
# only on first entry at feed boot up
if startup:
startup = False
task_status.started((init_msgs, first_quote)) task_status.started((init_msgs, first_quote))
# start a stream restarter task which monitors the
# data feed event.
async def reset_on_feed():
# TODO: this seems to be surpressed from the
# traceback in ``tractor``?
# assert 0
rt_ev = proxy.status_event(
'Market data farm connection is OK:usfarm'
)
await rt_ev.wait()
cs.cancel() # cancel called should now be set
nurse.start_soon(reset_on_feed)
async with aclosing(stream): async with aclosing(stream):
if syminfo.get('no_vlm', False): if type(first_ticker.contract) not in (
ibis.Commodity,
# generally speaking these feeds don't ibis.Forex
# include vlm data. ):
atype = syminfo['asset_type'] # wait for real volume on feed (trading might be closed)
log.info(
f'No-vlm {sym}@{atype}, skipping quote poll'
)
else:
# wait for real volume on feed (trading might be
# closed)
while True: while True:
ticker = await stream.receive() ticker = await stream.receive()
# for a real volume contract we rait for # for a real volume contract we rait for the first
# the first "real" trade to take place # "real" trade to take place
if ( if (
# not calc_price # not calc_price
# and not ticker.rtTime # and not ticker.rtTime
not ticker.rtTime not ticker.rtTime
): ):
# spin consuming tickers until we # spin consuming tickers until we get a real
# get a real market datum # market datum
log.debug(f"New unsent ticker: {ticker}") log.debug(f"New unsent ticker: {ticker}")
continue continue
else: else:
log.debug("Received first volume tick") log.debug("Received first real volume tick")
# ugh, clear ticks since we've # ugh, clear ticks since we've consumed them
# consumed them (ahem, ib_insync is # (ahem, ib_insync is truly stateful trash)
# truly stateful trash)
ticker.ticks = [] ticker.ticks = []
# XXX: this works because we don't use # XXX: this works because we don't use
@ -905,9 +781,7 @@ async def stream_quotes(
# last = time.time() # last = time.time()
async for ticker in stream: async for ticker in stream:
quote = normalize(ticker) quote = normalize(ticker)
fqsn = quote['fqsn'] await send_chan.send({quote['fqsn']: quote})
# print(f'sending {fqsn}:\n{quote}')
await send_chan.send({fqsn: quote})
# ugh, clear ticks since we've consumed them # ugh, clear ticks since we've consumed them
ticker.ticks = [] ticker.ticks = []
@ -931,9 +805,6 @@ async def data_reset_hack(
successful. successful.
- other OS support? - other OS support?
- integration with ``ib-gw`` run in docker + Xorg? - integration with ``ib-gw`` run in docker + Xorg?
- is it possible to offer a local server that can be accessed by
a client? Would be sure be handy for running native java blobs
that need to be wrangle.
''' '''
@ -964,10 +835,7 @@ async def data_reset_hack(
client.mouse.click() client.mouse.click()
client.keyboard.press('Ctrl', 'Alt', key) # keys are stacked client.keyboard.press('Ctrl', 'Alt', key) # keys are stacked
try:
await tractor.to_asyncio.run_task(vnc_click_hack) await tractor.to_asyncio.run_task(vnc_click_hack)
except OSError:
return False
# we don't really need the ``xdotool`` approach any more B) # we don't really need the ``xdotool`` approach any more B)
return True return True
@ -982,30 +850,14 @@ async def open_symbol_search(
# TODO: load user defined symbol set locally for fast search? # TODO: load user defined symbol set locally for fast search?
await ctx.started({}) await ctx.started({})
async with ( async with open_data_client() as proxy:
open_client_proxies() as (proxies, clients),
open_data_client() as data_proxy,
):
async with ctx.open_stream() as stream: async with ctx.open_stream() as stream:
# select a non-history client for symbol search to lighten
# the load in the main data node.
proxy = data_proxy
for name, proxy in proxies.items():
if proxy is data_proxy:
continue
break
ib_client = proxy._aio_ns.ib
log.info(f'Using {ib_client} for symbol search')
last = time.time() last = time.time()
async for pattern in stream:
log.info(f'received {pattern}')
now = time.time()
# this causes tractor hang... async for pattern in stream:
# assert 0 log.debug(f'received {pattern}')
now = time.time()
assert pattern, 'IB can not accept blank search pattern' assert pattern, 'IB can not accept blank search pattern'
@ -1019,14 +871,7 @@ async def open_symbol_search(
except trio.WouldBlock: except trio.WouldBlock:
pass pass
if ( if not pattern or pattern.isspace():
not pattern
or pattern.isspace()
# XXX: not sure if this is a bad assumption but it
# seems to make search snappier?
or len(pattern) < 1
):
log.warning('empty pattern received, skipping..') log.warning('empty pattern received, skipping..')
# TODO: *BUG* if nothing is returned here the client # TODO: *BUG* if nothing is returned here the client
@ -1041,7 +886,7 @@ async def open_symbol_search(
continue continue
log.info(f'searching for {pattern}') log.debug(f'searching for {pattern}')
last = time.time() last = time.time()
@ -1050,16 +895,8 @@ async def open_symbol_search(
stock_results = [] stock_results = []
async def stash_results(target: Awaitable[list]): async def stash_results(target: Awaitable[list]):
try: stock_results.extend(await target)
results = await target
except tractor.trionics.Lagged:
print("IB SYM-SEARCH OVERRUN?!?")
return
stock_results.extend(results)
for i in range(10):
with trio.move_on_after(3) as cs:
async with trio.open_nursery() as sn: async with trio.open_nursery() as sn:
sn.start_soon( sn.start_soon(
stash_results, stash_results,
@ -1072,26 +909,17 @@ async def open_symbol_search(
# trigger async request # trigger async request
await trio.sleep(0) await trio.sleep(0)
if cs.cancelled_caught: # match against our ad-hoc set immediately
log.warning( adhoc_matches = fuzzy.extractBests(
f'Search timeout? {proxy._aio_ns.ib.client}' pattern,
list(_adhoc_futes_set),
score_cutoff=90,
) )
continue log.info(f'fuzzy matched adhocs: {adhoc_matches}')
else: adhoc_match_results = {}
break if adhoc_matches:
# TODO: do we need to pull contract details?
# # match against our ad-hoc set immediately adhoc_match_results = {i[0]: {} for i in adhoc_matches}
# adhoc_matches = fuzzy.extractBests(
# pattern,
# list(_adhoc_futes_set),
# score_cutoff=90,
# )
# log.info(f'fuzzy matched adhocs: {adhoc_matches}')
# adhoc_match_results = {}
# if adhoc_matches:
# # TODO: do we need to pull contract details?
# adhoc_match_results = {i[0]: {} for i in
# adhoc_matches}
log.debug(f'fuzzy matching stocks {stock_results}') log.debug(f'fuzzy matching stocks {stock_results}')
stock_matches = fuzzy.extractBests( stock_matches = fuzzy.extractBests(
@ -1100,8 +928,7 @@ async def open_symbol_search(
score_cutoff=50, score_cutoff=50,
) )
# matches = adhoc_match_results | { matches = adhoc_match_results | {
matches = {
item[0]: {} for item in stock_matches item[0]: {} for item in stock_matches
} }
# TODO: we used to deliver contract details # TODO: we used to deliver contract details

File diff suppressed because it is too large Load Diff

View File

@ -1,64 +0,0 @@
``kraken`` backend
------------------
though they don't have the most liquidity of all the cexes they sure are
accommodating to those of us who appreciate a little ``xmr``.
status
******
current support is *production grade* and both real-time data and order
management should be correct and fast. this backend is used by core devs
for live trading.
config
******
In order to get order mode support your ``brokers.toml``
needs to have something like the following:
.. code:: toml
[kraken]
accounts.spot = 'spot'
key_descr = "spot"
api_key = "69696969696969696696969696969696969696969696969696969696"
secret = "BOOBSBOOBSBOOBSBOOBSBOOBSSMBZ69696969696969669969696969696"
If everything works correctly you should see any current positions
loaded in the pps pane on chart load and you should also be able to
check your trade records in the file::
<pikerk_conf_dir>/ledgers/trades_kraken_spot.toml
An example ledger file will have entries written verbatim from the
trade events schema:
.. code:: toml
[TFJBKK-SMBZS-VJ4UWS]
ordertxid = "SMBZSA-7CNQU-3HWLNJ"
postxid = "SMBZSE-M7IF5-CFI7LT"
pair = "XXMRZEUR"
time = 1655691993.4133966
type = "buy"
ordertype = "limit"
price = "103.97000000"
cost = "499.99999977"
fee = "0.80000000"
vol = "4.80907954"
margin = "0.00000000"
misc = ""
your ``pps.toml`` file will have position entries like,
.. code:: toml
[kraken.spot."xmreur.kraken"]
size = 4.80907954
ppu = 103.97000000
bsuid = "XXMRZEUR"
clears = [
{ tid = "TFJBKK-SMBZS-VJ4UWS", cost = 0.8, price = 103.97, size = 4.80907954, dt = "2022-05-20T02:26:33.413397+00:00" },
]

View File

@ -1,61 +0,0 @@
# piker: trading gear for hackers
# Copyright (C) Tyler Goodlet (in stewardship for pikers)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Kraken backend.
Sub-modules within break into the core functionalities:
- ``broker.py`` part for orders / trading endpoints
- ``feed.py`` for real-time data feed endpoints
- ``api.py`` for the core API machinery which is ``trio``-ized
wrapping around ``ib_insync``.
'''
from piker.log import get_logger
log = get_logger(__name__)
from .api import (
get_client,
)
from .feed import (
open_history_client,
open_symbol_search,
stream_quotes,
)
from .broker import (
trades_dialogue,
norm_trade_records,
)
__all__ = [
'get_client',
'trades_dialogue',
'open_history_client',
'open_symbol_search',
'stream_quotes',
'norm_trade_records',
]
# tractor RPC enable arg
__enable_modules__: list[str] = [
'api',
'feed',
'broker',
]

View File

@ -1,536 +0,0 @@
# piker: trading gear for hackers
# Copyright (C) Tyler Goodlet (in stewardship for pikers)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Kraken web API wrapping.
'''
from contextlib import asynccontextmanager as acm
from datetime import datetime
import itertools
from typing import (
Any,
Optional,
Union,
)
import time
from bidict import bidict
import pendulum
import asks
from fuzzywuzzy import process as fuzzy
import numpy as np
import urllib.parse
import hashlib
import hmac
import base64
import trio
from piker import config
from piker.brokers._util import (
resproc,
SymbolNotFound,
BrokerError,
DataThrottle,
)
from piker.pp import Transaction
from . import log
# <uri>/<version>/
_url = 'https://api.kraken.com/0'
# Broker specific ohlc schema which includes a vwap field
_ohlc_dtype = [
('index', int),
('time', int),
('open', float),
('high', float),
('low', float),
('close', float),
('volume', float),
('count', int),
('bar_wap', float),
]
# UI components allow this to be declared such that additional
# (historical) fields can be exposed.
ohlc_dtype = np.dtype(_ohlc_dtype)
_show_wap_in_history = True
_symbol_info_translation: dict[str, str] = {
'tick_decimals': 'pair_decimals',
}
def get_config() -> dict[str, Any]:
conf, path = config.load()
section = conf.get('kraken')
if section is None:
log.warning(f'No config section found for kraken in {path}')
return {}
return section
def get_kraken_signature(
urlpath: str,
data: dict[str, Any],
secret: str
) -> str:
postdata = urllib.parse.urlencode(data)
encoded = (str(data['nonce']) + postdata).encode()
message = urlpath.encode() + hashlib.sha256(encoded).digest()
mac = hmac.new(base64.b64decode(secret), message, hashlib.sha512)
sigdigest = base64.b64encode(mac.digest())
return sigdigest.decode()
class InvalidKey(ValueError):
'''
EAPI:Invalid key
This error is returned when the API key used for the call is
either expired or disabled, please review the API key in your
Settings -> API tab of account management or generate a new one
and update your application.
'''
class Client:
# global symbol normalization table
_ntable: dict[str, str] = {}
_atable: bidict[str, str] = bidict()
def __init__(
self,
config: dict[str, str],
name: str = '',
api_key: str = '',
secret: str = ''
) -> None:
self._sesh = asks.Session(connections=4)
self._sesh.base_location = _url
self._sesh.headers.update({
'User-Agent':
'krakenex/2.1.0 (+https://github.com/veox/python3-krakenex)'
})
self.conf: dict[str, str] = config
self._pairs: list[str] = []
self._name = name
self._api_key = api_key
self._secret = secret
@property
def pairs(self) -> dict[str, Any]:
if self._pairs is None:
raise RuntimeError(
"Make sure to run `cache_symbols()` on startup!"
)
# retreive and cache all symbols
return self._pairs
async def _public(
self,
method: str,
data: dict,
) -> dict[str, Any]:
resp = await self._sesh.post(
path=f'/public/{method}',
json=data,
timeout=float('inf')
)
return resproc(resp, log)
async def _private(
self,
method: str,
data: dict,
uri_path: str
) -> dict[str, Any]:
headers = {
'Content-Type':
'application/x-www-form-urlencoded',
'API-Key':
self._api_key,
'API-Sign':
get_kraken_signature(uri_path, data, self._secret)
}
resp = await self._sesh.post(
path=f'/private/{method}',
data=data,
headers=headers,
timeout=float('inf')
)
return resproc(resp, log)
async def endpoint(
self,
method: str,
data: dict[str, Any]
) -> dict[str, Any]:
uri_path = f'/0/private/{method}'
data['nonce'] = str(int(1000*time.time()))
return await self._private(method, data, uri_path)
async def get_balances(
self,
) -> dict[str, float]:
'''
Return the set of asset balances for this account
by symbol.
'''
resp = await self.endpoint(
'Balance',
{},
)
by_bsuid = resp['result']
return {
self._atable[sym].lower(): float(bal)
for sym, bal in by_bsuid.items()
}
async def get_assets(self) -> dict[str, dict]:
resp = await self._public('Assets', {})
return resp['result']
async def cache_assets(self) -> None:
assets = self.assets = await self.get_assets()
for bsuid, info in assets.items():
self._atable[bsuid] = info['altname']
async def get_trades(
self,
fetch_limit: int = 10,
) -> dict[str, Any]:
'''
Get the trades (aka cleared orders) history from the rest endpoint:
https://docs.kraken.com/rest/#operation/getTradeHistory
'''
ofs = 0
trades_by_id: dict[str, Any] = {}
for i in itertools.count():
if i >= fetch_limit:
break
# increment 'ofs' pagination offset
ofs = i*50
resp = await self.endpoint(
'TradesHistory',
{'ofs': ofs},
)
by_id = resp['result']['trades']
trades_by_id.update(by_id)
# we can get up to 50 results per query
if (
len(by_id) < 50
):
err = resp.get('error')
if err:
raise BrokerError(err)
# we know we received the max amount of
# trade results so there may be more history.
# catch the end of the trades
count = resp['result']['count']
break
# santity check on update
assert count == len(trades_by_id.values())
return trades_by_id
async def get_xfers(
self,
asset: str,
src_asset: str = '',
) -> dict[str, Transaction]:
'''
Get asset balance transfer transactions.
Currently only withdrawals are supported.
'''
xfers: list[dict] = (await self.endpoint(
'WithdrawStatus',
{'asset': asset},
))['result']
# eg. resp schema:
# 'result': [{'method': 'Bitcoin', 'aclass': 'currency', 'asset':
# 'XXBT', 'refid': 'AGBJRMB-JHD2M4-NDI3NR', 'txid':
# 'b95d66d3bb6fd76cbccb93f7639f99a505cb20752c62ea0acc093a0e46547c44',
# 'info': 'bc1qc8enqjekwppmw3g80p56z5ns7ze3wraqk5rl9z',
# 'amount': '0.00300726', 'fee': '0.00001000', 'time':
# 1658347714, 'status': 'Success'}]}
trans: dict[str, Transaction] = {}
for entry in xfers:
# look up the normalized name
asset = self._atable[entry['asset']].lower()
# XXX: this is in the asset units (likely) so it isn't
# quite the same as a commisions cost necessarily..)
cost = float(entry['fee'])
tran = Transaction(
fqsn=asset + '.kraken',
tid=entry['txid'],
dt=pendulum.from_timestamp(entry['time']),
bsuid=f'{asset}{src_asset}',
size=-1*(
float(entry['amount'])
+
cost
),
# since this will be treated as a "sell" it
# shouldn't be needed to compute the be price.
price='NaN',
# XXX: see note above
cost=0,
)
trans[tran.tid] = tran
return trans
async def submit_limit(
self,
symbol: str,
price: float,
action: str,
size: float,
reqid: str = None,
validate: bool = False # set True test call without a real submission
) -> dict:
'''
Place an order and return integer request id provided by client.
'''
# Build common data dict for common keys from both endpoints
data = {
"pair": symbol,
"price": str(price),
"validate": validate
}
if reqid is None:
# Build order data for kraken api
data |= {
"ordertype": "limit",
"type": action,
"volume": str(size),
}
return await self.endpoint('AddOrder', data)
else:
# Edit order data for kraken api
data["txid"] = reqid
return await self.endpoint('EditOrder', data)
async def submit_cancel(
self,
reqid: str,
) -> dict:
'''
Send cancel request for order id ``reqid``.
'''
# txid is a transaction id given by kraken
return await self.endpoint('CancelOrder', {"txid": reqid})
async def symbol_info(
self,
pair: Optional[str] = None,
) -> dict[str, dict[str, str]]:
if pair is not None:
pairs = {'pair': pair}
else:
pairs = None # get all pairs
resp = await self._public('AssetPairs', pairs)
err = resp['error']
if err:
symbolname = pairs['pair'] if pair else None
raise SymbolNotFound(f'{symbolname}.kraken')
pairs = resp['result']
if pair is not None:
_, data = next(iter(pairs.items()))
return data
else:
return pairs
async def cache_symbols(
self,
) -> dict:
if not self._pairs:
self._pairs = await self.symbol_info()
ntable = {}
for restapikey, info in self._pairs.items():
ntable[restapikey] = ntable[info['wsname']] = info['altname']
self._ntable.update(ntable)
return self._pairs
async def search_symbols(
self,
pattern: str,
limit: int = None,
) -> dict[str, Any]:
if self._pairs is not None:
data = self._pairs
else:
data = await self.symbol_info()
matches = fuzzy.extractBests(
pattern,
data,
score_cutoff=50,
)
# repack in dict form
return {item[0]['altname']: item[0] for item in matches}
async def bars(
self,
symbol: str = 'XBTUSD',
# UTC 2017-07-02 12:53:20
since: Optional[Union[int, datetime]] = None,
count: int = 720, # <- max allowed per query
as_np: bool = True,
) -> dict:
if since is None:
since = pendulum.now('UTC').start_of('minute').subtract(
minutes=count).timestamp()
elif isinstance(since, int):
since = pendulum.from_timestamp(since).timestamp()
else: # presumably a pendulum datetime
since = since.timestamp()
# UTC 2017-07-02 12:53:20 is oldest seconds value
since = str(max(1499000000, int(since)))
json = await self._public(
'OHLC',
data={
'pair': symbol,
'since': since,
},
)
try:
res = json['result']
res.pop('last')
bars = next(iter(res.values()))
new_bars = []
first = bars[0]
last_nz_vwap = first[-3]
if last_nz_vwap == 0:
# use close if vwap is zero
last_nz_vwap = first[-4]
# convert all fields to native types
for i, bar in enumerate(bars):
# normalize weird zero-ed vwap values..cmon kraken..
# indicates vwap didn't change since last bar
vwap = float(bar.pop(-3))
if vwap != 0:
last_nz_vwap = vwap
if vwap == 0:
vwap = last_nz_vwap
# re-insert vwap as the last of the fields
bar.append(vwap)
new_bars.append(
(i,) + tuple(
ftype(bar[j]) for j, (name, ftype) in enumerate(
_ohlc_dtype[1:]
)
)
)
array = np.array(new_bars, dtype=_ohlc_dtype) if as_np else bars
return array
except KeyError:
errmsg = json['error'][0]
if 'not found' in errmsg:
raise SymbolNotFound(errmsg + f': {symbol}')
elif 'Too many requests' in errmsg:
raise DataThrottle(f'{symbol}')
else:
raise BrokerError(errmsg)
@classmethod
def normalize_symbol(
cls,
ticker: str
) -> str:
'''
Normalize symbol names to to a 3x3 pair from the global
definition map which we build out from the data retreived from
the 'AssetPairs' endpoint, see methods above.
'''
ticker = cls._ntable[ticker]
return ticker.lower()
@acm
async def get_client() -> Client:
conf = get_config()
if conf:
client = Client(
conf,
name=conf['key_descr'],
api_key=conf['api_key'],
secret=conf['secret']
)
else:
client = Client({})
# at startup, load all symbols, and asset info in
# batch requests.
async with trio.open_nursery() as nurse:
nurse.start_soon(client.cache_assets)
await client.cache_symbols()
yield client

File diff suppressed because it is too large Load Diff

View File

@ -1,500 +0,0 @@
# piker: trading gear for hackers
# Copyright (C) Tyler Goodlet (in stewardship for pikers)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Real-time and historical data feed endpoints.
'''
from contextlib import asynccontextmanager as acm
from datetime import datetime
from typing import (
Any,
Optional,
Callable,
)
import time
from async_generator import aclosing
from fuzzywuzzy import process as fuzzy
import numpy as np
import pendulum
from trio_typing import TaskStatus
import tractor
import trio
from piker._cacheables import open_cached_client
from piker.brokers._util import (
BrokerError,
DataThrottle,
DataUnavailable,
)
from piker.log import get_console_log
from piker.data import ShmArray
from piker.data.types import Struct
from piker.data._web_bs import open_autorecon_ws, NoBsWs
from . import log
from .api import (
Client,
)
# https://www.kraken.com/features/api#get-tradable-pairs
class Pair(Struct):
altname: str # alternate pair name
wsname: str # WebSocket pair name (if available)
aclass_base: str # asset class of base component
base: str # asset id of base component
aclass_quote: str # asset class of quote component
quote: str # asset id of quote component
lot: str # volume lot size
cost_decimals: int
costmin: float
pair_decimals: int # scaling decimal places for pair
lot_decimals: int # scaling decimal places for volume
# amount to multiply lot volume by to get currency volume
lot_multiplier: float
# array of leverage amounts available when buying
leverage_buy: list[int]
# array of leverage amounts available when selling
leverage_sell: list[int]
# fee schedule array in [volume, percent fee] tuples
fees: list[tuple[int, float]]
# maker fee schedule array in [volume, percent fee] tuples (if on
# maker/taker)
fees_maker: list[tuple[int, float]]
fee_volume_currency: str # volume discount currency
margin_call: str # margin call level
margin_stop: str # stop-out/liquidation margin level
ordermin: float # minimum order volume for pair
tick_size: float # min price step size
status: str
short_position_limit: float
long_position_limit: float
class OHLC(Struct):
'''
Description of the flattened OHLC quote format.
For schema details see:
https://docs.kraken.com/websockets/#message-ohlc
'''
chan_id: int # internal kraken id
chan_name: str # eg. ohlc-1 (name-interval)
pair: str # fx pair
time: float # Begin time of interval, in seconds since epoch
etime: float # End time of interval, in seconds since epoch
open: float # Open price of interval
high: float # High price within interval
low: float # Low price within interval
close: float # Close price of interval
vwap: float # Volume weighted average price within interval
volume: float # Accumulated volume **within interval**
count: int # Number of trades within interval
# (sampled) generated tick data
ticks: list[Any] = []
async def stream_messages(
ws: NoBsWs,
):
'''
Message stream parser and heartbeat handler.
Deliver ws subscription messages as well as handle heartbeat logic
though a single async generator.
'''
too_slow_count = last_hb = 0
while True:
with trio.move_on_after(5) as cs:
msg = await ws.recv_msg()
# trigger reconnection if heartbeat is laggy
if cs.cancelled_caught:
too_slow_count += 1
if too_slow_count > 20:
log.warning(
"Heartbeat is too slow, resetting ws connection")
await ws._connect()
too_slow_count = 0
continue
match msg:
case {'event': 'heartbeat'}:
now = time.time()
delay = now - last_hb
last_hb = now
# XXX: why tf is this not printing without --tl flag?
log.debug(f"Heartbeat after {delay}")
# print(f"Heartbeat after {delay}")
continue
case _:
# passthrough sub msgs
yield msg
async def process_data_feed_msgs(
ws: NoBsWs,
):
'''
Parse and pack data feed messages.
'''
async for msg in stream_messages(ws):
match msg:
case {
'errorMessage': errmsg
}:
raise BrokerError(errmsg)
case {
'event': 'subscriptionStatus',
} as sub:
log.info(
'WS subscription is active:\n'
f'{sub}'
)
continue
case [
chan_id,
*payload_array,
chan_name,
pair
]:
if 'ohlc' in chan_name:
ohlc = OHLC(
chan_id,
chan_name,
pair,
*payload_array[0]
)
ohlc.typecast()
yield 'ohlc', ohlc
elif 'spread' in chan_name:
bid, ask, ts, bsize, asize = map(
float, payload_array[0])
# TODO: really makes you think IB has a horrible API...
quote = {
'symbol': pair.replace('/', ''),
'ticks': [
{'type': 'bid', 'price': bid, 'size': bsize},
{'type': 'bsize', 'price': bid, 'size': bsize},
{'type': 'ask', 'price': ask, 'size': asize},
{'type': 'asize', 'price': ask, 'size': asize},
],
}
yield 'l1', quote
# elif 'book' in msg[-2]:
# chan_id, *payload_array, chan_name, pair = msg
# print(msg)
case _:
print(f'UNHANDLED MSG: {msg}')
# yield msg
def normalize(
ohlc: OHLC,
) -> dict:
quote = ohlc.to_dict()
quote['broker_ts'] = quote['time']
quote['brokerd_ts'] = time.time()
quote['symbol'] = quote['pair'] = quote['pair'].replace('/', '')
quote['last'] = quote['close']
quote['bar_wap'] = ohlc.vwap
# seriously eh? what's with this non-symmetry everywhere
# in subscription systems...
# XXX: piker style is always lowercases symbols.
topic = quote['pair'].replace('/', '').lower()
# print(quote)
return topic, quote
@acm
async def open_history_client(
symbol: str,
) -> tuple[Callable, int]:
# TODO implement history getter for the new storage layer.
async with open_cached_client('kraken') as client:
# lol, kraken won't send any more then the "last"
# 720 1m bars.. so we have to just ignore further
# requests of this type..
queries: int = 0
async def get_ohlc(
timeframe: float,
end_dt: Optional[datetime] = None,
start_dt: Optional[datetime] = None,
) -> tuple[
np.ndarray,
datetime, # start
datetime, # end
]:
nonlocal queries
if (
queries > 0
or timeframe != 60
):
raise DataUnavailable(
'Only a single query for 1m bars supported')
count = 0
while count <= 3:
try:
array = await client.bars(
symbol,
since=end_dt,
)
count += 1
queries += 1
break
except DataThrottle:
log.warning(f'kraken OHLC throttle for {symbol}')
await trio.sleep(1)
start_dt = pendulum.from_timestamp(array[0]['time'])
end_dt = pendulum.from_timestamp(array[-1]['time'])
return array, start_dt, end_dt
yield get_ohlc, {'erlangs': 1, 'rate': 1}
async def stream_quotes(
send_chan: trio.abc.SendChannel,
symbols: list[str],
feed_is_live: trio.Event,
loglevel: str = None,
# backend specific
sub_type: str = 'ohlc',
# startup sync
task_status: TaskStatus[tuple[dict, dict]] = trio.TASK_STATUS_IGNORED,
) -> None:
'''
Subscribe for ohlc stream of quotes for ``pairs``.
``pairs`` must be formatted <crypto_symbol>/<fiat_symbol>.
'''
# XXX: required to propagate ``tractor`` loglevel to piker logging
get_console_log(loglevel or tractor.current_actor().loglevel)
ws_pairs = {}
sym_infos = {}
async with open_cached_client('kraken') as client, send_chan as send_chan:
# keep client cached for real-time section
for sym in symbols:
# transform to upper since piker style is always lower
sym = sym.upper()
sym_info = await client.symbol_info(sym)
try:
si = Pair(**sym_info) # validation
except TypeError:
fields_diff = set(sym_info) - set(Pair.__struct_fields__)
raise TypeError(
f'Missing msg fields {fields_diff}'
)
syminfo = si.to_dict()
syminfo['price_tick_size'] = 1 / 10**si.pair_decimals
syminfo['lot_tick_size'] = 1 / 10**si.lot_decimals
syminfo['asset_type'] = 'crypto'
sym_infos[sym] = syminfo
ws_pairs[sym] = si.wsname
symbol = symbols[0].lower()
init_msgs = {
# pass back token, and bool, signalling if we're the writer
# and that history has been written
symbol: {
'symbol_info': sym_infos[sym],
'shm_write_opts': {'sum_tick_vml': False},
'fqsn': sym,
},
}
@acm
async def subscribe(ws: NoBsWs):
# XXX: setup subs
# https://docs.kraken.com/websockets/#message-subscribe
# specific logic for this in kraken's sync client:
# https://github.com/krakenfx/kraken-wsclient-py/blob/master/kraken_wsclient_py/kraken_wsclient_py.py#L188
ohlc_sub = {
'event': 'subscribe',
'pair': list(ws_pairs.values()),
'subscription': {
'name': 'ohlc',
'interval': 1,
},
}
# TODO: we want to eventually allow unsubs which should
# be completely fine to request from a separate task
# since internally the ws methods appear to be FIFO
# locked.
await ws.send_msg(ohlc_sub)
# trade data (aka L1)
l1_sub = {
'event': 'subscribe',
'pair': list(ws_pairs.values()),
'subscription': {
'name': 'spread',
# 'depth': 10}
},
}
# pull a first quote and deliver
await ws.send_msg(l1_sub)
yield
# unsub from all pairs on teardown
if ws.connected():
await ws.send_msg({
'pair': list(ws_pairs.values()),
'event': 'unsubscribe',
'subscription': ['ohlc', 'spread'],
})
# XXX: do we need to ack the unsub?
# await ws.recv_msg()
# see the tips on reconnection logic:
# https://support.kraken.com/hc/en-us/articles/360044504011-WebSocket-API-unexpected-disconnections-from-market-data-feeds
ws: NoBsWs
async with (
open_autorecon_ws(
'wss://ws.kraken.com/',
fixture=subscribe,
) as ws,
aclosing(process_data_feed_msgs(ws)) as msg_gen,
):
# pull a first quote and deliver
typ, ohlc_last = await anext(msg_gen)
topic, quote = normalize(ohlc_last)
task_status.started((init_msgs, quote))
# lol, only "closes" when they're margin squeezing clients ;P
feed_is_live.set()
# keep start of last interval for volume tracking
last_interval_start = ohlc_last.etime
# start streaming
async for typ, ohlc in msg_gen:
if typ == 'ohlc':
# TODO: can get rid of all this by using
# ``trades`` subscription...
# generate tick values to match time & sales pane:
# https://trade.kraken.com/charts/KRAKEN:BTC-USD?period=1m
volume = ohlc.volume
# new OHLC sample interval
if ohlc.etime > last_interval_start:
last_interval_start = ohlc.etime
tick_volume = volume
else:
# this is the tick volume *within the interval*
tick_volume = volume - ohlc_last.volume
ohlc_last = ohlc
last = ohlc.close
if tick_volume:
ohlc.ticks.append({
'type': 'trade',
'price': last,
'size': tick_volume,
})
topic, quote = normalize(ohlc)
elif typ == 'l1':
quote = ohlc
topic = quote['symbol'].lower()
await send_chan.send({topic: quote})
@tractor.context
async def open_symbol_search(
ctx: tractor.Context,
) -> Client:
async with open_cached_client('kraken') as client:
# load all symbols locally for fast search
cache = await client.cache_symbols()
await ctx.started(cache)
async with ctx.open_stream() as stream:
async for pattern in stream:
matches = fuzzy.extractBests(
pattern,
cache,
score_cutoff=50,
)
# repack in dict form
await stream.send(
{item[0]['altname']: item[0]
for item in matches}
)

View File

@ -18,9 +18,3 @@
Market machinery for order executions, book, management. Market machinery for order executions, book, management.
""" """
from ._client import open_ems
__all__ = [
'open_ems',
]

View File

@ -22,10 +22,54 @@ from enum import Enum
from typing import Optional from typing import Optional
from bidict import bidict from bidict import bidict
from pydantic import BaseModel, validator
from ..data._source import Symbol from ..data._source import Symbol
from ..data.types import Struct from ._messages import BrokerdPosition, Status
from ..pp import Position
class Position(BaseModel):
'''
Basic pp (personal position) model with attached fills history.
This type should be IPC wire ready?
'''
symbol: Symbol
# last size and avg entry price
size: float
avg_price: float # TODO: contextual pricing
# ordered record of known constituent trade messages
fills: list[Status] = []
def update_from_msg(
self,
msg: BrokerdPosition,
) -> None:
# XXX: better place to do this?
symbol = self.symbol
lot_size_digits = symbol.lot_size_digits
avg_price, size = (
round(msg['avg_price'], ndigits=symbol.tick_size_digits),
round(msg['size'], ndigits=lot_size_digits),
)
self.avg_price = avg_price
self.size = size
@property
def dsize(self) -> float:
'''
The "dollar" size of the pp, normally in trading (fiat) unit
terms.
'''
return self.avg_price * self.size
_size_units = bidict({ _size_units = bidict({
@ -40,9 +84,34 @@ SizeUnit = Enum(
) )
class Allocator(Struct): class Allocator(BaseModel):
class Config:
validate_assignment = True
copy_on_model_validation = False
arbitrary_types_allowed = True
# required to get the account validator lookup working?
extra = 'allow'
underscore_attrs_are_private = False
symbol: Symbol symbol: Symbol
account: Optional[str] = 'paper'
# TODO: for enums this clearly doesn't fucking work, you can't set
# a default at startup by passing in a `dict` but yet you can set
# that value through assignment..for wtv cucked reason.. honestly, pure
# unintuitive garbage.
size_unit: str = 'currency'
_size_units: dict[str, Optional[str]] = _size_units
@validator('size_unit', pre=True)
def maybe_lookup_key(cls, v):
# apply the corresponding enum key for the text "description" value
if v not in _size_units:
return _size_units.inverse[v]
assert v in _size_units
return v
# TODO: if we ever want ot support non-uniform entry-slot-proportion # TODO: if we ever want ot support non-uniform entry-slot-proportion
# "sizes" # "sizes"
@ -51,28 +120,6 @@ class Allocator(Struct):
units_limit: float units_limit: float
currency_limit: float currency_limit: float
slots: int slots: int
account: Optional[str] = 'paper'
_size_units: bidict[str, Optional[str]] = _size_units
# TODO: for enums this clearly doesn't fucking work, you can't set
# a default at startup by passing in a `dict` but yet you can set
# that value through assignment..for wtv cucked reason.. honestly, pure
# unintuitive garbage.
_size_unit: str = 'currency'
@property
def size_unit(self) -> str:
return self._size_unit
@size_unit.setter
def size_unit(self, v: str) -> Optional[str]:
if v not in _size_units:
v = _size_units.inverse[v]
assert v in _size_units
self._size_unit = v
return v
def step_sizes( def step_sizes(
self, self,
@ -93,13 +140,10 @@ class Allocator(Struct):
else: else:
return self.units_limit return self.units_limit
def limit_info(self) -> tuple[str, float]:
return self.size_unit, self.limit()
def next_order_info( def next_order_info(
self, self,
# we only need a startup size for exit calcs, we can then # we only need a startup size for exit calcs, we can the
# determine how large slots should be if the initial pp size was # determine how large slots should be if the initial pp size was
# larger then the current live one, and the live one is smaller # larger then the current live one, and the live one is smaller
# then the initial config settings. # then the initial config settings.
@ -129,7 +173,7 @@ class Allocator(Struct):
l_sub_pp = self.units_limit - abs_live_size l_sub_pp = self.units_limit - abs_live_size
elif size_unit == 'currency': elif size_unit == 'currency':
live_cost_basis = abs_live_size * live_pp.ppu live_cost_basis = abs_live_size * live_pp.avg_price
slot_size = currency_per_slot / price slot_size = currency_per_slot / price
l_sub_pp = (self.currency_limit - live_cost_basis) / price l_sub_pp = (self.currency_limit - live_cost_basis) / price
@ -140,14 +184,12 @@ class Allocator(Struct):
# an entry (adding-to or starting a pp) # an entry (adding-to or starting a pp)
if ( if (
action == 'buy' and live_size > 0 or
action == 'sell' and live_size < 0 or
live_size == 0 live_size == 0
or (action == 'buy' and live_size > 0)
or action == 'sell' and live_size < 0
): ):
order_size = min(
slot_size, order_size = min(slot_size, l_sub_pp)
max(l_sub_pp, 0),
)
# an exit (removing-from or going to net-zero pp) # an exit (removing-from or going to net-zero pp)
else: else:
@ -163,7 +205,7 @@ class Allocator(Struct):
if size_unit == 'currency': if size_unit == 'currency':
# compute the "projected" limit's worth of units at the # compute the "projected" limit's worth of units at the
# current pp (weighted) price: # current pp (weighted) price:
slot_size = currency_per_slot / live_pp.ppu slot_size = currency_per_slot / live_pp.avg_price
else: else:
slot_size = u_per_slot slot_size = u_per_slot
@ -202,12 +244,7 @@ class Allocator(Struct):
if order_size < slot_size: if order_size < slot_size:
# compute a fractional slots size to display # compute a fractional slots size to display
slots_used = self.slots_used( slots_used = self.slots_used(
Position( Position(symbol=sym, size=order_size, avg_price=price)
symbol=sym,
size=order_size,
ppu=price,
bsuid=sym,
)
) )
return { return {
@ -234,8 +271,8 @@ class Allocator(Struct):
abs_pp_size = abs(pp.size) abs_pp_size = abs(pp.size)
if self.size_unit == 'currency': if self.size_unit == 'currency':
# live_currency_size = size or (abs_pp_size * pp.ppu) # live_currency_size = size or (abs_pp_size * pp.avg_price)
live_currency_size = abs_pp_size * pp.ppu live_currency_size = abs_pp_size * pp.avg_price
prop = live_currency_size / self.currency_limit prop = live_currency_size / self.currency_limit
else: else:
@ -247,6 +284,14 @@ class Allocator(Struct):
return round(prop * self.slots) return round(prop * self.slots)
_derivs = (
'future',
'continuous_future',
'option',
'futures_option',
)
def mk_allocator( def mk_allocator(
symbol: Symbol, symbol: Symbol,
@ -255,7 +300,7 @@ def mk_allocator(
# default allocation settings # default allocation settings
defaults: dict[str, float] = { defaults: dict[str, float] = {
'account': None, # select paper by default 'account': None, # select paper by default
# 'size_unit': 'currency', 'size_unit': 'currency',
'units_limit': 400, 'units_limit': 400,
'currency_limit': 5e3, 'currency_limit': 5e3,
'slots': 4, 'slots': 4,
@ -273,9 +318,42 @@ def mk_allocator(
'currency_limit': 6e3, 'currency_limit': 6e3,
'slots': 6, 'slots': 6,
} }
defaults.update(user_def) defaults.update(user_def)
return Allocator( alloc = Allocator(
symbol=symbol, symbol=symbol,
**defaults, **defaults,
) )
asset_type = symbol.type_key
# specific configs by asset class / type
if asset_type in _derivs:
# since it's harder to know how currency "applies" in this case
# given leverage properties
alloc.size_unit = '# units'
# set units limit to slots size thus making make the next
# entry step 1.0
alloc.units_limit = alloc.slots
# if the current position is already greater then the limit
# settings, increase the limit to the current position
if alloc.size_unit == 'currency':
startup_size = startup_pp.size * startup_pp.avg_price
if startup_size > alloc.currency_limit:
alloc.currency_limit = round(startup_size, ndigits=2)
else:
startup_size = abs(startup_pp.size)
if startup_size > alloc.units_limit:
alloc.units_limit = startup_size
if asset_type in _derivs:
alloc.slots = alloc.units_limit
return alloc

View File

@ -18,32 +18,26 @@
Orders and execution client API. Orders and execution client API.
""" """
from __future__ import annotations
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager as acm
from typing import Dict
from pprint import pformat from pprint import pformat
from typing import TYPE_CHECKING from dataclasses import dataclass, field
import trio import trio
import tractor import tractor
from tractor.trionics import broadcast_receiver from tractor.trionics import broadcast_receiver
from ..log import get_logger from ..log import get_logger
from ..data.types import Struct from ._ems import _emsd_main
from .._daemon import maybe_open_emsd from .._daemon import maybe_open_emsd
from ._messages import Order, Cancel from ._messages import Order, Cancel
from ..brokers import get_brokermod
if TYPE_CHECKING:
from ._messages import (
BrokerdPosition,
Status,
)
log = get_logger(__name__) log = get_logger(__name__)
class OrderBook(Struct): @dataclass
class OrderBook:
'''EMS-client-side order book ctl and tracking. '''EMS-client-side order book ctl and tracking.
A style similar to "model-view" is used here where this api is A style similar to "model-view" is used here where this api is
@ -58,18 +52,20 @@ class OrderBook(Struct):
# mem channels used to relay order requests to the EMS daemon # mem channels used to relay order requests to the EMS daemon
_to_ems: trio.abc.SendChannel _to_ems: trio.abc.SendChannel
_from_order_book: trio.abc.ReceiveChannel _from_order_book: trio.abc.ReceiveChannel
_sent_orders: dict[str, Order] = {}
_sent_orders: Dict[str, Order] = field(default_factory=dict)
_ready_to_receive: trio.Event = trio.Event()
def send( def send(
self, self,
msg: Order | dict, msg: Order,
) -> dict: ) -> dict:
self._sent_orders[msg.oid] = msg self._sent_orders[msg.oid] = msg
self._to_ems.send_nowait(msg) self._to_ems.send_nowait(msg.dict())
return msg return msg
def send_update( def update(
self, self,
uuid: str, uuid: str,
@ -77,8 +73,9 @@ class OrderBook(Struct):
) -> dict: ) -> dict:
cmd = self._sent_orders[uuid] cmd = self._sent_orders[uuid]
msg = cmd.copy(update=data) msg = cmd.dict()
self._sent_orders[uuid] = msg msg.update(data)
self._sent_orders[uuid] = Order(**msg)
self._to_ems.send_nowait(msg) self._to_ems.send_nowait(msg)
return cmd return cmd
@ -86,18 +83,12 @@ class OrderBook(Struct):
"""Cancel an order (or alert) in the EMS. """Cancel an order (or alert) in the EMS.
""" """
cmd = self._sent_orders.get(uuid) cmd = self._sent_orders[uuid]
if not cmd:
log.error(
f'Unknown order {uuid}!?\n'
f'Maybe there is a stale entry or line?\n'
f'You should report this as a bug!'
)
msg = Cancel( msg = Cancel(
oid=uuid, oid=uuid,
symbol=cmd.symbol, symbol=cmd.symbol,
) )
self._to_ems.send_nowait(msg) self._to_ems.send_nowait(msg.dict())
_orders: OrderBook = None _orders: OrderBook = None
@ -158,35 +149,21 @@ async def relay_order_cmds_from_sync_code(
book = get_orders() book = get_orders()
async with book._from_order_book.subscribe() as orders_stream: async with book._from_order_book.subscribe() as orders_stream:
async for cmd in orders_stream: async for cmd in orders_stream:
sym = cmd.symbol if cmd['symbol'] == symbol_key:
msg = pformat(cmd) log.info(f'Send order cmd:\n{pformat(cmd)}')
if sym == symbol_key:
log.info(f'Send order cmd:\n{msg}')
# send msg over IPC / wire # send msg over IPC / wire
await to_ems_stream.send(cmd) await to_ems_stream.send(cmd)
else:
log.warning(
f'Ignoring unmatched order cmd for {sym} != {symbol_key}:'
f'\n{msg}'
)
@acm @acm
async def open_ems( async def open_ems(
fqsn: str, fqsn: str,
mode: str = 'live',
) -> tuple[ ) -> (
OrderBook, OrderBook,
tractor.MsgStream, tractor.MsgStream,
dict[ dict,
# brokername, acctid ):
tuple[str, str],
list[BrokerdPosition],
],
list[str],
dict[str, Status],
]:
''' '''
Spawn an EMS daemon and begin sending orders and receiving Spawn an EMS daemon and begin sending orders and receiving
alerts. alerts.
@ -229,35 +206,18 @@ async def open_ems(
async with maybe_open_emsd(broker) as portal: async with maybe_open_emsd(broker) as portal:
mod = get_brokermod(broker)
if (
not getattr(mod, 'trades_dialogue', None)
or mode == 'paper'
):
mode = 'paper'
from ._ems import _emsd_main
async with ( async with (
# connect to emsd # connect to emsd
portal.open_context( portal.open_context(
_emsd_main, _emsd_main,
fqsn=fqsn, fqsn=fqsn,
exec_mode=mode,
) as ( ) as (ctx, (positions, accounts)),
ctx,
(
positions,
accounts,
dialogs,
)
),
# open 2-way trade command stream # open 2-way trade command stream
ctx.open_stream() as trades_stream, ctx.open_stream() as trades_stream,
): ):
# start sync code order msg delivery task
async with trio.open_nursery() as n: async with trio.open_nursery() as n:
n.start_soon( n.start_soon(
relay_order_cmds_from_sync_code, relay_order_cmds_from_sync_code,
@ -265,10 +225,4 @@ async def open_ems(
trades_stream trades_stream
) )
yield ( yield book, trades_stream, positions, accounts
book,
trades_stream,
positions,
accounts,
dialogs,
)

File diff suppressed because it is too large Load Diff

View File

@ -1,5 +1,5 @@
# piker: trading gear for hackers # piker: trading gear for hackers
# Copyright (C) Tyler Goodlet (in stewardship for pikers) # Copyright (C) Tyler Goodlet (in stewardship for piker0)
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -15,160 +15,108 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
Clearing sub-system message and protocols. Clearing system messagingn types and protocols.
""" """
# from collections import ( from typing import Optional, Union
# ChainMap,
# deque, # TODO: try out just encoding/send direction for now?
# ) # import msgspec
from typing import ( from pydantic import BaseModel
Optional,
Literal,
)
from ..data._source import Symbol from ..data._source import Symbol
from ..data.types import Struct
# TODO: a composite for tracking msg flow on 2-legged
# dialogs.
# class Dialog(ChainMap):
# '''
# Msg collection abstraction to easily track the state changes of
# a msg flow in one high level, query-able and immutable construct.
# The main use case is to query data from a (long-running)
# msg-transaction-sequence
# '''
# def update(
# self,
# msg,
# ) -> None:
# self.maps.insert(0, msg.to_dict())
# def flatten(self) -> dict:
# return dict(self)
# TODO: ``msgspec`` stuff worth paying attention to:
# - schema evolution:
# https://jcristharif.com/msgspec/usage.html#schema-evolution
# - for eg. ``BrokerdStatus``, instead just have separate messages?
# - use literals for a common msg determined by diff keys?
# - https://jcristharif.com/msgspec/usage.html#literal
# --------------
# Client -> emsd # Client -> emsd
# --------------
class Order(Struct):
# TODO: ideally we can combine these 2 fields into
# 1 and just use the size polarity to determine a buy/sell.
# i would like to see this become more like
# https://jcristharif.com/msgspec/usage.html#literal
# action: Literal[
# 'live',
# 'dark',
# 'alert',
# ]
action: Literal[
'buy',
'sell',
'alert',
]
# determines whether the create execution
# will be submitted to the ems or directly to
# the backend broker
exec_mode: Literal[
'dark',
'live',
# 'paper', no right?
]
# internal ``emdsd`` unique "order id"
oid: str # uuid4
symbol: str | Symbol
account: str # should we set a default as '' ?
price: float
size: float # -ve is "sell", +ve is "buy"
brokers: Optional[list[str]] = []
class Cancel(Struct): class Cancel(BaseModel):
''' '''Cancel msg for removing a dark (ems triggered) or
Cancel msg for removing a dark (ems triggered) or
broker-submitted (live) trigger/order. broker-submitted (live) trigger/order.
''' '''
action: str = 'cancel'
oid: str # uuid4 oid: str # uuid4
symbol: str symbol: str
action: str = 'cancel'
# -------------- class Order(BaseModel):
action: str # {'buy', 'sell', 'alert'}
# internal ``emdsd`` unique "order id"
oid: str # uuid4
symbol: Union[str, Symbol]
account: str # should we set a default as '' ?
price: float
size: float
brokers: list[str]
# Assigned once initial ack is received
# ack_time_ns: Optional[int] = None
# determines whether the create execution
# will be submitted to the ems or directly to
# the backend broker
exec_mode: str # {'dark', 'live', 'paper'}
class Config:
# just for pre-loading a ``Symbol`` when used
# in the order mode staging process
arbitrary_types_allowed = True
# don't copy this model instance when used in
# a recursive model
copy_on_model_validation = False
# Client <- emsd # Client <- emsd
# --------------
# update msgs from ems which relay state change info # update msgs from ems which relay state change info
# from the active clearing engine. # from the active clearing engine.
class Status(Struct):
time_ns: int class Status(BaseModel):
oid: str # uuid4 ems-order dialog id
resp: Literal[
'pending', # acked by broker but not yet open
'open',
'dark_open', # dark/algo triggered order is open in ems clearing loop
'triggered', # above triggered order sent to brokerd, or an alert closed
'closed', # fully cleared all size/units
'fill', # partial execution
'canceled',
'error',
]
name: str = 'status' name: str = 'status'
oid: str # uuid4
time_ns: int
# {
# 'dark_submitted',
# 'dark_cancelled',
# 'dark_triggered',
# 'broker_submitted',
# 'broker_cancelled',
# 'broker_executed',
# 'broker_filled',
# 'broker_errored',
# 'alert_submitted',
# 'alert_triggered',
# }
resp: str # "response", see above
# symbol: str
# trigger info
trigger_price: Optional[float] = None
# price: float
# broker: Optional[str] = None
# this maps normally to the ``BrokerdOrder.reqid`` below, an id # this maps normally to the ``BrokerdOrder.reqid`` below, an id
# normally allocated internally by the backend broker routing system # normally allocated internally by the backend broker routing system
reqid: Optional[int | str] = None broker_reqid: Optional[Union[int, str]] = None
# the (last) source order/request msg if provided # for relaying backend msg data "through" the ems layer
# (eg. the Order/Cancel which causes this msg) and
# acts as a back-reference to the corresponding
# request message which was the source of this msg.
req: Order | None = None
# XXX: better design/name here?
# flag that can be set to indicate a message for an order
# event that wasn't originated by piker's emsd (eg. some external
# trading system which does it's own order control but that you
# might want to "track" using piker UIs/systems).
src: Optional[str] = None
# set when a cancel request msg was set for this order flow dialog
# but the brokerd dialog isn't yet in a cancelled state.
cancel_called: bool = False
# for relaying a boxed brokerd-dialog-side msg data "through" the
# ems layer to clients.
brokerd_msg: dict = {} brokerd_msg: dict = {}
# ---------------
# emsd -> brokerd # emsd -> brokerd
# ---------------
# requests *sent* from ems to respective backend broker daemon # requests *sent* from ems to respective backend broker daemon
class BrokerdCancel(Struct): class BrokerdCancel(BaseModel):
action: str = 'cancel'
oid: str # piker emsd order id oid: str # piker emsd order id
time_ns: int time_ns: int
@ -179,39 +127,34 @@ class BrokerdCancel(Struct):
# for setting a unique order id then this value will be relayed back # for setting a unique order id then this value will be relayed back
# on the emsd order request stream as the ``BrokerdOrderAck.reqid`` # on the emsd order request stream as the ``BrokerdOrderAck.reqid``
# field # field
reqid: Optional[int | str] = None reqid: Optional[Union[int, str]] = None
action: str = 'cancel'
class BrokerdOrder(Struct): class BrokerdOrder(BaseModel):
action: str # {buy, sell}
oid: str oid: str
account: str account: str
time_ns: int time_ns: int
symbol: str # fqsn
price: float
size: float
# TODO: if we instead rely on a +ve/-ve size to determine
# the action we more or less don't need this field right?
action: str = '' # {buy, sell}
# "broker request id": broker specific/internal order id if this is # "broker request id": broker specific/internal order id if this is
# None, creates a new order otherwise if the id is valid the backend # None, creates a new order otherwise if the id is valid the backend
# api must modify the existing matching order. If the broker allows # api must modify the existing matching order. If the broker allows
# for setting a unique order id then this value will be relayed back # for setting a unique order id then this value will be relayed back
# on the emsd order request stream as the ``BrokerdOrderAck.reqid`` # on the emsd order request stream as the ``BrokerdOrderAck.reqid``
# field # field
reqid: Optional[int | str] = None reqid: Optional[Union[int, str]] = None
symbol: str # symbol.<providername> ?
price: float
size: float
# ---------------
# emsd <- brokerd # emsd <- brokerd
# ---------------
# requests *received* to ems from broker backend # requests *received* to ems from broker backend
class BrokerdOrderAck(Struct):
class BrokerdOrderAck(BaseModel):
''' '''
Immediate reponse to a brokerd order request providing the broker Immediate reponse to a brokerd order request providing the broker
specific unique order id so that the EMS can associate this specific unique order id so that the EMS can associate this
@ -219,35 +162,42 @@ class BrokerdOrderAck(Struct):
``.oid`` (which is a uuid4). ``.oid`` (which is a uuid4).
''' '''
name: str = 'ack'
# defined and provided by backend # defined and provided by backend
reqid: int | str reqid: Union[int, str]
# emsd id originally sent in matching request msg # emsd id originally sent in matching request msg
oid: str oid: str
account: str = '' account: str = ''
name: str = 'ack'
class BrokerdStatus(Struct): class BrokerdStatus(BaseModel):
reqid: int | str
time_ns: int
status: Literal[
'open',
'canceled',
'fill',
'pending',
'error',
]
account: str
name: str = 'status' name: str = 'status'
reqid: Union[int, str]
time_ns: int
# XXX: should be best effort set for every update
account: str = ''
# {
# 'submitted',
# 'cancelled',
# 'filled',
# }
status: str
filled: float = 0.0 filled: float = 0.0
reason: str = '' reason: str = ''
remaining: float = 0.0 remaining: float = 0.0
# external: bool = False # XXX: better design/name here?
# flag that can be set to indicate a message for an order
# event that wasn't originated by piker's emsd (eg. some external
# trading system which does it's own order control but that you
# might want to "track" using piker UIs/systems).
external: bool = False
# XXX: not required schema as of yet # XXX: not required schema as of yet
broker_details: dict = { broker_details: dict = {
@ -255,57 +205,59 @@ class BrokerdStatus(Struct):
} }
class BrokerdFill(Struct): class BrokerdFill(BaseModel):
''' '''
A single message indicating a "fill-details" event from the broker A single message indicating a "fill-details" event from the broker
if avaiable. if avaiable.
''' '''
name: str = 'fill'
reqid: Union[int, str]
time_ns: int
# order exeuction related
action: str
size: float
price: float
broker_details: dict = {} # meta-data (eg. commisions etc.)
# brokerd timestamp required for order mode arrow placement on x-axis # brokerd timestamp required for order mode arrow placement on x-axis
# TODO: maybe int if we force ns? # TODO: maybe int if we force ns?
# we need to normalize this somehow since backends will use their # we need to normalize this somehow since backends will use their
# own format and likely across many disparate epoch clocks... # own format and likely across many disparate epoch clocks...
broker_time: float broker_time: float
reqid: int | str
time_ns: int
# order exeuction related
size: float
price: float
name: str = 'fill'
action: Optional[str] = None
broker_details: dict = {} # meta-data (eg. commisions etc.)
class BrokerdError(Struct): class BrokerdError(BaseModel):
''' '''
Optional error type that can be relayed to emsd for error handling. Optional error type that can be relayed to emsd for error handling.
This is still a TODO thing since we're not sure how to employ it yet. This is still a TODO thing since we're not sure how to employ it yet.
''' '''
name: str = 'error'
oid: str oid: str
symbol: str
reason: str
# if no brokerd order request was actually submitted (eg. we errored # if no brokerd order request was actually submitted (eg. we errored
# at the ``pikerd`` layer) then there will be ``reqid`` allocated. # at the ``pikerd`` layer) then there will be ``reqid`` allocated.
reqid: Optional[int | str] = None reqid: Optional[Union[int, str]] = None
name: str = 'error' symbol: str
reason: str
broker_details: dict = {} broker_details: dict = {}
class BrokerdPosition(Struct): class BrokerdPosition(BaseModel):
'''Position update event from brokerd. '''Position update event from brokerd.
''' '''
name: str = 'position'
broker: str broker: str
account: str account: str
symbol: str symbol: str
currency: str
size: float size: float
avg_price: float avg_price: float
currency: str = ''
name: str = 'position'

View File

@ -18,71 +18,54 @@
Fake trading for forward testing. Fake trading for forward testing.
""" """
from collections import defaultdict
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from datetime import datetime from datetime import datetime
from operator import itemgetter from operator import itemgetter
import itertools
import time import time
from typing import ( from typing import Tuple, Optional, Callable
Any,
Optional,
Callable,
)
import uuid import uuid
from bidict import bidict from bidict import bidict
import pendulum
import trio import trio
import tractor import tractor
from dataclasses import dataclass
from .. import data from .. import data
from ..data._source import Symbol
from ..data.types import Struct
from ..pp import (
Position,
Transaction,
)
from ..data._normalize import iterticks from ..data._normalize import iterticks
from ..data._source import unpack_fqsn from ..data._source import unpack_fqsn
from ..log import get_logger from ..log import get_logger
from ._messages import ( from ._messages import (
BrokerdCancel, BrokerdCancel, BrokerdOrder, BrokerdOrderAck, BrokerdStatus,
BrokerdOrder, BrokerdFill, BrokerdPosition, BrokerdError
BrokerdOrderAck,
BrokerdStatus,
BrokerdFill,
BrokerdPosition,
BrokerdError,
) )
log = get_logger(__name__) log = get_logger(__name__)
class PaperBoi(Struct): @dataclass
''' class PaperBoi:
Emulates a broker order client providing approximately the same API """
and delivering an order-event response stream but with methods for Emulates a broker order client providing the same API and
delivering an order-event response stream but with methods for
triggering desired events based on forward testing engine triggering desired events based on forward testing engine
requirements (eg open, closed, fill msgs). requirements.
''' """
broker: str broker: str
ems_trades_stream: tractor.MsgStream ems_trades_stream: tractor.MsgStream
# map of paper "live" orders which be used # map of paper "live" orders which be used
# to simulate fills based on paper engine settings # to simulate fills based on paper engine settings
_buys: defaultdict[str, bidict] _buys: bidict
_sells: defaultdict[str, bidict] _sells: bidict
_reqids: bidict _reqids: bidict
_positions: dict[str, Position] _positions: dict[str, BrokerdPosition]
_trade_ledger: dict[str, Any]
# init edge case L1 spread # init edge case L1 spread
last_ask: tuple[float, float] = (float('inf'), 0) # price, size last_ask: Tuple[float, float] = (float('inf'), 0) # price, size
last_bid: tuple[float, float] = (0, 0) last_bid: Tuple[float, float] = (0, 0)
async def submit_limit( async def submit_limit(
self, self,
@ -92,24 +75,27 @@ class PaperBoi(Struct):
action: str, action: str,
size: float, size: float,
reqid: Optional[str], reqid: Optional[str],
) -> int: ) -> int:
''' """Place an order and return integer request id provided by client.
Place an order and return integer request id provided by client.
"""
is_modify: bool = False
if reqid is None:
reqid = str(uuid.uuid4())
else:
# order is already existing, this is a modify
(oid, symbol, action, old_price) = self._reqids[reqid]
assert old_price != price
is_modify = True
# register order internally
self._reqids[reqid] = (oid, symbol, action, price)
'''
if action == 'alert': if action == 'alert':
# bypass all fill simulation # bypass all fill simulation
return reqid return reqid
entry = self._reqids.get(reqid)
if entry:
# order is already existing, this is a modify
(oid, symbol, action, old_price) = entry
else:
# register order internally
self._reqids[reqid] = (oid, symbol, action, price)
# TODO: net latency model # TODO: net latency model
# we checkpoint here quickly particulalry # we checkpoint here quickly particulalry
# for dark orders since we want the dark_executed # for dark orders since we want the dark_executed
@ -121,18 +107,15 @@ class PaperBoi(Struct):
size = -size size = -size
msg = BrokerdStatus( msg = BrokerdStatus(
status='open', status='submitted',
# account=f'paper_{self.broker}',
account='paper',
reqid=reqid, reqid=reqid,
broker=self.broker,
time_ns=time.time_ns(), time_ns=time.time_ns(),
filled=0.0, filled=0.0,
reason='paper_trigger', reason='paper_trigger',
remaining=size, remaining=size,
broker_details={'name': 'paperboi'},
) )
await self.ems_trades_stream.send(msg) await self.ems_trades_stream.send(msg.dict())
# if we're already a clearing price simulate an immediate fill # if we're already a clearing price simulate an immediate fill
if ( if (
@ -140,28 +123,28 @@ class PaperBoi(Struct):
) or ( ) or (
action == 'sell' and (clear_price := self.last_bid[0]) >= price action == 'sell' and (clear_price := self.last_bid[0]) >= price
): ):
await self.fake_fill( await self.fake_fill(symbol, clear_price, size, action, reqid, oid)
symbol,
clear_price,
size,
action,
reqid,
oid,
)
# register this submissions as a paper live order
else: else:
# set the simulated order in the respective table for lookup # register this submissions as a paper live order
# and trigger by the simulated clearing task normally
# running ``simulate_fills()``. # submit order to book simulation fill loop
if action == 'buy': if action == 'buy':
orders = self._buys orders = self._buys
elif action == 'sell': elif action == 'sell':
orders = self._sells orders = self._sells
# {symbol -> bidict[oid, (<price data>)]} # set the simulated order in the respective table for lookup
orders[symbol][oid] = (price, size, reqid, action) # and trigger by the simulated clearing task normally
# running ``simulate_fills()``.
if is_modify:
# remove any existing order for the old price
orders[symbol].pop((oid, old_price))
# buys/sells: (symbol -> (price -> order))
orders.setdefault(symbol, {})[(oid, price)] = (size, reqid, action)
return reqid return reqid
@ -174,26 +157,26 @@ class PaperBoi(Struct):
oid, symbol, action, price = self._reqids[reqid] oid, symbol, action, price = self._reqids[reqid]
if action == 'buy': if action == 'buy':
self._buys[symbol].pop(oid, None) self._buys[symbol].pop((oid, price))
elif action == 'sell': elif action == 'sell':
self._sells[symbol].pop(oid, None) self._sells[symbol].pop((oid, price))
# TODO: net latency model # TODO: net latency model
await trio.sleep(0.05) await trio.sleep(0.05)
msg = BrokerdStatus( msg = BrokerdStatus(
status='canceled', status='cancelled',
account='paper', oid=oid,
reqid=reqid, reqid=reqid,
broker=self.broker,
time_ns=time.time_ns(), time_ns=time.time_ns(),
broker_details={'name': 'paperboi'},
) )
await self.ems_trades_stream.send(msg) await self.ems_trades_stream.send(msg.dict())
async def fake_fill( async def fake_fill(
self, self,
fqsn: str, symbol: str,
price: float, price: float,
size: float, size: float,
action: str, # one of {'buy', 'sell'} action: str, # one of {'buy', 'sell'}
@ -207,21 +190,21 @@ class PaperBoi(Struct):
remaining: float = 0, remaining: float = 0,
) -> None: ) -> None:
''' """Pretend to fill a broker order @ price and size.
Pretend to fill a broker order @ price and size.
''' """
# TODO: net latency model # TODO: net latency model
await trio.sleep(0.05) await trio.sleep(0.05)
fill_time_ns = time.time_ns()
fill_time_s = time.time()
fill_msg = BrokerdFill( msg = BrokerdFill(
reqid=reqid, reqid=reqid,
time_ns=fill_time_ns, time_ns=time.time_ns(),
action=action, action=action,
size=size, size=size,
price=price, price=price,
broker_time=datetime.now().timestamp(), broker_time=datetime.now().timestamp(),
broker_details={ broker_details={
'paper_info': { 'paper_info': {
@ -231,67 +214,79 @@ class PaperBoi(Struct):
'name': self.broker + '_paper', 'name': self.broker + '_paper',
}, },
) )
log.info(f'Fake filling order:\n{fill_msg}') await self.ems_trades_stream.send(msg.dict())
await self.ems_trades_stream.send(fill_msg)
self._trade_ledger.update(fill_msg.to_dict())
if order_complete: if order_complete:
msg = BrokerdStatus( msg = BrokerdStatus(
reqid=reqid, reqid=reqid,
time_ns=time.time_ns(), time_ns=time.time_ns(),
# account=f'paper_{self.broker}',
account='paper', status='filled',
status='closed',
filled=size, filled=size,
remaining=0 if order_complete else remaining, remaining=0 if order_complete else remaining,
)
await self.ems_trades_stream.send(msg)
# lookup any existing position action=action,
key = fqsn.rstrip(f'.{self.broker}')
pp = self._positions.setdefault(
fqsn,
Position(
Symbol(
key=key,
broker_info={self.broker: {}},
),
size=size,
ppu=price,
bsuid=key,
)
)
t = Transaction(
fqsn=fqsn,
tid=oid,
size=size, size=size,
price=price, price=price,
cost=0, # TODO: cost model
dt=pendulum.from_timestamp(fill_time_s),
bsuid=key,
)
pp.add_clear(t)
pp_msg = BrokerdPosition( broker_details={
'paper_info': {
'oid': oid,
},
'name': self.broker,
},
)
await self.ems_trades_stream.send(msg.dict())
# lookup any existing position
token = f'{symbol}.{self.broker}'
pp_msg = self._positions.setdefault(
token,
BrokerdPosition(
broker=self.broker, broker=self.broker,
account='paper', account='paper',
symbol=fqsn, symbol=symbol,
# TODO: we need to look up the asset currency from # TODO: we need to look up the asset currency from
# broker info. i guess for crypto this can be # broker info. i guess for crypto this can be
# inferred from the pair? # inferred from the pair?
currency='', currency='',
size=pp.size, size=0.0,
avg_price=pp.ppu, avg_price=0,
)
) )
await self.ems_trades_stream.send(pp_msg) # "avg position price" calcs
# TODO: eventually it'd be nice to have a small set of routines
# to do this stuff from a sequence of cleared orders to enable
# so called "contextual positions".
new_size = size + pp_msg.size
# old size minus the new size gives us size differential with
# +ve -> increase in pp size
# -ve -> decrease in pp size
size_diff = abs(new_size) - abs(pp_msg.size)
if new_size == 0:
pp_msg.avg_price = 0
elif size_diff > 0:
# only update the "average position price" when the position
# size increases not when it decreases (i.e. the position is
# being made smaller)
pp_msg.avg_price = (
abs(size) * price + pp_msg.avg_price * abs(pp_msg.size)
) / abs(new_size)
pp_msg.size = new_size
await self.ems_trades_stream.send(pp_msg.dict())
async def simulate_fills( async def simulate_fills(
quote_stream: tractor.MsgStream, # noqa quote_stream: 'tractor.ReceiveStream', # noqa
client: PaperBoi, client: PaperBoi,
) -> None: ) -> None:
# TODO: more machinery to better simulate real-world market things: # TODO: more machinery to better simulate real-world market things:
@ -311,116 +306,61 @@ async def simulate_fills(
# this stream may eventually contain multiple symbols # this stream may eventually contain multiple symbols
async for quotes in quote_stream: async for quotes in quote_stream:
for sym, quote in quotes.items(): for sym, quote in quotes.items():
for tick in iterticks( for tick in iterticks(
quote, quote,
# dark order price filter(s) # dark order price filter(s)
types=('ask', 'bid', 'trade', 'last') types=('ask', 'bid', 'trade', 'last')
): ):
tick_price = tick['price'] # print(tick)
tick_price = tick.get('price')
ttype = tick['type']
buys: bidict[str, tuple] = client._buys[sym] if ttype in ('ask',):
iter_buys = reversed(sorted(
buys.values(),
key=itemgetter(0),
))
def buy_on_ask(our_price):
return tick_price <= our_price
sells: bidict[str, tuple] = client._sells[sym]
iter_sells = sorted(
sells.values(),
key=itemgetter(0)
)
def sell_on_bid(our_price):
return tick_price >= our_price
match tick:
# on an ask queue tick, only clear buy entries
case {
'price': tick_price,
'type': 'ask',
}:
client.last_ask = ( client.last_ask = (
tick_price, tick_price,
tick.get('size', client.last_ask[1]), tick.get('size', client.last_ask[1]),
) )
iter_entries = zip( orders = client._buys.get(sym, {})
iter_buys,
itertools.repeat(buy_on_ask) book_sequence = reversed(
) sorted(orders.keys(), key=itemgetter(1)))
def pred(our_price):
return tick_price < our_price
elif ttype in ('bid',):
# on a bid queue tick, only clear sell entries
case {
'price': tick_price,
'type': 'bid',
}:
client.last_bid = ( client.last_bid = (
tick_price, tick_price,
tick.get('size', client.last_bid[1]), tick.get('size', client.last_bid[1]),
) )
iter_entries = zip( orders = client._sells.get(sym, {})
iter_sells, book_sequence = sorted(orders.keys(), key=itemgetter(1))
itertools.repeat(sell_on_bid)
)
# TODO: fix this block, though it definitely def pred(our_price):
# costs a lot more CPU-wise return tick_price > our_price
# - doesn't seem like clears are happening still on
# "resting" limit orders?
case {
'price': tick_price,
'type': ('trade' | 'last'),
}:
# in the clearing price / last price case we
# want to iterate both sides of our book for
# clears since we don't know which direction the
# price is going to move (especially with HFT)
# and thus we simply interleave both sides (buys
# and sells) until one side clears and then
# break until the next tick?
def interleave():
for pair in zip(
iter_buys,
iter_sells,
):
for order_info, pred in zip(
pair,
itertools.cycle([buy_on_ask, sell_on_bid]),
):
yield order_info, pred
iter_entries = interleave() elif ttype in ('trade', 'last'):
# TODO: simulate actual book queues and our orders
# NOTE: all other (non-clearable) tick event types # place in it, might require full L2 data?
# - we don't want to sping the simulated clear loop
# below unecessarily and further don't want to pop
# simulated live orders prematurely.
case _:
continue continue
# iterate all potentially clearable book prices # iterate book prices descending
# in FIFO order per side. for oid, our_price in book_sequence:
for order_info, pred in iter_entries: if pred(our_price):
(our_price, size, reqid, action) = order_info
# print(order_info) # retreive order info
clearable = pred(our_price) (size, reqid, action) = orders.pop((oid, our_price))
if clearable:
# pop and retreive order info
oid = {
'buy': buys,
'sell': sells
}[action].inverse.pop(order_info)
# clearing price would have filled entirely # clearing price would have filled entirely
await client.fake_fill( await client.fake_fill(
fqsn=sym, symbol=sym,
# todo slippage to determine fill price # todo slippage to determine fill price
price=tick_price, price=tick_price,
size=size, size=size,
@ -428,6 +368,9 @@ async def simulate_fills(
reqid=reqid, reqid=reqid,
oid=oid, oid=oid,
) )
else:
# prices are iterated in sorted order so we're done
break
async def handle_order_requests( async def handle_order_requests(
@ -437,83 +380,68 @@ async def handle_order_requests(
) -> None: ) -> None:
request_msg: dict # order_request: dict
async for request_msg in ems_order_stream: async for request_msg in ems_order_stream:
match request_msg:
case {'action': ('buy' | 'sell')}:
order = BrokerdOrder(**request_msg)
account = order.account
# error on bad inputs action = request_msg['action']
reason = None
if action in {'buy', 'sell'}:
account = request_msg['account']
if account != 'paper': if account != 'paper':
reason = f'No account found:`{account}` (paper only)?' log.error(
'This is a paper account, only a `paper` selection is valid'
elif order.size == 0: )
reason = 'Invalid size: 0'
if reason:
log.error(reason)
await ems_order_stream.send(BrokerdError( await ems_order_stream.send(BrokerdError(
oid=order.oid, oid=request_msg['oid'],
symbol=order.symbol, symbol=request_msg['symbol'],
reason=reason, reason=f'Paper only. No account found: `{account}` ?',
)) ).dict())
continue continue
reqid = order.reqid or str(uuid.uuid4()) # validate
order = BrokerdOrder(**request_msg)
# deliver ack that order has been submitted to broker routing
await ems_order_stream.send(
BrokerdOrderAck(
oid=order.oid,
reqid=reqid,
)
)
# call our client api to submit the order # call our client api to submit the order
reqid = await client.submit_limit( reqid = await client.submit_limit(
oid=order.oid, oid=order.oid,
symbol=f'{order.symbol}.{client.broker}', symbol=order.symbol,
price=order.price, price=order.price,
action=order.action, action=order.action,
size=order.size, size=order.size,
# XXX: by default 0 tells ``ib_insync`` methods that # XXX: by default 0 tells ``ib_insync`` methods that
# there is no existing order so ask the client to create # there is no existing order so ask the client to create
# a new one (which it seems to do by allocating an int # a new one (which it seems to do by allocating an int
# counter - collision prone..) # counter - collision prone..)
reqid=reqid, reqid=order.reqid,
) )
log.info(f'Submitted paper LIMIT {reqid}:\n{order}')
case {'action': 'cancel'}: # deliver ack that order has been submitted to broker routing
await ems_order_stream.send(
BrokerdOrderAck(
# ems order request id
oid=order.oid,
# broker specific request id
reqid=reqid,
).dict()
)
elif action == 'cancel':
msg = BrokerdCancel(**request_msg) msg = BrokerdCancel(**request_msg)
await client.submit_cancel( await client.submit_cancel(
reqid=msg.reqid reqid=msg.reqid
) )
case _: else:
log.error(f'Unknown order command: {request_msg}') log.error(f'Unknown order command: {request_msg}')
_reqids: bidict[str, tuple] = {}
_buys: defaultdict[
str, # symbol
bidict[
str, # oid
tuple[float, float, str, str], # order info
]
] = defaultdict(bidict)
_sells: defaultdict[
str, # symbol
bidict[
str, # oid
tuple[float, float, str, str], # order info
]
] = defaultdict(bidict)
_positions: dict[str, Position] = {}
@tractor.context @tractor.context
async def trades_dialogue( async def trades_dialogue(
@ -523,62 +451,42 @@ async def trades_dialogue(
loglevel: str = None, loglevel: str = None,
) -> None: ) -> None:
tractor.log.get_console_log(loglevel) tractor.log.get_console_log(loglevel)
async with ( async with (
data.open_feed( data.open_feed(
[fqsn], [fqsn],
loglevel=loglevel, loglevel=loglevel,
) as feed, ) as feed,
): ):
pp_msgs: list[BrokerdPosition] = []
pos: Position
token: str # f'{symbol}.{self.broker}'
for token, pos in _positions.items():
pp_msgs.append(BrokerdPosition(
broker=broker,
account='paper',
symbol=pos.symbol.front_fqsn(),
size=pos.size,
avg_price=pos.ppu,
))
# TODO: load paper positions per broker from .toml config file # TODO: load paper positions per broker from .toml config file
# and pass as symbol to position data mapping: ``dict[str, dict]`` # and pass as symbol to position data mapping: ``dict[str, dict]``
await ctx.started(( # await ctx.started(all_positions)
pp_msgs, await ctx.started(({}, {'paper',}))
['paper'],
))
async with ( async with (
ctx.open_stream() as ems_stream, ctx.open_stream() as ems_stream,
trio.open_nursery() as n, trio.open_nursery() as n,
): ):
client = PaperBoi( client = PaperBoi(
broker, broker,
ems_stream, ems_stream,
_buys=_buys, _buys={},
_sells=_sells, _sells={},
_reqids=_reqids, _reqids={},
# TODO: load paper positions from ``positions.toml`` # TODO: load paper positions from ``positions.toml``
_positions=_positions, _positions={},
# TODO: load postions from ledger file
_trade_ledger={},
) )
n.start_soon( n.start_soon(handle_order_requests, client, ems_stream)
handle_order_requests,
client,
ems_stream,
)
# paper engine simulator clearing task # paper engine simulator clearing task
await simulate_fills(feed.streams[broker], client) await simulate_fills(feed.stream, client)
@asynccontextmanager @asynccontextmanager
@ -603,7 +511,6 @@ async def open_paperboi(
# (we likely don't need more then one proc for basic # (we likely don't need more then one proc for basic
# simulated order clearing) # simulated order clearing)
if portal is None: if portal is None:
log.info('Starting new paper-engine actor')
portal = await tn.start_actor( portal = await tn.start_actor(
service_name, service_name,
enable_modules=[__name__] enable_modules=[__name__]
@ -616,4 +523,5 @@ async def open_paperboi(
loglevel=loglevel, loglevel=loglevel,
) as (ctx, first): ) as (ctx, first):
yield ctx, first yield ctx, first

View File

@ -27,35 +27,25 @@ import tractor
from ..log import get_console_log, get_logger, colorize_json from ..log import get_console_log, get_logger, colorize_json
from ..brokers import get_brokermod from ..brokers import get_brokermod
from .._daemon import ( from .._daemon import _tractor_kwargs
_default_registry_host,
_default_registry_port,
)
from .. import config from .. import config
log = get_logger('cli') log = get_logger('cli')
DEFAULT_BROKER = 'questrade'
@click.command() @click.command()
@click.option('--loglevel', '-l', default='warning', help='Logging level') @click.option('--loglevel', '-l', default='warning', help='Logging level')
@click.option('--tl', is_flag=True, help='Enable tractor logging') @click.option('--tl', is_flag=True, help='Enable tractor logging')
@click.option('--pdb', is_flag=True, help='Enable tractor debug mode') @click.option('--pdb', is_flag=True, help='Enable tractor debug mode')
@click.option('--host', '-h', default=None, help='Host addr to bind') @click.option('--host', '-h', default='127.0.0.1', help='Host address to bind')
@click.option('--port', '-p', default=None, help='Port number to bind')
@click.option( @click.option(
'--tsdb', '--tsdb',
is_flag=True, is_flag=True,
help='Enable local ``marketstore`` instance' help='Enable local ``marketstore`` instance'
) )
def pikerd( def pikerd(loglevel, host, tl, pdb, tsdb):
loglevel: str,
host: str,
port: int,
tl: bool,
pdb: bool,
tsdb: bool,
):
''' '''
Spawn the piker broker-daemon. Spawn the piker broker-daemon.
@ -72,21 +62,12 @@ def pikerd(
"\n" "\n"
)) ))
reg_addr: None | tuple[str, int] = None
if host or port:
reg_addr = (
host or _default_registry_host,
int(port) or _default_registry_port,
)
async def main(): async def main():
async with ( async with (
open_pikerd( open_pikerd(
loglevel=loglevel, loglevel=loglevel,
debug_mode=pdb, debug_mode=pdb,
registry_addr=reg_addr,
), # normally delivers a ``Services`` handle ), # normally delivers a ``Services`` handle
trio.open_nursery() as n, trio.open_nursery() as n,
): ):
@ -102,9 +83,9 @@ def pikerd(
) )
log.info( log.info(
f'`marketstored` up!\n' f'`marketstore` up!\n'
f'pid: {pid}\n' f'`marketstored` pid: {pid}\n'
f'container id: {cid[:12]}\n' f'docker container id: {cid}\n'
f'config: {pformat(config)}' f'config: {pformat(config)}'
) )
@ -116,46 +97,25 @@ def pikerd(
@click.group(context_settings=config._context_defaults) @click.group(context_settings=config._context_defaults)
@click.option( @click.option(
'--brokers', '-b', '--brokers', '-b',
default=None, default=[DEFAULT_BROKER],
multiple=True, multiple=True,
help='Broker backend to use' help='Broker backend to use'
) )
@click.option('--loglevel', '-l', default='warning', help='Logging level') @click.option('--loglevel', '-l', default='warning', help='Logging level')
@click.option('--tl', is_flag=True, help='Enable tractor logging') @click.option('--tl', is_flag=True, help='Enable tractor logging')
@click.option('--configdir', '-c', help='Configuration directory') @click.option('--configdir', '-c', help='Configuration directory')
@click.option('--host', '-h', default=None, help='Host addr to bind')
@click.option('--port', '-p', default=None, help='Port number to bind')
@click.pass_context @click.pass_context
def cli( def cli(ctx, brokers, loglevel, tl, configdir):
ctx: click.Context,
brokers: list[str],
loglevel: str,
tl: bool,
configdir: str,
host: str,
port: int,
) -> None:
if configdir is not None: if configdir is not None:
assert os.path.isdir(configdir), f"`{configdir}` is not a valid path" assert os.path.isdir(configdir), f"`{configdir}` is not a valid path"
config._override_config_dir(configdir) config._override_config_dir(configdir)
ctx.ensure_object(dict) ctx.ensure_object(dict)
if not brokers: if len(brokers) == 1:
# (try to) load all (supposedly) supported data/broker backends brokermods = [get_brokermod(brokers[0])]
from piker.brokers import __brokers__ else:
brokers = __brokers__
brokermods = [get_brokermod(broker) for broker in brokers] brokermods = [get_brokermod(broker) for broker in brokers]
assert brokermods
reg_addr: None | tuple[str, int] = None
if host or port:
reg_addr = (
host or _default_registry_host,
int(port) or _default_registry_port,
)
ctx.obj.update({ ctx.obj.update({
'brokers': brokers, 'brokers': brokers,
@ -165,7 +125,6 @@ def cli(
'log': get_console_log(loglevel), 'log': get_console_log(loglevel),
'confdir': config._config_dir, 'confdir': config._config_dir,
'wl_path': config._watchlists_data_path, 'wl_path': config._watchlists_data_path,
'registry_addr': reg_addr,
}) })
# allow enabling same loglevel in ``tractor`` machinery # allow enabling same loglevel in ``tractor`` machinery
@ -175,40 +134,29 @@ def cli(
@cli.command() @cli.command()
@click.option('--tl', is_flag=True, help='Enable tractor logging') @click.option('--tl', is_flag=True, help='Enable tractor logging')
@click.argument('ports', nargs=-1, required=False) @click.argument('names', nargs=-1, required=False)
@click.pass_obj @click.pass_obj
def services(config, tl, ports): def services(config, tl, names):
from .._daemon import (
open_piker_runtime,
_default_registry_port,
_default_registry_host,
)
host = _default_registry_host
if not ports:
ports = [_default_registry_port]
async def list_services(): async def list_services():
nonlocal host
async with ( async with tractor.get_arbiter(
open_piker_runtime( *_tractor_kwargs['arbiter_addr']
name='service_query', ) as portal:
loglevel=config['loglevel'] if tl else None,
),
tractor.get_arbiter(
host=host,
port=ports[0]
) as portal
):
registry = await portal.run_from_ns('self', 'get_registry') registry = await portal.run_from_ns('self', 'get_registry')
json_d = {} json_d = {}
for key, socket in registry.items(): for key, socket in registry.items():
# name, uuid = uid
host, port = socket host, port = socket
json_d[key] = f'{host}:{port}' json_d[key] = f'{host}:{port}'
click.echo(f"{colorize_json(json_d)}") click.echo(f"{colorize_json(json_d)}")
trio.run(list_services) tractor.run(
list_services,
name='service_query',
loglevel=config['loglevel'] if tl else None,
arbiter_addr=_tractor_kwargs['arbiter_addr'],
)
def _load_clis() -> None: def _load_clis() -> None:

View File

@ -21,7 +21,6 @@ Broker configuration mgmt.
import platform import platform
import sys import sys
import os import os
from os import path
from os.path import dirname from os.path import dirname
import shutil import shutil
from typing import Optional from typing import Optional
@ -112,7 +111,6 @@ if _parent_user:
_conf_names: set[str] = { _conf_names: set[str] = {
'brokers', 'brokers',
'pps',
'trades', 'trades',
'watchlists', 'watchlists',
} }
@ -149,21 +147,19 @@ def get_conf_path(
conf_name: str = 'brokers', conf_name: str = 'brokers',
) -> str: ) -> str:
''' """Return the default config path normally under
Return the top-level default config path normally under ``~/.config/piker`` on linux.
``~/.config/piker`` on linux for a given ``conf_name``, the config
name.
Contains files such as: Contains files such as:
- brokers.toml - brokers.toml
- pp.toml
- watchlists.toml - watchlists.toml
- trades.toml
# maybe coming soon ;) # maybe coming soon ;)
- signals.toml - signals.toml
- strats.toml - strats.toml
''' """
assert conf_name in _conf_names assert conf_name in _conf_names
fn = _conf_fn_w_ext(conf_name) fn = _conf_fn_w_ext(conf_name)
return os.path.join( return os.path.join(
@ -177,7 +173,7 @@ def repodir():
Return the abspath to the repo directory. Return the abspath to the repo directory.
''' '''
dirpath = path.abspath( dirpath = os.path.abspath(
# we're 3 levels down in **this** module file # we're 3 levels down in **this** module file
dirname(dirname(os.path.realpath(__file__))) dirname(dirname(os.path.realpath(__file__)))
) )
@ -186,9 +182,7 @@ def repodir():
def load( def load(
conf_name: str = 'brokers', conf_name: str = 'brokers',
path: str = None, path: str = None
**tomlkws,
) -> (dict, str): ) -> (dict, str):
''' '''
@ -196,10 +190,6 @@ def load(
''' '''
path = path or get_conf_path(conf_name) path = path or get_conf_path(conf_name)
if not os.path.isdir(_config_dir):
os.mkdir(_config_dir)
if not os.path.isfile(path): if not os.path.isfile(path):
fn = _conf_fn_w_ext(conf_name) fn = _conf_fn_w_ext(conf_name)
@ -212,11 +202,8 @@ def load(
# if one exists. # if one exists.
if os.path.isfile(template): if os.path.isfile(template):
shutil.copyfile(template, path) shutil.copyfile(template, path)
else:
with open(path, 'r'):
pass # touch it
config = toml.load(path, **tomlkws) config = toml.load(path)
log.debug(f"Read config file {path}") log.debug(f"Read config file {path}")
return config, path return config, path
@ -225,7 +212,6 @@ def write(
config: dict, # toml config as dict config: dict, # toml config as dict
name: str = 'brokers', name: str = 'brokers',
path: str = None, path: str = None,
**toml_kwargs,
) -> None: ) -> None:
'''' ''''
@ -249,14 +235,11 @@ def write(
f"{path}" f"{path}"
) )
with open(path, 'w') as cf: with open(path, 'w') as cf:
return toml.dump( return toml.dump(config, cf)
config,
cf,
**toml_kwargs,
)
def load_accounts( def load_accounts(
providers: Optional[list[str]] = None providers: Optional[list[str]] = None
) -> bidict[str, Optional[str]]: ) -> bidict[str, Optional[str]]:

View File

@ -22,12 +22,6 @@ and storing data from your brokers as well as
sharing live streams over a network. sharing live streams over a network.
""" """
import tractor
import trio
from ..log import (
get_console_log,
)
from ._normalize import iterticks from ._normalize import iterticks
from ._sharedmem import ( from ._sharedmem import (
maybe_open_shm_array, maybe_open_shm_array,
@ -38,6 +32,7 @@ from ._sharedmem import (
) )
from .feed import ( from .feed import (
open_feed, open_feed,
_setup_persistent_brokerd,
) )
@ -49,40 +44,5 @@ __all__ = [
'attach_shm_array', 'attach_shm_array',
'open_shm_array', 'open_shm_array',
'get_shm_token', 'get_shm_token',
'_setup_persistent_brokerd',
] ]
@tractor.context
async def _setup_persistent_brokerd(
ctx: tractor.Context,
brokername: str,
) -> None:
'''
Allocate a actor-wide service nursery in ``brokerd``
such that feeds can be run in the background persistently by
the broker backend as needed.
'''
get_console_log(tractor.current_actor().loglevel)
from .feed import (
_bus,
get_feed_bus,
)
global _bus
assert not _bus
async with trio.open_nursery() as service_nursery:
# assign a nursery to the feeds bus for spawning
# background tasks from clients
get_feed_bus(brokername, service_nursery)
# unblock caller
await ctx.started()
# we pin this task to keep the feeds manager active until the
# parent actor decides to tear it down
await trio.sleep_forever()

View File

@ -37,13 +37,8 @@ from docker.models.containers import Container as DockerContainer
from docker.errors import ( from docker.errors import (
DockerException, DockerException,
APIError, APIError,
# ContainerError,
)
import requests
from requests.exceptions import (
ConnectionError,
ReadTimeout,
) )
from requests.exceptions import ConnectionError, ReadTimeout
from ..log import get_logger, get_console_log from ..log import get_logger, get_console_log
from .. import config from .. import config
@ -55,8 +50,8 @@ class DockerNotStarted(Exception):
'Prolly you dint start da daemon bruh' 'Prolly you dint start da daemon bruh'
class ApplicationLogError(Exception): class ContainerError(RuntimeError):
'App in container reported an error in logs' 'Error reported via app-container logging level'
@acm @acm
@ -101,9 +96,9 @@ async def open_docker(
# not perms? # not perms?
raise raise
# finally: finally:
# if client: if client:
# client.close() client.close()
class Container: class Container:
@ -161,7 +156,7 @@ class Container:
# print(f'level: {level}') # print(f'level: {level}')
if level in ('error', 'fatal'): if level in ('error', 'fatal'):
raise ApplicationLogError(msg) raise ContainerError(msg)
if patt in msg: if patt in msg:
return True return True
@ -190,29 +185,12 @@ class Container:
if 'is not running' in err.explanation: if 'is not running' in err.explanation:
return False return False
def hard_kill(self, start: float) -> None:
delay = time.time() - start
# get out the big guns, bc apparently marketstore
# doesn't actually know how to terminate gracefully
# :eyeroll:...
log.error(
f'SIGKILL-ing: {self.cntr.id} after {delay}s\n'
)
self.try_signal('SIGKILL')
self.cntr.wait(
timeout=3,
condition='not-running',
)
async def cancel( async def cancel(
self, self,
stop_msg: str, stop_msg: str,
hard_kill: bool = False,
) -> None: ) -> None:
cid = self.cntr.id cid = self.cntr.id
# first try a graceful cancel # first try a graceful cancel
log.cancel( log.cancel(
f'SIGINT cancelling container: {cid}\n' f'SIGINT cancelling container: {cid}\n'
@ -221,25 +199,15 @@ class Container:
self.try_signal('SIGINT') self.try_signal('SIGINT')
start = time.time() start = time.time()
for _ in range(6): for _ in range(30):
with trio.move_on_after(0.5) as cs: with trio.move_on_after(0.5) as cs:
log.cancel('polling for CNTR logs...') cs.shield = True
try:
await self.process_logs_until(stop_msg) await self.process_logs_until(stop_msg)
except ApplicationLogError:
hard_kill = True
else:
# if we aren't cancelled on above checkpoint then we
# assume we read the expected stop msg and
# terminated.
break
if cs.cancelled_caught: # if we aren't cancelled on above checkpoint then we
# on timeout just try a hard kill after # assume we read the expected stop msg and terminated.
# a quick container sync-wait. break
hard_kill = True
try: try:
log.info(f'Polling for container shutdown:\n{cid}') log.info(f'Polling for container shutdown:\n{cid}')
@ -250,7 +218,6 @@ class Container:
condition='not-running', condition='not-running',
) )
# graceful exit if we didn't time out
break break
except ( except (
@ -262,22 +229,24 @@ class Container:
except ( except (
docker.errors.APIError, docker.errors.APIError,
ConnectionError, ConnectionError,
requests.exceptions.ConnectionError,
trio.Cancelled,
): ):
log.exception('Docker connection failure') log.exception('Docker connection failure')
self.hard_kill(start) break
raise
except trio.Cancelled:
log.exception('trio cancelled...')
self.hard_kill(start)
else: else:
hard_kill = True delay = time.time() - start
log.error(
f'Failed to kill container {cid} after {delay}s\n'
'sending SIGKILL..'
)
# get out the big guns, bc apparently marketstore
# doesn't actually know how to terminate gracefully
# :eyeroll:...
self.try_signal('SIGKILL')
self.cntr.wait(
timeout=3,
condition='not-running',
)
if hard_kill:
self.hard_kill(start)
else:
log.cancel(f'Container stopped: {cid}') log.cancel(f'Container stopped: {cid}')
@ -320,12 +289,14 @@ async def open_ahabd(
)) ))
try: try:
# TODO: we might eventually want a proxy-style msg-prot here # TODO: we might eventually want a proxy-style msg-prot here
# to allow remote control of containers without needing # to allow remote control of containers without needing
# callers to have root perms? # callers to have root perms?
await trio.sleep_forever() await trio.sleep_forever()
finally: finally:
with trio.CancelScope(shield=True):
await cntr.cancel(stop_msg) await cntr.cancel(stop_msg)

View File

@ -56,7 +56,7 @@ def iterticks(
sig = ( sig = (
time, time,
tick['price'], tick['price'],
tick.get('size') tick['size']
) )
if ttype == 'dark_trade': if ttype == 'dark_trade':

View File

@ -20,96 +20,53 @@ financial data flows.
""" """
from __future__ import annotations from __future__ import annotations
from collections import ( from collections import Counter
Counter,
defaultdict,
)
from contextlib import asynccontextmanager as acm
import time import time
from typing import ( from typing import TYPE_CHECKING, Optional, Union
AsyncIterator,
TYPE_CHECKING,
)
import tractor import tractor
from tractor.trionics import (
maybe_open_nursery,
)
import trio import trio
from trio_typing import TaskStatus from trio_typing import TaskStatus
from ..log import ( from ..log import get_logger
get_logger,
get_console_log,
)
from .._daemon import maybe_spawn_daemon
if TYPE_CHECKING: if TYPE_CHECKING:
from ._sharedmem import ( from ._sharedmem import ShmArray
ShmArray,
)
from .feed import _FeedsBus from .feed import _FeedsBus
log = get_logger(__name__) log = get_logger(__name__)
# highest frequency sample step is 1 second by default, though in class sampler:
# the future we may want to support shorter periods or a dynamic style
# tick-event stream.
_default_delay_s: float = 1.0
class Sampler:
''' '''
Global sampling engine registry. Global sampling engine registry.
Manages state for sampling events, shm incrementing and Manages state for sampling events, shm incrementing and
sample period logic. sample period logic.
This non-instantiated type is meant to be a singleton within
a `samplerd` actor-service spawned once by the user wishing to
time-step sample real-time quote feeds, see
``._daemon.maybe_open_samplerd()`` and the below
``register_with_sampler()``.
''' '''
service_nursery: None | trio.Nursery = None
# TODO: we could stick these in a composed type to avoid # TODO: we could stick these in a composed type to avoid
# angering the "i hate module scoped variables crowd" (yawn). # angering the "i hate module scoped variables crowd" (yawn).
ohlcv_shms: dict[float, list[ShmArray]] = {} ohlcv_shms: dict[int, list[ShmArray]] = {}
# holds one-task-per-sample-period tasks which are spawned as-needed by # holds one-task-per-sample-period tasks which are spawned as-needed by
# data feed requests with a given detected time step usually from # data feed requests with a given detected time step usually from
# history loading. # history loading.
incr_task_cs: trio.CancelScope | None = None incrementers: dict[int, trio.CancelScope] = {}
# holds all the ``tractor.Context`` remote subscriptions for # holds all the ``tractor.Context`` remote subscriptions for
# a particular sample period increment event: all subscribers are # a particular sample period increment event: all subscribers are
# notified on a step. # notified on a step.
# subscribers: dict[int, list[tractor.MsgStream]] = {} subscribers: dict[int, tractor.Context] = {}
subscribers: defaultdict[
float,
list[
float,
set[tractor.MsgStream]
],
] = defaultdict(
lambda: [
round(time.time()),
set(),
]
)
@classmethod
async def increment_ohlc_buffer( async def increment_ohlc_buffer(
self, delay_s: int,
period_s: float,
task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED, task_status: TaskStatus[trio.CancelScope] = trio.TASK_STATUS_IGNORED,
): ):
''' '''
Task which inserts new bars into the provide shared memory array Task which inserts new bars into the provide shared memory array
every ``period_s`` seconds. every ``delay_s`` seconds.
This task fulfills 2 purposes: This task fulfills 2 purposes:
- it takes the subscribed set of shm arrays and increments them - it takes the subscribed set of shm arrays and increments them
@ -121,143 +78,100 @@ class Sampler:
the underlying buffers will actually be incremented. the underlying buffers will actually be incremented.
''' '''
# # wait for brokerd to signal we should start sampling
# await shm_incrementing(shm_token['shm_name']).wait()
# TODO: right now we'll spin printing bars if the last time stamp is # TODO: right now we'll spin printing bars if the last time stamp is
# before a large period of no market activity. Likely the best way # before a large period of no market activity. Likely the best way
# to solve this is to make this task aware of the instrument's # to solve this is to make this task aware of the instrument's
# tradable hours? # tradable hours?
total_s: float = 0 # total seconds counted # adjust delay to compensate for trio processing time
ad = period_s - 0.001 # compensate for trio processing time ad = min(sampler.ohlcv_shms.keys()) - 0.001
total_s = 0 # total seconds counted
lowest = min(sampler.ohlcv_shms.keys())
lowest_shm = sampler.ohlcv_shms[lowest][0]
ad = lowest - 0.001
with trio.CancelScope() as cs: with trio.CancelScope() as cs:
# register this time period step as active # register this time period step as active
sampler.incrementers[delay_s] = cs
task_status.started(cs) task_status.started(cs)
# sample step loop:
# includes broadcasting to all connected consumers on every
# new sample step as well incrementing any registered
# buffers by registered sample period.
while True: while True:
# TODO: do we want to support dynamically
# adding a "lower" lowest increment period?
await trio.sleep(ad) await trio.sleep(ad)
total_s += period_s total_s += lowest
# increment all subscribed shm arrays # increment all subscribed shm arrays
# TODO: # TODO:
# - this in ``numba`` # - this in ``numba``
# - just lookup shms for this step instead of iterating? # - just lookup shms for this step instead of iterating?
for delay_s, shms in sampler.ohlcv_shms.items():
i_epoch = round(time.time()) if total_s % delay_s != 0:
broadcasted: set[float] = set()
# print(f'epoch: {i_epoch} -> REGISTRY {self.ohlcv_shms}')
for shm_period_s, shms in self.ohlcv_shms.items():
# short-circuit on any not-ready because slower sample
# rate consuming shm buffers.
if total_s % shm_period_s != 0:
# print(f'skipping `{shm_period_s}s` sample update')
continue continue
# update last epoch stamp for this period group
if shm_period_s not in broadcasted:
sub_pair = self.subscribers[shm_period_s]
sub_pair[0] = i_epoch
broadcasted.add(shm_period_s)
# TODO: ``numba`` this! # TODO: ``numba`` this!
for shm in shms: for shm in shms:
# print(f'UPDATE {shm_period_s}s STEP for {shm.token}') # TODO: in theory we could make this faster by copying the
# "last" readable value into the underlying larger buffer's
# next value and then incrementing the counter instead of
# using ``.push()``?
# append new entry to buffer thus "incrementing" # append new entry to buffer thus "incrementing" the bar
# the bar
array = shm.array array = shm.array
last = array[-1:][shm._write_fields].copy() last = array[-1:][shm._write_fields].copy()
# (index, t, close) = last[0][['index', 'time', 'close']]
(t, close) = last[0][['time', 'close']]
# guard against startup backfilling races where # this copies non-std fields (eg. vwap) from the last datum
# the buffer has not yet been filled. last[
if not last.size: ['time', 'volume', 'open', 'high', 'low', 'close']
continue ][0] = (t + delay_s, 0, close, close, close, close)
(t, close) = last[0][[
'time',
'close',
]]
next_t = t + shm_period_s
if shm_period_s <= 1:
next_t = i_epoch
# this copies non-std fields (eg. vwap) from the
# last datum
last[[
'time',
'open',
'high',
'low',
'close',
'volume',
]][0] = (
# epoch timestamp
next_t,
# OHLC
close,
close,
close,
close,
0, # vlm
)
# TODO: in theory we could make this faster by
# copying the "last" readable value into the
# underlying larger buffer's next value and then
# incrementing the counter instead of using
# ``.push()``?
# write to the buffer # write to the buffer
shm.push(last) shm.push(last)
# broadcast increment msg to all updated subs per period await broadcast(delay_s, shm=lowest_shm)
for shm_period_s in broadcasted:
await self.broadcast(
period_s=shm_period_s,
time_stamp=i_epoch,
)
@classmethod
async def broadcast(
self,
period_s: float,
time_stamp: float | None = None,
) -> None: async def broadcast(
delay_s: int,
shm: Optional[ShmArray] = None,
) -> None:
''' '''
Broadcast the period size and last index step value to all Broadcast the given ``shm: ShmArray``'s buffer index step to any
subscribers for a given sample period. subscribers for a given sample period.
The sent msg will include the first and last index which slice into
the buffer's non-empty data.
''' '''
pair = self.subscribers[period_s] subs = sampler.subscribers.get(delay_s, ())
last_ts, subs = pair first = last = -1
if shm is None:
periods = sampler.ohlcv_shms.keys()
# if this is an update triggered by a history update there
# might not actually be any sampling bus setup since there's
# no "live feed" active yet.
if periods:
lowest = min(periods)
shm = sampler.ohlcv_shms[lowest][0]
first = shm._first.value
last = shm._last.value
task = trio.lowlevel.current_task()
log.debug(
f'SUBS {self.subscribers}\n'
f'PAIR {pair}\n'
f'TASK: {task}: {id(task)}\n'
f'broadcasting {period_s} -> {last_ts}\n'
# f'consumers: {subs}'
)
borked: set[tractor.MsgStream] = set()
for stream in subs: for stream in subs:
try: try:
await stream.send({ await stream.send({
'index': time_stamp or last_ts, 'first': first,
'period': period_s, 'last': last,
'index': last,
}) })
except ( except (
trio.BrokenResourceError, trio.BrokenResourceError,
@ -266,9 +180,6 @@ class Sampler:
log.error( log.error(
f'{stream._ctx.chan.uid} dropped connection' f'{stream._ctx.chan.uid} dropped connection'
) )
borked.add(stream)
for stream in borked:
try: try:
subs.remove(stream) subs.remove(stream)
except ValueError: except ValueError:
@ -276,255 +187,53 @@ class Sampler:
f'{stream._ctx.chan.uid} sub already removed!?' f'{stream._ctx.chan.uid} sub already removed!?'
) )
@classmethod
async def broadcast_all(self) -> None:
for period_s in self.subscribers:
await self.broadcast(period_s)
@tractor.context @tractor.context
async def register_with_sampler( async def iter_ohlc_periods(
ctx: tractor.Context, ctx: tractor.Context,
period_s: float, delay_s: int,
shms_by_period: dict[float, dict] | None = None,
open_index_stream: bool = True, # open a 2way stream for sample step msgs?
sub_for_broadcasts: bool = True, # sampler side to send step updates?
) -> None: ) -> None:
'''
Subscribe to OHLC sampling "step" events: when the time
aggregation period increments, this event stream emits an index
event.
get_console_log(tractor.current_actor().loglevel) '''
incr_was_started: bool = False # add our subscription
subs = sampler.subscribers.setdefault(delay_s, [])
try: await ctx.started()
async with maybe_open_nursery(
Sampler.service_nursery
) as service_nursery:
# init startup, create (actor-)local service nursery and start
# increment task
Sampler.service_nursery = service_nursery
# always ensure a period subs entry exists
last_ts, subs = Sampler.subscribers[float(period_s)]
async with trio.Lock():
if Sampler.incr_task_cs is None:
Sampler.incr_task_cs = await service_nursery.start(
Sampler.increment_ohlc_buffer,
1.,
)
incr_was_started = True
# insert the base 1s period (for OHLC style sampling) into
# the increment buffer set to update and shift every second.
if shms_by_period is not None:
from ._sharedmem import (
attach_shm_array,
_Token,
)
for period in shms_by_period:
# load and register shm handles
shm_token_msg = shms_by_period[period]
shm = attach_shm_array(
_Token.from_msg(shm_token_msg),
readonly=False,
)
shms_by_period[period] = shm
Sampler.ohlcv_shms.setdefault(period, []).append(shm)
assert Sampler.ohlcv_shms
# unblock caller
await ctx.started(set(Sampler.ohlcv_shms.keys()))
if open_index_stream:
try:
async with ctx.open_stream() as stream: async with ctx.open_stream() as stream:
if sub_for_broadcasts: subs.append(stream)
subs.add(stream)
# except broadcast requests from the subscriber try:
async for msg in stream: # stream and block until cancelled
if msg == 'broadcast_all':
await Sampler.broadcast_all()
finally:
if sub_for_broadcasts:
subs.remove(stream)
else:
# if no shms are passed in we just wait until cancelled
# by caller.
await trio.sleep_forever() await trio.sleep_forever()
finally: finally:
# TODO: why tf isn't this working? try:
if shms_by_period is not None: subs.remove(stream)
for period, shm in shms_by_period.items(): except ValueError:
Sampler.ohlcv_shms[period].remove(shm) log.error(
f'iOHLC step stream was already dropped {ctx.chan.uid}?'
if incr_was_started:
Sampler.incr_task_cs.cancel()
Sampler.incr_task_cs = None
async def spawn_samplerd(
loglevel: str | None = None,
**extra_tractor_kwargs
) -> bool:
'''
Daemon-side service task: start a sampling daemon for common step
update and increment count write and stream broadcasting.
'''
from piker._daemon import Services
dname = 'samplerd'
log.info(f'Spawning `{dname}`')
# singleton lock creation of ``samplerd`` since we only ever want
# one daemon per ``pikerd`` proc tree.
# TODO: make this built-into the service api?
async with Services.locks[dname + '_singleton']:
if dname not in Services.service_tasks:
portal = await Services.actor_n.start_actor(
dname,
enable_modules=[
'piker.data._sampling',
],
loglevel=loglevel,
debug_mode=Services.debug_mode, # set by pikerd flag
**extra_tractor_kwargs
) )
await Services.start_service_task(
dname,
portal,
register_with_sampler,
period_s=1,
sub_for_broadcasts=False,
)
return True
return False
@acm
async def maybe_open_samplerd(
loglevel: str | None = None,
**kwargs,
) -> tractor._portal.Portal: # noqa
'''
Client-side helper to maybe startup the ``samplerd`` service
under the ``pikerd`` tree.
'''
dname = 'samplerd'
async with maybe_spawn_daemon(
dname,
service_task_target=spawn_samplerd,
spawn_args={'loglevel': loglevel},
loglevel=loglevel,
**kwargs,
) as portal:
yield portal
@acm
async def open_sample_stream(
period_s: float,
shms_by_period: dict[float, dict] | None = None,
open_index_stream: bool = True,
sub_for_broadcasts: bool = True,
cache_key: str | None = None,
allow_new_sampler: bool = True,
) -> AsyncIterator[dict[str, float]]:
'''
Subscribe to OHLC sampling "step" events: when the time aggregation
period increments, this event stream emits an index event.
This is a client-side endpoint that does all the work of ensuring
the `samplerd` actor is up and that mult-consumer-tasks are given
a broadcast stream when possible.
'''
# TODO: wrap this manager with the following to make it cached
# per client-multitasks entry.
# maybe_open_context(
# acm_func=partial(
# portal.open_context,
# register_with_sampler,
# ),
# key=cache_key or period_s,
# )
# if cache_hit:
# # add a new broadcast subscription for the quote stream
# # if this feed is likely already in use
# async with istream.subscribe() as bistream:
# yield bistream
# else:
async with (
# XXX: this should be singleton on a host,
# a lone broker-daemon per provider should be
# created for all practical purposes
maybe_open_samplerd() as portal,
portal.open_context(
register_with_sampler,
**{
'period_s': period_s,
'shms_by_period': shms_by_period,
'open_index_stream': open_index_stream,
'sub_for_broadcasts': sub_for_broadcasts,
},
) as (ctx, first)
):
async with (
ctx.open_stream() as istream,
# TODO: we don't need this task-bcasting right?
# istream.subscribe() as istream,
):
yield istream
async def sample_and_broadcast( async def sample_and_broadcast(
bus: _FeedsBus, # noqa bus: _FeedsBus, # noqa
rt_shm: ShmArray, shm: ShmArray,
hist_shm: ShmArray,
quote_stream: trio.abc.ReceiveChannel, quote_stream: trio.abc.ReceiveChannel,
brokername: str, brokername: str,
sum_tick_vlm: bool = True, sum_tick_vlm: bool = True,
) -> None: ) -> None:
'''
`brokerd`-side task which writes latest datum sampled data.
This task is meant to run in the same actor (mem space) as the
`brokerd` real-time quote feed which is being sampled to
a ``ShmArray`` buffer.
'''
log.info("Started shared mem bar writer") log.info("Started shared mem bar writer")
overruns = Counter() overruns = Counter()
# iterate stream delivered by broker # iterate stream delivered by broker
async for quotes in quote_stream: async for quotes in quote_stream:
# print(quotes)
# TODO: ``numba`` this! # TODO: ``numba`` this!
for broker_symbol, quote in quotes.items(): for broker_symbol, quote in quotes.items():
# TODO: in theory you can send the IPC msg *before* writing # TODO: in theory you can send the IPC msg *before* writing
@ -548,9 +257,6 @@ async def sample_and_broadcast(
last = tick['price'] last = tick['price']
# more compact inline-way to do this assignment
# to both buffers?
for shm in [rt_shm, hist_shm]:
# update last entry # update last entry
# benchmarked in the 4-5 us range # benchmarked in the 4-5 us range
o, high, low, v = shm.array[-1][ o, high, low, v = shm.array[-1][
@ -587,29 +293,29 @@ async def sample_and_broadcast(
volume, volume,
) )
# TODO: PUT THIS IN A ``_FeedsBus.broadcast()`` method!
# XXX: we need to be very cautious here that no # XXX: we need to be very cautious here that no
# context-channel is left lingering which doesn't have # context-channel is left lingering which doesn't have
# a far end receiver actor-task. In such a case you can # a far end receiver actor-task. In such a case you can
# end up triggering backpressure which which will # end up triggering backpressure which which will
# eventually block this producer end of the feed and # eventually block this producer end of the feed and
# thus other consumers still attached. # thus other consumers still attached.
sub_key: str = broker_symbol.lower()
subs: list[ subs: list[
tuple[ tuple[
tractor.MsgStream | trio.MemorySendChannel, Union[tractor.MsgStream, trio.MemorySendChannel],
float | None, # tick throttle in Hz tractor.Context,
Optional[float], # tick throttle in Hz
] ]
] = bus.get_subs(sub_key) ] = bus._subscribers[broker_symbol.lower()]
# NOTE: by default the broker backend doesn't append # NOTE: by default the broker backend doesn't append
# it's own "name" into the fqsn schema (but maybe it # it's own "name" into the fqsn schema (but maybe it
# should?) so we have to manually generate the correct # should?) so we have to manually generate the correct
# key here. # key here.
fqsn = f'{broker_symbol}.{brokername}' bsym = f'{broker_symbol}.{brokername}'
lags: int = 0 lags: int = 0
for (stream, tick_throttle) in subs.copy(): for (stream, ctx, tick_throttle) in subs:
try: try:
with trio.move_on_after(0.2) as cs: with trio.move_on_after(0.2) as cs:
if tick_throttle: if tick_throttle:
@ -617,39 +323,47 @@ async def sample_and_broadcast(
# pushes to the ``uniform_rate_send()`` below. # pushes to the ``uniform_rate_send()`` below.
try: try:
stream.send_nowait( stream.send_nowait(
(fqsn, quote) (bsym, quote)
) )
except trio.WouldBlock: except trio.WouldBlock:
overruns[sub_key] += 1
ctx = stream._ctx
chan = ctx.chan chan = ctx.chan
if ctx:
log.warning( log.warning(
f'Feed OVERRUN {sub_key}' f'Feed overrun {bus.brokername} ->'
'@{bus.brokername} -> \n' f'{chan.uid} !!!'
f'feed @ {chan.uid}\n'
f'throttle = {tick_throttle} Hz'
) )
else:
if overruns[sub_key] > 6: key = id(stream)
overruns[key] += 1
log.warning(
f'Feed overrun {broker_symbol}'
'@{bus.brokername} -> '
f'feed @ {tick_throttle} Hz'
)
if overruns[key] > 6:
# TODO: should we check for the # TODO: should we check for the
# context being cancelled? this # context being cancelled? this
# could happen but the # could happen but the
# channel-ipc-pipe is still up. # channel-ipc-pipe is still up.
if ( if not chan.connected():
not chan.connected()
or ctx._cancel_called
):
log.warning( log.warning(
'Dropping broken consumer:\n' 'Dropping broken consumer:\n'
f'{sub_key}:' f'{broker_symbol}:'
f'{ctx.cid}@{chan.uid}' f'{ctx.cid}@{chan.uid}'
) )
await stream.aclose() await stream.aclose()
raise trio.BrokenResourceError raise trio.BrokenResourceError
else:
log.warning(
'Feed getting overrun bro!\n'
f'{broker_symbol}:'
f'{ctx.cid}@{chan.uid}'
)
continue
else: else:
await stream.send( await stream.send(
{fqsn: quote} {bsym: quote}
) )
if cs.cancelled_caught: if cs.cancelled_caught:
@ -662,7 +376,6 @@ async def sample_and_broadcast(
trio.ClosedResourceError, trio.ClosedResourceError,
trio.EndOfChannel, trio.EndOfChannel,
): ):
ctx = stream._ctx
chan = ctx.chan chan = ctx.chan
if ctx: if ctx:
log.warning( log.warning(
@ -678,69 +391,20 @@ async def sample_and_broadcast(
# so far seems like no since this should all # so far seems like no since this should all
# be single-threaded. Doing it anyway though # be single-threaded. Doing it anyway though
# since there seems to be some kinda race.. # since there seems to be some kinda race..
bus.remove_subs( try:
sub_key, subs.remove((stream, tick_throttle))
{(stream, tick_throttle)}, except ValueError:
log.error(
f'Stream was already removed from subs!?\n'
f'{broker_symbol}:'
f'{ctx.cid}@{chan.uid}'
) )
# a working tick-type-classes template
_tick_groups = {
'clears': {'trade', 'dark_trade', 'last'},
'bids': {'bid', 'bsize'},
'asks': {'ask', 'asize'},
}
def frame_ticks(
first_quote: dict,
last_quote: dict,
ticks_by_type: dict,
) -> None:
# append quotes since last iteration into the last quote's
# tick array/buffer.
ticks = last_quote.get('ticks')
# TODO: once we decide to get fancy really we should
# have a shared mem tick buffer that is just
# continually filled and the UI just ready from it
# at it's display rate.
if ticks:
# TODO: do we need this any more or can we just
# expect the receiver to unwind the below
# `ticks_by_type: dict`?
# => undwinding would potentially require a
# `dict[str, set | list]` instead with an
# included `'types' field which is an (ordered)
# set of tick type fields in the order which
# types arrived?
first_quote['ticks'].extend(ticks)
# XXX: build a tick-by-type table of lists
# of tick messages. This allows for less
# iteration on the receiver side by allowing for
# a single "latest tick event" look up by
# indexing the last entry in each sub-list.
# tbt = {
# 'types': ['bid', 'asize', 'last', .. '<type_n>'],
# 'bid': [tick0, tick1, tick2, .., tickn],
# 'asize': [tick0, tick1, tick2, .., tickn],
# 'last': [tick0, tick1, tick2, .., tickn],
# ...
# '<type_n>': [tick0, tick1, tick2, .., tickn],
# }
# append in reverse FIFO order for in-order iteration on
# receiver side.
for tick in ticks:
ttype = tick['type']
ticks_by_type[ttype].append(tick)
# TODO: a less naive throttler, here's some snippets: # TODO: a less naive throttler, here's some snippets:
# token bucket by njs: # token bucket by njs:
# https://gist.github.com/njsmith/7ea44ec07e901cb78ebe1dd8dd846cb9 # https://gist.github.com/njsmith/7ea44ec07e901cb78ebe1dd8dd846cb9
async def uniform_rate_send( async def uniform_rate_send(
rate: float, rate: float,
@ -751,9 +415,6 @@ async def uniform_rate_send(
) -> None: ) -> None:
# try not to error-out on overruns of the subscribed (chart) client
stream._ctx._backpressure = True
# TODO: compute the approx overhead latency per cycle # TODO: compute the approx overhead latency per cycle
left_to_sleep = throttle_period = 1/rate - 0.000616 left_to_sleep = throttle_period = 1/rate - 0.000616
@ -763,12 +424,6 @@ async def uniform_rate_send(
diff = 0 diff = 0
task_status.started() task_status.started()
ticks_by_type: defaultdict[
str,
list[dict],
] = defaultdict(list)
clear_types = _tick_groups['clears']
while True: while True:
@ -787,17 +442,34 @@ async def uniform_rate_send(
if not first_quote: if not first_quote:
first_quote = last_quote first_quote = last_quote
# first_quote['tbt'] = ticks_by_type
if (throttle_period - diff) > 0: if (throttle_period - diff) > 0:
# received a quote but the send cycle period hasn't yet # received a quote but the send cycle period hasn't yet
# expired we aren't supposed to send yet so append # expired we aren't supposed to send yet so append
# to the tick frame. # to the tick frame.
frame_ticks(
first_quote, # append quotes since last iteration into the last quote's
last_quote, # tick array/buffer.
ticks_by_type, ticks = last_quote.get('ticks')
)
# XXX: idea for frame type data structure we could
# use on the wire instead of a simple list?
# frames = {
# 'index': ['type_a', 'type_c', 'type_n', 'type_n'],
# 'type_a': [tick0, tick1, tick2, .., tickn],
# 'type_b': [tick0, tick1, tick2, .., tickn],
# 'type_c': [tick0, tick1, tick2, .., tickn],
# ...
# 'type_n': [tick0, tick1, tick2, .., tickn],
# }
# TODO: once we decide to get fancy really we should
# have a shared mem tick buffer that is just
# continually filled and the UI just ready from it
# at it's display rate.
if ticks:
first_quote['ticks'].extend(ticks)
# send cycle isn't due yet so continue waiting # send cycle isn't due yet so continue waiting
continue continue
@ -814,35 +486,12 @@ async def uniform_rate_send(
# received quote ASAP. # received quote ASAP.
sym, first_quote = await quote_stream.receive() sym, first_quote = await quote_stream.receive()
frame_ticks(
first_quote,
first_quote,
ticks_by_type,
)
# we have a quote already so send it now. # we have a quote already so send it now.
with trio.move_on_after(throttle_period) as cs:
while (
not set(ticks_by_type).intersection(clear_types)
):
try:
sym, last_quote = await quote_stream.receive()
except trio.EndOfChannel:
log.exception(f"feed for {stream} ended?")
break
frame_ticks(
first_quote,
last_quote,
ticks_by_type,
)
# measured_rate = 1 / (time.time() - last_send) # measured_rate = 1 / (time.time() - last_send)
# log.info( # log.info(
# f'`{sym}` throttled send hz: {round(measured_rate, ndigits=1)}' # f'`{sym}` throttled send hz: {round(measured_rate, ndigits=1)}'
# ) # )
first_quote['tbt'] = ticks_by_type
# TODO: now if only we could sync this to the display # TODO: now if only we could sync this to the display
# rate timing exactly lul # rate timing exactly lul
@ -868,4 +517,3 @@ async def uniform_rate_send(
first_quote = last_quote = None first_quote = last_quote = None
diff = 0 diff = 0
last_send = time.time() last_send = time.time()
ticks_by_type.clear()

View File

@ -1,5 +1,5 @@
# piker: trading gear for hackers # piker: trading gear for hackers
# Copyright (C) Tyler Goodlet (in stewardship for pikers) # Copyright (C) Tyler Goodlet (in stewardship for piker0)
# This program is free software: you can redistribute it and/or modify # This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by # it under the terms of the GNU Affero General Public License as published by
@ -27,14 +27,13 @@ from multiprocessing.shared_memory import SharedMemory, _USE_POSIX
if _USE_POSIX: if _USE_POSIX:
from _posixshmem import shm_unlink from _posixshmem import shm_unlink
# import msgspec
import numpy as np
from numpy.lib import recfunctions as rfn
import tractor import tractor
import numpy as np
from pydantic import BaseModel
from numpy.lib import recfunctions as rfn
from ..log import get_logger from ..log import get_logger
from ._source import base_iohlc_dtype from ._source import base_iohlc_dtype
from .types import Struct
log = get_logger(__name__) log = get_logger(__name__)
@ -50,11 +49,7 @@ _rt_buffer_start = int((_days_worth - 1) * _secs_in_day)
def cuckoff_mantracker(): def cuckoff_mantracker():
'''
Disable all ``multiprocessing``` "resource tracking" machinery since
it's an absolute multi-threaded mess of non-SC madness.
'''
from multiprocessing import resource_tracker as mantracker from multiprocessing import resource_tracker as mantracker
# Tell the "resource tracker" thing to fuck off. # Tell the "resource tracker" thing to fuck off.
@ -112,39 +107,36 @@ class SharedInt:
log.warning(f'Shm for {name} already unlinked?') log.warning(f'Shm for {name} already unlinked?')
class _Token(Struct, frozen=True): class _Token(BaseModel):
''' '''
Internal represenation of a shared memory "token" Internal represenation of a shared memory "token"
which can be used to key a system wide post shm entry. which can be used to key a system wide post shm entry.
''' '''
class Config:
frozen = True
shm_name: str # this servers as a "key" value shm_name: str # this servers as a "key" value
shm_first_index_name: str shm_first_index_name: str
shm_last_index_name: str shm_last_index_name: str
dtype_descr: tuple dtype_descr: tuple
size: int # in struct-array index / row terms
@property @property
def dtype(self) -> np.dtype: def dtype(self) -> np.dtype:
return np.dtype(list(map(tuple, self.dtype_descr))).descr return np.dtype(list(map(tuple, self.dtype_descr))).descr
def as_msg(self): def as_msg(self):
return self.to_dict() return self.dict()
@classmethod @classmethod
def from_msg(cls, msg: dict) -> _Token: def from_msg(cls, msg: dict) -> _Token:
if isinstance(msg, _Token): if isinstance(msg, _Token):
return msg return msg
# TODO: native struct decoding
# return _token_dec.decode(msg)
msg['dtype_descr'] = tuple(map(tuple, msg['dtype_descr'])) msg['dtype_descr'] = tuple(map(tuple, msg['dtype_descr']))
return _Token(**msg) return _Token(**msg)
# _token_dec = msgspec.msgpack.Decoder(_Token)
# TODO: this api? # TODO: this api?
# _known_tokens = tractor.ActorVar('_shm_tokens', {}) # _known_tokens = tractor.ActorVar('_shm_tokens', {})
# _known_tokens = tractor.ContextStack('_known_tokens', ) # _known_tokens = tractor.ContextStack('_known_tokens', )
@ -163,7 +155,6 @@ def get_shm_token(key: str) -> _Token:
def _make_token( def _make_token(
key: str, key: str,
size: int,
dtype: Optional[np.dtype] = None, dtype: Optional[np.dtype] = None,
) -> _Token: ) -> _Token:
''' '''
@ -176,8 +167,7 @@ def _make_token(
shm_name=key, shm_name=key,
shm_first_index_name=key + "_first", shm_first_index_name=key + "_first",
shm_last_index_name=key + "_last", shm_last_index_name=key + "_last",
dtype_descr=tuple(np.dtype(dtype).descr), dtype_descr=np.dtype(dtype).descr
size=size,
) )
@ -229,7 +219,6 @@ class ShmArray:
shm_first_index_name=self._first._shm.name, shm_first_index_name=self._first._shm.name,
shm_last_index_name=self._last._shm.name, shm_last_index_name=self._last._shm.name,
dtype_descr=tuple(self._array.dtype.descr), dtype_descr=tuple(self._array.dtype.descr),
size=self._len,
) )
@property @property
@ -444,7 +433,7 @@ class ShmArray:
def open_shm_array( def open_shm_array(
key: Optional[str] = None, key: Optional[str] = None,
size: int = _default_size, # see above size: int = _default_size,
dtype: Optional[np.dtype] = None, dtype: Optional[np.dtype] = None,
readonly: bool = False, readonly: bool = False,
@ -475,8 +464,7 @@ def open_shm_array(
token = _make_token( token = _make_token(
key=key, key=key,
size=size, dtype=dtype
dtype=dtype,
) )
# create single entry arrays for storing an first and last indices # create single entry arrays for storing an first and last indices
@ -528,15 +516,15 @@ def open_shm_array(
# "unlink" created shm on process teardown by # "unlink" created shm on process teardown by
# pushing teardown calls onto actor context stack # pushing teardown calls onto actor context stack
stack = tractor.current_actor().lifetime_stack tractor._actor._lifetime_stack.callback(shmarr.close)
stack.callback(shmarr.close) tractor._actor._lifetime_stack.callback(shmarr.destroy)
stack.callback(shmarr.destroy)
return shmarr return shmarr
def attach_shm_array( def attach_shm_array(
token: tuple[str, str, tuple[str, str]], token: tuple[str, str, tuple[str, str]],
size: int = _default_size,
readonly: bool = True, readonly: bool = True,
) -> ShmArray: ) -> ShmArray:
@ -575,7 +563,7 @@ def attach_shm_array(
raise _err raise _err
shmarr = np.ndarray( shmarr = np.ndarray(
(token.size,), (size,),
dtype=token.dtype, dtype=token.dtype,
buffer=shm.buf buffer=shm.buf
) )
@ -614,8 +602,8 @@ def attach_shm_array(
if key not in _known_tokens: if key not in _known_tokens:
_known_tokens[key] = token _known_tokens[key] = token
# "close" attached shm on actor teardown # "close" attached shm on process teardown
tractor.current_actor().lifetime_stack.callback(sha.close) tractor._actor._lifetime_stack.callback(sha.close)
return sha return sha
@ -643,7 +631,6 @@ def maybe_open_shm_array(
use ``attach_shm_array``. use ``attach_shm_array``.
''' '''
size = kwargs.pop('size', _default_size)
try: try:
# see if we already know this key # see if we already know this key
token = _known_tokens[key] token = _known_tokens[key]
@ -651,11 +638,7 @@ def maybe_open_shm_array(
except KeyError: except KeyError:
log.warning(f"Could not find {key} in shms cache") log.warning(f"Could not find {key} in shms cache")
if dtype: if dtype:
token = _make_token( token = _make_token(key, dtype)
key,
size=size,
dtype=dtype,
)
try: try:
return attach_shm_array(token=token, **kwargs), False return attach_shm_array(token=token, **kwargs), False
except FileNotFoundError: except FileNotFoundError:

View File

@ -23,8 +23,7 @@ import decimal
from bidict import bidict from bidict import bidict
import numpy as np import numpy as np
from pydantic import BaseModel
from .types import Struct
# from numba import from_dtype # from numba import from_dtype
@ -127,7 +126,7 @@ def unpack_fqsn(fqsn: str) -> tuple[str, str, str]:
) )
class Symbol(Struct): class Symbol(BaseModel):
''' '''
I guess this is some kinda container thing for dealing with I guess this is some kinda container thing for dealing with
all the different meta-data formats from brokers? all the different meta-data formats from brokers?
@ -153,7 +152,9 @@ class Symbol(Struct):
info: dict[str, Any], info: dict[str, Any],
suffix: str = '', suffix: str = '',
) -> Symbol: # XXX: like wtf..
# ) -> 'Symbol':
) -> None:
tick_size = info.get('price_tick_size', 0.01) tick_size = info.get('price_tick_size', 0.01)
lot_tick_size = info.get('lot_tick_size', 0.0) lot_tick_size = info.get('lot_tick_size', 0.0)
@ -174,7 +175,9 @@ class Symbol(Struct):
fqsn: str, fqsn: str,
info: dict[str, Any], info: dict[str, Any],
) -> Symbol: # XXX: like wtf..
# ) -> 'Symbol':
) -> None:
broker, key, suffix = unpack_fqsn(fqsn) broker, key, suffix = unpack_fqsn(fqsn)
return cls.from_broker_info( return cls.from_broker_info(
broker, broker,
@ -218,10 +221,6 @@ class Symbol(Struct):
else: else:
return (key, broker) return (key, broker)
@property
def fqsn(self) -> str:
return '.'.join(self.tokens()).lower()
def front_fqsn(self) -> str: def front_fqsn(self) -> str:
''' '''
fqsn = "fully qualified symbol name" fqsn = "fully qualified symbol name"
@ -241,7 +240,7 @@ class Symbol(Struct):
''' '''
tokens = self.tokens() tokens = self.tokens()
fqsn = '.'.join(map(str.lower, tokens)) fqsn = '.'.join(tokens)
return fqsn return fqsn
def iterfqsns(self) -> list[str]: def iterfqsns(self) -> list[str]:

View File

@ -18,24 +18,13 @@
ToOlS fOr CoPInG wITh "tHE wEB" protocols. ToOlS fOr CoPInG wITh "tHE wEB" protocols.
""" """
from contextlib import ( from contextlib import asynccontextmanager, AsyncExitStack
asynccontextmanager,
AsyncExitStack,
)
from itertools import count
from types import ModuleType from types import ModuleType
from typing import ( from typing import Any, Callable, AsyncGenerator
Any,
Optional,
Callable,
AsyncGenerator,
Iterable,
)
import json import json
import trio import trio
import trio_websocket import trio_websocket
from wsproto.utilities import LocalProtocolError
from trio_websocket._impl import ( from trio_websocket._impl import (
ConnectionClosed, ConnectionClosed,
DisconnectionTimeout, DisconnectionTimeout,
@ -46,53 +35,43 @@ from trio_websocket._impl import (
from ..log import get_logger from ..log import get_logger
from .types import Struct
log = get_logger(__name__) log = get_logger(__name__)
class NoBsWs: class NoBsWs:
''' """Make ``trio_websocket`` sockets stay up no matter the bs.
Make ``trio_websocket`` sockets stay up no matter the bs.
You can provide a ``fixture`` async-context-manager which will be """
enter/exitted around each reconnect operation.
'''
recon_errors = ( recon_errors = (
ConnectionClosed, ConnectionClosed,
DisconnectionTimeout, DisconnectionTimeout,
ConnectionRejected, ConnectionRejected,
HandshakeError, HandshakeError,
ConnectionTimeout, ConnectionTimeout,
LocalProtocolError,
) )
def __init__( def __init__(
self, self,
url: str, url: str,
token: str,
stack: AsyncExitStack, stack: AsyncExitStack,
fixture: Optional[Callable] = None, fixture: Callable,
serializer: ModuleType = json serializer: ModuleType = json,
): ):
self.url = url self.url = url
self.token = token
self.fixture = fixture self.fixture = fixture
self._stack = stack self._stack = stack
self._ws: 'WebSocketConnection' = None # noqa self._ws: 'WebSocketConnection' = None # noqa
# TODO: is there some method we can call
# on the underlying `._ws` to get this?
self._connected: bool = False
async def _connect( async def _connect(
self, self,
tries: int = 1000, tries: int = 1000,
) -> None: ) -> None:
self._connected = False
while True: while True:
try: try:
await self._stack.aclose() await self._stack.aclose()
except self.recon_errors: except (DisconnectionTimeout, RuntimeError):
await trio.sleep(0.5) await trio.sleep(0.5)
else: else:
break break
@ -103,18 +82,19 @@ class NoBsWs:
self._ws = await self._stack.enter_async_context( self._ws = await self._stack.enter_async_context(
trio_websocket.open_websocket_url(self.url) trio_websocket.open_websocket_url(self.url)
) )
if self.fixture is not None:
# rerun user code fixture # rerun user code fixture
if self.token == '':
ret = await self._stack.enter_async_context( ret = await self._stack.enter_async_context(
self.fixture(self) self.fixture(self)
) )
else:
ret = await self._stack.enter_async_context(
self.fixture(self, self.token)
)
assert ret is None assert ret is None
log.info(f'Connection success: {self.url}') log.info(f'Connection success: {self.url}')
self._connected = True
return self._ws return self._ws
except self.recon_errors as err: except self.recon_errors as err:
@ -124,15 +104,11 @@ class NoBsWs:
f'{type(err)}...retry attempt {i}' f'{type(err)}...retry attempt {i}'
) )
await trio.sleep(0.5) await trio.sleep(0.5)
self._connected = False
continue continue
else: else:
log.exception('ws connection fail...') log.exception('ws connection fail...')
raise last_err raise last_err
def connected(self) -> bool:
return self._connected
async def send_msg( async def send_msg(
self, self,
data: Any, data: Any,
@ -152,26 +128,21 @@ class NoBsWs:
except self.recon_errors: except self.recon_errors:
await self._connect() await self._connect()
def __aiter__(self):
return self
async def __anext__(self):
return await self.recv_msg()
@asynccontextmanager @asynccontextmanager
async def open_autorecon_ws( async def open_autorecon_ws(
url: str, url: str,
# TODO: proper type cannot smh # TODO: proper type annot smh
fixture: Optional[Callable] = None, fixture: Callable,
# used for authenticated websockets
token: str = '',
) -> AsyncGenerator[tuple[...], NoBsWs]: ) -> AsyncGenerator[tuple[...], NoBsWs]:
"""Apparently we can QoS for all sorts of reasons..so catch em. """Apparently we can QoS for all sorts of reasons..so catch em.
""" """
async with AsyncExitStack() as stack: async with AsyncExitStack() as stack:
ws = NoBsWs(url, stack, fixture=fixture) ws = NoBsWs(url, token, stack, fixture=fixture)
await ws._connect() await ws._connect()
try: try:
@ -179,114 +150,3 @@ async def open_autorecon_ws(
finally: finally:
await stack.aclose() await stack.aclose()
'''
JSONRPC response-request style machinery for transparent multiplexing of msgs
over a NoBsWs.
'''
class JSONRPCResult(Struct):
id: int
jsonrpc: str = '2.0'
result: Optional[dict] = None
error: Optional[dict] = None
@asynccontextmanager
async def open_jsonrpc_session(
url: str,
start_id: int = 0,
response_type: type = JSONRPCResult,
request_type: Optional[type] = None,
request_hook: Optional[Callable] = None,
error_hook: Optional[Callable] = None,
) -> Callable[[str, dict], dict]:
async with (
trio.open_nursery() as n,
open_autorecon_ws(url) as ws
):
rpc_id: Iterable = count(start_id)
rpc_results: dict[int, dict] = {}
async def json_rpc(method: str, params: dict) -> dict:
'''
perform a json rpc call and wait for the result, raise exception in
case of error field present on response
'''
msg = {
'jsonrpc': '2.0',
'id': next(rpc_id),
'method': method,
'params': params
}
_id = msg['id']
rpc_results[_id] = {
'result': None,
'event': trio.Event()
}
await ws.send_msg(msg)
await rpc_results[_id]['event'].wait()
ret = rpc_results[_id]['result']
del rpc_results[_id]
if ret.error is not None:
raise Exception(json.dumps(ret.error, indent=4))
return ret
async def recv_task():
'''
receives every ws message and stores it in its corresponding
result field, then sets the event to wakeup original sender
tasks. also recieves responses to requests originated from
the server side.
'''
async for msg in ws:
match msg:
case {
'result': _,
'id': mid,
} if res_entry := rpc_results.get(mid):
res_entry['result'] = response_type(**msg)
res_entry['event'].set()
case {
'result': _,
'id': mid,
} if not rpc_results.get(mid):
log.warning(
f'Unexpected ws msg: {json.dumps(msg, indent=4)}'
)
case {
'method': _,
'params': _,
}:
log.debug(f'Recieved\n{msg}')
if request_hook:
await request_hook(request_type(**msg))
case {
'error': error
}:
log.warning(f'Recieved\n{error}')
if error_hook:
await error_hook(response_type(**msg))
case _:
log.warning(f'Unhandled JSON-RPC msg!?\n{msg}')
n.start_soon(recv_task)
yield json_rpc
n.cancel_scope.cancel()

File diff suppressed because it is too large Load Diff

View File

@ -1,321 +0,0 @@
# piker: trading gear for hackers
# Copyright (C) Tyler Goodlet (in stewardship for pikers)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
abstractions for organizing, managing and generally operating-on
real-time data processing data-structures.
"Streams, flumes, cascades and flows.."
"""
from __future__ import annotations
from contextlib import asynccontextmanager as acm
from functools import partial
from typing import (
AsyncIterator,
TYPE_CHECKING,
)
import tractor
from tractor.trionics import (
maybe_open_context,
)
import pendulum
import numpy as np
from .types import Struct
from ._source import (
Symbol,
)
from ._sharedmem import (
attach_shm_array,
ShmArray,
_Token,
)
from ._sampling import (
open_sample_stream,
)
if TYPE_CHECKING:
from pyqtgraph import PlotItem
from .feed import Feed
# TODO: ideas for further abstractions as per
# https://github.com/pikers/piker/issues/216 and
# https://github.com/pikers/piker/issues/270:
# - a ``Cascade`` would be the minimal "connection" of 2 ``Flumes``
# as per circuit parlance:
# https://en.wikipedia.org/wiki/Two-port_network#Cascade_connection
# - could cover the combination of our `FspAdmin` and the
# backend `.fsp._engine` related machinery to "connect" one flume
# to another?
# - a (financial signal) ``Flow`` would be the a "collection" of such
# minmial cascades. Some engineering based jargon concepts:
# - https://en.wikipedia.org/wiki/Signal_chain
# - https://en.wikipedia.org/wiki/Daisy_chain_(electrical_engineering)
# - https://en.wikipedia.org/wiki/Audio_signal_flow
# - https://en.wikipedia.org/wiki/Digital_signal_processing#Implementation
# - https://en.wikipedia.org/wiki/Dataflow_programming
# - https://en.wikipedia.org/wiki/Signal_programming
# - https://en.wikipedia.org/wiki/Incremental_computing
class Flume(Struct):
'''
Composite reference type which points to all the addressing handles
and other meta-data necessary for the read, measure and management
of a set of real-time updated data flows.
Can be thought of as a "flow descriptor" or "flow frame" which
describes the high level properties of a set of data flows that can
be used seamlessly across process-memory boundaries.
Each instance's sub-components normally includes:
- a msg oriented quote stream provided via an IPC transport
- history and real-time shm buffers which are both real-time
updated and backfilled.
- associated startup indexing information related to both buffer
real-time-append and historical prepend addresses.
- low level APIs to read and measure the updated data and manage
queuing properties.
'''
symbol: Symbol
first_quote: dict
_rt_shm_token: _Token
# optional since some data flows won't have a "downsampled" history
# buffer/stream (eg. FSPs).
_hist_shm_token: _Token | None = None
# private shm refs loaded dynamically from tokens
_hist_shm: ShmArray | None = None
_rt_shm: ShmArray | None = None
stream: tractor.MsgStream | None = None
izero_hist: int = 0
izero_rt: int = 0
throttle_rate: int | None = None
# TODO: do we need this really if we can pull the `Portal` from
# ``tractor``'s internals?
feed: Feed | None = None
@property
def rt_shm(self) -> ShmArray:
if self._rt_shm is None:
self._rt_shm = attach_shm_array(
token=self._rt_shm_token,
readonly=True,
)
return self._rt_shm
@property
def hist_shm(self) -> ShmArray:
if self._hist_shm_token is None:
raise RuntimeError(
'No shm token has been set for the history buffer?'
)
if (
self._hist_shm is None
):
self._hist_shm = attach_shm_array(
token=self._hist_shm_token,
readonly=True,
)
return self._hist_shm
async def receive(self) -> dict:
return await self.stream.receive()
@acm
async def index_stream(
self,
delay_s: float = 1,
) -> AsyncIterator[int]:
if not self.feed:
raise RuntimeError('This flume is not part of any ``Feed``?')
# TODO: maybe a public (property) API for this in ``tractor``?
portal = self.stream._ctx._portal
assert portal
# XXX: this should be singleton on a host,
# a lone broker-daemon per provider should be
# created for all practical purposes
async with open_sample_stream(float(delay_s)) as stream:
yield stream
def get_ds_info(
self,
) -> tuple[float, float, float]:
'''
Compute the "downsampling" ratio info between the historical shm
buffer and the real-time (HFT) one.
Return a tuple of the fast sample period, historical sample
period and ratio between them.
'''
times = self.hist_shm.array['time']
end = pendulum.from_timestamp(times[-1])
start = pendulum.from_timestamp(times[times != times[-1]][-1])
hist_step_size_s = (end - start).seconds
times = self.rt_shm.array['time']
end = pendulum.from_timestamp(times[-1])
start = pendulum.from_timestamp(times[times != times[-1]][-1])
rt_step_size_s = (end - start).seconds
ratio = hist_step_size_s / rt_step_size_s
return (
rt_step_size_s,
hist_step_size_s,
ratio,
)
# TODO: get native msgspec decoding for these workinn
def to_msg(self) -> dict:
msg = self.to_dict()
msg['symbol'] = msg['symbol'].to_dict()
# can't serialize the stream or feed objects, it's expected
# you'll have a ref to it since this msg should be rxed on
# a stream on whatever far end IPC..
msg.pop('stream')
msg.pop('feed')
return msg
@classmethod
def from_msg(cls, msg: dict) -> dict:
symbol = Symbol(**msg.pop('symbol'))
return cls(
symbol=symbol,
**msg,
)
def get_index(
self,
time_s: float,
) -> int:
'''
Return array shm-buffer index for for epoch time.
'''
array = self.rt_shm.array
times = array['time']
mask = (times >= time_s)
if any(mask):
return array['index'][mask][0]
# just the latest index
array['index'][-1]
def slice_from_time(
self,
array: np.ndarray,
start_t: float,
stop_t: float,
timeframe_s: int = 1,
return_data: bool = False,
) -> np.ndarray:
'''
Slice an input struct array providing only datums
"in view" of this chart.
'''
arr = {
1: self.rt_shm.array,
60: self.hist_shm.arry,
}[timeframe_s]
times = arr['time']
index = array['index']
# use advanced indexing to map the
# time range to the index range.
mask = (
(times >= start_t)
&
(times < stop_t)
)
# TODO: if we can ensure each time field has a uniform
# step we can instead do some arithmetic to determine
# the equivalent index like we used to?
# return array[
# lbar - ifirst:
# (rbar - ifirst) + 1
# ]
i_by_t = index[mask]
i_0 = i_by_t[0]
abs_slc = slice(
i_0,
i_by_t[-1],
)
# slice data by offset from the first index
# available in the passed datum set.
read_slc = slice(
0,
i_by_t[-1] - i_0,
)
if not return_data:
return (
abs_slc,
read_slc,
)
# also return the readable data from the timerange
return (
abs_slc,
read_slc,
arr[mask],
)
def view_data(
self,
plot: PlotItem,
timeframe_s: int = 1,
) -> np.ndarray:
# get far-side x-indices plot view
vr = plot.viewRect()
(
abs_slc,
buf_slc,
iv_arr,
) = self.slice_from_time(
start_t=vr.left(),
stop_t=vr.right(),
timeframe_s=timeframe_s,
return_data=True,
)
return iv_arr

View File

@ -37,8 +37,8 @@ import time
from math import isnan from math import isnan
from bidict import bidict from bidict import bidict
from msgspec.msgpack import encode, decode import msgpack
# import pyqtgraph as pg import pyqtgraph as pg
import numpy as np import numpy as np
import tractor import tractor
from trio_websocket import open_websocket_url from trio_websocket import open_websocket_url
@ -56,7 +56,6 @@ if TYPE_CHECKING:
from .feed import maybe_open_feed from .feed import maybe_open_feed
from ..log import get_logger, get_console_log from ..log import get_logger, get_console_log
from .._profile import Profiler
log = get_logger(__name__) log = get_logger(__name__)
@ -132,10 +131,7 @@ def start_marketstore(
mktsdir = os.path.join(config._config_dir, 'marketstore') mktsdir = os.path.join(config._config_dir, 'marketstore')
# create dirs when dne # create when dne
if not os.path.isdir(config._config_dir):
os.mkdir(config._config_dir)
if not os.path.isdir(mktsdir): if not os.path.isdir(mktsdir):
os.mkdir(mktsdir) os.mkdir(mktsdir)
@ -391,54 +387,50 @@ class Storage:
async def load( async def load(
self, self,
fqsn: str, fqsn: str,
timeframe: int,
) -> tuple[ ) -> tuple[
np.ndarray, # timeframe sampled array-series dict[int, np.ndarray], # timeframe (in secs) to series
Optional[datetime], # first dt Optional[datetime], # first dt
Optional[datetime], # last dt Optional[datetime], # last dt
]: ]:
first_tsdb_dt, last_tsdb_dt = None, None first_tsdb_dt, last_tsdb_dt = None, None
hist = await self.read_ohlcv( tsdb_arrays = await self.read_ohlcv(
fqsn, fqsn,
# on first load we don't need to pull the max # on first load we don't need to pull the max
# history per request size worth. # history per request size worth.
limit=3000, limit=3000,
timeframe=timeframe,
) )
log.info(f'Loaded tsdb history {hist}') log.info(f'Loaded tsdb history {tsdb_arrays}')
if len(hist): if tsdb_arrays:
times = hist['Epoch'] fastest = list(tsdb_arrays.values())[0]
times = fastest['Epoch']
first, last = times[0], times[-1] first, last = times[0], times[-1]
first_tsdb_dt, last_tsdb_dt = map( first_tsdb_dt, last_tsdb_dt = map(
pendulum.from_timestamp, [first, last] pendulum.from_timestamp, [first, last]
) )
return ( return tsdb_arrays, first_tsdb_dt, last_tsdb_dt
hist, # array-data
first_tsdb_dt, # start of query-frame
last_tsdb_dt, # most recent
)
async def read_ohlcv( async def read_ohlcv(
self, self,
fqsn: str, fqsn: str,
timeframe: int | str, timeframe: Optional[Union[int, str]] = None,
end: Optional[int] = None, end: Optional[int] = None,
limit: int = int(800e3), limit: int = int(800e3),
) -> np.ndarray: ) -> tuple[
MarketstoreClient,
Union[dict, np.ndarray]
]:
client = self.client client = self.client
syms = await client.list_symbols() syms = await client.list_symbols()
if fqsn not in syms: if fqsn not in syms:
return {} return {}
# use the provided timeframe or 1s by default tfstr = tf_in_1s[1]
tfstr = tf_in_1s.get(timeframe, tf_in_1s[1])
params = Params( params = Params(
symbols=fqsn, symbols=fqsn,
@ -452,72 +444,58 @@ class Storage:
limit=limit, limit=limit,
) )
if timeframe is None:
log.info(f'starting {fqsn} tsdb granularity scan..')
# loop through and try to find highest granularity
for tfstr in tf_in_1s.values():
try: try:
log.info(f'querying for {tfstr}@{fqsn}')
params.set('timeframe', tfstr)
result = await client.query(params) result = await client.query(params)
except purerpc.grpclib.exceptions.UnknownError as err: break
# indicate there is no history for this timeframe
log.exception( except purerpc.grpclib.exceptions.UnknownError:
f'Unknown mkts QUERY error: {params}\n' # XXX: this is already logged by the container and
f'{err.args}' # thus shows up through `marketstored` logs relay.
) # log.warning(f'{tfstr}@{fqsn} not found')
continue
else:
return {} return {}
else:
result = await client.query(params)
# TODO: it turns out column access on recarrays is actually slower: # TODO: it turns out column access on recarrays is actually slower:
# https://jakevdp.github.io/PythonDataScienceHandbook/02.09-structured-data-numpy.html#RecordArrays:-Structured-Arrays-with-a-Twist # https://jakevdp.github.io/PythonDataScienceHandbook/02.09-structured-data-numpy.html#RecordArrays:-Structured-Arrays-with-a-Twist
# it might make sense to make these structured arrays? # it might make sense to make these structured arrays?
data_set = result.by_symbols()[fqsn] # Fill out a `numpy` array-results map
array = data_set.array arrays = {}
for fqsn, data_set in result.by_symbols().items():
arrays.setdefault(fqsn, {})[
tf_in_1s.inverse[data_set.timeframe]
] = data_set.array
# XXX: ensure sample rate is as expected return arrays[fqsn][timeframe] if timeframe else arrays[fqsn]
time = data_set.array['Epoch']
if len(time) > 1:
time_step = time[-1] - time[-2]
ts = tf_in_1s.inverse[data_set.timeframe]
if time_step != ts:
log.warning(
f'MKTS BUG: wrong timeframe loaded: {time_step}'
'YOUR DATABASE LIKELY CONTAINS BAD DATA FROM AN OLD BUG'
f'WIPING HISTORY FOR {ts}s'
)
await self.delete_ts(fqsn, timeframe)
# try reading again..
return await self.read_ohlcv(
fqsn,
timeframe,
end,
limit,
)
return array
async def delete_ts( async def delete_ts(
self, self,
key: str, key: str,
timeframe: Optional[Union[int, str]] = None, timeframe: Optional[Union[int, str]] = None,
fmt: str = 'OHLCV',
) -> bool: ) -> bool:
client = self.client client = self.client
syms = await client.list_symbols() syms = await client.list_symbols()
print(syms) print(syms)
if key not in syms: # if key not in syms:
raise KeyError(f'`{key}` table key not found in\n{syms}?') # raise KeyError(f'`{fqsn}` table key not found?')
tbk = mk_tbk(( return await client.destroy(tbk=key)
key,
tf_in_1s.get(timeframe, tf_in_1s[60]),
fmt,
))
return await client.destroy(tbk=tbk)
async def write_ohlcv( async def write_ohlcv(
self, self,
fqsn: str, fqsn: str,
ohlcv: np.ndarray, ohlcv: np.ndarray,
timeframe: int,
append_and_duplicate: bool = True, append_and_duplicate: bool = True,
limit: int = int(800e3), limit: int = int(800e3),
@ -541,18 +519,17 @@ class Storage:
m, r = divmod(len(mkts_array), limit) m, r = divmod(len(mkts_array), limit)
tfkey = tf_in_1s[timeframe]
for i in range(m, 1): for i in range(m, 1):
to_push = mkts_array[i-1:i*limit] to_push = mkts_array[i-1:i*limit]
# write to db # write to db
resp = await self.client.write( resp = await self.client.write(
to_push, to_push,
tbk=f'{fqsn}/{tfkey}/OHLCV', tbk=f'{fqsn}/1Sec/OHLCV',
# NOTE: will will append duplicates # NOTE: will will append duplicates
# for the same timestamp-index. # for the same timestamp-index.
# TODO: pre-deduplicate? # TODO: pre deduplicate?
isvariablelength=append_and_duplicate, isvariablelength=append_and_duplicate,
) )
@ -571,7 +548,7 @@ class Storage:
# write to db # write to db
resp = await self.client.write( resp = await self.client.write(
to_push, to_push,
tbk=f'{fqsn}/{tfkey}/OHLCV', tbk=f'{fqsn}/1Sec/OHLCV',
# NOTE: will will append duplicates # NOTE: will will append duplicates
# for the same timestamp-index. # for the same timestamp-index.
@ -600,7 +577,6 @@ class Storage:
# def delete_range(self, start_dt, end_dt) -> None: # def delete_range(self, start_dt, end_dt) -> None:
# ... # ...
@acm @acm
async def open_storage_client( async def open_storage_client(
fqsn: str, fqsn: str,
@ -650,7 +626,7 @@ async def tsdb_history_update(
# * the original data feed arch blurb: # * the original data feed arch blurb:
# - https://github.com/pikers/piker/issues/98 # - https://github.com/pikers/piker/issues/98
# #
profiler = Profiler( profiler = pg.debug.Profiler(
disabled=False, # not pg_profile_enabled(), disabled=False, # not pg_profile_enabled(),
delayed=False, delayed=False,
) )
@ -662,35 +638,34 @@ async def tsdb_history_update(
[fqsn], [fqsn],
start_stream=False, start_stream=False,
) as feed, ) as (feed, stream),
): ):
profiler(f'opened feed for {fqsn}') profiler(f'opened feed for {fqsn}')
# to_append = feed.hist_shm.array to_append = feed.shm.array
# to_prepend = None to_prepend = None
if fqsn: if fqsn:
flume = feed.flumes[fqsn] symbol = feed.symbols.get(fqsn)
symbol = flume.symbol
if symbol: if symbol:
fqsn = symbol.fqsn fqsn = symbol.front_fqsn()
# diff db history with shm and only write the missing portions # diff db history with shm and only write the missing portions
# ohlcv = flume.hist_shm.array ohlcv = feed.shm.array
# TODO: use pg profiler # TODO: use pg profiler
# for secs in (1, 60): tsdb_arrays = await storage.read_ohlcv(fqsn)
# tsdb_array = await storage.read_ohlcv( # hist diffing
# fqsn, if tsdb_arrays:
# timeframe=timeframe, for secs in (1, 60):
# ) ts = tsdb_arrays.get(secs)
# # hist diffing: if ts is not None and len(ts):
# # these aren't currently used but can be referenced from # these aren't currently used but can be referenced from
# # within the embedded ipython shell below. # within the embedded ipython shell below.
# to_append = ohlcv[ohlcv['time'] > ts['Epoch'][-1]] to_append = ohlcv[ohlcv['time'] > ts['Epoch'][-1]]
# to_prepend = ohlcv[ohlcv['time'] < ts['Epoch'][0]] to_prepend = ohlcv[ohlcv['time'] < ts['Epoch'][0]]
# profiler('Finished db arrays diffs') profiler('Finished db arrays diffs')
syms = await storage.client.list_symbols() syms = await storage.client.list_symbols()
log.info(f'Existing tsdb symbol set:\n{pformat(syms)}') log.info(f'Existing tsdb symbol set:\n{pformat(syms)}')
@ -799,13 +774,12 @@ async def stream_quotes(
async with open_websocket_url(f'ws://{host}:{port}/ws') as ws: async with open_websocket_url(f'ws://{host}:{port}/ws') as ws:
# send subs topics to server # send subs topics to server
resp = await ws.send_message( resp = await ws.send_message(
msgpack.dumps({'streams': list(tbks.values())})
encode({'streams': list(tbks.values())})
) )
log.info(resp) log.info(resp)
async def recv() -> dict[str, Any]: async def recv() -> dict[str, Any]:
return decode((await ws.get_message()), encoding='utf-8') return msgpack.loads((await ws.get_message()), encoding='utf-8')
streams = (await recv())['streams'] streams = (await recv())['streams']
log.info(f"Subscribed to {streams}") log.info(f"Subscribed to {streams}")

View File

@ -1,88 +0,0 @@
# piker: trading gear for hackers
# Copyright (C) Guillermo Rodriguez (in stewardship for piker0)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Built-in (extension) types.
"""
import sys
from typing import Optional
from pprint import pformat
import msgspec
class Struct(
msgspec.Struct,
# https://jcristharif.com/msgspec/structs.html#tagged-unions
# tag='pikerstruct',
# tag=True,
):
'''
A "human friendlier" (aka repl buddy) struct subtype.
'''
def to_dict(self) -> dict:
return {
f: getattr(self, f)
for f in self.__struct_fields__
}
# Lul, doesn't seem to work that well..
# def __repr__(self):
# # only turn on pprint when we detect a python REPL
# # at runtime B)
# if (
# hasattr(sys, 'ps1')
# # TODO: check if we're in pdb
# ):
# return self.pformat()
# return super().__repr__()
def pformat(self) -> str:
return f'Struct({pformat(self.to_dict())})'
def copy(
self,
update: Optional[dict] = None,
) -> msgspec.Struct:
'''
Validate-typecast all self defined fields, return a copy of us
with all such fields.
This is kinda like the default behaviour in `pydantic.BaseModel`.
'''
if update:
for k, v in update.items():
setattr(self, k, v)
# roundtrip serialize to validate
return msgspec.msgpack.Decoder(
type=type(self)
).decode(
msgspec.msgpack.Encoder().encode(self)
)
def typecast(
self,
# fields: Optional[list[str]] = None,
) -> None:
for fname, ftype in self.__annotations__.items():
setattr(self, fname, ftype(getattr(self, fname)))

View File

@ -78,8 +78,7 @@ class Fsp:
# + the consuming fsp *to* the consumers output # + the consuming fsp *to* the consumers output
# shm flow. # shm flow.
_flow_registry: dict[ _flow_registry: dict[
tuple[_Token, str], tuple[_Token, str], _Token,
tuple[_Token, Optional[ShmArray]],
] = {} ] = {}
def __init__( def __init__(
@ -121,6 +120,7 @@ class Fsp:
): ):
return self.func(*args, **kwargs) return self.func(*args, **kwargs)
# TODO: lru_cache this? prettty sure it'll work?
def get_shm( def get_shm(
self, self,
src_shm: ShmArray, src_shm: ShmArray,
@ -131,27 +131,12 @@ class Fsp:
for this "instance" of a signal processor for for this "instance" of a signal processor for
the given ``key``. the given ``key``.
The destination shm "token" and array are cached if possible to
minimize multiple stdlib/system calls.
''' '''
dst_token, maybe_array = self._flow_registry[ dst_token = self._flow_registry[
(src_shm._token, self.name) (src_shm._token, self.name)
] ]
if maybe_array is None: shm = attach_shm_array(dst_token)
self._flow_registry[ return shm
(src_shm._token, self.name)
] = (
dst_token,
# "cache" the ``ShmArray`` such that
# we call the underlying "attach" code as few
# times as possible as per:
# - https://github.com/pikers/piker/issues/359
# - https://github.com/pikers/piker/issues/332
maybe_array := attach_shm_array(dst_token)
)
return maybe_array
def fsp( def fsp(
@ -199,10 +184,7 @@ def maybe_mk_fsp_shm(
# TODO: load output types from `Fsp` # TODO: load output types from `Fsp`
# - should `index` be a required internal field? # - should `index` be a required internal field?
fsp_dtype = np.dtype( fsp_dtype = np.dtype(
[('index', int)] [('index', int)] +
+
[('time', float)]
+
[(field_name, float) for field_name in target.outputs] [(field_name, float) for field_name in target.outputs]
) )

View File

@ -21,13 +21,12 @@ core task logic for processing chains
from dataclasses import dataclass from dataclasses import dataclass
from functools import partial from functools import partial
from typing import ( from typing import (
AsyncIterator, AsyncIterator, Callable, Optional,
Callable,
Optional,
Union, Union,
) )
import numpy as np import numpy as np
import pyqtgraph as pg
import trio import trio
from trio_typing import TaskStatus from trio_typing import TaskStatus
import tractor import tractor
@ -36,22 +35,14 @@ from tractor.msg import NamespacePath
from ..log import get_logger, get_console_log from ..log import get_logger, get_console_log
from .. import data from .. import data
from ..data import attach_shm_array from ..data import attach_shm_array
from ..data.feed import ( from ..data.feed import Feed
Flume,
Feed,
)
from ..data._sharedmem import ShmArray from ..data._sharedmem import ShmArray
from ..data._sampling import (
_default_delay_s,
open_sample_stream,
)
from ..data._source import Symbol from ..data._source import Symbol
from ._api import ( from ._api import (
Fsp, Fsp,
_load_builtins, _load_builtins,
_Token, _Token,
) )
from .._profile import Profiler
log = get_logger(__name__) log = get_logger(__name__)
@ -86,7 +77,7 @@ async def filter_quotes_by_sym(
async def fsp_compute( async def fsp_compute(
symbol: Symbol, symbol: Symbol,
flume: Flume, feed: Feed,
quote_stream: trio.abc.ReceiveChannel, quote_stream: trio.abc.ReceiveChannel,
src: ShmArray, src: ShmArray,
@ -99,7 +90,7 @@ async def fsp_compute(
) -> None: ) -> None:
profiler = Profiler( profiler = pg.debug.Profiler(
delayed=False, delayed=False,
disabled=True disabled=True
) )
@ -114,17 +105,16 @@ async def fsp_compute(
filter_quotes_by_sym(fqsn, quote_stream), filter_quotes_by_sym(fqsn, quote_stream),
# XXX: currently the ``ohlcv`` arg # XXX: currently the ``ohlcv`` arg
flume.rt_shm, feed.shm,
) )
# HISTORY COMPUTE PHASE # Conduct a single iteration of fsp with historical bars input
# conduct a single iteration of fsp with historical bars input # and get historical output
# and get historical output.
history_output: Union[ history_output: Union[
dict[str, np.ndarray], # multi-output case dict[str, np.ndarray], # multi-output case
np.ndarray, # single output case np.ndarray, # single output case
] ]
history_output = await anext(out_stream) history_output = await out_stream.__anext__()
func_name = func.__name__ func_name = func.__name__
profiler(f'{func_name} generated history') profiler(f'{func_name} generated history')
@ -136,13 +126,9 @@ async def fsp_compute(
# each respective field. # each respective field.
fields = getattr(dst.array.dtype, 'fields', None).copy() fields = getattr(dst.array.dtype, 'fields', None).copy()
fields.pop('index') fields.pop('index')
history_by_field: Optional[np.ndarray] = None history: Optional[np.ndarray] = None # TODO: nptyping here!
src_time = src.array['time']
if ( if fields and len(fields) > 1 and fields:
fields and
len(fields) > 1
):
if not isinstance(history_output, dict): if not isinstance(history_output, dict):
raise ValueError( raise ValueError(
f'`{func_name}` is a multi-output FSP and should yield a ' f'`{func_name}` is a multi-output FSP and should yield a '
@ -153,7 +139,7 @@ async def fsp_compute(
if key in history_output: if key in history_output:
output = history_output[key] output = history_output[key]
if history_by_field is None: if history is None:
if output is None: if output is None:
length = len(src.array) length = len(src.array)
@ -163,7 +149,7 @@ async def fsp_compute(
# using the first output, determine # using the first output, determine
# the length of the struct-array that # the length of the struct-array that
# will be pushed to shm. # will be pushed to shm.
history_by_field = np.zeros( history = np.zeros(
length, length,
dtype=dst.array.dtype dtype=dst.array.dtype
) )
@ -171,7 +157,7 @@ async def fsp_compute(
if output is None: if output is None:
continue continue
history_by_field[key] = output history[key] = output
# single-key output stream # single-key output stream
else: else:
@ -180,13 +166,11 @@ async def fsp_compute(
f'`{func_name}` is a single output FSP and should yield an ' f'`{func_name}` is a single output FSP and should yield an '
'`np.ndarray` for history' '`np.ndarray` for history'
) )
history_by_field = np.zeros( history = np.zeros(
len(history_output), len(history_output),
dtype=dst.array.dtype dtype=dst.array.dtype
) )
history_by_field[func_name] = history_output history[func_name] = history_output
history_by_field['time'] = src_time[-len(history_by_field):]
# TODO: XXX: # TODO: XXX:
# THERE'S A BIG BUG HERE WITH THE `index` field since we're # THERE'S A BIG BUG HERE WITH THE `index` field since we're
@ -203,10 +187,7 @@ async def fsp_compute(
# TODO: can we use this `start` flag instead of the manual # TODO: can we use this `start` flag instead of the manual
# setting above? # setting above?
index = dst.push( index = dst.push(history, start=first)
history_by_field,
start=first,
)
profiler(f'{func_name} pushed history') profiler(f'{func_name} pushed history')
profiler.finish() profiler.finish()
@ -232,14 +213,8 @@ async def fsp_compute(
log.debug(f"{func_name}: {processed}") log.debug(f"{func_name}: {processed}")
key, output = processed key, output = processed
# dst.array[-1][key] = output index = src.index
dst.array[[key, 'time']][-1] = ( dst.array[-1][key] = output
output,
# TODO: what about pushing ``time.time_ns()``
# in which case we'll need to round at the graphics
# processing / sampling layer?
src.array[-1]['time']
)
# NOTE: for now we aren't streaming this to the consumer # NOTE: for now we aren't streaming this to the consumer
# stream latest array index entry which basically just acts # stream latest array index entry which basically just acts
@ -250,7 +225,6 @@ async def fsp_compute(
# N-consumers who subscribe for the real-time output, # N-consumers who subscribe for the real-time output,
# which we'll likely want to implement using local-mem # which we'll likely want to implement using local-mem
# chans for the fan out? # chans for the fan out?
# index = src.index
# if attach_stream: # if attach_stream:
# await client_stream.send(index) # await client_stream.send(index)
@ -287,7 +261,7 @@ async def cascade(
destination shm array buffer. destination shm array buffer.
''' '''
profiler = Profiler( profiler = pg.debug.Profiler(
delayed=False, delayed=False,
disabled=False disabled=False
) )
@ -310,10 +284,9 @@ async def cascade(
# TODO: ugh i hate this wind/unwind to list over the wire # TODO: ugh i hate this wind/unwind to list over the wire
# but not sure how else to do it. # but not sure how else to do it.
for (token, fsp_name, dst_token) in shm_registry: for (token, fsp_name, dst_token) in shm_registry:
Fsp._flow_registry[( Fsp._flow_registry[
_Token.from_msg(token), (_Token.from_msg(token), fsp_name)
fsp_name, ] = _Token.from_msg(dst_token)
)] = _Token.from_msg(dst_token), None
fsp: Fsp = reg.get( fsp: Fsp = reg.get(
NamespacePath(ns_path) NamespacePath(ns_path)
@ -325,7 +298,6 @@ async def cascade(
raise ValueError(f'Unknown fsp target: {ns_path}') raise ValueError(f'Unknown fsp target: {ns_path}')
# open a data feed stream with requested broker # open a data feed stream with requested broker
feed: Feed
async with data.feed.maybe_open_feed( async with data.feed.maybe_open_feed(
[fqsn], [fqsn],
@ -335,13 +307,14 @@ async def cascade(
# needs to get throttled the ticks we generate. # needs to get throttled the ticks we generate.
# tick_throttle=60, # tick_throttle=60,
) as feed: ) as (feed, quote_stream):
symbol = feed.symbols[fqsn]
flume = feed.flumes[fqsn]
symbol = flume.symbol
assert src.token == flume.rt_shm.token
profiler(f'{func}: feed up') profiler(f'{func}: feed up')
assert src.token == feed.shm.token
# last_len = new_len = len(src.array)
func_name = func.__name__ func_name = func.__name__
async with ( async with (
trio.open_nursery() as n, trio.open_nursery() as n,
@ -351,8 +324,8 @@ async def cascade(
fsp_compute, fsp_compute,
symbol=symbol, symbol=symbol,
flume=flume, feed=feed,
quote_stream=flume.stream, quote_stream=quote_stream,
# shm # shm
src=src, src=src,
@ -388,7 +361,7 @@ async def cascade(
) -> tuple[TaskTracker, int]: ) -> tuple[TaskTracker, int]:
# TODO: adopt an incremental update engine/approach # TODO: adopt an incremental update engine/approach
# where possible here eventually! # where possible here eventually!
log.info(f're-syncing fsp {func_name} to source') log.debug(f're-syncing fsp {func_name} to source')
tracker.cs.cancel() tracker.cs.cancel()
await tracker.complete.wait() await tracker.complete.wait()
tracker, index = await n.start(fsp_target) tracker, index = await n.start(fsp_target)
@ -401,16 +374,14 @@ async def cascade(
'key': dst_shm_token, 'key': dst_shm_token,
'first': dst._first.value, 'first': dst._first.value,
'last': dst._last.value, 'last': dst._last.value,
} }})
})
return tracker, index return tracker, index
def is_synced( def is_synced(
src: ShmArray, src: ShmArray,
dst: ShmArray dst: ShmArray
) -> tuple[bool, int, int]: ) -> tuple[bool, int, int]:
''' '''Predicate to dertmine if a destination FSP
Predicate to dertmine if a destination FSP
output array is aligned to its source array. output array is aligned to its source array.
''' '''
@ -419,15 +390,16 @@ async def cascade(
return not ( return not (
# the source is likely backfilling and we must # the source is likely backfilling and we must
# sync history calculations # sync history calculations
len_diff > 2 len_diff > 2 or
# we aren't step synced to the source and may be # we aren't step synced to the source and may be
# leading/lagging by a step # leading/lagging by a step
or step_diff > 1 step_diff > 1 or
or step_diff < 0 step_diff < 0
), step_diff, len_diff ), step_diff, len_diff
async def poll_and_sync_to_step( async def poll_and_sync_to_step(
tracker: TaskTracker, tracker: TaskTracker,
src: ShmArray, src: ShmArray,
dst: ShmArray, dst: ShmArray,
@ -446,23 +418,18 @@ async def cascade(
# detect sample period step for subscription to increment # detect sample period step for subscription to increment
# signal # signal
times = src.array['time'] times = src.array['time']
if len(times) > 1: delay_s = times[-1] - times[times != times[-1]][-1]
last_ts = times[-1]
delay_s = float(last_ts - times[times != last_ts][-1])
else:
# our default "HFT" sample rate.
delay_s = _default_delay_s
# sub and increment the underlying shared memory buffer # Increment the underlying shared memory buffer on every
# on every step msg received from the global `samplerd` # "increment" msg received from the underlying data feed.
# service. async with feed.index_stream(
async with open_sample_stream(float(delay_s)) as istream: int(delay_s)
) as istream:
profiler(f'{func_name}: sample stream up') profiler(f'{func_name}: sample stream up')
profiler.finish() profiler.finish()
async for i in istream: async for _ in istream:
# print(f'FSP incrementing {i}')
# respawn the compute task if the source # respawn the compute task if the source
# array has been updated such that we compute # array has been updated such that we compute
@ -491,23 +458,3 @@ async def cascade(
last = array[-1:].copy() last = array[-1:].copy()
dst.push(last) dst.push(last)
# sync with source buffer's time step
src_l2 = src.array[-2:]
src_li, src_lt = src_l2[-1][['index', 'time']]
src_2li, src_2lt = src_l2[-2][['index', 'time']]
dst._array['time'][src_li] = src_lt
dst._array['time'][src_2li] = src_2lt
# last2 = dst.array[-2:]
# if (
# last2[-1]['index'] != src_li
# or last2[-2]['index'] != src_2li
# ):
# dstl2 = list(last2)
# srcl2 = list(src_l2)
# print(
# # f'{dst.token}\n'
# f'src: {srcl2}\n'
# f'dst: {dstl2}\n'
# )

View File

@ -234,7 +234,7 @@ async def flow_rates(
# FSPs, user input, and possibly any general event stream in # FSPs, user input, and possibly any general event stream in
# real-time. Hint: ideally implemented with caching until mutated # real-time. Hint: ideally implemented with caching until mutated
# ;) # ;)
period: 'Param[int]' = 1, # noqa period: 'Param[int]' = 6, # noqa
# TODO: support other means by providing a map # TODO: support other means by providing a map
# to weights `partial()`-ed with `wma()`? # to weights `partial()`-ed with `wma()`?
@ -268,7 +268,8 @@ async def flow_rates(
'dark_dvlm_rate': None, 'dark_dvlm_rate': None,
} }
quote = await anext(source) # TODO: 3.10 do ``anext()``
quote = await source.__anext__()
# ltr = 0 # ltr = 0
# lvr = 0 # lvr = 0

View File

@ -1,998 +0,0 @@
# piker: trading gear for hackers
# Copyright (C) Tyler Goodlet (in stewardship for pikers)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
'''
Personal/Private position parsing, calculating, summarizing in a way
that doesn't try to cuk most humans who prefer to not lose their moneys..
(looking at you `ib` and dirt-bird friends)
'''
from contextlib import contextmanager as cm
from pprint import pformat
import os
from os import path
from math import copysign
import re
import time
from typing import (
Any,
Iterator,
Optional,
Union,
)
import pendulum
from pendulum import datetime, now
import tomli
import toml
from . import config
from .brokers import get_brokermod
from .clearing._messages import BrokerdPosition, Status
from .data._source import Symbol
from .log import get_logger
from .data.types import Struct
log = get_logger(__name__)
@cm
def open_trade_ledger(
broker: str,
account: str,
) -> str:
'''
Indempotently create and read in a trade log file from the
``<configuration_dir>/ledgers/`` directory.
Files are named per broker account of the form
``<brokername>_<accountname>.toml``. The ``accountname`` here is the
name as defined in the user's ``brokers.toml`` config.
'''
ldir = path.join(config._config_dir, 'ledgers')
if not path.isdir(ldir):
os.makedirs(ldir)
fname = f'trades_{broker}_{account}.toml'
tradesfile = path.join(ldir, fname)
if not path.isfile(tradesfile):
log.info(
f'Creating new local trades ledger: {tradesfile}'
)
with open(tradesfile, 'w') as cf:
pass # touch
with open(tradesfile, 'rb') as cf:
start = time.time()
ledger = tomli.load(cf)
print(f'Ledger load took {time.time() - start}s')
cpy = ledger.copy()
try:
yield cpy
finally:
if cpy != ledger:
# TODO: show diff output?
# https://stackoverflow.com/questions/12956957/print-diff-of-python-dictionaries
print(f'Updating ledger for {tradesfile}:\n')
ledger.update(cpy)
# we write on close the mutated ledger data
with open(tradesfile, 'w') as cf:
toml.dump(ledger, cf)
class Transaction(Struct, frozen=True):
# TODO: should this be ``.to`` (see below)?
fqsn: str
tid: Union[str, int] # unique transaction id
size: float
price: float
cost: float # commisions or other additional costs
dt: datetime
expiry: Optional[datetime] = None
# optional key normally derived from the broker
# backend which ensures the instrument-symbol this record
# is for is truly unique.
bsuid: Optional[Union[str, int]] = None
# optional fqsn for the source "asset"/money symbol?
# from: Optional[str] = None
def iter_by_dt(
clears: dict[str, Any],
) -> Iterator[tuple[str, dict]]:
'''
Iterate entries of a ``clears: dict`` table sorted by entry recorded
datetime presumably set at the ``'dt'`` field in each entry.
'''
for tid, data in sorted(
list(clears.items()),
key=lambda item: item[1]['dt'],
):
yield tid, data
class Position(Struct):
'''
Basic pp (personal/piker position) model with attached clearing
transaction history.
'''
symbol: Symbol
# can be +ve or -ve for long/short
size: float
# "breakeven price" above or below which pnl moves above and below
# zero for the entirety of the current "trade state".
ppu: float
# unique backend symbol id
bsuid: str
split_ratio: Optional[int] = None
# ordered record of known constituent trade messages
clears: dict[
Union[str, int, Status], # trade id
dict[str, Any], # transaction history summaries
] = {}
first_clear_dt: Optional[datetime] = None
expiry: Optional[datetime] = None
def to_dict(self) -> dict:
return {
f: getattr(self, f)
for f in self.__struct_fields__
}
def to_pretoml(self) -> tuple[str, dict]:
'''
Prep this position's data contents for export to toml including
re-structuring of the ``.clears`` table to an array of
inline-subtables for better ``pps.toml`` compactness.
'''
d = self.to_dict()
clears = d.pop('clears')
expiry = d.pop('expiry')
if self.split_ratio is None:
d.pop('split_ratio')
# should be obvious from clears/event table
d.pop('first_clear_dt')
# TODO: we need to figure out how to have one top level
# listing venue here even when the backend isn't providing
# it via the trades ledger..
# drop symbol obj in serialized form
s = d.pop('symbol')
fqsn = s.front_fqsn()
if self.expiry is None:
d.pop('expiry', None)
elif expiry:
d['expiry'] = str(expiry)
toml_clears_list = []
# reverse sort so latest clears are at top of section?
for tid, data in iter_by_dt(clears):
inline_table = toml.TomlDecoder().get_empty_inline_table()
# serialize datetime to parsable `str`
inline_table['dt'] = str(data['dt'])
# insert optional clear fields in column order
for k in ['ppu', 'accum_size']:
val = data.get(k)
if val:
inline_table[k] = val
# insert required fields
for k in ['price', 'size', 'cost']:
inline_table[k] = data[k]
inline_table['tid'] = tid
toml_clears_list.append(inline_table)
d['clears'] = toml_clears_list
return fqsn, d
def ensure_state(self) -> None:
'''
Audit either the `.size` and `.ppu` local instance vars against
the clears table calculations and return the calc-ed values if
they differ and log warnings to console.
'''
clears = list(self.clears.values())
self.first_clear_dt = min(list(entry['dt'] for entry in clears))
last_clear = clears[-1]
csize = self.calc_size()
accum = last_clear['accum_size']
if not self.expired():
if (
csize != accum
and csize != round(accum * self.split_ratio or 1)
):
raise ValueError(f'Size mismatch: {csize}')
else:
assert csize == 0, 'Contract is expired but non-zero size?'
if self.size != csize:
log.warning(
'Position state mismatch:\n'
f'{self.size} => {csize}'
)
self.size = csize
cppu = self.calc_ppu()
ppu = last_clear['ppu']
if (
cppu != ppu
and self.split_ratio is not None
# handle any split info entered (for now) manually by user
and cppu != (ppu / self.split_ratio)
):
raise ValueError(f'PPU mismatch: {cppu}')
if self.ppu != cppu:
log.warning(
'Position state mismatch:\n'
f'{self.ppu} => {cppu}'
)
self.ppu = cppu
def update_from_msg(
self,
msg: BrokerdPosition,
) -> None:
# XXX: better place to do this?
symbol = self.symbol
lot_size_digits = symbol.lot_size_digits
ppu, size = (
round(
msg['avg_price'],
ndigits=symbol.tick_size_digits
),
round(
msg['size'],
ndigits=lot_size_digits
),
)
self.ppu = ppu
self.size = size
@property
def dsize(self) -> float:
'''
The "dollar" size of the pp, normally in trading (fiat) unit
terms.
'''
return self.ppu * self.size
# TODO: idea: "real LIFO" dynamic positioning.
# - when a trade takes place where the pnl for
# the (set of) trade(s) is below the breakeven price
# it may be that the trader took a +ve pnl on a short(er)
# term trade in the same account.
# - in this case we could recalc the be price to
# be reverted back to it's prior value before the nearest term
# trade was opened.?
# def lifo_price() -> float:
# ...
def iter_clears(self) -> Iterator[tuple[str, dict]]:
'''
Iterate the internally managed ``.clears: dict`` table in
datetime-stamped order.
'''
return iter_by_dt(self.clears)
def calc_ppu(
self,
# include transaction cost in breakeven price
# and presume the worst case of the same cost
# to exit this transaction (even though in reality
# it will be dynamic based on exit stratetgy).
cost_scalar: float = 2,
) -> float:
'''
Compute the "price-per-unit" price for the given non-zero sized
rolling position.
The recurrence relation which computes this (exponential) mean
per new clear which **increases** the accumulative postiion size
is:
ppu[-1] = (
ppu[-2] * accum_size[-2]
+
ppu[-1] * size
) / accum_size[-1]
where `cost_basis` for the current step is simply the price
* size of the most recent clearing transaction.
'''
asize_h: list[float] = [] # historical accumulative size
ppu_h: list[float] = [] # historical price-per-unit
tid: str
entry: dict[str, Any]
for (tid, entry) in self.iter_clears():
clear_size = entry['size']
clear_price = entry['price']
last_accum_size = asize_h[-1] if asize_h else 0
accum_size = last_accum_size + clear_size
accum_sign = copysign(1, accum_size)
sign_change: bool = False
if accum_size == 0:
ppu_h.append(0)
asize_h.append(0)
continue
if accum_size == 0:
ppu_h.append(0)
asize_h.append(0)
continue
# test if the pp somehow went "passed" a net zero size state
# resulting in a change of the "sign" of the size (+ve for
# long, -ve for short).
sign_change = (
copysign(1, last_accum_size) + accum_sign == 0
and last_accum_size != 0
)
# since we passed the net-zero-size state the new size
# after sum should be the remaining size the new
# "direction" (aka, long vs. short) for this clear.
if sign_change:
clear_size = accum_size
abs_diff = abs(accum_size)
asize_h.append(0)
ppu_h.append(0)
else:
# old size minus the new size gives us size diff with
# +ve -> increase in pp size
# -ve -> decrease in pp size
abs_diff = abs(accum_size) - abs(last_accum_size)
# XXX: LIFO breakeven price update. only an increaze in size
# of the position contributes the breakeven price,
# a decrease does not (i.e. the position is being made
# smaller).
# abs_clear_size = abs(clear_size)
abs_new_size = abs(accum_size)
if abs_diff > 0:
cost_basis = (
# cost basis for this clear
clear_price * abs(clear_size)
+
# transaction cost
accum_sign * cost_scalar * entry['cost']
)
if asize_h:
size_last = abs(asize_h[-1])
cb_last = ppu_h[-1] * size_last
ppu = (cost_basis + cb_last) / abs_new_size
else:
ppu = cost_basis / abs_new_size
ppu_h.append(ppu)
asize_h.append(accum_size)
else:
# on "exit" clears from a given direction,
# only the size changes not the price-per-unit
# need to be updated since the ppu remains constant
# and gets weighted by the new size.
asize_h.append(accum_size)
ppu_h.append(ppu_h[-1])
final_ppu = ppu_h[-1] if ppu_h else 0
# handle any split info entered (for now) manually by user
if self.split_ratio is not None:
final_ppu /= self.split_ratio
return final_ppu
def expired(self) -> bool:
'''
Predicate which checks if the contract/instrument is past its expiry.
'''
return bool(self.expiry) and self.expiry < now()
def calc_size(self) -> float:
'''
Calculate the unit size of this position in the destination
asset using the clears/trade event table; zero if expired.
'''
size: float = 0
# time-expired pps (normally derivatives) are "closed"
# and have a zero size.
if self.expired():
return 0
for tid, entry in self.clears.items():
size += entry['size']
if self.split_ratio is not None:
size = round(size * self.split_ratio)
return size
def minimize_clears(
self,
) -> dict[str, dict]:
'''
Minimize the position's clears entries by removing
all transactions before the last net zero size to avoid
unecessary history irrelevant to the current pp state.
'''
size: float = 0
clears_since_zero: list[tuple(str, dict)] = []
# TODO: we might just want to always do this when iterating
# a ledger? keep a state of the last net-zero and only do the
# full iterate when no state was stashed?
# scan for the last "net zero" position by iterating
# transactions until the next net-zero size, rinse, repeat.
for tid, clear in self.clears.items():
size += clear['size']
clears_since_zero.append((tid, clear))
if size == 0:
clears_since_zero.clear()
self.clears = dict(clears_since_zero)
return self.clears
def add_clear(
self,
t: Transaction,
) -> dict:
'''
Update clearing table and populate rolling ppu and accumulative
size in both the clears entry and local attrs state.
'''
clear = self.clears[t.tid] = {
'cost': t.cost,
'price': t.price,
'size': t.size,
'dt': t.dt,
}
# TODO: compute these incrementally instead
# of re-looping through each time resulting in O(n**2)
# behaviour..?
# NOTE: we compute these **after** adding the entry in order to
# make the recurrence relation math work inside
# ``.calc_size()``.
self.size = clear['accum_size'] = self.calc_size()
self.ppu = clear['ppu'] = self.calc_ppu()
return clear
def sugest_split(self) -> float:
...
class PpTable(Struct):
brokername: str
acctid: str
pps: dict[str, Position]
conf: Optional[dict] = {}
def update_from_trans(
self,
trans: dict[str, Transaction],
cost_scalar: float = 2,
) -> dict[str, Position]:
pps = self.pps
updated: dict[str, Position] = {}
# lifo update all pps from records
for tid, t in trans.items():
pp = pps.setdefault(
t.bsuid,
# if no existing pp, allocate fresh one.
Position(
Symbol.from_fqsn(
t.fqsn,
info={},
),
size=0.0,
ppu=0.0,
bsuid=t.bsuid,
expiry=t.expiry,
)
)
clears = pp.clears
if clears:
first_clear_dt = pp.first_clear_dt
# don't do updates for ledger records we already have
# included in the current pps state.
if (
t.tid in clears
or first_clear_dt and t.dt < first_clear_dt
):
# NOTE: likely you'll see repeats of the same
# ``Transaction`` passed in here if/when you are restarting
# a ``brokerd.ib`` where the API will re-report trades from
# the current session, so we need to make sure we don't
# "double count" these in pp calculations.
continue
# update clearing table
pp.add_clear(t)
updated[t.bsuid] = pp
# minimize clears tables and update sizing.
for bsuid, pp in updated.items():
pp.ensure_state()
return updated
def dump_active(
self,
) -> tuple[
dict[str, Position],
dict[str, Position]
]:
'''
Iterate all tabulated positions, render active positions to
a ``dict`` format amenable to serialization (via TOML) and drop
from state (``.pps``) as well as return in a ``dict`` all
``Position``s which have recently closed.
'''
# NOTE: newly closed position are also important to report/return
# since a consumer, like an order mode UI ;), might want to react
# based on the closure (for example removing the breakeven line
# and clearing the entry from any lists/monitors).
closed_pp_objs: dict[str, Position] = {}
open_pp_objs: dict[str, Position] = {}
pp_objs = self.pps
for bsuid in list(pp_objs):
pp = pp_objs[bsuid]
# XXX: debug hook for size mismatches
# qqqbsuid = 320227571
# if bsuid == qqqbsuid:
# breakpoint()
pp.ensure_state()
if (
# "net-zero" is a "closed" position
pp.size == 0
# time-expired pps (normally derivatives) are "closed"
or (pp.expiry and pp.expiry < now())
):
# for expired cases
pp.size = 0
# NOTE: we DO NOT pop the pp here since it can still be
# used to check for duplicate clears that may come in as
# new transaction from some backend API and need to be
# ignored; the closed positions won't be written to the
# ``pps.toml`` since ``pp_active_entries`` above is what's
# written.
closed_pp_objs[bsuid] = pp
else:
open_pp_objs[bsuid] = pp
return open_pp_objs, closed_pp_objs
def to_toml(
self,
) -> dict[str, Any]:
active, closed = self.dump_active()
# ONLY dict-serialize all active positions; those that are closed
# we don't store in the ``pps.toml``.
to_toml_dict = {}
for bsuid, pos in active.items():
# keep the minimal amount of clears that make up this
# position since the last net-zero state.
pos.minimize_clears()
pos.ensure_state()
# serialize to pre-toml form
fqsn, asdict = pos.to_pretoml()
log.info(f'Updating active pp: {fqsn}')
# XXX: ugh, it's cuz we push the section under
# the broker name.. maybe we need to rethink this?
brokerless_key = fqsn.removeprefix(f'{self.brokername}.')
to_toml_dict[brokerless_key] = asdict
return to_toml_dict
def write_config(self) -> None:
'''
Write the current position table to the user's ``pps.toml``.
'''
# TODO: show diff output?
# https://stackoverflow.com/questions/12956957/print-diff-of-python-dictionaries
print(f'Updating ``pps.toml`` for {path}:\n')
# active, closed_pp_objs = table.dump_active()
pp_entries = self.to_toml()
self.conf[self.brokername][self.acctid] = pp_entries
# TODO: why tf haven't they already done this for inline
# tables smh..
enc = PpsEncoder(preserve=True)
# table_bs_type = type(toml.TomlDecoder().get_empty_inline_table())
enc.dump_funcs[
toml.decoder.InlineTableDict
] = enc.dump_inline_table
config.write(
self.conf,
'pps',
encoder=enc,
)
def load_pps_from_ledger(
brokername: str,
acctname: str,
# post normalization filter on ledger entries to be processed
filter_by: Optional[list[dict]] = None,
) -> tuple[
dict[str, Transaction],
dict[str, Position],
]:
'''
Open a ledger file by broker name and account and read in and
process any trade records into our normalized ``Transaction`` form
and then update the equivalent ``Pptable`` and deliver the two
bsuid-mapped dict-sets of the transactions and pps.
'''
with (
open_trade_ledger(brokername, acctname) as ledger,
open_pps(brokername, acctname) as table,
):
if not ledger:
# null case, no ledger file with content
return {}
mod = get_brokermod(brokername)
src_records: dict[str, Transaction] = mod.norm_trade_records(ledger)
if filter_by:
records = {}
bsuids = set(filter_by)
for tid, r in src_records.items():
if r.bsuid in bsuids:
records[tid] = r
else:
records = src_records
updated = table.update_from_trans(records)
return records, updated
# TODO: instead see if we can hack tomli and tomli-w to do the same:
# - https://github.com/hukkin/tomli
# - https://github.com/hukkin/tomli-w
class PpsEncoder(toml.TomlEncoder):
'''
Special "styled" encoder that makes a ``pps.toml`` redable and
compact by putting `.clears` tables inline and everything else
flat-ish.
'''
separator = ','
def dump_list(self, v):
'''
Dump an inline list with a newline after every element and
with consideration for denoted inline table types.
'''
retval = "[\n"
for u in v:
if isinstance(u, toml.decoder.InlineTableDict):
out = self.dump_inline_table(u)
else:
out = str(self.dump_value(u))
retval += " " + out + "," + "\n"
retval += "]"
return retval
def dump_inline_table(self, section):
"""Preserve inline table in its compact syntax instead of expanding
into subsection.
https://github.com/toml-lang/toml#user-content-inline-table
"""
val_list = []
for k, v in section.items():
# if isinstance(v, toml.decoder.InlineTableDict):
if isinstance(v, dict):
val = self.dump_inline_table(v)
else:
val = str(self.dump_value(v))
val_list.append(k + " = " + val)
retval = "{ " + ", ".join(val_list) + " }"
return retval
def dump_sections(self, o, sup):
retstr = ""
if sup != "" and sup[-1] != ".":
sup += '.'
retdict = self._dict()
arraystr = ""
for section in o:
qsection = str(section)
value = o[section]
if not re.match(r'^[A-Za-z0-9_-]+$', section):
qsection = toml.encoder._dump_str(section)
# arrayoftables = False
if (
self.preserve
and isinstance(value, toml.decoder.InlineTableDict)
):
retstr += (
qsection
+
" = "
+
self.dump_inline_table(o[section])
+
'\n' # only on the final terminating left brace
)
# XXX: this code i'm pretty sure is just blatantly bad
# and/or wrong..
# if isinstance(o[section], list):
# for a in o[section]:
# if isinstance(a, dict):
# arrayoftables = True
# if arrayoftables:
# for a in o[section]:
# arraytabstr = "\n"
# arraystr += "[[" + sup + qsection + "]]\n"
# s, d = self.dump_sections(a, sup + qsection)
# if s:
# if s[0] == "[":
# arraytabstr += s
# else:
# arraystr += s
# while d:
# newd = self._dict()
# for dsec in d:
# s1, d1 = self.dump_sections(d[dsec], sup +
# qsection + "." +
# dsec)
# if s1:
# arraytabstr += ("[" + sup + qsection +
# "." + dsec + "]\n")
# arraytabstr += s1
# for s1 in d1:
# newd[dsec + "." + s1] = d1[s1]
# d = newd
# arraystr += arraytabstr
elif isinstance(value, dict):
retdict[qsection] = o[section]
elif o[section] is not None:
retstr += (
qsection
+
" = "
+
str(self.dump_value(o[section]))
)
# if not isinstance(value, dict):
if not isinstance(value, toml.decoder.InlineTableDict):
# inline tables should not contain newlines:
# https://toml.io/en/v1.0.0#inline-table
retstr += '\n'
else:
raise ValueError(value)
retstr += arraystr
return (retstr, retdict)
@cm
def open_pps(
brokername: str,
acctid: str,
write_on_exit: bool = True,
) -> PpTable:
'''
Read out broker-specific position entries from
incremental update file: ``pps.toml``.
'''
conf, path = config.load('pps')
brokersection = conf.setdefault(brokername, {})
pps = brokersection.setdefault(acctid, {})
# TODO: ideally we can pass in an existing
# pps state to this right? such that we
# don't have to do a ledger reload all the
# time.. a couple ideas I can think of,
# - mirror this in some client side actor which
# does the actual ledger updates (say the paper
# engine proc if we decide to always spawn it?),
# - do diffs against updates from the ledger writer
# actor and the in-mem state here?
pp_objs = {}
table = PpTable(
brokername,
acctid,
pp_objs,
conf=conf,
)
# unmarshal/load ``pps.toml`` config entries into object form
# and update `PpTable` obj entries.
for fqsn, entry in pps.items():
bsuid = entry['bsuid']
# convert clears sub-tables (only in this form
# for toml re-presentation) back into a master table.
clears_list = entry['clears']
# index clears entries in "object" form by tid in a top
# level dict instead of a list (as is presented in our
# ``pps.toml``).
clears = pp_objs.setdefault(bsuid, {})
# TODO: should be make a ``Struct`` for clear/event entries?
# convert "clear events table" from the toml config (list of
# a dicts) and load it into object form for use in position
# processing of new clear events.
trans: list[Transaction] = []
for clears_table in clears_list:
tid = clears_table.pop('tid')
dtstr = clears_table['dt']
dt = pendulum.parse(dtstr)
clears_table['dt'] = dt
trans.append(Transaction(
fqsn=bsuid,
bsuid=bsuid,
tid=tid,
size=clears_table['size'],
price=clears_table['price'],
cost=clears_table['cost'],
dt=dt,
))
clears[tid] = clears_table
size = entry['size']
# TODO: remove but, handle old field name for now
ppu = entry.get('ppu', entry.get('be_price', 0))
split_ratio = entry.get('split_ratio')
expiry = entry.get('expiry')
if expiry:
expiry = pendulum.parse(expiry)
pp = pp_objs[bsuid] = Position(
Symbol.from_fqsn(fqsn, info={}),
size=size,
ppu=ppu,
split_ratio=split_ratio,
expiry=expiry,
bsuid=entry['bsuid'],
)
# XXX: super critical, we need to be sure to include
# all pps.toml clears to avoid reusing clears that were
# already included in the current incremental update
# state, since today's records may have already been
# processed!
for t in trans:
pp.add_clear(t)
# audit entries loaded from toml
pp.ensure_state()
try:
yield table
finally:
if write_on_exit:
table.write_config()
if __name__ == '__main__':
import sys
args = sys.argv
assert len(args) > 1, 'Specifiy account(s) from `brokers.toml`'
args = args[1:]
for acctid in args:
broker, name = acctid.split('.')
trans, updated_pps = load_pps_from_ledger(broker, name)
print(
f'Processing transactions into pps for {broker}:{acctid}\n'
f'{pformat(trans)}\n\n'
f'{pformat(updated_pps)}'
)

View File

@ -32,22 +32,16 @@ def mk_marker_path(
style: str, style: str,
) -> QGraphicsPathItem: ) -> QGraphicsPathItem:
''' """Add a marker to be displayed on the line wrapped in a ``QGraphicsPathItem``
Add a marker to be displayed on the line wrapped in ready to be placed using scene coordinates (not view).
a ``QGraphicsPathItem`` ready to be placed using scene coordinates
(not view).
**Arguments** **Arguments**
style String indicating the style of marker to add: style String indicating the style of marker to add:
``'<|'``, ``'|>'``, ``'>|'``, ``'|<'``, ``'<|>'``, ``'<|'``, ``'|>'``, ``'>|'``, ``'|<'``, ``'<|>'``,
``'>|<'``, ``'^'``, ``'v'``, ``'o'`` ``'>|<'``, ``'^'``, ``'v'``, ``'o'``
size Size of the marker in pixels.
This code is taken nearly verbatim from the """
`InfiniteLine.addMarker()` method but does not attempt do be aware
of low(er) level graphics controls and expects for the output
polygon to be applied to a ``QGraphicsPathItem``.
'''
path = QtGui.QPainterPath() path = QtGui.QPainterPath()
if style == 'o': if style == 'o':
@ -93,8 +87,7 @@ def mk_marker_path(
class LevelMarker(QGraphicsPathItem): class LevelMarker(QGraphicsPathItem):
''' '''An arrow marker path graphich which redraws itself
An arrow marker path graphich which redraws itself
to the specified view coordinate level on each paint cycle. to the specified view coordinate level on each paint cycle.
''' '''
@ -111,8 +104,7 @@ class LevelMarker(QGraphicsPathItem):
# get polygon and scale # get polygon and scale
super().__init__() super().__init__()
# self.setScale(size, size) self.scale(size, size)
self.setScale(size)
# interally generates path # interally generates path
self._style = None self._style = None
@ -122,7 +114,6 @@ class LevelMarker(QGraphicsPathItem):
self.get_level = get_level self.get_level = get_level
self._on_paint = on_paint self._on_paint = on_paint
self.scene_x = lambda: chart.marker_right_points()[1] self.scene_x = lambda: chart.marker_right_points()[1]
self.level: float = 0 self.level: float = 0
self.keep_in_view = keep_in_view self.keep_in_view = keep_in_view
@ -158,9 +149,12 @@ class LevelMarker(QGraphicsPathItem):
def w(self) -> float: def w(self) -> float:
return self.path_br().width() return self.path_br().width()
def position_in_view(self) -> None: def position_in_view(
''' self,
Show a pp off-screen indicator for a level label. # level: float,
) -> None:
'''Show a pp off-screen indicator for a level label.
This is like in fps games where you have a gps "nav" indicator This is like in fps games where you have a gps "nav" indicator
but your teammate is outside the range of view, except in 2D, on but your teammate is outside the range of view, except in 2D, on
@ -168,6 +162,7 @@ class LevelMarker(QGraphicsPathItem):
''' '''
level = self.get_level() level = self.get_level()
view = self.chart.getViewBox() view = self.chart.getViewBox()
vr = view.state['viewRange'] vr = view.state['viewRange']
ymn, ymx = vr[1] ymn, ymx = vr[1]
@ -191,6 +186,7 @@ class LevelMarker(QGraphicsPathItem):
) )
elif level < ymn: # pin to bottom of view elif level < ymn: # pin to bottom of view
self.setPos( self.setPos(
QPointF( QPointF(
x, x,
@ -215,8 +211,7 @@ class LevelMarker(QGraphicsPathItem):
w: QtWidgets.QWidget w: QtWidgets.QWidget
) -> None: ) -> None:
''' '''Core paint which we override to always update
Core paint which we override to always update
our marker position in scene coordinates from a our marker position in scene coordinates from a
view cooridnate "level". view cooridnate "level".
@ -240,12 +235,11 @@ def qgo_draw_markers(
right_offset: float, right_offset: float,
) -> float: ) -> float:
''' """Paint markers in ``pg.GraphicsItem`` style by first
Paint markers in ``pg.GraphicsItem`` style by first
removing the view transform for the painter, drawing the markers removing the view transform for the painter, drawing the markers
in scene coords, then restoring the view coords. in scene coords, then restoring the view coords.
''' """
# paint markers in native coordinate system # paint markers in native coordinate system
orig_tr = p.transform() orig_tr = p.transform()

View File

@ -19,16 +19,15 @@ Main app startup and run.
''' '''
from functools import partial from functools import partial
from types import ModuleType
from PyQt5.QtCore import QEvent from PyQt5.QtCore import QEvent
import trio import trio
from .._daemon import maybe_spawn_brokerd from .._daemon import maybe_spawn_brokerd
from ..brokers import get_brokermod
from . import _event from . import _event
from ._exec import run_qtractor from ._exec import run_qtractor
from ..data.feed import install_brokerd_search from ..data.feed import install_brokerd_search
from ..data._source import unpack_fqsn
from . import _search from . import _search
from ._chart import GodWidget from ._chart import GodWidget
from ..log import get_logger from ..log import get_logger
@ -37,26 +36,27 @@ log = get_logger(__name__)
async def load_provider_search( async def load_provider_search(
brokermod: str,
broker: str,
loglevel: str, loglevel: str,
) -> None: ) -> None:
name = brokermod.name log.info(f'loading brokerd for {broker}..')
log.info(f'loading brokerd for {name}..')
async with ( async with (
maybe_spawn_brokerd( maybe_spawn_brokerd(
name, broker,
loglevel=loglevel loglevel=loglevel
) as portal, ) as portal,
install_brokerd_search( install_brokerd_search(
portal, portal,
brokermod, get_brokermod(broker),
), ),
): ):
# keep search engine stream up until cancelled # keep search engine stream up until cancelled
await trio.sleep_forever() await trio.sleep_forever()
@ -66,8 +66,8 @@ async def _async_main(
# implicit required argument provided by ``qtractor_run()`` # implicit required argument provided by ``qtractor_run()``
main_widget: GodWidget, main_widget: GodWidget,
syms: list[str], sym: str,
brokers: dict[str, ModuleType], brokernames: str,
loglevel: str, loglevel: str,
) -> None: ) -> None:
@ -78,8 +78,6 @@ async def _async_main(
""" """
from . import _display from . import _display
from ._pg_overrides import _do_overrides
_do_overrides()
godwidget = main_widget godwidget = main_widget
@ -99,11 +97,6 @@ async def _async_main(
sbar = godwidget.window.status_bar sbar = godwidget.window.status_bar
starting_done = sbar.open_status('starting ze sexy chartz') starting_done = sbar.open_status('starting ze sexy chartz')
needed_brokermods: dict[str, ModuleType] = {}
for fqsn in syms:
brokername, *_ = unpack_fqsn(fqsn)
needed_brokermods[brokername] = brokers[brokername]
async with ( async with (
trio.open_nursery() as root_n, trio.open_nursery() as root_n,
): ):
@ -114,14 +107,18 @@ async def _async_main(
# setup search widget and focus main chart view at startup # setup search widget and focus main chart view at startup
# search widget is a singleton alongside the godwidget # search widget is a singleton alongside the godwidget
search = _search.SearchWidget(godwidget=godwidget) search = _search.SearchWidget(godwidget=godwidget)
# search.bar.unfocus() search.bar.unfocus()
# godwidget.hbox.addWidget(search)
godwidget.hbox.addWidget(search)
godwidget.search = search godwidget.search = search
symbol, _, provider = sym.rpartition('.')
# this internally starts a ``display_symbol_data()`` task above # this internally starts a ``display_symbol_data()`` task above
order_mode_ready = await godwidget.load_symbols( order_mode_ready = await godwidget.load_symbol(
fqsns=syms, provider,
loglevel=loglevel, symbol,
loglevel
) )
# spin up a search engine for the local cached symbol set # spin up a search engine for the local cached symbol set
@ -138,12 +135,8 @@ async def _async_main(
): ):
# load other providers into search **after** # load other providers into search **after**
# the chart's select cache # the chart's select cache
for brokername, mod in needed_brokermods.items(): for broker in brokernames:
root_n.start_soon( root_n.start_soon(load_provider_search, broker, loglevel)
load_provider_search,
mod,
loglevel,
)
await order_mode_ready.wait() await order_mode_ready.wait()
@ -172,8 +165,8 @@ async def _async_main(
def _main( def _main(
syms: list[str], sym: str,
brokermods: list[ModuleType], brokernames: [str],
piker_loglevel: str, piker_loglevel: str,
tractor_kwargs, tractor_kwargs,
) -> None: ) -> None:
@ -184,11 +177,7 @@ def _main(
''' '''
run_qtractor( run_qtractor(
func=_async_main, func=_async_main,
args=( args=(sym, brokernames, piker_loglevel),
syms, main_widget=GodWidget,
{mod.name: mod for mod in brokermods},
piker_loglevel,
),
main_widget_type=GodWidget,
tractor_kwargs=tractor_kwargs, tractor_kwargs=tractor_kwargs,
) )

View File

@ -18,7 +18,6 @@
Chart axes graphics and behavior. Chart axes graphics and behavior.
""" """
from __future__ import annotations
from functools import lru_cache from functools import lru_cache
from typing import Optional, Callable from typing import Optional, Callable
from math import floor from math import floor
@ -28,7 +27,6 @@ import pyqtgraph as pg
from PyQt5 import QtCore, QtGui, QtWidgets from PyQt5 import QtCore, QtGui, QtWidgets
from PyQt5.QtCore import QPointF from PyQt5.QtCore import QPointF
from . import _pg_overrides as pgo
from ..data._source import float_digits from ..data._source import float_digits
from ._label import Label from ._label import Label
from ._style import DpiAwareFont, hcolor, _font from ._style import DpiAwareFont, hcolor, _font
@ -41,17 +39,12 @@ class Axis(pg.AxisItem):
''' '''
A better axis that sizes tick contents considering font size. A better axis that sizes tick contents considering font size.
Also includes tick values lru caching originally proposed in but never
accepted upstream:
https://github.com/pyqtgraph/pyqtgraph/pull/2160
''' '''
def __init__( def __init__(
self, self,
plotitem: pgo.PlotItem, linkedsplits,
typical_max_str: str = '100 000.000', typical_max_str: str = '100 000.000',
text_color: str = 'bracket', text_color: str = 'bracket',
lru_cache_tick_strings: bool = True,
**kwargs **kwargs
) -> None: ) -> None:
@ -63,32 +56,27 @@ class Axis(pg.AxisItem):
# XXX: pretty sure this makes things slower # XXX: pretty sure this makes things slower
# self.setCacheMode(QtWidgets.QGraphicsItem.DeviceCoordinateCache) # self.setCacheMode(QtWidgets.QGraphicsItem.DeviceCoordinateCache)
self.pi = plotitem self.linkedsplits = linkedsplits
self._dpi_font = _font self._dpi_font = _font
self.setTickFont(_font.font) self.setTickFont(_font.font)
font_size = self._dpi_font.font.pixelSize() font_size = self._dpi_font.font.pixelSize()
style_conf = {
'textFillLimits': [(0, 0.5)],
'tickFont': self._dpi_font.font,
}
text_offset = None
if self.orientation in ('bottom',): if self.orientation in ('bottom',):
text_offset = floor(0.25 * font_size) text_offset = floor(0.25 * font_size)
elif self.orientation in ('left', 'right'): elif self.orientation in ('left', 'right'):
text_offset = floor(font_size / 2) text_offset = floor(font_size / 2)
if text_offset: self.setStyle(**{
style_conf.update({ 'textFillLimits': [(0, 0.5)],
'tickFont': self._dpi_font.font,
# offset of text *away from* axis line in px # offset of text *away from* axis line in px
# use approx. half the font pixel size (height) # use approx. half the font pixel size (height)
'tickTextOffset': text_offset, 'tickTextOffset': text_offset,
}) })
self.setStyle(**style_conf)
self.setTickFont(_font.font) self.setTickFont(_font.font)
# NOTE: this is for surrounding "border" # NOTE: this is for surrounding "border"
@ -103,37 +91,6 @@ class Axis(pg.AxisItem):
# size the pertinent axis dimension to a "typical value" # size the pertinent axis dimension to a "typical value"
self.size_to_values() self.size_to_values()
# NOTE: requires override ``.tickValues()`` method seen below.
if lru_cache_tick_strings:
self.tickStrings = lru_cache(
maxsize=2**20
)(self.tickStrings)
# axis "sticky" labels
self._stickies: dict[str, YAxisLabel] = {}
# NOTE: only overriden to cast tick values entries into tuples
# for use with the lru caching.
def tickValues(
self,
minVal: float,
maxVal: float,
size: int,
) -> list[tuple[float, tuple[str]]]:
'''
Repack tick values into tuples for lru caching.
'''
ticks = []
for scalar, values in super().tickValues(minVal, maxVal, size):
ticks.append((
scalar,
tuple(values), # this
))
return ticks
@property @property
def text_color(self) -> str: def text_color(self) -> str:
return self._text_color return self._text_color
@ -149,40 +106,6 @@ class Axis(pg.AxisItem):
def txt_offsets(self) -> tuple[int, int]: def txt_offsets(self) -> tuple[int, int]:
return tuple(self.style['tickTextOffset']) return tuple(self.style['tickTextOffset'])
def add_sticky(
self,
pi: pgo.PlotItem,
name: None | str = None,
digits: None | int = 2,
# axis_name: str = 'right',
bg_color='bracket',
) -> YAxisLabel:
# if the sticky is for our symbol
# use the tick size precision for display
name = name or pi.name
digits = digits or 2
# TODO: ``._ysticks`` should really be an attr on each
# ``PlotItem`` no instead of the (containing because of
# overlays) widget?
# add y-axis "last" value label
sticky = self._stickies[name] = YAxisLabel(
pi=pi,
parent=self,
# TODO: pass this from symbol data
digits=digits,
opacity=1,
bg_color=bg_color,
)
pi.sigRangeChanged.connect(sticky.update_on_resize)
# pi.addItem(sticky)
# pi.addItem(last)
return sticky
class PriceAxis(Axis): class PriceAxis(Axis):
@ -299,9 +222,7 @@ class DynamicDateAxis(Axis):
) -> list[str]: ) -> list[str]:
# XX: ARGGGGG AG:LKSKDJF:LKJSDFD chart = self.linkedsplits.chart
chart = self.pi.chart_widget
flow = chart._flows[chart.name] flow = chart._flows[chart.name]
shm = flow.shm shm = flow.shm
bars = shm.array bars = shm.array
@ -568,7 +489,7 @@ class XAxisLabel(AxisLabel):
class YAxisLabel(AxisLabel): class YAxisLabel(AxisLabel):
_y_margin: int = 4 _y_margin = 4
text_flags = ( text_flags = (
QtCore.Qt.AlignLeft QtCore.Qt.AlignLeft
@ -579,19 +500,19 @@ class YAxisLabel(AxisLabel):
def __init__( def __init__(
self, self,
pi: pgo.PlotItem, chart,
*args, *args,
**kwargs **kwargs
) -> None: ) -> None:
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._pi = pi self._chart = chart
pi.sigRangeChanged.connect(self.update_on_resize)
chart.sigRangeChanged.connect(self.update_on_resize)
self._last_datum = (None, None) self._last_datum = (None, None)
self.x_offset = 0
# pull text offset from axis from parent axis # pull text offset from axis from parent axis
if getattr(self._parent, 'txt_offsets', False): if getattr(self._parent, 'txt_offsets', False):
self.x_offset, y_offset = self._parent.txt_offsets() self.x_offset, y_offset = self._parent.txt_offsets()
@ -610,8 +531,7 @@ class YAxisLabel(AxisLabel):
value: float, # data for text value: float, # data for text
# on odd dimension and/or adds nice black line # on odd dimension and/or adds nice black line
x_offset: int = 0, x_offset: Optional[int] = None
) -> None: ) -> None:
# this is read inside ``.paint()`` # this is read inside ``.paint()``
@ -657,7 +577,7 @@ class YAxisLabel(AxisLabel):
self._last_datum = (index, value) self._last_datum = (index, value)
self.update_label( self.update_label(
self._pi.mapFromView(QPointF(index, value)), self._chart.mapFromView(QPointF(index, value)),
value value
) )

File diff suppressed because it is too large Load Diff

View File

@ -223,20 +223,14 @@ def ds_m4(
assert frames >= (xrange / uppx) assert frames >= (xrange / uppx)
# call into ``numba`` # call into ``numba``
( nb, i_win, y_out = _m4(
nb,
x_out,
y_out,
ymn,
ymx,
) = _m4(
x, x,
y, y,
frames, frames,
# TODO: see func below.. # TODO: see func below..
# x_out, # i_win,
# y_out, # y_out,
# first index in x data to start at # first index in x data to start at
@ -249,11 +243,10 @@ def ds_m4(
# filter out any overshoot in the input allocation arrays by # filter out any overshoot in the input allocation arrays by
# removing zero-ed tail entries which should start at a certain # removing zero-ed tail entries which should start at a certain
# index. # index.
x_out = x_out[x_out != 0] i_win = i_win[i_win != 0]
y_out = y_out[:x_out.size] y_out = y_out[:i_win.size]
# print(f'M4 output ymn, ymx: {ymn},{ymx}') return nb, i_win, y_out
return nb, x_out, y_out, ymn, ymx
@jit( @jit(
@ -267,8 +260,8 @@ def _m4(
frames: int, frames: int,
# TODO: using this approach, having the ``.zeros()`` alloc lines # TODO: using this approach by having the ``.zeros()`` alloc lines
# below in pure python, there were segs faults and alloc crashes.. # below, in put python was causing segs faults and alloc crashes..
# we might need to see how it behaves with shm arrays and consider # we might need to see how it behaves with shm arrays and consider
# allocating them once at startup? # allocating them once at startup?
@ -281,22 +274,14 @@ def _m4(
x_start: int, x_start: int,
step: float, step: float,
) -> tuple[ ) -> int:
int, # nbins = len(i_win)
np.ndarray, # count = len(xs)
np.ndarray,
float,
float,
]:
'''
Implementation of the m4 algorithm in ``numba``:
http://www.vldb.org/pvldb/vol7/p797-jugel.pdf
'''
# these are pre-allocated and mutated by ``numba`` # these are pre-allocated and mutated by ``numba``
# code in-place. # code in-place.
y_out = np.zeros((frames, 4), ys.dtype) y_out = np.zeros((frames, 4), ys.dtype)
x_out = np.zeros(frames, xs.dtype) i_win = np.zeros(frames, xs.dtype)
bincount = 0 bincount = 0
x_left = x_start x_left = x_start
@ -310,34 +295,24 @@ def _m4(
# set all bins in the left-most entry to the starting left-most x value # set all bins in the left-most entry to the starting left-most x value
# (aka a row broadcast). # (aka a row broadcast).
x_out[bincount] = x_left i_win[bincount] = x_left
# set all y-values to the first value passed in. # set all y-values to the first value passed in.
y_out[bincount] = ys[0] y_out[bincount] = ys[0]
# full input y-data mx and mn
mx: float = -np.inf
mn: float = np.inf
# compute OHLC style max / min values per window sized x-frame.
for i in range(len(xs)): for i in range(len(xs)):
x = xs[i] x = xs[i]
y = ys[i] y = ys[i]
if x < x_left + step: # the current window "step" is [bin, bin+1) if x < x_left + step: # the current window "step" is [bin, bin+1)
ymn = y_out[bincount, 1] = min(y, y_out[bincount, 1]) y_out[bincount, 1] = min(y, y_out[bincount, 1])
ymx = y_out[bincount, 2] = max(y, y_out[bincount, 2]) y_out[bincount, 2] = max(y, y_out[bincount, 2])
y_out[bincount, 3] = y y_out[bincount, 3] = y
mx = max(mx, ymx)
mn = min(mn, ymn)
else: else:
# Find the next bin # Find the next bin
while x >= x_left + step: while x >= x_left + step:
x_left += step x_left += step
bincount += 1 bincount += 1
x_out[bincount] = x_left i_win[bincount] = x_left
y_out[bincount] = y y_out[bincount] = y
return bincount, x_out, y_out, mn, mx return bincount, i_win, y_out

View File

@ -18,13 +18,8 @@
Mouse interaction graphics Mouse interaction graphics
""" """
from __future__ import annotations
from functools import partial from functools import partial
from typing import ( from typing import Optional, Callable
Optional,
Callable,
TYPE_CHECKING,
)
import inspect import inspect
import numpy as np import numpy as np
@ -41,12 +36,6 @@ from ._style import (
from ._axes import YAxisLabel, XAxisLabel from ._axes import YAxisLabel, XAxisLabel
from ..log import get_logger from ..log import get_logger
if TYPE_CHECKING:
from ._chart import (
ChartPlotWidget,
LinkedSplits,
)
log = get_logger(__name__) log = get_logger(__name__)
@ -69,7 +58,7 @@ class LineDot(pg.CurvePoint):
curve: pg.PlotCurveItem, curve: pg.PlotCurveItem,
index: int, index: int,
plot: ChartPlotWidget, # type: ingore # noqa plot: 'ChartPlotWidget', # type: ingore # noqa
pos=None, pos=None,
color: str = 'default_light', color: str = 'default_light',
@ -162,7 +151,7 @@ class ContentsLabel(pg.LabelItem):
def __init__( def __init__(
self, self,
# chart: ChartPlotWidget, # noqa # chart: 'ChartPlotWidget', # noqa
view: pg.ViewBox, view: pg.ViewBox,
anchor_at: str = ('top', 'right'), anchor_at: str = ('top', 'right'),
@ -255,7 +244,7 @@ class ContentsLabels:
''' '''
def __init__( def __init__(
self, self,
linkedsplits: LinkedSplits, # type: ignore # noqa linkedsplits: 'LinkedSplits', # type: ignore # noqa
) -> None: ) -> None:
@ -300,7 +289,7 @@ class ContentsLabels:
def add_label( def add_label(
self, self,
chart: ChartPlotWidget, # type: ignore # noqa chart: 'ChartPlotWidget', # type: ignore # noqa
name: str, name: str,
anchor_at: tuple[str, str] = ('top', 'left'), anchor_at: tuple[str, str] = ('top', 'left'),
update_func: Callable = ContentsLabel.update_from_value, update_func: Callable = ContentsLabel.update_from_value,
@ -327,7 +316,7 @@ class Cursor(pg.GraphicsObject):
def __init__( def __init__(
self, self,
linkedsplits: LinkedSplits, # noqa linkedsplits: 'LinkedSplits', # noqa
digits: int = 0 digits: int = 0
) -> None: ) -> None:
@ -336,8 +325,6 @@ class Cursor(pg.GraphicsObject):
self.linked = linkedsplits self.linked = linkedsplits
self.graphics: dict[str, pg.GraphicsObject] = {} self.graphics: dict[str, pg.GraphicsObject] = {}
self.xaxis_label: Optional[XAxisLabel] = None
self.always_show_xlabel: bool = True
self.plots: list['PlotChartWidget'] = [] # type: ignore # noqa self.plots: list['PlotChartWidget'] = [] # type: ignore # noqa
self.active_plot = None self.active_plot = None
self.digits: int = digits self.digits: int = digits
@ -398,7 +385,7 @@ class Cursor(pg.GraphicsObject):
def add_plot( def add_plot(
self, self,
plot: ChartPlotWidget, # noqa plot: 'ChartPlotWidget', # noqa
digits: int = 0, digits: int = 0,
) -> None: ) -> None:
@ -418,7 +405,7 @@ class Cursor(pg.GraphicsObject):
hl.hide() hl.hide()
yl = YAxisLabel( yl = YAxisLabel(
pi=plot.plotItem, chart=plot,
# parent=plot.getAxis('right'), # parent=plot.getAxis('right'),
parent=plot.pi_overlay.get_axis(plot.plotItem, 'right'), parent=plot.pi_overlay.get_axis(plot.plotItem, 'right'),
digits=digits or self.digits, digits=digits or self.digits,
@ -482,7 +469,7 @@ class Cursor(pg.GraphicsObject):
def add_curve_cursor( def add_curve_cursor(
self, self,
plot: ChartPlotWidget, # noqa plot: 'ChartPlotWidget', # noqa
curve: 'PlotCurveItem', # noqa curve: 'PlotCurveItem', # noqa
) -> LineDot: ) -> LineDot:
@ -504,29 +491,17 @@ class Cursor(pg.GraphicsObject):
log.debug(f"{(action, plot.name)}") log.debug(f"{(action, plot.name)}")
if action == 'Enter': if action == 'Enter':
self.active_plot = plot self.active_plot = plot
plot.linked.godwidget._active_cursor = self
# show horiz line and y-label # show horiz line and y-label
self.graphics[plot]['hl'].show() self.graphics[plot]['hl'].show()
self.graphics[plot]['yl'].show() self.graphics[plot]['yl'].show()
if ( else: # Leave
not self.always_show_xlabel
and not self.xaxis_label.isVisible()
):
self.xaxis_label.show()
# Leave: hide horiz line and y-label # hide horiz line and y-label
else:
self.graphics[plot]['hl'].hide() self.graphics[plot]['hl'].hide()
self.graphics[plot]['yl'].hide() self.graphics[plot]['yl'].hide()
if (
not self.always_show_xlabel
and self.xaxis_label.isVisible()
):
self.xaxis_label.hide()
def mouseMoved( def mouseMoved(
self, self,
coords: tuple[QPointF], # noqa coords: tuple[QPointF], # noqa
@ -615,10 +590,6 @@ class Cursor(pg.GraphicsObject):
left_axis_width += left.width() left_axis_width += left.width()
# map back to abs (label-local) coordinates # map back to abs (label-local) coordinates
if (
self.always_show_xlabel
or self.xaxis_label.isVisible()
):
self.xaxis_label.update_label( self.xaxis_label.update_label(
abs_pos=( abs_pos=(
plot.mapFromView(QPointF(vl_x, iy)) - plot.mapFromView(QPointF(vl_x, iy)) -

View File

@ -28,7 +28,10 @@ from PyQt5.QtWidgets import QGraphicsItem
from PyQt5.QtCore import ( from PyQt5.QtCore import (
Qt, Qt,
QLineF, QLineF,
QSizeF,
QRectF, QRectF,
# QRect,
QPointF,
) )
from PyQt5.QtGui import ( from PyQt5.QtGui import (
QPainter, QPainter,
@ -41,7 +44,6 @@ from ._style import hcolor
# ds_m4, # ds_m4,
# ) # )
from ..log import get_logger from ..log import get_logger
from .._profile import Profiler
log = get_logger(__name__) log = get_logger(__name__)
@ -86,9 +88,9 @@ class Curve(pg.GraphicsObject):
''' '''
# sub-type customization methods # sub-type customization methods
declare_paintables: Optional[Callable] = None sub_br: Optional[Callable] = None
sub_paint: Optional[Callable] = None sub_paint: Optional[Callable] = None
# sub_br: Optional[Callable] = None declare_paintables: Optional[Callable] = None
def __init__( def __init__(
self, self,
@ -137,7 +139,9 @@ class Curve(pg.GraphicsObject):
# self.last_step_pen = pg.mkPen(hcolor(color), width=2) # self.last_step_pen = pg.mkPen(hcolor(color), width=2)
self.last_step_pen = pg.mkPen(pen, width=2) self.last_step_pen = pg.mkPen(pen, width=2)
# self._last_line: Optional[QLineF] = None
self._last_line = QLineF() self._last_line = QLineF()
self._last_w: float = 1
# flat-top style histogram-like discrete curve # flat-top style histogram-like discrete curve
# self._step_mode: bool = step_mode # self._step_mode: bool = step_mode
@ -226,8 +230,8 @@ class Curve(pg.GraphicsObject):
self.path.clear() self.path.clear()
if self.fast_path: if self.fast_path:
self.fast_path.clear() # self.fast_path.clear()
# self.fast_path = None self.fast_path = None
@cm @cm
def reset_cache(self) -> None: def reset_cache(self) -> None:
@ -247,65 +251,77 @@ class Curve(pg.GraphicsObject):
self.boundingRect = self._path_br self.boundingRect = self._path_br
return self._path_br() return self._path_br()
# Qt docs: https://doc.qt.io/qt-5/qgraphicsitem.html#boundingRect
def _path_br(self): def _path_br(self):
''' '''
Post init ``.boundingRect()```. Post init ``.boundingRect()```.
''' '''
# profiler = Profiler( # hb = self.path.boundingRect()
# msg=f'Curve.boundingRect(): `{self._name}`', hb = self.path.controlPointRect()
# disabled=not pg_profile_enabled(), hb_size = hb.size()
# ms_threshold=ms_slower_then,
fp = self.fast_path
if fp:
fhb = fp.controlPointRect()
hb_size = fhb.size() + hb_size
# print(f'hb_size: {hb_size}')
# if self._last_step_rect:
# hb_size += self._last_step_rect.size()
# if self._line:
# br = self._last_step_rect.bottomRight()
# tl = QPointF(
# # self._vr[0],
# # hb.topLeft().y(),
# # 0,
# # hb_size.height() + 1
# ) # )
pr = self.path.controlPointRect()
hb_tl, hb_br = (
pr.topLeft(),
pr.bottomRight(),
)
mn_y = hb_tl.y()
mx_y = hb_br.y()
most_left = hb_tl.x()
most_right = hb_br.x()
# profiler('calc path vertices')
# TODO: if/when we get fast path appends working in the # br = self._last_step_rect.bottomRight()
# `Renderer`, then we might need to actually use this..
# fp = self.fast_path
# if fp:
# fhb = fp.controlPointRect()
# # hb_size = fhb.size() + hb_size
# br = pr.united(fhb)
# XXX: *was* a way to allow sub-types to extend the w = hb_size.width()
# boundingrect calc, but in the one use case for a step curve h = hb_size.height()
# doesn't seem like we need it as long as the last line segment
# is drawn as it is?
# sbr = self.sub_br
# if sbr:
# # w, h = self.sub_br(w, h)
# sub_br = sbr()
# br = br.united(sub_br)
sbr = self.sub_br
if sbr:
w, h = self.sub_br(w, h)
else:
# assume plain line graphic and use # assume plain line graphic and use
# default unit step in each direction. # default unit step in each direction.
ll = self._last_line
y1, y2 = ll.y1(), ll.y2()
x1, x2 = ll.x1(), ll.x2()
ymn = min(y1, y2, mn_y) # only on a plane line do we include
ymx = max(y1, y2, mx_y) # and extra index step's worth of width
most_left = min(x1, x2, most_left) # since in the step case the end of the curve
most_right = max(x1, x2, most_right) # actually terminates earlier so we don't need
# profiler('calc last line vertices') # this for the last step.
w += self._last_w
# ll = self._last_line
h += 1 # ll.y2() - ll.y1()
return QRectF( # br = QPointF(
most_left, # self._vr[-1],
ymn, # # tl.x() + w,
most_right - most_left + 1, # tl.y() + h,
ymx, # )
br = QRectF(
# top left
# hb.topLeft()
# tl,
QPointF(hb.topLeft()),
# br,
# total size
# QSizeF(hb_size)
# hb_size,
QSizeF(w, h)
) )
# print(f'bounding rect: {br}')
return br
def paint( def paint(
self, self,
@ -315,7 +331,7 @@ class Curve(pg.GraphicsObject):
) -> None: ) -> None:
profiler = Profiler( profiler = pg.debug.Profiler(
msg=f'Curve.paint(): `{self._name}`', msg=f'Curve.paint(): `{self._name}`',
disabled=not pg_profile_enabled(), disabled=not pg_profile_enabled(),
ms_threshold=ms_slower_then, ms_threshold=ms_slower_then,
@ -323,7 +339,7 @@ class Curve(pg.GraphicsObject):
sub_paint = self.sub_paint sub_paint = self.sub_paint
if sub_paint: if sub_paint:
sub_paint(p) sub_paint(p, profiler)
p.setPen(self.last_step_pen) p.setPen(self.last_step_pen)
p.drawLine(self._last_line) p.drawLine(self._last_line)
@ -433,34 +449,36 @@ class StepCurve(Curve):
y = src_data[array_key] y = src_data[array_key]
x_last = x[-1] x_last = x[-1]
x_2last = x[-2]
y_last = y[-1] y_last = y[-1]
step_size = x_last - x_2last
half_step = step_size / 2
# lol, commenting this makes step curves # lol, commenting this makes step curves
# all "black" for me :eyeroll:.. # all "black" for me :eyeroll:..
self._last_line = QLineF( self._last_line = QLineF(
x_2last, 0, x_last - w, 0,
x_last, 0, x_last + w, 0,
) )
self._last_step_rect = QRectF( self._last_step_rect = QRectF(
x_last - half_step, 0, x_last - w, 0,
step_size, y_last, x_last + w, y_last,
) )
return x, y return x, y
def sub_paint( def sub_paint(
self, self,
p: QPainter, p: QPainter,
profiler: pg.debug.Profiler,
) -> None: ) -> None:
# p.drawLines(*tuple(filter(bool, self._last_step_lines))) # p.drawLines(*tuple(filter(bool, self._last_step_lines)))
# p.drawRect(self._last_step_rect) # p.drawRect(self._last_step_rect)
p.fillRect(self._last_step_rect, self._brush) p.fillRect(self._last_step_rect, self._brush)
profiler('.fillRect()')
# def sub_br( def sub_br(
# self, self,
# parent_br: QRectF | None = None, path_w: float,
# ) -> QRectF: path_h: float,
# return self._last_step_rect
) -> (float, float):
# passthrough
return path_w, path_h

File diff suppressed because it is too large Load Diff

View File

@ -18,27 +18,11 @@
Higher level annotation editors. Higher level annotation editors.
""" """
from __future__ import annotations from dataclasses import dataclass, field
from collections import defaultdict from typing import Optional
from typing import (
Optional,
TYPE_CHECKING
)
import pyqtgraph as pg import pyqtgraph as pg
from pyqtgraph import ( from pyqtgraph import ViewBox, Point, QtCore, QtGui
ViewBox,
Point,
QtCore,
QtWidgets,
)
from PyQt5.QtGui import (
QColor,
)
from PyQt5.QtWidgets import (
QLabel,
)
from pyqtgraph import functions as fn from pyqtgraph import functions as fn
from PyQt5.QtCore import QPointF from PyQt5.QtCore import QPointF
import numpy as np import numpy as np
@ -46,34 +30,28 @@ import numpy as np
from ._style import hcolor, _font from ._style import hcolor, _font
from ._lines import LevelLine from ._lines import LevelLine
from ..log import get_logger from ..log import get_logger
from ..data.types import Struct
if TYPE_CHECKING:
from ._chart import GodWidget
log = get_logger(__name__) log = get_logger(__name__)
class ArrowEditor(Struct): @dataclass
class ArrowEditor:
godw: GodWidget = None # type: ignore # noqa chart: 'ChartPlotWidget' # noqa
_arrows: dict[str, list[pg.ArrowItem]] = {} _arrows: field(default_factory=dict)
def add( def add(
self, self,
plot: pg.PlotItem,
uid: str, uid: str,
x: float, x: float,
y: float, y: float,
color='default', color='default',
pointing: Optional[str] = None, pointing: Optional[str] = None,
) -> pg.ArrowItem: ) -> pg.ArrowItem:
''' """Add an arrow graphic to view at given (x, y).
Add an arrow graphic to view at given (x, y).
''' """
angle = { angle = {
'up': 90, 'up': 90,
'down': -90, 'down': -90,
@ -96,25 +74,25 @@ class ArrowEditor(Struct):
brush=pg.mkBrush(hcolor(color)), brush=pg.mkBrush(hcolor(color)),
) )
arrow.setPos(x, y) arrow.setPos(x, y)
self._arrows.setdefault(uid, []).append(arrow)
self._arrows[uid] = arrow
# render to view # render to view
plot.addItem(arrow) self.chart.plotItem.addItem(arrow)
return arrow return arrow
def remove(self, arrow) -> bool: def remove(self, arrow) -> bool:
for linked in self.godw.iter_linked(): self.chart.plotItem.removeItem(arrow)
linked.chart.plotItem.removeItem(arrow)
class LineEditor(Struct): @dataclass
''' class LineEditor:
The great editor of linez. '''The great editor of linez.
''' '''
godw: GodWidget = None # type: ignore # noqa chart: 'ChartPlotWidget' = None # type: ignore # noqa
_order_lines: defaultdict[str, LevelLine] = defaultdict(list) _order_lines: dict[str, LevelLine] = field(default_factory=dict)
_active_staged_line: LevelLine = None _active_staged_line: LevelLine = None
def stage_line( def stage_line(
@ -122,11 +100,11 @@ class LineEditor(Struct):
line: LevelLine, line: LevelLine,
) -> LevelLine: ) -> LevelLine:
''' """Stage a line at the current chart's cursor position
Stage a line at the current chart's cursor position
and return it. and return it.
''' """
# add a "staged" cursor-tracking line to view # add a "staged" cursor-tracking line to view
# and cash it in a a var # and cash it in a a var
if self._active_staged_line: if self._active_staged_line:
@ -137,25 +115,17 @@ class LineEditor(Struct):
return line return line
def unstage_line(self) -> LevelLine: def unstage_line(self) -> LevelLine:
''' """Inverse of ``.stage_line()``.
Inverse of ``.stage_line()``.
''' """
cursor = self.godw.get_cursor() # chart = self.chart._cursor.active_plot
if not cursor: # # chart.setCursor(QtCore.Qt.ArrowCursor)
return None cursor = self.chart.linked.cursor
# delete "staged" cursor tracking line from view # delete "staged" cursor tracking line from view
line = self._active_staged_line line = self._active_staged_line
if line: if line:
try:
cursor._trackers.remove(line) cursor._trackers.remove(line)
except KeyError:
# when the current cursor doesn't have said line
# registered (probably means that user held order mode
# key while panning to another view) then we just
# ignore the remove error.
pass
line.delete() line.delete()
self._active_staged_line = None self._active_staged_line = None
@ -163,58 +133,55 @@ class LineEditor(Struct):
# show the crosshair y line and label # show the crosshair y line and label
cursor.show_xhair() cursor.show_xhair()
def submit_lines( def submit_line(
self, self,
lines: list[LevelLine], line: LevelLine,
uuid: str, uuid: str,
) -> LevelLine: ) -> LevelLine:
# staged_line = self._active_staged_line staged_line = self._active_staged_line
# if not staged_line: if not staged_line:
# raise RuntimeError("No line is currently staged!?") raise RuntimeError("No line is currently staged!?")
# for now, until submission reponse arrives # for now, until submission reponse arrives
for line in lines:
line.hide_labels() line.hide_labels()
# register for later lookup/deletion # register for later lookup/deletion
self._order_lines[uuid] += lines self._order_lines[uuid] = line
return lines return line
def commit_line(self, uuid: str) -> list[LevelLine]: def commit_line(self, uuid: str) -> LevelLine:
''' """Commit a "staged line" to view.
Commit a "staged line" to view.
Submits the line graphic under the cursor as a (new) permanent Submits the line graphic under the cursor as a (new) permanent
graphic in view. graphic in view.
''' """
lines = self._order_lines[uuid] try:
if lines: line = self._order_lines[uuid]
for line in lines: except KeyError:
log.warning(f'No line for {uuid} could be found?')
return
else:
line.show_labels() line.show_labels()
line.hide_markers()
log.debug(f'Level active for level: {line.value()}')
# TODO: other flashy things to indicate the order is active # TODO: other flashy things to indicate the order is active
return lines log.debug(f'Level active for level: {line.value()}')
return line
def lines_under_cursor(self) -> list[LevelLine]: def lines_under_cursor(self) -> list[LevelLine]:
''' """Get the line(s) under the cursor position.
Get the line(s) under the cursor position.
''' """
# Delete any hoverable under the cursor # Delete any hoverable under the cursor
return self.godw.get_cursor()._hovered return self.chart.linked.cursor._hovered
def all_lines(self) -> list[LevelLine]: def all_lines(self) -> tuple[LevelLine]:
all_lines = [] return tuple(self._order_lines.values())
for lines in list(self._order_lines.values()):
all_lines.extend(lines)
return all_lines
def remove_line( def remove_line(
self, self,
@ -229,30 +196,29 @@ class LineEditor(Struct):
''' '''
# try to look up line from our registry # try to look up line from our registry
lines = self._order_lines.pop(uuid, None) line = self._order_lines.pop(uuid, line)
if lines: if line:
cursor = self.godw.get_cursor()
if cursor:
for line in lines:
# if hovered remove from cursor set # if hovered remove from cursor set
cursor = self.chart.linked.cursor
hovered = cursor._hovered hovered = cursor._hovered
if line in hovered: if line in hovered:
hovered.remove(line) hovered.remove(line)
log.debug(f'deleting {line} with oid: {uuid}')
line.delete()
# make sure the xhair doesn't get left off # make sure the xhair doesn't get left off
# just because we never got a un-hover event # just because we never got a un-hover event
cursor.show_xhair() cursor.show_xhair()
log.debug(f'deleting {line} with oid: {uuid}')
line.delete()
else: else:
log.warning(f'Could not find line for {line}') log.warning(f'Could not find line for {line}')
return lines return line
class SelectRect(QtWidgets.QGraphicsRectItem): class SelectRect(QtGui.QGraphicsRectItem):
def __init__( def __init__(
self, self,
@ -261,12 +227,12 @@ class SelectRect(QtWidgets.QGraphicsRectItem):
) -> None: ) -> None:
super().__init__(0, 0, 1, 1) super().__init__(0, 0, 1, 1)
# self.rbScaleBox = QGraphicsRectItem(0, 0, 1, 1) # self.rbScaleBox = QtGui.QGraphicsRectItem(0, 0, 1, 1)
self.vb = viewbox self.vb = viewbox
self._chart: 'ChartPlotWidget' = None # noqa self._chart: 'ChartPlotWidget' = None # noqa
# override selection box color # override selection box color
color = QColor(hcolor(color)) color = QtGui.QColor(hcolor(color))
self.setPen(fn.mkPen(color, width=1)) self.setPen(fn.mkPen(color, width=1))
color.setAlpha(66) color.setAlpha(66)
self.setBrush(fn.mkBrush(color)) self.setBrush(fn.mkBrush(color))
@ -274,7 +240,7 @@ class SelectRect(QtWidgets.QGraphicsRectItem):
self.hide() self.hide()
self._label = None self._label = None
label = self._label = QLabel() label = self._label = QtGui.QLabel()
label.setTextFormat(0) # markdown label.setTextFormat(0) # markdown
label.setFont(_font.font) label.setFont(_font.font)
label.setMargin(0) label.setMargin(0)
@ -311,8 +277,8 @@ class SelectRect(QtWidgets.QGraphicsRectItem):
# TODO: get bg color working # TODO: get bg color working
palette.setColor( palette.setColor(
self._label.backgroundRole(), self._label.backgroundRole(),
# QColor(chart.backgroundBrush()), # QtGui.QColor(chart.backgroundBrush()),
QColor(hcolor('papas_special')), QtGui.QColor(hcolor('papas_special')),
) )
def update_on_resize(self, vr, r): def update_on_resize(self, vr, r):
@ -360,7 +326,7 @@ class SelectRect(QtWidgets.QGraphicsRectItem):
self.setPos(r.topLeft()) self.setPos(r.topLeft())
self.resetTransform() self.resetTransform()
self.setRect(r) self.scale(r.width(), r.height())
self.show() self.show()
y1, y2 = start_pos.y(), end_pos.y() y1, y2 = start_pos.y(), end_pos.y()

View File

@ -18,11 +18,11 @@
Qt event proxying and processing using ``trio`` mem chans. Qt event proxying and processing using ``trio`` mem chans.
""" """
from contextlib import asynccontextmanager as acm from contextlib import asynccontextmanager, AsyncExitStack
from typing import Callable from typing import Callable
from pydantic import BaseModel
import trio import trio
from tractor.trionics import gather_contexts
from PyQt5 import QtCore from PyQt5 import QtCore
from PyQt5.QtCore import QEvent, pyqtBoundSignal from PyQt5.QtCore import QEvent, pyqtBoundSignal
from PyQt5.QtWidgets import QWidget from PyQt5.QtWidgets import QWidget
@ -30,8 +30,6 @@ from PyQt5.QtWidgets import (
QGraphicsSceneMouseEvent as gs_mouse, QGraphicsSceneMouseEvent as gs_mouse,
) )
from ..data.types import Struct
MOUSE_EVENTS = { MOUSE_EVENTS = {
gs_mouse.GraphicsSceneMousePress, gs_mouse.GraphicsSceneMousePress,
@ -45,10 +43,13 @@ MOUSE_EVENTS = {
# TODO: maybe consider some constrained ints down the road? # TODO: maybe consider some constrained ints down the road?
# https://pydantic-docs.helpmanual.io/usage/types/#constrained-types # https://pydantic-docs.helpmanual.io/usage/types/#constrained-types
class KeyboardMsg(Struct): class KeyboardMsg(BaseModel):
'''Unpacked Qt keyboard event data. '''Unpacked Qt keyboard event data.
''' '''
class Config:
arbitrary_types_allowed = True
event: QEvent event: QEvent
etype: int etype: int
key: int key: int
@ -56,13 +57,16 @@ class KeyboardMsg(Struct):
txt: str txt: str
def to_tuple(self) -> tuple: def to_tuple(self) -> tuple:
return tuple(self.to_dict().values()) return tuple(self.dict().values())
class MouseMsg(Struct): class MouseMsg(BaseModel):
'''Unpacked Qt keyboard event data. '''Unpacked Qt keyboard event data.
''' '''
class Config:
arbitrary_types_allowed = True
event: QEvent event: QEvent
etype: int etype: int
button: int button: int
@ -156,7 +160,7 @@ class EventRelay(QtCore.QObject):
return False return False
@acm @asynccontextmanager
async def open_event_stream( async def open_event_stream(
source_widget: QWidget, source_widget: QWidget,
@ -182,7 +186,7 @@ async def open_event_stream(
source_widget.removeEventFilter(kc) source_widget.removeEventFilter(kc)
@acm @asynccontextmanager
async def open_signal_handler( async def open_signal_handler(
signal: pyqtBoundSignal, signal: pyqtBoundSignal,
@ -207,7 +211,7 @@ async def open_signal_handler(
yield yield
@acm @asynccontextmanager
async def open_handlers( async def open_handlers(
source_widgets: list[QWidget], source_widgets: list[QWidget],
@ -216,14 +220,16 @@ async def open_handlers(
**kwargs, **kwargs,
) -> None: ) -> None:
async with ( async with (
trio.open_nursery() as n, trio.open_nursery() as n,
gather_contexts([ AsyncExitStack() as stack,
open_event_stream(widget, event_types, **kwargs)
for widget in source_widgets
]) as streams,
): ):
for widget, event_recv_stream in zip(source_widgets, streams): for widget in source_widgets:
event_recv_stream = await stack.enter_async_context(
open_event_stream(widget, event_types, **kwargs)
)
n.start_soon(async_handler, widget, event_recv_stream) n.start_soon(async_handler, widget, event_recv_stream)
yield yield

View File

@ -20,24 +20,16 @@ Trio - Qt integration
Run ``trio`` in guest mode on top of the Qt event loop. Run ``trio`` in guest mode on top of the Qt event loop.
All global Qt runtime settings are mostly defined here. All global Qt runtime settings are mostly defined here.
""" """
from __future__ import annotations from typing import Tuple, Callable, Dict, Any
from typing import (
Callable,
Any,
Type,
TYPE_CHECKING,
)
import platform import platform
import traceback import traceback
# Qt specific # Qt specific
import PyQt5 # noqa import PyQt5 # noqa
from PyQt5.QtWidgets import ( import pyqtgraph as pg
QWidget, from pyqtgraph import QtGui
QMainWindow,
QApplication,
)
from PyQt5 import QtCore from PyQt5 import QtCore
# from PyQt5.QtGui import QLabel, QStatusBar
from PyQt5.QtCore import ( from PyQt5.QtCore import (
pyqtRemoveInputHook, pyqtRemoveInputHook,
Qt, Qt,
@ -45,19 +37,15 @@ from PyQt5.QtCore import (
) )
import qdarkstyle import qdarkstyle
from qdarkstyle import DarkPalette from qdarkstyle import DarkPalette
# import qdarkgraystyle # TODO: play with it # import qdarkgraystyle
import trio import trio
from outcome import Error from outcome import Error
from .._daemon import ( from .._daemon import maybe_open_pikerd, _tractor_kwargs
maybe_open_pikerd,
get_tractor_runtime_kwargs,
)
from ..log import get_logger from ..log import get_logger
from ._pg_overrides import _do_overrides from ._pg_overrides import _do_overrides
from . import _style from . import _style
log = get_logger(__name__) log = get_logger(__name__)
# pyqtgraph global config # pyqtgraph global config
@ -84,18 +72,17 @@ if platform.system() == "Windows":
def run_qtractor( def run_qtractor(
func: Callable, func: Callable,
args: tuple, args: Tuple,
main_widget_type: Type[QWidget], main_widget: QtGui.QWidget,
tractor_kwargs: dict[str, Any] = {}, tractor_kwargs: Dict[str, Any] = {},
window_type: QMainWindow = None, window_type: QtGui.QMainWindow = None,
) -> None: ) -> None:
# avoids annoying message when entering debugger from qt loop # avoids annoying message when entering debugger from qt loop
pyqtRemoveInputHook() pyqtRemoveInputHook()
app = QApplication.instance() app = QtGui.QApplication.instance()
if app is None: if app is None:
app = QApplication([]) app = PyQt5.QtWidgets.QApplication([])
# TODO: we might not need this if it's desired # TODO: we might not need this if it's desired
# to cancel the tractor machinery on Qt loop # to cancel the tractor machinery on Qt loop
@ -169,11 +156,11 @@ def run_qtractor(
# hook into app focus change events # hook into app focus change events
app.focusChanged.connect(window.on_focus_change) app.focusChanged.connect(window.on_focus_change)
instance = main_widget_type() instance = main_widget()
instance.window = window instance.window = window
# override tractor's defaults # override tractor's defaults
tractor_kwargs.update(get_tractor_runtime_kwargs()) tractor_kwargs.update(_tractor_kwargs)
# define tractor entrypoint # define tractor entrypoint
async def main(): async def main():
@ -191,7 +178,7 @@ def run_qtractor(
# restrict_keyboard_interrupt_to_checkpoints=True, # restrict_keyboard_interrupt_to_checkpoints=True,
) )
window.godwidget: GodWidget = instance window.main_widget = main_widget
window.setCentralWidget(instance) window.setCentralWidget(instance)
if is_windows: if is_windows:
window.configure_to_desktop() window.configure_to_desktop()

View File

@ -25,10 +25,13 @@ incremental update.
from __future__ import annotations from __future__ import annotations
from typing import ( from typing import (
Optional, Optional,
Callable,
Union,
) )
import msgspec import msgspec
import numpy as np import numpy as np
from numpy.lib import recfunctions as rfn
import pyqtgraph as pg import pyqtgraph as pg
from PyQt5.QtGui import QPainterPath from PyQt5.QtGui import QPainterPath
from PyQt5.QtCore import QLineF from PyQt5.QtCore import QLineF
@ -36,16 +39,14 @@ from PyQt5.QtCore import QLineF
from ..data._sharedmem import ( from ..data._sharedmem import (
ShmArray, ShmArray,
) )
from ..data.feed import Flume
from .._profile import ( from .._profile import (
pg_profile_enabled, pg_profile_enabled,
# ms_slower_then, # ms_slower_then,
) )
from ._pathops import ( from ._pathops import (
IncrementalFormatter, gen_ohlc_qpath,
OHLCBarsFmtr, # Plain OHLC renderer ohlc_to_line,
OHLCBarsAsCurveFmtr, # OHLC converted to line to_step_format,
StepCurveFmtr, # "step" curve (like for vlm)
xy_downsample, xy_downsample,
) )
from ._ohlc import ( from ._ohlc import (
@ -58,12 +59,70 @@ from ._curve import (
FlattenedOHLC, FlattenedOHLC,
) )
from ..log import get_logger from ..log import get_logger
from .._profile import Profiler
log = get_logger(__name__) log = get_logger(__name__)
# class FlowsTable(msgspec.Struct):
# '''
# Data-AGGRegate: high level API onto multiple (categorized)
# ``Flow``s with high level processing routines for
# multi-graphics computations and display.
# '''
# flows: dict[str, np.ndarray] = {}
def update_ohlc_to_line(
src_shm: ShmArray,
array_key: str,
src_update: np.ndarray,
slc: slice,
ln: int,
first: int,
last: int,
is_append: bool,
) -> np.ndarray:
fields = ['open', 'high', 'low', 'close']
return (
rfn.structured_to_unstructured(src_update[fields]),
slc,
)
def ohlc_flat_to_xy(
r: Renderer,
array: np.ndarray,
array_key: str,
vr: tuple[int, int],
) -> tuple[
np.ndarray,
np.nd.array,
str,
]:
# TODO: in the case of an existing ``.update_xy()``
# should we be passing in array as an xy arrays tuple?
# 2 more datum-indexes to capture zero at end
x_flat = r.x_data[r._xy_first:r._xy_last]
y_flat = r.y_data[r._xy_first:r._xy_last]
# slice to view
ivl, ivr = vr
x_iv_flat = x_flat[ivl:ivr]
y_iv_flat = y_flat[ivl:ivr]
# reshape to 1d for graphics rendering
y_iv = y_iv_flat.reshape(-1)
x_iv = x_iv_flat.reshape(-1)
return x_iv, y_iv, 'all'
def render_baritems( def render_baritems(
flow: Flow, flow: Flow,
graphics: BarItems, graphics: BarItems,
@ -71,7 +130,7 @@ def render_baritems(
int, int, np.ndarray, int, int, np.ndarray,
int, int, np.ndarray, int, int, np.ndarray,
], ],
profiler: Profiler, profiler: pg.debug.Profiler,
**kwargs, **kwargs,
) -> None: ) -> None:
@ -95,24 +154,21 @@ def render_baritems(
r = self._src_r r = self._src_r
if not r: if not r:
show_bars = True show_bars = True
# OHLC bars path renderer # OHLC bars path renderer
r = self._src_r = Renderer( r = self._src_r = Renderer(
flow=self, flow=self,
fmtr=OHLCBarsFmtr( format_xy=gen_ohlc_qpath,
shm=flow.shm, last_read=read,
flow=flow,
_last_read=read,
),
) )
ds_curve_r = Renderer( ds_curve_r = Renderer(
flow=self, flow=self,
fmtr=OHLCBarsAsCurveFmtr( last_read=read,
shm=flow.shm,
flow=flow, # incr update routines
_last_read=read, allocate_xy=ohlc_to_line,
), update_xy=update_ohlc_to_line,
format_xy=ohlc_flat_to_xy,
) )
curve = FlattenedOHLC( curve = FlattenedOHLC(
@ -196,6 +252,77 @@ def render_baritems(
) )
def update_step_xy(
src_shm: ShmArray,
array_key: str,
y_update: np.ndarray,
slc: slice,
ln: int,
first: int,
last: int,
is_append: bool,
) -> np.ndarray:
# for a step curve we slice from one datum prior
# to the current "update slice" to get the previous
# "level".
if is_append:
start = max(last - 1, 0)
end = src_shm._last.value
new_y = src_shm._array[start:end][array_key]
slc = slice(start, end)
else:
new_y = y_update
return (
np.broadcast_to(
new_y[:, None], (new_y.size, 2),
),
slc,
)
def step_to_xy(
r: Renderer,
array: np.ndarray,
array_key: str,
vr: tuple[int, int],
) -> tuple[
np.ndarray,
np.nd.array,
str,
]:
# 2 more datum-indexes to capture zero at end
x_step = r.x_data[r._xy_first:r._xy_last+2]
y_step = r.y_data[r._xy_first:r._xy_last+2]
lasts = array[['index', array_key]]
last = lasts[array_key][-1]
y_step[-1] = last
# slice out in-view data
ivl, ivr = vr
ys_iv = y_step[ivl:ivr+1]
xs_iv = x_step[ivl:ivr+1]
# flatten to 1d
y_iv = ys_iv.reshape(ys_iv.size)
x_iv = xs_iv.reshape(xs_iv.size)
# print(
# f'ys_iv : {ys_iv[-s:]}\n'
# f'y_iv: {y_iv[-s:]}\n'
# f'xs_iv: {xs_iv[-s:]}\n'
# f'x_iv: {x_iv[-s:]}\n'
# )
return x_iv, y_iv, 'all'
class Flow(msgspec.Struct): # , frozen=True): class Flow(msgspec.Struct): # , frozen=True):
''' '''
(Financial Signal-)Flow compound type which wraps a real-time (Financial Signal-)Flow compound type which wraps a real-time
@ -209,18 +336,15 @@ class Flow(msgspec.Struct): # , frozen=True):
''' '''
name: str name: str
plot: pg.PlotItem plot: pg.PlotItem
_shm: ShmArray graphics: Union[Curve, BarItems]
flume: Flume
graphics: Curve | BarItems
# for tracking y-mn/mx for y-axis auto-ranging
yrange: tuple[float, float] = None
# in some cases a flow may want to change its # in some cases a flow may want to change its
# graphical "type" or, "form" when downsampling, to # graphical "type" or, "form" when downsampling,
# start this is only ever an interpolation line. # normally this is just a plain line.
ds_graphics: Optional[Curve] = None ds_graphics: Optional[Curve] = None
_shm: ShmArray
is_ohlc: bool = False is_ohlc: bool = False
render: bool = True # toggle for display loop render: bool = True # toggle for display loop
@ -253,20 +377,19 @@ class Flow(msgspec.Struct): # , frozen=True):
# TODO: remove this and only allow setting through # TODO: remove this and only allow setting through
# private ``._shm`` attr? # private ``._shm`` attr?
# @shm.setter @shm.setter
# def shm(self, shm: ShmArray) -> ShmArray: def shm(self, shm: ShmArray) -> ShmArray:
# self._shm = shm self._shm = shm
def maxmin( def maxmin(
self, self,
lbar: int, lbar: int,
rbar: int, rbar: int,
) -> Optional[tuple[float, float]]: ) -> tuple[float, float]:
''' '''
Compute the cached max and min y-range values for a given Compute the cached max and min y-range values for a given
x-range determined by ``lbar`` and ``rbar`` or ``None`` x-range determined by ``lbar`` and ``rbar``.
if no range can be determined (yet).
''' '''
rkey = (lbar, rbar) rkey = (lbar, rbar)
@ -276,8 +399,9 @@ class Flow(msgspec.Struct): # , frozen=True):
shm = self.shm shm = self.shm
if shm is None: if shm is None:
return None mxmn = None
else: # new block for profiling?..
arr = shm.array arr = shm.array
# build relative indexes into shm array # build relative indexes into shm array
@ -290,11 +414,7 @@ class Flow(msgspec.Struct): # , frozen=True):
] ]
if not slice_view.size: if not slice_view.size:
return None mxmn = None
elif self.yrange:
mxmn = self.yrange
# print(f'{self.name} M4 maxmin: {mxmn}')
else: else:
if self.is_ohlc: if self.is_ohlc:
@ -307,10 +427,9 @@ class Flow(msgspec.Struct): # , frozen=True):
yhigh = np.max(view) yhigh = np.max(view)
mxmn = ylow, yhigh mxmn = ylow, yhigh
# print(f'{self.name} MANUAL maxmin: {mxmin}')
# cache result for input range if mxmn is not None:
assert mxmn # cache new mxmn result
self._mxmns[rkey] = mxmn self._mxmns[rkey] = mxmn
return mxmn return mxmn
@ -322,15 +441,9 @@ class Flow(msgspec.Struct): # , frozen=True):
''' '''
vr = self.plot.viewRect() vr = self.plot.viewRect()
return ( return int(vr.left()), int(vr.right())
vr.left(),
vr.right(),
)
def datums_range( def datums_range(self) -> tuple[
self,
index_field: str = 'index',
) -> tuple[
int, int, int, int, int, int int, int, int, int, int, int
]: ]:
''' '''
@ -338,8 +451,6 @@ class Flow(msgspec.Struct): # , frozen=True):
''' '''
l, r = self.view_range() l, r = self.view_range()
l = round(l)
r = round(r)
# TODO: avoid this and have shm passed # TODO: avoid this and have shm passed
# in earlier. # in earlier.
@ -360,23 +471,15 @@ class Flow(msgspec.Struct): # , frozen=True):
def read( def read(
self, self,
array_field: Optional[str] = None, array_field: Optional[str] = None,
index_field: str = 'index',
) -> tuple[ ) -> tuple[
int, int, np.ndarray, int, int, np.ndarray,
int, int, np.ndarray, int, int, np.ndarray,
]: ]:
''' # read call
Read the underlying shm array buffer and
return the data plus indexes for the first
and last
which has been written to.
'''
# readable data
array = self.shm.array array = self.shm.array
indexes = array[index_field] indexes = array['index']
ifirst = indexes[0] ifirst = indexes[0]
ilast = indexes[-1] ilast = indexes[-1]
@ -408,7 +511,7 @@ class Flow(msgspec.Struct): # , frozen=True):
render: bool = True, render: bool = True,
array_key: Optional[str] = None, array_key: Optional[str] = None,
profiler: Optional[Profiler] = None, profiler: Optional[pg.debug.Profiler] = None,
do_append: bool = True, do_append: bool = True,
**kwargs, **kwargs,
@ -419,7 +522,7 @@ class Flow(msgspec.Struct): # , frozen=True):
render to graphics. render to graphics.
''' '''
profiler = Profiler( profiler = pg.debug.Profiler(
msg=f'Flow.update_graphics() for {self.name}', msg=f'Flow.update_graphics() for {self.name}',
disabled=not pg_profile_enabled(), disabled=not pg_profile_enabled(),
ms_threshold=4, ms_threshold=4,
@ -444,14 +547,9 @@ class Flow(msgspec.Struct): # , frozen=True):
slice_to_head: int = -1 slice_to_head: int = -1
should_redraw: bool = False should_redraw: bool = False
should_line: bool = False
rkwargs = {} rkwargs = {}
# TODO: probably specialize ``Renderer`` types instead of should_line = False
# these logic checks?
# - put these blocks into a `.load_renderer()` meth?
# - consider a OHLCRenderer, StepCurveRenderer, Renderer?
r = self._src_r
if isinstance(graphics, BarItems): if isinstance(graphics, BarItems):
# XXX: special case where we change out graphics # XXX: special case where we change out graphics
# to a line after a certain uppx threshold. # to a line after a certain uppx threshold.
@ -471,34 +569,14 @@ class Flow(msgspec.Struct): # , frozen=True):
should_redraw = changed_to_line or not should_line should_redraw = changed_to_line or not should_line
self._in_ds = should_line self._in_ds = should_line
elif not r:
if isinstance(graphics, StepCurve):
r = self._src_r = Renderer(
flow=self,
fmtr=StepCurveFmtr(
shm=self.shm,
flow=self,
_last_read=read,
),
)
# TODO: append logic inside ``.render()`` isn't
# correct yet for step curves.. remove this to see it.
should_redraw = True
slice_to_head = -2
else: else:
r = self._src_r r = self._src_r
if not r: if not r:
# just using for ``.diff()`` atm.. # just using for ``.diff()`` atm..
r = self._src_r = Renderer( r = self._src_r = Renderer(
flow=self, flow=self,
fmtr=IncrementalFormatter( # TODO: rename this to something with ohlc
shm=self.shm, last_read=read,
flow=self,
_last_read=read,
),
) )
# ``Curve`` derivative case(s): # ``Curve`` derivative case(s):
@ -510,6 +588,19 @@ class Flow(msgspec.Struct): # , frozen=True):
should_ds: bool = r._in_ds should_ds: bool = r._in_ds
showing_src_data: bool = not r._in_ds showing_src_data: bool = not r._in_ds
# step_mode = getattr(graphics, '_step_mode', False)
step_mode = isinstance(graphics, StepCurve)
if step_mode:
r.allocate_xy = to_step_format
r.update_xy = update_step_xy
r.format_xy = step_to_xy
# TODO: append logic inside ``.render()`` isn't
# correct yet for step curves.. remove this to see it.
should_redraw = True
slice_to_head = -2
# downsampling incremental state checking # downsampling incremental state checking
# check for and set std m4 downsample conditions # check for and set std m4 downsample conditions
uppx = graphics.x_uppx() uppx = graphics.x_uppx()
@ -537,13 +628,10 @@ class Flow(msgspec.Struct): # , frozen=True):
# source data so we clear our path data in prep # source data so we clear our path data in prep
# to generate a new one from original source data. # to generate a new one from original source data.
new_sample_rate = True new_sample_rate = True
showing_src_data = True
should_ds = False should_ds = False
should_redraw = True should_redraw = True
showing_src_data = True
# reset yrange to be computed from source data
self.yrange = None
# MAIN RENDER LOGIC: # MAIN RENDER LOGIC:
# - determine in view data and redraw on range change # - determine in view data and redraw on range change
# - determine downsampling ops if needed # - determine downsampling ops if needed
@ -569,10 +657,6 @@ class Flow(msgspec.Struct): # , frozen=True):
**rkwargs, **rkwargs,
) )
if showing_src_data:
# print(f"{self.name} SHOWING SOURCE")
# reset yrange to be computed from source data
self.yrange = None
if not out: if not out:
log.warning(f'{self.name} failed to render!?') log.warning(f'{self.name} failed to render!?')
@ -580,29 +664,25 @@ class Flow(msgspec.Struct): # , frozen=True):
path, data, reset = out path, data, reset = out
# if self.yrange:
# print(f'flow {self.name} yrange from m4: {self.yrange}')
# XXX: SUPER UGGGHHH... without this we get stale cache # XXX: SUPER UGGGHHH... without this we get stale cache
# graphics that don't update until you downsampler again.. # graphics that don't update until you downsampler again..
# reset = False if reset:
# if reset: with graphics.reset_cache():
# with graphics.reset_cache(): # assign output paths to graphicis obj
# # assign output paths to graphicis obj graphics.path = r.path
# graphics.path = r.path graphics.fast_path = r.fast_path
# graphics.fast_path = r.fast_path
# # XXX: we don't need this right? # XXX: we don't need this right?
# # graphics.draw_last_datum( # graphics.draw_last_datum(
# # path, # path,
# # src_array, # src_array,
# # data, # data,
# # reset, # reset,
# # array_key, # array_key,
# # ) # )
# # graphics.update() # graphics.update()
# # profiler('.update()') # profiler('.update()')
# else: else:
# assign output paths to graphicis obj # assign output paths to graphicis obj
graphics.path = r.path graphics.path = r.path
graphics.fast_path = r.fast_path graphics.fast_path = r.fast_path
@ -689,10 +769,51 @@ class Flow(msgspec.Struct): # , frozen=True):
g.update() g.update()
def by_index_and_key(
renderer: Renderer,
array: np.ndarray,
array_key: str,
vr: tuple[int, int],
) -> tuple[
np.ndarray,
np.ndarray,
np.ndarray,
]:
return array['index'], array[array_key], 'all'
class Renderer(msgspec.Struct): class Renderer(msgspec.Struct):
flow: Flow flow: Flow
fmtr: IncrementalFormatter # last array view read
last_read: Optional[tuple] = None
# default just returns index, and named array from data
format_xy: Callable[
[np.ndarray, str],
tuple[np.ndarray]
] = by_index_and_key
# optional pre-graphics xy formatted data which
# is incrementally updated in sync with the source data.
allocate_xy: Optional[Callable[
[int, slice],
tuple[np.ndarray, np.nd.array]
]] = None
update_xy: Optional[Callable[
[int, slice], None]
] = None
x_data: Optional[np.ndarray] = None
y_data: Optional[np.ndarray] = None
# indexes which slice into the above arrays (which are allocated
# based on source data shm input size) and allow retrieving
# incrementally updated data.
_xy_first: int = 0
_xy_last: int = 0
# output graphics rendering, the main object # output graphics rendering, the main object
# processed in ``QGraphicsObject.paint()`` # processed in ``QGraphicsObject.paint()``
@ -714,11 +835,58 @@ class Renderer(msgspec.Struct):
_last_uppx: float = 0 _last_uppx: float = 0
_in_ds: bool = False _in_ds: bool = False
# incremental update state(s)
_last_vr: Optional[tuple[float, float]] = None
_last_ivr: Optional[tuple[float, float]] = None
def diff(
self,
new_read: tuple[np.ndarray],
) -> tuple[
np.ndarray,
np.ndarray,
]:
(
last_xfirst,
last_xlast,
last_array,
last_ivl,
last_ivr,
last_in_view,
) = self.last_read
# TODO: can the renderer just call ``Flow.read()`` directly?
# unpack latest source data read
(
xfirst,
xlast,
array,
ivl,
ivr,
in_view,
) = new_read
# compute the length diffs between the first/last index entry in
# the input data and the last indexes we have on record from the
# last time we updated the curve index.
prepend_length = int(last_xfirst - xfirst)
append_length = int(xlast - last_xlast)
# blah blah blah
# do diffing for prepend, append and last entry
return (
slice(xfirst, last_xfirst),
prepend_length,
append_length,
slice(last_xlast, xlast),
)
def draw_path( def draw_path(
self, self,
x: np.ndarray, x: np.ndarray,
y: np.ndarray, y: np.ndarray,
connect: str | np.ndarray = 'all', connect: Union[str, np.ndarray] = 'all',
path: Optional[QPainterPath] = None, path: Optional[QPainterPath] = None,
redraw: bool = False, redraw: bool = False,
@ -764,7 +932,7 @@ class Renderer(msgspec.Struct):
new_read, new_read,
array_key: str, array_key: str,
profiler: Profiler, profiler: pg.debug.Profiler,
uppx: float = 1, uppx: float = 1,
# redraw and ds flags # redraw and ds flags
@ -796,54 +964,165 @@ class Renderer(msgspec.Struct):
''' '''
# TODO: can the renderer just call ``Flow.read()`` directly? # TODO: can the renderer just call ``Flow.read()`` directly?
# unpack latest source data read # unpack latest source data read
fmtr = self.fmtr
( (
_, xfirst,
_, xlast,
array, array,
ivl, ivl,
ivr, ivr,
in_view, in_view,
) = new_read ) = new_read
# xy-path data transform: convert source data to a format
# able to be passed to a `QPainterPath` rendering routine.
fmt_out = fmtr.format_to_1d(
new_read,
array_key,
profiler,
slice_to_head=slice_to_head,
read_src_from_key=read_from_key,
slice_to_inview=use_vr,
)
# no history in view case
if not fmt_out:
# XXX: this might be why the profiler only has exits?
return
( (
x_1d, pre_slice,
y_1d,
connect,
prepend_length, prepend_length,
append_length, append_length,
view_changed, post_slice,
# append_tres, ) = self.diff(new_read)
) = fmt_out if self.update_xy:
shm = self.flow.shm
if self.y_data is None:
# we first need to allocate xy data arrays
# from the source data.
assert self.allocate_xy
self.x_data, self.y_data = self.allocate_xy(
shm,
array_key,
)
self._xy_first = shm._first.value
self._xy_last = shm._last.value
profiler('allocated xy history')
if prepend_length:
y_prepend = shm._array[pre_slice]
if read_from_key:
y_prepend = y_prepend[array_key]
xy_data, xy_slice = self.update_xy(
shm,
array_key,
# this is the pre-sliced, "normally expected"
# new data that an updater would normally be
# expected to process, however in some cases (like
# step curves) the updater routine may want to do
# the source history-data reading itself, so we pass
# both here.
y_prepend,
pre_slice,
prepend_length,
self._xy_first,
self._xy_last,
is_append=False,
)
self.y_data[xy_slice] = xy_data
self._xy_first = shm._first.value
profiler('prepended xy history: {prepend_length}')
if append_length:
y_append = shm._array[post_slice]
if read_from_key:
y_append = y_append[array_key]
xy_data, xy_slice = self.update_xy(
shm,
array_key,
y_append,
post_slice,
append_length,
self._xy_first,
self._xy_last,
is_append=True,
)
# self.y_data[post_slice] = xy_data
# self.y_data[xy_slice or post_slice] = xy_data
self.y_data[xy_slice] = xy_data
self._xy_last = shm._last.value
profiler('appened xy history: {append_length}')
if use_vr:
array = in_view
# else:
# ivl, ivr = xfirst, xlast
hist = array[:slice_to_head]
# xy-path data transform: convert source data to a format
# able to be passed to a `QPainterPath` rendering routine.
if not len(hist):
return
x_out, y_out, connect = self.format_xy(
self,
# TODO: hist here should be the pre-sliced
# x/y_data in the case where allocate_xy is
# defined?
hist,
array_key,
(ivl, ivr),
)
profiler('sliced input arrays')
if (
use_vr
):
# if a view range is passed, plan to draw the
# source ouput that's "in view" of the chart.
view_range = (ivl, ivr)
# print(f'{self._name} vr: {view_range}')
profiler(f'view range slice {view_range}')
vl, vr = view_range
zoom_or_append = False
last_vr = self._last_vr
last_ivr = self._last_ivr or vl, vr
# incremental in-view data update.
if last_vr:
# relative slice indices
lvl, lvr = last_vr
# abs slice indices
al, ar = last_ivr
# left_change = abs(x_iv[0] - al) >= 1
# right_change = abs(x_iv[-1] - ar) >= 1
if (
# likely a zoom view change
(vr - lvr) > 2 or vl < lvl
# append / prepend update
# we had an append update where the view range
# didn't change but the data-viewed (shifted)
# underneath, so we need to redraw.
# or left_change and right_change and last_vr == view_range
# not (left_change and right_change) and ivr
# (
# or abs(x_iv[ivr] - livr) > 1
):
zoom_or_append = True
self._last_vr = view_range
if len(x_out):
self._last_ivr = x_out[0], x_out[slice_to_head]
# redraw conditions # redraw conditions
if ( if (
prepend_length > 0 prepend_length > 0
or new_sample_rate or new_sample_rate
or view_changed
# NOTE: comment this to try and make "append paths"
# work below..
or append_length > 0 or append_length > 0
or zoom_or_append
): ):
should_redraw = True should_redraw = True
@ -865,21 +1144,18 @@ class Renderer(msgspec.Struct):
elif should_ds and uppx > 1: elif should_ds and uppx > 1:
x_1d, y_1d, ymn, ymx = xy_downsample( x_out, y_out = xy_downsample(
x_1d, x_out,
y_1d, y_out,
uppx, uppx,
) )
self.flow.yrange = ymn, ymx
# print(f'{self.flow.name} post ds: ymn, ymx: {ymn},{ymx}')
reset = True reset = True
profiler(f'FULL PATH downsample redraw={should_ds}') profiler(f'FULL PATH downsample redraw={should_ds}')
self._in_ds = True self._in_ds = True
path = self.draw_path( path = self.draw_path(
x=x_1d, x=x_out,
y=y_1d, y=y_out,
connect=connect, connect=connect,
path=path, path=path,
redraw=True, redraw=True,
@ -894,6 +1170,7 @@ class Renderer(msgspec.Struct):
# TODO: get this piecewise prepend working - right now it's # TODO: get this piecewise prepend working - right now it's
# giving heck on vwap... # giving heck on vwap...
# elif prepend_length: # elif prepend_length:
# breakpoint()
# prepend_path = pg.functions.arrayToQPath( # prepend_path = pg.functions.arrayToQPath(
# x[0:prepend_length], # x[0:prepend_length],
@ -910,22 +1187,18 @@ class Renderer(msgspec.Struct):
elif ( elif (
append_length > 0 append_length > 0
and do_append and do_append
and not should_redraw
): ):
print(f'{array_key} append len: {append_length}') # print(f'{array_key} append len: {append_length}')
# new_x = x_1d[-append_length - 2:] # slice_to_head] new_x = x_out[-append_length - 2:] # slice_to_head]
# new_y = y_1d[-append_length - 2:] # slice_to_head] new_y = y_out[-append_length - 2:] # slice_to_head]
profiler('sliced append path') profiler('sliced append path')
# (
# x_1d,
# y_1d,
# connect,
# ) = append_tres
profiler( profiler(
f'diffed array input, append_length={append_length}' f'diffed array input, append_length={append_length}'
) )
# if should_ds and uppx > 1: # if should_ds:
# new_x, new_y = xy_downsample( # new_x, new_y = xy_downsample(
# new_x, # new_x,
# new_y, # new_y,
@ -934,15 +1207,14 @@ class Renderer(msgspec.Struct):
# profiler(f'fast path downsample redraw={should_ds}') # profiler(f'fast path downsample redraw={should_ds}')
append_path = self.draw_path( append_path = self.draw_path(
x=x_1d, x=new_x,
y=y_1d, y=new_y,
connect=connect, connect=connect,
path=fast_path, path=fast_path,
) )
profiler('generated append qpath') profiler('generated append qpath')
if use_fpath: if use_fpath:
# print(f'{self.flow.name}: FAST PATH')
# an attempt at trying to make append-updates faster.. # an attempt at trying to make append-updates faster..
if fast_path is None: if fast_path is None:
fast_path = append_path fast_path = append_path
@ -952,12 +1224,7 @@ class Renderer(msgspec.Struct):
size = fast_path.capacity() size = fast_path.capacity()
profiler(f'connected fast path w size: {size}') profiler(f'connected fast path w size: {size}')
print( # print(f"append_path br: {append_path.boundingRect()}")
f"append_path br: {append_path.boundingRect()}\n"
f"path size: {size}\n"
f"append_path len: {append_path.length()}\n"
f"fast_path len: {fast_path.length()}\n"
)
# graphics.path.moveTo(new_x[0], new_y[0]) # graphics.path.moveTo(new_x[0], new_y[0])
# path.connectPath(append_path) # path.connectPath(append_path)
@ -971,4 +1238,10 @@ class Renderer(msgspec.Struct):
self.path = path self.path = path
self.fast_path = fast_path self.fast_path = fast_path
# TODO: eventually maybe we can implement some kind of
# transform on the ``QPainterPath`` that will more or less
# detect the diff in "elements" terms?
# update diff state since we've now rendered paths.
self.last_read = new_read
return self.path, array, reset return self.path, array, reset

View File

@ -619,7 +619,7 @@ class FillStatusBar(QProgressBar):
# color: #19232D; # color: #19232D;
# width: 10px; # width: 10px;
self.setRange(0, int(slots)) self.setRange(0, slots)
self.setValue(value) self.setValue(value)
@ -644,7 +644,7 @@ def mk_fill_status_bar(
# TODO: calc this height from the ``ChartnPane`` # TODO: calc this height from the ``ChartnPane``
chart_h = round(parent_pane.height() * 5/8) chart_h = round(parent_pane.height() * 5/8)
bar_h = chart_h * 0.375*0.9 bar_h = chart_h * 0.375
# TODO: once things are sized to screen # TODO: once things are sized to screen
bar_label_font_size = label_font_size or _font.px_size - 2 bar_label_font_size = label_font_size or _font.px_size - 2

View File

@ -27,13 +27,12 @@ from itertools import cycle
from typing import Optional, AsyncGenerator, Any from typing import Optional, AsyncGenerator, Any
import numpy as np import numpy as np
import msgspec from pydantic import create_model
import tractor import tractor
import pyqtgraph as pg import pyqtgraph as pg
import trio import trio
from trio_typing import TaskStatus from trio_typing import TaskStatus
from piker.data.types import Struct
from ._axes import PriceAxis from ._axes import PriceAxis
from .._cacheables import maybe_open_context from .._cacheables import maybe_open_context
from ..calc import humanize from ..calc import humanize
@ -42,8 +41,6 @@ from ..data._sharedmem import (
_Token, _Token,
try_read, try_read,
) )
from ..data.feed import Flume
from ..data._source import Symbol
from ._chart import ( from ._chart import (
ChartPlotWidget, ChartPlotWidget,
LinkedSplits, LinkedSplits,
@ -53,18 +50,14 @@ from ._forms import (
mk_form, mk_form,
open_form_input_handling, open_form_input_handling,
) )
from ..fsp._api import ( from ..fsp._api import maybe_mk_fsp_shm, Fsp
maybe_mk_fsp_shm,
Fsp,
)
from ..fsp import cascade from ..fsp import cascade
from ..fsp._volume import ( from ..fsp._volume import (
# tina_vwap, tina_vwap,
dolla_vlm, dolla_vlm,
flow_rates, flow_rates,
) )
from ..log import get_logger from ..log import get_logger
from .._profile import Profiler
log = get_logger(__name__) log = get_logger(__name__)
@ -112,8 +105,7 @@ def update_fsp_chart(
# sub-charts reference it under different 'named charts'. # sub-charts reference it under different 'named charts'.
# read from last calculated value and update any label # read from last calculated value and update any label
last_val_sticky = chart.plotItem.getAxis( last_val_sticky = chart._ysticks.get(graphics_name)
'right')._stickies.get(graphics_name)
if last_val_sticky: if last_val_sticky:
last = last_row[array_key] last = last_row[array_key]
last_val_sticky.update_from_data(-1, last) last_val_sticky.update_from_data(-1, last)
@ -161,13 +153,12 @@ async def open_fsp_sidepane(
) )
# https://pydantic-docs.helpmanual.io/usage/models/#dynamic-model-creation # https://pydantic-docs.helpmanual.io/usage/models/#dynamic-model-creation
FspConfig = msgspec.defstruct( FspConfig = create_model(
"Point", 'FspConfig',
[('name', name)] + list(params.items()), name=name,
bases=(Struct,), **params,
) )
model = FspConfig(name=name, **params) sidepane.model = FspConfig()
sidepane.model = model
# just a logger for now until we get fsp configs up and running. # just a logger for now until we get fsp configs up and running.
async def settings_change( async def settings_change(
@ -197,7 +188,7 @@ async def open_fsp_actor_cluster(
from tractor._clustering import open_actor_cluster from tractor._clustering import open_actor_cluster
# profiler = Profiler( # profiler = pg.debug.Profiler(
# delayed=False, # delayed=False,
# disabled=False # disabled=False
# ) # )
@ -214,12 +205,12 @@ async def open_fsp_actor_cluster(
async def run_fsp_ui( async def run_fsp_ui(
linkedsplits: LinkedSplits, linkedsplits: LinkedSplits,
flume: Flume, shm: ShmArray,
started: trio.Event, started: trio.Event,
target: Fsp, target: Fsp,
conf: dict[str, dict], conf: dict[str, dict],
loglevel: str, loglevel: str,
# profiler: Profiler, # profiler: pg.debug.Profiler,
# _quote_throttle_rate: int = 58, # _quote_throttle_rate: int = 58,
) -> None: ) -> None:
@ -251,11 +242,9 @@ async def run_fsp_ui(
else: else:
chart = linkedsplits.subplots[overlay_with] chart = linkedsplits.subplots[overlay_with]
shm = flume.rt_shm
chart.draw_curve( chart.draw_curve(
name, name=name,
shm, shm=shm,
flume,
overlay=True, overlay=True,
color='default_light', color='default_light',
array_key=name, array_key=name,
@ -265,9 +254,8 @@ async def run_fsp_ui(
else: else:
# create a new sub-chart widget for this fsp # create a new sub-chart widget for this fsp
chart = linkedsplits.add_plot( chart = linkedsplits.add_plot(
name, name=name,
shm, shm=shm,
flume,
array_key=name, array_key=name,
sidepane=sidepane, sidepane=sidepane,
@ -357,9 +345,6 @@ async def run_fsp_ui(
# last = time.time() # last = time.time()
# TODO: maybe this should be our ``Flow`` type since it maps
# one flume to the next? The machinery for task/actor mgmt should
# be part of the instantiation API?
class FspAdmin: class FspAdmin:
''' '''
Client API for orchestrating FSP actors and displaying Client API for orchestrating FSP actors and displaying
@ -371,7 +356,7 @@ class FspAdmin:
tn: trio.Nursery, tn: trio.Nursery,
cluster: dict[str, tractor.Portal], cluster: dict[str, tractor.Portal],
linked: LinkedSplits, linked: LinkedSplits,
flume: Flume, src_shm: ShmArray,
) -> None: ) -> None:
self.tn = tn self.tn = tn
@ -383,11 +368,7 @@ class FspAdmin:
tuple[tractor.MsgStream, ShmArray] tuple[tractor.MsgStream, ShmArray]
] = {} ] = {}
self._flow_registry: dict[_Token, str] = {} self._flow_registry: dict[_Token, str] = {}
self.src_shm = src_shm
# TODO: make this a `.src_flume` and add
# a `dst_flume`?
# (=> but then wouldn't this be the most basic `Flow`?)
self.flume = flume
def rr_next_portal(self) -> tractor.Portal: def rr_next_portal(self) -> tractor.Portal:
name, portal = next(self._rr_next_actor) name, portal = next(self._rr_next_actor)
@ -400,7 +381,7 @@ class FspAdmin:
complete: trio.Event, complete: trio.Event,
started: trio.Event, started: trio.Event,
fqsn: str, fqsn: str,
dst_fsp_flume: Flume, dst_shm: ShmArray,
conf: dict, conf: dict,
target: Fsp, target: Fsp,
loglevel: str, loglevel: str,
@ -421,10 +402,9 @@ class FspAdmin:
# data feed key # data feed key
fqsn=fqsn, fqsn=fqsn,
# TODO: pass `Flume.to_msg()`s here?
# mems # mems
src_shm_token=self.flume.rt_shm.token, src_shm_token=self.src_shm.token,
dst_shm_token=dst_fsp_flume.rt_shm.token, dst_shm_token=dst_shm.token,
# target # target
ns_path=ns_path, ns_path=ns_path,
@ -441,14 +421,12 @@ class FspAdmin:
ctx.open_stream() as stream, ctx.open_stream() as stream,
): ):
dst_fsp_flume.stream: tractor.MsgStream = stream
# register output data # register output data
self._registry[ self._registry[
(fqsn, ns_path) (fqsn, ns_path)
] = ( ] = (
stream, stream,
dst_fsp_flume.rt_shm, dst_shm,
complete complete
) )
@ -462,9 +440,7 @@ class FspAdmin:
# if the chart isn't hidden try to update # if the chart isn't hidden try to update
# the data on screen. # the data on screen.
if not self.linked.isHidden(): if not self.linked.isHidden():
log.debug( log.debug(f'Re-syncing graphics for fsp: {ns_path}')
f'Re-syncing graphics for fsp: {ns_path}'
)
self.linked.graphics_cycle( self.linked.graphics_cycle(
trigger_all=True, trigger_all=True,
prepend_update_index=info['first'], prepend_update_index=info['first'],
@ -483,9 +459,9 @@ class FspAdmin:
worker_name: Optional[str] = None, worker_name: Optional[str] = None,
loglevel: str = 'info', loglevel: str = 'info',
) -> (Flume, trio.Event): ) -> (ShmArray, trio.Event):
fqsn = self.flume.symbol.fqsn fqsn = self.linked.symbol.front_fqsn()
# allocate an output shm array # allocate an output shm array
key, dst_shm, opened = maybe_mk_fsp_shm( key, dst_shm, opened = maybe_mk_fsp_shm(
@ -493,36 +469,16 @@ class FspAdmin:
target=target, target=target,
readonly=True, readonly=True,
) )
self._flow_registry[
portal = self.cluster.get(worker_name) or self.rr_next_portal() (self.src_shm._token, target.name)
provider_tag = portal.channel.uid ] = dst_shm._token
symbol = Symbol(
key=key,
broker_info={
provider_tag: {'asset_type': 'fsp'},
},
)
dst_fsp_flume = Flume(
symbol=symbol,
_rt_shm_token=dst_shm.token,
first_quote={},
# set to 0 presuming for now that we can't load
# FSP history (though we should eventually).
izero_hist=0,
izero_rt=0,
)
self._flow_registry[(
self.flume.rt_shm._token,
target.name
)] = dst_shm._token
# if not opened: # if not opened:
# raise RuntimeError( # raise RuntimeError(
# f'Already started FSP `{fqsn}:{func_name}`' # f'Already started FSP `{fqsn}:{func_name}`'
# ) # )
portal = self.cluster.get(worker_name) or self.rr_next_portal()
complete = trio.Event() complete = trio.Event()
started = trio.Event() started = trio.Event()
self.tn.start_soon( self.tn.start_soon(
@ -531,13 +487,13 @@ class FspAdmin:
complete, complete,
started, started,
fqsn, fqsn,
dst_fsp_flume, dst_shm,
conf, conf,
target, target,
loglevel, loglevel,
) )
return dst_fsp_flume, started return dst_shm, started
async def open_fsp_chart( async def open_fsp_chart(
self, self,
@ -549,7 +505,7 @@ class FspAdmin:
) -> (trio.Event, ChartPlotWidget): ) -> (trio.Event, ChartPlotWidget):
flume, started = await self.start_engine_task( shm, started = await self.start_engine_task(
target, target,
conf, conf,
loglevel, loglevel,
@ -561,7 +517,7 @@ class FspAdmin:
run_fsp_ui, run_fsp_ui,
self.linked, self.linked,
flume, shm,
started, started,
target, target,
@ -575,7 +531,7 @@ class FspAdmin:
@acm @acm
async def open_fsp_admin( async def open_fsp_admin(
linked: LinkedSplits, linked: LinkedSplits,
flume: Flume, src_shm: ShmArray,
**kwargs, **kwargs,
) -> AsyncGenerator[dict, dict[str, tractor.Portal]]: ) -> AsyncGenerator[dict, dict[str, tractor.Portal]]:
@ -596,7 +552,7 @@ async def open_fsp_admin(
tn, tn,
cluster_map, cluster_map,
linked, linked,
flume, src_shm,
) )
try: try:
yield admin yield admin
@ -610,7 +566,7 @@ async def open_fsp_admin(
async def open_vlm_displays( async def open_vlm_displays(
linked: LinkedSplits, linked: LinkedSplits,
flume: Flume, ohlcv: ShmArray,
dvlm: bool = True, dvlm: bool = True,
task_status: TaskStatus[ChartPlotWidget] = trio.TASK_STATUS_IGNORED, task_status: TaskStatus[ChartPlotWidget] = trio.TASK_STATUS_IGNORED,
@ -632,8 +588,6 @@ async def open_vlm_displays(
sig = inspect.signature(flow_rates.func) sig = inspect.signature(flow_rates.func)
params = sig.parameters params = sig.parameters
ohlcv: ShmArray = flume.rt_shm
async with ( async with (
open_fsp_sidepane( open_fsp_sidepane(
linked, { linked, {
@ -653,7 +607,7 @@ async def open_vlm_displays(
} }
}, },
) as sidepane, ) as sidepane,
open_fsp_admin(linked, flume) as admin, open_fsp_admin(linked, ohlcv) as admin,
): ):
# TODO: support updates # TODO: support updates
# period_field = sidepane.fields['period'] # period_field = sidepane.fields['period']
@ -664,12 +618,9 @@ async def open_vlm_displays(
# built-in vlm which we plot ASAP since it's # built-in vlm which we plot ASAP since it's
# usually data provided directly with OHLC history. # usually data provided directly with OHLC history.
shm = ohlcv shm = ohlcv
ohlc_chart = linked.chart
chart = linked.add_plot( chart = linked.add_plot(
name='volume', name='volume',
shm=shm, shm=shm,
flume=flume,
array_key='volume', array_key='volume',
sidepane=sidepane, sidepane=sidepane,
@ -682,36 +633,26 @@ async def open_vlm_displays(
# the curve item internals are pretty convoluted. # the curve item internals are pretty convoluted.
style='step', style='step',
) )
# back-link the volume chart to trigger y-autoranging
# in the ohlc (parent) chart.
ohlc_chart.view.enable_auto_yrange(
src_vb=chart.view,
)
# force 0 to always be in view # force 0 to always be in view
def multi_maxmin( def multi_maxmin(
names: list[str], names: list[str],
) -> tuple[float, float]: ) -> tuple[float, float]:
'''
Flows "group" maxmin loop; assumes all named flows
are in the same co-domain and thus can be sorted
as one set.
Iterates all the named flows and calls the chart
api to find their range values and return.
TODO: really we should probably have a more built-in API
for this?
'''
mx = 0 mx = 0
for name in names: for name in names:
ymn, ymx = chart.maxmin(name=name)
mx = max(mx, ymx) mxmn = chart.maxmin(name=name)
if mxmn:
ymax = mxmn[1]
if ymax > mx:
mx = ymax
return 0, mx return 0, mx
chart.view.maxmin = partial(multi_maxmin, names=['volume'])
# TODO: fix the x-axis label issue where if you put # TODO: fix the x-axis label issue where if you put
# the axis on the left it's totally not lined up... # the axis on the left it's totally not lined up...
# show volume units value on LHS (for dinkus) # show volume units value on LHS (for dinkus)
@ -725,8 +666,7 @@ async def open_vlm_displays(
assert chart.name != linked.chart.name assert chart.name != linked.chart.name
# sticky only on sub-charts atm # sticky only on sub-charts atm
last_val_sticky = chart.plotItem.getAxis( last_val_sticky = chart._ysticks[chart.name]
'right')._stickies.get(chart.name)
# read from last calculated value # read from last calculated value
value = shm.array['volume'][-1] value = shm.array['volume'][-1]
@ -749,7 +689,7 @@ async def open_vlm_displays(
tasks_ready = [] tasks_ready = []
# spawn and overlay $ vlm on the same subchart # spawn and overlay $ vlm on the same subchart
dvlm_flume, started = await admin.start_engine_task( dvlm_shm, started = await admin.start_engine_task(
dolla_vlm, dolla_vlm,
{ # fsp engine conf { # fsp engine conf
@ -796,8 +736,6 @@ async def open_vlm_displays(
}, },
) )
dvlm_pi.hideAxis('left')
dvlm_pi.hideAxis('bottom')
# all to be overlayed curve names # all to be overlayed curve names
fields = [ fields = [
'dolla_vlm', 'dolla_vlm',
@ -838,7 +776,6 @@ async def open_vlm_displays(
) -> None: ) -> None:
for name in names: for name in names:
if 'dark' in name: if 'dark' in name:
color = dark_vlm_color color = dark_vlm_color
elif 'rate' in name: elif 'rate' in name:
@ -846,13 +783,9 @@ async def open_vlm_displays(
else: else:
color = 'bracket' color = 'bracket'
assert isinstance(shm, ShmArray) curve, _ = chart.draw_curve(
assert isinstance(flume, Flume) name=name,
shm=shm,
flow = chart.draw_curve(
name,
shm,
flume,
array_key=name, array_key=name,
overlay=pi, overlay=pi,
color=color, color=color,
@ -865,20 +798,20 @@ async def open_vlm_displays(
# specially store ref to shm for lookup in display loop # specially store ref to shm for lookup in display loop
# since only a placeholder of `None` is entered in # since only a placeholder of `None` is entered in
# ``.draw_curve()``. # ``.draw_curve()``.
# flow = chart._flows[name] flow = chart._flows[name]
assert flow.plot is pi assert flow.plot is pi
chart_curves( chart_curves(
fields, fields,
dvlm_pi, dvlm_pi,
dvlm_flume.rt_shm, dvlm_shm,
step_mode=True, step_mode=True,
) )
# spawn flow rates fsp **ONLY AFTER** the 'dolla_vlm' fsp is # spawn flow rates fsp **ONLY AFTER** the 'dolla_vlm' fsp is
# up since this one depends on it. # up since this one depends on it.
fr_flume, started = await admin.start_engine_task( fr_shm, started = await admin.start_engine_task(
flow_rates, flow_rates,
{ # fsp engine conf { # fsp engine conf
'func_name': 'flow_rates', 'func_name': 'flow_rates',
@ -891,7 +824,7 @@ async def open_vlm_displays(
# chart_curves( # chart_curves(
# dvlm_rate_fields, # dvlm_rate_fields,
# dvlm_pi, # dvlm_pi,
# fr_flume.rt_shm, # fr_shm,
# ) # )
# TODO: is there a way to "sync" the dual axes such that only # TODO: is there a way to "sync" the dual axes such that only
@ -934,12 +867,11 @@ async def open_vlm_displays(
# keep both regular and dark vlm in view # keep both regular and dark vlm in view
names=trade_rate_fields, names=trade_rate_fields,
) )
tr_pi.hideAxis('bottom')
chart_curves( chart_curves(
trade_rate_fields, trade_rate_fields,
tr_pi, tr_pi,
fr_flume.rt_shm, fr_shm,
# step_mode=True, # step_mode=True,
# dashed line to represent "individual trades" being # dashed line to represent "individual trades" being
@ -973,7 +905,7 @@ async def open_vlm_displays(
async def start_fsp_displays( async def start_fsp_displays(
linked: LinkedSplits, linked: LinkedSplits,
flume: Flume, ohlcv: ShmArray,
group_status_key: str, group_status_key: str,
loglevel: str, loglevel: str,
@ -1008,7 +940,7 @@ async def start_fsp_displays(
# }, # },
# }, # },
} }
profiler = Profiler( profiler = pg.debug.Profiler(
delayed=False, delayed=False,
disabled=False disabled=False
) )
@ -1016,10 +948,7 @@ async def start_fsp_displays(
async with ( async with (
# NOTE: this admin internally opens an actor cluster # NOTE: this admin internally opens an actor cluster
open_fsp_admin( open_fsp_admin(linked, ohlcv) as admin,
linked,
flume,
) as admin,
): ):
statuses = [] statuses = []
for target, conf in fsp_conf.items(): for target, conf in fsp_conf.items():

View File

@ -33,7 +33,6 @@ import numpy as np
import trio import trio
from ..log import get_logger from ..log import get_logger
from .._profile import Profiler
from .._profile import pg_profile_enabled, ms_slower_then from .._profile import pg_profile_enabled, ms_slower_then
# from ._style import _min_points_to_show # from ._style import _min_points_to_show
from ._editors import SelectRect from ._editors import SelectRect
@ -142,16 +141,13 @@ async def handle_viewmode_kb_inputs(
Qt.Key_Space, Qt.Key_Space,
} }
): ):
godw = view._chart.linked.godwidget view._chart.linked.godwidget.search.focus()
godw.hist_linked.resize_sidepanes(from_linked=godw.rt_linked)
godw.search.focus()
# esc and ctrl-c # esc and ctrl-c
if key == Qt.Key_Escape or (ctrl and key == Qt.Key_C): if key == Qt.Key_Escape or (ctrl and key == Qt.Key_C):
# ctrl-c as cancel # ctrl-c as cancel
# https://forum.qt.io/topic/532/how-to-catch-ctrl-c-on-a-widget/9 # https://forum.qt.io/topic/532/how-to-catch-ctrl-c-on-a-widget/9
view.select_box.clear() view.select_box.clear()
view.linked.focus()
# cancel order or clear graphics # cancel order or clear graphics
if key == Qt.Key_C or key == Qt.Key_Delete: if key == Qt.Key_C or key == Qt.Key_Delete:
@ -182,17 +178,17 @@ async def handle_viewmode_kb_inputs(
if key in pressed: if key in pressed:
pressed.remove(key) pressed.remove(key)
# QUERY/QUOTE MODE # QUERY/QUOTE MODE #
# ----------------
if {Qt.Key_Q}.intersection(pressed): if {Qt.Key_Q}.intersection(pressed):
view.linked.cursor.in_query_mode = True view.linkedsplits.cursor.in_query_mode = True
else: else:
view.linked.cursor.in_query_mode = False view.linkedsplits.cursor.in_query_mode = False
# SELECTION MODE # SELECTION MODE
# -------------- # --------------
if shift: if shift:
if view.state['mouseMode'] == ViewBox.PanMode: if view.state['mouseMode'] == ViewBox.PanMode:
view.setMouseMode(ViewBox.RectMode) view.setMouseMode(ViewBox.RectMode)
@ -213,27 +209,18 @@ async def handle_viewmode_kb_inputs(
# ORDER MODE # ORDER MODE
# ---------- # ----------
# live vs. dark trigger + an action {buy, sell, alert} # live vs. dark trigger + an action {buy, sell, alert}
order_keys_pressed = ORDER_MODE.intersection(pressed) order_keys_pressed = ORDER_MODE.intersection(pressed)
if order_keys_pressed: if order_keys_pressed:
# TODO: it seems like maybe the composition should be # show the pp size label
# reversed here? Like, maybe we should have the nav have order_mode.current_pp.show()
# access to the pos state and then make encapsulated logic
# that shows the right stuff on screen instead or order mode
# and position-related abstractions doing this?
# show the pp size label only if there is
# a non-zero pos existing
tracker = order_mode.current_pp
if tracker.live_pp.size:
tracker.nav.show()
# TODO: show pp config mini-params in status bar widget # TODO: show pp config mini-params in status bar widget
# mode.pp_config.show() # mode.pp_config.show()
trigger_type: str = 'dark'
if ( if (
# 's' for "submit" to activate "live" order # 's' for "submit" to activate "live" order
Qt.Key_S in pressed or Qt.Key_S in pressed or
@ -241,6 +228,9 @@ async def handle_viewmode_kb_inputs(
): ):
trigger_type: str = 'live' trigger_type: str = 'live'
else:
trigger_type: str = 'dark'
# order mode trigger "actions" # order mode trigger "actions"
if Qt.Key_D in pressed: # for "damp eet" if Qt.Key_D in pressed: # for "damp eet"
action = 'sell' action = 'sell'
@ -269,8 +259,8 @@ async def handle_viewmode_kb_inputs(
Qt.Key_S in pressed or Qt.Key_S in pressed or
order_keys_pressed or order_keys_pressed or
Qt.Key_O in pressed Qt.Key_O in pressed
) ) and
and key in NUMBER_LINE key in NUMBER_LINE
): ):
# hot key to set order slots size. # hot key to set order slots size.
# change edit field to current number line value, # change edit field to current number line value,
@ -288,7 +278,7 @@ async def handle_viewmode_kb_inputs(
else: # none active else: # none active
# hide pp label # hide pp label
order_mode.current_pp.nav.hide_info() order_mode.current_pp.hide_info()
# if none are pressed, remove "staged" level # if none are pressed, remove "staged" level
# line under cursor position # line under cursor position
@ -329,6 +319,7 @@ async def handle_viewmode_mouse(
): ):
# when in order mode, submit execution # when in order mode, submit execution
# msg.event.accept() # msg.event.accept()
# breakpoint()
view.order_mode.submit_order() view.order_mode.submit_order()
@ -345,6 +336,16 @@ class ChartView(ViewBox):
''' '''
mode_name: str = 'view' mode_name: str = 'view'
# "relay events" for making overlaid views work.
# NOTE: these MUST be defined here (and can't be monkey patched
# on later) due to signal construction requiring refs to be
# in place during the run of meta-class machinery.
mouseDragEventRelay = QtCore.Signal(object, object, object)
wheelEventRelay = QtCore.Signal(object, object, object)
event_relay_source: 'Optional[ViewBox]' = None
relays: dict[str, QtCore.Signal] = {}
def __init__( def __init__(
self, self,
@ -374,7 +375,7 @@ class ChartView(ViewBox):
y=True, y=True,
) )
self.linked = None self.linkedsplits = None
self._chart: 'ChartPlotWidget' = None # noqa self._chart: 'ChartPlotWidget' = None # noqa
# add our selection box annotator # add our selection box annotator
@ -396,11 +397,8 @@ class ChartView(ViewBox):
''' '''
if self._ic is None: if self._ic is None:
try:
self.chart.pause_all_feeds() self.chart.pause_all_feeds()
self._ic = trio.Event() self._ic = trio.Event()
except RuntimeError:
pass
def signal_ic( def signal_ic(
self, self,
@ -413,12 +411,9 @@ class ChartView(ViewBox):
''' '''
if self._ic: if self._ic:
try:
self._ic.set() self._ic.set()
self._ic = None self._ic = None
self.chart.resume_all_feeds() self.chart.resume_all_feeds()
except RuntimeError:
pass
@asynccontextmanager @asynccontextmanager
async def open_async_input_handler( async def open_async_input_handler(
@ -468,7 +463,7 @@ class ChartView(ViewBox):
self, self,
ev, ev,
axis=None, axis=None,
# relayed_from: ChartView = None, relayed_from: ChartView = None,
): ):
''' '''
Override "center-point" location for scrolling. Override "center-point" location for scrolling.
@ -479,20 +474,13 @@ class ChartView(ViewBox):
TODO: PR a method into ``pyqtgraph`` to make this configurable TODO: PR a method into ``pyqtgraph`` to make this configurable
''' '''
linked = self.linked
if (
not linked
):
# print(f'{self.name} not linked but relay from {relayed_from.name}')
return
if axis in (0, 1): if axis in (0, 1):
mask = [False, False] mask = [False, False]
mask[axis] = self.state['mouseEnabled'][axis] mask[axis] = self.state['mouseEnabled'][axis]
else: else:
mask = self.state['mouseEnabled'][:] mask = self.state['mouseEnabled'][:]
chart = self.linked.chart chart = self.linkedsplits.chart
# don't zoom more then the min points setting # don't zoom more then the min points setting
l, lbar, rbar, r = chart.bars_range() l, lbar, rbar, r = chart.bars_range()
@ -605,20 +593,9 @@ class ChartView(ViewBox):
self, self,
ev, ev,
axis: Optional[int] = None, axis: Optional[int] = None,
# relayed_from: ChartView = None, relayed_from: ChartView = None,
) -> None: ) -> None:
# if relayed_from:
# print(f'PAN: {self.name} -> RELAYED FROM: {relayed_from.name}')
# NOTE since in the overlay case axes are already
# "linked" any x-range change will already be mirrored
# in all overlaid ``PlotItems``, so we need to simply
# ignore the signal here since otherwise we get N-calls
# from N-overlays resulting in an "accelerated" feeling
# panning motion instead of the expect linear shift.
# if relayed_from:
# return
pos = ev.pos() pos = ev.pos()
lastPos = ev.lastPos() lastPos = ev.lastPos()
@ -692,10 +669,7 @@ class ChartView(ViewBox):
# XXX: WHY # XXX: WHY
ev.accept() ev.accept()
try:
self.start_ic() self.start_ic()
except RuntimeError:
pass
# if self._ic is None: # if self._ic is None:
# self.chart.pause_all_feeds() # self.chart.pause_all_feeds()
# self._ic = trio.Event() # self._ic = trio.Event()
@ -787,7 +761,7 @@ class ChartView(ViewBox):
''' '''
name = self.name name = self.name
# print(f'YRANGE ON {name}') # print(f'YRANGE ON {name}')
profiler = Profiler( profiler = pg.debug.Profiler(
msg=f'`ChartView._set_yrange()`: `{name}`', msg=f'`ChartView._set_yrange()`: `{name}`',
disabled=not pg_profile_enabled(), disabled=not pg_profile_enabled(),
ms_threshold=ms_slower_then, ms_threshold=ms_slower_then,
@ -856,33 +830,29 @@ class ChartView(ViewBox):
) -> None: ) -> None:
''' '''
Assign callbacks for rescaling and resampling y-axis data Assign callback for rescaling y-axis automatically
automatically based on data contents and ``ViewBox`` state. based on data contents and ``ViewBox`` state.
''' '''
if src_vb is None: if src_vb is None:
src_vb = self src_vb = self
# widget-UIs/splitter(s) resizing # splitter(s) resizing
src_vb.sigResized.connect(self._set_yrange) src_vb.sigResized.connect(self._set_yrange)
# re-sampling trigger:
# TODO: a smarter way to avoid calling this needlessly? # TODO: a smarter way to avoid calling this needlessly?
# 2 things i can think of: # 2 things i can think of:
# - register downsample-able graphics specially and only # - register downsample-able graphics specially and only
# iterate those. # iterate those.
# - only register this when certain downsample-able graphics are # - only register this when certain downsampleable graphics are
# "added to scene". # "added to scene".
src_vb.sigRangeChangedManually.connect( src_vb.sigRangeChangedManually.connect(
self.maybe_downsample_graphics self.maybe_downsample_graphics
) )
# mouse wheel doesn't emit XRangeChanged # mouse wheel doesn't emit XRangeChanged
src_vb.sigRangeChangedManually.connect(self._set_yrange) src_vb.sigRangeChangedManually.connect(self._set_yrange)
# XXX: enabling these will cause "jittery"-ness
# on zoom where sharp diffs in the y-range will
# not re-size right away until a new sample update?
# if src_vb is not self:
# src_vb.sigXRangeChanged.connect(self._set_yrange) # src_vb.sigXRangeChanged.connect(self._set_yrange)
# src_vb.sigXRangeChanged.connect( # src_vb.sigXRangeChanged.connect(
# self.maybe_downsample_graphics # self.maybe_downsample_graphics
@ -927,7 +897,8 @@ class ChartView(ViewBox):
self, self,
autoscale_overlays: bool = True, autoscale_overlays: bool = True,
): ):
profiler = Profiler(
profiler = pg.debug.Profiler(
msg=f'ChartView.maybe_downsample_graphics() for {self.name}', msg=f'ChartView.maybe_downsample_graphics() for {self.name}',
disabled=not pg_profile_enabled(), disabled=not pg_profile_enabled(),
@ -941,12 +912,8 @@ class ChartView(ViewBox):
# TODO: a faster single-loop-iterator way of doing this XD # TODO: a faster single-loop-iterator way of doing this XD
chart = self._chart chart = self._chart
plots = {chart.name: chart} linked = self.linkedsplits
plots = linked.subplots | {chart.name: chart}
linked = self.linked
if linked:
plots |= linked.subplots
for chart_name, chart in plots.items(): for chart_name, chart in plots.items():
for name, flow in chart._flows.items(): for name, flow in chart._flows.items():
@ -956,7 +923,6 @@ class ChartView(ViewBox):
# XXX: super important to be aware of this. # XXX: super important to be aware of this.
# or not flow.graphics.isVisible() # or not flow.graphics.isVisible()
): ):
# print(f'skipping {flow.name}')
continue continue
# pass in no array which will read and render from the last # pass in no array which will read and render from the last

View File

@ -26,7 +26,6 @@ from PyQt5.QtCore import QPointF
from ._axes import YAxisLabel from ._axes import YAxisLabel
from ._style import hcolor from ._style import hcolor
from ._pg_overrides import PlotItem
class LevelLabel(YAxisLabel): class LevelLabel(YAxisLabel):
@ -133,7 +132,7 @@ class LevelLabel(YAxisLabel):
level = self.fields['level'] level = self.fields['level']
# map "level" to local coords # map "level" to local coords
abs_xy = self._pi.mapFromView(QPointF(0, level)) abs_xy = self._chart.mapFromView(QPointF(0, level))
self.update_label( self.update_label(
abs_xy, abs_xy,
@ -150,7 +149,7 @@ class LevelLabel(YAxisLabel):
h, w = self.set_label_str(fields) h, w = self.set_label_str(fields)
if self._adjust_to_l1: if self._adjust_to_l1:
self._x_offset = self._pi.chart_widget._max_l1_line_len self._x_offset = self._chart._max_l1_line_len
self.setPos(QPointF( self.setPos(QPointF(
self._h_shift * (w + self._x_offset), self._h_shift * (w + self._x_offset),
@ -237,10 +236,10 @@ class L1Label(LevelLabel):
# Set a global "max L1 label length" so we can # Set a global "max L1 label length" so we can
# look it up on order lines and adjust their # look it up on order lines and adjust their
# labels not to overlap with it. # labels not to overlap with it.
chart = self._pi.chart_widget chart = self._chart
chart._max_l1_line_len: float = max( chart._max_l1_line_len: float = max(
chart._max_l1_line_len, chart._max_l1_line_len,
w, w
) )
return h, w return h, w
@ -252,17 +251,17 @@ class L1Labels:
""" """
def __init__( def __init__(
self, self,
plotitem: PlotItem, chart: 'ChartPlotWidget', # noqa
digits: int = 2, digits: int = 2,
size_digits: int = 3, size_digits: int = 3,
font_size: str = 'small', font_size: str = 'small',
) -> None: ) -> None:
chart = self.chart = plotitem.chart_widget self.chart = chart
raxis = plotitem.getAxis('right') raxis = chart.getAxis('right')
kwargs = { kwargs = {
'chart': plotitem, 'chart': chart,
'parent': raxis, 'parent': raxis,
'opacity': 1, 'opacity': 1,

View File

@ -18,14 +18,9 @@
Lines for orders, alerts, L2. Lines for orders, alerts, L2.
""" """
from __future__ import annotations
from functools import partial from functools import partial
from math import floor from math import floor
from typing import ( from typing import Optional, Callable
Optional,
Callable,
TYPE_CHECKING,
)
import pyqtgraph as pg import pyqtgraph as pg
from pyqtgraph import Point, functions as fn from pyqtgraph import Point, functions as fn
@ -42,9 +37,6 @@ from ..calc import humanize
from ._label import Label from ._label import Label
from ._style import hcolor, _font from ._style import hcolor, _font
if TYPE_CHECKING:
from ._cursor import Cursor
# TODO: probably worth investigating if we can # TODO: probably worth investigating if we can
# make .boundingRect() faster: # make .boundingRect() faster:
@ -92,7 +84,7 @@ class LevelLine(pg.InfiniteLine):
self._marker = None self._marker = None
self.only_show_markers_on_hover = only_show_markers_on_hover self.only_show_markers_on_hover = only_show_markers_on_hover
self.track_marker_pos: bool = False self.show_markers: bool = True # presuming the line is hovered at init
# should line go all the way to far end or leave a "margin" # should line go all the way to far end or leave a "margin"
# space for other graphics (eg. L1 book) # space for other graphics (eg. L1 book)
@ -130,9 +122,6 @@ class LevelLine(pg.InfiniteLine):
self._y_incr_mult = 1 / chart.linked.symbol.tick_size self._y_incr_mult = 1 / chart.linked.symbol.tick_size
self._right_end_sc: float = 0 self._right_end_sc: float = 0
# use px caching
self.setCacheMode(QtWidgets.QGraphicsItem.DeviceCoordinateCache)
def txt_offsets(self) -> tuple[int, int]: def txt_offsets(self) -> tuple[int, int]:
return 0, 0 return 0, 0
@ -227,23 +216,20 @@ class LevelLine(pg.InfiniteLine):
y: float y: float
) -> None: ) -> None:
''' '''Chart coordinates cursor tracking callback.
Chart coordinates cursor tracking callback.
this is called by our ``Cursor`` type once this line is set to this is called by our ``Cursor`` type once this line is set to
track the cursor: for every movement this callback is invoked to track the cursor: for every movement this callback is invoked to
reposition the line with the current view coordinates. reposition the line with the current view coordinates.
''' '''
self.movable = True self.movable = True
self.set_level(y) # implictly calls reposition handler self.set_level(y) # implictly calls reposition handler
def mouseDragEvent(self, ev): def mouseDragEvent(self, ev):
''' """Override the ``InfiniteLine`` handler since we need more
Override the ``InfiniteLine`` handler since we need more
detailed control and start end signalling. detailed control and start end signalling.
''' """
cursor = self._chart.linked.cursor cursor = self._chart.linked.cursor
# hide y-crosshair # hide y-crosshair
@ -295,20 +281,10 @@ class LevelLine(pg.InfiniteLine):
# show y-crosshair again # show y-crosshair again
cursor.show_xhair() cursor.show_xhair()
def get_cursor(self) -> Optional[Cursor]:
chart = self._chart
cur = chart.linked.cursor
if self in cur._hovered:
return cur
return None
def delete(self) -> None: def delete(self) -> None:
''' """Remove this line from containing chart/view/scene.
Remove this line from containing chart/view/scene.
''' """
scene = self.scene() scene = self.scene()
if scene: if scene:
for label in self._labels: for label in self._labels:
@ -322,8 +298,9 @@ class LevelLine(pg.InfiniteLine):
# remove from chart/cursor states # remove from chart/cursor states
chart = self._chart chart = self._chart
cur = self.get_cursor() cur = chart.linked.cursor
if cur:
if self in cur._hovered:
cur._hovered.remove(self) cur._hovered.remove(self)
chart.plotItem.removeItem(self) chart.plotItem.removeItem(self)
@ -331,8 +308,8 @@ class LevelLine(pg.InfiniteLine):
def mouseDoubleClickEvent( def mouseDoubleClickEvent(
self, self,
ev: QtGui.QMouseEvent, ev: QtGui.QMouseEvent,
) -> None: ) -> None:
# TODO: enter labels edit mode # TODO: enter labels edit mode
print(f'double click {ev}') print(f'double click {ev}')
@ -357,22 +334,30 @@ class LevelLine(pg.InfiniteLine):
line_end, marker_right, r_axis_x = self._chart.marker_right_points() line_end, marker_right, r_axis_x = self._chart.marker_right_points()
# (legacy) NOTE: at one point this seemed slower when moving around if self.show_markers and self.markers:
# order lines.. not sure if that's still true or why but we've
# dropped the original hacky `.pain()` transform stuff for inf p.setPen(self.pen)
# line markers now - check the git history if it needs to be qgo_draw_markers(
# reverted. self.markers,
if self._marker: self.pen.color(),
if self.track_marker_pos: p,
# make the line end at the marker's x pos vb_left,
line_end = marker_right = self._marker.pos().x() vb_right,
marker_right,
)
# marker_size = self.markers[0][2]
self._maxMarkerSize = max([m[2] / 2. for m in self.markers])
# this seems slower when moving around
# order lines.. not sure wtf is up with that.
# for now we're just using it on the position line.
elif self._marker:
# TODO: make this label update part of a scene-aware-marker # TODO: make this label update part of a scene-aware-marker
# composed annotation # composed annotation
self._marker.setPos( self._marker.setPos(
QPointF(marker_right, self.scene_y()) QPointF(marker_right, self.scene_y())
) )
if hasattr(self._marker, 'label'): if hasattr(self._marker, 'label'):
self._marker.label.update() self._marker.label.update()
@ -394,14 +379,16 @@ class LevelLine(pg.InfiniteLine):
def hide(self) -> None: def hide(self) -> None:
super().hide() super().hide()
mkr = self._marker if self._marker:
if mkr: self._marker.hide()
mkr.hide() # needed for ``order_line()`` lines currently
self._marker.label.hide()
def show(self) -> None: def show(self) -> None:
super().show() super().show()
if self._marker: if self._marker:
self._marker.show() self._marker.show()
# self._marker.label.show()
def scene_y(self) -> float: def scene_y(self) -> float:
return self.getViewBox().mapFromView( return self.getViewBox().mapFromView(
@ -434,10 +421,6 @@ class LevelLine(pg.InfiniteLine):
return path return path
@property
def marker(self) -> LevelMarker:
return self._marker
def hoverEvent(self, ev): def hoverEvent(self, ev):
''' '''
Mouse hover callback. Mouse hover callback.
@ -446,16 +429,17 @@ class LevelLine(pg.InfiniteLine):
cur = self._chart.linked.cursor cur = self._chart.linked.cursor
# hovered # hovered
if ( if (not ev.isExit()) and ev.acceptDrags(QtCore.Qt.LeftButton):
not ev.isExit()
and ev.acceptDrags(QtCore.Qt.LeftButton)
):
# if already hovered we don't need to run again # if already hovered we don't need to run again
if self.mouseHovering is True: if self.mouseHovering is True:
return return
if self.only_show_markers_on_hover: if self.only_show_markers_on_hover:
self.show_markers() self.show_markers = True
if self._marker:
self._marker.show()
# highlight if so configured # highlight if so configured
if self.highlight_on_hover: if self.highlight_on_hover:
@ -498,7 +482,11 @@ class LevelLine(pg.InfiniteLine):
cur._hovered.remove(self) cur._hovered.remove(self)
if self.only_show_markers_on_hover: if self.only_show_markers_on_hover:
self.hide_markers() self.show_markers = False
if self._marker:
self._marker.hide()
self._marker.label.hide()
if self not in cur._trackers: if self not in cur._trackers:
cur.show_xhair(y_label_level=self.value()) cur.show_xhair(y_label_level=self.value())
@ -510,15 +498,6 @@ class LevelLine(pg.InfiniteLine):
self.update() self.update()
def hide_markers(self) -> None:
if self._marker:
self._marker.hide()
self._marker.label.hide()
def show_markers(self) -> None:
if self._marker:
self._marker.show()
def level_line( def level_line(
@ -539,10 +518,9 @@ def level_line(
**kwargs, **kwargs,
) -> LevelLine: ) -> LevelLine:
''' """Convenience routine to add a styled horizontal line to a plot.
Convenience routine to add a styled horizontal line to a plot.
''' """
hl_color = color + '_light' if highlight_on_hover else color hl_color = color + '_light' if highlight_on_hover else color
line = LevelLine( line = LevelLine(
@ -724,7 +702,7 @@ def order_line(
marker = LevelMarker( marker = LevelMarker(
chart=chart, chart=chart,
style=marker_style, style=marker_style,
get_level=line.value, # callback get_level=line.value,
size=marker_size, size=marker_size,
keep_in_view=False, keep_in_view=False,
) )
@ -733,8 +711,7 @@ def order_line(
marker = line.add_marker(marker) marker = line.add_marker(marker)
# XXX: DON'T COMMENT THIS! # XXX: DON'T COMMENT THIS!
# this fixes it the artifact issue! # this fixes it the artifact issue! .. of course, bounding rect stuff
# .. of course, bounding rect stuff
line._maxMarkerSize = marker_size line._maxMarkerSize = marker_size
assert line._marker is marker assert line._marker is marker
@ -755,8 +732,7 @@ def order_line(
if action != 'alert': if action != 'alert':
# add a partial position label if we also added a level # add a partial position label if we also added a level marker
# marker
pp_size_label = Label( pp_size_label = Label(
view=view, view=view,
color=line.color, color=line.color,
@ -790,9 +766,9 @@ def order_line(
# XXX: without this the pp proportion label next the marker # XXX: without this the pp proportion label next the marker
# seems to lag? this is the same issue we had with position # seems to lag? this is the same issue we had with position
# lines which we handle with ``.update_graphcis()``. # lines which we handle with ``.update_graphcis()``.
# marker._on_paint=lambda marker: pp_size_label.update()
marker._on_paint = lambda marker: pp_size_label.update() marker._on_paint = lambda marker: pp_size_label.update()
# XXX: THIS IS AN UNTYPED MONKEY PATCH!?!?!
marker.label = label marker.label = label
# sanity check # sanity check

View File

@ -1,104 +0,0 @@
# piker: trading gear for hackers
# Copyright (C) Tyler Goodlet (in stewardship for piker0)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Notifications utils.
"""
import os
import platform
import subprocess
from typing import Optional
import trio
from ..log import get_logger
from ..clearing._messages import (
Status,
)
log = get_logger(__name__)
_dbus_uid: Optional[str] = ''
async def notify_from_ems_status_msg(
msg: Status,
duration: int = 3000,
is_subproc: bool = False,
) -> None:
'''
Send a linux desktop notification.
Handle subprocesses by discovering the dbus user id
on first call.
'''
if platform.system() != "Linux":
return
# TODO: this in another task?
# not sure if this will ever be a bottleneck,
# we probably could do graphics stuff first tho?
if is_subproc:
global _dbus_uid
su = os.environ.get('SUDO_USER')
if (
not _dbus_uid
and su
):
# TODO: use `trio` but we need to use nursery.start()
# to use pipes?
# result = await trio.run_process(
result = subprocess.run(
[
'id',
'-u',
su,
],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
# check=True
)
_dbus_uid = result.stdout.decode("utf-8").replace('\n', '')
os.environ['DBUS_SESSION_BUS_ADDRESS'] = (
f'unix:path=/run/user/{_dbus_uid}/bus'
)
result = await trio.run_process(
[
'notify-send',
'-u', 'normal',
'-t', f'{duration}',
'piker',
# TODO: add in standard fill/exec info that maybe we
# pack in a broker independent way?
f"'{msg.pformat()}'",
],
capture_stdout=True,
capture_stderr=True,
check=False,
)
if result.returncode != 0:
log.warn(f'No notification daemon installed stderr: {result.stderr}')
log.runtime(result)

View File

@ -25,21 +25,13 @@ from typing import (
import numpy as np import numpy as np
import pyqtgraph as pg import pyqtgraph as pg
from PyQt5 import ( from PyQt5 import QtCore, QtGui, QtWidgets
QtGui, from PyQt5.QtCore import QLineF, QPointF
QtWidgets,
)
from PyQt5.QtCore import (
QLineF,
QRectF,
)
from PyQt5.QtGui import QPainterPath from PyQt5.QtGui import QPainterPath
from .._profile import pg_profile_enabled, ms_slower_then from .._profile import pg_profile_enabled, ms_slower_then
from ._style import hcolor from ._style import hcolor
from ..log import get_logger from ..log import get_logger
from .._profile import Profiler
if TYPE_CHECKING: if TYPE_CHECKING:
from ._chart import LinkedSplits from ._chart import LinkedSplits
@ -98,8 +90,8 @@ class BarItems(pg.GraphicsObject):
self, self,
linked: LinkedSplits, linked: LinkedSplits,
plotitem: 'pg.PlotItem', # noqa plotitem: 'pg.PlotItem', # noqa
color: str = 'bracket', pen_color: str = 'bracket',
last_bar_color: str = 'original', last_bar_color: str = 'bracket',
name: Optional[str] = None, name: Optional[str] = None,
@ -108,8 +100,8 @@ class BarItems(pg.GraphicsObject):
self.linked = linked self.linked = linked
# XXX: for the mega-lulz increasing width here increases draw # XXX: for the mega-lulz increasing width here increases draw
# latency... so probably don't do it until we figure that out. # latency... so probably don't do it until we figure that out.
self._color = color self._color = pen_color
self.bars_pen = pg.mkPen(hcolor(color), width=1) self.bars_pen = pg.mkPen(hcolor(pen_color), width=1)
self.last_bar_pen = pg.mkPen(hcolor(last_bar_color), width=2) self.last_bar_pen = pg.mkPen(hcolor(last_bar_color), width=2)
self._name = name self._name = name
@ -121,13 +113,8 @@ class BarItems(pg.GraphicsObject):
# we expect the downsample curve report this. # we expect the downsample curve report this.
return 0 return 0
# Qt docs: https://doc.qt.io/qt-5/qgraphicsitem.html#boundingRect
def boundingRect(self): def boundingRect(self):
# profiler = Profiler( # Qt docs: https://doc.qt.io/qt-5/qgraphicsitem.html#boundingRect
# msg=f'BarItems.boundingRect(): `{self._name}`',
# disabled=not pg_profile_enabled(),
# ms_threshold=ms_slower_then,
# )
# TODO: Can we do rect caching to make this faster # TODO: Can we do rect caching to make this faster
# like `pg.PlotCurveItem` does? In theory it's just # like `pg.PlotCurveItem` does? In theory it's just
@ -147,37 +134,32 @@ class BarItems(pg.GraphicsObject):
hb.topLeft(), hb.topLeft(),
hb.bottomRight(), hb.bottomRight(),
) )
mn_y = hb_tl.y()
mx_y = hb_br.y()
most_left = hb_tl.x()
most_right = hb_br.x()
# profiler('calc path vertices')
# need to include last bar height or BR will be off # need to include last bar height or BR will be off
# OHLC line segments: [hl, o, c] mx_y = hb_br.y()
last_lines: tuple[QLineF] | None = self._last_bar_lines mn_y = hb_tl.y()
last_lines = self._last_bar_lines
if last_lines: if last_lines:
( body_line = self._last_bar_lines[0]
hl, if body_line:
o, mx_y = max(mx_y, max(body_line.y1(), body_line.y2()))
c, mn_y = min(mn_y, min(body_line.y1(), body_line.y2()))
) = last_lines
most_right = c.x2() + 1
ymx = ymn = c.y2()
if hl: return QtCore.QRectF(
y1, y2 = hl.y1(), hl.y2()
ymn = min(y1, y2)
ymx = max(y1, y2)
mx_y = max(ymx, mx_y)
mn_y = min(ymn, mn_y)
# profiler('calc last bar vertices')
return QRectF( # top left
most_left, QPointF(
hb_tl.x(),
mn_y, mn_y,
most_right - most_left + 1, ),
mx_y - mn_y,
# bottom right
QPointF(
hb_br.x() + 1,
mx_y,
)
) )
def paint( def paint(
@ -188,7 +170,7 @@ class BarItems(pg.GraphicsObject):
) -> None: ) -> None:
profiler = Profiler( profiler = pg.debug.Profiler(
disabled=not pg_profile_enabled(), disabled=not pg_profile_enabled(),
ms_threshold=ms_slower_then, ms_threshold=ms_slower_then,
) )
@ -230,15 +212,11 @@ class BarItems(pg.GraphicsObject):
# relevant fields # relevant fields
ohlc = src_data[fields] ohlc = src_data[fields]
# last_row = ohlc[-1:] last_row = ohlc[-1:]
# individual values # individual values
last_row = i, o, h, l, last = ohlc[-1] last_row = i, o, h, l, last = ohlc[-1]
# times = src_data['time']
# if times[-1] - times[-2]:
# breakpoint()
# generate new lines objects for updatable "current bar" # generate new lines objects for updatable "current bar"
self._last_bar_lines = bar_from_ohlc_row(last_row) self._last_bar_lines = bar_from_ohlc_row(last_row)
@ -269,5 +247,4 @@ class BarItems(pg.GraphicsObject):
# date / from some previous sample. It's weird though # date / from some previous sample. It's weird though
# because i've seen it do this to bars i - 3 back? # because i've seen it do this to bars i - 3 back?
# return ohlc['time'], ohlc['close']
return ohlc['index'], ohlc['close'] return ohlc['index'], ohlc['close']

View File

@ -22,9 +22,12 @@ from __future__ import annotations
from typing import ( from typing import (
Optional, Generic, Optional, Generic,
TypeVar, Callable, TypeVar, Callable,
Literal,
) )
import enum
import sys
# from pydantic import BaseModel, validator from pydantic import BaseModel, validator
from pydantic.generics import GenericModel from pydantic.generics import GenericModel
from PyQt5.QtWidgets import ( from PyQt5.QtWidgets import (
QWidget, QWidget,
@ -35,7 +38,6 @@ from ._forms import (
# FontScaledDelegate, # FontScaledDelegate,
Edit, Edit,
) )
from ..data.types import Struct
DataType = TypeVar('DataType') DataType = TypeVar('DataType')
@ -60,7 +62,7 @@ class Selection(Field[DataType], Generic[DataType]):
options: dict[str, DataType] options: dict[str, DataType]
# value: DataType = None # value: DataType = None
# @validator('value') # , always=True) @validator('value') # , always=True)
def set_value_first( def set_value_first(
cls, cls,
@ -98,7 +100,7 @@ class Edit(Field[DataType], Generic[DataType]):
widget_factory = Edit widget_factory = Edit
class AllocatorPane(Struct): class AllocatorPane(BaseModel):
account = Selection[str]( account = Selection[str](
options=dict.fromkeys( options=dict.fromkeys(

View File

@ -18,27 +18,23 @@
Charting overlay helpers. Charting overlay helpers.
''' '''
from collections import defaultdict from typing import Callable, Optional
from functools import partial
from typing import ( from pyqtgraph.Qt.QtCore import (
Callable, # QObject,
Optional, # Signal,
Qt,
# QEvent,
) )
from pyqtgraph.graphicsItems.AxisItem import AxisItem from pyqtgraph.graphicsItems.AxisItem import AxisItem
from pyqtgraph.graphicsItems.ViewBox import ViewBox from pyqtgraph.graphicsItems.ViewBox import ViewBox
# from pyqtgraph.graphicsItems.GraphicsWidget import GraphicsWidget from pyqtgraph.graphicsItems.GraphicsWidget import GraphicsWidget
from pyqtgraph.graphicsItems.PlotItem.PlotItem import PlotItem from pyqtgraph.graphicsItems.PlotItem.PlotItem import PlotItem
from pyqtgraph.Qt.QtCore import ( from pyqtgraph.Qt.QtCore import QObject, Signal, QEvent
QObject, from pyqtgraph.Qt.QtWidgets import QGraphicsGridLayout, QGraphicsLinearLayout
Signal,
QEvent, from ._interaction import ChartView
Qt,
)
from pyqtgraph.Qt.QtWidgets import (
# QGraphicsGridLayout,
QGraphicsLinearLayout,
)
__all__ = ["PlotItemOverlay"] __all__ = ["PlotItemOverlay"]
@ -84,8 +80,8 @@ class ComposedGridLayout:
``<axis_name>i`` in the layout. ``<axis_name>i`` in the layout.
The ``item: PlotItem`` passed to the constructor's grid layout is The ``item: PlotItem`` passed to the constructor's grid layout is
used verbatim as the "main plot" who's view box is given precedence used verbatim as the "main plot" who's view box is give precedence
for input handling. The main plot's axes are removed from its for input handling. The main plot's axes are removed from it's
layout and placed in the surrounding exterior layouts to allow for layout and placed in the surrounding exterior layouts to allow for
re-ordering if desired. re-ordering if desired.
@ -93,11 +89,16 @@ class ComposedGridLayout:
def __init__( def __init__(
self, self,
item: PlotItem, item: PlotItem,
grid: QGraphicsGridLayout,
reverse: bool = False, # insert items to the "center"
) -> None: ) -> None:
self.items: list[PlotItem] = [] self.items: list[PlotItem] = []
self._pi2axes: dict[ # TODO: use a ``bidict`` here? # self.grid = grid
self.reverse = reverse
# TODO: use a ``bidict`` here?
self._pi2axes: dict[
int, int,
dict[str, AxisItem], dict[str, AxisItem],
] = {} ] = {}
@ -119,13 +120,12 @@ class ComposedGridLayout:
if name in ('top', 'bottom'): if name in ('top', 'bottom'):
orient = Qt.Vertical orient = Qt.Vertical
elif name in ('left', 'right'): elif name in ('left', 'right'):
orient = Qt.Horizontal orient = Qt.Horizontal
layout.setOrientation(orient) layout.setOrientation(orient)
self.insert_plotitem(0, item) self.insert(0, item)
# insert surrounding linear layouts into the parent pi's layout # insert surrounding linear layouts into the parent pi's layout
# such that additional axes can be appended arbitrarily without # such that additional axes can be appended arbitrarily without
@ -159,7 +159,7 @@ class ComposedGridLayout:
# enter plot into list for index tracking # enter plot into list for index tracking
self.items.insert(index, plotitem) self.items.insert(index, plotitem)
def insert_plotitem( def insert(
self, self,
index: int, index: int,
plotitem: PlotItem, plotitem: PlotItem,
@ -171,9 +171,7 @@ class ComposedGridLayout:
''' '''
if index < 0: if index < 0:
raise ValueError( raise ValueError('`insert()` only supports an index >= 0')
'`.insert_plotitem()` only supports an index >= 0'
)
# add plot's axes in sequence to the embedded linear layouts # add plot's axes in sequence to the embedded linear layouts
# for each "side" thus avoiding graphics collisions. # for each "side" thus avoiding graphics collisions.
@ -222,7 +220,7 @@ class ComposedGridLayout:
return index return index
def append_plotitem( def append(
self, self,
item: PlotItem, item: PlotItem,
@ -234,7 +232,7 @@ class ComposedGridLayout:
''' '''
# for left and bottom axes we have to first remove # for left and bottom axes we have to first remove
# items and re-insert to maintain a list-order. # items and re-insert to maintain a list-order.
return self.insert_plotitem(len(self.items), item) return self.insert(len(self.items), item)
def get_axis( def get_axis(
self, self,
@ -251,16 +249,16 @@ class ComposedGridLayout:
named = self._pi2axes[name] named = self._pi2axes[name]
return named.get(index) return named.get(index)
# def pop( def pop(
# self, self,
# item: PlotItem, item: PlotItem,
# ) -> PlotItem: ) -> PlotItem:
# ''' '''
# Remove item and restack all axes in list-order. Remove item and restack all axes in list-order.
# ''' '''
# raise NotImplementedError raise NotImplementedError
# Unimplemented features TODO: # Unimplemented features TODO:
@ -281,6 +279,194 @@ class ComposedGridLayout:
# axis? # axis?
# TODO: we might want to enabled some kind of manual flag to disable
# this method wrapping during type creation? As example a user could
# definitively decide **not** to enable broadcasting support by
# setting something like ``ViewBox.disable_relays = True``?
def mk_relay_method(
signame: str,
slot: Callable[
[ViewBox,
'QEvent',
Optional[AxisItem]],
None,
],
) -> Callable[
[
ViewBox,
# lol, there isn't really a generic type thanks
# to the rewrite of Qt's event system XD
'QEvent',
'Optional[AxisItem]',
'Optional[ViewBox]', # the ``relayed_from`` arg we provide
],
None,
]:
def maybe_broadcast(
vb: 'ViewBox',
ev: 'QEvent',
axis: 'Optional[int]' = None,
relayed_from: 'ViewBox' = None,
) -> None:
'''
(soon to be) Decorator which makes an event handler
"broadcastable" to overlayed ``GraphicsWidget``s.
Adds relay signals based on the decorated handler's name
and conducts a signal broadcast of the relay signal if there
are consumers registered.
'''
# When no relay source has been set just bypass all
# the broadcast machinery.
if vb.event_relay_source is None:
ev.accept()
return slot(
vb,
ev,
axis=axis,
)
if relayed_from:
assert axis is None
# this is a relayed event and should be ignored (so it does not
# halt/short circuit the graphicscene loop). Further the
# surrounding handler for this signal must be allowed to execute
# and get processed by **this consumer**.
# print(f'{vb.name} rx relayed from {relayed_from.name}')
ev.ignore()
return slot(
vb,
ev,
axis=axis,
)
if axis is not None:
# print(f'{vb.name} handling axis event:\n{str(ev)}')
ev.accept()
return slot(
vb,
ev,
axis=axis,
)
elif (
relayed_from is None
and vb.event_relay_source is vb # we are the broadcaster
and axis is None
):
# Broadcast case: this is a source event which will be
# relayed to attached consumers and accepted after all
# consumers complete their own handling followed by this
# routine's processing. Sequence is,
# - pre-relay to all consumers *first* - ``.emit()`` blocks
# until all downstream relay handlers have run.
# - run the source handler for **this** event and accept
# the event
# Access the "bound signal" that is created
# on the widget type as part of instantiation.
signal = getattr(vb, signame)
# print(f'{vb.name} emitting {signame}')
# TODO/NOTE: we could also just bypass a "relay" signal
# entirely and instead call the handlers manually in
# a loop? This probably is a lot simpler and also doesn't
# have any downside, and allows not touching target widget
# internals.
signal.emit(
ev,
axis,
# passing this demarks a broadcasted/relayed event
vb,
)
# accept event so no more relays are fired.
ev.accept()
# call underlying wrapped method with an extra
# ``relayed_from`` value to denote that this is a relayed
# event handling case.
return slot(
vb,
ev,
axis=axis,
)
return maybe_broadcast
# XXX: :( can't define signals **after** class compile time
# so this is not really useful.
# def mk_relay_signal(
# func,
# name: str = None,
# ) -> Signal:
# (
# args,
# varargs,
# varkw,
# defaults,
# kwonlyargs,
# kwonlydefaults,
# annotations
# ) = inspect.getfullargspec(func)
# # XXX: generate a relay signal with 1 extra
# # argument for a ``relayed_from`` kwarg. Since
# # ``'self'`` is already ignored by signals we just need
# # to count the arguments since we're adding only 1 (and
# # ``args`` will capture that).
# numargs = len(args + list(defaults))
# signal = Signal(*tuple(numargs * [object]))
# signame = name or func.__name__ + 'Relay'
# return signame, signal
def enable_relays(
widget: GraphicsWidget,
handler_names: list[str],
) -> list[Signal]:
'''
Method override helper which enables relay of a particular
``Signal`` from some chosen broadcaster widget to a set of
consumer widgets which should operate their event handlers normally
but instead of signals "relayed" from the broadcaster.
Mostly useful for overlaying widgets that handle user input
that you want to overlay graphically. The target ``widget`` type must
define ``QtCore.Signal``s each with a `'Relay'` suffix for each
name provided in ``handler_names: list[str]``.
'''
signals = []
for name in handler_names:
handler = getattr(widget, name)
signame = name + 'Relay'
# ensure the target widget defines a relay signal
relay = getattr(widget, signame)
widget.relays[signame] = name
signals.append(relay)
method = mk_relay_method(signame, handler)
setattr(widget, name, method)
return signals
enable_relays(
ChartView,
['wheelEvent', 'mouseDragEvent']
)
class PlotItemOverlay: class PlotItemOverlay:
''' '''
A composite for managing overlaid ``PlotItem`` instances such that A composite for managing overlaid ``PlotItem`` instances such that
@ -296,18 +482,16 @@ class PlotItemOverlay:
) -> None: ) -> None:
self.root_plotitem: PlotItem = root_plotitem self.root_plotitem: PlotItem = root_plotitem
self.relay_handlers: defaultdict[
str,
list[Callable],
] = defaultdict(list)
# NOTE: required for scene layering/relaying; this guarantees vb = root_plotitem.vb
# the "root" plot receives priority for interaction vb.event_relay_source = vb # TODO: maybe change name?
# events/signals. vb.setZValue(1000) # XXX: critical for scene layering/relaying
root_plotitem.vb.setZValue(10)
self.overlays: list[PlotItem] = [] self.overlays: list[PlotItem] = []
self.layout = ComposedGridLayout(root_plotitem) self.layout = ComposedGridLayout(
root_plotitem,
root_plotitem.layout,
)
self._relays: dict[str, Signal] = {} self._relays: dict[str, Signal] = {}
def add_plotitem( def add_plotitem(
@ -315,10 +499,8 @@ class PlotItemOverlay:
plotitem: PlotItem, plotitem: PlotItem,
index: Optional[int] = None, index: Optional[int] = None,
# event/signal names which will be broadcasted to all added # TODO: we could also put the ``ViewBox.XAxis``
# (relayee) ``PlotItem``s (eg. ``ViewBox.mouseDragEvent``). # style enum here?
relay_events: list[str] = [],
# (0,), # link x # (0,), # link x
# (1,), # link y # (1,), # link y
# (0, 1), # link both # (0, 1), # link both
@ -328,155 +510,58 @@ class PlotItemOverlay:
index = index or len(self.overlays) index = index or len(self.overlays)
root = self.root_plotitem root = self.root_plotitem
# layout: QGraphicsGridLayout = root.layout
self.overlays.insert(index, plotitem) self.overlays.insert(index, plotitem)
vb: ViewBox = plotitem.vb vb: ViewBox = plotitem.vb
# mark this consumer overlay as ready to expect relayed events
# from the root plotitem.
vb.event_relay_source = root.vb
# TODO: some sane way to allow menu event broadcast XD # TODO: some sane way to allow menu event broadcast XD
# vb.setMenuEnabled(False) # vb.setMenuEnabled(False)
# wire up any relay signal(s) from the source plot to added # TODO: inside the `maybe_broadcast()` (soon to be) decorator
# "overlays". We use a plain loop instead of mucking with # we need have checks that consumers have been attached to
# re-connecting signal/slots which tends to be more invasive and # these relay signals.
# harder to implement and provides no measurable performance if link_axes != (0, 1):
# gain.
if relay_events:
for ev_name in relay_events:
relayee_handler: Callable[
[
ViewBox,
# lol, there isn't really a generic type thanks
# to the rewrite of Qt's event system XD
QEvent,
AxisItem | None, # wire up relay signals
], for relay_signal_name, handler_name in vb.relays.items():
None, # print(handler_name)
] = getattr(vb, ev_name) # XXX: Signal class attrs are bound after instantiation
# of the defining type, so we need to access that bound
sub_handlers: list[Callable] = self.relay_handlers[ev_name] # version here.
signal = getattr(root.vb, relay_signal_name)
# on the first registry of a relayed event we pop the handler = getattr(vb, handler_name)
# root's handler and override it to a custom broadcaster signal.connect(handler)
# routine.
if not sub_handlers:
src_handler = getattr(
root.vb,
ev_name,
)
def broadcast(
ev: 'QEvent',
# TODO: drop this viewbox specific input and
# allow a predicate to be passed in by user.
axis: 'Optional[int]' = None,
*,
# these are bound in by the ``partial`` below
# and ensure a unique broadcaster per event.
ev_name: str = None,
src_handler: Callable = None,
relayed_from: 'ViewBox' = None,
# remaining inputs the source handler expects
**kwargs,
) -> None:
'''
Broadcast signal or event: this is a source
event which will be relayed to attached
"relayee" plot item consumers.
The event is accepted halting any further
handlers from being triggered.
Sequence is,
- pre-relay to all consumers *first* - exactly
like how a ``Signal.emit()`` blocks until all
downstream relay handlers have run.
- run the event's source handler event
'''
ev.accept()
# broadcast first to relayees *first*. trigger
# relay of event to all consumers **before**
# processing/consumption in the source handler.
relayed_handlers = self.relay_handlers[ev_name]
assert getattr(vb, ev_name).__name__ == ev_name
# TODO: generalize as an input predicate
if axis is None:
for handler in relayed_handlers:
handler(
ev,
axis=axis,
**kwargs,
)
# run "source" widget's handler last
src_handler(
ev,
axis=axis,
)
# dynamic handler override on the publisher plot
setattr(
root.vb,
ev_name,
partial(
broadcast,
ev_name=ev_name,
src_handler=src_handler
),
)
else:
assert getattr(root.vb, ev_name)
assert relayee_handler not in sub_handlers
# append relayed-to widget's handler to relay table
sub_handlers.append(relayee_handler)
# link dim-axes to root if requested by user. # link dim-axes to root if requested by user.
# TODO: solve more-then-wanted scaled panning on click drag
# which seems to be due to broadcast. So we probably need to
# disable broadcast when axes are linked in a particular
# dimension?
for dim in link_axes: for dim in link_axes:
# link x and y axes to new view box such that the top level # link x and y axes to new view box such that the top level
# viewbox propagates to the root (and whatever other # viewbox propagates to the root (and whatever other
# plotitem overlays that have been added). # plotitem overlays that have been added).
vb.linkView(dim, root.vb) vb.linkView(dim, root.vb)
# => NOTE: in order to prevent "more-then-linear" scaled # make overlaid viewbox impossible to focus since the top
# panning moves on (for eg. click-drag) certain range change # level should handle all input and relay to overlays.
# signals (i.e. ``.sigXRangeChanged``), the user needs to be # NOTE: this was solved with the `setZValue()` above!
# careful that any broadcasted ``relay_events`` are are short
# circuited in sub-handlers (aka relayee's) implementations. As
# an example if a ``ViewBox.mouseDragEvent`` is broadcasted, the
# overlayed implementations need to be sure they either don't
# also link the x-axes (by not providing ``link_axes=(0,)``
# above) or that the relayee ``.mouseDragEvent()`` handlers are
# ready to "``return`` early" in the case that
# ``.sigXRangeChanged`` is emitted as part of linked axes.
# For more details on such signalling mechanics peek in
# ``ViewBox.linkView()``.
# make overlaid viewbox impossible to focus since the top level # TODO: we will probably want to add a "focus" api such that
# should handle all input and relay to overlays. Note that the # a new "top level" ``PlotItem`` can be selected dynamically
# "root" plot item gettingn interaction priority is configured # (and presumably the axes dynamically sorted to match).
# with the ``.setZValue()`` during init.
vb.setFlag( vb.setFlag(
vb.GraphicsItemFlag.ItemIsFocusable, vb.GraphicsItemFlag.ItemIsFocusable,
False False
) )
vb.setFocusPolicy(Qt.NoFocus) vb.setFocusPolicy(Qt.NoFocus)
# => TODO: add a "focus" api for switching the "top level"
# ``PlotItem`` dynamically.
# append-compose into the layout all axes from this plot # append-compose into the layout all axes from this plot
self.layout.insert_plotitem(index, plotitem) self.layout.insert(index, plotitem)
plotitem.setGeometry(root.vb.sceneBoundingRect()) plotitem.setGeometry(root.vb.sceneBoundingRect())
@ -494,7 +579,24 @@ class PlotItemOverlay:
root.vb.setFocus() root.vb.setFocus()
assert root.vb.focusWidget() assert root.vb.focusWidget()
vb.setZValue(100) # XXX: do we need this? Why would you build then destroy?
def remove_plotitem(self, plotItem: PlotItem) -> None:
'''
Remove this ``PlotItem`` from the overlayed set making not shown
and unable to accept input.
'''
...
# TODO: i think this would be super hot B)
def focus_item(self, plotitem: PlotItem) -> PlotItem:
'''
Apply focus to a contained PlotItem thus making it the "top level"
item in the overlay able to accept peripheral's input from the user
and responsible for zoom and panning control via its ``ViewBox``.
'''
...
def get_axis( def get_axis(
self, self,
@ -528,9 +630,8 @@ class PlotItemOverlay:
return axes return axes
# XXX: untested as of now. # TODO: i guess we need this if you want to detach existing plots
# TODO: need this as part of selecting a different root/source # dynamically? XXX: untested as of now.
# plot to rewire interaction event broadcast dynamically.
def _disconnect_all( def _disconnect_all(
self, self,
plotitem: PlotItem, plotitem: PlotItem,
@ -545,22 +646,3 @@ class PlotItemOverlay:
disconnected.append(sig) disconnected.append(sig)
return disconnected return disconnected
# XXX: do we need this? Why would you build then destroy?
# def remove_plotitem(self, plotItem: PlotItem) -> None:
# '''
# Remove this ``PlotItem`` from the overlayed set making not shown
# and unable to accept input.
# '''
# ...
# TODO: i think this would be super hot B)
# def focus_plotitem(self, plotitem: PlotItem) -> PlotItem:
# '''
# Apply focus to a contained PlotItem thus making it the "top level"
# item in the overlay able to accept peripheral's input from the user
# and responsible for zoom and panning control via its ``ViewBox``.
# '''
# ...

View File

@ -19,16 +19,15 @@ Super fast ``QPainterPath`` generation related operator routines.
""" """
from __future__ import annotations from __future__ import annotations
from typing import ( from typing import (
Optional, # Optional,
TYPE_CHECKING, TYPE_CHECKING,
) )
import msgspec
import numpy as np import numpy as np
from numpy.lib import recfunctions as rfn from numpy.lib import recfunctions as rfn
from numba import njit, float64, int64 # , optional from numba import njit, float64, int64 # , optional
# import pyqtgraph as pg # import pyqtgraph as pg
# from PyQt5 import QtGui from PyQt5 import QtGui
# from PyQt5.QtCore import QLineF, QPointF # from PyQt5.QtCore import QLineF, QPointF
from ..data._sharedmem import ( from ..data._sharedmem import (
@ -40,514 +39,53 @@ from ._compression import (
) )
if TYPE_CHECKING: if TYPE_CHECKING:
from ._flows import ( from ._flows import Renderer
Renderer,
Flow,
)
from .._profile import Profiler
def by_index_and_key( def xy_downsample(
renderer: Renderer, x,
array: np.ndarray, y,
array_key: str, uppx,
vr: tuple[int, int],
) -> tuple[ x_spacer: float = 0.5,
np.ndarray,
np.ndarray,
np.ndarray,
]:
return array['index'], array[array_key], 'all'
) -> tuple[np.ndarray, np.ndarray]:
class IncrementalFormatter(msgspec.Struct): # downsample whenever more then 1 pixels per datum can be shown.
''' # always refresh data bounds until we get diffing
Incrementally updating, pre-path-graphics tracking, formatter. # working properly, see above..
bins, x, y = ds_m4(
Allows tracking source data state in an updateable pre-graphics x,
``np.ndarray`` format (in local process memory) as well as y,
incrementally rendering from that format **to** 1d x/y for path uppx,
generation using ``pg.functions.arrayToQPath()``.
'''
shm: ShmArray
flow: Flow
# last read from shm (usually due to an update call)
_last_read: tuple[
int,
int,
np.ndarray
]
@property
def last_read(self) -> tuple | None:
return self._last_read
def __repr__(self) -> str:
msg = (
f'{type(self)}: ->\n\n'
f'fqsn={self.flow.name}\n'
f'shm_name={self.shm.token["shm_name"]}\n\n'
f'last_vr={self._last_vr}\n'
f'last_ivdr={self._last_ivdr}\n\n'
f'xy_nd_start={self.xy_nd_start}\n'
f'xy_nd_stop={self.xy_nd_stop}\n\n'
) )
x_nd_len = 0 # flatten output to 1d arrays suitable for path-graphics generation.
y_nd_len = 0 x = np.broadcast_to(x[:, None], y.shape)
if self.x_nd is not None: x = (x + np.array(
x_nd_len = len(self.x_nd) [-x_spacer, 0, 0, x_spacer]
y_nd_len = len(self.y_nd) )).flatten()
y = y.flatten()
msg += ( return x, y
f'x_nd_len={x_nd_len}\n'
f'y_nd_len={y_nd_len}\n'
)
return msg
def diff( @njit(
self, # TODO: for now need to construct this manually for readonly arrays, see
new_read: tuple[np.ndarray], # https://github.com/numba/numba/issues/4511
) -> tuple[
np.ndarray,
np.ndarray,
]:
(
last_xfirst,
last_xlast,
last_array,
last_ivl,
last_ivr,
last_in_view,
) = self.last_read
# TODO: can the renderer just call ``Flow.read()`` directly?
# unpack latest source data read
(
xfirst,
xlast,
array,
ivl,
ivr,
in_view,
) = new_read
# compute the length diffs between the first/last index entry in
# the input data and the last indexes we have on record from the
# last time we updated the curve index.
prepend_length = int(last_xfirst - xfirst)
append_length = int(xlast - last_xlast)
# blah blah blah
# do diffing for prepend, append and last entry
return (
slice(xfirst, last_xfirst),
prepend_length,
append_length,
slice(last_xlast, xlast),
)
# Incrementally updated xy ndarray formatted data, a pre-1d
# format which is updated and cached independently of the final
# pre-graphics-path 1d format.
x_nd: Optional[np.ndarray] = None
y_nd: Optional[np.ndarray] = None
# indexes which slice into the above arrays (which are allocated
# based on source data shm input size) and allow retrieving
# incrementally updated data.
xy_nd_start: int = 0
xy_nd_stop: int = 0
# TODO: eventually incrementally update 1d-pre-graphics path data?
# x_1d: Optional[np.ndarray] = None
# y_1d: Optional[np.ndarray] = None
# incremental view-change state(s) tracking
_last_vr: tuple[float, float] | None = None
_last_ivdr: tuple[float, float] | None = None
def _track_inview_range(
self,
view_range: tuple[int, int],
) -> bool:
# if a view range is passed, plan to draw the
# source ouput that's "in view" of the chart.
vl, vr = view_range
zoom_or_append = False
last_vr = self._last_vr
# incremental in-view data update.
if last_vr:
lvl, lvr = last_vr # relative slice indices
# TODO: detecting more specifically the interaction changes
# last_ivr = self._last_ivdr or (vl, vr)
# al, ar = last_ivr # abs slice indices
# left_change = abs(x_iv[0] - al) >= 1
# right_change = abs(x_iv[-1] - ar) >= 1
# likely a zoom/pan view change or data append update
if (
(vr - lvr) > 2
or vl < lvl
# append / prepend update
# we had an append update where the view range
# didn't change but the data-viewed (shifted)
# underneath, so we need to redraw.
# or left_change and right_change and last_vr == view_range
# not (left_change and right_change) and ivr
# (
# or abs(x_iv[ivr] - livr) > 1
):
zoom_or_append = True
self._last_vr = view_range
return zoom_or_append
def format_to_1d(
self,
new_read: tuple,
array_key: str,
profiler: Profiler,
slice_to_head: int = -1,
read_src_from_key: bool = True,
slice_to_inview: bool = True,
) -> tuple[
np.ndarray,
np.ndarray,
]:
shm = self.shm
(
_,
_,
array,
ivl,
ivr,
in_view,
) = new_read
(
pre_slice,
prepend_len,
append_len,
post_slice,
) = self.diff(new_read)
if self.y_nd is None:
# we first need to allocate xy data arrays
# from the source data.
self.x_nd, self.y_nd = self.allocate_xy_nd(
shm,
array_key,
)
self.xy_nd_start = shm._first.value
self.xy_nd_stop = shm._last.value
profiler('allocated xy history')
if prepend_len:
y_prepend = shm._array[pre_slice]
if read_src_from_key:
y_prepend = y_prepend[array_key]
(
new_y_nd,
y_nd_slc,
) = self.incr_update_xy_nd(
shm,
array_key,
# this is the pre-sliced, "normally expected"
# new data that an updater would normally be
# expected to process, however in some cases (like
# step curves) the updater routine may want to do
# the source history-data reading itself, so we pass
# both here.
y_prepend,
pre_slice,
prepend_len,
self.xy_nd_start,
self.xy_nd_stop,
is_append=False,
)
# y_nd_view = self.y_nd[y_nd_slc]
self.y_nd[y_nd_slc] = new_y_nd
# if read_src_from_key:
# y_nd_view[:][array_key] = new_y_nd
# else:
# y_nd_view[:] = new_y_nd
self.xy_nd_start = shm._first.value
profiler('prepended xy history: {prepend_length}')
if append_len:
y_append = shm._array[post_slice]
if read_src_from_key:
y_append = y_append[array_key]
(
new_y_nd,
y_nd_slc,
) = self.incr_update_xy_nd(
shm,
array_key,
y_append,
post_slice,
append_len,
self.xy_nd_start,
self.xy_nd_stop,
is_append=True,
)
# self.y_nd[post_slice] = new_y_nd
# self.y_nd[xy_slice or post_slice] = xy_data
self.y_nd[y_nd_slc] = new_y_nd
# if read_src_from_key:
# y_nd_view[:][array_key] = new_y_nd
# else:
# y_nd_view[:] = new_y_nd
self.xy_nd_stop = shm._last.value
profiler('appened xy history: {append_length}')
view_changed: bool = False
view_range: tuple[int, int] = (ivl, ivr)
if slice_to_inview:
view_changed = self._track_inview_range(view_range)
array = in_view
profiler(f'{self.flow.name} view range slice {view_range}')
hist = array[:slice_to_head]
# xy-path data transform: convert source data to a format
# able to be passed to a `QPainterPath` rendering routine.
if not len(hist):
# XXX: this might be why the profiler only has exits?
return
# TODO: hist here should be the pre-sliced
# x/y_data in the case where allocate_xy is
# defined?
x_1d, y_1d, connect = self.format_xy_nd_to_1d(
hist,
array_key,
view_range,
)
# app_tres = None
# if append_len:
# appended = array[-append_len-1:slice_to_head]
# app_tres = self.format_xy_nd_to_1d(
# appended,
# array_key,
# (
# view_range[1] - append_len + slice_to_head,
# view_range[1]
# ),
# )
# # assert (len(appended) - 1) == append_len
# # assert len(appended) == append_len
# print(
# f'{self.flow.name} APPEND LEN: {append_len}\n'
# f'{self.flow.name} APPENDED: {appended}\n'
# f'{self.flow.name} app_tres: {app_tres}\n'
# )
# update the last "in view data range"
if len(x_1d):
self._last_ivdr = x_1d[0], x_1d[slice_to_head]
# TODO: eventually maybe we can implement some kind of
# transform on the ``QPainterPath`` that will more or less
# detect the diff in "elements" terms?
# update diff state since we've now rendered paths.
self._last_read = new_read
profiler('.format_to_1d()')
return (
x_1d,
y_1d,
connect,
prepend_len,
append_len,
view_changed,
# app_tres,
)
###############################
# Sub-type override interface #
###############################
# optional pre-graphics xy formatted data which
# is incrementally updated in sync with the source data.
# XXX: was ``.allocate_xy()``
def allocate_xy_nd(
self,
src_shm: ShmArray,
data_field: str,
index_field: str = 'index',
) -> tuple[
np.ndarray, # x
np.nd.array # y
]:
'''
Convert the structured-array ``src_shm`` format to
a equivalently shaped (and field-less) ``np.ndarray``.
Eg. a 4 field x N struct-array => (N, 4)
'''
y_nd = src_shm._array[data_field].copy()
x_nd = src_shm._array[index_field].copy()
return x_nd, y_nd
# XXX: was ``.update_xy()``
def incr_update_xy_nd(
self,
src_shm: ShmArray,
data_field: str,
new_from_src: np.ndarray, # portion of source that was updated
read_slc: slice,
ln: int, # len of updated
nd_start: int,
nd_stop: int,
is_append: bool,
index_field: str = 'index',
) -> tuple[
np.ndarray,
slice,
]:
# write pushed data to flattened copy
new_y_nd = new_from_src
# XXX
# TODO: this should be returned and written by caller!
# XXX
# generate same-valued-per-row x support based on y shape
if index_field != 'index':
self.x_nd[read_slc, :] = new_from_src[index_field]
return new_y_nd, read_slc
# XXX: was ``.format_xy()``
def format_xy_nd_to_1d(
self,
array: np.ndarray,
array_key: str,
vr: tuple[int, int],
) -> tuple[
np.ndarray, # 1d x
np.ndarray, # 1d y
np.ndarray | str, # connection array/style
]:
'''
Default xy-nd array to 1d pre-graphics-path render routine.
Return single field column data verbatim
'''
return (
array['index'],
array[array_key],
# 1d connection array or style-key to
# ``pg.functions.arrayToQPath()``
'all',
)
class OHLCBarsFmtr(IncrementalFormatter):
fields: list[str] = ['open', 'high', 'low', 'close']
def allocate_xy_nd(
self,
ohlc_shm: ShmArray,
data_field: str,
) -> tuple[
np.ndarray, # x
np.nd.array # y
]:
'''
Convert an input struct-array holding OHLC samples into a pair of
flattened x, y arrays with the same size (datums wise) as the source
data.
'''
y_nd = ohlc_shm.ustruct(self.fields)
# generate an flat-interpolated x-domain
x_nd = (
np.broadcast_to(
ohlc_shm._array['index'][:, None],
(
ohlc_shm._array.size,
# 4, # only ohlc
y_nd.shape[1],
),
) + np.array([-0.5, 0, 0, 0.5])
)
assert y_nd.any()
# write pushed data to flattened copy
return (
x_nd,
y_nd,
)
@staticmethod
@njit(
# TODO: for now need to construct this manually for readonly
# arrays, see https://github.com/numba/numba/issues/4511
# ntypes.tuple((float64[:], float64[:], float64[:]))( # ntypes.tuple((float64[:], float64[:], float64[:]))(
# numba_ohlc_dtype[::1], # contiguous # numba_ohlc_dtype[::1], # contiguous
# int64, # int64,
# optional(float64), # optional(float64),
# ), # ),
nogil=True nogil=True
) )
def path_arrays_from_ohlc( def path_arrays_from_ohlc(
data: np.ndarray, data: np.ndarray,
start: int64, start: int64,
bar_gap: float64 = 0.43, bar_gap: float64 = 0.43,
) -> tuple[ ) -> np.ndarray:
np.ndarray,
np.ndarray,
np.ndarray,
]:
''' '''
Generate an array of lines objects from input ohlc data. Generate an array of lines objects from input ohlc data.
@ -603,120 +141,80 @@ class OHLCBarsFmtr(IncrementalFormatter):
return x, y, c return x, y, c
# TODO: can we drop this frame and just use the above?
def format_xy_nd_to_1d(
self,
array: np.ndarray, def gen_ohlc_qpath(
array_key: str, r: Renderer,
data: np.ndarray,
array_key: str, # we ignore this
vr: tuple[int, int], vr: tuple[int, int],
start: int = 0, # XXX: do we need this? start: int = 0, # XXX: do we need this?
# 0.5 is no overlap between arms, 1.0 is full overlap # 0.5 is no overlap between arms, 1.0 is full overlap
w: float = 0.43, w: float = 0.43,
) -> tuple[ ) -> QtGui.QPainterPath:
np.ndarray,
np.ndarray,
np.ndarray,
]:
''' '''
More or less direct proxy to the ``numba``-fied More or less direct proxy to ``path_arrays_from_ohlc()``
``path_arrays_from_ohlc()`` (above) but with closed in kwargs but with closed in kwargs for line spacing.
for line spacing.
''' '''
x, y, c = self.path_arrays_from_ohlc( x, y, c = path_arrays_from_ohlc(
array, data,
start, start,
bar_gap=w, bar_gap=w,
) )
return x, y, c return x, y, c
def incr_update_xy_nd(
self,
src_shm: ShmArray, def ohlc_to_line(
ohlc_shm: ShmArray,
data_field: str, data_field: str,
fields: list[str] = ['open', 'high', 'low', 'close']
new_from_src: np.ndarray, # portion of source that was updated ) -> tuple[
read_slc: slice,
ln: int, # len of updated
nd_start: int,
nd_stop: int,
is_append: bool,
index_field: str = 'index',
) -> tuple[
np.ndarray, np.ndarray,
slice, np.ndarray,
]: ]:
# write newly pushed data to flattened copy '''
# a struct-arr is always passed in. Convert an input struct-array holding OHLC samples into a pair of
new_y_nd = rfn.structured_to_unstructured( flattened x, y arrays with the same size (datums wise) as the source
new_from_src[self.fields] data.
'''
y_out = ohlc_shm.ustruct(fields)
first = ohlc_shm._first.value
last = ohlc_shm._last.value
# write pushed data to flattened copy
y_out[first:last] = rfn.structured_to_unstructured(
ohlc_shm.array[fields]
) )
# XXX # generate an flat-interpolated x-domain
# TODO: this should be returned and written by caller! x_out = (
# XXX np.broadcast_to(
# generate same-valued-per-row x support based on y shape ohlc_shm._array['index'][:, None],
if index_field != 'index': (
self.x_nd[read_slc, :] = new_from_src[index_field] ohlc_shm._array.size,
# 4, # only ohlc
y_out.shape[1],
),
) + np.array([-0.5, 0, 0, 0.5])
)
assert y_out.any()
return new_y_nd, read_slc return (
x_out,
y_out,
)
class OHLCBarsAsCurveFmtr(OHLCBarsFmtr): def to_step_format(
def format_xy_nd_to_1d(
self,
array: np.ndarray,
array_key: str,
vr: tuple[int, int],
) -> tuple[
np.ndarray,
np.ndarray,
str,
]:
# TODO: in the case of an existing ``.update_xy()``
# should we be passing in array as an xy arrays tuple?
# 2 more datum-indexes to capture zero at end
x_flat = self.x_nd[self.xy_nd_start:self.xy_nd_stop]
y_flat = self.y_nd[self.xy_nd_start:self.xy_nd_stop]
# slice to view
ivl, ivr = vr
x_iv_flat = x_flat[ivl:ivr]
y_iv_flat = y_flat[ivl:ivr]
# reshape to 1d for graphics rendering
y_iv = y_iv_flat.reshape(-1)
x_iv = x_iv_flat.reshape(-1)
return x_iv, y_iv, 'all'
class StepCurveFmtr(IncrementalFormatter):
def allocate_xy_nd(
self,
shm: ShmArray, shm: ShmArray,
data_field: str, data_field: str,
index_field: str = 'index', index_field: str = 'index',
) -> tuple[ ) -> tuple[int, np.ndarray, np.ndarray]:
np.ndarray, # x
np.nd.array # y
]:
''' '''
Convert an input 1d shm array to a "step array" format Convert an input 1d shm array to a "step array" format
for use by path graphics generation. for use by path graphics generation.
@ -736,116 +234,3 @@ class StepCurveFmtr(IncrementalFormatter):
# start y at origin level # start y at origin level
y_out[0, 0] = 0 y_out[0, 0] = 0
return x_out, y_out return x_out, y_out
def incr_update_xy_nd(
self,
src_shm: ShmArray,
array_key: str,
src_update: np.ndarray, # portion of source that was updated
slc: slice,
ln: int, # len of updated
first: int,
last: int,
is_append: bool,
) -> tuple[
np.ndarray,
slice,
]:
# for a step curve we slice from one datum prior
# to the current "update slice" to get the previous
# "level".
if is_append:
start = max(last - 1, 0)
end = src_shm._last.value
new_y = src_shm._array[start:end][array_key]
slc = slice(start, end)
else:
new_y = src_update
return (
np.broadcast_to(
new_y[:, None], (new_y.size, 2),
),
slc,
)
def format_xy_nd_to_1d(
self,
array: np.ndarray,
array_key: str,
vr: tuple[int, int],
) -> tuple[
np.ndarray,
np.ndarray,
str,
]:
lasts = array[['index', array_key]]
last = lasts[array_key][-1]
# 2 more datum-indexes to capture zero at end
x_step = self.x_nd[self.xy_nd_start:self.xy_nd_stop+2]
y_step = self.y_nd[self.xy_nd_start:self.xy_nd_stop+2]
y_step[-1] = last
# slice out in-view data
ivl, ivr = vr
ys_iv = y_step[ivl:ivr+1]
xs_iv = x_step[ivl:ivr+1]
# flatten to 1d
y_iv = ys_iv.reshape(ys_iv.size)
x_iv = xs_iv.reshape(xs_iv.size)
# print(
# f'ys_iv : {ys_iv[-s:]}\n'
# f'y_iv: {y_iv[-s:]}\n'
# f'xs_iv: {xs_iv[-s:]}\n'
# f'x_iv: {x_iv[-s:]}\n'
# )
return x_iv, y_iv, 'all'
def xy_downsample(
x,
y,
uppx,
x_spacer: float = 0.5,
) -> tuple[
np.ndarray,
np.ndarray,
float,
float,
]:
'''
Downsample 1D (flat ``numpy.ndarray``) arrays using M4 given an input
``uppx`` (units-per-pixel) and add space between discreet datums.
'''
# downsample whenever more then 1 pixels per datum can be shown.
# always refresh data bounds until we get diffing
# working properly, see above..
bins, x, y, ymn, ymx = ds_m4(
x,
y,
uppx,
)
# flatten output to 1d arrays suitable for path-graphics generation.
x = np.broadcast_to(x[:, None], y.shape)
x = (x + np.array(
[-x_spacer, 0, 0, x_spacer]
)).flatten()
y = y.flatten()
return x, y, ymn, ymx

View File

@ -15,19 +15,13 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>. # along with this program. If not, see <https://www.gnu.org/licenses/>.
""" """
Customization of ``pyqtgraph`` core routines and various types normally Customization of ``pyqtgraph`` core routines to speed up our use mostly
for speedups. based on not requiring "scentific precision" for pixel perfect view
transforms.
Generally, our does not require "scentific precision" for pixel perfect
view transforms.
""" """
from typing import Optional
import pyqtgraph as pg import pyqtgraph as pg
from ._axes import Axis
def invertQTransform(tr): def invertQTransform(tr):
"""Return a QTransform that is the inverse of *tr*. """Return a QTransform that is the inverse of *tr*.
@ -52,232 +46,3 @@ def _do_overrides() -> None:
""" """
# we don't care about potential fp issues inside Qt # we don't care about potential fp issues inside Qt
pg.functions.invertQTransform = invertQTransform pg.functions.invertQTransform = invertQTransform
pg.PlotItem = PlotItem
# NOTE: the below customized type contains all our changes on a method
# by method basis as per the diff:
# https://github.com/pyqtgraph/pyqtgraph/commit/8e60bc14234b6bec1369ff4192dbfb82f8682920#diff-a2b5865955d2ba703dbc4c35ff01aa761aa28d2aeaac5e68d24e338bc82fb5b1R500
class PlotItem(pg.PlotItem):
'''
Overrides for the core plot object mostly pertaining to overlayed
multi-view management as it relates to multi-axis managment.
This object is the combination of a ``ViewBox`` and multiple
``AxisItem``s and so far we've added additional functionality and
APIs for:
- removal of axes
---
From ``pyqtgraph`` super type docs:
- Manage placement of ViewBox, AxisItems, and LabelItems
- Create and manage a list of PlotDataItems displayed inside the
ViewBox
- Implement a context menu with commonly used display and analysis
options
'''
def __init__(
self,
parent=None,
name=None,
labels=None,
title=None,
viewBox=None,
axisItems=None,
default_axes=['left', 'bottom'],
enableMenu=True,
**kargs
):
super().__init__(
parent=parent,
name=name,
labels=labels,
title=title,
viewBox=viewBox,
axisItems=axisItems,
# default_axes=default_axes,
enableMenu=enableMenu,
kargs=kargs,
)
self.name = name
self.chart_widget = None
# self.setAxisItems(
# axisItems,
# default_axes=default_axes,
# )
# NOTE: this is an entirely new method not in upstream.
def removeAxis(
self,
name: str,
unlink: bool = True,
) -> Optional[pg.AxisItem]:
"""
Remove an axis from the contained axis items
by ```name: str```.
This means the axis graphics object will be removed
from the ``.layout: QGraphicsGridLayout`` as well as unlinked
from the underlying associated ``ViewBox``.
If the ``unlink: bool`` is set to ``False`` then the axis will
stay linked to its view and will only be removed from the
layoutonly be removed from the layout.
If no axis with ``name: str`` is found then this is a noop.
Return the axis instance that was removed.
"""
entry = self.axes.pop(name, None)
if not entry:
return
axis = entry['item']
self.layout.removeItem(axis)
axis.scene().removeItem(axis)
if unlink:
axis.unlinkFromView()
self.update()
return axis
# Why do we need to always have all axes created?
#
# I don't understand this at all.
#
# Everything seems to work if you just always apply the
# set passed to this method **EXCEPT** for some super weird reason
# the view box geometry still computes as though the space for the
# `'bottom'` axis is always there **UNLESS** you always add that
# axis but hide it?
#
# Why in tf would this be the case!?!?
def setAxisItems(
self,
# XXX: yeah yeah, i know we can't use type annots like this yet.
axisItems: Optional[dict[str, pg.AxisItem]] = None,
add_to_layout: bool = True,
default_axes: list[str] = ['left', 'bottom'],
):
"""
Override axis item setting to only
"""
axisItems = axisItems or {}
# XXX: wth is is this even saying?!?
# Array containing visible axis items
# Also containing potentially hidden axes, but they are not
# touched so it does not matter
# visibleAxes = ['left', 'bottom']
# Note that it does not matter that this adds
# some values to visibleAxes a second time
# XXX: uhhh wat^ ..?
visibleAxes = list(default_axes) + list(axisItems.keys())
# TODO: we should probably invert the loop here to not loop the
# predefined "axis name set" and instead loop the `axisItems`
# input and lookup indices from a predefined map.
for name, pos in (
('top', (1, 1)),
('bottom', (3, 1)),
('left', (2, 0)),
('right', (2, 2))
):
if (
name in self.axes and
name in axisItems
):
# we already have an axis entry for this name
# so remove the existing entry.
self.removeAxis(name)
# elif name not in axisItems:
# # this axis entry is not provided in this call
# # so remove any old/existing entry.
# self.removeAxis(name)
# Create new axis
if name in axisItems:
axis = axisItems[name]
if axis.scene() is not None:
if (
name not in self.axes
or axis != self.axes[name]["item"]
):
raise RuntimeError(
"Can't add an axis to multiple plots. Shared axes"
" can be achieved with multiple AxisItem instances"
" and set[X/Y]Link.")
else:
# Set up new axis
# XXX: ok but why do we want to add axes for all entries
# if not desired by the user? The only reason I can see
# adding this is without it there's some weird
# ``ViewBox`` geometry bug.. where a gap for the
# 'bottom' axis is somehow left in?
# axis = pg.AxisItem(orientation=name, parent=self)
axis = Axis(
self,
orientation=name,
parent=self,
)
axis.linkToView(self.vb)
# XXX: shouldn't you already know the ``pos`` from the name?
# Oh right instead of a global map that would let you
# reasily look that up it's redefined over and over and over
# again in methods..
self.axes[name] = {'item': axis, 'pos': pos}
# NOTE: in the overlay case the axis may be added to some
# other layout and should not be added here.
if add_to_layout:
self.layout.addItem(axis, *pos)
# place axis above images at z=0, items that want to draw
# over the axes should be placed at z>=1:
axis.setZValue(0.5)
axis.setFlag(
axis.GraphicsItemFlag.ItemNegativeZStacksBehindParent
)
if name in visibleAxes:
self.showAxis(name, True)
else:
# why do we need to insert all axes to ``.axes`` and
# only hide the ones the user doesn't specify? It all
# seems to work fine without doing this except for this
# weird gap for the 'bottom' axis that always shows up
# in the view box geometry??
self.hideAxis(name)
def updateGrid(
self,
*args,
):
alpha = self.ctrl.gridAlphaSlider.value()
x = alpha if self.ctrl.xGridCheck.isChecked() else False
y = alpha if self.ctrl.yGridCheck.isChecked() else False
for name, dim in (
('top', x),
('bottom', x),
('left', y),
('right', y)
):
if name in self.axes:
self.getAxis(name).setGrid(dim)
# self.getAxis('bottom').setGrid(x)
# self.getAxis('left').setGrid(y)
# self.getAxis('right').setGrid(y)

File diff suppressed because it is too large Load Diff

View File

@ -35,13 +35,9 @@ from collections import defaultdict
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial from functools import partial
from typing import ( from typing import (
Optional, Optional, Callable,
Callable, Awaitable, Sequence,
Awaitable, Any, AsyncIterator
Sequence,
Any,
AsyncIterator,
Iterator,
) )
import time import time
# from pprint import pformat # from pprint import pformat
@ -123,7 +119,7 @@ class CompleterView(QTreeView):
# TODO: size this based on DPI font # TODO: size this based on DPI font
self.setIndentation(_font.px_size) self.setIndentation(_font.px_size)
self.setUniformRowHeights(True) # self.setUniformRowHeights(True)
# self.setColumnWidth(0, 3) # self.setColumnWidth(0, 3)
# self.setVerticalBarPolicy(Qt.ScrollBarAlwaysOff) # self.setVerticalBarPolicy(Qt.ScrollBarAlwaysOff)
# self.setSizeAdjustPolicy(QAbstractScrollArea.AdjustIgnored) # self.setSizeAdjustPolicy(QAbstractScrollArea.AdjustIgnored)
@ -142,15 +138,13 @@ class CompleterView(QTreeView):
model.setHorizontalHeaderLabels(labels) model.setHorizontalHeaderLabels(labels)
self._font_size: int = 0 # pixels self._font_size: int = 0 # pixels
self._init: bool = False
async def on_pressed(self, idx: QModelIndex) -> None: async def on_pressed(self, idx: QModelIndex) -> None:
''' '''Mouse pressed on view handler.
Mouse pressed on view handler.
''' '''
search = self.parent() search = self.parent()
await search.chart_current_item() await search.chart_current_item(clear_to_cache=False)
search.focus() search.focus()
def set_font_size(self, size: int = 18): def set_font_size(self, size: int = 18):
@ -162,64 +156,56 @@ class CompleterView(QTreeView):
self.setStyleSheet(f"font: {size}px") self.setStyleSheet(f"font: {size}px")
def resize_to_results( # def resizeEvent(self, event: 'QEvent') -> None:
self, # event.accept()
w: Optional[float] = 0, # super().resizeEvent(event)
h: Optional[float] = None,
) -> None: def on_resize(self) -> None:
'''
Resize relay event from god.
'''
self.resize_to_results()
def resize_to_results(self):
model = self.model() model = self.model()
cols = model.columnCount() cols = model.columnCount()
cidx = self.selectionModel().currentIndex() # rows = model.rowCount()
rows = model.rowCount()
self.expandAll()
# compute the approx height in pixels needed to include
# all result rows in view.
row_h = rows_h = self.rowHeight(cidx) * (rows + 1)
for idx, item in self.iter_df_rows():
row_h = self.rowHeight(idx)
rows_h += row_h
# print(f'row_h: {row_h}\nrows_h: {rows_h}')
# TODO: could we just break early here on detection
# of ``rows_h >= h``?
col_w_tot = 0 col_w_tot = 0
for i in range(cols): for i in range(cols):
# only slap in a rows's height's worth
# of padding once at startup.. no idea
if (
not self._init
and row_h
):
col_w_tot = row_h
self._init = True
self.resizeColumnToContents(i) self.resizeColumnToContents(i)
col_w_tot += self.columnWidth(i) col_w_tot += self.columnWidth(i)
# NOTE: if the heigh `h` set here is **too large** then the win = self.window()
# resize event will perpetually trigger as the window causes win_h = win.height()
# some kind of recompute of callbacks.. so we have to ensure edit_h = self.parent().bar.height()
# it's limited. sb_h = win.statusBar().height()
if h:
h: int = round(h)
abs_mx = round(0.91 * h)
self.setMaximumHeight(abs_mx)
if rows_h <= abs_mx: # TODO: probably make this more general / less hacky
# self.setMinimumHeight(rows_h) # we should figure out the exact number of rows to allow
self.setMinimumHeight(rows_h) # inclusive of search bar and header "rows", in pixel terms.
# self.setFixedHeight(rows_h) # Eventually when we have an "info" widget below the results we
# will want space for it and likely terminating the results-view
# space **exactly on a row** would be ideal.
# if row_px > 0:
# rows = ceil(window_h / row_px) - 4
# else:
# rows = 16
# self.setFixedHeight(rows * row_px)
# self.resize(self.width(), rows * row_px)
else: # NOTE: if the heigh set here is **too large** then the resize
self.setMinimumHeight(abs_mx) # event will perpetually trigger as the window causes some kind
# of recompute of callbacks.. so we have to ensure it's limited.
h = win_h - (edit_h + 1.666*sb_h)
assert h > 0
self.setFixedHeight(round(h))
# dyncamically size to width of longest result seen # size to width of longest result seen thus far
curr_w = self.width() # TODO: should we always dynamically scale to longest result?
if curr_w < col_w_tot: if self.width() < col_w_tot:
self.setMinimumWidth(col_w_tot) self.setFixedWidth(col_w_tot)
self.update() self.update()
@ -345,23 +331,6 @@ class CompleterView(QTreeView):
item = model.itemFromIndex(idx) item = model.itemFromIndex(idx)
yield idx, item yield idx, item
def iter_df_rows(
self,
iparent: QModelIndex = QModelIndex(),
) -> Iterator[tuple[QModelIndex, QStandardItem]]:
model = self.model()
isections = model.rowCount(iparent)
for i in range(isections):
idx = model.index(i, 0, iparent)
item = model.itemFromIndex(idx)
yield idx, item
if model.hasChildren(idx):
# recursively yield child items depth-first
yield from self.iter_df_rows(idx)
def find_section( def find_section(
self, self,
section: str, section: str,
@ -385,8 +354,7 @@ class CompleterView(QTreeView):
status_field: str = None, status_field: str = None,
) -> None: ) -> None:
''' '''Clear all result-rows from under the depth = 1 section.
Clear all result-rows from under the depth = 1 section.
''' '''
idx = self.find_section(section) idx = self.find_section(section)
@ -407,6 +375,8 @@ class CompleterView(QTreeView):
else: else:
model.setItem(idx.row(), 1, QStandardItem()) model.setItem(idx.row(), 1, QStandardItem())
self.resize_to_results()
return idx return idx
else: else:
return None return None
@ -416,26 +386,12 @@ class CompleterView(QTreeView):
section: str, section: str,
values: Sequence[str], values: Sequence[str],
clear_all: bool = False, clear_all: bool = False,
reverse: bool = False,
) -> None: ) -> None:
''' '''
Set result-rows for depth = 1 tree section ``section``. Set result-rows for depth = 1 tree section ``section``.
''' '''
if (
values
and not isinstance(values[0], str)
):
flattened: list[str] = []
for val in values:
flattened.extend(val)
values = flattened
if reverse:
values = reversed(values)
model = self.model() model = self.model()
if clear_all: if clear_all:
# XXX: rewrite the model from scratch if caller requests it # XXX: rewrite the model from scratch if caller requests it
@ -488,22 +444,9 @@ class CompleterView(QTreeView):
self.show_matches() self.show_matches()
def show_matches( def show_matches(self) -> None:
self,
wh: Optional[tuple[float, float]] = None,
) -> None:
if wh:
self.resize_to_results(*wh)
else:
# case where it's just an update from results and *NOT*
# a resize of some higher level parent-container widget.
search = self.parent()
w, h = search.space_dims()
self.resize_to_results(w=w, h=h)
self.show() self.show()
self.resize_to_results()
class SearchBar(Edit): class SearchBar(Edit):
@ -523,15 +466,18 @@ class SearchBar(Edit):
self.godwidget = godwidget self.godwidget = godwidget
super().__init__(parent, **kwargs) super().__init__(parent, **kwargs)
self.view: CompleterView = view self.view: CompleterView = view
godwidget._widgets[view.mode_name] = view
def show(self) -> None:
super().show()
self.view.show_matches()
def unfocus(self) -> None: def unfocus(self) -> None:
self.parent().hide() self.parent().hide()
self.clearFocus() self.clearFocus()
def hide(self) -> None:
if self.view: if self.view:
self.view.hide() self.view.hide()
super().hide()
class SearchWidget(QtWidgets.QWidget): class SearchWidget(QtWidgets.QWidget):
@ -550,16 +496,15 @@ class SearchWidget(QtWidgets.QWidget):
parent=None, parent=None,
) -> None: ) -> None:
super().__init__(parent) super().__init__(parent or godwidget)
# size it as we specify # size it as we specify
self.setSizePolicy( self.setSizePolicy(
QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Fixed,
QtWidgets.QSizePolicy.Fixed, QtWidgets.QSizePolicy.Expanding,
) )
self.godwidget = godwidget self.godwidget = godwidget
godwidget.reg_for_resize(self)
self.vbox = QtWidgets.QVBoxLayout(self) self.vbox = QtWidgets.QVBoxLayout(self)
self.vbox.setContentsMargins(0, 4, 4, 0) self.vbox.setContentsMargins(0, 4, 4, 0)
@ -609,37 +554,20 @@ class SearchWidget(QtWidgets.QWidget):
self.vbox.setAlignment(self.view, Qt.AlignTop | Qt.AlignLeft) self.vbox.setAlignment(self.view, Qt.AlignTop | Qt.AlignLeft)
def focus(self) -> None: def focus(self) -> None:
self.show()
self.bar.focus()
def show_cache_entries(
self,
only: bool = False,
) -> None:
'''
Clear the search results view and show only cached (aka recently
loaded with active data) feeds in the results section.
'''
godw = self.godwidget
# first entry in the cache is the current symbol(s)
fqsns = []
for multi_fqsns in list(godw._chart_cache):
fqsns.extend(list(multi_fqsns))
if self.view.model().rowCount(QModelIndex()) == 0:
# fill cache list if nothing existing
self.view.set_section_entries( self.view.set_section_entries(
'cache', 'cache',
list(fqsns), list(reversed(self.godwidget._chart_cache)),
# remove all other completion results except for cache clear_all=True,
clear_all=only,
reverse=True,
) )
self.bar.focus()
self.show()
def get_current_item(self) -> Optional[tuple[str, str]]: def get_current_item(self) -> Optional[tuple[str, str]]:
''' '''Return the current completer tree selection as
Return the current completer tree selection as
a tuple ``(parent: str, child: str)`` if valid, else ``None``. a tuple ``(parent: str, child: str)`` if valid, else ``None``.
''' '''
@ -675,8 +603,7 @@ class SearchWidget(QtWidgets.QWidget):
clear_to_cache: bool = True, clear_to_cache: bool = True,
) -> Optional[str]: ) -> Optional[str]:
''' '''Attempt to load and switch the current selected
Attempt to load and switch the current selected
completion result to the affiliated chart app. completion result to the affiliated chart app.
Return any loaded symbol. Return any loaded symbol.
@ -687,15 +614,14 @@ class SearchWidget(QtWidgets.QWidget):
return None return None
provider, symbol = value provider, symbol = value
godw = self.godwidget chart = self.godwidget
fqsn = f'{symbol}.{provider}' log.info(f'Requesting symbol: {symbol}.{provider}')
log.info(f'Requesting symbol: {fqsn}')
# assert provider in symbol await chart.load_symbol(
await godw.load_symbols( provider,
fqsns=[fqsn], symbol,
loglevel='info', 'info',
) )
# fully qualified symbol name (SNS i guess is what we're # fully qualified symbol name (SNS i guess is what we're
@ -709,46 +635,18 @@ class SearchWidget(QtWidgets.QWidget):
# Re-order the symbol cache on the chart to display in # Re-order the symbol cache on the chart to display in
# LIFO order. this is normally only done internally by # LIFO order. this is normally only done internally by
# the chart on new symbols being loaded into memory # the chart on new symbols being loaded into memory
godw.set_chart_symbols( chart.set_chart_symbol(fqsn, chart.linkedsplits)
(fqsn,), (
godw.hist_linked, self.view.set_section_entries(
godw.rt_linked, 'cache',
) values=list(reversed(chart._chart_cache)),
)
self.show_cache_entries(only=True) # remove all other completion results except for cache
clear_all=True,
)
self.bar.focus()
return fqsn return fqsn
def space_dims(self) -> tuple[float, float]:
'''
Compute and return the "available space dimentions" for this
search widget in terms of px space for results by return the
pair of width and height.
'''
# XXX: dun need dis rite?
# win = self.window()
# win_h = win.height()
# sb_h = win.statusBar().height()
godw = self.godwidget
hl = godw.hist_linked
edit_h = self.bar.height()
h = hl.height() - edit_h
w = hl.width()
return w, h
def on_resize(self) -> None:
'''
Resize relay event from god, resize all child widgets.
Right now this is just view to contents and/or the fast chart
height.
'''
w, h = self.space_dims()
self.bar.view.show_matches(wh=(w, h))
_search_active: trio.Event = trio.Event() _search_active: trio.Event = trio.Event()
_search_enabled: bool = False _search_enabled: bool = False
@ -784,10 +682,9 @@ async def pack_matches(
with trio.CancelScope() as cs: with trio.CancelScope() as cs:
task_status.started(cs) task_status.started(cs)
# ensure ^ status is updated # ensure ^ status is updated
results = list(await search(pattern)) results = await search(pattern)
# XXX: don't cache the cache results xD if provider != 'cache': # XXX: don't cache the cache results xD
if provider != 'cache':
matches[(provider, pattern)] = results matches[(provider, pattern)] = results
# print(f'results from {provider}: {results}') # print(f'results from {provider}: {results}')
@ -815,11 +712,10 @@ async def fill_results(
max_pause_time: float = 6/16 + 0.001, max_pause_time: float = 6/16 + 0.001,
) -> None: ) -> None:
''' """Task to search through providers and fill in possible
Task to search through providers and fill in possible
completion results. completion results.
''' """
global _search_active, _search_enabled, _searcher_cache global _search_active, _search_enabled, _searcher_cache
bar = search.bar bar = search.bar
@ -833,10 +729,6 @@ async def fill_results(
matches = defaultdict(list) matches = defaultdict(list)
has_results: defaultdict[str, set[str]] = defaultdict(set) has_results: defaultdict[str, set[str]] = defaultdict(set)
# show cached feed list at startup
search.show_cache_entries()
search.on_resize()
while True: while True:
await _search_active.wait() await _search_active.wait()
period = None period = None
@ -850,7 +742,7 @@ async def fill_results(
pattern = await recv_chan.receive() pattern = await recv_chan.receive()
period = time.time() - wait_start period = time.time() - wait_start
log.debug(f'{pattern} after {period}') print(f'{pattern} after {period}')
# during fast multiple key inputs, wait until a pause # during fast multiple key inputs, wait until a pause
# (in typing) to initiate search # (in typing) to initiate search
@ -888,9 +780,8 @@ async def fill_results(
# it hasn't already been searched with the current # it hasn't already been searched with the current
# input pattern (in which case just look up the old # input pattern (in which case just look up the old
# results). # results).
if ( if (period >= pause) and (
period >= pause provider not in already_has_results
and provider not in already_has_results
): ):
# TODO: it may make more sense TO NOT search the # TODO: it may make more sense TO NOT search the
@ -898,9 +789,7 @@ async def fill_results(
# cpu-bound. # cpu-bound.
if provider != 'cache': if provider != 'cache':
view.clear_section( view.clear_section(
provider, provider, status_field='-> searchin..')
status_field='-> searchin..',
)
await n.start( await n.start(
pack_matches, pack_matches,
@ -921,20 +810,11 @@ async def fill_results(
# re-searching it's ``dict`` since it's easier # re-searching it's ``dict`` since it's easier
# but it also causes it to be slower then cached # but it also causes it to be slower then cached
# results from other providers on occasion. # results from other providers on occasion.
if ( if results and provider != 'cache':
results
):
if provider != 'cache':
view.set_section_entries( view.set_section_entries(
section=provider, section=provider,
values=results, values=results,
) )
else:
# if provider == 'cache':
# for the cache just show what we got
# that matches
search.show_cache_entries()
else: else:
view.clear_section(provider) view.clear_section(provider)
@ -961,7 +841,8 @@ async def handle_keyboard_input(
godwidget = search.godwidget godwidget = search.godwidget
view = bar.view view = bar.view
view.set_font_size(bar.dpi_font.px_size) view.set_font_size(bar.dpi_font.px_size)
send, recv = trio.open_memory_channel(616)
send, recv = trio.open_memory_channel(16)
async with trio.open_nursery() as n: async with trio.open_nursery() as n:
@ -976,10 +857,6 @@ async def handle_keyboard_input(
) )
) )
bar.focus()
search.show_cache_entries()
await trio.sleep(0)
async for kbmsg in recv_chan: async for kbmsg in recv_chan:
event, etype, key, mods, txt = kbmsg.to_tuple() event, etype, key, mods, txt = kbmsg.to_tuple()
@ -989,21 +866,19 @@ async def handle_keyboard_input(
if mods == Qt.ControlModifier: if mods == Qt.ControlModifier:
ctl = True ctl = True
if key in ( if key in (Qt.Key_Enter, Qt.Key_Return):
Qt.Key_Enter,
Qt.Key_Return
):
_search_enabled = False
await search.chart_current_item(clear_to_cache=True) await search.chart_current_item(clear_to_cache=True)
search.show_cache_entries(only=True) _search_enabled = False
view.show_matches() continue
search.focus()
elif not ctl and not bar.text(): elif not ctl and not bar.text():
# if nothing in search text show the cache
# TODO: really should factor this somewhere..bc view.set_section_entries(
# we're doin it in another spot as well.. 'cache',
search.show_cache_entries(only=True) list(reversed(godwidget._chart_cache)),
clear_all=True,
)
continue continue
# cancel and close # cancel and close
@ -1012,7 +887,7 @@ async def handle_keyboard_input(
Qt.Key_Space, # i feel like this is the "native" one Qt.Key_Space, # i feel like this is the "native" one
Qt.Key_Alt, Qt.Key_Alt,
}: }:
bar.unfocus() search.bar.unfocus()
# kill the search and focus back on main chart # kill the search and focus back on main chart
if godwidget: if godwidget:
@ -1060,14 +935,11 @@ async def handle_keyboard_input(
if item: if item:
parent_item = item.parent() parent_item = item.parent()
# if we're in the cache section and thus the next
# selection is a cache item, switch and show it
# immediately since it should be very fast.
if parent_item and parent_item.text() == 'cache': if parent_item and parent_item.text() == 'cache':
# if it's a cache item, switch and show it immediately
await search.chart_current_item(clear_to_cache=False) await search.chart_current_item(clear_to_cache=False)
# ACTUAL SEARCH BLOCK #
# where we fuzzy complete and fill out sections.
elif not ctl: elif not ctl:
# relay to completer task # relay to completer task
_search_enabled = True _search_enabled = True
@ -1078,21 +950,13 @@ async def handle_keyboard_input(
async def search_simple_dict( async def search_simple_dict(
text: str, text: str,
source: dict, source: dict,
) -> dict[str, Any]: ) -> dict[str, Any]:
tokens = []
for key in source:
if not isinstance(key, str):
tokens.extend(key)
else:
tokens.append(key)
# search routine can be specified as a function such # search routine can be specified as a function such
# as in the case of the current app's local symbol cache # as in the case of the current app's local symbol cache
matches = fuzzy.extractBests( matches = fuzzy.extractBests(
text, text,
tokens, source.keys(),
score_cutoff=90, score_cutoff=90,
) )

View File

@ -240,12 +240,12 @@ def hcolor(name: str) -> str:
'gunmetal': '#91A3B0', 'gunmetal': '#91A3B0',
'battleship': '#848482', 'battleship': '#848482',
# default ohlc-bars/curve gray
'bracket': '#666666', # like the logo
# bluish # bluish
'charcoal': '#36454F', 'charcoal': '#36454F',
# default bars
'bracket': '#666666', # like the logo
# work well for filled polygons which want a 'bracket' feel # work well for filled polygons which want a 'bracket' feel
# going light to dark # going light to dark
'davies': '#555555', 'davies': '#555555',

View File

@ -21,29 +21,15 @@ Qt main window singletons and stuff.
import os import os
import signal import signal
import time import time
from typing import ( from typing import Callable, Optional, Union
Callable,
Optional,
Union,
)
import uuid import uuid
from pyqtgraph import QtGui
from PyQt5 import QtCore from PyQt5 import QtCore
from PyQt5.QtWidgets import ( from PyQt5.QtWidgets import QLabel, QStatusBar
QWidget,
QMainWindow,
QApplication,
QLabel,
QStatusBar,
)
from PyQt5.QtGui import (
QScreen,
QCloseEvent,
)
from ..log import get_logger from ..log import get_logger
from ._style import _font_small, hcolor from ._style import _font_small, hcolor
from ._chart import GodWidget
log = get_logger(__name__) log = get_logger(__name__)
@ -162,13 +148,12 @@ class MultiStatus:
self.bar.clearMessage() self.bar.clearMessage()
class MainWindow(QMainWindow): class MainWindow(QtGui.QMainWindow):
# XXX: for tiling wms this should scale # XXX: for tiling wms this should scale
# with the alloted window size. # with the alloted window size.
# TODO: detect for tiling and if untrue set some size? # TODO: detect for tiling and if untrue set some size?
# size = (300, 500) size = (300, 500)
godwidget: GodWidget
title = 'piker chart (ur symbol is loading bby)' title = 'piker chart (ur symbol is loading bby)'
@ -177,20 +162,17 @@ class MainWindow(QMainWindow):
# self.setMinimumSize(*self.size) # self.setMinimumSize(*self.size)
self.setWindowTitle(self.title) self.setWindowTitle(self.title)
# set by runtime after `trio` is engaged.
self.godwidget: Optional[GodWidget] = None
self._status_bar: QStatusBar = None self._status_bar: QStatusBar = None
self._status_label: QLabel = None self._status_label: QLabel = None
self._size: Optional[tuple[int, int]] = None self._size: Optional[tuple[int, int]] = None
@property @property
def mode_label(self) -> QLabel: def mode_label(self) -> QtGui.QLabel:
# init mode label # init mode label
if not self._status_label: if not self._status_label:
self._status_label = label = QLabel() self._status_label = label = QtGui.QLabel()
label.setStyleSheet( label.setStyleSheet(
f"""QLabel {{ f"""QLabel {{
color : {hcolor('gunmetal')}; color : {hcolor('gunmetal')};
@ -212,7 +194,8 @@ class MainWindow(QMainWindow):
def closeEvent( def closeEvent(
self, self,
event: QCloseEvent,
event: QtGui.QCloseEvent,
) -> None: ) -> None:
'''Cancel the root actor asap. '''Cancel the root actor asap.
@ -252,8 +235,8 @@ class MainWindow(QMainWindow):
def on_focus_change( def on_focus_change(
self, self,
last: QWidget, last: QtGui.QWidget,
current: QWidget, current: QtGui.QWidget,
) -> None: ) -> None:
@ -264,12 +247,11 @@ class MainWindow(QMainWindow):
name = getattr(current, 'mode_name', '') name = getattr(current, 'mode_name', '')
self.set_mode_name(name) self.set_mode_name(name)
def current_screen(self) -> QScreen: def current_screen(self) -> QtGui.QScreen:
''' """Get a frickin screen (if we can, gawd).
Get a frickin screen (if we can, gawd).
''' """
app = QApplication.instance() app = QtGui.QApplication.instance()
for _ in range(3): for _ in range(3):
screen = app.screenAt(self.pos()) screen = app.screenAt(self.pos())
@ -302,7 +284,7 @@ class MainWindow(QMainWindow):
''' '''
# https://stackoverflow.com/a/18975846 # https://stackoverflow.com/a/18975846
if not size and not self._size: if not size and not self._size:
# app = QApplication.instance() app = QtGui.QApplication.instance()
geo = self.current_screen().geometry() geo = self.current_screen().geometry()
h, w = geo.height(), geo.width() h, w = geo.height(), geo.width()
# use approx 1/3 of the area of the screen by default # use approx 1/3 of the area of the screen by default
@ -310,36 +292,9 @@ class MainWindow(QMainWindow):
self.resize(*size or self._size) self.resize(*size or self._size)
def resizeEvent(self, event: QtCore.QEvent) -> None:
if (
# event.spontaneous()
event.oldSize().height == event.size().height
):
event.ignore()
return
# XXX: uncomment for debugging..
# attrs = {}
# for key in dir(event):
# if key == '__dir__':
# continue
# attr = getattr(event, key)
# try:
# attrs[key] = attr()
# except TypeError:
# attrs[key] = attr
# from pprint import pformat
# print(
# f'{pformat(attrs)}\n'
# f'WINDOW RESIZE: {self.size()}\n\n'
# )
self.godwidget.on_win_resize(event)
event.accept()
# singleton app per actor # singleton app per actor
_qt_win: QMainWindow = None _qt_win: QtGui.QMainWindow = None
def main_window() -> MainWindow: def main_window() -> MainWindow:

View File

@ -46,10 +46,8 @@ def _kivy_import_hack():
@click.argument('name', nargs=1, required=True) @click.argument('name', nargs=1, required=True)
@click.pass_obj @click.pass_obj
def monitor(config, rate, name, dhost, test, tl): def monitor(config, rate, name, dhost, test, tl):
''' """Start a real-time watchlist UI
Start a real-time watchlist UI """
'''
# global opts # global opts
brokermod = config['brokermods'][0] brokermod = config['brokermods'][0]
loglevel = config['loglevel'] loglevel = config['loglevel']
@ -72,12 +70,8 @@ def monitor(config, rate, name, dhost, test, tl):
) as portal: ) as portal:
# run app "main" # run app "main"
await _async_main( await _async_main(
name, name, portal, tickers,
portal, brokermod, rate, test=test,
tickers,
brokermod,
rate,
test=test,
) )
tractor.run( tractor.run(
@ -128,7 +122,7 @@ def optschain(config, symbol, date, rate, test):
@cli.command() @cli.command()
@click.option( @click.option(
'--profile', '--profile',
# '-p', '-p',
default=None, default=None,
help='Enable pyqtgraph profiling' help='Enable pyqtgraph profiling'
) )
@ -137,14 +131,9 @@ def optschain(config, symbol, date, rate, test):
is_flag=True, is_flag=True,
help='Enable tractor debug mode' help='Enable tractor debug mode'
) )
@click.argument('symbols', nargs=-1, required=True) @click.argument('symbol', required=True)
@click.pass_obj @click.pass_obj
def chart( def chart(config, symbol, profile, pdb):
config,
symbols: list[str],
profile,
pdb: bool,
):
''' '''
Start a real-time chartng UI Start a real-time chartng UI
@ -155,10 +144,8 @@ def chart(
_profile._pg_profile = True _profile._pg_profile = True
_profile.ms_slower_then = float(profile) _profile.ms_slower_then = float(profile)
# Qt UI entrypoint
from ._app import _main from ._app import _main
for symbol in symbols:
if '.' not in symbol: if '.' not in symbol:
click.echo(click.style( click.echo(click.style(
f'symbol: {symbol} must have a {symbol}.<provider> suffix', f'symbol: {symbol} must have a {symbol}.<provider> suffix',
@ -166,16 +153,15 @@ def chart(
)) ))
return return
# global opts # global opts
brokernames = config['brokers'] brokernames = config['brokers']
brokermods = config['brokermods']
assert brokermods
tractorloglevel = config['tractorloglevel'] tractorloglevel = config['tractorloglevel']
pikerloglevel = config['loglevel'] pikerloglevel = config['loglevel']
_main( _main(
syms=symbols, sym=symbol,
brokermods=brokermods, brokernames=brokernames,
piker_loglevel=pikerloglevel, piker_loglevel=pikerloglevel,
tractor_kwargs={ tractor_kwargs={
'debug_mode': pdb, 'debug_mode': pdb,
@ -184,6 +170,5 @@ def chart(
'enable_modules': [ 'enable_modules': [
'piker.clearing._client' 'piker.clearing._client'
], ],
'registry_addr': config.get('registry_addr'),
}, },
) )

File diff suppressed because it is too large Load Diff

View File

@ -1,12 +1,13 @@
# we require a pinned dev branch to get some edge features that # we require a pinned dev branch to get some edge features that
# are often untested in tractor's CI and/or being tested by us # are often untested in tractor's CI and/or being tested by us
# first before committing as core features in tractor's base. # first before committing as core features in tractor's base.
-e git+https://github.com/goodboy/tractor.git@piker_pin#egg=tractor -e git+https://github.com/goodboy/tractor.git@master#egg=tractor
# `pyqtgraph` peeps keep breaking, fixing, improving so might as well # `pyqtgraph` peeps keep breaking, fixing, improving so might as well
# pin this to a dev branch that we have more control over especially # pin this to a dev branch that we have more control over especially
# as more graphics stuff gets hashed out. # as more graphics stuff gets hashed out.
-e git+https://github.com/pikers/pyqtgraph.git@master#egg=pyqtgraph -e git+https://github.com/pikers/pyqtgraph.git@piker_pin#egg=pyqtgraph
# our async client for ``marketstore`` (the tsdb) # our async client for ``marketstore`` (the tsdb)
-e git+https://github.com/pikers/anyio-marketstore.git@master#egg=anyio-marketstore -e git+https://github.com/pikers/anyio-marketstore.git@master#egg=anyio-marketstore
@ -17,7 +18,4 @@
# ``asyncvnc`` for sending interactions to ib-gw inside docker # ``asyncvnc`` for sending interactions to ib-gw inside docker
-e git+https://github.com/pikers/asyncvnc.git@main#egg=asyncvnc -e git+https://github.com/pikers/asyncvnc.git@vid_passthrough#egg=asyncvnc
# ``cryptofeed`` for connecting to various crypto exchanges + custom fixes
-e git+https://github.com/pikers/cryptofeed.git@date_parsing#egg=cryptofeed

View File

@ -41,24 +41,23 @@ setup(
}, },
install_requires=[ install_requires=[
'toml', 'toml',
'tomli', # fastest pure py reader
'click', 'click',
'colorlog', 'colorlog',
'attrs', 'attrs',
'pygments', 'pygments',
'colorama', # numba traceback coloring 'colorama', # numba traceback coloring
'msgspec', # performant IPC messaging and structs 'pydantic', # structured data
# async # async
'trio', 'trio',
'trio-websocket', 'trio-websocket',
'msgspec', # performant IPC messaging
'async_generator', 'async_generator',
# from github currently (see requirements.txt) # from github currently (see requirements.txt)
# 'trimeter', # not released yet.. # 'trimeter', # not released yet..
# 'tractor', # 'tractor',
# asyncvnc, # asyncvnc,
# 'cryptofeed',
# brokers # brokers
'asks==2.4.8', 'asks==2.4.8',

View File

@ -1,15 +1,10 @@
from contextlib import asynccontextmanager as acm
import os import os
import pytest import pytest
import tractor import tractor
from piker import ( import trio
# log, from piker import log, config
config, from piker.brokers import questrade
)
from piker._daemon import (
Services,
)
def pytest_addoption(parser): def pytest_addoption(parser):
@ -19,6 +14,15 @@ def pytest_addoption(parser):
help="Use a practice API account") help="Use a practice API account")
@pytest.fixture(scope='session', autouse=True)
def loglevel(request):
orig = tractor.log._default_loglevel
level = tractor.log._default_loglevel = request.config.option.loglevel
log.get_console_log(level)
yield level
tractor.log._default_loglevel = orig
@pytest.fixture(scope='session') @pytest.fixture(scope='session')
def test_config(): def test_config():
dirname = os.path.dirname dirname = os.path.dirname
@ -33,11 +37,9 @@ def test_config():
@pytest.fixture(scope='session', autouse=True) @pytest.fixture(scope='session', autouse=True)
def confdir(request, test_config): def confdir(request, test_config):
''' """If the `--confdir` flag is not passed use the
If the `--confdir` flag is not passed use the
broker config file found in that dir. broker config file found in that dir.
"""
'''
confdir = request.config.option.confdir confdir = request.config.option.confdir
if confdir is not None: if confdir is not None:
config._override_config_dir(confdir) config._override_config_dir(confdir)
@ -45,61 +47,49 @@ def confdir(request, test_config):
return confdir return confdir
# @pytest.fixture(scope='session', autouse=True) @pytest.fixture(scope='session', autouse=True)
# def travis(confdir): def travis(confdir):
# is_travis = os.environ.get('TRAVIS', False) is_travis = os.environ.get('TRAVIS', False)
# if is_travis: if is_travis:
# # this directory is cached, see .travis.yaml # this directory is cached, see .travis.yaml
# conf_file = config.get_broker_conf_path() conf_file = config.get_broker_conf_path()
# refresh_token = os.environ['QT_REFRESH_TOKEN'] refresh_token = os.environ['QT_REFRESH_TOKEN']
# def write_with_token(token): def write_with_token(token):
# # XXX don't pass the dir path here since may be # XXX don't pass the dir path here since may be
# # written behind the scenes in the `confdir fixture` # written behind the scenes in the `confdir fixture`
# if not os.path.isfile(conf_file): if not os.path.isfile(conf_file):
# open(conf_file, 'w').close() open(conf_file, 'w').close()
# conf, path = config.load() conf, path = config.load()
# conf.setdefault('questrade', {}).update( conf.setdefault('questrade', {}).update(
# {'refresh_token': token, {'refresh_token': token,
# 'is_practice': 'True'} 'is_practice': 'True'}
# ) )
# config.write(conf, path) config.write(conf, path)
# async def ensure_config(): async def ensure_config():
# # try to refresh current token using cached brokers config # try to refresh current token using cached brokers config
# # if it fails fail try using the refresh token provided by the # if it fails fail try using the refresh token provided by the
# # env var and if that fails stop the test run here. # env var and if that fails stop the test run here.
# try: try:
# async with questrade.get_client(ask_user=False): async with questrade.get_client(ask_user=False):
# pass pass
# except ( except (
# FileNotFoundError, ValueError, FileNotFoundError, ValueError,
# questrade.BrokerError, questrade.QuestradeError, questrade.BrokerError, questrade.QuestradeError,
# trio.MultiError, trio.MultiError,
# ): ):
# # 3 cases: # 3 cases:
# # - config doesn't have a ``refresh_token`` k/v # - config doesn't have a ``refresh_token`` k/v
# # - cache dir does not exist yet # - cache dir does not exist yet
# # - current token is expired; take it form env var # - current token is expired; take it form env var
# write_with_token(refresh_token) write_with_token(refresh_token)
# async with questrade.get_client(ask_user=False): async with questrade.get_client(ask_user=False):
# pass pass
# # XXX ``pytest_trio`` doesn't support scope or autouse # XXX ``pytest_trio`` doesn't support scope or autouse
# trio.run(ensure_config) trio.run(ensure_config)
_ci_env: bool = os.environ.get('CI', False)
@pytest.fixture(scope='session')
def ci_env() -> bool:
'''
Detect CI envoirment.
'''
return _ci_env
@pytest.fixture @pytest.fixture
@ -115,61 +105,3 @@ def tmx_symbols():
@pytest.fixture @pytest.fixture
def cse_symbols(): def cse_symbols():
return ['TRUL.CN', 'CWEB.CN', 'SNN.CN'] return ['TRUL.CN', 'CWEB.CN', 'SNN.CN']
@acm
async def _open_test_pikerd(
reg_addr: tuple[str, int] | None = None,
**kwargs,
) -> tuple[
str,
int,
tractor.Portal
]:
'''
Testing helper to startup the service tree and runtime on
a different port then the default to allow testing alongside
a running stack.
'''
import random
from piker._daemon import maybe_open_pikerd
if reg_addr is None:
port = random.randint(6e3, 7e3)
reg_addr = ('127.0.0.1', port)
# try:
async with (
maybe_open_pikerd(
registry_addr=reg_addr,
**kwargs,
) as service_manager,
):
# this proc/actor is the pikerd
assert service_manager is Services
async with tractor.wait_for_actor(
'pikerd',
arbiter_sockaddr=reg_addr,
) as portal:
raddr = portal.channel.raddr
assert raddr == reg_addr
yield (
raddr[0],
raddr[1],
portal,
service_manager,
)
@pytest.fixture
def open_test_pikerd():
yield _open_test_pikerd
# TODO: teardown checks such as,
# - no leaked subprocs or shm buffers
# - all requested container service are torn down
# - certain ``tractor`` runtime state?

View File

@ -1,128 +0,0 @@
'''
Data feed layer APIs, performance, msg throttling.
'''
from collections import Counter
from pprint import pprint
from typing import AsyncContextManager
import pytest
# import tractor
import trio
from piker.data import (
ShmArray,
open_feed,
)
from piker.data._source import (
unpack_fqsn,
)
@pytest.mark.parametrize(
'fqsns',
[
# binance
(100, {'btcusdt.binance', 'ethusdt.binance'}, False),
# kraken
(20, {'ethusdt.kraken', 'xbtusd.kraken'}, True),
# binance + kraken
(100, {'btcusdt.binance', 'xbtusd.kraken'}, False),
],
ids=lambda param: f'quotes={param[0]}@fqsns={param[1]}',
)
def test_multi_fqsn_feed(
open_test_pikerd: AsyncContextManager,
fqsns: set[str],
ci_env: bool
):
'''
Start a real-time data feed for provided fqsn and pull
a few quotes then simply shut down.
'''
max_quotes, fqsns, run_in_ci = fqsns
if (
ci_env
and not run_in_ci
):
pytest.skip('Skipping CI disabled test due to feed restrictions')
brokers = set()
for fqsn in fqsns:
brokername, key, suffix = unpack_fqsn(fqsn)
brokers.add(brokername)
async def main():
async with (
open_test_pikerd(),
open_feed(
fqsns,
loglevel='info',
# TODO: ensure throttle rate is applied
# limit to at least display's FPS
# avoiding needless Qt-in-guest-mode context switches
# tick_throttle=_quote_throttle_rate,
) as feed
):
# verify shm buffers exist
for fqin in fqsns:
flume = feed.flumes[fqin]
ohlcv: ShmArray = flume.rt_shm
hist_ohlcv: ShmArray = flume.hist_shm
async with feed.open_multi_stream(brokers) as stream:
# pull the first startup quotes, one for each fqsn, and
# ensure they match each flume's startup quote value.
fqsns_copy = fqsns.copy()
with trio.fail_after(0.5):
for _ in range(1):
first_quotes = await stream.receive()
for fqsn, quote in first_quotes.items():
# XXX: TODO: WTF apparently this error will get
# supressed and only show up in the teardown
# excgroup if we don't have the fix from
# <tractorbugurl>
# assert 0
fqsns_copy.remove(fqsn)
flume = feed.flumes[fqsn]
assert quote['last'] == flume.first_quote['last']
cntr = Counter()
with trio.fail_after(6):
async for quotes in stream:
for fqsn, quote in quotes.items():
cntr[fqsn] += 1
# await tractor.breakpoint()
flume = feed.flumes[fqsn]
ohlcv: ShmArray = flume.rt_shm
hist_ohlcv: ShmArray = flume.hist_shm
# print quote msg, rt and history
# buffer values on console.
rt_row = ohlcv.array[-1]
hist_row = hist_ohlcv.array[-1]
# last = quote['last']
# assert last == rt_row['close']
# assert last == hist_row['close']
pprint(
f'{fqsn}: {quote}\n'
f'rt_ohlc: {rt_row}\n'
f'hist_ohlc: {hist_row}\n'
)
if cntr.total() >= max_quotes:
break
assert set(cntr.keys()) == fqsns
trio.run(main)

View File

@ -8,6 +8,7 @@ from trio.testing import trio_test
from piker.brokers import questrade as qt from piker.brokers import questrade as qt
import pytest import pytest
import tractor import tractor
from tractor.testing import tractor_test
import piker import piker
from piker.brokers import get_brokermod from piker.brokers import get_brokermod
@ -22,12 +23,6 @@ pytestmark = pytest.mark.skipif(
reason="questrade tests can only be run locally with an API key", reason="questrade tests can only be run locally with an API key",
) )
# TODO: this module was removed from tractor into it's
# tests/conftest.py, we need to rewrite the below tests
# to use the `open_pikerd_runtime()` to make these work again
# (if we're not just gonna junk em).
# from tractor.testing import tractor_test
# stock quote # stock quote
_ex_quotes = { _ex_quotes = {
@ -111,7 +106,7 @@ def match_packet(symbols, quotes, feed_type='stock'):
assert not quotes assert not quotes
# @tractor_test @tractor_test
async def test_concurrent_tokens_refresh(us_symbols, loglevel): async def test_concurrent_tokens_refresh(us_symbols, loglevel):
"""Verify that concurrent requests from mulitple tasks work alongside """Verify that concurrent requests from mulitple tasks work alongside
random token refreshing which simulates an access token expiry + refresh random token refreshing which simulates an access token expiry + refresh
@ -342,7 +337,7 @@ async def stream_stocks(feed, symbols):
'options_and_options', 'options_and_options',
], ],
) )
# @tractor_test @tractor_test
async def test_quote_streaming(tmx_symbols, loglevel, stream_what): async def test_quote_streaming(tmx_symbols, loglevel, stream_what):
"""Set up option streaming using the broker daemon. """Set up option streaming using the broker daemon.
""" """

View File

@ -1,176 +0,0 @@
'''
Actor tree daemon sub-service verifications
'''
from typing import AsyncContextManager
from contextlib import asynccontextmanager as acm
import pytest
import trio
import tractor
from piker._daemon import (
find_service,
check_for_service,
Services,
)
from piker.data import (
open_feed,
)
from piker.clearing import (
open_ems,
)
from piker.clearing._messages import (
BrokerdPosition,
Status,
)
from piker.clearing._client import (
OrderBook,
)
def test_runtime_boot(
open_test_pikerd: AsyncContextManager
):
'''
Verify we can boot the `pikerd` service stack using the
`open_test_pikerd` fixture helper and that registry address details
match up.
'''
async def main():
port = 6666
daemon_addr = ('127.0.0.1', port)
services: Services
async with (
open_test_pikerd(
reg_addr=daemon_addr,
) as (_, _, pikerd_portal, services),
tractor.wait_for_actor(
'pikerd',
arbiter_sockaddr=daemon_addr,
) as portal,
):
assert pikerd_portal.channel.raddr == daemon_addr
assert pikerd_portal.channel.raddr == portal.channel.raddr
trio.run(main)
@acm
async def ensure_service(
name: str,
sockaddr: tuple[str, int] | None = None,
) -> None:
async with find_service(name) as portal:
remote_sockaddr = portal.channel.raddr
print(f'FOUND `{name}` @ {remote_sockaddr}')
if sockaddr:
assert remote_sockaddr == sockaddr
yield portal
def test_ensure_datafeed_actors(
open_test_pikerd: AsyncContextManager
) -> None:
'''
Verify that booting a data feed starts a `brokerd`
actor and a singleton global `samplerd` and opening
an order mode in paper opens the `paperboi` service.
'''
actor_name: str = 'brokerd'
backend: str = 'kraken'
brokerd_name: str = f'{actor_name}.{backend}'
async def main():
async with (
open_test_pikerd(),
open_feed(
['xbtusdt.kraken'],
loglevel='info',
) as feed
):
# halt rt quote streams since we aren't testing them
await feed.pause()
async with (
ensure_service(brokerd_name),
ensure_service('samplerd'),
):
pass
trio.run(main)
def test_ensure_ems_in_paper_actors(
open_test_pikerd: AsyncContextManager
) -> None:
actor_name: str = 'brokerd'
backend: str = 'kraken'
brokerd_name: str = f'{actor_name}.{backend}'
async def main():
# type declares
book: OrderBook
trades_stream: tractor.MsgStream
pps: dict[str, list[BrokerdPosition]]
accounts: list[str]
dialogs: dict[str, Status]
# ensure we timeout after is startup is too slow.
# TODO: something like this should be our start point for
# benchmarking end-to-end startup B)
with trio.fail_after(9):
async with (
open_test_pikerd() as (_, _, _, services),
open_ems(
'xbtusdt.kraken',
mode='paper',
) as (
book,
trades_stream,
pps,
accounts,
dialogs,
),
):
# there should be no on-going positions,
# TODO: though eventually we'll want to validate against
# local ledger and `pps.toml` state ;)
assert not pps
assert not dialogs
pikerd_subservices = ['emsd', 'samplerd']
async with (
ensure_service('emsd'),
ensure_service(brokerd_name),
ensure_service(f'paperboi.{backend}'),
):
for name in pikerd_subservices:
assert name in services.service_tasks
# brokerd.kraken actor should have been started
# implicitly by the ems.
assert brokerd_name in services.service_tasks
print('ALL SERVICES STARTED, terminating..')
await services.cancel_service('emsd')
with pytest.raises(
tractor._exceptions.ContextCancelled,
) as exc_info:
trio.run(main)
cancel_msg: str = '`_emsd_main()` was remotely cancelled by its caller'
assert cancel_msg in exc_info.value.args[0]