Begin adding TUI

guilles_counter_review
Guillermo Rodriguez 2025-02-05 15:35:40 -03:00
parent 93ee65087f
commit b3dc7c1074
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
8 changed files with 391 additions and 22 deletions

View File

@ -61,6 +61,7 @@ cuda = [
"basicsr>=1.4.2,<2",
"realesrgan>=0.3.0,<0.4",
"sentencepiece>=0.2.0",
"urwid>=2.6.16",
]
[tool.uv]

View File

@ -1,4 +1,5 @@
import logging
import warnings
import trio
@ -6,11 +7,31 @@ from hypercorn.config import Config
from hypercorn.trio import serve
from quart_trio import QuartTrio as Quart
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import ModelMngr
from skynet.dgpu.daemon import WorkerDaemon
from skynet.dgpu.network import NetConnector
def setup_logging_for_tui(level):
warnings.filterwarnings("ignore")
logger = logging.getLogger()
logger.setLevel(level)
fh = logging.FileHandler('dgpu.log')
fh.setLevel(level)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler):
logger.removeHandler(handler)
async def open_dgpu_node(config: dict) -> None:
'''
Open a top level "GPU mgmt daemon", keep the
@ -18,13 +39,17 @@ async def open_dgpu_node(config: dict) -> None:
and *maybe* serve a `hypercorn` web API.
'''
# suppress logs from httpx (logs url + status after every query)
logging.getLogger("httpx").setLevel(logging.WARNING)
conn = NetConnector(config)
mm = ModelMngr(config)
daemon = WorkerDaemon(mm, conn, config)
tui = None
if config['tui']:
setup_logging_for_tui(logging.INFO)
tui = WorkerMonitor()
conn = NetConnector(config, tui=tui)
mm = ModelMngr(config, tui=tui)
daemon = WorkerDaemon(mm, conn, config, tui=tui)
api: Quart|None = None
if 'api_bind' in config:
@ -35,6 +60,8 @@ async def open_dgpu_node(config: dict) -> None:
tn: trio.Nursery
async with trio.open_nursery() as tn:
tn.start_soon(daemon.snap_updater_task)
if tui:
tn.start_soon(tui.run)
# TODO, consider a more explicit `as hypercorn_serve`
# to clarify?
@ -42,5 +69,9 @@ async def open_dgpu_node(config: dict) -> None:
logging.info(f'serving api @ {config["api_bind"]}')
tn.start_soon(serve, api, api_conf)
try:
# block until cancelled
await daemon.serve_forever()
except *urwid.ExitMainLoop in ex_group:
...

View File

@ -11,6 +11,7 @@ from hashlib import sha256
import trio
import torch
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.errors import (
DGPUComputeError,
DGPUInferenceCancelled,
@ -72,7 +73,8 @@ class ModelMngr:
checking load state, and unloading when no-longer-needed/finished.
'''
def __init__(self, config: dict):
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
self._tui = tui
self.cache_dir = None
if 'hf_home' in config:
self.cache_dir = config['hf_home']
@ -80,8 +82,6 @@ class ModelMngr:
self._model_name: str = ''
self._model_mode: str = ''
# self.load_model(DEFAULT_INITAL_MODEL, 'txt2img')
def log_debug_info(self):
logging.debug('memory summary:')
logging.debug('\n' + torch.cuda.memory_summary())
@ -110,6 +110,7 @@ class ModelMngr:
) -> None:
logging.info(f'loading model {name}...')
self.unload_model()
self._model = pipeline_for(
name, mode, cache_dir=self.cache_dir)
self._model_mode = mode
@ -124,11 +125,19 @@ class ModelMngr:
params: dict,
inputs: list[bytes] = []
):
def maybe_cancel_work(step, *args, **kwargs):
total_steps = params['step']
def inference_step_wakeup(*args, **kwargs):
'''This is a callback function that gets invoked every inference step,
we need to raise an exception here if we need to cancel work
'''
if self._should_cancel:
step = args[0]
# compat with callback_on_step_end
if not isinstance(step, int):
step = args[1]
if self._tui:
self._tui.set_progress(step, done=total_steps)
should_raise = trio.from_thread.run(self._should_cancel, request_id)
if should_raise:
logging.warning(f'CANCELLING work at step {step}')
@ -136,7 +145,10 @@ class ModelMngr:
return {}
maybe_cancel_work(0)
if self._tui:
self._tui.set_status(f'Request #{request_id}')
inference_step_wakeup(0)
output_type = 'png'
if 'output_type' in params:
@ -157,10 +169,10 @@ class ModelMngr:
prompt, guidance, step, seed, upscaler, extra_params = arguments
if 'flux' in name.lower():
extra_params['callback_on_step_end'] = maybe_cancel_work
extra_params['callback_on_step_end'] = inference_step_wakeup
else:
extra_params['callback'] = maybe_cancel_work
extra_params['callback'] = inference_step_wakeup
extra_params['callback_steps'] = 1
output = self._model(
@ -213,4 +225,7 @@ class ModelMngr:
finally:
torch.cuda.empty_cache()
if self._tui:
self._tui.set_status('')
return output_hash, output

View File

@ -17,6 +17,7 @@ from skynet.constants import (
from skynet.dgpu.errors import (
DGPUComputeError,
)
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.compute import ModelMngr
from skynet.dgpu.network import NetConnector
@ -41,10 +42,12 @@ class WorkerDaemon:
self,
mm: ModelMngr,
conn: NetConnector,
config: dict
config: dict,
tui: WorkerMonitor | None = None
):
self.mm: ModelMngr = mm
self.conn: NetConnector = conn
self._tui = tui
self.auto_withdraw = (
config['auto_withdraw']
if 'auto_withdraw' in config else False
@ -150,6 +153,12 @@ class WorkerDaemon:
return app
async def _update_balance(self):
if self._tui:
# update balance
balance = await self.conn.get_worker_balance()
self._tui.set_header_text(new_balance=f'balance: {balance}')
# TODO? this func is kinda big and maybe is better at module
# level to reduce indentation?
# -[ ] just pass `daemon: WorkerDaemon` vs. `self`
@ -238,6 +247,8 @@ class WorkerDaemon:
request_hash = sha256(hash_str.encode('utf-8')).hexdigest()
logging.info(f'calculated request hash: {request_hash}')
total_step = body['params']['step']
# TODO: validate request
resp = await self.conn.begin_work(rid)
@ -246,6 +257,9 @@ class WorkerDaemon:
else:
try:
if self._tui:
self._tui.set_progress(0, done=total_step)
output_type = 'png'
if 'output_type' in body['params']:
output_type = body['params']['output_type']
@ -269,6 +283,9 @@ class WorkerDaemon:
f'Unsupported backend {self.backend}'
)
if self._tui:
self._tui.set_progress(total_step)
self._last_generation_ts: str = datetime.now().isoformat()
self._last_benchmark: list[float] = self._benchmark
self._benchmark: list[float] = []
@ -277,6 +294,9 @@ class WorkerDaemon:
await self.conn.submit_work(rid, request_hash, output_hash, ipfs_hash)
await self._update_balance()
except BaseException as err:
if 'network cancel' not in str(err):
logging.exception('Failed to serve model request !?\n')
@ -294,6 +314,7 @@ class WorkerDaemon:
# -[ ] keeps tasks-as-funcs style prominent
# -[ ] avoids so much indentation due to methods
async def serve_forever(self):
await self._update_balance()
try:
while True:
if self.auto_withdraw:

View File

@ -13,6 +13,7 @@ import outcome
from PIL import Image
from leap.cleos import CLEOS
from leap.protocol import Asset
from skynet.dgpu.tui import WorkerMonitor
from skynet.constants import (
DEFAULT_IPFS_DOMAIN,
GPU_CONTRACT_ABI,
@ -57,7 +58,7 @@ class NetConnector:
- CLEOS client
'''
def __init__(self, config: dict):
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
# TODO, why these extra instance vars for an (unsynced)
# copy of the `config` state?
self.account = config['account']
@ -81,6 +82,10 @@ class NetConnector:
self.ipfs_domain = config['ipfs_domain']
self._wip_requests = {}
self._tui = tui
if self._tui:
self._tui.set_header_text(new_worker_name=self.account)
# blockchain helpers
@ -163,6 +168,9 @@ class NetConnector:
n.start_soon(
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
if self._tui:
self._tui.network_update(snap)
return snap
async def begin_work(self, request_id: int):

248
skynet/dgpu/tui.py 100644
View File

@ -0,0 +1,248 @@
import urwid
import trio
import json
class WorkerMonitor:
def __init__(self):
self.requests = []
self.header_info = {}
self.palette = [
('headerbar', 'white', 'dark blue'),
('request_row', 'white', 'dark gray'),
('worker_row', 'light gray', 'black'),
('progress_normal', 'black', 'light gray'),
('progress_complete', 'black', 'dark green'),
('body', 'white', 'black'),
]
# --- Top bar (header) ---
worker_name = self.header_info.get('left', "unknown")
balance = self.header_info.get('right', "balance: unknown")
self.worker_name_widget = urwid.Text(worker_name)
self.balance_widget = urwid.Text(balance, align='right')
header = urwid.Columns([self.worker_name_widget, self.balance_widget])
header_attr = urwid.AttrMap(header, 'headerbar')
# --- Body (List of requests) ---
self.body_listbox = self._create_listbox_body(self.requests)
# --- Bottom bar (progress) ---
self.status_text = urwid.Text("Request: none", align='left')
self.progress_bar = urwid.ProgressBar(
'progress_normal',
'progress_complete',
current=0,
done=100
)
footer_cols = urwid.Columns([
('fixed', 20, self.status_text),
self.progress_bar,
])
# Build the main frame
frame = urwid.Frame(
self.body_listbox,
header=header_attr,
footer=footer_cols
)
# Set up the main loop with Trio
self.event_loop = urwid.TrioEventLoop()
self.main_loop = urwid.MainLoop(
frame,
palette=self.palette,
event_loop=self.event_loop,
unhandled_input=self._exit_on_q
)
def _create_listbox_body(self, requests):
"""
Build a ListBox (vertical list) of requests & workers using SimpleFocusListWalker.
"""
widgets = self._build_request_widgets(requests)
walker = urwid.SimpleFocusListWalker(widgets)
return urwid.ListBox(walker)
def _build_request_widgets(self, requests):
"""
Build a list of Urwid widgets (one row per request + per worker).
"""
row_widgets = []
for req in requests:
# Build a columns widget for the request row
columns = urwid.Columns([
('fixed', 5, urwid.Text(f"#{req['id']}")), # e.g. "#12"
('weight', 3, urwid.Text(req['model'])),
('weight', 3, urwid.Text(req['prompt'])),
('fixed', 13, urwid.Text(req['user'])),
('fixed', 13, urwid.Text(req['reward'])),
], dividechars=1)
# Wrap the columns with an attribute map for coloring
request_row = urwid.AttrMap(columns, 'request_row')
row_widgets.append(request_row)
# Then add each worker in its own line below
for w in req["workers"]:
worker_line = urwid.Text(f" {w}")
worker_row = urwid.AttrMap(worker_line, 'worker_row')
row_widgets.append(worker_row)
# Optional blank line after each request
row_widgets.append(urwid.Text(""))
return row_widgets
def _exit_on_q(self, key):
"""Exit the TUI on 'q' or 'Q'."""
if key in ('q', 'Q'):
raise urwid.ExitMainLoop()
async def run(self):
"""
Run the TUI in an async context (Trio).
This method blocks until the user quits (pressing q/Q).
"""
with self.main_loop.start():
await self.event_loop.run_async()
raise urwid.ExitMainLoop()
# -------------------------------------------------------------------------
# Public Methods to Update Various Parts of the UI
# -------------------------------------------------------------------------
def set_status(self, status: str):
self.status_text.set_text(status)
def set_progress(self, current, done=None):
"""
Update the bottom progress bar.
- `current`: new current progress value (int).
- `done`: max progress value (int). If None, we dont change it.
"""
if done is not None:
self.progress_bar.done = done
self.progress_bar.current = current
pct = 0
if self.progress_bar.done != 0:
pct = int((self.progress_bar.current / self.progress_bar.done) * 100)
def update_requests(self, new_requests):
"""
Replace the data in the existing ListBox with new request widgets.
"""
new_widgets = self._build_request_widgets(new_requests)
self.body_listbox.body[:] = new_widgets # replace content of the list walker
def set_header_text(self, new_worker_name=None, new_balance=None):
"""
Update the text in the header bar for worker name and/or balance.
"""
if new_worker_name is not None:
self.worker_name_widget.set_text(new_worker_name)
if new_balance is not None:
self.balance_widget.set_text(new_balance)
def network_update(self, snapshot: dict):
queue = [
{
**r,
**(json.loads(r['body'])['params']),
'workers': [s['worker'] for s in snapshot['requests'][r['id']]]
}
for r in snapshot['queue']
]
self.update_requests(queue)
# # -----------------------------------------------------------------------------
# # Example usage
# # -----------------------------------------------------------------------------
#
# async def main():
# # Example data
# example_requests = [
# {
# "id": 12,
# "model": "black-forest-labs/FLUX.1-schnell",
# "prompt": "Generate an answer about quantum entanglement.",
# "user": "alice123",
# "reward": "20.0000 GPU",
# "workers": ["workerA", "workerB"],
# },
# {
# "id": 5,
# "model": "some-other-model/v2.0",
# "prompt": "A story about dragons.",
# "user": "bobthebuilder",
# "reward": "15.0000 GPU",
# "workers": ["workerX"],
# },
# {
# "id": 99,
# "model": "cool-model/turbo",
# "prompt": "Classify sentiment in these tweets.",
# "user": "charlie",
# "reward": "25.5000 GPU",
# "workers": ["workerOne", "workerTwo", "workerThree"],
# },
# ]
#
# ui = WorkerMonitor()
#
# async def progress_task():
# # Fill from 0% to 100%
# for pct in range(101):
# ui.set_progress(
# current=pct,
# status_str=f"Request #1234 ({pct}%)"
# )
# await trio.sleep(0.05)
# # Reset to 0
# ui.set_progress(
# current=0,
# status_str="Starting again..."
# )
#
# async def update_data_task():
# await trio.sleep(3) # Wait a bit, then update requests
# new_data = [{
# "id": 101,
# "model": "new-model/v1.0",
# "prompt": "Say hi to the world.",
# "user": "eve",
# "reward": "50.0000 GPU",
# "workers": ["workerFresh", "workerPower"],
# }]
# ui.update_requests(new_data)
# ui.set_header_text(new_worker_name="NewNodeName",
# new_balance="balance: 12345.6789 GPU")
#
# try:
# async with trio.open_nursery() as nursery:
# # Run the TUI
# nursery.start_soon(ui.run_teadown_on_exit, nursery)
#
# ui.update_requests(example_requests)
# ui.set_header_text(
# new_worker_name="worker1.scd",
# new_balance="balance: 12345.6789 GPU"
# )
# # Start background tasks
# nursery.start_soon(progress_task)
# nursery.start_soon(update_data_task)
#
# except *KeyboardInterrupt as ex_group:
# ...
#
#
# if __name__ == "__main__":
# trio.run(main)

View File

@ -7,8 +7,10 @@ import logging
import importlib
from typing import Optional
from contextlib import contextmanager
import torch
import diffusers
import numpy as np
from PIL import Image
@ -74,12 +76,27 @@ def convert_from_bytes_and_crop(raw: bytes, max_w: int, max_h: int) -> Image:
return crop_image(convert_from_bytes_to_img(raw), max_w, max_h)
class DummyPB:
def update(self):
...
@torch.compiler.disable
@contextmanager
def dummy_progress_bar(*args, **kwargs):
yield DummyPB()
def monkey_patch_pipeline_disable_progress_bar(pipe):
pipe.progress_bar = dummy_progress_bar
def pipeline_for(
model: str,
mode: str,
mem_fraction: float = 1.0,
cache_dir: str | None = None
) -> DiffusionPipeline:
diffusers.utils.logging.disable_progress_bar()
logging.info(f'pipeline_for {model} {mode}')
assert torch.cuda.is_available()
@ -105,7 +122,9 @@ def pipeline_for(
normalized_shortname = shortname.replace('-', '_')
custom_pipeline = importlib.import_module(f'skynet.dgpu.pipes.{normalized_shortname}')
assert custom_pipeline.__model['name'] == model
return custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
pipe = custom_pipeline.pipeline_for(model, mode, mem_fraction=mem_fraction, cache_dir=cache_dir)
monkey_patch_pipeline_disable_progress_bar(pipe)
return pipe
except ImportError:
# TODO, uhh why not warn/error log this?
@ -121,7 +140,6 @@ def pipeline_for(
logging.warn(f'model requires {req_mem} but card has {mem_gb}, model will run slower..')
params = {
'safety_checker': None,
'torch_dtype': torch.float16,
'cache_dir': cache_dir,
'variant': 'fp16',
@ -130,6 +148,7 @@ def pipeline_for(
match shortname:
case 'stable':
params['revision'] = 'fp16'
params['safety_checker'] = None
torch.cuda.set_per_process_memory_fraction(mem_fraction)
@ -167,6 +186,8 @@ def pipeline_for(
pipe = pipe.to('cuda')
monkey_patch_pipeline_disable_progress_bar(pipe)
return pipe

24
uv.lock
View File

@ -2261,6 +2261,7 @@ cuda = [
{ name = "torchvision" },
{ name = "transformers" },
{ name = "triton" },
{ name = "urwid" },
{ name = "xformers" },
]
dev = [
@ -2312,6 +2313,7 @@ cuda = [
{ name = "torchvision", specifier = "==0.20.1+cu121", index = "https://download.pytorch.org/whl/cu121" },
{ name = "transformers", specifier = "==4.48.0" },
{ name = "triton", specifier = "==3.1.0", index = "https://download.pytorch.org/whl/cu121" },
{ name = "urwid", specifier = ">=2.6.16" },
{ name = "xformers", specifier = ">=0.0.29,<0.0.30" },
]
dev = [
@ -2626,6 +2628,28 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/c8/19/4ec628951a74043532ca2cf5d97b7b14863931476d117c471e8e2b1eb39f/urllib3-2.3.0-py3-none-any.whl", hash = "sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df", size = 128369 },
]
[[package]]
name = "urwid"
version = "2.6.16"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "typing-extensions" },
{ name = "wcwidth" },
]
sdist = { url = "https://files.pythonhosted.org/packages/98/21/ad23c9e961b2d36d57c63686a6f86768dd945d406323fb58c84f09478530/urwid-2.6.16.tar.gz", hash = "sha256:93ad239939e44c385e64aa00027878b9e5c486d59e855ec8ab5b1e1adcdb32a2", size = 848179 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/54/cb/271a4f5a1bf4208dbdc96d85b9eae744cf4e5e11ac73eda76dc98c8fd2d7/urwid-2.6.16-py3-none-any.whl", hash = "sha256:de14896c6df9eb759ed1fd93e0384a5279e51e0dde8f621e4083f7a8368c0797", size = 297196 },
]
[[package]]
name = "wcwidth"
version = "0.2.13"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/6c/63/53559446a878410fc5a5974feb13d31d78d752eb18aeba59c7fef1af7598/wcwidth-0.2.13.tar.gz", hash = "sha256:72ea0c06399eb286d978fdedb6923a9eb47e1c486ce63e9b4e64fc18303972b5", size = 101301 }
wheels = [
{ url = "https://files.pythonhosted.org/packages/fd/84/fd2ba7aafacbad3c4201d395674fc6348826569da3c0937e75505ead3528/wcwidth-0.2.13-py2.py3-none-any.whl", hash = "sha256:3da69048e4540d84af32131829ff948f1e022c1c6bdb8d6102117aac784f6859", size = 34166 },
]
[[package]]
name = "websocket-client"
version = "1.8.0"