Refactoring tui to be functional style

guilles_counter_review
Guillermo Rodriguez 2025-02-05 19:48:57 -03:00
parent cd028d15e7
commit d8f243df9b
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
5 changed files with 68 additions and 138 deletions

View File

@ -1,36 +1,17 @@
import logging
import warnings
import trio
import urwid
from hypercorn.config import Config
from hypercorn.trio import serve
from quart_trio import QuartTrio as Quart
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.tui import init_tui
from skynet.dgpu.daemon import WorkerDaemon
from skynet.dgpu.network import NetConnector
def setup_logging_for_tui(level):
warnings.filterwarnings("ignore")
logger = logging.getLogger()
logger.setLevel(level)
fh = logging.FileHandler('dgpu.log')
fh.setLevel(level)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler):
logger.removeHandler(handler)
async def open_dgpu_node(config: dict) -> None:
'''
Open a top level "GPU mgmt daemon", keep the
@ -43,11 +24,10 @@ async def open_dgpu_node(config: dict) -> None:
tui = None
if config['tui']:
setup_logging_for_tui(logging.INFO)
tui = WorkerMonitor()
tui = init_tui()
conn = NetConnector(config, tui=tui)
daemon = WorkerDaemon(conn, config, tui=tui)
conn = NetConnector(config)
daemon = WorkerDaemon(conn, config)
api: Quart|None = None
if 'api_bind' in config:
@ -71,5 +51,5 @@ async def open_dgpu_node(config: dict) -> None:
# block until cancelled
await daemon.serve_forever()
except *urwid.ExitMainLoop in ex_group:
except *urwid.ExitMainLoop:
...

View File

@ -12,7 +12,7 @@ from contextlib import contextmanager as cm
import trio
import torch
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.tui import maybe_update_tui
from skynet.dgpu.errors import (
DGPUComputeError,
DGPUInferenceCancelled,
@ -108,8 +108,7 @@ def compute_one(
method: str,
params: dict,
inputs: list[bytes] = [],
should_cancel = None,
tui: WorkerMonitor | None = None
should_cancel = None
):
if method == 'diffuse':
method = 'txt2img'
@ -130,8 +129,7 @@ def compute_one(
if not isinstance(step, int):
step = args[1]
if tui:
tui.set_progress(step, done=total_steps)
maybe_update_tui(lambda tui: tui.set_progress(step, done=total_steps))
if should_cancel:
should_raise = trio.from_thread.run(should_cancel, request_id)
@ -142,8 +140,7 @@ def compute_one(
return {}
if tui:
tui.set_status(f'Request #{request_id}')
maybe_update_tui(lambda tui: tui.set_status(f'Request #{request_id}'))
inference_step_wakeup(0)
@ -210,7 +207,6 @@ def compute_one(
except BaseException as err:
raise DGPUComputeError(str(err)) from err
if tui:
tui.set_status('')
maybe_update_tui(lambda tui: tui.set_status(''))
return output_hash, output

View File

@ -17,7 +17,7 @@ from skynet.constants import (
from skynet.dgpu.errors import (
DGPUComputeError,
)
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
from skynet.dgpu.compute import maybe_load_model, compute_one
from skynet.dgpu.network import NetConnector
@ -41,11 +41,9 @@ class WorkerDaemon:
def __init__(
self,
conn: NetConnector,
config: dict,
tui: WorkerMonitor | None = None
config: dict
):
self.conn: NetConnector = conn
self._tui = tui
self.auto_withdraw = (
config['auto_withdraw']
if 'auto_withdraw' in config else False
@ -152,10 +150,12 @@ class WorkerDaemon:
return app
async def _update_balance(self):
if self._tui:
async def _fn(tui):
# update balance
balance = await self.conn.get_worker_balance()
self._tui.set_header_text(new_balance=f'balance: {balance}')
tui.set_header_text(new_balance=f'balance: {balance}')
await maybe_update_tui_async(_fn)
# TODO? this func is kinda big and maybe is better at module
# level to reduce indentation?
@ -258,8 +258,7 @@ class WorkerDaemon:
with maybe_load_model(model, mode):
try:
if self._tui:
self._tui.set_progress(0, done=total_step)
maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
output_type = 'png'
if 'output_type' in body['params']:
@ -276,7 +275,6 @@ class WorkerDaemon:
mode, body['params'],
inputs=inputs,
should_cancel=self.should_cancel_work,
tui=self._tui
)
)
@ -285,8 +283,7 @@ class WorkerDaemon:
f'Unsupported backend {self.backend}'
)
if self._tui:
self._tui.set_progress(total_step)
maybe_update_tui(lambda tui: tui.set_progress(total_step))
self._last_generation_ts: str = datetime.now().isoformat()
self._last_benchmark: list[float] = self._benchmark

View File

@ -13,7 +13,7 @@ import outcome
from PIL import Image
from leap.cleos import CLEOS
from leap.protocol import Asset
from skynet.dgpu.tui import WorkerMonitor
from skynet.dgpu.tui import maybe_update_tui
from skynet.constants import (
DEFAULT_IPFS_DOMAIN,
GPU_CONTRACT_ABI,
@ -58,7 +58,7 @@ class NetConnector:
- CLEOS client
'''
def __init__(self, config: dict, tui: WorkerMonitor | None = None):
def __init__(self, config: dict):
# TODO, why these extra instance vars for an (unsynced)
# copy of the `config` state?
self.account = config['account']
@ -82,9 +82,8 @@ class NetConnector:
self.ipfs_domain = config['ipfs_domain']
self._wip_requests = {}
self._tui = tui
if self._tui:
self._tui.set_header_text(new_worker_name=self.account)
maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.account))
# blockchain helpers
@ -168,8 +167,8 @@ class NetConnector:
n.start_soon(
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
if self._tui:
self._tui.network_update(snap)
maybe_update_tui(lambda tui: tui.network_update(snap))
return snap

View File

@ -1,6 +1,9 @@
import urwid
import trio
import json
import logging
import warnings
import trio
import urwid
class WorkerMonitor:
@ -163,86 +166,41 @@ class WorkerMonitor:
self.update_requests(queue)
# # -----------------------------------------------------------------------------
# # Example usage
# # -----------------------------------------------------------------------------
#
# async def main():
# # Example data
# example_requests = [
# {
# "id": 12,
# "model": "black-forest-labs/FLUX.1-schnell",
# "prompt": "Generate an answer about quantum entanglement.",
# "user": "alice123",
# "reward": "20.0000 GPU",
# "workers": ["workerA", "workerB"],
# },
# {
# "id": 5,
# "model": "some-other-model/v2.0",
# "prompt": "A story about dragons.",
# "user": "bobthebuilder",
# "reward": "15.0000 GPU",
# "workers": ["workerX"],
# },
# {
# "id": 99,
# "model": "cool-model/turbo",
# "prompt": "Classify sentiment in these tweets.",
# "user": "charlie",
# "reward": "25.5000 GPU",
# "workers": ["workerOne", "workerTwo", "workerThree"],
# },
# ]
#
# ui = WorkerMonitor()
#
# async def progress_task():
# # Fill from 0% to 100%
# for pct in range(101):
# ui.set_progress(
# current=pct,
# status_str=f"Request #1234 ({pct}%)"
# )
# await trio.sleep(0.05)
# # Reset to 0
# ui.set_progress(
# current=0,
# status_str="Starting again..."
# )
#
# async def update_data_task():
# await trio.sleep(3) # Wait a bit, then update requests
# new_data = [{
# "id": 101,
# "model": "new-model/v1.0",
# "prompt": "Say hi to the world.",
# "user": "eve",
# "reward": "50.0000 GPU",
# "workers": ["workerFresh", "workerPower"],
# }]
# ui.update_requests(new_data)
# ui.set_header_text(new_worker_name="NewNodeName",
# new_balance="balance: 12345.6789 GPU")
#
# try:
# async with trio.open_nursery() as nursery:
# # Run the TUI
# nursery.start_soon(ui.run_teadown_on_exit, nursery)
#
# ui.update_requests(example_requests)
# ui.set_header_text(
# new_worker_name="worker1.scd",
# new_balance="balance: 12345.6789 GPU"
# )
# # Start background tasks
# nursery.start_soon(progress_task)
# nursery.start_soon(update_data_task)
#
# except *KeyboardInterrupt as ex_group:
# ...
#
#
# if __name__ == "__main__":
# trio.run(main)
def setup_logging_for_tui(level):
warnings.filterwarnings("ignore")
logger = logging.getLogger()
logger.setLevel(level)
fh = logging.FileHandler('dgpu.log')
fh.setLevel(level)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler):
logger.removeHandler(handler)
_tui = None
def init_tui():
global _tui
assert not _tui
setup_logging_for_tui(logging.INFO)
_tui = WorkerMonitor()
return _tui
def maybe_update_tui(fn):
global _tui
if _tui:
fn(_tui)
async def maybe_update_tui_async(fn):
global _tui
if _tui:
await fn(_tui)