skynet/skynet/dgpu/tui.py

212 lines
6.5 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

import json
import logging
import warnings
import trio
import urwid
from skynet.config import DgpuConfig as Config
class WorkerMonitor:
def __init__(self):
self.requests = []
self.header_info = {}
self.palette = [
('headerbar', 'white', 'dark blue'),
('request_row', 'white', 'dark gray'),
('worker_row', 'light gray', 'black'),
('progress_normal', 'black', 'light gray'),
('progress_complete', 'black', 'dark green'),
('body', 'white', 'black'),
]
# --- Top bar (header) ---
worker_name = self.header_info.get('left', "unknown")
balance = self.header_info.get('right', "balance: unknown")
self.worker_name_widget = urwid.Text(worker_name)
self.balance_widget = urwid.Text(balance, align='right')
header = urwid.Columns([self.worker_name_widget, self.balance_widget])
header_attr = urwid.AttrMap(header, 'headerbar')
# --- Body (List of requests) ---
self.body_listbox = self._create_listbox_body(self.requests)
# --- Bottom bar (progress) ---
self.status_text = urwid.Text("Request: none", align='left')
self.progress_bar = urwid.ProgressBar(
'progress_normal',
'progress_complete',
current=0,
done=100
)
footer_cols = urwid.Columns([
('fixed', 20, self.status_text),
self.progress_bar,
])
# Build the main frame
frame = urwid.Frame(
self.body_listbox,
header=header_attr,
footer=footer_cols
)
# Set up the main loop with Trio
self.event_loop = urwid.TrioEventLoop()
self.main_loop = urwid.MainLoop(
frame,
palette=self.palette,
event_loop=self.event_loop,
unhandled_input=self._exit_on_q
)
def _create_listbox_body(self, requests):
"""
Build a ListBox (vertical list) of requests & workers using SimpleFocusListWalker.
"""
widgets = self._build_request_widgets(requests)
walker = urwid.SimpleFocusListWalker(widgets)
return urwid.ListBox(walker)
def _build_request_widgets(self, requests):
"""
Build a list of Urwid widgets (one row per request + per worker).
"""
row_widgets = []
for req in requests:
# Build a columns widget for the request row
prompt = req['prompt'] if 'prompt' in req else 'UPSCALE'
columns = urwid.Columns([
('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12"
('weight', 3, urwid.Text(req['model'])),
('weight', 3, urwid.Text(prompt)),
('fixed', 13, urwid.Text(req['user'])),
('fixed', 13, urwid.Text(req['reward'])),
], dividechars=1)
# Wrap the columns with an attribute map for coloring
request_row = urwid.AttrMap(columns, 'request_row')
row_widgets.append(request_row)
# Then add each worker in its own line below
for w in req["workers"]:
worker_line = urwid.Text(f" {w}")
worker_row = urwid.AttrMap(worker_line, 'worker_row')
row_widgets.append(worker_row)
# Optional blank line after each request
row_widgets.append(urwid.Text(""))
return row_widgets
def _exit_on_q(self, key):
"""Exit the TUI on 'q' or 'Q'."""
if key in ('q', 'Q'):
raise urwid.ExitMainLoop()
async def run(self):
"""
Run the TUI in an async context (Trio).
This method blocks until the user quits (pressing q/Q).
"""
with self.main_loop.start():
await self.event_loop.run_async()
raise urwid.ExitMainLoop()
# -------------------------------------------------------------------------
# Public Methods to Update Various Parts of the UI
# -------------------------------------------------------------------------
def set_status(self, status: str):
self.status_text.set_text(status)
def set_progress(self, current, done=None):
"""
Update the bottom progress bar.
- `current`: new current progress value (int).
- `done`: max progress value (int). If None, we dont change it.
"""
if done is not None:
self.progress_bar.done = done
self.progress_bar.current = current
pct = 0
if self.progress_bar.done != 0:
pct = int((self.progress_bar.current / self.progress_bar.done) * 100)
def update_requests(self, new_requests):
"""
Replace the data in the existing ListBox with new request widgets.
"""
new_widgets = self._build_request_widgets(new_requests)
self.body_listbox.body[:] = new_widgets # replace content of the list walker
def set_header_text(self, new_worker_name=None, new_balance=None):
"""
Update the text in the header bar for worker name and/or balance.
"""
if new_worker_name is not None:
self.worker_name_widget.set_text(new_worker_name)
if new_balance is not None:
self.balance_widget.set_text(new_balance)
def network_update(self, snapshot: dict):
queue = [
{
**r,
**(json.loads(r['body'])['params']),
'workers': [s['worker'] for s in snapshot['requests'][r['id']]]
}
for r in snapshot['queue']
]
self.update_requests(queue)
def setup_logging_for_tui(config: Config):
warnings.filterwarnings("ignore")
level = getattr(logging, config.log_level.upper(), logging.WARNING)
logger = logging.getLogger()
logger.setLevel(level)
fh = logging.FileHandler(config.log_file)
fh.setLevel(level)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler):
logger.removeHandler(handler)
_tui: WorkerMonitor | None = None
def init_tui(config: Config):
global _tui
assert not _tui
setup_logging_for_tui(config)
_tui = WorkerMonitor()
return _tui
def maybe_update_tui(fn):
global _tui
if _tui:
fn(_tui)
async def maybe_update_tui_async(fn):
global _tui
if _tui:
await fn(_tui)