hist_backfill_fixes: solving conc issues in the tsdb backfiller #62

Open
goodboy wants to merge 12 commits from hist_backfill_fixes into main
10 changed files with 2236 additions and 1571 deletions

View File

@ -1187,7 +1187,7 @@ async def load_aio_clients(
# the API TCP in `ib_insync` connection can be flaky af so instead # the API TCP in `ib_insync` connection can be flaky af so instead
# retry a few times to get the client going.. # retry a few times to get the client going..
connect_retries: int = 3, connect_retries: int = 3,
connect_timeout: float = 10, connect_timeout: float = 30, # in case a remote-host
disconnect_on_exit: bool = True, disconnect_on_exit: bool = True,
) -> dict[str, Client]: ) -> dict[str, Client]:

View File

@ -43,7 +43,6 @@ from typing import (
import numpy as np import numpy as np
from .. import config from .. import config
from ..service import ( from ..service import (
check_for_service, check_for_service,
@ -152,7 +151,10 @@ class StorageConnectionError(ConnectionError):
''' '''
def get_storagemod(name: str) -> ModuleType: def get_storagemod(
name: str,
) -> ModuleType:
mod: ModuleType = import_module( mod: ModuleType = import_module(
'.' + name, '.' + name,
'piker.storage', 'piker.storage',
@ -165,9 +167,12 @@ def get_storagemod(name: str) -> ModuleType:
@acm @acm
async def open_storage_client( async def open_storage_client(
backend: str | None = None, backend: str|None = None,
) -> tuple[ModuleType, StorageClient]: ) -> tuple[
ModuleType,
StorageClient,
]:
''' '''
Load the ``StorageClient`` for named backend. Load the ``StorageClient`` for named backend.
@ -267,7 +272,10 @@ async def open_tsdb_client(
from ..data.feed import maybe_open_feed from ..data.feed import maybe_open_feed
async with ( async with (
open_storage_client() as (_, storage), open_storage_client() as (
_,
storage,
),
maybe_open_feed( maybe_open_feed(
[fqme], [fqme],
@ -275,7 +283,7 @@ async def open_tsdb_client(
) as feed, ) as feed,
): ):
profiler(f'opened feed for {fqme}') profiler(f'opened feed for {fqme!r}')
# to_append = feed.hist_shm.array # to_append = feed.hist_shm.array
# to_prepend = None # to_prepend = None

View File

@ -19,16 +19,10 @@ Storage middle-ware CLIs.
""" """
from __future__ import annotations from __future__ import annotations
# from datetime import datetime
# from contextlib import (
# AsyncExitStack,
# )
from pathlib import Path from pathlib import Path
from math import copysign
import time import time
from types import ModuleType from types import ModuleType
from typing import ( from typing import (
Any,
TYPE_CHECKING, TYPE_CHECKING,
) )
@ -47,7 +41,6 @@ from piker.data import (
ShmArray, ShmArray,
) )
from piker import tsp from piker import tsp
from piker.data._formatters import BGM
from . import log from . import log
from . import ( from . import (
__tsdbs__, __tsdbs__,
@ -242,122 +235,12 @@ def anal(
trio.run(main) trio.run(main)
async def markup_gaps(
fqme: str,
timeframe: float,
actl: AnnotCtl,
wdts: pl.DataFrame,
gaps: pl.DataFrame,
) -> dict[int, dict]:
'''
Remote annotate time-gaps in a dt-fielded ts (normally OHLC)
with rectangles.
'''
aids: dict[int] = {}
for i in range(gaps.height):
row: pl.DataFrame = gaps[i]
# the gap's RIGHT-most bar's OPEN value
# at that time (sample) step.
iend: int = row['index'][0]
# dt: datetime = row['dt'][0]
# dt_prev: datetime = row['dt_prev'][0]
# dt_end_t: float = dt.timestamp()
# TODO: can we eventually remove this
# once we figure out why the epoch cols
# don't match?
# TODO: FIX HOW/WHY these aren't matching
# and are instead off by 4hours (EST
# vs. UTC?!?!)
# end_t: float = row['time']
# assert (
# dt.timestamp()
# ==
# end_t
# )
# the gap's LEFT-most bar's CLOSE value
# at that time (sample) step.
prev_r: pl.DataFrame = wdts.filter(
pl.col('index') == iend - 1
)
# XXX: probably a gap in the (newly sorted or de-duplicated)
# dt-df, so we might need to re-index first..
if prev_r.is_empty():
await tractor.pause()
istart: int = prev_r['index'][0]
# dt_start_t: float = dt_prev.timestamp()
# start_t: float = prev_r['time']
# assert (
# dt_start_t
# ==
# start_t
# )
# TODO: implement px-col width measure
# and ensure at least as many px-cols
# shown per rect as configured by user.
# gap_w: float = abs((iend - istart))
# if gap_w < 6:
# margin: float = 6
# iend += margin
# istart -= margin
rect_gap: float = BGM*3/8
opn: float = row['open'][0]
ro: tuple[float, float] = (
# dt_end_t,
iend + rect_gap + 1,
opn,
)
cls: float = prev_r['close'][0]
lc: tuple[float, float] = (
# dt_start_t,
istart - rect_gap, # + 1 ,
cls,
)
color: str = 'dad_blue'
diff: float = cls - opn
sgn: float = copysign(1, diff)
color: str = {
-1: 'buy_green',
1: 'sell_red',
}[sgn]
rect_kwargs: dict[str, Any] = dict(
fqme=fqme,
timeframe=timeframe,
start_pos=lc,
end_pos=ro,
color=color,
)
aid: int = await actl.add_rect(**rect_kwargs)
assert aid
aids[aid] = rect_kwargs
# tell chart to redraw all its
# graphics view layers Bo
await actl.redraw(
fqme=fqme,
timeframe=timeframe,
)
return aids
@store.command() @store.command()
def ldshm( def ldshm(
fqme: str, fqme: str,
write_parquet: bool = True, write_parquet: bool = True,
reload_parquet_to_shm: bool = True, reload_parquet_to_shm: bool = True,
pdb: bool = False, # --pdb passed?
) -> None: ) -> None:
''' '''
@ -377,7 +260,7 @@ def ldshm(
open_piker_runtime( open_piker_runtime(
'polars_boi', 'polars_boi',
enable_modules=['piker.data._sharedmem'], enable_modules=['piker.data._sharedmem'],
debug_mode=True, debug_mode=pdb,
), ),
open_storage_client() as ( open_storage_client() as (
mod, mod,
@ -397,6 +280,9 @@ def ldshm(
times: np.ndarray = shm.array['time'] times: np.ndarray = shm.array['time']
d1: float = float(times[-1] - times[-2]) d1: float = float(times[-1] - times[-2])
d2: float = 0
# XXX, take a median sample rate if sufficient data
if times.size > 2:
d2: float = float(times[-2] - times[-3]) d2: float = float(times[-2] - times[-3])
med: float = np.median(np.diff(times)) med: float = np.median(np.diff(times))
if ( if (
@ -407,7 +293,6 @@ def ldshm(
raise ValueError( raise ValueError(
f'Something is wrong with time period for {shm}:\n{times}' f'Something is wrong with time period for {shm}:\n{times}'
) )
period_s: float = float(max(d1, d2, med)) period_s: float = float(max(d1, d2, med))
null_segs: tuple = tsp.get_null_segs( null_segs: tuple = tsp.get_null_segs(
@ -417,6 +302,8 @@ def ldshm(
# TODO: call null-seg fixer somehow? # TODO: call null-seg fixer somehow?
if null_segs: if null_segs:
if tractor._state.is_debug_mode():
await tractor.pause() await tractor.pause()
# async with ( # async with (
# trio.open_nursery() as tn, # trio.open_nursery() as tn,
@ -441,9 +328,35 @@ def ldshm(
wdts, wdts,
deduped, deduped,
diff, diff,
) = tsp.dedupe( valid_races,
dq_issues,
) = tsp.dedupe_ohlcv_smart(
shm_df, shm_df,
period=period_s, )
# Report duplicate analysis
if diff > 0:
log.info(
f'Removed {diff} duplicate timestamp(s)\n'
)
if valid_races is not None:
identical: int = (
valid_races
.filter(pl.col('identical_bars'))
.height
)
monotonic: int = valid_races.height - identical
log.info(
f'Valid race conditions: {valid_races.height}\n'
f' - Identical bars: {identical}\n'
f' - Volume monotonic: {monotonic}\n'
)
if dq_issues is not None:
log.warning(
f'DATA QUALITY ISSUES from provider: '
f'{dq_issues.height} timestamp(s)\n'
f'{dq_issues}\n'
) )
# detect gaps from in expected (uniform OHLC) sample period # detect gaps from in expected (uniform OHLC) sample period
@ -460,7 +373,8 @@ def ldshm(
# TODO: actually pull the exact duration # TODO: actually pull the exact duration
# expected for each venue operational period? # expected for each venue operational period?
gap_dt_unit='days', # gap_dt_unit='day',
gap_dt_unit='day',
gap_thresh=1, gap_thresh=1,
) )
@ -471,8 +385,11 @@ def ldshm(
if ( if (
not venue_gaps.is_empty() not venue_gaps.is_empty()
or ( or (
period_s < 60 not step_gaps.is_empty()
and not step_gaps.is_empty() # XXX, i presume i put this bc i was guarding
# for ib venue gaps?
# and
# period_s < 60
) )
): ):
# write repaired ts to parquet-file? # write repaired ts to parquet-file?
@ -521,7 +438,7 @@ def ldshm(
do_markup_gaps: bool = True do_markup_gaps: bool = True
if do_markup_gaps: if do_markup_gaps:
new_df: pl.DataFrame = tsp.np2pl(new) new_df: pl.DataFrame = tsp.np2pl(new)
aids: dict = await markup_gaps( aids: dict = await tsp._annotate.markup_gaps(
fqme, fqme,
period_s, period_s,
actl, actl,
@ -534,8 +451,13 @@ def ldshm(
tf2aids[period_s] = aids tf2aids[period_s] = aids
else: else:
# allow interaction even when no ts problems. # No significant gaps to handle, but may have had
assert not diff # duplicates removed (valid race conditions are ok)
if diff > 0 and dq_issues is not None:
log.warning(
'Found duplicates with data quality issues '
'but no significant time gaps!\n'
)
await tractor.pause() await tractor.pause()
log.info('Exiting TSP shm anal-izer!') log.info('Exiting TSP shm anal-izer!')

File diff suppressed because it is too large Load Diff

View File

@ -578,11 +578,22 @@ def detect_time_gaps(
# NOTE: this flag is to indicate that on this (sampling) time # NOTE: this flag is to indicate that on this (sampling) time
# scale we expect to only be filtering against larger venue # scale we expect to only be filtering against larger venue
# closures-scale time gaps. # closures-scale time gaps.
#
# Map to total_ method since `dt_diff` is a duration type,
# not datetime - modern polars requires `total_*` methods
# for duration types (e.g. `total_days()` not `day()`)
# Ensure plural form for polars API (e.g. 'day' -> 'days')
unit_plural: str = (
gap_dt_unit
if gap_dt_unit.endswith('s')
else f'{gap_dt_unit}s'
)
duration_method: str = f'total_{unit_plural}'
return step_gaps.filter( return step_gaps.filter(
# Second by an arbitrary dt-unit step size # Second by an arbitrary dt-unit step size
getattr( getattr(
pl.col('dt_diff').dt, pl.col('dt_diff').dt,
gap_dt_unit, duration_method,
)().abs() > gap_thresh )().abs() > gap_thresh
) )

View File

@ -0,0 +1,166 @@
# piker: trading gear for hackers
# Copyright (C) 2018-present Tyler Goodlet (in stewardship of pikers)
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""
Time-series (remote) annotation APIs.
"""
from __future__ import annotations
from math import copysign
from typing import (
Any,
TYPE_CHECKING,
)
import polars as pl
import tractor
from piker.data._formatters import BGM
from piker.storage import log
if TYPE_CHECKING:
from piker.ui._remote_ctl import AnnotCtl
async def markup_gaps(
fqme: str,
timeframe: float,
actl: AnnotCtl,
wdts: pl.DataFrame,
gaps: pl.DataFrame,
) -> dict[int, dict]:
'''
Remote annotate time-gaps in a dt-fielded ts (normally OHLC)
with rectangles.
'''
aids: dict[int] = {}
for i in range(gaps.height):
row: pl.DataFrame = gaps[i]
# the gap's RIGHT-most bar's OPEN value
# at that time (sample) step.
iend: int = row['index'][0]
# dt: datetime = row['dt'][0]
# dt_prev: datetime = row['dt_prev'][0]
# dt_end_t: float = dt.timestamp()
# TODO: can we eventually remove this
# once we figure out why the epoch cols
# don't match?
# TODO: FIX HOW/WHY these aren't matching
# and are instead off by 4hours (EST
# vs. UTC?!?!)
# end_t: float = row['time']
# assert (
# dt.timestamp()
# ==
# end_t
# )
# the gap's LEFT-most bar's CLOSE value
# at that time (sample) step.
prev_r: pl.DataFrame = wdts.filter(
pl.col('index') == iend - 1
)
# XXX: probably a gap in the (newly sorted or de-duplicated)
# dt-df, so we might need to re-index first..
dt: pl.Series = row['dt']
dt_prev: pl.Series = row['dt_prev']
if prev_r.is_empty():
# XXX, filter out any special ignore cases,
# - UNIX-epoch stamped datums
# - first row
if (
dt_prev.dt.epoch()[0] == 0
or
dt.dt.epoch()[0] == 0
):
log.warning('Skipping row with UNIX epoch timestamp ??')
continue
if wdts[0]['index'][0] == iend: # first row
log.warning('Skipping first-row (has no previous obvi) !!')
continue
# XXX, if the previous-row by shm-index is missing,
# meaning there is a missing sample (set), get the prior
# row by df index and attempt to use it?
i_wdts: pl.DataFrame = wdts.with_row_index(name='i')
i_row: int = i_wdts.filter(pl.col('index') == iend)['i'][0]
prev_row_by_i = wdts[i_row]
prev_r: pl.DataFrame = prev_row_by_i
# debug any missing pre-row
if tractor._state.is_debug_mode():
await tractor.pause()
istart: int = prev_r['index'][0]
# TODO: implement px-col width measure
# and ensure at least as many px-cols
# shown per rect as configured by user.
# gap_w: float = abs((iend - istart))
# if gap_w < 6:
# margin: float = 6
# iend += margin
# istart -= margin
rect_gap: float = BGM*3/8
opn: float = row['open'][0]
ro: tuple[float, float] = (
# dt_end_t,
iend + rect_gap + 1,
opn,
)
cls: float = prev_r['close'][0]
lc: tuple[float, float] = (
# dt_start_t,
istart - rect_gap, # + 1 ,
cls,
)
color: str = 'dad_blue'
diff: float = cls - opn
sgn: float = copysign(1, diff)
color: str = {
-1: 'buy_green',
1: 'sell_red',
}[sgn]
rect_kwargs: dict[str, Any] = dict(
fqme=fqme,
timeframe=timeframe,
start_pos=lc,
end_pos=ro,
color=color,
)
aid: int = await actl.add_rect(**rect_kwargs)
assert aid
aids[aid] = rect_kwargs
# tell chart to redraw all its
# graphics view layers Bo
await actl.redraw(
fqme=fqme,
timeframe=timeframe,
)
return aids

View File

@ -0,0 +1,206 @@
'''
Smart OHLCV deduplication with data quality validation.
Handles concurrent write conflicts by keeping the most complete bar
(highest volume) while detecting data quality anomalies.
'''
import polars as pl
from ._anal import with_dts
def dedupe_ohlcv_smart(
src_df: pl.DataFrame,
time_col: str = 'time',
volume_col: str = 'volume',
sort: bool = True,
) -> tuple[
pl.DataFrame, # with dts
pl.DataFrame, # deduped (keeping higher volume bars)
int, # count of dupes removed
pl.DataFrame|None, # valid race conditions
pl.DataFrame|None, # data quality violations
]:
'''
Smart OHLCV deduplication keeping most complete bars.
For duplicate timestamps, keeps bar with highest volume under
the assumption that higher volume indicates more complete/final
data from backfill vs partial live updates.
Returns
-------
Tuple of:
- wdts: original dataframe with datetime columns added
- deduped: deduplicated frame keeping highest-volume bars
- diff: number of duplicate rows removed
- valid_races: duplicates meeting expected race condition pattern
(volume monotonic, OHLC ranges valid)
- data_quality_issues: duplicates violating expected relationships
indicating provider data problems
'''
wdts: pl.DataFrame = with_dts(src_df)
# Find duplicate timestamps
dupes: pl.DataFrame = wdts.filter(
pl.col(time_col).is_duplicated()
)
if dupes.is_empty():
# No duplicates, return as-is
return (wdts, wdts, 0, None, None)
# Analyze duplicate groups for validation
dupe_analysis: pl.DataFrame = (
dupes
.sort([time_col, 'index'])
.group_by(time_col, maintain_order=True)
.agg([
pl.col('index').alias('indices'),
pl.col('volume').alias('volumes'),
pl.col('high').alias('highs'),
pl.col('low').alias('lows'),
pl.col('open').alias('opens'),
pl.col('close').alias('closes'),
pl.col('dt').first().alias('dt'),
pl.len().alias('count'),
])
)
# Validate OHLCV monotonicity for each duplicate group
def check_ohlcv_validity(row) -> dict[str, bool]:
'''
Check if duplicate bars follow expected race condition pattern.
For a valid live-update backfill race:
- volume should be monotonically increasing
- high should be monotonically non-decreasing
- low should be monotonically non-increasing
- open should be identical (fixed at bar start)
Returns dict of violation flags.
'''
vols: list = row['volumes']
highs: list = row['highs']
lows: list = row['lows']
opens: list = row['opens']
violations: dict[str, bool] = {
'volume_non_monotonic': False,
'high_decreased': False,
'low_increased': False,
'open_mismatch': False,
'identical_bars': False,
}
# Check if all bars are identical (pure duplicate)
if (
len(set(vols)) == 1
and len(set(highs)) == 1
and len(set(lows)) == 1
and len(set(opens)) == 1
):
violations['identical_bars'] = True
return violations
# Check volume monotonicity
for i in range(1, len(vols)):
if vols[i] < vols[i-1]:
violations['volume_non_monotonic'] = True
break
# Check high monotonicity (can only increase or stay same)
for i in range(1, len(highs)):
if highs[i] < highs[i-1]:
violations['high_decreased'] = True
break
# Check low monotonicity (can only decrease or stay same)
for i in range(1, len(lows)):
if lows[i] > lows[i-1]:
violations['low_increased'] = True
break
# Check open consistency (should be fixed)
if len(set(opens)) > 1:
violations['open_mismatch'] = True
return violations
# Apply validation
dupe_analysis = dupe_analysis.with_columns([
pl.struct(['volumes', 'highs', 'lows', 'opens'])
.map_elements(
check_ohlcv_validity,
return_dtype=pl.Struct([
pl.Field('volume_non_monotonic', pl.Boolean),
pl.Field('high_decreased', pl.Boolean),
pl.Field('low_increased', pl.Boolean),
pl.Field('open_mismatch', pl.Boolean),
pl.Field('identical_bars', pl.Boolean),
])
)
.alias('validity')
])
# Unnest validity struct
dupe_analysis = dupe_analysis.unnest('validity')
# Separate valid races from data quality issues
valid_races: pl.DataFrame|None = (
dupe_analysis
.filter(
# Valid if no violations OR just identical bars
~pl.col('volume_non_monotonic')
& ~pl.col('high_decreased')
& ~pl.col('low_increased')
& ~pl.col('open_mismatch')
)
)
if valid_races.is_empty():
valid_races = None
data_quality_issues: pl.DataFrame|None = (
dupe_analysis
.filter(
# Issues if any non-identical violation exists
(
pl.col('volume_non_monotonic')
| pl.col('high_decreased')
| pl.col('low_increased')
| pl.col('open_mismatch')
)
& ~pl.col('identical_bars')
)
)
if data_quality_issues.is_empty():
data_quality_issues = None
# Deduplicate: keep highest volume bar for each timestamp
deduped: pl.DataFrame = (
wdts
.sort([time_col, volume_col])
.unique(
subset=[time_col],
keep='last',
maintain_order=False,
)
)
# Re-sort by time or index
if sort:
deduped = deduped.sort(by=time_col)
diff: int = wdts.height - deduped.height
return (
wdts,
deduped,
diff,
valid_races,
data_quality_issues,
)

1506
piker/tsp/_history.py 100644

File diff suppressed because it is too large Load Diff

View File

@ -237,8 +237,8 @@ class LevelLabel(YAxisLabel):
class L1Label(LevelLabel): class L1Label(LevelLabel):
text_flags = ( text_flags = (
QtCore.Qt.TextDontClip QtCore.Qt.TextFlag.TextDontClip
| QtCore.Qt.AlignLeft | QtCore.Qt.AlignmentFlag.AlignLeft
) )
def set_label_str( def set_label_str(

View File

@ -0,0 +1,256 @@
#!/usr/bin/env python
'''
Programmatic debugging helper for `pdbp` REPL human-like
interaction but built to allow `claude` to interact with
crashes and `tractor.pause()` breakpoints along side a human dev.
Originally written by `clauded` during a backfiller inspection
session with @goodboy trying to resolve duplicate/gappy ohlcv ts
issues discovered while testing the new `nativedb` tsdb.
Allows `claude` to run `pdb` commands and capture output in an "offline"
manner but generating similar output as if it was iteracting with
the debug REPL.
The use of `pexpect` is heavily based on tractor's REPL UX test
suite(s), namely various `tests/devx/test_debugger.py` patterns.
'''
import sys
import os
import time
import pexpect
from pexpect.exceptions import (
TIMEOUT,
EOF,
)
PROMPT: str = r'\(Pdb\+\)'
def expect(
child: pexpect.spawn,
patt: str,
**kwargs,
) -> None:
'''
Expect wrapper that prints last console data before failing.
'''
try:
child.expect(
patt,
**kwargs,
)
except TIMEOUT:
before: str = (
str(child.before.decode())
if isinstance(child.before, bytes)
else str(child.before)
)
print(
f'TIMEOUT waiting for pattern: {patt}\n'
f'Last seen output:\n{before}'
)
raise
def run_pdb_commands(
commands: list[str],
initial_cmd: str = 'piker store ldshm xmrusdt.usdtm.perp.binance',
timeout: int = 30,
print_output: bool = True,
) -> dict[str, str]:
'''
Spawn piker process, wait for pdb prompt, execute commands.
Returns dict mapping command -> output.
'''
results: dict[str, str] = {}
# Disable colored output for easier parsing
os.environ['PYTHON_COLORS'] = '0'
# Spawn the process
if print_output:
print(f'Spawning: {initial_cmd}')
child: pexpect.spawn = pexpect.spawn(
initial_cmd,
timeout=timeout,
encoding='utf-8',
echo=False,
)
# Wait for pdb prompt
try:
expect(child, PROMPT, timeout=timeout)
if print_output:
print('Reached pdb prompt!')
# Execute each command
for cmd in commands:
if print_output:
print(f'\n>>> {cmd}')
child.sendline(cmd)
time.sleep(0.1)
# Wait for next prompt
expect(child, PROMPT, timeout=timeout)
# Capture output (everything before the prompt)
output: str = (
str(child.before.decode())
if isinstance(child.before, bytes)
else str(child.before)
)
results[cmd] = output
if print_output:
print(output)
# Quit debugger gracefully
child.sendline('quit')
try:
child.expect(EOF, timeout=5)
except (TIMEOUT, EOF):
pass
except TIMEOUT as e:
print(f'Timeout: {e}')
if child.before:
before: str = (
str(child.before.decode())
if isinstance(child.before, bytes)
else str(child.before)
)
print(f'Buffer:\n{before}')
results['_error'] = str(e)
finally:
if child.isalive():
child.close(force=True)
return results
class InteractivePdbSession:
'''
Interactive pdb session manager for incremental debugging.
'''
def __init__(
self,
cmd: str = 'piker store ldshm xmrusdt.usdtm.perp.binance',
timeout: int = 30,
):
self.cmd: str = cmd
self.timeout: int = timeout
self.child: pexpect.spawn|None = None
self.history: list[tuple[str, str]] = []
def start(self) -> None:
'''
Start the piker process and wait for first prompt.
'''
os.environ['PYTHON_COLORS'] = '0'
print(f'Starting: {self.cmd}')
self.child = pexpect.spawn(
self.cmd,
timeout=self.timeout,
encoding='utf-8',
echo=False,
)
# Wait for initial prompt
expect(self.child, PROMPT, timeout=self.timeout)
print('Ready at pdb prompt!')
def run(
self,
cmd: str,
print_output: bool = True,
) -> str:
'''
Execute a single pdb command and return output.
'''
if not self.child or not self.child.isalive():
raise RuntimeError('Session not started or dead')
if print_output:
print(f'\n>>> {cmd}')
self.child.sendline(cmd)
time.sleep(0.1)
# Wait for next prompt
expect(self.child, PROMPT, timeout=self.timeout)
output: str = (
str(self.child.before.decode())
if isinstance(self.child.before, bytes)
else str(self.child.before)
)
self.history.append((cmd, output))
if print_output:
print(output)
return output
def quit(self) -> None:
'''
Exit the debugger and cleanup.
'''
if self.child and self.child.isalive():
self.child.sendline('quit')
try:
self.child.expect(EOF, timeout=5)
except (TIMEOUT, EOF):
pass
self.child.close(force=True)
def __enter__(self):
self.start()
return self
def __exit__(self, *args):
self.quit()
if __name__ == '__main__':
# Example inspection commands
inspect_cmds: list[str] = [
'locals().keys()',
'type(deduped)',
'deduped.shape',
(
'step_gaps.shape '
'if "step_gaps" in locals() '
'else "N/A"'
),
(
'venue_gaps.shape '
'if "venue_gaps" in locals() '
'else "N/A"'
),
]
# Allow commands from CLI args
if len(sys.argv) > 1:
inspect_cmds = sys.argv[1:]
# Interactive session example
with InteractivePdbSession() as session:
for cmd in inspect_cmds:
session.run(cmd)
print('\n=== Session Complete ===')