diff --git a/skynet/dgpu/__init__.py b/skynet/dgpu/__init__.py index f5fcc9e..96cc303 100755 --- a/skynet/dgpu/__init__.py +++ b/skynet/dgpu/__init__.py @@ -1,36 +1,17 @@ import logging -import warnings import trio +import urwid from hypercorn.config import Config from hypercorn.trio import serve from quart_trio import QuartTrio as Quart -from skynet.dgpu.tui import WorkerMonitor +from skynet.dgpu.tui import init_tui from skynet.dgpu.daemon import WorkerDaemon from skynet.dgpu.network import NetConnector -def setup_logging_for_tui(level): - warnings.filterwarnings("ignore") - - logger = logging.getLogger() - logger.setLevel(level) - - fh = logging.FileHandler('dgpu.log') - fh.setLevel(level) - - formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") - fh.setFormatter(formatter) - - logger.addHandler(fh) - - for handler in logger.handlers: - if isinstance(handler, logging.StreamHandler): - logger.removeHandler(handler) - - async def open_dgpu_node(config: dict) -> None: ''' Open a top level "GPU mgmt daemon", keep the @@ -43,11 +24,10 @@ async def open_dgpu_node(config: dict) -> None: tui = None if config['tui']: - setup_logging_for_tui(logging.INFO) - tui = WorkerMonitor() + tui = init_tui() - conn = NetConnector(config, tui=tui) - daemon = WorkerDaemon(conn, config, tui=tui) + conn = NetConnector(config) + daemon = WorkerDaemon(conn, config) api: Quart|None = None if 'api_bind' in config: @@ -71,5 +51,5 @@ async def open_dgpu_node(config: dict) -> None: # block until cancelled await daemon.serve_forever() - except *urwid.ExitMainLoop in ex_group: + except *urwid.ExitMainLoop: ... diff --git a/skynet/dgpu/compute.py b/skynet/dgpu/compute.py index ec50054..a027dc8 100755 --- a/skynet/dgpu/compute.py +++ b/skynet/dgpu/compute.py @@ -12,7 +12,7 @@ from contextlib import contextmanager as cm import trio import torch -from skynet.dgpu.tui import WorkerMonitor +from skynet.dgpu.tui import maybe_update_tui from skynet.dgpu.errors import ( DGPUComputeError, DGPUInferenceCancelled, @@ -108,8 +108,7 @@ def compute_one( method: str, params: dict, inputs: list[bytes] = [], - should_cancel = None, - tui: WorkerMonitor | None = None + should_cancel = None ): if method == 'diffuse': method = 'txt2img' @@ -130,8 +129,7 @@ def compute_one( if not isinstance(step, int): step = args[1] - if tui: - tui.set_progress(step, done=total_steps) + maybe_update_tui(lambda tui: tui.set_progress(step, done=total_steps)) if should_cancel: should_raise = trio.from_thread.run(should_cancel, request_id) @@ -142,8 +140,7 @@ def compute_one( return {} - if tui: - tui.set_status(f'Request #{request_id}') + maybe_update_tui(lambda tui: tui.set_status(f'Request #{request_id}')) inference_step_wakeup(0) @@ -210,7 +207,6 @@ def compute_one( except BaseException as err: raise DGPUComputeError(str(err)) from err - if tui: - tui.set_status('') + maybe_update_tui(lambda tui: tui.set_status('')) return output_hash, output diff --git a/skynet/dgpu/daemon.py b/skynet/dgpu/daemon.py index 4c0bdce..88c0eed 100755 --- a/skynet/dgpu/daemon.py +++ b/skynet/dgpu/daemon.py @@ -17,7 +17,7 @@ from skynet.constants import ( from skynet.dgpu.errors import ( DGPUComputeError, ) -from skynet.dgpu.tui import WorkerMonitor +from skynet.dgpu.tui import maybe_update_tui, maybe_update_tui_async from skynet.dgpu.compute import maybe_load_model, compute_one from skynet.dgpu.network import NetConnector @@ -41,11 +41,9 @@ class WorkerDaemon: def __init__( self, conn: NetConnector, - config: dict, - tui: WorkerMonitor | None = None + config: dict ): self.conn: NetConnector = conn - self._tui = tui self.auto_withdraw = ( config['auto_withdraw'] if 'auto_withdraw' in config else False @@ -152,10 +150,12 @@ class WorkerDaemon: return app async def _update_balance(self): - if self._tui: + async def _fn(tui): # update balance balance = await self.conn.get_worker_balance() - self._tui.set_header_text(new_balance=f'balance: {balance}') + tui.set_header_text(new_balance=f'balance: {balance}') + + await maybe_update_tui_async(_fn) # TODO? this func is kinda big and maybe is better at module # level to reduce indentation? @@ -258,8 +258,7 @@ class WorkerDaemon: with maybe_load_model(model, mode): try: - if self._tui: - self._tui.set_progress(0, done=total_step) + maybe_update_tui(lambda tui: tui.set_progress(0, done=total_step)) output_type = 'png' if 'output_type' in body['params']: @@ -276,7 +275,6 @@ class WorkerDaemon: mode, body['params'], inputs=inputs, should_cancel=self.should_cancel_work, - tui=self._tui ) ) @@ -285,8 +283,7 @@ class WorkerDaemon: f'Unsupported backend {self.backend}' ) - if self._tui: - self._tui.set_progress(total_step) + maybe_update_tui(lambda tui: tui.set_progress(total_step)) self._last_generation_ts: str = datetime.now().isoformat() self._last_benchmark: list[float] = self._benchmark diff --git a/skynet/dgpu/network.py b/skynet/dgpu/network.py index 6efd871..5deb058 100755 --- a/skynet/dgpu/network.py +++ b/skynet/dgpu/network.py @@ -13,7 +13,7 @@ import outcome from PIL import Image from leap.cleos import CLEOS from leap.protocol import Asset -from skynet.dgpu.tui import WorkerMonitor +from skynet.dgpu.tui import maybe_update_tui from skynet.constants import ( DEFAULT_IPFS_DOMAIN, GPU_CONTRACT_ABI, @@ -58,7 +58,7 @@ class NetConnector: - CLEOS client ''' - def __init__(self, config: dict, tui: WorkerMonitor | None = None): + def __init__(self, config: dict): # TODO, why these extra instance vars for an (unsynced) # copy of the `config` state? self.account = config['account'] @@ -82,9 +82,8 @@ class NetConnector: self.ipfs_domain = config['ipfs_domain'] self._wip_requests = {} - self._tui = tui - if self._tui: - self._tui.set_header_text(new_worker_name=self.account) + + maybe_update_tui(lambda tui: tui.set_header_text(new_worker_name=self.account)) # blockchain helpers @@ -168,8 +167,8 @@ class NetConnector: n.start_soon( _run_and_save, snap['requests'], req['id'], self.get_status_by_request_id, req['id']) - if self._tui: - self._tui.network_update(snap) + + maybe_update_tui(lambda tui: tui.network_update(snap)) return snap diff --git a/skynet/dgpu/tui.py b/skynet/dgpu/tui.py index 6530b58..7614d1c 100644 --- a/skynet/dgpu/tui.py +++ b/skynet/dgpu/tui.py @@ -1,6 +1,9 @@ -import urwid -import trio import json +import logging +import warnings + +import trio +import urwid class WorkerMonitor: @@ -163,86 +166,41 @@ class WorkerMonitor: self.update_requests(queue) -# # ----------------------------------------------------------------------------- -# # Example usage -# # ----------------------------------------------------------------------------- -# -# async def main(): -# # Example data -# example_requests = [ -# { -# "id": 12, -# "model": "black-forest-labs/FLUX.1-schnell", -# "prompt": "Generate an answer about quantum entanglement.", -# "user": "alice123", -# "reward": "20.0000 GPU", -# "workers": ["workerA", "workerB"], -# }, -# { -# "id": 5, -# "model": "some-other-model/v2.0", -# "prompt": "A story about dragons.", -# "user": "bobthebuilder", -# "reward": "15.0000 GPU", -# "workers": ["workerX"], -# }, -# { -# "id": 99, -# "model": "cool-model/turbo", -# "prompt": "Classify sentiment in these tweets.", -# "user": "charlie", -# "reward": "25.5000 GPU", -# "workers": ["workerOne", "workerTwo", "workerThree"], -# }, -# ] -# -# ui = WorkerMonitor() -# -# async def progress_task(): -# # Fill from 0% to 100% -# for pct in range(101): -# ui.set_progress( -# current=pct, -# status_str=f"Request #1234 ({pct}%)" -# ) -# await trio.sleep(0.05) -# # Reset to 0 -# ui.set_progress( -# current=0, -# status_str="Starting again..." -# ) -# -# async def update_data_task(): -# await trio.sleep(3) # Wait a bit, then update requests -# new_data = [{ -# "id": 101, -# "model": "new-model/v1.0", -# "prompt": "Say hi to the world.", -# "user": "eve", -# "reward": "50.0000 GPU", -# "workers": ["workerFresh", "workerPower"], -# }] -# ui.update_requests(new_data) -# ui.set_header_text(new_worker_name="NewNodeName", -# new_balance="balance: 12345.6789 GPU") -# -# try: -# async with trio.open_nursery() as nursery: -# # Run the TUI -# nursery.start_soon(ui.run_teadown_on_exit, nursery) -# -# ui.update_requests(example_requests) -# ui.set_header_text( -# new_worker_name="worker1.scd", -# new_balance="balance: 12345.6789 GPU" -# ) -# # Start background tasks -# nursery.start_soon(progress_task) -# nursery.start_soon(update_data_task) -# -# except *KeyboardInterrupt as ex_group: -# ... -# -# -# if __name__ == "__main__": -# trio.run(main) +def setup_logging_for_tui(level): + warnings.filterwarnings("ignore") + + logger = logging.getLogger() + logger.setLevel(level) + + fh = logging.FileHandler('dgpu.log') + fh.setLevel(level) + + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") + fh.setFormatter(formatter) + + logger.addHandler(fh) + + for handler in logger.handlers: + if isinstance(handler, logging.StreamHandler): + logger.removeHandler(handler) + + +_tui = None +def init_tui(): + global _tui + assert not _tui + setup_logging_for_tui(logging.INFO) + _tui = WorkerMonitor() + return _tui + + +def maybe_update_tui(fn): + global _tui + if _tui: + fn(_tui) + + +async def maybe_update_tui_async(fn): + global _tui + if _tui: + await fn(_tui)