Refactoring tui to be functional style

guilles_counter_review
Guillermo Rodriguez 2025-02-05 19:48:57 -03:00
parent cd028d15e7
commit d8f243df9b
No known key found for this signature in database
GPG Key ID: 002CC5F1E6BDA53E
5 changed files with 68 additions and 138 deletions

View File

@ -1,36 +1,17 @@
import logging import logging
import warnings
import trio import trio
import urwid
from hypercorn.config import Config from hypercorn.config import Config
from hypercorn.trio import serve from hypercorn.trio import serve
from quart_trio import QuartTrio as Quart from quart_trio import QuartTrio as Quart
from skynet.dgpu.tui import WorkerMonitor from skynet.dgpu.tui import init_tui
from skynet.dgpu.daemon import WorkerDaemon from skynet.dgpu.daemon import WorkerDaemon
from skynet.dgpu.network import NetConnector from skynet.dgpu.network import NetConnector
def setup_logging_for_tui(level):
warnings.filterwarnings("ignore")
logger = logging.getLogger()
logger.setLevel(level)
fh = logging.FileHandler('dgpu.log')
fh.setLevel(level)
formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
fh.setFormatter(formatter)
logger.addHandler(fh)
for handler in logger.handlers:
if isinstance(handler, logging.StreamHandler):
logger.removeHandler(handler)
async def open_dgpu_node(config: dict) -> None: async def open_dgpu_node(config: dict) -> None:
''' '''
Open a top level "GPU mgmt daemon", keep the Open a top level "GPU mgmt daemon", keep the
@ -43,11 +24,10 @@ async def open_dgpu_node(config: dict) -> None:
tui = None tui = None
if config['tui']: if config['tui']:
setup_logging_for_tui(logging.INFO) tui = init_tui()
tui = WorkerMonitor()
conn = NetConnector(config, tui=tui) conn = NetConnector(config)
daemon = WorkerDaemon(conn, config, tui=tui) daemon = WorkerDaemon(conn, config)
api: Quart|None = None api: Quart|None = None
if 'api_bind' in config: if 'api_bind' in config:
@ -71,5 +51,5 @@ async def open_dgpu_node(config: dict) -> None:
# block until cancelled # block until cancelled
await daemon.serve_forever() await daemon.serve_forever()
except *urwid.ExitMainLoop in ex_group: except *urwid.ExitMainLoop:
... ...

View File

@ -12,7 +12,7 @@ from contextlib import contextmanager as cm
import trio import trio
import torch import torch
from skynet.dgpu.tui import WorkerMonitor from skynet.dgpu.tui import maybe_update_tui
from skynet.dgpu.errors import ( from skynet.dgpu.errors import (
DGPUComputeError, DGPUComputeError,
DGPUInferenceCancelled, DGPUInferenceCancelled,
@ -108,8 +108,7 @@ def compute_one(
method: str, method: str,
params: dict, params: dict,
inputs: list[bytes] = [], inputs: list[bytes] = [],
should_cancel = None, should_cancel = None
tui: WorkerMonitor | None = None
): ):
if method == 'diffuse': if method == 'diffuse':
method = 'txt2img' method = 'txt2img'
@ -130,8 +129,7 @@ def compute_one(
if not isinstance(step, int): if not isinstance(step, int):
step = args[1] step = args[1]
if tui: maybe_update_tui(lambda tui: tui.set_progress(step, done=total_steps))
tui.set_progress(step, done=total_steps)
if should_cancel: if should_cancel:
should_raise = trio.from_thread.run(should_cancel, request_id) should_raise = trio.from_thread.run(should_cancel, request_id)
@ -142,8 +140,7 @@ def compute_one(
return {} return {}
if tui: maybe_update_tui(lambda tui: tui.set_status(f'Request #{request_id}'))
tui.set_status(f'Request #{request_id}')
inference_step_wakeup(0) inference_step_wakeup(0)
@ -210,7 +207,6 @@ def compute_one(
except BaseException as err: except BaseException as err:
raise DGPUComputeError(str(err)) from err raise DGPUComputeError(str(err)) from err
if tui: maybe_update_tui(lambda tui: tui.set_status(''))
tui.set_status('')
return output_hash, output return output_hash, output

View File

@ -17,7 +17,7 @@ from skynet.constants import (
from skynet.dgpu.errors import ( from skynet.dgpu.errors import (
DGPUComputeError, DGPUComputeError,
) )
from skynet.dgpu.tui import WorkerMonitor from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async
from skynet.dgpu.compute import maybe_load_model, compute_one from skynet.dgpu.compute import maybe_load_model, compute_one
from skynet.dgpu.network import NetConnector from skynet.dgpu.network import NetConnector
@ -41,11 +41,9 @@ class WorkerDaemon:
def __init__( def __init__(
self, self,
conn: NetConnector, conn: NetConnector,
config: dict, config: dict
tui: WorkerMonitor | None = None
): ):
self.conn: NetConnector = conn self.conn: NetConnector = conn
self._tui = tui
self.auto_withdraw = ( self.auto_withdraw = (
config['auto_withdraw'] config['auto_withdraw']
if 'auto_withdraw' in config else False if 'auto_withdraw' in config else False
@ -152,10 +150,12 @@ class WorkerDaemon:
return app return app
async def _update_balance(self): async def _update_balance(self):
if self._tui: async def _fn(tui):
# update balance # update balance
balance = await self.conn.get_worker_balance() balance = await self.conn.get_worker_balance()
self._tui.set_header_text(new_balance=f'balance: {balance}') tui.set_header_text(new_balance=f'balance: {balance}')
await maybe_update_tui_async(_fn)
# TODO? this func is kinda big and maybe is better at module # TODO? this func is kinda big and maybe is better at module
# level to reduce indentation? # level to reduce indentation?
@ -258,8 +258,7 @@ class WorkerDaemon:
with maybe_load_model(model, mode): with maybe_load_model(model, mode):
try: try:
if self._tui: maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step))
self._tui.set_progress(0, done=total_step)
output_type = 'png' output_type = 'png'
if 'output_type' in body['params']: if 'output_type' in body['params']:
@ -276,7 +275,6 @@ class WorkerDaemon:
mode, body['params'], mode, body['params'],
inputs=inputs, inputs=inputs,
should_cancel=self.should_cancel_work, should_cancel=self.should_cancel_work,
tui=self._tui
) )
) )
@ -285,8 +283,7 @@ class WorkerDaemon:
f'Unsupported backend {self.backend}' f'Unsupported backend {self.backend}'
) )
if self._tui: maybe_update_tui(lambda tui: tui.set_progress(total_step))
self._tui.set_progress(total_step)
self._last_generation_ts: str = datetime.now().isoformat() self._last_generation_ts: str = datetime.now().isoformat()
self._last_benchmark: list[float] = self._benchmark self._last_benchmark: list[float] = self._benchmark

View File

@ -13,7 +13,7 @@ import outcome
from PIL import Image from PIL import Image
from leap.cleos import CLEOS from leap.cleos import CLEOS
from leap.protocol import Asset from leap.protocol import Asset
from skynet.dgpu.tui import WorkerMonitor from skynet.dgpu.tui import maybe_update_tui
from skynet.constants import ( from skynet.constants import (
DEFAULT_IPFS_DOMAIN, DEFAULT_IPFS_DOMAIN,
GPU_CONTRACT_ABI, GPU_CONTRACT_ABI,
@ -58,7 +58,7 @@ class NetConnector:
- CLEOS client - CLEOS client
''' '''
def __init__(self, config: dict, tui: WorkerMonitor | None = None): def __init__(self, config: dict):
# TODO, why these extra instance vars for an (unsynced) # TODO, why these extra instance vars for an (unsynced)
# copy of the `config` state? # copy of the `config` state?
self.account = config['account'] self.account = config['account']
@ -82,9 +82,8 @@ class NetConnector:
self.ipfs_domain = config['ipfs_domain'] self.ipfs_domain = config['ipfs_domain']
self._wip_requests = {} self._wip_requests = {}
self._tui = tui
if self._tui: maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.account))
self._tui.set_header_text(new_worker_name=self.account)
# blockchain helpers # blockchain helpers
@ -168,8 +167,8 @@ class NetConnector:
n.start_soon( n.start_soon(
_run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id']) _run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id'])
if self._tui:
self._tui.network_update(snap) maybe_update_tui(lambda tui: tui.network_update(snap))
return snap return snap

View File

@ -1,6 +1,9 @@
import urwid
import trio
import json import json
import logging
import warnings
import trio
import urwid
class WorkerMonitor: class WorkerMonitor:
@ -163,86 +166,41 @@ class WorkerMonitor:
self.update_requests(queue) self.update_requests(queue)
# # ----------------------------------------------------------------------------- def setup_logging_for_tui(level):
# # Example usage warnings.filterwarnings("ignore")
# # -----------------------------------------------------------------------------
# logger = logging.getLogger()
# async def main(): logger.setLevel(level)
# # Example data
# example_requests = [ fh = logging.FileHandler('dgpu.log')
# { fh.setLevel(level)
# "id": 12,
# "model": "black-forest-labs/FLUX.1-schnell", formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")
# "prompt": "Generate an answer about quantum entanglement.", fh.setFormatter(formatter)
# "user": "alice123",
# "reward": "20.0000 GPU", logger.addHandler(fh)
# "workers": ["workerA", "workerB"],
# }, for handler in logger.handlers:
# { if isinstance(handler, logging.StreamHandler):
# "id": 5, logger.removeHandler(handler)
# "model": "some-other-model/v2.0",
# "prompt": "A story about dragons.",
# "user": "bobthebuilder", _tui = None
# "reward": "15.0000 GPU", def init_tui():
# "workers": ["workerX"], global _tui
# }, assert not _tui
# { setup_logging_for_tui(logging.INFO)
# "id": 99, _tui = WorkerMonitor()
# "model": "cool-model/turbo", return _tui
# "prompt": "Classify sentiment in these tweets.",
# "user": "charlie",
# "reward": "25.5000 GPU", def maybe_update_tui(fn):
# "workers": ["workerOne", "workerTwo", "workerThree"], global _tui
# }, if _tui:
# ] fn(_tui)
#
# ui = WorkerMonitor()
# async def maybe_update_tui_async(fn):
# async def progress_task(): global _tui
# # Fill from 0% to 100% if _tui:
# for pct in range(101): await fn(_tui)
# ui.set_progress(
# current=pct,
# status_str=f"Request #1234 ({pct}%)"
# )
# await trio.sleep(0.05)
# # Reset to 0
# ui.set_progress(
# current=0,
# status_str="Starting again..."
# )
#
# async def update_data_task():
# await trio.sleep(3) # Wait a bit, then update requests
# new_data = [{
# "id": 101,
# "model": "new-model/v1.0",
# "prompt": "Say hi to the world.",
# "user": "eve",
# "reward": "50.0000 GPU",
# "workers": ["workerFresh", "workerPower"],
# }]
# ui.update_requests(new_data)
# ui.set_header_text(new_worker_name="NewNodeName",
# new_balance="balance: 12345.6789 GPU")
#
# try:
# async with trio.open_nursery() as nursery:
# # Run the TUI
# nursery.start_soon(ui.run_teadown_on_exit, nursery)
#
# ui.update_requests(example_requests)
# ui.set_header_text(
# new_worker_name="worker1.scd",
# new_balance="balance: 12345.6789 GPU"
# )
# # Start background tasks
# nursery.start_soon(progress_task)
# nursery.start_soon(update_data_task)
#
# except *KeyboardInterrupt as ex_group:
# ...
#
#
# if __name__ == "__main__":
# trio.run(main)